diff options
Diffstat (limited to 'test')
202 files changed, 14700 insertions, 4544 deletions
diff --git a/test/README.md b/test/README.md index 02bbf42ff..15b0f4c33 100644 --- a/test/README.md +++ b/test/README.md @@ -24,11 +24,11 @@ also used to run these tests in `kokoro`. To run image and integration tests, run: -`./scripts/docker_tests.sh` +`make docker-tests` To run root tests, run: -`./scripts/root_tests.sh` +`make root-tests` There are a few other interesting variations for image and integration tests: diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD new file mode 100644 index 000000000..32c139204 --- /dev/null +++ b/test/benchmarks/base/BUILD @@ -0,0 +1,34 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "base", + testonly = 1, + srcs = [ + "base.go", + ], + deps = ["//test/benchmarks/harness"], +) + +go_test( + name = "base_test", + size = "large", + srcs = [ + "size_test.go", + "startup_test.go", + "sysbench_test.go", + ], + library = ":base", + tags = [ + # Requires docker and runsc to be configured before test runs. + "manual", + "local", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], +) diff --git a/test/benchmarks/base/base.go b/test/benchmarks/base/base.go new file mode 100644 index 000000000..7bac52ff1 --- /dev/null +++ b/test/benchmarks/base/base.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package base holds base performance benchmarks. +package base + +import ( + "os" + "testing" + + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +var testHarness harness.Harness + +// TestMain is the main method for package network. +func TestMain(m *testing.M) { + testHarness.Init() + os.Exit(m.Run()) +} diff --git a/test/benchmarks/base/size_test.go b/test/benchmarks/base/size_test.go new file mode 100644 index 000000000..7d3877459 --- /dev/null +++ b/test/benchmarks/base/size_test.go @@ -0,0 +1,221 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkSizeEmpty creates N empty containers and reads memory usage from +// /proc/meminfo. +func BenchmarkSizeEmpty(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + meminfo := tools.Meminfo{} + ctx := context.Background() + containers := make([]*dockerutil.Container, 0, b.N) + + // DropCaches before the test. + harness.DropCaches(machine) + + // Check available memory on 'machine'. + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to get meminfo: %v", err) + } + + // Make N containers. + for i := 0; i < b.N; i++ { + container := machine.GetContainer(ctx, b) + containers = append(containers, container) + if err := container.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/alpine", + }, "sh", "-c", "echo Hello && sleep 1000"); err != nil { + cleanUpContainers(ctx, containers) + b.Fatalf("failed to run container: %v", err) + } + if _, err := container.WaitForOutputSubmatch(ctx, "Hello", 5*time.Second); err != nil { + cleanUpContainers(ctx, containers) + b.Fatalf("failed to read container output: %v", err) + } + } + + // Drop caches again before second measurement. + harness.DropCaches(machine) + + // Check available memory after containers are up. + after, err := machine.RunCommand(cmd, args...) + cleanUpContainers(ctx, containers) + if err != nil { + b.Fatalf("failed to get meminfo: %v", err) + } + meminfo.Report(b, before, after) +} + +// BenchmarkSizeNginx starts N containers running Nginx, checks that they're +// serving, and checks memory used based on /proc/meminfo. +func BenchmarkSizeNginx(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + // DropCaches for the first measurement. + harness.DropCaches(machine) + + // Measure MemAvailable before creating containers. + meminfo := tools.Meminfo{} + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + + // Make N Nginx containers. + ctx := context.Background() + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + } + const port = 80 + servers := startServers(ctx, b, + serverArgs{ + machine: machine, + port: port, + runOpts: runOpts, + cmd: []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"}, + }) + defer cleanUpContainers(ctx, servers) + + // DropCaches after servers are created. + harness.DropCaches(machine) + // Take after measurement. + after, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + meminfo.Report(b, before, after) +} + +// BenchmarkSizeNode starts N containers running a Node app, checks that +// they're serving, and checks memory used based on /proc/meminfo. +func BenchmarkSizeNode(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + // Make a redis instance for Node to connect. + ctx := context.Background() + redis, redisIP := redisInstance(ctx, b, machine) + defer redis.CleanUp(ctx) + + // DropCaches after redis is created. + harness.DropCaches(machine) + + // Take before measurement. + meminfo := tools.Meminfo{} + cmd, args := meminfo.MakeCmd() + before, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo commend: %v", err) + } + + // Create N Node servers. + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + } + nodeCmd := []string{"node", "index.js", redisIP.String()} + const port = 8080 + servers := startServers(ctx, b, + serverArgs{ + machine: machine, + port: port, + runOpts: runOpts, + cmd: nodeCmd, + }) + defer cleanUpContainers(ctx, servers) + + // DropCaches after servers are created. + harness.DropCaches(machine) + // Take after measurement. + cmd, args = meminfo.MakeCmd() + after, err := machine.RunCommand(cmd, args...) + if err != nil { + b.Fatalf("failed to run meminfo command: %v", err) + } + meminfo.Report(b, before, after) +} + +// serverArgs wraps args for startServers and runServerWorkload. +type serverArgs struct { + machine harness.Machine + port int + runOpts dockerutil.RunOpts + cmd []string +} + +// startServers starts b.N containers defined by 'runOpts' and 'cmd' and uses +// 'machine' to check that each is up. +func startServers(ctx context.Context, b *testing.B, args serverArgs) []*dockerutil.Container { + b.Helper() + servers := make([]*dockerutil.Container, 0, b.N) + + // Create N servers and wait until each of them is serving. + for i := 0; i < b.N; i++ { + server := args.machine.GetContainer(ctx, b) + servers = append(servers, server) + if err := server.Spawn(ctx, args.runOpts, args.cmd...); err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to spawn node instance: %v", err) + } + + // Get the container IP. + servingIP, err := server.FindIP(ctx, false) + if err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to get ip from server: %v", err) + } + + // Wait until the server is up. + if err := harness.WaitUntilServing(ctx, args.machine, servingIP, args.port); err != nil { + cleanUpContainers(ctx, servers) + b.Fatalf("failed to wait for serving") + } + } + return servers +} + +// cleanUpContainers cleans up a slice of containers. +func cleanUpContainers(ctx context.Context, containers []*dockerutil.Container) { + for _, c := range containers { + if c != nil { + c.CleanUp(ctx) + } + } +} diff --git a/test/benchmarks/base/startup_test.go b/test/benchmarks/base/startup_test.go new file mode 100644 index 000000000..c36a544db --- /dev/null +++ b/test/benchmarks/base/startup_test.go @@ -0,0 +1,155 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" +) + +// BenchmarkStartEmpty times startup time for an empty container. +func BenchmarkStartupEmpty(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + for i := 0; i < b.N; i++ { + container := machine.GetContainer(ctx, b) + defer container.CleanUp(ctx) + if _, err := container.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/alpine", + }, "true"); err != nil { + b.Fatalf("failed to run container: %v", err) + } + } +} + +// BenchmarkStartupNginx times startup for a Nginx instance. +// Time is measured from start until the first request is served. +func BenchmarkStartupNginx(b *testing.B) { + // The machine to hold Nginx and the Node Server. + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + } + runServerWorkload(ctx, b, + serverArgs{ + machine: machine, + runOpts: runOpts, + port: 80, + cmd: []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"}, + }) +} + +// BenchmarkStartupNode times startup for a Node application instance. +// Time is measured from start until the first request is served. +// Note that the Node app connects to a Redis instance before serving. +func BenchmarkStartupNode(b *testing.B) { + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer machine.CleanUp() + + ctx := context.Background() + redis, redisIP := redisInstance(ctx, b, machine) + defer redis.CleanUp(ctx) + runOpts := dockerutil.RunOpts{ + Image: "benchmarks/node", + WorkDir: "/usr/src/app", + Links: []string{redis.MakeLink("redis")}, + } + + cmd := []string{"node", "index.js", redisIP.String()} + runServerWorkload(ctx, b, + serverArgs{ + machine: machine, + port: 8080, + runOpts: runOpts, + cmd: cmd, + }) +} + +// redisInstance returns a Redis container and its reachable IP. +func redisInstance(ctx context.Context, b *testing.B, machine harness.Machine) (*dockerutil.Container, net.IP) { + b.Helper() + // Spawn a redis instance for the app to use. + redis := machine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to spwan redis instance: %v", err) + } + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + redis.CleanUp(ctx) + b.Fatalf("failed to get IP from redis instance: %v", err) + } + return redis, redisIP +} + +// runServerWorkload runs a server workload defined by 'runOpts' and 'cmd'. +// 'clientMachine' is used to connect to the server on 'serverMachine'. +func runServerWorkload(ctx context.Context, b *testing.B, args serverArgs) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := func() error { + server := args.machine.GetContainer(ctx, b) + defer func() { + b.StopTimer() + // Cleanup servers as we run so that we can go indefinitely. + server.CleanUp(ctx) + b.StartTimer() + }() + if err := server.Spawn(ctx, args.runOpts, args.cmd...); err != nil { + return fmt.Errorf("failed to spawn node instance: %v", err) + } + + servingIP, err := server.FindIP(ctx, false) + if err != nil { + return fmt.Errorf("failed to get ip from server: %v", err) + } + + // Wait until the Client sees the server as up. + if err := harness.WaitUntilServing(ctx, args.machine, servingIP, args.port); err != nil { + return fmt.Errorf("failed to wait for serving: %v", err) + } + return nil + }(); err != nil { + b.Fatal(err) + } + } +} diff --git a/test/benchmarks/base/sysbench_test.go b/test/benchmarks/base/sysbench_test.go new file mode 100644 index 000000000..6fb813640 --- /dev/null +++ b/test/benchmarks/base/sysbench_test.go @@ -0,0 +1,89 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package base + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +type testCase struct { + name string + test tools.Sysbench +} + +// BenchmarSysbench runs sysbench on the runtime. +func BenchmarkSysbench(b *testing.B) { + + testCases := []testCase{ + testCase{ + name: "CPU", + test: &tools.SysbenchCPU{ + Base: tools.SysbenchBase{ + Threads: 1, + Time: 5, + }, + MaxPrime: 50000, + }, + }, + testCase{ + name: "Memory", + test: &tools.SysbenchMemory{ + Base: tools.SysbenchBase{ + Threads: 1, + }, + BlockSize: "1M", + TotalSize: "500G", + }, + }, + testCase{ + name: "Mutex", + test: &tools.SysbenchMutex{ + Base: tools.SysbenchBase{ + Threads: 8, + }, + Loops: 1, + Locks: 10000000, + Num: 4, + }, + }, + } + + machine, err := testHarness.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer machine.CleanUp() + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + + ctx := context.Background() + sysbench := machine.GetContainer(ctx, b) + defer sysbench.CleanUp(ctx) + + out, err := sysbench.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/sysbench", + }, tc.test.MakeCmd()...) + if err != nil { + b.Fatalf("failed to run sysbench: %v: logs:%s", err, out) + } + tc.test.Report(b, out) + }) + } +} diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD index 572db665f..93b380e8a 100644 --- a/test/benchmarks/database/BUILD +++ b/test/benchmarks/database/BUILD @@ -12,15 +12,14 @@ go_library( go_test( name = "database_test", size = "enormous", - srcs = [ - "redis_test.go", - ], + srcs = ["redis_test.go"], library = ":database", tags = [ # Requires docker and runsc to be configured before test runs. "manual", "local", ], + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go index 394fce820..6671a4969 100644 --- a/test/benchmarks/database/redis_test.go +++ b/test/benchmarks/database/redis_test.go @@ -84,12 +84,12 @@ func BenchmarkRedis(b *testing.B) { ip, err := serverMachine.IPAddress() if err != nil { - b.Fatal("failed to get IP from server: %v", err) + b.Fatalf("failed to get IP from server: %v", err) } serverPort, err := server.FindPort(ctx, port) if err != nil { - b.Fatal("failed to get IP from server: %v", err) + b.Fatalf("failed to get IP from server: %v", err) } if err = harness.WaitUntilServing(ctx, clientMachine, ip, serverPort); err != nil { diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD index 20654d88f..45f11372b 100644 --- a/test/benchmarks/fs/BUILD +++ b/test/benchmarks/fs/BUILD @@ -22,6 +22,7 @@ go_test( "local", "manual", ], + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go index f4236ba37..ef1b8e4ea 100644 --- a/test/benchmarks/fs/bazel_test.go +++ b/test/benchmarks/fs/bazel_test.go @@ -62,7 +62,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { container := machine.GetContainer(ctx, b) defer container.CleanUp(ctx) - // Start a container and sleep by an order of b.N. + // Start a container and sleep. if err := container.Spawn(ctx, dockerutil.RunOpts{ Image: image, }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil { @@ -70,12 +70,13 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { } // If we are running on a tmpfs, copy to /tmp which is a tmpfs. + prefix := "" if bm.tmpfs { if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, "cp", "-r", workdir, "/tmp/."); err != nil { - b.Fatal("failed to copy directory: %v %s", err, out) + b.Fatalf("failed to copy directory: %v (%s)", err, out) } - workdir = "/tmp" + workdir + prefix = "/tmp" } // Restart profiles after the copy. @@ -94,7 +95,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { b.StartTimer() got, err := container.Exec(ctx, dockerutil.ExecOpts{ - WorkDir: workdir, + WorkDir: prefix + workdir, }, "bazel", "build", "-c", "opt", target) if err != nil { b.Fatalf("build failed with: %v", err) @@ -107,7 +108,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { } // Clean bazel in case we use b.N. _, err = container.Exec(ctx, dockerutil.ExecOpts{ - WorkDir: workdir, + WorkDir: prefix + workdir, }, "bazel", "clean") if err != nil { b.Fatalf("build failed with: %v", err) diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go index bc551c582..86b863f78 100644 --- a/test/benchmarks/harness/util.go +++ b/test/benchmarks/harness/util.go @@ -23,23 +23,25 @@ import ( "gvisor.dev/gvisor/pkg/test/testutil" ) +//TODO(gvisor.dev/issue/3535): move to own package or move methods to harness struct. + // WaitUntilServing grabs a container from `machine` and waits for a server at // IP:port. func WaitUntilServing(ctx context.Context, machine Machine, server net.IP, port int) error { - var logger testutil.DefaultLogger = "netcat" + var logger testutil.DefaultLogger = "util" netcat := machine.GetNativeContainer(ctx, logger) defer netcat.CleanUp(ctx) - cmd := fmt.Sprintf("while ! nc -zv %s %d; do true; done", server, port) + cmd := fmt.Sprintf("while ! wget -q --spider http://%s:%d; do true; done", server, port) _, err := netcat.Run(ctx, dockerutil.RunOpts{ - Image: "packetdrill", + Image: "benchmarks/util", }, "sh", "-c", cmd) return err } // DropCaches drops caches on the provided machine. Requires root. func DropCaches(machine Machine) error { - if out, err := machine.RunCommand("/bin/sh", "-c", "sync | sysctl vm.drop_caches=3"); err != nil { + if out, err := machine.RunCommand("/bin/sh", "-c", "sync && sysctl vm.drop_caches=3"); err != nil { return fmt.Errorf("failed to drop caches: %v logs: %s", err, out) } return nil diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD index 6c41fc4f6..bb242d385 100644 --- a/test/benchmarks/media/BUILD +++ b/test/benchmarks/media/BUILD @@ -14,6 +14,7 @@ go_test( size = "large", srcs = ["ffmpeg_test.go"], library = ":media", + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD index 2430b60a7..970f52706 100644 --- a/test/benchmarks/ml/BUILD +++ b/test/benchmarks/ml/BUILD @@ -14,6 +14,7 @@ go_test( size = "large", srcs = ["tensorflow_test.go"], library = ":ml", + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD index d15cd55ee..472b5c387 100644 --- a/test/benchmarks/network/BUILD +++ b/test/benchmarks/network/BUILD @@ -5,8 +5,15 @@ package(licenses = ["notice"]) go_library( name = "network", testonly = 1, - srcs = ["network.go"], - deps = ["//test/benchmarks/harness"], + srcs = [ + "network.go", + "static_server.go", + ], + deps = [ + "//pkg/test/dockerutil", + "//test/benchmarks/harness", + "//test/benchmarks/tools", + ], ) go_test( @@ -17,6 +24,7 @@ go_test( "iperf_test.go", "nginx_test.go", "node_test.go", + "ruby_test.go", ], library = ":network", tags = [ @@ -24,6 +32,7 @@ go_test( "manual", "local", ], + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//pkg/test/testutil", diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go index 07833f9cd..369ab326e 100644 --- a/test/benchmarks/network/httpd_test.go +++ b/test/benchmarks/network/httpd_test.go @@ -14,22 +14,19 @@ package network import ( - "context" "fmt" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" ) // see Dockerfile '//images/benchmarks/httpd'. -var docs = map[string]string{ +var httpdDocs = map[string]string{ "notfound": "notfound", "1Kb": "latin1k.txt", "10Kb": "latin10k.txt", "100Kb": "latin100k.txt", - "1000Kb": "latin1000k.txt", "1Mb": "latin1024k.txt", "10Mb": "latin10240k.txt", } @@ -37,76 +34,57 @@ var docs = map[string]string{ // BenchmarkHttpdConcurrency iterates the concurrency argument and tests // how well the runtime under test handles requests in parallel. func BenchmarkHttpdConcurrency(b *testing.B) { - // Grab a machine for the client and server. - clientMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get client: %v", err) - } - defer clientMachine.CleanUp() - - serverMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get server: %v", err) - } - defer serverMachine.CleanUp() - // The test iterates over client concurrency, so set other parameters. - concurrency := []int{1, 5, 10, 25} + concurrency := []int{1, 25, 50, 100, 1000} for _, c := range concurrency { b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { hey := &tools.Hey{ - Requests: 10000, + Requests: c * b.N, Concurrency: c, - Doc: docs["10Kb"], + Doc: httpdDocs["10Kb"], } - runHttpd(b, clientMachine, serverMachine, hey) + runHttpd(b, hey, false /* reverse */) }) } } // BenchmarkHttpdDocSize iterates over different sized payloads, testing how -// well the runtime handles different payload sizes. +// well the runtime handles sending different payload sizes. func BenchmarkHttpdDocSize(b *testing.B) { - clientMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get machine: %v", err) - } - defer clientMachine.CleanUp() - - serverMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get machine: %v", err) - } - defer serverMachine.CleanUp() + benchmarkHttpdDocSize(b, false /* reverse */) +} - for name, filename := range docs { - b.Run(name, func(b *testing.B) { - hey := &tools.Hey{ - Requests: 10000, - Concurrency: 1, - Doc: filename, - } - runHttpd(b, clientMachine, serverMachine, hey) - }) - } +// BenchmarkReverseHttpdDocSize iterates over different sized payloads, testing +// how well the runtime handles receiving different payload sizes. +func BenchmarkReverseHttpdDocSize(b *testing.B) { + benchmarkHttpdDocSize(b, true /* reverse */) } -// runHttpd runs a single test run. -func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey) { +// benchmarkHttpdDocSize iterates through all doc sizes, running subbenchmarks +// for each size. +func benchmarkHttpdDocSize(b *testing.B, reverse bool) { b.Helper() + for name, filename := range httpdDocs { + concurrency := []int{1, 25, 50, 100, 1000} + for _, c := range concurrency { + b.Run(fmt.Sprintf("%s_%d", name, c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: c * b.N, + Concurrency: c, + Doc: filename, + } + runHttpd(b, hey, reverse) + }) + } + } +} - // Grab a container from the server. - ctx := context.Background() - server := serverMachine.GetContainer(ctx, b) - defer server.CleanUp(ctx) - - // Copy the docs to /tmp and serve from there. - cmd := "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X" +// runHttpd configures the static serving methods to run httpd. +func runHttpd(b *testing.B, hey *tools.Hey, reverse bool) { + // httpd runs on port 80. port := 80 - - // Start the server. - if err := server.Spawn(ctx, dockerutil.RunOpts{ + httpdRunOpts := dockerutil.RunOpts{ Image: "benchmarks/httpd", Ports: []int{port}, Env: []string{ @@ -117,39 +95,7 @@ func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, hey *t "APACHE_LOG_DIR=/tmp", "APACHE_PID_FILE=/tmp/apache.pid", }, - }, "sh", "-c", cmd); err != nil { - b.Fatalf("failed to start server: %v") - } - - ip, err := serverMachine.IPAddress() - if err != nil { - b.Fatalf("failed to find server ip: %v", err) - } - - servingPort, err := server.FindPort(ctx, port) - if err != nil { - b.Fatalf("failed to find server port %d: %v", port, err) - } - - // Check the server is serving. - harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) - - // Grab a client. - client := clientMachine.GetNativeContainer(ctx, b) - defer client.CleanUp(ctx) - - b.ResetTimer() - server.RestartProfiles() - for i := 0; i < b.N; i++ { - out, err := client.Run(ctx, dockerutil.RunOpts{ - Image: "benchmarks/hey", - }, hey.MakeCmd(ip, servingPort)...) - if err != nil { - b.Fatalf("run failed with: %v", err) - } - - b.StopTimer() - hey.Report(b, out) - b.StartTimer() } + httpdCmd := []string{"sh", "-c", "mkdir -p /tmp/html; cp -r /local/* /tmp/html/.; apache2 -X"} + runStaticServer(b, httpdRunOpts, httpdCmd, port, hey, reverse) } diff --git a/test/benchmarks/network/nginx_test.go b/test/benchmarks/network/nginx_test.go index 5965652a5..9ec70369b 100644 --- a/test/benchmarks/network/nginx_test.go +++ b/test/benchmarks/network/nginx_test.go @@ -14,91 +14,97 @@ package network import ( - "context" "fmt" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" ) +// see Dockerfile '//images/benchmarks/nginx'. +var nginxDocs = map[string]string{ + "notfound": "notfound", + "1Kb": "latin1k.txt", + "10Kb": "latin10k.txt", + "100Kb": "latin100k.txt", + "1Mb": "latin1024k.txt", + "10Mb": "latin10240k.txt", +} + // BenchmarkNginxConcurrency iterates the concurrency argument and tests // how well the runtime under test handles requests in parallel. -// TODO(zkoopmans): Update with different doc sizes like Httpd. func BenchmarkNginxConcurrency(b *testing.B) { - // Grab a machine for the client and server. - clientMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get client: %v", err) - } - defer clientMachine.CleanUp() - - serverMachine, err := h.GetMachine() - if err != nil { - b.Fatalf("failed to get server: %v", err) - } - defer serverMachine.CleanUp() - - concurrency := []int{1, 5, 10, 25} + concurrency := []int{1, 25, 100, 1000} for _, c := range concurrency { - b.Run(fmt.Sprintf("%d", c), func(b *testing.B) { - hey := &tools.Hey{ - Requests: 10000, - Concurrency: c, + for _, tmpfs := range []bool{true, false} { + fs := "Gofer" + if tmpfs { + fs = "Tmpfs" } - runNginx(b, clientMachine, serverMachine, hey) - }) + name := fmt.Sprintf("%d_%s", c, fs) + b.Run(name, func(b *testing.B) { + hey := &tools.Hey{ + Requests: c * b.N, + Concurrency: c, + Doc: nginxDocs["10kb"], // see Dockerfile '//images/benchmarks/nginx' and httpd_test. + } + runNginx(b, hey, false /* reverse */, tmpfs /* tmpfs */) + }) + } + } } -// runHttpd runs a single test run. -func runNginx(b *testing.B, clientMachine, serverMachine harness.Machine, hey *tools.Hey) { - b.Helper() +// BenchmarkNginxDocSize iterates over different sized payloads, testing how +// well the runtime handles sending different payload sizes. +func BenchmarkNginxDocSize(b *testing.B) { + benchmarkNginxDocSize(b, false /* reverse */, true /* tmpfs */) + benchmarkNginxDocSize(b, false /* reverse */, false /* tmpfs */) +} - // Grab a container from the server. - ctx := context.Background() - server := serverMachine.GetContainer(ctx, b) - defer server.CleanUp(ctx) +// BenchmarkReverseNginxDocSize iterates over different sized payloads, testing +// how well the runtime handles receiving different payload sizes. +func BenchmarkReverseNginxDocSize(b *testing.B) { + benchmarkNginxDocSize(b, true /* reverse */, true /* tmpfs */) +} - port := 80 - // Start the server. - if err := server.Spawn(ctx, - dockerutil.RunOpts{ - Image: "benchmarks/nginx", - Ports: []int{port}, - }); err != nil { - b.Fatalf("server failed to start: %v", err) +// benchmarkNginxDocSize iterates through all doc sizes, running subbenchmarks +// for each size. +func benchmarkNginxDocSize(b *testing.B, reverse, tmpfs bool) { + for name, filename := range nginxDocs { + concurrency := []int{1, 25, 50, 100, 1000} + for _, c := range concurrency { + fs := "Gofer" + if tmpfs { + fs = "Tmpfs" + } + benchName := fmt.Sprintf("%s_%d_%s", name, c, fs) + b.Run(benchName, func(b *testing.B) { + hey := &tools.Hey{ + Requests: c * b.N, + Concurrency: c, + Doc: filename, + } + runNginx(b, hey, reverse, tmpfs) + }) + } } +} - ip, err := serverMachine.IPAddress() - if err != nil { - b.Fatalf("failed to find server ip: %v", err) +// runNginx configures the static serving methods to run httpd. +func runNginx(b *testing.B, hey *tools.Hey, reverse, tmpfs bool) { + // nginx runs on port 80. + port := 80 + nginxRunOpts := dockerutil.RunOpts{ + Image: "benchmarks/nginx", + Ports: []int{port}, } - servingPort, err := server.FindPort(ctx, port) - if err != nil { - b.Fatalf("failed to find server port %d: %v", port, err) + nginxCmd := []string{"nginx", "-c", "/etc/nginx/nginx_gofer.conf"} + if tmpfs { + nginxCmd = []string{"sh", "-c", "mkdir -p /tmp/html && cp -a /local/* /tmp/html && nginx -c /etc/nginx/nginx.conf"} } - // Check the server is serving. - harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) - - // Grab a client. - client := clientMachine.GetNativeContainer(ctx, b) - defer client.CleanUp(ctx) - - b.ResetTimer() - server.RestartProfiles() - for i := 0; i < b.N; i++ { - out, err := client.Run(ctx, dockerutil.RunOpts{ - Image: "benchmarks/hey", - }, hey.MakeCmd(ip, servingPort)...) - if err != nil { - b.Fatalf("run failed with: %v", err) - } - b.StopTimer() - hey.Report(b, out) - b.StartTimer() - } + // Command copies nginxDocs to tmpfs serving directory and runs nginx. + runStaticServer(b, nginxRunOpts, nginxCmd, port, hey, reverse) } diff --git a/test/benchmarks/network/node_test.go b/test/benchmarks/network/node_test.go index 5b568cfe5..0f4a205b6 100644 --- a/test/benchmarks/network/node_test.go +++ b/test/benchmarks/network/node_test.go @@ -24,18 +24,16 @@ import ( "gvisor.dev/gvisor/test/benchmarks/tools" ) -// BenchmarkNode runs 10K requests using 'hey' against a Node server run on +// BenchmarkNode runs requests using 'hey' against a Node server run on // 'runtime'. The server responds to requests by grabbing some data in a // redis instance and returns the data in its reponse. The test loops through // increasing amounts of concurency for requests. func BenchmarkNode(b *testing.B) { - requests := 10000 concurrency := []int{1, 5, 10, 25} - for _, c := range concurrency { b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { hey := &tools.Hey{ - Requests: requests, + Requests: b.N * c, // Requests b.N requests per thread. Concurrency: c, } runNode(b, hey) @@ -50,14 +48,14 @@ func runNode(b *testing.B, hey *tools.Hey) { // The machine to hold Redis and the Node Server. serverMachine, err := h.GetMachine() if err != nil { - b.Fatal("failed to get machine with: %v", err) + b.Fatalf("failed to get machine with: %v", err) } defer serverMachine.CleanUp() // The machine to run 'hey'. clientMachine, err := h.GetMachine() if err != nil { - b.Fatal("failed to get machine with: %v", err) + b.Fatalf("failed to get machine with: %v", err) } defer clientMachine.CleanUp() @@ -113,19 +111,17 @@ func runNode(b *testing.B, hey *tools.Hey) { nodeApp.RestartProfiles() b.ResetTimer() - for i := 0; i < b.N; i++ { - // the client should run on Native. - client := clientMachine.GetNativeContainer(ctx, b) - out, err := client.Run(ctx, dockerutil.RunOpts{ - Image: "benchmarks/hey", - }, heyCmd...) - if err != nil { - b.Fatalf("hey container failed: %v logs: %s", err, out) - } - - // Stop the timer to parse the data and report stats. - b.StopTimer() - hey.Report(b, out) - b.StartTimer() + // the client should run on Native. + client := clientMachine.GetNativeContainer(ctx, b) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, heyCmd...) + if err != nil { + b.Fatalf("hey container failed: %v logs: %s", err, out) } + + // Stop the timer to parse the data and report stats. + b.StopTimer() + hey.Report(b, out) + b.StartTimer() } diff --git a/test/benchmarks/network/ruby_test.go b/test/benchmarks/network/ruby_test.go new file mode 100644 index 000000000..67f63f76a --- /dev/null +++ b/test/benchmarks/network/ruby_test.go @@ -0,0 +1,134 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package network + +import ( + "context" + "fmt" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// BenchmarkRuby runs requests using 'hey' against a ruby application server. +// On start, ruby app generates some random data and pushes it to a redis +// instance. On a request, the app grabs for random entries from the redis +// server, publishes it to a document, and returns the doc to the request. +func BenchmarkRuby(b *testing.B) { + concurrency := []int{1, 5, 10, 25} + for _, c := range concurrency { + b.Run(fmt.Sprintf("Concurrency%d", c), func(b *testing.B) { + hey := &tools.Hey{ + Requests: b.N * c, // b.N requests per thread. + Concurrency: c, + } + runRuby(b, hey) + }) + } +} + +// runRuby runs the test for a given # of requests and concurrency. +func runRuby(b *testing.B, hey *tools.Hey) { + b.Helper() + // The machine to hold Redis and the Ruby Server. + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer serverMachine.CleanUp() + + // The machine to run 'hey'. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine with: %v", err) + } + defer clientMachine.CleanUp() + ctx := context.Background() + + // Spawn a redis instance for the app to use. + redis := serverMachine.GetNativeContainer(ctx, b) + if err := redis.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/redis", + }); err != nil { + b.Fatalf("failed to spwan redis instance: %v", err) + } + defer redis.CleanUp(ctx) + + if out, err := redis.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil { + b.Fatalf("failed to start redis server: %v %s", err, out) + } + redisIP, err := redis.FindIP(ctx, false) + if err != nil { + b.Fatalf("failed to get IP from redis instance: %v", err) + } + + // Ruby runs on port 9292. + const port = 9292 + + // Start-up the Ruby server. + rubyApp := serverMachine.GetContainer(ctx, b) + if err := rubyApp.Spawn(ctx, dockerutil.RunOpts{ + Image: "benchmarks/ruby", + WorkDir: "/app", + Links: []string{redis.MakeLink("redis")}, + Ports: []int{port}, + Env: []string{ + fmt.Sprintf("PORT=%d", port), + "WEB_CONCURRENCY=20", + "WEB_MAX_THREADS=20", + "RACK_ENV=production", + fmt.Sprintf("HOST=%s", redisIP), + }, + User: "nobody", + }, "sh", "-c", "/usr/bin/puma"); err != nil { + b.Fatalf("failed to spawn node instance: %v", err) + } + defer rubyApp.CleanUp(ctx) + + servingIP, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to get ip from server: %v", err) + } + + servingPort, err := rubyApp.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to port from node instance: %v", err) + } + + // Wait until the Client sees the server as up. + if err := harness.WaitUntilServing(ctx, clientMachine, servingIP, servingPort); err != nil { + b.Fatalf("failed to wait until serving: %v", err) + } + heyCmd := hey.MakeCmd(servingIP, servingPort) + rubyApp.RestartProfiles() + b.ResetTimer() + + // the client should run on Native. + client := clientMachine.GetNativeContainer(ctx, b) + defer client.CleanUp(ctx) + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, heyCmd...) + if err != nil { + b.Fatalf("hey container failed: %v logs: %s", err, out) + } + + // Stop the timer to parse the data and report stats. + b.StopTimer() + hey.Report(b, out) + b.StartTimer() +} diff --git a/test/benchmarks/network/static_server.go b/test/benchmarks/network/static_server.go new file mode 100644 index 000000000..e747a1395 --- /dev/null +++ b/test/benchmarks/network/static_server.go @@ -0,0 +1,87 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package network + +import ( + "context" + "testing" + + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/benchmarks/harness" + "gvisor.dev/gvisor/test/benchmarks/tools" +) + +// runStaticServer runs static serving workloads (httpd, nginx). +func runStaticServer(b *testing.B, serverOpts dockerutil.RunOpts, serverCmd []string, port int, hey *tools.Hey, reverse bool) { + ctx := context.Background() + + // Get two machines: a client and server. + clientMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer clientMachine.CleanUp() + + serverMachine, err := h.GetMachine() + if err != nil { + b.Fatalf("failed to get machine: %v", err) + } + defer serverMachine.CleanUp() + + // Make the containers. 'reverse=true' specifies that the client should use the + // runtime under test. + var client, server *dockerutil.Container + if reverse { + client = clientMachine.GetContainer(ctx, b) + server = serverMachine.GetNativeContainer(ctx, b) + } else { + client = clientMachine.GetNativeContainer(ctx, b) + server = serverMachine.GetContainer(ctx, b) + } + defer client.CleanUp(ctx) + defer server.CleanUp(ctx) + + // Start the server. + if err := server.Spawn(ctx, serverOpts, serverCmd...); err != nil { + b.Fatalf("failed to start server: %v", err) + } + + // Get its IP. + ip, err := serverMachine.IPAddress() + if err != nil { + b.Fatalf("failed to find server ip: %v", err) + } + + // Get the published port. + servingPort, err := server.FindPort(ctx, port) + if err != nil { + b.Fatalf("failed to find server port %d: %v", port, err) + } + + // Make sure the server is serving. + harness.WaitUntilServing(ctx, clientMachine, ip, servingPort) + b.ResetTimer() + server.RestartProfiles() + out, err := client.Run(ctx, dockerutil.RunOpts{ + Image: "benchmarks/hey", + }, hey.MakeCmd(ip, servingPort)...) + if err != nil { + b.Fatalf("run failed with: %v", err) + } + + b.StopTimer() + hey.Report(b, out) + b.StartTimer() +} diff --git a/test/benchmarks/tcp/BUILD b/test/benchmarks/tcp/BUILD new file mode 100644 index 000000000..6dde7d9e6 --- /dev/null +++ b/test/benchmarks/tcp/BUILD @@ -0,0 +1,41 @@ +load("//tools:defs.bzl", "cc_binary", "go_binary") + +package(licenses = ["notice"]) + +go_binary( + name = "tcp_proxy", + srcs = ["tcp_proxy.go"], + visibility = ["//:sandbox"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/adapters/gonet", + "//pkg/tcpip/link/fdbased", + "//pkg/tcpip/link/qdisc/fifo", + "//pkg/tcpip/network/arp", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/tcpip/transport/udp", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +# nsjoin is a trivial replacement for nsenter. This is used because nsenter is +# not available on all systems where this benchmark is run (and we aim to +# minimize external dependencies.) + +cc_binary( + name = "nsjoin", + srcs = ["nsjoin.c"], + visibility = ["//:sandbox"], +) + +sh_binary( + name = "tcp_benchmark", + srcs = ["tcp_benchmark.sh"], + data = [ + ":nsjoin", + ":tcp_proxy", + ], + visibility = ["//:sandbox"], +) diff --git a/test/benchmarks/tcp/README.md b/test/benchmarks/tcp/README.md new file mode 100644 index 000000000..38e6e69f0 --- /dev/null +++ b/test/benchmarks/tcp/README.md @@ -0,0 +1,87 @@ +# TCP Benchmarks + +This directory contains a standardized TCP benchmark. This helps to evaluate the +performance of netstack and native networking stacks under various conditions. + +## `tcp_benchmark` + +This benchmark allows TCP throughput testing under various conditions. The setup +consists of an iperf client, a client proxy, a server proxy and an iperf server. +The client proxy and server proxy abstract the network mechanism used to +communicate between the iperf client and server. + +The setup looks like the following: + +``` + +--------------+ (native) +--------------+ + | iperf client |[lo @ 10.0.0.1]------>| client proxy | + +--------------+ +--------------+ + [client.0 @ 10.0.0.2] + (netstack) | | (native) + +------+-----+ + | + [br0] + | + Network emulation applied ---> [wan.0:wan.1] + | + [br1] + | + +------+-----+ + (netstack) | | (native) + [server.0 @ 10.0.0.3] + +--------------+ +--------------+ + | iperf server |<------[lo @ 10.0.0.4]| server proxy | + +--------------+ (native) +--------------+ +``` + +Different configurations can be run using different arguments. For example: + +* Native test under normal internet conditions: `tcp_benchmark` +* Native test under ideal conditions: `tcp_benchmark --ideal` +* Netstack client under ideal conditions: `tcp_benchmark --client --ideal` +* Netstack client with 5% packet loss: `tcp_benchmark --client --ideal --loss + 5` + +Use `tcp_benchmark --help` for full arguments. + +This tool may be used to easily generate data for graphing. For example, to +generate a CSV for various latencies, you might do: + +``` +rm -f /tmp/netstack_latency.csv /tmp/native_latency.csv +latencies=$(seq 0 5 50; + seq 60 10 100; + seq 125 25 250; + seq 300 50 500) +for latency in $latencies; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --client --ideal --latency $latency) + echo $latency,$throughput,$client_cpu >> /tmp/netstack_latency.csv +done +for latency in $latencies; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --ideal --latency $latency) + echo $latency,$throughput,$client_cpu >> /tmp/native_latency.csv +done +``` + +Similarly, to generate a CSV for various levels of packet loss, the following +would be appropriate: + +``` +rm -f /tmp/netstack_loss.csv /tmp/native_loss.csv +losses=$(seq 0 0.1 1.0; + seq 1.2 0.2 2.0; + seq 2.5 0.5 5.0; + seq 6.0 1.0 10.0) +for loss in $losses; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --client --ideal --latency 10 --loss $loss) + echo $loss,$throughput,$client_cpu >> /tmp/netstack_loss.csv +done +for loss in $losses; do + read throughput client_cpu server_cpu <<< \ + $(./tcp_benchmark --duration 30 --ideal --latency 10 --loss $loss) + echo $loss,$throughput,$client_cpu >> /tmp/native_loss.csv +done +``` diff --git a/test/benchmarks/tcp/nsjoin.c b/test/benchmarks/tcp/nsjoin.c new file mode 100644 index 000000000..524b4d549 --- /dev/null +++ b/test/benchmarks/tcp/nsjoin.c @@ -0,0 +1,47 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include <errno.h> +#include <fcntl.h> +#include <sched.h> +#include <stdio.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +int main(int argc, char** argv) { + if (argc <= 2) { + fprintf(stderr, "error: must provide a namespace file.\n"); + fprintf(stderr, "usage: %s <file> [arguments...]\n", argv[0]); + return 1; + } + + int fd = open(argv[1], O_RDONLY); + if (fd < 0) { + fprintf(stderr, "error opening %s: %s\n", argv[1], strerror(errno)); + return 1; + } + if (setns(fd, 0) < 0) { + fprintf(stderr, "error joining %s: %s\n", argv[1], strerror(errno)); + return 1; + } + + execvp(argv[2], &argv[2]); + return 1; +} diff --git a/test/benchmarks/tcp/tcp_benchmark.sh b/test/benchmarks/tcp/tcp_benchmark.sh new file mode 100755 index 000000000..ef04b4ace --- /dev/null +++ b/test/benchmarks/tcp/tcp_benchmark.sh @@ -0,0 +1,392 @@ +#!/bin/bash + +# Copyright 2018 The gVisor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TCP benchmark; see README.md for documentation. + +# Fixed parameters. +iperf_port=45201 # Not likely to be privileged. +proxy_port=44000 # Ditto. +client_addr=10.0.0.1 +client_proxy_addr=10.0.0.2 +server_proxy_addr=10.0.0.3 +server_addr=10.0.0.4 +mask=8 + +# Defaults; this provides a reasonable approximation of a decent internet link. +# Parameters can be varied independently from this set to see response to +# various changes in the kind of link available. +client=false +server=false +verbose=false +gso=0 +swgso=false +mtu=1280 # 1280 is a reasonable lowest-common-denominator. +latency=10 # 10ms approximates a fast, dedicated connection. +latency_variation=1 # +/- 1ms is a relatively low amount of jitter. +loss=0.1 # 0.1% loss is non-zero, but not extremely high. +duplicate=0.1 # 0.1% means duplicates are 1/10x as frequent as losses. +duration=30 # 30s is enough time to consistent results (experimentally). +helper_dir=$(dirname $0) +netstack_opts= +disable_linux_gso= +num_client_threads=1 + +# Check for netem support. +lsmod_output=$(lsmod | grep sch_netem) +if [ "$?" != "0" ]; then + echo "warning: sch_netem may not be installed." >&2 +fi + +while [ $# -gt 0 ]; do + case "$1" in + --client) + client=true + ;; + --client_tcp_probe_file) + shift + netstack_opts="${netstack_opts} -client_tcp_probe_file=$1" + ;; + --server) + server=true + ;; + --verbose) + verbose=true + ;; + --gso) + shift + gso=$1 + ;; + --swgso) + swgso=true + ;; + --server_tcp_probe_file) + shift + netstack_opts="${netstack_opts} -server_tcp_probe_file=$1" + ;; + --ideal) + mtu=1500 # Standard ethernet. + latency=0 # No latency. + latency_variation=0 # No jitter. + loss=0 # No loss. + duplicate=0 # No duplicates. + ;; + --mtu) + shift + [ "$#" -le 0 ] && echo "no mtu provided" && exit 1 + mtu=$1 + ;; + --sack) + netstack_opts="${netstack_opts} -sack" + ;; + --cubic) + netstack_opts="${netstack_opts} -cubic" + ;; + --moderate-recv-buf) + netstack_opts="${netstack_opts} -moderate_recv_buf" + ;; + --duration) + shift + [ "$#" -le 0 ] && echo "no duration provided" && exit 1 + duration=$1 + ;; + --latency) + shift + [ "$#" -le 0 ] && echo "no latency provided" && exit 1 + latency=$1 + ;; + --latency-variation) + shift + [ "$#" -le 0 ] && echo "no latency variation provided" && exit 1 + latency_variation=$1 + ;; + --loss) + shift + [ "$#" -le 0 ] && echo "no loss probability provided" && exit 1 + loss=$1 + ;; + --duplicate) + shift + [ "$#" -le 0 ] && echo "no duplicate provided" && exit 1 + duplicate=$1 + ;; + --cpuprofile) + shift + netstack_opts="${netstack_opts} -cpuprofile=$1" + ;; + --memprofile) + shift + netstack_opts="${netstack_opts} -memprofile=$1" + ;; + --disable-linux-gso) + disable_linux_gso=1 + ;; + --num-client-threads) + shift + num_client_threads=$1 + ;; + --helpers) + shift + [ "$#" -le 0 ] && echo "no helper dir provided" && exit 1 + helper_dir=$1 + ;; + *) + echo "usage: $0 [options]" + echo "options:" + echo " --help show this message" + echo " --verbose verbose output" + echo " --client use netstack as the client" + echo " --ideal reset all network emulation" + echo " --server use netstack as the server" + echo " --mtu set the mtu (bytes)" + echo " --sack enable SACK support" + echo " --moderate-recv-buf enable TCP receive buffer auto-tuning" + echo " --cubic enable CUBIC congestion control for Netstack" + echo " --duration set the test duration (s)" + echo " --latency set the latency (ms)" + echo " --latency-variation set the latency variation" + echo " --loss set the loss probability (%)" + echo " --duplicate set the duplicate probability (%)" + echo " --helpers set the helper directory" + echo " --num-client-threads number of parallel client threads to run" + echo " --disable-linux-gso disable segmentation offload in the Linux network stack" + echo "" + echo "The output will of the script will be:" + echo " <throughput> <client-cpu-usage> <server-cpu-usage>" + exit 1 + esac + shift +done + +if [ ${verbose} == "true" ]; then + set -x +fi + +# Latency needs to be halved, since it's applied on both ways. +half_latency=$(echo ${latency}/2 | bc -l | awk '{printf "%1.2f", $0}') +half_loss=$(echo ${loss}/2 | bc -l | awk '{printf "%1.6f", $0}') +half_duplicate=$(echo ${duplicate}/2 | bc -l | awk '{printf "%1.6f", $0}') +helper_dir=${helper_dir#$(pwd)/} # Use relative paths. +proxy_binary=${helper_dir}/tcp_proxy +nsjoin_binary=${helper_dir}/nsjoin + +if [ ! -e ${proxy_binary} ]; then + echo "Could not locate ${proxy_binary}, please make sure you've built the binary" + exit 1 +fi + +if [ ! -e ${nsjoin_binary} ]; then + echo "Could not locate ${nsjoin_binary}, please make sure you've built the binary" + exit 1 +fi + +if [ $(echo ${latency_variation} | awk '{printf "%1.2f", $0}') != "0.00" ]; then + # As long as there's some jitter, then we use the paretonormal distribution. + # This will preserve the minimum RTT, but add a realistic amount of jitter to + # the connection and cause re-ordering, etc. The regular pareto distribution + # appears to an unreasonable level of delay (we want only small spikes.) + distribution="distribution paretonormal" +else + distribution="" +fi + +# Client proxy that will listen on the client's iperf target forward traffic +# using the host networking stack. +client_args="${proxy_binary} -port ${proxy_port} -forward ${server_proxy_addr}:${proxy_port}" +if ${client}; then + # Client proxy that will listen on the client's iperf target + # and forward traffic using netstack. + client_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -client \\ + -mtu ${mtu} -iface client.0 -addr ${client_proxy_addr} -mask ${mask} \\ + -forward ${server_proxy_addr}:${proxy_port} -gso=${gso} -swgso=${swgso}" +fi + +# Server proxy that will listen on the proxy port and forward to the server's +# iperf server using the host networking stack. +server_args="${proxy_binary} -port ${proxy_port} -forward ${server_addr}:${iperf_port}" +if ${server}; then + # Server proxy that will listen on the proxy port and forward to the servers' + # iperf server using netstack. + server_args="${proxy_binary} ${netstack_opts} -port ${proxy_port} -server \\ + -mtu ${mtu} -iface server.0 -addr ${server_proxy_addr} -mask ${mask} \\ + -forward ${server_addr}:${iperf_port} -gso=${gso} -swgso=${swgso}" +fi + +# Specify loss and duplicate parameters only if they are non-zero +loss_opt="" +if [ "$(echo $half_loss | bc -q)" != "0" ]; then + loss_opt="loss random ${half_loss}%" +fi +duplicate_opt="" +if [ "$(echo $half_duplicate | bc -q)" != "0" ]; then + duplicate_opt="duplicate ${half_duplicate}%" +fi + +exec unshare -U -m -n -r -f -p --mount-proc /bin/bash << EOF +set -e -m + +if [ ${verbose} == "true" ]; then + set -x +fi + +mount -t tmpfs netstack-bench /tmp + +# We may have reset the path in the unshare if the shell loaded some public +# profiles. Ensure that tools are discoverable via the parent's PATH. +export PATH=${PATH} + +# Add client, server interfaces. +ip link add client.0 type veth peer name client.1 +ip link add server.0 type veth peer name server.1 + +# Add network emulation devices. +ip link add wan.0 type veth peer name wan.1 +ip link set wan.0 up +ip link set wan.1 up + +# Enroll on the bridge. +ip link add name br0 type bridge +ip link add name br1 type bridge +ip link set client.1 master br0 +ip link set server.1 master br1 +ip link set wan.0 master br0 +ip link set wan.1 master br1 +ip link set br0 up +ip link set br1 up + +# Set the MTU appropriately. +ip link set client.0 mtu ${mtu} +ip link set server.0 mtu ${mtu} +ip link set wan.0 mtu ${mtu} +ip link set wan.1 mtu ${mtu} + +# Add appropriate latency, loss and duplication. +# +# This is added in at the point of bridge connection. +for device in wan.0 wan.1; do + # NOTE: We don't support a loss correlation as testing has shown that it + # actually doesn't work. The man page actually has a small comment about this + # "It is also possible to add a correlation, but this option is now deprecated + # due to the noticed bad behavior." For more information see netem(8). + tc qdisc add dev \$device root netem \\ + delay ${half_latency}ms ${latency_variation}ms ${distribution} \\ + ${loss_opt} ${duplicate_opt} +done + +# Start a client proxy. +touch /tmp/client.netns +unshare -n mount --bind /proc/self/ns/net /tmp/client.netns + +# Move the endpoint into the namespace. +while ip link | grep client.0 > /dev/null; do + ip link set dev client.0 netns /tmp/client.netns +done + +if ! ${client}; then + # Only add the address to NIC if netstack is not in use. Otherwise the host + # will also process the inbound SYN and send a RST back. + ${nsjoin_binary} /tmp/client.netns ip addr add ${client_proxy_addr}/${mask} dev client.0 +fi + +# Start a server proxy. +touch /tmp/server.netns +unshare -n mount --bind /proc/self/ns/net /tmp/server.netns +# Move the endpoint into the namespace. +while ip link | grep server.0 > /dev/null; do + ip link set dev server.0 netns /tmp/server.netns +done +if ! ${server}; then + # Only add the address to NIC if netstack is not in use. Otherwise the host + # will also process the inbound SYN and send a RST back. + ${nsjoin_binary} /tmp/server.netns ip addr add ${server_proxy_addr}/${mask} dev server.0 +fi + +# Add client and server addresses, and bring everything up. +${nsjoin_binary} /tmp/client.netns ip addr add ${client_addr}/${mask} dev client.0 +${nsjoin_binary} /tmp/server.netns ip addr add ${server_addr}/${mask} dev server.0 +if [ "${disable_linux_gso}" == "1" ]; then + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 tso off + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gro off + ${nsjoin_binary} /tmp/client.netns ethtool -K client.0 gso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 tso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gso off + ${nsjoin_binary} /tmp/server.netns ethtool -K server.0 gro off +fi +${nsjoin_binary} /tmp/client.netns ip link set client.0 up +${nsjoin_binary} /tmp/client.netns ip link set lo up +${nsjoin_binary} /tmp/server.netns ip link set server.0 up +${nsjoin_binary} /tmp/server.netns ip link set lo up +ip link set dev client.1 up +ip link set dev server.1 up + +${nsjoin_binary} /tmp/client.netns ${client_args} & +client_pid=\$! +${nsjoin_binary} /tmp/server.netns ${server_args} & +server_pid=\$! + +# Start the iperf server. +${nsjoin_binary} /tmp/server.netns iperf -p ${iperf_port} -s >&2 & +iperf_pid=\$! + +# Show traffic information. +if ! ${client} && ! ${server}; then + ${nsjoin_binary} /tmp/client.netns ping -c 100 -i 0.001 -W 1 ${server_addr} >&2 || true +fi + +results_file=\$(mktemp) +function cleanup { + rm -f \$results_file + kill -TERM \$client_pid + kill -TERM \$server_pid + wait \$client_pid + wait \$server_pid + kill -9 \$iperf_pid 2>/dev/null +} + +# Allow failure from this point. +set +e +trap cleanup EXIT + +# Run the benchmark, recording the results file. +while ${nsjoin_binary} /tmp/client.netns iperf \\ + -p ${proxy_port} -c ${client_addr} -t ${duration} -f m -P ${num_client_threads} 2>&1 \\ + | tee \$results_file \\ + | grep "connect failed" >/dev/null; do + sleep 0.1 # Wait for all services. +done + +# Unlink all relevant devices from the bridge. This is because when the bridge +# is deleted, the kernel may hang. It appears that this problem is fixed in +# upstream commit 1ce5cce895309862d2c35d922816adebe094fe4a. +ip link set client.1 nomaster +ip link set server.1 nomaster +ip link set wan.0 nomaster +ip link set wan.1 nomaster + +# Emit raw results. +cat \$results_file >&2 + +# Emit a useful result (final throughput). +mbits=\$(grep Mbits/sec \$results_file \\ + | sed -n -e 's/^.*[[:space:]]\\([[:digit:]]\\+\\(\\.[[:digit:]]\\+\\)\\?\\)[[:space:]]*Mbits\\/sec.*/\\1/p') +client_cpu_ticks=\$(cat /proc/\$client_pid/stat \\ + | awk '{print (\$14+\$15);}') +server_cpu_ticks=\$(cat /proc/\$server_pid/stat \\ + | awk '{print (\$14+\$15);}') +ticks_per_sec=\$(getconf CLK_TCK) +client_cpu_load=\$(bc -l <<< \$client_cpu_ticks/\$ticks_per_sec/${duration}) +server_cpu_load=\$(bc -l <<< \$server_cpu_ticks/\$ticks_per_sec/${duration}) +echo \$mbits \$client_cpu_load \$server_cpu_load +EOF diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go new file mode 100644 index 000000000..5afe10f69 --- /dev/null +++ b/test/benchmarks/tcp/tcp_proxy.go @@ -0,0 +1,458 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary tcp_proxy is a simple TCP proxy. +package main + +import ( + "encoding/gob" + "flag" + "fmt" + "io" + "log" + "math/rand" + "net" + "os" + "os/signal" + "regexp" + "runtime" + "runtime/pprof" + "strconv" + "syscall" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" + "gvisor.dev/gvisor/pkg/tcpip/link/qdisc/fifo" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" +) + +var ( + port = flag.Int("port", 0, "bind port (all addresses)") + forward = flag.String("forward", "", "forwarding target") + client = flag.Bool("client", false, "use netstack for listen") + server = flag.Bool("server", false, "use netstack for dial") + + // Netstack-specific options. + mtu = flag.Int("mtu", 1280, "mtu for network stack") + addr = flag.String("addr", "", "address for tap-based netstack") + mask = flag.Int("mask", 8, "mask size for address") + iface = flag.String("iface", "", "network interface name to bind for netstack") + sack = flag.Bool("sack", false, "enable SACK support for netstack") + moderateRecvBuf = flag.Bool("moderate_recv_buf", false, "enable TCP Receive Buffer Auto-tuning") + cubic = flag.Bool("cubic", false, "enable use of CUBIC congestion control for netstack") + gso = flag.Int("gso", 0, "GSO maximum size") + swgso = flag.Bool("swgso", false, "software-level GSO") + clientTCPProbeFile = flag.String("client_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") + serverTCPProbeFile = flag.String("server_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.") + cpuprofile = flag.String("cpuprofile", "", "write cpu profile to the specified file.") + memprofile = flag.String("memprofile", "", "write memory profile to the specified file.") +) + +type impl interface { + dial(address string) (net.Conn, error) + listen(port int) (net.Listener, error) + printStats() +} + +type netImpl struct{} + +func (netImpl) dial(address string) (net.Conn, error) { + return net.Dial("tcp", address) +} + +func (netImpl) listen(port int) (net.Listener, error) { + return net.Listen("tcp", fmt.Sprintf(":%d", port)) +} + +func (netImpl) printStats() { +} + +const ( + nicID = 1 // Fixed. + bufSize = 4 << 20 // 4MB. +) + +type netstackImpl struct { + s *stack.Stack + addr tcpip.Address + mode string +} + +func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { + // Get all interfaces in the namespace. + ifaces, err := net.Interfaces() + if err != nil { + return nil, fmt.Errorf("querying interfaces: %v", err) + } + + for _, iface := range ifaces { + if iface.Name != ifaceName { + continue + } + // Create the socket. + const protocol = 0x0300 // htons(ETH_P_ALL) + fds := make([]int, numChannels) + for i := range fds { + fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) + if err != nil { + return nil, fmt.Errorf("unable to create raw socket: %v", err) + } + + // Bind to the appropriate device. + ll := syscall.SockaddrLinklayer{ + Protocol: protocol, + Ifindex: iface.Index, + Pkttype: syscall.PACKET_HOST, + } + if err := syscall.Bind(fd, &ll); err != nil { + return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) + } + + // RAW Sockets by default have a very small SO_RCVBUF of 256KB, + // up it to at least 4MB to reduce packet drops. + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err) + } + + if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil { + return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err) + } + + if !*swgso && *gso != 0 { + if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { + return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) + } + } + fds[i] = fd + } + return fds, nil + } + return nil, fmt.Errorf("failed to find interface: %v", ifaceName) +} + +func newNetstackImpl(mode string) (impl, error) { + fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1)) + if err != nil { + return nil, err + } + + // Parse details. + parsedAddr := tcpip.Address(net.ParseIP(*addr).To4()) + parsedDest := tcpip.Address("") // Filled in below. + parsedMask := tcpip.AddressMask("") // Filled in below. + switch *mask { + case 8: + parsedDest = tcpip.Address([]byte{parsedAddr[0], 0, 0, 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0, 0, 0}) + case 16: + parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], 0, 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0, 0}) + case 24: + parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], parsedAddr[2], 0}) + parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0xff, 0}) + default: + // This is just laziness; we don't expect a different mask. + return nil, fmt.Errorf("mask %d not supported", mask) + } + + // Create a new network stack. + netProtos := []stack.NetworkProtocolFactory{ipv4.NewProtocol, arp.NewProtocol} + transProtos := []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol} + s := stack.New(stack.Options{ + NetworkProtocols: netProtos, + TransportProtocols: transProtos, + }) + + // Generate a new mac for the eth device. + mac := make(net.HardwareAddr, 6) + rand.Read(mac) // Fill with random data. + mac[0] &^= 0x1 // Clear multicast bit. + mac[0] |= 0x2 // Set local assignment bit (IEEE802). + ep, err := fdbased.New(&fdbased.Options{ + FDs: fds, + MTU: uint32(*mtu), + EthernetHeader: true, + Address: tcpip.LinkAddress(mac), + // Enable checksum generation as we need to generate valid + // checksums for the veth device to deliver our packets to the + // peer. But we do want to disable checksum verification as veth + // devices do perform GRO and the linux host kernel may not + // regenerate valid checksums after GRO. + TXChecksumOffload: false, + RXChecksumOffload: true, + PacketDispatchMode: fdbased.RecvMMsg, + GSOMaxSize: uint32(*gso), + SoftwareGSOEnabled: *swgso, + }) + if err != nil { + return nil, fmt.Errorf("failed to create FD endpoint: %v", err) + } + if err := s.CreateNIC(nicID, fifo.New(ep, runtime.GOMAXPROCS(0), 1000)); err != nil { + return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err) + } + if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + return nil, fmt.Errorf("error adding ARP address to %q: %v", *iface, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil { + return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err) + } + + subnet, err := tcpip.NewSubnet(parsedDest, parsedMask) + if err != nil { + return nil, fmt.Errorf("tcpip.Subnet(%s, %s): %s", parsedDest, parsedMask, err) + } + // Add default route; we only support + s.SetRouteTable([]tcpip.Route{ + { + Destination: subnet, + NIC: nicID, + }, + }) + + // Set protocol options. + { + opt := tcpip.TCPSACKEnabled(*sack) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + } + + // Enable Receive Buffer Auto-Tuning. + { + opt := tcpip.TCPModerateReceiveBufferOption(*moderateRecvBuf) + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) + } + } + + // Set Congestion Control to cubic if requested. + if *cubic { + opt := tcpip.CongestionControlOption("cubic") + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { + return nil, fmt.Errorf("SetTransportProtocolOption(%d, &%T(%s)): %s", tcp.ProtocolNumber, opt, opt, err) + } + } + + return netstackImpl{ + s: s, + addr: parsedAddr, + mode: mode, + }, nil +} + +func (n netstackImpl) dial(address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + if host == "" { + // A host must be provided for the dial. + return nil, fmt.Errorf("no host provided") + } + portNumber, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + addr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(net.ParseIP(host).To4()), + Port: uint16(portNumber), + } + conn, err := gonet.DialTCP(n.s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + return conn, nil +} + +func (n netstackImpl) listen(port int) (net.Listener, error) { + addr := tcpip.FullAddress{ + NIC: nicID, + Port: uint16(port), + } + listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + return listener, nil +} + +var zeroFieldsRegexp = regexp.MustCompile(`\s*[a-zA-Z0-9]*:0`) + +func (n netstackImpl) printStats() { + // Don't show zero fields. + stats := zeroFieldsRegexp.ReplaceAllString(fmt.Sprintf("%+v", n.s.Stats()), "") + log.Printf("netstack %s Stats: %+v\n", n.mode, stats) +} + +// installProbe installs a TCP Probe function that will dump endpoint +// state to the specified file. It also returns a close func() that +// can be used to close the probeFile. +func (n netstackImpl) installProbe(probeFileName string) (close func()) { + // Install Probe to dump out end point state. + probeFile, err := os.Create(probeFileName) + if err != nil { + log.Fatalf("failed to create tcp_probe file %s: %v", probeFileName, err) + } + probeEncoder := gob.NewEncoder(probeFile) + // Install a TCP Probe. + n.s.AddTCPProbe(func(state stack.TCPEndpointState) { + probeEncoder.Encode(state) + }) + return func() { probeFile.Close() } +} + +func main() { + flag.Parse() + if *port == 0 { + log.Fatalf("no port provided") + } + if *forward == "" { + log.Fatalf("no forward provided") + } + // Seed the random number generator to ensure that we are given MAC addresses that don't + // for the case of the client and server stack. + rand.Seed(time.Now().UTC().UnixNano()) + + if *cpuprofile != "" { + f, err := os.Create(*cpuprofile) + if err != nil { + log.Fatal("could not create CPU profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Print("error closing CPU profile: ", err) + } + }() + if err := pprof.StartCPUProfile(f); err != nil { + log.Fatal("could not start CPU profile: ", err) + } + defer pprof.StopCPUProfile() + } + + var ( + in impl + out impl + err error + ) + if *server { + in, err = newNetstackImpl("server") + if *serverTCPProbeFile != "" { + defer in.(netstackImpl).installProbe(*serverTCPProbeFile)() + } + + } else { + in = netImpl{} + } + if err != nil { + log.Fatalf("netstack error: %v", err) + } + if *client { + out, err = newNetstackImpl("client") + if *clientTCPProbeFile != "" { + defer out.(netstackImpl).installProbe(*clientTCPProbeFile)() + } + } else { + out = netImpl{} + } + if err != nil { + log.Fatalf("netstack error: %v", err) + } + + // Dial forward before binding. + var next net.Conn + for { + next, err = out.dial(*forward) + if err == nil { + break + } + time.Sleep(50 * time.Millisecond) + log.Printf("connect failed retrying: %v", err) + } + + // Bind once to the server socket. + listener, err := in.listen(*port) + if err != nil { + // Should not happen, everything must be bound by this time + // this proxy is started. + log.Fatalf("unable to listen: %v", err) + } + log.Printf("client=%v, server=%v, ready.", *client, *server) + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGTERM) + go func() { + <-sigs + if *cpuprofile != "" { + pprof.StopCPUProfile() + } + if *memprofile != "" { + f, err := os.Create(*memprofile) + if err != nil { + log.Fatal("could not create memory profile: ", err) + } + defer func() { + if err := f.Close(); err != nil { + log.Print("error closing memory profile: ", err) + } + }() + runtime.GC() // get up-to-date statistics + if err := pprof.WriteHeapProfile(f); err != nil { + log.Fatalf("Unable to write heap profile: %v", err) + } + } + os.Exit(0) + }() + + for { + // Forward all connections. + inConn, err := listener.Accept() + if err != nil { + // This should not happen; we are listening + // successfully. Exhausted all available FDs? + log.Fatalf("accept error: %v", err) + } + log.Printf("incoming connection established.") + + // Copy both ways. + go io.Copy(inConn, next) + go io.Copy(next, inConn) + + // Print stats every second. + go func() { + t := time.NewTicker(time.Second) + defer t.Stop() + for { + <-t.C + in.printStats() + out.printStats() + } + }() + + for { + // Dial again. + next, err = out.dial(*forward) + if err == nil { + break + } + } + } +} diff --git a/test/benchmarks/tools/BUILD b/test/benchmarks/tools/BUILD index 4358551bc..e5734d85c 100644 --- a/test/benchmarks/tools/BUILD +++ b/test/benchmarks/tools/BUILD @@ -9,7 +9,9 @@ go_library( "fio.go", "hey.go", "iperf.go", + "meminfo.go", "redis.go", + "sysbench.go", "tools.go", ], visibility = ["//:sandbox"], @@ -23,7 +25,9 @@ go_test( "fio_test.go", "hey_test.go", "iperf_test.go", + "meminfo_test.go", "redis_test.go", + "sysbench_test.go", ], library = ":tools", ) diff --git a/test/benchmarks/tools/hey.go b/test/benchmarks/tools/hey.go index 699497c64..b1e20e356 100644 --- a/test/benchmarks/tools/hey.go +++ b/test/benchmarks/tools/hey.go @@ -25,7 +25,7 @@ import ( // Hey is for the client application 'hey'. type Hey struct { - Requests int + Requests int // Note: requests cannot be less than concurrency. Concurrency int Doc string } diff --git a/test/benchmarks/tools/meminfo.go b/test/benchmarks/tools/meminfo.go new file mode 100644 index 000000000..2414a96a7 --- /dev/null +++ b/test/benchmarks/tools/meminfo.go @@ -0,0 +1,60 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "regexp" + "strconv" + "testing" +) + +// Meminfo wraps measurements of MemAvailable using /proc/meminfo. +type Meminfo struct { +} + +// MakeCmd returns a command for checking meminfo. +func (*Meminfo) MakeCmd() (string, []string) { + return "cat", []string{"/proc/meminfo"} +} + +// Report takes two reads of meminfo, parses them, and reports the difference +// divided by b.N. +func (*Meminfo) Report(b *testing.B, before, after string) { + b.Helper() + + beforeVal, err := parseMemAvailable(before) + if err != nil { + b.Fatalf("could not parse before value %s: %v", before, err) + } + + afterVal, err := parseMemAvailable(after) + if err != nil { + b.Fatalf("could not parse before value %s: %v", before, err) + } + val := 1024 * ((beforeVal - afterVal) / float64(b.N)) + b.ReportMetric(val, "average_container_size_bytes") +} + +var memInfoRE = regexp.MustCompile(`MemAvailable:\s*(\d+)\skB\n`) + +// parseMemAvailable grabs the MemAvailable number from /proc/meminfo. +func parseMemAvailable(data string) (float64, error) { + match := memInfoRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0, fmt.Errorf("couldn't find MemAvailable in %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/meminfo_test.go b/test/benchmarks/tools/meminfo_test.go new file mode 100644 index 000000000..ba803540f --- /dev/null +++ b/test/benchmarks/tools/meminfo_test.go @@ -0,0 +1,84 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" +) + +// TestMeminfo checks the Meminfo parser on sample output. +func TestMeminfo(t *testing.T) { + sampleData := ` +MemTotal: 16337408 kB +MemFree: 3742696 kB +MemAvailable: 9319948 kB +Buffers: 1433884 kB +Cached: 4607036 kB +SwapCached: 45284 kB +Active: 8288376 kB +Inactive: 2685928 kB +Active(anon): 4724912 kB +Inactive(anon): 1047940 kB +Active(file): 3563464 kB +Inactive(file): 1637988 kB +Unevictable: 326940 kB +Mlocked: 48 kB +SwapTotal: 33292284 kB +SwapFree: 32865736 kB +Dirty: 708 kB +Writeback: 0 kB +AnonPages: 4304204 kB +Mapped: 975424 kB +Shmem: 910292 kB +KReclaimable: 744532 kB +Slab: 1058448 kB +SReclaimable: 744532 kB +SUnreclaim: 313916 kB +KernelStack: 25188 kB +PageTables: 65300 kB +NFS_Unstable: 0 kB +Bounce: 0 kB +WritebackTmp: 0 kB +CommitLimit: 41460988 kB +Committed_AS: 22859492 kB +VmallocTotal: 34359738367 kB +VmallocUsed: 63088 kB +VmallocChunk: 0 kB +Percpu: 9248 kB +HardwareCorrupted: 0 kB +AnonHugePages: 786432 kB +ShmemHugePages: 0 kB +ShmemPmdMapped: 0 kB +FileHugePages: 0 kB +FilePmdMapped: 0 kB +HugePages_Total: 0 +HugePages_Free: 0 +HugePages_Rsvd: 0 +HugePages_Surp: 0 +Hugepagesize: 2048 kB +Hugetlb: 0 kB +DirectMap4k: 5408532 kB +DirectMap2M: 11241472 kB +DirectMap1G: 1048576 kB +` + want := 9319948.0 + got, err := parseMemAvailable(sampleData) + if err != nil { + t.Fatalf("parseMemAvailable failed: %v", err) + } + if got != want { + t.Fatalf("parseMemAvailable got %f, want %f", got, want) + } +} diff --git a/test/benchmarks/tools/redis.go b/test/benchmarks/tools/redis.go index db32460ec..c899ae0d4 100644 --- a/test/benchmarks/tools/redis.go +++ b/test/benchmarks/tools/redis.go @@ -56,7 +56,6 @@ func (r *Redis) Report(b *testing.B, output string) { func (r *Redis) parseOperation(data string) (float64, error) { re := regexp.MustCompile(fmt.Sprintf(`"%s( .*)?","(\d*\.\d*)"`, r.Operation)) match := re.FindStringSubmatch(data) - // If no match, simply don't add it to the result map. if len(match) < 3 { return 0.0, fmt.Errorf("could not find %s in %s", r.Operation, data) } diff --git a/test/benchmarks/tools/sysbench.go b/test/benchmarks/tools/sysbench.go new file mode 100644 index 000000000..6b2f75ca2 --- /dev/null +++ b/test/benchmarks/tools/sysbench.go @@ -0,0 +1,245 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "fmt" + "regexp" + "strconv" + "strings" + "testing" +) + +var warmup = "sysbench --threads=8 --memory-total-size=5G memory run > /dev/null &&" + +// Sysbench represents a 'sysbench' command. +type Sysbench interface { + MakeCmd() []string // Makes a sysbench command. + flags() []string + Report(*testing.B, string) // Reports results contained in string. +} + +// SysbenchBase is the top level struct for sysbench and holds top-level arguments +// for sysbench. See: 'sysbench --help' +type SysbenchBase struct { + Threads int // number of Threads for the test. + Time int // time limit for test in seconds. +} + +// baseFlags returns top level flags. +func (s *SysbenchBase) baseFlags() []string { + var ret []string + if s.Threads > 0 { + ret = append(ret, fmt.Sprintf("--threads=%d", s.Threads)) + } + if s.Time > 0 { + ret = append(ret, fmt.Sprintf("--time=%d", s.Time)) + } + return ret +} + +// SysbenchCPU is for 'sysbench [flags] cpu run' and holds CPU specific arguments. +type SysbenchCPU struct { + Base SysbenchBase + MaxPrime int // upper limit for primes generator [10000]. +} + +// MakeCmd makes commands for SysbenchCPU. +func (s *SysbenchCPU) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "cpu run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchCPU cmds. +func (s *SysbenchCPU) flags() []string { + cmd := s.Base.baseFlags() + if s.MaxPrime > 0 { + return append(cmd, fmt.Sprintf("--cpu-max-prime=%d", s.MaxPrime)) + } + return cmd +} + +// Report reports the relevant metrics for SysbenchCPU. +func (s *SysbenchCPU) Report(b *testing.B, output string) { + b.Helper() + result, err := s.parseEvents(output) + if err != nil { + b.Fatalf("parsing CPU events from %s failed: %v", output, err) + } + b.ReportMetric(result, "cpu_events_per_second") +} + +var cpuEventsPerSecondRE = regexp.MustCompile(`events per second:\s*(\d*.?\d*)\n`) + +// parseEvents parses cpu events per second. +func (s *SysbenchCPU) parseEvents(data string) (float64, error) { + match := cpuEventsPerSecondRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find events per second: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// SysbenchMemory is for 'sysbench [FLAGS] memory run' and holds Memory specific arguments. +type SysbenchMemory struct { + Base SysbenchBase + BlockSize string // size of test memory block [1K]. + TotalSize string // size of data to transfer [100G]. + Scope string // memory access scope {global, local} [global]. + HugeTLB bool // allocate memory from HugeTLB [off]. + OperationType string // type of memory ops {read, write, none} [write]. + AccessMode string // access mode {seq, rnd} [seq]. +} + +// MakeCmd makes commands for SysbenchMemory. +func (s *SysbenchMemory) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "memory run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchMemory cmds. +func (s *SysbenchMemory) flags() []string { + cmd := s.Base.baseFlags() + if s.BlockSize != "" { + cmd = append(cmd, fmt.Sprintf("--memory-block-size=%s", s.BlockSize)) + } + if s.TotalSize != "" { + cmd = append(cmd, fmt.Sprintf("--memory-total-size=%s", s.TotalSize)) + } + if s.Scope != "" { + cmd = append(cmd, fmt.Sprintf("--memory-scope=%s", s.Scope)) + } + if s.HugeTLB { + cmd = append(cmd, "--memory-hugetlb=on") + } + if s.OperationType != "" { + cmd = append(cmd, fmt.Sprintf("--memory-oper=%s", s.OperationType)) + } + if s.AccessMode != "" { + cmd = append(cmd, fmt.Sprintf("--memory-access-mode=%s", s.AccessMode)) + } + return cmd +} + +// Report reports the relevant metrics for SysbenchMemory. +func (s *SysbenchMemory) Report(b *testing.B, output string) { + b.Helper() + result, err := s.parseOperations(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "operations_per_second") +} + +var memoryOperationsRE = regexp.MustCompile(`Total\soperations:\s+\d*\s*\((\d*\.\d*)\sper\ssecond\)`) + +// parseOperations parses memory operations per second form sysbench memory ouput. +func (s *SysbenchMemory) parseOperations(data string) (float64, error) { + match := memoryOperationsRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("couldn't find memory operations per second: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// SysbenchMutex is for 'sysbench [FLAGS] mutex run' and holds Mutex specific arguments. +type SysbenchMutex struct { + Base SysbenchBase + Num int // total size of mutex array [4096]. + Locks int // number of mutex locks per thread [50K]. + Loops int // number of loops to do outside mutex lock [10K]. +} + +// MakeCmd makes commands for SysbenchMutex. +func (s *SysbenchMutex) MakeCmd() []string { + cmd := []string{warmup, "sysbench"} + cmd = append(cmd, s.flags()...) + cmd = append(cmd, "mutex run") + return []string{"sh", "-c", strings.Join(cmd, " ")} +} + +// flags makes flags for SysbenchMutex commands. +func (s *SysbenchMutex) flags() []string { + var cmd []string + cmd = append(cmd, s.Base.baseFlags()...) + if s.Num > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-num=%d", s.Num)) + } + if s.Locks > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-locks=%d", s.Locks)) + } + if s.Loops > 0 { + cmd = append(cmd, fmt.Sprintf("--mutex-loops=%d", s.Loops)) + } + return cmd +} + +// Report parses and reports relevant sysbench mutex metrics. +func (s *SysbenchMutex) Report(b *testing.B, output string) { + b.Helper() + + result, err := s.parseExecutionTime(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "average_execution_time_secs") + + result, err = s.parseDeviation(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result, "stdev_execution_time_secs") + + result, err = s.parseLatency(output) + if err != nil { + b.Fatalf("parsing result %s failed with err: %v", output, err) + } + b.ReportMetric(result/1000, "average_latency_secs") +} + +var executionTimeRE = regexp.MustCompile(`execution time \(avg/stddev\):\s*(\d*.?\d*)/(\d*.?\d*)`) + +// parseExecutionTime parses threads fairness average execution time from sysbench output. +func (s *SysbenchMutex) parseExecutionTime(data string) (float64, error) { + match := executionTimeRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find execution time average: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} + +// parseDeviation parses threads fairness stddev time from sysbench output. +func (s *SysbenchMutex) parseDeviation(data string) (float64, error) { + match := executionTimeRE.FindStringSubmatch(data) + if len(match) < 3 { + return 0.0, fmt.Errorf("could not find execution time deviation: %s", data) + } + return strconv.ParseFloat(match[2], 64) +} + +var averageLatencyRE = regexp.MustCompile(`avg:[^\n^\d]*(\d*\.?\d*)`) + +// parseLatency parses latency from sysbench output. +func (s *SysbenchMutex) parseLatency(data string) (float64, error) { + match := averageLatencyRE.FindStringSubmatch(data) + if len(match) < 2 { + return 0.0, fmt.Errorf("could not find average latency: %s", data) + } + return strconv.ParseFloat(match[1], 64) +} diff --git a/test/benchmarks/tools/sysbench_test.go b/test/benchmarks/tools/sysbench_test.go new file mode 100644 index 000000000..850d1939e --- /dev/null +++ b/test/benchmarks/tools/sysbench_test.go @@ -0,0 +1,169 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "testing" +) + +// TestSysbenchCpu tests parses on sample 'sysbench cpu' output. +func TestSysbenchCpu(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Prime numbers limit: 10000 + +Initializing worker threads... + +Threads started! + +CPU speed: + events per second: 9093.38 + +General statistics: + total time: 10.0007s + total number of events: 90949 + +Latency (ms): + min: 0.64 + avg: 0.88 + max: 24.65 + 95th percentile: 1.55 + sum: 79936.91 + +Threads fairness: + events (avg/stddev): 11368.6250/831.38 + execution time (avg/stddev): 9.9921/0.01 +` + sysbench := SysbenchCPU{} + want := 9093.38 + if got, err := sysbench.parseEvents(sampleData); err != nil { + t.Fatalf("parse cpu events failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} + +// TestSysbenchMemory tests parsers on sample 'sysbench memory' output. +func TestSysbenchMemory(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Running memory speed test with the following options: + block size: 1KiB + total size: 102400MiB + operation: write + scope: global + +Initializing worker threads... + +Threads started! + +Total operations: 47999046 (9597428.64 per second) + +46874.07 MiB transferred (9372.49 MiB/sec) + + +General statistics: + total time: 5.0001s + total number of events: 47999046 + +Latency (ms): + min: 0.00 + avg: 0.00 + max: 0.21 + 95th percentile: 0.00 + sum: 33165.91 + +Threads fairness: + events (avg/stddev): 5999880.7500/111242.52 + execution time (avg/stddev): 4.1457/0.09 +` + sysbench := SysbenchMemory{} + want := 9597428.64 + if got, err := sysbench.parseOperations(sampleData); err != nil { + t.Fatalf("parse memory ops failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} + +// TestSysbenchMutex tests parsers on sample 'sysbench mutex' output. +func TestSysbenchMutex(t *testing.T) { + sampleData := ` +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +The 'mutex' test requires a command argument. See 'sysbench mutex help' +root@ec078132e294:/# sysbench mutex --threads=8 run +sysbench 1.0.11 (using system LuaJIT 2.1.0-beta3) + +Running the test with following options: +Number of threads: 8 +Initializing random number generator from current time + + +Initializing worker threads... + +Threads started! + + +General statistics: + total time: 0.2320s + total number of events: 8 + +Latency (ms): + min: 152.35 + avg: 192.48 + max: 231.41 + 95th percentile: 231.53 + sum: 1539.83 + +Threads fairness: + events (avg/stddev): 1.0000/0.00 + execution time (avg/stddev): 0.1925/0.04 +` + + sysbench := SysbenchMutex{} + want := .1925 + if got, err := sysbench.parseExecutionTime(sampleData); err != nil { + t.Fatalf("parse mutex time failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } + + want = 0.04 + if got, err := sysbench.parseDeviation(sampleData); err != nil { + t.Fatalf("parse mutex deviation failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } + + want = 192.48 + if got, err := sysbench.parseLatency(sampleData); err != nil { + t.Fatalf("parse mutex time failed: %v", err) + } else if want != got { + t.Fatalf("got: %f want: %f", got, want) + } +} diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 6fe6d304f..8425abecb 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -64,9 +64,10 @@ func TestLifeCycle(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 80 if err := d.Create(ctx, dockerutil.RunOpts{ Image: "basic/nginx", - Ports: []int{80}, + Ports: []int{port}, }); err != nil { t.Fatalf("docker create failed: %v", err) } @@ -74,16 +75,15 @@ func TestLifeCycle(t *testing.T) { t.Fatalf("docker start failed: %v", err) } - // Test that container is working. - port, err := d.FindPort(ctx, 80) + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } client := http.Client{Timeout: defaultWait} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Errorf("http request failed: %v", err) } @@ -105,27 +105,28 @@ func TestPauseResume(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 8080 if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/python", - Ports: []int{8080}, // See Dockerfile. + Ports: []int{port}, // See Dockerfile. }); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(ctx, 8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check that container is working. client := http.Client{Timeout: defaultWait} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } @@ -135,7 +136,7 @@ func TestPauseResume(t *testing.T) { // Check if container is paused. client = http.Client{Timeout: 10 * time.Millisecond} // Don't wait a minute. - switch _, err := client.Get(fmt.Sprintf("http://localhost:%d", port)); v := err.(type) { + switch _, err := client.Get(fmt.Sprintf("http://%s:%d", ip.String(), port)); v := err.(type) { case nil: t.Errorf("http req expected to fail but it succeeded") case net.Error: @@ -151,13 +152,13 @@ func TestPauseResume(t *testing.T) { } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. client = http.Client{Timeout: defaultWait} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } } @@ -167,14 +168,22 @@ func TestCheckpointRestore(t *testing.T) { t.Skip("Pause/resume is not supported.") } + // TODO(gvisor.dev/issue/3373): Remove after implementing. + if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 { + t.Skip("CheckpointRestore not implemented in VFS2.") + } else if err != nil { + t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err) + } + ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) // Start the container. + port := 8080 if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/python", - Ports: []int{8080}, // See Dockerfile. + Ports: []int{port}, // See Dockerfile. }); err != nil { t.Fatalf("docker run failed: %v", err) } @@ -192,20 +201,20 @@ func TestCheckpointRestore(t *testing.T) { t.Fatalf("docker restore failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(ctx, 8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("docker.FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Check if container is working again. client := http.Client{Timeout: defaultWait} - if err := httpRequestSucceeds(client, "localhost", port); err != nil { + if err := httpRequestSucceeds(client, ip.String(), port); err != nil { t.Error("http request failed:", err) } } @@ -467,6 +476,24 @@ func TestHostOverlayfsRewindDir(t *testing.T) { } } +// Basic test for linkat(2). Syscall tests requires CAP_DAC_READ_SEARCH and it +// cannot use tricks like userns as root. For this reason, run a basic link test +// to ensure some coverage. +func TestLink(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + if got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/linktest", + WorkDir: "/root", + }, "./link_test"); err != nil { + t.Fatalf("docker run failed: %v", err) + } else if got != "" { + t.Errorf("test failed:\n%s", got) + } +} + func TestMain(m *testing.M) { dockerutil.EnsureSupportedDockerVersion() flag.Parse() diff --git a/test/fuse/BUILD b/test/fuse/BUILD new file mode 100644 index 000000000..8e31fdd41 --- /dev/null +++ b/test/fuse/BUILD @@ -0,0 +1,73 @@ +load("//test/runner:defs.bzl", "syscall_test") + +package(licenses = ["notice"]) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:stat_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:open_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:release_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:mknod_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:symlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:readlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:mkdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:read_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:write_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:rmdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:readdir_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:create_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:unlink_test", +) + +syscall_test( + fuse = "True", + test = "//test/fuse/linux:setstat_test", +) diff --git a/test/fuse/README.md b/test/fuse/README.md new file mode 100644 index 000000000..65add57e2 --- /dev/null +++ b/test/fuse/README.md @@ -0,0 +1,188 @@ +# gVisor FUSE Test Suite + +This is an integration test suite for fuse(4) filesystem. It runs under gVisor +sandbox container with VFS2 and FUSE function enabled. + +This document describes the framework of FUSE integration test, how to use it, +and the guidelines that should be followed when adding new testing features. + +## Integration Test Framework + +By inheriting the `FuseTest` class defined in `linux/fuse_base.h`, every test +fixture can run in an environment with `mount_point_` mounted by a fake FUSE +server. It creates a `socketpair(2)` to send and receive control commands and +data between the client and the server. Because the FUSE server runs in the +background thread, gTest cannot catch its assertion failure immediately. Thus, +`TearDown()` function sends command to the FUSE server to check if all gTest +assertion in the server are successful and all requests and preset responses are +consumed. + +## Communication Diagram + +Diagram below describes how a testing thread communicates with the FUSE server +to achieve integration test. + +For the following diagram, `>` means entering the function, `<` is leaving the +function, and `=` indicates sequentially entering and leaving. Not necessarily +follow exactly the below diagram due to the nature of a multi-threaded system, +however, it is still helpful to know when the client waits for the server to +complete a command and when the server awaits the next instruction. + +``` +| Client (Testing Thread) | Server (FUSE Server Thread) +| | +| >TEST_F() | +| >SetUp() | +| =MountFuse() | +| >SetUpFuseServer() | +| [create communication socket]| +| =fork() | =fork() +| [wait server complete] | +| | =ServerConsumeFuseInit() +| | =ServerCompleteWith() +| <SetUpFuseServer() | +| <SetUp() | +| [testing main] | +| | >ServerFuseLoop() +| | [poll on socket and fd] +| >SetServerResponse() | +| [write data to socket] | +| [wait server complete] | +| | [socket event occurs] +| | >ServerHandleCommand() +| | >ServerReceiveResponse() +| | [read data from socket] +| | [save data to memory] +| | <ServerReceiveResponse() +| | =ServerCompleteWith() +| <SetServerResponse() | +| | <ServerHandleCommand() +| >[Do fs operation] | +| [wait for fs response] | +| | [fd event occurs] +| | >ServerProcessFuseRequest() +| | =[read fs request] +| | =[save fs request to memory] +| | =[write fs response] +| <[Do fs operation] | +| | <ServerProcessFuseRequest() +| | +| =[Test fs operation result] | +| | +| >GetServerActualRequest() | +| [write data to socket] | +| [wait data from server] | +| | [socket event occurs] +| | >ServerHandleCommand() +| | >ServerSendReceivedRequest() +| | [write data to socket] +| [read data from socket] | +| [wait server complete] | +| | <ServerSendReceivedRequest() +| | =ServerCompleteWith() +| <GetServerActualRequest() | +| | <ServerHandleCommand() +| | +| =[Test actual request] | +| | +| >TearDown() | +| ... | +| >GetServerNumUnsentResponses() | +| [write data to socket] | +| [wait server complete] | +| | [socket event arrive] +| | >ServerHandleCommand() +| | >ServerSendData() +| | [write data to socket] +| | <ServerSendData() +| | =ServerCompleteWith() +| [read data from socket] | +| [test if all succeeded] | +| <GetServerNumUnsentResponses() | +| | <ServerHandleCommand() +| =UnmountFuse() | +| <TearDown() | +| <TEST_F() | +``` + +## Running the tests + +Based on syscall tests, FUSE tests generate targets only with vfs2 and fuse +enabled. The corresponding targets end in `_fuse`. + +For example, to run fuse test in `stat_test.cc`: + +```bash +$ bazel test //test/fuse:stat_test_runsc_ptrace_vfs2_fuse +``` + +Test all targets tagged with fuse: + +```bash +$ bazel test --test_tag_filters=fuse //test/fuse/... +``` + +## Writing a new FUSE test + +1. Add test targets in `BUILD` and `linux/BUILD`. +2. Inherit your test from `FuseTest` base class. It allows you to: + - Fork a fake FUSE server in background during each test setup. + - Create a pair of sockets for communication and provide utility + functions. + - Stop FUSE server and check if error occurs in it after test completes. +3. Build the expected opcode-response pairs of your FUSE operation. +4. Call `SetServerResponse()` to preset the next expected opcode and response. +5. Do real filesystem operations (FUSE is mounted at `mount_point_`). +6. Check FUSE response and/or errors. +7. Retrieve FUSE request by `GetServerActualRequest()`. +8. Check if the request is as expected. + +A few customized matchers used in syscalls test are encouraged to test the +outcome of filesystem operations. Such as: + +```cc +SyscallSucceeds() +SyscallSucceedsWithValue(...) +SyscallFails() +SyscallFailsWithErrno(...) +``` + +Please refer to [test/syscalls/README.md](../syscalls/README.md) for further +details. + +## Writing a new FuseTestCmd + +A `FuseTestCmd` is a control protocol used in the communication between the +testing thread and the FUSE server. Such commands are sent from the testing +thread to the FUSE server to set up, control, or inspect the behavior of the +FUSE server in response to a sequence of FUSE requests. + +The lifecycle of a command contains following steps: + +1. The testing thread sends a `FuseTestCmd` via socket and waits for + completion. +2. The FUSE server receives the command and does corresponding action. +3. (Optional) The testing thread reads data from socket. +4. The FUSE server sends a success indicator via socket after processing. +5. The testing thread gets the success signal and continues testing. + +The success indicator, i.e. `WaitServerComplete()`, is crucial at the end of +each `FuseTestCmd` sent from the testing thread. Because we don't want to begin +filesystem operation if the requests have not been completely set up. Also, to +test FUSE interactions in a sequential manner, concurrent requests are not +supported now. + +To add a new `FuseTestCmd`, one must comply with following format: + +1. Add a new `FuseTestCmd` enum class item defined in `linux/fuse_base.h` +2. Add a `SetServerXXX()` or `GetServerXXX()` public function in `FuseTest`. + This is how the testing thread will call to send control message. Define how + many bytes you want to send along with the command and what you will expect + to receive. Finally it should block and wait for a success indicator from + the FUSE server. +3. Add a handler logic in the switch condition of `ServerHandleCommand()`. Use + `ServerSendData()` or declare a new private function such as + `ServerReceiveXXX()` or `ServerSendXXX()`. It is mandatory to set it private + since only the FUSE server (forked from `FuseTest` base class) can call it. + This is the server part of the specific `FuseTestCmd` and the format of the + data should be consistent with what the client expects in the previous step. diff --git a/test/fuse/linux/BUILD b/test/fuse/linux/BUILD new file mode 100644 index 000000000..7673252ec --- /dev/null +++ b/test/fuse/linux/BUILD @@ -0,0 +1,230 @@ +load("//tools:defs.bzl", "cc_binary", "cc_library", "gtest") + +package( + default_visibility = ["//:sandbox"], + licenses = ["notice"], +) + +cc_binary( + name = "stat_test", + testonly = 1, + srcs = ["stat_test.cc"], + deps = [ + gtest, + ":fuse_fd_util", + "//test/util:cleanup", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "open_test", + testonly = 1, + srcs = ["open_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "release_test", + testonly = 1, + srcs = ["release_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "mknod_test", + testonly = 1, + srcs = ["mknod_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "symlink_test", + testonly = 1, + srcs = ["symlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "readlink_test", + testonly = 1, + srcs = ["readlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "mkdir_test", + testonly = 1, + srcs = ["mkdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "setstat_test", + testonly = 1, + srcs = ["setstat_test.cc"], + deps = [ + gtest, + ":fuse_fd_util", + "//test/util:cleanup", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "rmdir_test", + testonly = 1, + srcs = ["rmdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "readdir_test", + testonly = 1, + srcs = ["readdir_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_library( + name = "fuse_base", + testonly = 1, + srcs = ["fuse_base.cc"], + hdrs = ["fuse_base.h"], + deps = [ + gtest, + "//test/util:fuse_util", + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_util", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "fuse_fd_util", + testonly = 1, + srcs = ["fuse_fd_util.cc"], + hdrs = ["fuse_fd_util.h"], + deps = [ + gtest, + ":fuse_base", + "//test/util:cleanup", + "//test/util:file_descriptor", + "//test/util:fuse_util", + "//test/util:posix_error", + ], +) + +cc_binary( + name = "read_test", + testonly = 1, + srcs = ["read_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "write_test", + testonly = 1, + srcs = ["write_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "create_test", + testonly = 1, + srcs = ["create_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fs_util", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "unlink_test", + testonly = 1, + srcs = ["unlink_test.cc"], + deps = [ + gtest, + ":fuse_base", + "//test/util:fuse_util", + "//test/util:temp_umask", + "//test/util:test_main", + "//test/util:test_util", + ], +) diff --git a/test/fuse/linux/create_test.cc b/test/fuse/linux/create_test.cc new file mode 100644 index 000000000..9a0219a58 --- /dev/null +++ b/test/fuse/linux/create_test.cc @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class CreateTest : public FuseTest { + protected: + const std::string test_file_name_ = "test_file"; + const mode_t mode = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(CreateTest, CreateFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_name_); + + // Ensure the file doesn't exist. + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_LOOKUP, iov_out); + + // creat(2) is equal to open(2) with open_flags O_CREAT | O_WRONLY | O_TRUNC. + const mode_t new_mask = S_IWGRP | S_IWOTH; + const int open_flags = O_CREAT | O_WRONLY | O_TRUNC; + out_header.error = 0; + out_header.len = sizeof(struct fuse_out_header) + + sizeof(struct fuse_entry_out) + sizeof(struct fuse_open_out); + struct fuse_entry_out entry_payload = DefaultEntryOut(mode & ~new_mask, 2); + struct fuse_open_out out_payload = { + .fh = 1, + .open_flags = open_flags, + }; + iov_out = FuseGenerateIovecs(out_header, entry_payload, out_payload); + SetServerResponse(FUSE_CREATE, iov_out); + + // kernfs generates a successive FUSE_OPEN after the file is created. Linux's + // fuse kernel module will not send this FUSE_OPEN after creat(2). + out_header.len = + sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out); + iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + + int fd; + TempUmask mask(new_mask); + EXPECT_THAT(fd = creat(test_file_path.c_str(), mode), SyscallSucceeds()); + EXPECT_THAT(fcntl(fd, F_GETFL), + SyscallSucceedsWithValue(open_flags & O_ACCMODE)); + + struct fuse_in_header in_header; + struct fuse_create_in in_payload; + std::vector<char> name(test_file_name_.size() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, name); + + // Skip the request of FUSE_LOOKUP. + SkipServerActualRequest(); + + // Get the first FUSE_CREATE. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload) + + test_file_name_.size() + 1); + EXPECT_EQ(in_header.opcode, FUSE_CREATE); + EXPECT_EQ(in_payload.flags, open_flags); + EXPECT_EQ(in_payload.mode, mode & ~new_mask); + EXPECT_EQ(in_payload.umask, new_mask); + EXPECT_EQ(std::string(name.data()), test_file_name_); + + // Get the successive FUSE_OPEN. + struct fuse_open_in in_payload_open; + iov_in = FuseGenerateIovecs(in_header, in_payload_open); + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload_open)); + EXPECT_EQ(in_header.opcode, FUSE_OPEN); + EXPECT_EQ(in_payload_open.flags, open_flags & O_ACCMODE); + + EXPECT_THAT(close(fd), SyscallSucceeds()); + // Skip the FUSE_RELEASE. + SkipServerActualRequest(); +} + +TEST_F(CreateTest, CreateFileAlreadyExists) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_name_); + + const int open_flags = O_CREAT | O_EXCL; + + SetServerInodeLookup(test_file_name_); + + EXPECT_THAT(open(test_file_path.c_str(), mode, open_flags), + SyscallFailsWithErrno(EEXIST)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_base.cc b/test/fuse/linux/fuse_base.cc new file mode 100644 index 000000000..5b45804e1 --- /dev/null +++ b/test/fuse/linux/fuse_base.cc @@ -0,0 +1,447 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/fuse/linux/fuse_base.h" + +#include <fcntl.h> +#include <linux/fuse.h> +#include <poll.h> +#include <sys/mount.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/strings/str_format.h" +#include "test/util/fuse_util.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +void FuseTest::SetUp() { + MountFuse(); + SetUpFuseServer(); +} + +void FuseTest::TearDown() { + EXPECT_EQ(GetServerNumUnconsumedRequests(), 0); + EXPECT_EQ(GetServerNumUnsentResponses(), 0); + UnmountFuse(); +} + +// Sends 3 parts of data to the FUSE server: +// 1. The `kSetResponse` command +// 2. The expected opcode +// 3. The fake FUSE response +// Then waits for the FUSE server to notify its completion. +void FuseTest::SetServerResponse(uint32_t opcode, + std::vector<struct iovec>& iovecs) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSetResponse); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &opcode, sizeof(opcode)), + SyscallSucceedsWithValue(sizeof(opcode))); + + EXPECT_THAT(RetryEINTR(writev)(sock_[0], iovecs.data(), iovecs.size()), + SyscallSucceeds()); + + WaitServerComplete(); +} + +// Waits for the FUSE server to finish its blocking job and check if it +// completes without errors. +void FuseTest::WaitServerComplete() { + uint32_t success; + EXPECT_THAT(RetryEINTR(read)(sock_[0], &success, sizeof(success)), + SyscallSucceedsWithValue(sizeof(success))); + ASSERT_EQ(success, 1); +} + +// Sends the `kGetRequest` command to the FUSE server, then reads the next +// request into iovec struct. The order of calling this function should be +// the same as the one of SetServerResponse(). +void FuseTest::GetServerActualRequest(std::vector<struct iovec>& iovecs) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kGetRequest); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(readv)(sock_[0], iovecs.data(), iovecs.size()), + SyscallSucceeds()); + + WaitServerComplete(); +} + +// Sends a FuseTestCmd command to the FUSE server, reads from the socket, and +// returns the corresponding data. +uint32_t FuseTest::GetServerData(uint32_t cmd) { + uint32_t data; + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(read)(sock_[0], &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); + + WaitServerComplete(); + return data; +} + +uint32_t FuseTest::GetServerNumUnconsumedRequests() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetNumUnconsumedRequests)); +} + +uint32_t FuseTest::GetServerNumUnsentResponses() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetNumUnsentResponses)); +} + +uint32_t FuseTest::GetServerTotalReceivedBytes() { + return GetServerData( + static_cast<uint32_t>(FuseTestCmd::kGetTotalReceivedBytes)); +} + +// Sends the `kSkipRequest` command to the FUSE server, which would skip +// current stored request data. +void FuseTest::SkipServerActualRequest() { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSkipRequest); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + WaitServerComplete(); +} + +// Sends the `kSetInodeLookup` command, expected mode, and the path of the +// inode to create under the mount point. +void FuseTest::SetServerInodeLookup(const std::string& path, mode_t mode, + uint64_t size) { + uint32_t cmd = static_cast<uint32_t>(FuseTestCmd::kSetInodeLookup); + EXPECT_THAT(RetryEINTR(write)(sock_[0], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &mode, sizeof(mode)), + SyscallSucceedsWithValue(sizeof(mode))); + + EXPECT_THAT(RetryEINTR(write)(sock_[0], &size, sizeof(size)), + SyscallSucceedsWithValue(sizeof(size))); + + // Pad 1 byte for null-terminate c-string. + EXPECT_THAT(RetryEINTR(write)(sock_[0], path.c_str(), path.size() + 1), + SyscallSucceedsWithValue(path.size() + 1)); + + WaitServerComplete(); +} + +void FuseTest::MountFuse(const char* mountOpts) { + EXPECT_THAT(dev_fd_ = open("/dev/fuse", O_RDWR), SyscallSucceeds()); + + std::string mount_opts = absl::StrFormat("fd=%d,%s", dev_fd_, mountOpts); + mount_point_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + EXPECT_THAT(mount("fuse", mount_point_.path().c_str(), "fuse", + MS_NODEV | MS_NOSUID, mount_opts.c_str()), + SyscallSucceeds()); +} + +void FuseTest::UnmountFuse() { + EXPECT_THAT(umount(mount_point_.path().c_str()), SyscallSucceeds()); + // TODO(gvisor.dev/issue/3330): ensure the process is terminated successfully. +} + +// Consumes the first FUSE request and returns the corresponding PosixError. +PosixError FuseTest::ServerConsumeFuseInit( + const struct fuse_init_out* out_payload) { + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + RETURN_ERROR_IF_SYSCALL_FAIL( + RetryEINTR(read)(dev_fd_, buf.data(), buf.size())); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_init_out), + .error = 0, + .unique = 2, + }; + // Returns a fake fuse_init_out with 7.0 version to avoid ECONNREFUSED + // error in the initialization of FUSE connection. + auto iov_out = FuseGenerateIovecs( + out_header, *const_cast<struct fuse_init_out*>(out_payload)); + + RETURN_ERROR_IF_SYSCALL_FAIL( + RetryEINTR(writev)(dev_fd_, iov_out.data(), iov_out.size())); + return NoError(); +} + +// Reads 1 expected opcode and a fake response from socket and save them into +// the serial buffer of this testing instance. +void FuseTest::ServerReceiveResponse() { + ssize_t len; + uint32_t opcode; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + EXPECT_THAT(RetryEINTR(read)(sock_[1], &opcode, sizeof(opcode)), + SyscallSucceedsWithValue(sizeof(opcode))); + + EXPECT_THAT(len = RetryEINTR(read)(sock_[1], buf.data(), buf.size()), + SyscallSucceeds()); + + responses_.AddMemBlock(opcode, buf.data(), len); +} + +// Writes 1 byte of success indicator through socket. +void FuseTest::ServerCompleteWith(bool success) { + uint32_t data = success ? 1 : 0; + ServerSendData(data); +} + +// ServerFuseLoop is the implementation of the fake FUSE server. Monitors 2 +// file descriptors: /dev/fuse and sock_[1]. Events from /dev/fuse are FUSE +// requests and events from sock_[1] are FUSE testing commands, leading by +// a FuseTestCmd data to indicate the command. +void FuseTest::ServerFuseLoop() { + const int nfds = 2; + struct pollfd fds[nfds] = { + { + .fd = dev_fd_, + .events = POLL_IN | POLLHUP | POLLERR | POLLNVAL, + }, + { + .fd = sock_[1], + .events = POLL_IN | POLLHUP | POLLERR | POLLNVAL, + }, + }; + + while (true) { + ASSERT_THAT(poll(fds, nfds, -1), SyscallSucceeds()); + + for (int fd_idx = 0; fd_idx < nfds; ++fd_idx) { + if (fds[fd_idx].revents == 0) continue; + + ASSERT_EQ(fds[fd_idx].revents, POLL_IN); + if (fds[fd_idx].fd == sock_[1]) { + ServerHandleCommand(); + } else if (fds[fd_idx].fd == dev_fd_) { + ServerProcessFuseRequest(); + } + } + } +} + +// SetUpFuseServer creates 1 socketpair and fork the process. The parent thread +// becomes testing thread and the child thread becomes the FUSE server running +// in background. These 2 threads are connected via socketpair. sock_[0] is +// opened in testing thread and sock_[1] is opened in the FUSE server. +void FuseTest::SetUpFuseServer(const struct fuse_init_out* payload) { + ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, 0, sock_), SyscallSucceeds()); + + switch (fork()) { + case -1: + GTEST_FAIL(); + return; + case 0: + break; + default: + ASSERT_THAT(close(sock_[1]), SyscallSucceeds()); + WaitServerComplete(); + return; + } + + // Begin child thread, i.e. the FUSE server. + ASSERT_THAT(close(sock_[0]), SyscallSucceeds()); + ServerCompleteWith(ServerConsumeFuseInit(payload).ok()); + ServerFuseLoop(); + _exit(0); +} + +void FuseTest::ServerSendData(uint32_t data) { + EXPECT_THAT(RetryEINTR(write)(sock_[1], &data, sizeof(data)), + SyscallSucceedsWithValue(sizeof(data))); +} + +// Reads FuseTestCmd sent from testing thread and routes to correct handler. +// Since each command should be a blocking operation, a `ServerCompleteWith()` +// is required after the switch keyword. +void FuseTest::ServerHandleCommand() { + uint32_t cmd; + EXPECT_THAT(RetryEINTR(read)(sock_[1], &cmd, sizeof(cmd)), + SyscallSucceedsWithValue(sizeof(cmd))); + + switch (static_cast<FuseTestCmd>(cmd)) { + case FuseTestCmd::kSetResponse: + ServerReceiveResponse(); + break; + case FuseTestCmd::kSetInodeLookup: + ServerReceiveInodeLookup(); + break; + case FuseTestCmd::kGetRequest: + ServerSendReceivedRequest(); + break; + case FuseTestCmd::kGetTotalReceivedBytes: + ServerSendData(static_cast<uint32_t>(requests_.UsedBytes())); + break; + case FuseTestCmd::kGetNumUnconsumedRequests: + ServerSendData(static_cast<uint32_t>(requests_.RemainingBlocks())); + break; + case FuseTestCmd::kGetNumUnsentResponses: + ServerSendData(static_cast<uint32_t>(responses_.RemainingBlocks())); + break; + case FuseTestCmd::kSkipRequest: + ServerSkipReceivedRequest(); + break; + default: + FAIL() << "Unknown FuseTestCmd " << cmd; + break; + } + + ServerCompleteWith(!HasFailure()); +} + +// Reads the expected file mode and the path of one file. Crafts a basic +// `fuse_entry_out` memory block and inserts into a map for future use. +// The FUSE server will always return this response if a FUSE_LOOKUP +// request with this specific path comes in. +void FuseTest::ServerReceiveInodeLookup() { + mode_t mode; + uint64_t size; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], &mode, sizeof(mode)), + SyscallSucceedsWithValue(sizeof(mode))); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], &size, sizeof(size)), + SyscallSucceedsWithValue(sizeof(size))); + + EXPECT_THAT(RetryEINTR(read)(sock_[1], buf.data(), buf.size()), + SyscallSucceeds()); + + std::string path(buf.data()); + + uint32_t out_len = + sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out); + struct fuse_out_header out_header = { + .len = out_len, + .error = 0, + }; + struct fuse_entry_out out_payload = DefaultEntryOut(mode, nodeid_); + // Since this is only used in test, nodeid_ is simply increased by 1 to + // comply with the unqiueness of different path. + ++nodeid_; + + // Set the size. + out_payload.attr.size = size; + + memcpy(buf.data(), &out_header, sizeof(out_header)); + memcpy(buf.data() + sizeof(out_header), &out_payload, sizeof(out_payload)); + lookups_.AddMemBlock(FUSE_LOOKUP, buf.data(), out_len); + lookup_map_[path] = lookups_.Next(); +} + +// Sends the received request pointed by current cursor and advances cursor. +void FuseTest::ServerSendReceivedRequest() { + if (requests_.End()) { + FAIL() << "No more received request."; + return; + } + auto mem_block = requests_.Next(); + EXPECT_THAT( + RetryEINTR(write)(sock_[1], requests_.DataAtOffset(mem_block.offset), + mem_block.len), + SyscallSucceedsWithValue(mem_block.len)); +} + +// Skip the request pointed by current cursor. +void FuseTest::ServerSkipReceivedRequest() { + if (requests_.End()) { + FAIL() << "No more received request."; + return; + } + requests_.Next(); +} + +// Handles FUSE request. Reads request from /dev/fuse, checks if it has the +// same opcode as expected, and responds with the saved fake FUSE response. +// The FUSE request is copied to the serial buffer and can be retrieved one- +// by-one by calling GetServerActualRequest from testing thread. +void FuseTest::ServerProcessFuseRequest() { + ssize_t len; + std::vector<char> buf(FUSE_MIN_READ_BUFFER); + + // Read FUSE request. + EXPECT_THAT(len = RetryEINTR(read)(dev_fd_, buf.data(), buf.size()), + SyscallSucceeds()); + fuse_in_header* in_header = reinterpret_cast<fuse_in_header*>(buf.data()); + + // Check if this is a preset FUSE_LOOKUP path. + if (in_header->opcode == FUSE_LOOKUP) { + std::string path(buf.data() + sizeof(struct fuse_in_header)); + auto it = lookup_map_.find(path); + if (it != lookup_map_.end()) { + // Matches a preset path. Reply with fake data and skip saving the + // request. + ServerRespondFuseSuccess(lookups_, it->second, in_header->unique); + return; + } + } + + requests_.AddMemBlock(in_header->opcode, buf.data(), len); + + if (in_header->opcode == FUSE_RELEASE || in_header->opcode == FUSE_RELEASEDIR) + return; + // Check if there is a corresponding response. + if (responses_.End()) { + GTEST_NONFATAL_FAILURE_("No more FUSE response is expected"); + ServerRespondFuseError(in_header->unique); + return; + } + auto mem_block = responses_.Next(); + if (in_header->opcode != mem_block.opcode) { + std::string message = absl::StrFormat("Expect opcode %d but got %d", + mem_block.opcode, in_header->opcode); + GTEST_NONFATAL_FAILURE_(message.c_str()); + // We won't get correct response if opcode is not expected. Send error + // response here to avoid wrong parsing by VFS. + ServerRespondFuseError(in_header->unique); + return; + } + + // Write FUSE response. + ServerRespondFuseSuccess(responses_, mem_block, in_header->unique); +} + +void FuseTest::ServerRespondFuseSuccess(FuseMemBuffer& mem_buf, + const FuseMemBlock& block, + uint64_t unique) { + fuse_out_header* out_header = + reinterpret_cast<fuse_out_header*>(mem_buf.DataAtOffset(block.offset)); + + // Patch `unique` in fuse_out_header to avoid EINVAL caused by responding + // with an unknown `unique`. + out_header->unique = unique; + EXPECT_THAT(RetryEINTR(write)(dev_fd_, out_header, block.len), + SyscallSucceedsWithValue(block.len)); +} + +void FuseTest::ServerRespondFuseError(uint64_t unique) { + fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = ENOSYS, + .unique = unique, + }; + EXPECT_THAT(RetryEINTR(write)(dev_fd_, &out_header, sizeof(out_header)), + SyscallSucceedsWithValue(sizeof(out_header))); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_base.h b/test/fuse/linux/fuse_base.h new file mode 100644 index 000000000..6ad296ca2 --- /dev/null +++ b/test/fuse/linux/fuse_base.h @@ -0,0 +1,251 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_FUSE_FUSE_BASE_H_ +#define GVISOR_TEST_FUSE_FUSE_BASE_H_ + +#include <linux/fuse.h> +#include <string.h> +#include <sys/stat.h> +#include <sys/uio.h> + +#include <iostream> +#include <unordered_map> +#include <vector> + +#include "gtest/gtest.h" +#include "test/util/posix_error.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +constexpr char kMountOpts[] = "rootmode=755,user_id=0,group_id=0"; + +constexpr struct fuse_init_out kDefaultFUSEInitOutPayload = {.major = 7}; + +// Internal commands used to communicate between testing thread and the FUSE +// server. See test/fuse/README.md for further detail. +enum class FuseTestCmd { + kSetResponse = 0, + kSetInodeLookup, + kGetRequest, + kGetNumUnconsumedRequests, + kGetNumUnsentResponses, + kGetTotalReceivedBytes, + kSkipRequest, +}; + +// Holds the information of a memory block in a serial buffer. +struct FuseMemBlock { + uint32_t opcode; + size_t offset; + size_t len; +}; + +// A wrapper of a simple serial buffer that can be used with read(2) and +// write(2). Contains a cursor to indicate accessing. This class is not thread- +// safe and can only be used in single-thread version. +class FuseMemBuffer { + public: + FuseMemBuffer() : cursor_(0) { + // To read from /dev/fuse, a buffer needs at least FUSE_MIN_READ_BUFFER + // bytes to avoid EINVAL. FuseMemBuffer holds memory that can accommodate + // a sequence of FUSE request/response, so it is initiated with double + // minimal requirement. + mem_.resize(FUSE_MIN_READ_BUFFER * 2); + } + + // Returns whether there is no memory block. + bool Empty() { return blocks_.empty(); } + + // Returns if there is no more remaining memory blocks. + bool End() { return cursor_ == blocks_.size(); } + + // Returns how many bytes that have been received. + size_t UsedBytes() { + return Empty() ? 0 : blocks_.back().offset + blocks_.back().len; + } + + // Returns the available bytes remains in the serial buffer. + size_t AvailBytes() { return mem_.size() - UsedBytes(); } + + // Appends a memory block information that starts at the tail of the serial + // buffer. /dev/fuse requires at least FUSE_MIN_READ_BUFFER bytes to read, or + // it will issue EINVAL. If it is not enough, just double the buffer length. + void AddMemBlock(uint32_t opcode, void* data, size_t len) { + if (AvailBytes() < FUSE_MIN_READ_BUFFER) { + mem_.resize(mem_.size() << 1); + } + size_t offset = UsedBytes(); + memcpy(mem_.data() + offset, data, len); + blocks_.push_back(FuseMemBlock{opcode, offset, len}); + } + + // Returns the memory address at a specific offset. Used with read(2) or + // write(2). + char* DataAtOffset(size_t offset) { return mem_.data() + offset; } + + // Returns current memory block pointed by the cursor and increase by 1. + FuseMemBlock Next() { + if (End()) { + std::cerr << "Buffer is already exhausted." << std::endl; + return FuseMemBlock{}; + } + return blocks_[cursor_++]; + } + + // Returns the number of the blocks that has not been requested. + size_t RemainingBlocks() { return blocks_.size() - cursor_; } + + private: + size_t cursor_; + std::vector<FuseMemBlock> blocks_; + std::vector<char> mem_; +}; + +// FuseTest base class is useful in FUSE integration test. Inherit this class +// to automatically set up a fake FUSE server and use the member functions +// to manipulate with it. Refer to test/fuse/README.md for detailed explanation. +class FuseTest : public ::testing::Test { + public: + // nodeid_ is the ID of a fake inode. We starts from 2 since 1 is occupied by + // the mount point. + FuseTest() : nodeid_(2) {} + void SetUp() override; + void TearDown() override; + + // Called by the testing thread to set up a fake response for an expected + // opcode via socket. This can be used multiple times to define a sequence of + // expected FUSE reactions. + void SetServerResponse(uint32_t opcode, std::vector<struct iovec>& iovecs); + + // Called by the testing thread to install a fake path under the mount point. + // e.g. a file under /mnt/dir/file and moint point is /mnt, then it will look + // up "dir/file" in this case. + // + // It sets a fixed response to the FUSE_LOOKUP requests issued with this + // path, pretending there is an inode and avoid ENOENT when testing. If mode + // is not given, it creates a regular file with mode 0600. + void SetServerInodeLookup(const std::string& path, + mode_t mode = S_IFREG | S_IRUSR | S_IWUSR, + uint64_t size = 512); + + // Called by the testing thread to ask the FUSE server for its next received + // FUSE request. Be sure to use the corresponding struct of iovec to receive + // data from server. + void GetServerActualRequest(std::vector<struct iovec>& iovecs); + + // Called by the testing thread to query the number of unconsumed requests in + // the requests_ serial buffer of the FUSE server. TearDown() ensures all + // FUSE requests received by the FUSE server were consumed by the testing + // thread. + uint32_t GetServerNumUnconsumedRequests(); + + // Called by the testing thread to query the number of unsent responses in + // the responses_ serial buffer of the FUSE server. TearDown() ensures all + // preset FUSE responses were sent out by the FUSE server. + uint32_t GetServerNumUnsentResponses(); + + // Called by the testing thread to ask the FUSE server for its total received + // bytes from /dev/fuse. + uint32_t GetServerTotalReceivedBytes(); + + // Called by the testing thread to ask the FUSE server to skip stored + // request data. + void SkipServerActualRequest(); + + protected: + TempPath mount_point_; + + // Opens /dev/fuse and inherit the file descriptor for the FUSE server. + void MountFuse(const char* mountOpts = kMountOpts); + + // Creates a socketpair for communication and forks FUSE server. + void SetUpFuseServer( + const struct fuse_init_out* payload = &kDefaultFUSEInitOutPayload); + + // Unmounts the mountpoint of the FUSE server. + void UnmountFuse(); + + private: + // Sends a FuseTestCmd and gets a uint32_t data from the FUSE server. + inline uint32_t GetServerData(uint32_t cmd); + + // Waits for FUSE server to complete its processing. Complains if the FUSE + // server responds any failure during tests. + void WaitServerComplete(); + + // The FUSE server stays here and waits next command or FUSE request until it + // is terminated. + void ServerFuseLoop(); + + // Used by the FUSE server to tell testing thread if it is OK to proceed next + // command. Will be issued after processing each FuseTestCmd. + void ServerCompleteWith(bool success); + + // Consumes the first FUSE request when mounting FUSE. Replies with a + // response with empty payload. + PosixError ServerConsumeFuseInit(const struct fuse_init_out* payload); + + // A command switch that dispatch different FuseTestCmd to its handler. + void ServerHandleCommand(); + + // The FUSE server side's corresponding code of `SetServerResponse()`. + // Handles `kSetResponse` command. Saves the fake response into its output + // memory queue. + void ServerReceiveResponse(); + + // The FUSE server side's corresponding code of `SetServerInodeLookup()`. + // Handles `kSetInodeLookup` command. Receives an expected file mode and + // file path under the mount point. + void ServerReceiveInodeLookup(); + + // The FUSE server side's corresponding code of `GetServerActualRequest()`. + // Handles `kGetRequest` command. Sends the next received request pointed by + // the cursor. + void ServerSendReceivedRequest(); + + // Sends a uint32_t data via socket. + inline void ServerSendData(uint32_t data); + + // The FUSE server side's corresponding code of `SkipServerActualRequest()`. + // Handles `kSkipRequest` command. Skip the request pointed by current cursor. + void ServerSkipReceivedRequest(); + + // Handles FUSE request sent to /dev/fuse by its saved responses. + void ServerProcessFuseRequest(); + + // Responds to FUSE request with a saved data. + void ServerRespondFuseSuccess(FuseMemBuffer& mem_buf, + const FuseMemBlock& block, uint64_t unique); + + // Responds an error header to /dev/fuse when bad thing happens. + void ServerRespondFuseError(uint64_t unique); + + int dev_fd_; + int sock_[2]; + + uint64_t nodeid_; + std::unordered_map<std::string, FuseMemBlock> lookup_map_; + + FuseMemBuffer requests_; + FuseMemBuffer responses_; + FuseMemBuffer lookups_; +}; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_FUSE_FUSE_BASE_H_ diff --git a/test/fuse/linux/fuse_fd_util.cc b/test/fuse/linux/fuse_fd_util.cc new file mode 100644 index 000000000..30d1157bb --- /dev/null +++ b/test/fuse/linux/fuse_fd_util.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/fuse/linux/fuse_fd_util.h" + +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/uio.h> + +#include <string> +#include <vector> + +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/fuse_util.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<FileDescriptor> FuseFdTest::OpenPath(const std::string &path, + uint32_t flags, uint64_t fh) { + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload = { + .fh = fh, + .open_flags = flags, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + + auto res = Open(path.c_str(), flags); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; +} + +Cleanup FuseFdTest::CloseFD(FileDescriptor &fd) { + return Cleanup([&] { + close(fd.release()); + SkipServerActualRequest(); + }); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/fuse_fd_util.h b/test/fuse/linux/fuse_fd_util.h new file mode 100644 index 000000000..066185c94 --- /dev/null +++ b/test/fuse/linux/fuse_fd_util.h @@ -0,0 +1,48 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ +#define GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ + +#include <fcntl.h> +#include <sys/stat.h> +#include <sys/types.h> + +#include <string> + +#include "test/fuse/linux/fuse_base.h" +#include "test/util/cleanup.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +class FuseFdTest : public FuseTest { + public: + // Sets the FUSE server to respond to a FUSE_OPEN with corresponding flags and + // fh. Then does a real file system open on the absolute path to get an fd. + PosixErrorOr<FileDescriptor> OpenPath(const std::string &path, + uint32_t flags = O_RDONLY, + uint64_t fh = 1); + + // Returns a cleanup object that closes the fd when it is destroyed. After + // the close is done, tells the FUSE server to skip this FUSE_RELEASE. + Cleanup CloseFD(FileDescriptor &fd); +}; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_FUSE_FUSE_FD_UTIL_H_ diff --git a/test/fuse/linux/mkdir_test.cc b/test/fuse/linux/mkdir_test.cc new file mode 100644 index 000000000..9647cb93f --- /dev/null +++ b/test/fuse/linux/mkdir_test.cc @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class MkdirTest : public FuseTest { + protected: + const std::string test_dir_ = "test_dir"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(MkdirTest, CreateDir) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_); + const mode_t new_umask = 0077; + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFDIR | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKDIR, iov_out); + TempUmask mask(new_umask); + ASSERT_THAT(mkdir(test_dir_path_.c_str(), 0777), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_mkdir_in in_payload; + std::vector<char> actual_dir(test_dir_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, actual_dir); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + sizeof(in_payload) + test_dir_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_MKDIR); + EXPECT_EQ(in_payload.mode & 0777, perms_ & ~new_umask); + EXPECT_EQ(in_payload.umask, new_umask); + EXPECT_EQ(std::string(actual_dir.data()), test_dir_); +} + +TEST_F(MkdirTest, FileTypeError) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKDIR, iov_out); + ASSERT_THAT(mkdir(test_dir_path_.c_str(), 0777), SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/mknod_test.cc b/test/fuse/linux/mknod_test.cc new file mode 100644 index 000000000..74c74d76b --- /dev/null +++ b/test/fuse/linux/mknod_test.cc @@ -0,0 +1,107 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/temp_umask.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class MknodTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(MknodTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + const mode_t new_umask = 0077; + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + TempUmask mask(new_umask); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_mknod_in in_payload; + std::vector<char> actual_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, in_payload, actual_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + sizeof(in_payload) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_MKNOD); + EXPECT_EQ(in_payload.mode & 0777, perms_ & ~new_umask); + EXPECT_EQ(in_payload.umask, new_umask); + EXPECT_EQ(in_payload.rdev, 0); + EXPECT_EQ(std::string(actual_file.data()), test_file_); +} + +TEST_F(MknodTest, FileTypeError) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + // server return directory instead of regular file should cause an error. + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFDIR | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +TEST_F(MknodTest, NodeIDError) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = + DefaultEntryOut(S_IFREG | perms_, FUSE_ROOT_ID); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_MKNOD, iov_out); + ASSERT_THAT(mknod(test_file_path.c_str(), perms_, 0), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/open_test.cc b/test/fuse/linux/open_test.cc new file mode 100644 index 000000000..4b0c4a805 --- /dev/null +++ b/test/fuse/linux/open_test.cc @@ -0,0 +1,128 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class OpenTest : public FuseTest { + // OpenTest doesn't care the release request when close a fd, + // so doesn't check leftover requests when tearing down. + void TearDown() { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t regular_file_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + + struct fuse_open_out out_payload_ = { + .fh = 1, + .open_flags = O_RDWR, + }; +}; + +TEST_F(OpenTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, regular_file_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPEN, iov_out); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + + struct fuse_in_header in_header; + struct fuse_open_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_OPEN); + EXPECT_EQ(in_payload.flags, O_RDWR); + EXPECT_THAT(fcntl(fd.get(), F_GETFL), SyscallSucceedsWithValue(O_RDWR)); +} + +TEST_F(OpenTest, SetNoOpen) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, regular_file_); + + // ENOSYS indicates open is not implemented. + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + .error = -ENOSYS, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPEN, iov_out); + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + SkipServerActualRequest(); + + // check open doesn't send new request. + uint32_t recieved_before = GetServerTotalReceivedBytes(); + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path.c_str(), O_RDWR)); + EXPECT_EQ(GetServerTotalReceivedBytes(), recieved_before); +} + +TEST_F(OpenTest, OpenFail) { + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + .error = -ENOENT, + }; + + auto iov_out = FuseGenerateIovecs(out_header, out_payload_); + SetServerResponse(FUSE_OPENDIR, iov_out); + ASSERT_THAT(open(mount_point_.path().c_str(), O_RDWR), + SyscallFailsWithErrno(ENOENT)); + + struct fuse_in_header in_header; + struct fuse_open_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_OPENDIR); + EXPECT_EQ(in_payload.flags, O_RDWR); +} + +TEST_F(OpenTest, DirectoryFlagOnRegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + + SetServerInodeLookup(test_file_, regular_file_); + ASSERT_THAT(open(test_file_path.c_str(), O_RDWR | O_DIRECTORY), + SyscallFailsWithErrno(ENOTDIR)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/read_test.cc b/test/fuse/linux/read_test.cc new file mode 100644 index 000000000..88fc299d8 --- /dev/null +++ b/test/fuse/linux/read_test.cc @@ -0,0 +1,390 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReadTest : public FuseTest { + void SetUp() override { + FuseTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + // TearDown overrides the parent's function + // to skip checking the unconsumed release request at the end. + void TearDown() override { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t test_file_mode_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + const uint64_t test_fh_ = 1; + const uint32_t open_flag_ = O_RDWR; + + std::string test_file_path_; + + PosixErrorOr<FileDescriptor> OpenTestFile(const std::string &path, + uint64_t size = 512) { + SetServerInodeLookup(test_file_, test_file_mode_, size); + + struct fuse_out_header out_header_open = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload_open = { + .fh = test_fh_, + .open_flags = open_flag_, + }; + auto iov_out_open = FuseGenerateIovecs(out_header_open, out_payload_open); + SetServerResponse(FUSE_OPEN, iov_out_open); + + auto res = Open(path.c_str(), open_flag_); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; + } +}; + +class ReadTestSmallMaxRead : public ReadTest { + void SetUp() override { + MountFuse(mountOpts); + SetUpFuseServer(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + protected: + constexpr static char mountOpts[] = + "rootmode=755,user_id=0,group_id=0,max_read=4096"; + // 4096 is hard-coded as the max_read in mount options. + const int size_fragment = 4096; +}; + +TEST_F(ReadTest, ReadWhole) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the read. + const int n_read = 5; + std::vector<char> data(n_read); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + EXPECT_EQ(buf, data); +} + +TEST_F(ReadTest, ReadPartial) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the read. + const int n_data = 10; + std::vector<char> data(n_data); + RandomizeBuffer(data.data(), data.size()); + // Note: due to read ahead, current read implementation will treat any + // response that is longer than requested as correct (i.e. not reach the EOF). + // Therefore, the test below should make sure the size to read does not exceed + // n_data. + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + std::vector<char> buf(n_data); + + // Read 1 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 1), SyscallSucceedsWithValue(1)); + + // Check the 1-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + + // Read 3 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 3), SyscallSucceedsWithValue(3)); + + // Check the 3-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_payload_read.offset, 1); + + // Read 5 bytes. + SetServerResponse(FUSE_READ, iov_out_read); + EXPECT_THAT(read(fd.get(), buf.data(), 5), SyscallSucceedsWithValue(5)); + + // Check the 5-byte read request. + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_payload_read.offset, 4); +} + +TEST_F(ReadTest, PRead) { + const int file_size = 512; + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, file_size)); + + // Prepare for the read. + const int n_read = 5; + std::vector<char> data(n_read); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read some bytes. + std::vector<char> buf(n_read); + const int offset_read = file_size >> 1; + EXPECT_THAT(pread(fd.get(), buf.data(), n_read, offset_read), + SyscallSucceedsWithValue(n_read)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, offset_read); + EXPECT_EQ(buf, data); +} + +TEST_F(ReadTest, ReadZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Issue the read. + std::vector<char> buf; + EXPECT_THAT(read(fd.get(), buf.data(), 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(ReadTest, ReadShort) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the short read. + const int n_read = 5; + std::vector<char> data(n_read >> 1); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(data.size())); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); + std::vector<char> short_buf(buf.begin(), buf.begin() + data.size()); + EXPECT_EQ(short_buf, data); +} + +TEST_F(ReadTest, ReadShortEOF) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the short read. + struct fuse_out_header out_header_read = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header)), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read); + SetServerResponse(FUSE_READ, iov_out_read); + + // Read the whole "file". + const int n_read = 10; + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), SyscallSucceedsWithValue(0)); + + // Check the read request. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, 0); +} + +TEST_F(ReadTestSmallMaxRead, ReadSmallMaxRead) { + const int n_fragment = 10; + const int n_read = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_read)); + + // Prepare for the read. + std::vector<char> data(size_fragment); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + + for (int i = 0; i < n_fragment; ++i) { + SetServerResponse(FUSE_READ, iov_out_read); + } + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read)); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check each read segment. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, i * size_fragment); + EXPECT_EQ(in_payload_read.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + EXPECT_EQ(std::vector<char>(it, it + size_fragment), data); + } +} + +TEST_F(ReadTestSmallMaxRead, ReadSmallMaxReadShort) { + const int n_fragment = 10; + const int n_read = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_read)); + + // Prepare for the read. + std::vector<char> data(size_fragment); + RandomizeBuffer(data.data(), data.size()); + struct fuse_out_header out_header_read = { + .len = + static_cast<uint32_t>(sizeof(struct fuse_out_header) + data.size()), + }; + auto iov_out_read = FuseGenerateIovecs(out_header_read, data); + + for (int i = 0; i < n_fragment - 1; ++i) { + SetServerResponse(FUSE_READ, iov_out_read); + } + + // The last fragment is a short read. + std::vector<char> half_data(data.begin(), data.begin() + (data.size() >> 1)); + struct fuse_out_header out_header_read_short = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header) + + half_data.size()), + }; + auto iov_out_read_short = + FuseGenerateIovecs(out_header_read_short, half_data); + SetServerResponse(FUSE_READ, iov_out_read_short); + + // Read the whole "file". + std::vector<char> buf(n_read); + EXPECT_THAT(read(fd.get(), buf.data(), n_read), + SyscallSucceedsWithValue(n_read - (data.size() >> 1))); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check each read segment. + struct fuse_in_header in_header_read; + struct fuse_read_in in_payload_read; + auto iov_in = FuseGenerateIovecs(in_header_read, in_payload_read); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in); + EXPECT_EQ(in_payload_read.fh, test_fh_); + EXPECT_EQ(in_header_read.len, + sizeof(in_header_read) + sizeof(in_payload_read)); + EXPECT_EQ(in_header_read.opcode, FUSE_READ); + EXPECT_EQ(in_payload_read.offset, i * size_fragment); + EXPECT_EQ(in_payload_read.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + if (i != n_fragment - 1) { + EXPECT_EQ(std::vector<char>(it, it + data.size()), data); + } else { + EXPECT_EQ(std::vector<char>(it, it + half_data.size()), half_data); + } + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/readdir_test.cc b/test/fuse/linux/readdir_test.cc new file mode 100644 index 000000000..2afb4b062 --- /dev/null +++ b/test/fuse/linux/readdir_test.cc @@ -0,0 +1,193 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <dirent.h> +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <linux/unistd.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +#define FUSE_NAME_OFFSET offsetof(struct fuse_dirent, name) +#define FUSE_DIRENT_ALIGN(x) \ + (((x) + sizeof(uint64_t) - 1) & ~(sizeof(uint64_t) - 1)) +#define FUSE_DIRENT_SIZE(d) FUSE_DIRENT_ALIGN(FUSE_NAME_OFFSET + (d)->namelen) + +namespace gvisor { +namespace testing { + +namespace { + +class ReaddirTest : public FuseTest { + public: + void fill_fuse_dirent(char *buf, const char *name, uint64_t ino) { + size_t namelen = strlen(name); + size_t entlen = FUSE_NAME_OFFSET + namelen; + size_t entlen_padded = FUSE_DIRENT_ALIGN(entlen); + struct fuse_dirent *dirent; + + dirent = reinterpret_cast<struct fuse_dirent *>(buf); + dirent->ino = ino; + dirent->namelen = namelen; + memcpy(dirent->name, name, namelen); + memset(dirent->name + namelen, 0, entlen_padded - entlen); + } + + protected: + const std::string test_dir_name_ = "test_dir"; +}; + +TEST_F(ReaddirTest, SingleEntry) { + const std::string test_dir_path = + JoinPath(mount_point_.path().c_str(), test_dir_name_); + + const uint64_t ino_dir = 1024; + // We need to make sure the test dir is a directory that can be found. + mode_t expected_mode = + S_IFDIR | S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH; + struct fuse_attr dir_attr = { + .ino = ino_dir, + .size = 512, + .blocks = 4, + .mode = expected_mode, + .blksize = 4096, + }; + + // We need to make sure the test dir is a directory that can be found. + struct fuse_out_header lookup_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out lookup_payload = { + .nodeid = 1, + .entry_valid = true, + .attr_valid = true, + .attr = dir_attr, + }; + + struct fuse_out_header open_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out open_payload = { + .fh = 1, + }; + auto iov_out = FuseGenerateIovecs(lookup_header, lookup_payload); + SetServerResponse(FUSE_LOOKUP, iov_out); + + iov_out = FuseGenerateIovecs(open_header, open_payload); + SetServerResponse(FUSE_OPENDIR, iov_out); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_dir_path.c_str(), O_RDONLY)); + + // The open command makes two syscalls. Lookup the dir file and open. + // We don't need to inspect those headers in this test. + SkipServerActualRequest(); // LOOKUP. + SkipServerActualRequest(); // OPENDIR. + + // Readdir test code. + std::string dot = "."; + std::string dot_dot = ".."; + std::string test_file = "testFile"; + + // Figure out how many dirents to send over and allocate them appropriately. + // Each dirent has a dynamic name and a static metadata part. The dirent size + // is aligned to being a multiple of 8. + size_t dot_file_dirent_size = + FUSE_DIRENT_ALIGN(dot.length() + FUSE_NAME_OFFSET); + size_t dot_dot_file_dirent_size = + FUSE_DIRENT_ALIGN(dot_dot.length() + FUSE_NAME_OFFSET); + size_t test_file_dirent_size = + FUSE_DIRENT_ALIGN(test_file.length() + FUSE_NAME_OFFSET); + + // Create an appropriately sized payload. + size_t readdir_payload_size = + test_file_dirent_size + dot_file_dirent_size + dot_dot_file_dirent_size; + std::vector<char> readdir_payload_vec(readdir_payload_size); + char *readdir_payload = readdir_payload_vec.data(); + + // Use fake ino for other directories. + fill_fuse_dirent(readdir_payload, dot.c_str(), ino_dir - 2); + fill_fuse_dirent(readdir_payload + dot_file_dirent_size, dot_dot.c_str(), + ino_dir - 1); + fill_fuse_dirent( + readdir_payload + dot_file_dirent_size + dot_dot_file_dirent_size, + test_file.c_str(), ino_dir); + + struct fuse_out_header readdir_header = { + .len = uint32_t(sizeof(struct fuse_out_header) + readdir_payload_size), + }; + struct fuse_out_header readdir_header_break = { + .len = uint32_t(sizeof(struct fuse_out_header)), + }; + + iov_out = FuseGenerateIovecs(readdir_header, readdir_payload_vec); + SetServerResponse(FUSE_READDIR, iov_out); + + iov_out = FuseGenerateIovecs(readdir_header_break); + SetServerResponse(FUSE_READDIR, iov_out); + + std::vector<char> buf(4090, 0); + int nread, off = 0, i = 0; + EXPECT_THAT( + nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), + SyscallSucceeds()); + for (; off < nread;) { + struct dirent64 *ent = (struct dirent64 *)(buf.data() + off); + off += ent->d_reclen; + switch (i++) { + case 0: + EXPECT_EQ(std::string(ent->d_name), dot); + break; + case 1: + EXPECT_EQ(std::string(ent->d_name), dot_dot); + break; + case 2: + EXPECT_EQ(std::string(ent->d_name), test_file); + break; + } + } + + EXPECT_THAT( + nread = syscall(__NR_getdents64, fd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(0)); + + SkipServerActualRequest(); // READDIR. + SkipServerActualRequest(); // READDIR with no data. + + // Clean up. + fd.reset(-1); + + struct fuse_in_header in_header; + struct fuse_release_in in_payload; + + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_RELEASEDIR); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/readlink_test.cc b/test/fuse/linux/readlink_test.cc new file mode 100644 index 000000000..2cba8fc23 --- /dev/null +++ b/test/fuse/linux/readlink_test.cc @@ -0,0 +1,85 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReadlinkTest : public FuseTest { + protected: + const std::string test_file_ = "test_file_"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(ReadlinkTest, ReadSymLink) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFLNK | perms_); + + struct fuse_out_header out_header = { + .len = static_cast<uint32_t>(sizeof(struct fuse_out_header)) + + static_cast<uint32_t>(test_file_.length()) + 1, + }; + std::string link = test_file_; + auto iov_out = FuseGenerateIovecs(out_header, link); + SetServerResponse(FUSE_READLINK, iov_out); + const std::string actual_link = + ASSERT_NO_ERRNO_AND_VALUE(ReadLink(symlink_path)); + + struct fuse_in_header in_header; + auto iov_in = FuseGenerateIovecs(in_header); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header)); + EXPECT_EQ(in_header.opcode, FUSE_READLINK); + EXPECT_EQ(0, memcmp(actual_link.c_str(), link.data(), link.size())); + + // next readlink should have link cached, so shouldn't have new request to + // server. + uint32_t recieved_before = GetServerTotalReceivedBytes(); + ASSERT_NO_ERRNO(ReadLink(symlink_path)); + EXPECT_EQ(GetServerTotalReceivedBytes(), recieved_before); +} + +TEST_F(ReadlinkTest, NotSymlink) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | perms_); + + std::vector<char> buf(PATH_MAX + 1); + ASSERT_THAT(readlink(test_file_path.c_str(), buf.data(), PATH_MAX), + SyscallFailsWithErrno(EINVAL)); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/release_test.cc b/test/fuse/linux/release_test.cc new file mode 100644 index 000000000..b5adb0870 --- /dev/null +++ b/test/fuse/linux/release_test.cc @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class ReleaseTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; +}; + +TEST_F(ReleaseTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload = { + .fh = 1, + .open_flags = O_RDWR, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_OPEN, iov_out); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_path, O_RDWR)); + SkipServerActualRequest(); + ASSERT_THAT(close(fd.release()), SyscallSucceeds()); + + struct fuse_in_header in_header; + struct fuse_release_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_RELEASE); + EXPECT_EQ(in_payload.flags, O_RDWR); + EXPECT_EQ(in_payload.fh, 1); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/rmdir_test.cc b/test/fuse/linux/rmdir_test.cc new file mode 100644 index 000000000..e3200e446 --- /dev/null +++ b/test/fuse/linux/rmdir_test.cc @@ -0,0 +1,100 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class RmDirTest : public FuseTest { + protected: + const std::string test_dir_name_ = "test_dir"; + const std::string test_subdir_ = "test_subdir"; + const mode_t test_dir_mode_ = S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(RmDirTest, NormalRmDir) { + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_dir_name_); + + SetServerInodeLookup(test_dir_name_, test_dir_mode_); + + // RmDir code. + struct fuse_out_header rmdir_header = { + .len = sizeof(struct fuse_out_header), + }; + + auto iov_out = FuseGenerateIovecs(rmdir_header); + SetServerResponse(FUSE_RMDIR, iov_out); + + ASSERT_THAT(rmdir(test_dir_path_.c_str()), SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_dirname(test_dir_name_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, actual_dirname); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_dir_name_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_RMDIR); + EXPECT_EQ(std::string(actual_dirname.data()), test_dir_name_); +} + +TEST_F(RmDirTest, NormalRmDirSubdir) { + SetServerInodeLookup(test_subdir_, S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO); + const std::string test_dir_path_ = + JoinPath(mount_point_.path().c_str(), test_subdir_, test_dir_name_); + SetServerInodeLookup(test_dir_name_, test_dir_mode_); + + // RmDir code. + struct fuse_out_header rmdir_header = { + .len = sizeof(struct fuse_out_header), + }; + + auto iov_out = FuseGenerateIovecs(rmdir_header); + SetServerResponse(FUSE_RMDIR, iov_out); + + ASSERT_THAT(rmdir(test_dir_path_.c_str()), SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_dirname(test_dir_name_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, actual_dirname); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_dir_name_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_RMDIR); + EXPECT_EQ(std::string(actual_dirname.data()), test_dir_name_); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/setstat_test.cc b/test/fuse/linux/setstat_test.cc new file mode 100644 index 000000000..68301c775 --- /dev/null +++ b/test/fuse/linux/setstat_test.cc @@ -0,0 +1,338 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> +#include <utime.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_fd_util.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class SetStatTest : public FuseFdTest { + public: + void SetUp() override { + FuseFdTest::SetUp(); + test_dir_path_ = JoinPath(mount_point_.path(), test_dir_); + test_file_path_ = JoinPath(mount_point_.path(), test_file_); + } + + protected: + const uint64_t fh = 23; + const std::string test_dir_ = "testdir"; + const std::string test_file_ = "testfile"; + const mode_t test_dir_mode_ = S_IFDIR | S_IRUSR | S_IWUSR | S_IXUSR; + const mode_t test_file_mode_ = S_IFREG | S_IRUSR | S_IWUSR | S_IXUSR; + + std::string test_dir_path_; + std::string test_file_path_; +}; + +TEST_F(SetStatTest, ChmodDir) { + // Set up fixture. + SetServerInodeLookup(test_dir_, test_dir_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + mode_t set_mode = S_IRGRP | S_IWGRP | S_IXGRP; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(set_mode, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(chmod(test_dir_path_.c_str(), set_mode), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_MODE); + EXPECT_EQ(in_payload.mode, S_IFDIR | set_mode); +} + +TEST_F(SetStatTest, ChownDir) { + // Set up fixture. + SetServerInodeLookup(test_dir_, test_dir_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_dir_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(chown(test_dir_path_.c_str(), 1025, 1025), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_UID | FATTR_GID); + EXPECT_EQ(in_payload.uid, 1025); + EXPECT_EQ(in_payload.gid, 1025); +} + +TEST_F(SetStatTest, TruncateFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(truncate(test_file_path_.c_str(), 321), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_SIZE); + EXPECT_EQ(in_payload.size, 321); +} + +TEST_F(SetStatTest, UtimeFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + time_t expected_atime = 1597159766, expected_mtime = 1597159765; + struct utimbuf times = { + .actime = expected_atime, + .modtime = expected_mtime, + }; + EXPECT_THAT(utime(test_file_path_.c_str(), ×), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_ATIME | FATTR_MTIME); + EXPECT_EQ(in_payload.atime, expected_atime); + EXPECT_EQ(in_payload.mtime, expected_mtime); +} + +TEST_F(SetStatTest, UtimesFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_file_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + struct timeval expected_times[2] = { + { + .tv_sec = 1597159766, + .tv_usec = 234945, + }, + { + .tv_sec = 1597159765, + .tv_usec = 232341, + }, + }; + EXPECT_THAT(utimes(test_file_path_.c_str(), expected_times), + SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_ATIME | FATTR_MTIME); + EXPECT_EQ(in_payload.atime, expected_times[0].tv_sec); + EXPECT_EQ(in_payload.atimensec, expected_times[0].tv_usec * 1000); + EXPECT_EQ(in_payload.mtime, expected_times[1].tv_sec); + EXPECT_EQ(in_payload.mtimensec, expected_times[1].tv_usec * 1000); +} + +TEST_F(SetStatTest, FtruncateFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(test_file_mode_, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(ftruncate(fd.get(), 321), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_SIZE | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.size, 321); +} + +TEST_F(SetStatTest, FchmodFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + mode_t set_mode = S_IROTH | S_IWOTH | S_IXOTH; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(set_mode, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(fchmod(fd.get(), set_mode), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_MODE | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.mode, S_IFREG | set_mode); +} + +TEST_F(SetStatTest, FchownFile) { + // Set up fixture. + SetServerInodeLookup(test_file_, test_file_mode_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDWR, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + .error = 0, + }; + struct fuse_attr_out out_payload = { + .attr = DefaultFuseAttr(S_IFREG | S_IRUSR | S_IWUSR | S_IXUSR, 2), + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SETATTR, iov_out); + + // Make syscall. + EXPECT_THAT(fchown(fd.get(), 1025, 1025), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_setattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.len, sizeof(in_header) + sizeof(in_payload)); + EXPECT_EQ(in_header.opcode, FUSE_SETATTR); + EXPECT_EQ(in_header.uid, 0); + EXPECT_EQ(in_header.gid, 0); + EXPECT_EQ(in_payload.valid, FATTR_UID | FATTR_GID | FATTR_FH); + EXPECT_EQ(in_payload.fh, fh); + EXPECT_EQ(in_payload.uid, 1025); + EXPECT_EQ(in_payload.gid, 1025); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/stat_test.cc b/test/fuse/linux/stat_test.cc new file mode 100644 index 000000000..6f032cac1 --- /dev/null +++ b/test/fuse/linux/stat_test.cc @@ -0,0 +1,219 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <sys/uio.h> +#include <unistd.h> + +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_fd_util.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class StatTest : public FuseFdTest { + public: + void SetUp() override { + FuseFdTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path(), test_file_); + } + + protected: + bool StatsAreEqual(struct stat expected, struct stat actual) { + // Device number will be dynamically allocated by kernel, we cannot know in + // advance. + actual.st_dev = expected.st_dev; + return memcmp(&expected, &actual, sizeof(struct stat)) == 0; + } + + const std::string test_file_ = "testfile"; + const mode_t expected_mode = S_IFREG | S_IRUSR | S_IWUSR; + const uint64_t fh = 23; + + std::string test_file_path_; +}; + +TEST_F(StatTest, StatNormal) { + // Set up fixture. + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 1); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + struct stat stat_buf; + EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), SyscallSucceeds()); + + // Check filesystem operation result. + struct stat expected_stat = { + .st_ino = attr.ino, + .st_nlink = attr.nlink, + .st_mode = expected_mode, + .st_uid = attr.uid, + .st_gid = attr.gid, + .st_rdev = attr.rdev, + .st_size = static_cast<off_t>(attr.size), + .st_blksize = attr.blksize, + .st_blocks = static_cast<blkcnt_t>(attr.blocks), + .st_atim = (struct timespec){.tv_sec = static_cast<int>(attr.atime), + .tv_nsec = attr.atimensec}, + .st_mtim = (struct timespec){.tv_sec = static_cast<int>(attr.mtime), + .tv_nsec = attr.mtimensec}, + .st_ctim = (struct timespec){.tv_sec = static_cast<int>(attr.ctime), + .tv_nsec = attr.ctimensec}, + }; + EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat)); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} + +TEST_F(StatTest, StatNotFound) { + // Set up fixture. + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + struct stat stat_buf; + EXPECT_THAT(stat(mount_point_.path().c_str(), &stat_buf), + SyscallFailsWithErrno(ENOENT)); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} + +TEST_F(StatTest, FstatNormal) { + // Set up fixture. + SetServerInodeLookup(test_file_); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDONLY, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 2); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + struct stat stat_buf; + EXPECT_THAT(fstat(fd.get(), &stat_buf), SyscallSucceeds()); + + // Check filesystem operation result. + struct stat expected_stat = { + .st_ino = attr.ino, + .st_nlink = attr.nlink, + .st_mode = expected_mode, + .st_uid = attr.uid, + .st_gid = attr.gid, + .st_rdev = attr.rdev, + .st_size = static_cast<off_t>(attr.size), + .st_blksize = attr.blksize, + .st_blocks = static_cast<blkcnt_t>(attr.blocks), + .st_atim = (struct timespec){.tv_sec = static_cast<int>(attr.atime), + .tv_nsec = attr.atimensec}, + .st_mtim = (struct timespec){.tv_sec = static_cast<int>(attr.mtime), + .tv_nsec = attr.mtimensec}, + .st_ctim = (struct timespec){.tv_sec = static_cast<int>(attr.ctime), + .tv_nsec = attr.ctimensec}, + }; + EXPECT_TRUE(StatsAreEqual(stat_buf, expected_stat)); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, 0); + EXPECT_EQ(in_payload.fh, 0); +} + +TEST_F(StatTest, StatByFileHandle) { + // Set up fixture. + SetServerInodeLookup(test_file_, expected_mode, 0); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenPath(test_file_path_, O_RDONLY, fh)); + auto close_fd = CloseFD(fd); + + struct fuse_attr attr = DefaultFuseAttr(expected_mode, 2, 0); + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_attr_out), + }; + struct fuse_attr_out out_payload = { + .attr = attr, + }; + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_GETATTR, iov_out); + + // Make syscall. + std::vector<char> buf(1); + // Since this is an empty file, it won't issue FUSE_READ. But a FUSE_GETATTR + // will be issued before read completes. + EXPECT_THAT(read(fd.get(), buf.data(), buf.size()), SyscallSucceeds()); + + // Check FUSE request. + struct fuse_in_header in_header; + struct fuse_getattr_in in_payload; + auto iov_in = FuseGenerateIovecs(in_header, in_payload); + + GetServerActualRequest(iov_in); + EXPECT_EQ(in_header.opcode, FUSE_GETATTR); + EXPECT_EQ(in_payload.getattr_flags, FUSE_GETATTR_FH); + EXPECT_EQ(in_payload.fh, fh); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/symlink_test.cc b/test/fuse/linux/symlink_test.cc new file mode 100644 index 000000000..2c3a52987 --- /dev/null +++ b/test/fuse/linux/symlink_test.cc @@ -0,0 +1,88 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class SymlinkTest : public FuseTest { + protected: + const std::string target_file_ = "target_file_"; + const std::string symlink_ = "symlink_"; + const mode_t perms_ = S_IRWXU | S_IRWXG | S_IRWXO; +}; + +TEST_F(SymlinkTest, CreateSymLink) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), symlink_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFLNK | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SYMLINK, iov_out); + ASSERT_THAT(symlink(target_file_.c_str(), symlink_path.c_str()), + SyscallSucceeds()); + + struct fuse_in_header in_header; + std::vector<char> actual_target_file(target_file_.length() + 1); + std::vector<char> actual_symlink(symlink_.length() + 1); + auto iov_in = + FuseGenerateIovecs(in_header, actual_symlink, actual_target_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, + sizeof(in_header) + symlink_.length() + target_file_.length() + 2); + EXPECT_EQ(in_header.opcode, FUSE_SYMLINK); + EXPECT_EQ(std::string(actual_target_file.data()), target_file_); + EXPECT_EQ(std::string(actual_symlink.data()), symlink_); +} + +TEST_F(SymlinkTest, FileTypeError) { + const std::string symlink_path = + JoinPath(mount_point_.path().c_str(), symlink_); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_entry_out), + }; + struct fuse_entry_out out_payload = DefaultEntryOut(S_IFREG | perms_, 5); + auto iov_out = FuseGenerateIovecs(out_header, out_payload); + SetServerResponse(FUSE_SYMLINK, iov_out); + ASSERT_THAT(symlink(target_file_.c_str(), symlink_path.c_str()), + SyscallFailsWithErrno(EIO)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/unlink_test.cc b/test/fuse/linux/unlink_test.cc new file mode 100644 index 000000000..13efbf7c7 --- /dev/null +++ b/test/fuse/linux/unlink_test.cc @@ -0,0 +1,107 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class UnlinkTest : public FuseTest { + protected: + const std::string test_file_ = "test_file"; + const std::string test_subdir_ = "test_subdir"; +}; + +TEST_F(UnlinkTest, RegularFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallSucceeds()); + struct fuse_in_header in_header; + std::vector<char> unlinked_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, unlinked_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_UNLINK); + EXPECT_EQ(std::string(unlinked_file.data()), test_file_); +} + +TEST_F(UnlinkTest, RegularFileSubDir) { + SetServerInodeLookup(test_subdir_, S_IFDIR | S_IRWXU | S_IRWXG | S_IRWXO); + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_subdir_, test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallSucceeds()); + struct fuse_in_header in_header; + std::vector<char> unlinked_file(test_file_.length() + 1); + auto iov_in = FuseGenerateIovecs(in_header, unlinked_file); + GetServerActualRequest(iov_in); + + EXPECT_EQ(in_header.len, sizeof(in_header) + test_file_.length() + 1); + EXPECT_EQ(in_header.opcode, FUSE_UNLINK); + EXPECT_EQ(std::string(unlinked_file.data()), test_file_); +} + +TEST_F(UnlinkTest, NoFile) { + const std::string test_file_path = + JoinPath(mount_point_.path().c_str(), test_file_); + SetServerInodeLookup(test_file_, S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO); + + struct fuse_out_header out_header = { + .len = sizeof(struct fuse_out_header), + .error = -ENOENT, + }; + auto iov_out = FuseGenerateIovecs(out_header); + SetServerResponse(FUSE_UNLINK, iov_out); + + ASSERT_THAT(unlink(test_file_path.c_str()), SyscallFailsWithErrno(ENOENT)); + SkipServerActualRequest(); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/fuse/linux/write_test.cc b/test/fuse/linux/write_test.cc new file mode 100644 index 000000000..1a62beb96 --- /dev/null +++ b/test/fuse/linux/write_test.cc @@ -0,0 +1,303 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <errno.h> +#include <fcntl.h> +#include <linux/fuse.h> +#include <sys/stat.h> +#include <sys/statfs.h> +#include <sys/types.h> +#include <unistd.h> + +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "test/fuse/linux/fuse_base.h" +#include "test/util/fuse_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +class WriteTest : public FuseTest { + void SetUp() override { + FuseTest::SetUp(); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + // TearDown overrides the parent's function + // to skip checking the unconsumed release request at the end. + void TearDown() override { UnmountFuse(); } + + protected: + const std::string test_file_ = "test_file"; + const mode_t test_file_mode_ = S_IFREG | S_IRWXU | S_IRWXG | S_IRWXO; + const uint64_t test_fh_ = 1; + const uint32_t open_flag_ = O_RDWR; + + std::string test_file_path_; + + PosixErrorOr<FileDescriptor> OpenTestFile(const std::string &path, + uint64_t size = 512) { + SetServerInodeLookup(test_file_, test_file_mode_, size); + + struct fuse_out_header out_header_open = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_open_out), + }; + struct fuse_open_out out_payload_open = { + .fh = test_fh_, + .open_flags = open_flag_, + }; + auto iov_out_open = FuseGenerateIovecs(out_header_open, out_payload_open); + SetServerResponse(FUSE_OPEN, iov_out_open); + + auto res = Open(path.c_str(), open_flag_); + if (res.ok()) { + SkipServerActualRequest(); + } + return res; + } +}; + +class WriteTestSmallMaxWrite : public WriteTest { + void SetUp() override { + MountFuse(); + SetUpFuseServer(&fuse_init_payload); + test_file_path_ = JoinPath(mount_point_.path().c_str(), test_file_); + } + + protected: + const static uint32_t max_write_ = 4096; + constexpr static struct fuse_init_out fuse_init_payload = { + .major = 7, + .max_write = max_write_, + }; + + const uint32_t size_fragment = max_write_; +}; + +TEST_F(WriteTest, WriteNormal) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_write, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_write)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteShort) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10, n_written = 5; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_written, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_written)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteShortZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = 0, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), SyscallFailsWithErrno(EIO)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, 0); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTest, WriteZero) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_)); + + // Issue the write. + std::vector<char> buf(0); + EXPECT_THAT(write(fd.get(), buf.data(), 0), SyscallSucceedsWithValue(0)); +} + +TEST_F(WriteTest, PWrite) { + const int file_size = 512; + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, file_size)); + + // Prepare for the write. + const int n_write = 10; + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = n_write, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + SetServerResponse(FUSE_WRITE, iov_out_write); + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + const int offset_write = file_size >> 1; + EXPECT_THAT(pwrite(fd.get(), buf.data(), n_write, offset_write), + SyscallSucceedsWithValue(n_write)); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(n_write); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, offset_write); + EXPECT_EQ(in_payload_write.size, n_write); + EXPECT_EQ(buf, payload_buf); +} + +TEST_F(WriteTestSmallMaxWrite, WriteSmallMaxWrie) { + const int n_fragment = 10; + const int n_write = size_fragment * n_fragment; + + auto fd = ASSERT_NO_ERRNO_AND_VALUE(OpenTestFile(test_file_path_, n_write)); + + // Prepare for the write. + struct fuse_out_header out_header_write = { + .len = sizeof(struct fuse_out_header) + sizeof(struct fuse_write_out), + }; + struct fuse_write_out out_payload_write = { + .size = size_fragment, + }; + auto iov_out_write = FuseGenerateIovecs(out_header_write, out_payload_write); + + for (int i = 0; i < n_fragment; ++i) { + SetServerResponse(FUSE_WRITE, iov_out_write); + } + + // Issue the write. + std::vector<char> buf(n_write); + RandomizeBuffer(buf.data(), buf.size()); + EXPECT_THAT(write(fd.get(), buf.data(), n_write), + SyscallSucceedsWithValue(n_write)); + + ASSERT_EQ(GetServerNumUnsentResponses(), 0); + ASSERT_EQ(GetServerNumUnconsumedRequests(), n_fragment); + + // Check the write request. + struct fuse_in_header in_header_write; + struct fuse_write_in in_payload_write; + std::vector<char> payload_buf(size_fragment); + auto iov_in_write = + FuseGenerateIovecs(in_header_write, in_payload_write, payload_buf); + + for (int i = 0; i < n_fragment; ++i) { + GetServerActualRequest(iov_in_write); + + EXPECT_EQ(in_payload_write.fh, test_fh_); + EXPECT_EQ(in_header_write.len, + sizeof(in_header_write) + sizeof(in_payload_write)); + EXPECT_EQ(in_header_write.opcode, FUSE_WRITE); + EXPECT_EQ(in_payload_write.offset, i * size_fragment); + EXPECT_EQ(in_payload_write.size, size_fragment); + + auto it = buf.begin() + i * size_fragment; + EXPECT_EQ(std::vector<char>(it, it + size_fragment), payload_buf); + } +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/image/image_test.go b/test/image/image_test.go index ac6186688..968e62f63 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -63,8 +63,8 @@ func TestHelloWorld(t *testing.T) { } } -func runHTTPRequest(port int) error { - url := fmt.Sprintf("http://localhost:%d/not-found", port) +func runHTTPRequest(ip string, port int) error { + url := fmt.Sprintf("http://%s:%d/not-found", ip, port) resp, err := http.Get(url) if err != nil { return fmt.Errorf("error reaching http server: %v", err) @@ -73,7 +73,7 @@ func runHTTPRequest(port int) error { return fmt.Errorf("Wrong response code, got: %d, want: %d", resp.StatusCode, want) } - url = fmt.Sprintf("http://localhost:%d/latin10k.txt", port) + url = fmt.Sprintf("http://%s:%d/latin10k.txt", ip, port) resp, err = http.Get(url) if err != nil { return fmt.Errorf("Error reaching http server: %v", err) @@ -95,13 +95,13 @@ func runHTTPRequest(port int) error { return nil } -func testHTTPServer(t *testing.T, port int) { +func testHTTPServer(t *testing.T, ip string, port int) { const requests = 10 ch := make(chan error, requests) for i := 0; i < requests; i++ { go func() { start := time.Now() - err := runHTTPRequest(port) + err := runHTTPRequest(ip, port) log.Printf("Response time %v: %v", time.Since(start).String(), err) ch <- err }() @@ -110,7 +110,7 @@ func testHTTPServer(t *testing.T, port int) { for i := 0; i < requests; i++ { err := <-ch if err != nil { - t.Errorf("testHTTPServer(%d) failed: %v", port, err) + t.Errorf("testHTTPServer(%s, %d) failed: %v", ip, port, err) } } } @@ -121,27 +121,28 @@ func TestHttpd(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 80 opts := dockerutil.RunOpts{ Image: "basic/httpd", - Ports: []int{80}, + Ports: []int{port}, } d.CopyFiles(&opts, "/usr/local/apache2/htdocs", "test/image/latin10k.txt") if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 80 is mapped to. - port, err := d.FindPort(ctx, 80) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } - testHTTPServer(t, port) + testHTTPServer(t, ip.String(), port) } func TestNginx(t *testing.T) { @@ -150,27 +151,28 @@ func TestNginx(t *testing.T) { defer d.CleanUp(ctx) // Start the container. + port := 80 opts := dockerutil.RunOpts{ Image: "basic/nginx", - Ports: []int{80}, + Ports: []int{port}, } d.CopyFiles(&opts, "/usr/share/nginx/html", "test/image/latin10k.txt") if err := d.Spawn(ctx, opts); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 80 is mapped to. - port, err := d.FindPort(ctx, 80) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(80) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Errorf("WaitForHTTP() timeout: %v", err) } - testHTTPServer(t, port) + testHTTPServer(t, ip.String(), port) } func TestMysql(t *testing.T) { @@ -218,26 +220,27 @@ func TestTomcat(t *testing.T) { defer d.CleanUp(ctx) // Start the server. + port := 8080 if err := d.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/tomcat", - Ports: []int{8080}, + Ports: []int{port}, }); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(ctx, 8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running. - if err := testutil.WaitForHTTP(port, defaultWait); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, defaultWait); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) + url := fmt.Sprintf("http://%s:%d", ip.String(), port) resp, err := http.Get(url) if err != nil { t.Errorf("Error reaching http server: %v", err) @@ -253,28 +256,29 @@ func TestRuby(t *testing.T) { defer d.CleanUp(ctx) // Execute the ruby workload. + port := 8080 opts := dockerutil.RunOpts{ Image: "basic/ruby", - Ports: []int{8080}, + Ports: []int{port}, } d.CopyFiles(&opts, "/src", "test/image/ruby.rb", "test/image/ruby.sh") if err := d.Spawn(ctx, opts, "/src/ruby.sh"); err != nil { t.Fatalf("docker run failed: %v", err) } - // Find where port 8080 is mapped to. - port, err := d.FindPort(ctx, 8080) + // Find container IP address. + ip, err := d.FindIP(ctx, false) if err != nil { - t.Fatalf("FindPort(8080) failed: %v", err) + t.Fatalf("docker.FindIP failed: %v", err) } // Wait until it's up and running, 'gem install' can take some time. - if err := testutil.WaitForHTTP(port, time.Minute); err != nil { + if err := testutil.WaitForHTTP(ip.String(), port, time.Minute); err != nil { t.Fatalf("WaitForHTTP() timeout: %v", err) } // Ensure that content is being served. - url := fmt.Sprintf("http://localhost:%d", port) + url := fmt.Sprintf("http://%s:%d", ip.String(), port) resp, err := http.Get(url) if err != nil { t.Errorf("error reaching http server: %v", err) diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go index 5737ee317..b45d448b8 100644 --- a/test/iptables/filter_input.go +++ b/test/iptables/filter_input.go @@ -15,6 +15,7 @@ package iptables import ( + "context" "errors" "fmt" "net" @@ -53,7 +54,7 @@ func init() { } // FilterInputDropUDP tests that we can drop UDP traffic. -type FilterInputDropUDP struct{} +type FilterInputDropUDP struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropUDP) Name() string { @@ -61,15 +62,17 @@ func (FilterInputDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -79,12 +82,12 @@ func (FilterInputDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic. -type FilterInputDropOnlyUDP struct{} +type FilterInputDropOnlyUDP struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropOnlyUDP) Name() string { @@ -92,13 +95,13 @@ func (FilterInputDropOnlyUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropOnlyUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-j", "DROP"); err != nil { return err } // Listen for a TCP connection, which should be allowed. - if err := listenTCP(acceptPort, sendloopDuration); err != nil { + if err := listenTCP(ctx, acceptPort); err != nil { return fmt.Errorf("failed to establish a connection %v", err) } @@ -106,14 +109,14 @@ func (FilterInputDropOnlyUDP) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropOnlyUDP) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropOnlyUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Try to establish a TCP connection with the container, which should // succeed. - return connectTCP(ip, acceptPort, sendloopDuration) + return connectTCP(ctx, ip, acceptPort) } // FilterInputDropUDPPort tests that we can drop UDP traffic by port. -type FilterInputDropUDPPort struct{} +type FilterInputDropUDPPort struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropUDPPort) Name() string { @@ -121,15 +124,17 @@ func (FilterInputDropUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -139,13 +144,13 @@ func (FilterInputDropUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropUDPPort) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port // doesn't drop packets on other ports. -type FilterInputDropDifferentUDPPort struct{} +type FilterInputDropDifferentUDPPort struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropDifferentUDPPort) Name() string { @@ -153,13 +158,13 @@ func (FilterInputDropDifferentUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropDifferentUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for UDP packets on another port. - if err := listenUDP(acceptPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, acceptPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) } @@ -167,12 +172,12 @@ func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP, ipv6 bool) err } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDropDifferentUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports. -type FilterInputDropTCPDestPort struct{} +type FilterInputDropTCPDestPort struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropTCPDestPort) Name() string { @@ -180,33 +185,36 @@ func (FilterInputDropTCPDestPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPDestPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for TCP packets on drop port. - if err := listenTCP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, dropPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPDestPort) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Ensure we cannot connect to the container. - for start := time.Now(); time.Since(start) < sendloopDuration; { - if err := connectTCP(ip, dropPort, sendloopDuration-time.Since(start)); err == nil { - return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort) - } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", dropPort) } - return nil } // FilterInputDropTCPSrcPort tests that connections are not accepted on specified source ports. -type FilterInputDropTCPSrcPort struct{} +type FilterInputDropTCPSrcPort struct{ baseCase } // Name implements TestCase.Name. func (FilterInputDropTCPSrcPort) Name() string { @@ -214,34 +222,37 @@ func (FilterInputDropTCPSrcPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropTCPSrcPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Drop anything from an ephemeral port. if err := filterTable(ipv6, "-A", "INPUT", "-p", "tcp", "-m", "tcp", "--sport", "1024:65535", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but was", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropTCPSrcPort) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Ensure we cannot connect to the container. - for start := time.Now(); time.Since(start) < sendloopDuration; { - if err := connectTCP(ip, acceptPort, sendloopDuration-time.Since(start)); err == nil { - return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort) - } + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { + return fmt.Errorf("expected not to connect, but was able to connect on port %d", acceptPort) } - return nil } // FilterInputDropAll tests that we can drop all traffic to the INPUT chain. -type FilterInputDropAll struct{} +type FilterInputDropAll struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDropAll) Name() string { @@ -249,15 +260,17 @@ func (FilterInputDropAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDropAll) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDropAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-j", "DROP"); err != nil { return err } // Listen for all packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets should have been dropped, but got a packet") - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -267,15 +280,15 @@ func (FilterInputDropAll) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputDropAll) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputDropAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputMultiUDPRules verifies that multiple UDP rules are applied // correctly. This has the added benefit of testing whether we're serializing // rules correctly -- if we do it incorrectly, the iptables tool will // misunderstand and save the wrong tables. -type FilterInputMultiUDPRules struct{} +type FilterInputMultiUDPRules struct{ baseCase } // Name implements TestCase.Name. func (FilterInputMultiUDPRules) Name() string { @@ -283,7 +296,7 @@ func (FilterInputMultiUDPRules) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputMultiUDPRules) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputMultiUDPRules) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"}, {"-A", "INPUT", "-p", "udp", "-m", "udp", "--destination-port", fmt.Sprintf("%d", acceptPort), "-j", "ACCEPT"}, @@ -293,14 +306,14 @@ func (FilterInputMultiUDPRules) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputMultiUDPRules) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputMultiUDPRules) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputRequireProtocolUDP checks that "-m udp" requires "-p udp" to be // specified. -type FilterInputRequireProtocolUDP struct{} +type FilterInputRequireProtocolUDP struct{ baseCase } // Name implements TestCase.Name. func (FilterInputRequireProtocolUDP) Name() string { @@ -308,20 +321,20 @@ func (FilterInputRequireProtocolUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputRequireProtocolUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputRequireProtocolUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-m", "udp", "--destination-port", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err == nil { return errors.New("expected iptables to fail with out \"-p udp\", but succeeded") } return nil } -func (FilterInputRequireProtocolUDP) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputRequireProtocolUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputCreateUserChain tests chain creation. -type FilterInputCreateUserChain struct{} +type FilterInputCreateUserChain struct{ baseCase } // Name implements TestCase.Name. func (FilterInputCreateUserChain) Name() string { @@ -329,7 +342,7 @@ func (FilterInputCreateUserChain) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputCreateUserChain) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputCreateUserChain) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ // Create a chain. {"-N", chainName}, @@ -340,13 +353,13 @@ func (FilterInputCreateUserChain) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputCreateUserChain) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputCreateUserChain) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputDefaultPolicyAccept tests the default ACCEPT policy. -type FilterInputDefaultPolicyAccept struct{} +type FilterInputDefaultPolicyAccept struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDefaultPolicyAccept) Name() string { @@ -354,21 +367,21 @@ func (FilterInputDefaultPolicyAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDefaultPolicyAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Set the default policy to accept, then receive a packet. if err := filterTable(ipv6, "-P", "INPUT", "ACCEPT"); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDefaultPolicyAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDefaultPolicyDrop tests the default DROP policy. -type FilterInputDefaultPolicyDrop struct{} +type FilterInputDefaultPolicyDrop struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDefaultPolicyDrop) Name() string { @@ -376,15 +389,17 @@ func (FilterInputDefaultPolicyDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDefaultPolicyDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-P", "INPUT", "DROP"); err != nil { return err } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -394,13 +409,13 @@ func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP, ipv6 bool) error } // LocalAction implements TestCase.LocalAction. -func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDefaultPolicyDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes // the underflow rule (i.e. default policy) to be executed. -type FilterInputReturnUnderflow struct{} +type FilterInputReturnUnderflow struct{ containerCase } // Name implements TestCase.Name. func (FilterInputReturnUnderflow) Name() string { @@ -408,7 +423,7 @@ func (FilterInputReturnUnderflow) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputReturnUnderflow) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputReturnUnderflow) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Add a RETURN rule followed by an unconditional accept, and set the // default policy to DROP. rules := [][]string{ @@ -422,16 +437,16 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP, ipv6 bool) error { // We should receive packets, as the RETURN rule will trigger the default // ACCEPT policy. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputReturnUnderflow) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputReturnUnderflow) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputSerializeJump verifies that we can serialize jumps. -type FilterInputSerializeJump struct{} +type FilterInputSerializeJump struct{ baseCase } // Name implements TestCase.Name. func (FilterInputSerializeJump) Name() string { @@ -439,7 +454,7 @@ func (FilterInputSerializeJump) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputSerializeJump) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputSerializeJump) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Write a JUMP rule, the serialize it with `-L`. rules := [][]string{ {"-N", chainName}, @@ -450,13 +465,13 @@ func (FilterInputSerializeJump) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputSerializeJump) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputSerializeJump) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputJumpBasic jumps to a chain and executes a rule there. -type FilterInputJumpBasic struct{} +type FilterInputJumpBasic struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpBasic) Name() string { @@ -464,7 +479,7 @@ func (FilterInputJumpBasic) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBasic) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpBasic) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-P", "INPUT", "DROP"}, {"-N", chainName}, @@ -476,16 +491,16 @@ func (FilterInputJumpBasic) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBasic) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpBasic) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputJumpReturn jumps, returns, and executes a rule. -type FilterInputJumpReturn struct{} +type FilterInputJumpReturn struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpReturn) Name() string { @@ -493,7 +508,7 @@ func (FilterInputJumpReturn) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturn) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpReturn) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-N", chainName}, {"-P", "INPUT", "ACCEPT"}, @@ -506,16 +521,16 @@ func (FilterInputJumpReturn) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturn) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpReturn) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets. -type FilterInputJumpReturnDrop struct{} +type FilterInputJumpReturnDrop struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpReturnDrop) Name() string { @@ -523,7 +538,7 @@ func (FilterInputJumpReturnDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpReturnDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-N", chainName}, {"-A", "INPUT", "-j", chainName}, @@ -535,9 +550,11 @@ func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets on port %d should have been dropped, but got a packet", dropPort) - } else if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + } else if !errors.Is(err, context.DeadlineExceeded) { return fmt.Errorf("error reading: %v", err) } @@ -547,12 +564,12 @@ func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpReturnDrop) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (FilterInputJumpReturnDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal. -type FilterInputJumpBuiltin struct{} +type FilterInputJumpBuiltin struct{ baseCase } // Name implements TestCase.Name. func (FilterInputJumpBuiltin) Name() string { @@ -560,7 +577,7 @@ func (FilterInputJumpBuiltin) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpBuiltin) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpBuiltin) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "INPUT", "-j", "OUTPUT"); err == nil { return fmt.Errorf("iptables should be unable to jump to a built-in chain") } @@ -568,13 +585,13 @@ func (FilterInputJumpBuiltin) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpBuiltin) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpBuiltin) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // FilterInputJumpTwice jumps twice, then returns twice and executes a rule. -type FilterInputJumpTwice struct{} +type FilterInputJumpTwice struct{ containerCase } // Name implements TestCase.Name. func (FilterInputJumpTwice) Name() string { @@ -582,7 +599,7 @@ func (FilterInputJumpTwice) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputJumpTwice) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputJumpTwice) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { const chainName2 = chainName + "2" rules := [][]string{ {"-P", "INPUT", "DROP"}, @@ -598,17 +615,17 @@ func (FilterInputJumpTwice) ContainerAction(ip net.IP, ipv6 bool) error { // UDP packets should jump and return twice, eventually hitting the // ACCEPT rule. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputJumpTwice) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputJumpTwice) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputDestination verifies that we can filter packets via `-d // <ipaddr>`. -type FilterInputDestination struct{} +type FilterInputDestination struct{ containerCase } // Name implements TestCase.Name. func (FilterInputDestination) Name() string { @@ -616,7 +633,7 @@ func (FilterInputDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputDestination) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { addrs, err := localAddrs(ipv6) if err != nil { return err @@ -632,17 +649,17 @@ func (FilterInputDestination) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputDestination) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputInvertDestination verifies that we can filter packets via `! -d // <ipaddr>`. -type FilterInputInvertDestination struct{} +type FilterInputInvertDestination struct{ containerCase } // Name implements TestCase.Name. func (FilterInputInvertDestination) Name() string { @@ -650,7 +667,7 @@ func (FilterInputInvertDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets not bound // for 127.0.0.1. rules := [][]string{ @@ -661,17 +678,17 @@ func (FilterInputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputInvertDestination) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputSource verifies that we can filter packets via `-s // <ipaddr>`. -type FilterInputSource struct{} +type FilterInputSource struct{ containerCase } // Name implements TestCase.Name. func (FilterInputSource) Name() string { @@ -679,7 +696,7 @@ func (FilterInputSource) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputSource) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets from this // machine. rules := [][]string{ @@ -690,17 +707,17 @@ func (FilterInputSource) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputSource) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // FilterInputInvertSource verifies that we can filter packets via `! -s // <ipaddr>`. -type FilterInputInvertSource struct{} +type FilterInputInvertSource struct{ containerCase } // Name implements TestCase.Name. func (FilterInputInvertSource) Name() string { @@ -708,7 +725,7 @@ func (FilterInputInvertSource) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterInputInvertSource) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterInputInvertSource) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Make INPUT's default action DROP, then ACCEPT all packets not bound // for 127.0.0.1. rules := [][]string{ @@ -719,10 +736,10 @@ func (FilterInputInvertSource) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterInputInvertSource) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (FilterInputInvertSource) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } diff --git a/test/iptables/filter_output.go b/test/iptables/filter_output.go index c1d83b471..32bf2a992 100644 --- a/test/iptables/filter_output.go +++ b/test/iptables/filter_output.go @@ -15,6 +15,8 @@ package iptables import ( + "context" + "errors" "fmt" "net" ) @@ -44,7 +46,7 @@ func init() { // FilterOutputDropTCPDestPort tests that connections are not accepted on // specified source ports. -type FilterOutputDropTCPDestPort struct{} +type FilterOutputDropTCPDestPort struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPDestPort) Name() string { @@ -52,22 +54,28 @@ func (FilterOutputDropTCPDestPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPDestPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropTCPDestPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", "1024:65535", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPDestPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) } @@ -76,7 +84,7 @@ func (FilterOutputDropTCPDestPort) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputDropTCPSrcPort tests that connections are not accepted on // specified source ports. -type FilterOutputDropTCPSrcPort struct{} +type FilterOutputDropTCPSrcPort struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPSrcPort) Name() string { @@ -84,22 +92,28 @@ func (FilterOutputDropTCPSrcPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPSrcPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropTCPSrcPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--sport", fmt.Sprintf("%d", dropPort), "-j", "DROP"); err != nil { return err } // Listen for TCP packets on drop port. - if err := listenTCP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, dropPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", dropPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, dropPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPSrcPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, dropPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", dropPort) } @@ -107,7 +121,7 @@ func (FilterOutputDropTCPSrcPort) LocalAction(ip net.IP, ipv6 bool) error { } // FilterOutputAcceptTCPOwner tests that TCP connections from uid owner are accepted. -type FilterOutputAcceptTCPOwner struct{} +type FilterOutputAcceptTCPOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputAcceptTCPOwner) Name() string { @@ -115,22 +129,22 @@ func (FilterOutputAcceptTCPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptTCPOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputAcceptTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptTCPOwner) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputAcceptTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputDropTCPOwner tests that TCP connections from uid owner are dropped. -type FilterOutputDropTCPOwner struct{} +type FilterOutputDropTCPOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropTCPOwner) Name() string { @@ -138,22 +152,28 @@ func (FilterOutputDropTCPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropTCPOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropTCPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should be dropped, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropTCPOwner) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropTCPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should be dropped, but got accepted", acceptPort) } @@ -161,7 +181,7 @@ func (FilterOutputDropTCPOwner) LocalAction(ip net.IP, ipv6 bool) error { } // FilterOutputAcceptUDPOwner tests that UDP packets from uid owner are accepted. -type FilterOutputAcceptUDPOwner struct{} +type FilterOutputAcceptUDPOwner struct{ localCase } // Name implements TestCase.Name. func (FilterOutputAcceptUDPOwner) Name() string { @@ -169,23 +189,23 @@ func (FilterOutputAcceptUDPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptUDPOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputAcceptUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Send UDP packets on acceptPort. - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptUDPOwner) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterOutputAcceptUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Listen for UDP packets on acceptPort. - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // FilterOutputDropUDPOwner tests that UDP packets from uid owner are dropped. -type FilterOutputDropUDPOwner struct{} +type FilterOutputDropUDPOwner struct{ localCase } // Name implements TestCase.Name. func (FilterOutputDropUDPOwner) Name() string { @@ -193,20 +213,24 @@ func (FilterOutputDropUDPOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropUDPOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropUDPOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "--uid-owner", "root", "-j", "DROP"); err != nil { return err } // Send UDP packets on dropPort. - return sendUDPLoop(ip, dropPort, sendloopDuration) + return sendUDPLoop(ctx, ip, dropPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropUDPOwner) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropUDPOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Listen for UDP packets on dropPort. - if err := listenUDP(dropPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, dropPort); err == nil { return fmt.Errorf("packets should not be received") + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -214,7 +238,7 @@ func (FilterOutputDropUDPOwner) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputOwnerFail tests that without uid/gid option, owner rule // will fail. -type FilterOutputOwnerFail struct{} +type FilterOutputOwnerFail struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputOwnerFail) Name() string { @@ -222,7 +246,7 @@ func (FilterOutputOwnerFail) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputOwnerFail) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputOwnerFail) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-m", "owner", "-j", "ACCEPT"); err == nil { return fmt.Errorf("Invalid argument") } @@ -231,13 +255,13 @@ func (FilterOutputOwnerFail) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (FilterOutputOwnerFail) LocalAction(ip net.IP, ipv6 bool) error { +func (FilterOutputOwnerFail) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // no-op. return nil } // FilterOutputAcceptGIDOwner tests that TCP connections from gid owner are accepted. -type FilterOutputAcceptGIDOwner struct{} +type FilterOutputAcceptGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputAcceptGIDOwner) Name() string { @@ -245,22 +269,22 @@ func (FilterOutputAcceptGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputAcceptGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputAcceptGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputAcceptGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputAcceptGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputDropGIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputDropGIDOwner struct{} +type FilterOutputDropGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputDropGIDOwner) Name() string { @@ -268,22 +292,28 @@ func (FilterOutputDropGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDropGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDropGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "owner", "--gid-owner", "root", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDropGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputDropGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -291,7 +321,7 @@ func (FilterOutputDropGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { } // FilterOutputInvertGIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputInvertGIDOwner struct{} +type FilterOutputInvertGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertGIDOwner) Name() string { @@ -299,7 +329,7 @@ func (FilterOutputInvertGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInvertGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--gid-owner", "root", "-j", "ACCEPT"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, @@ -309,16 +339,22 @@ func (FilterOutputInvertGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInvertGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -326,7 +362,7 @@ func (FilterOutputInvertGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { } // FilterOutputInvertUIDOwner tests that TCP connections from gid owner are dropped. -type FilterOutputInvertUIDOwner struct{} +type FilterOutputInvertUIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertUIDOwner) Name() string { @@ -334,7 +370,7 @@ func (FilterOutputInvertUIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInvertUIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "-j", "DROP"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "ACCEPT"}, @@ -344,17 +380,17 @@ func (FilterOutputInvertUIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertUIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputInvertUIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // FilterOutputInvertUIDAndGIDOwner tests that TCP connections from uid and gid // owner are dropped. -type FilterOutputInvertUIDAndGIDOwner struct{} +type FilterOutputInvertUIDAndGIDOwner struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInvertUIDAndGIDOwner) Name() string { @@ -362,7 +398,7 @@ func (FilterOutputInvertUIDAndGIDOwner) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-p", "tcp", "-m", "owner", "!", "--uid-owner", "root", "!", "--gid-owner", "root", "-j", "ACCEPT"}, {"-A", "OUTPUT", "-p", "tcp", "-j", "DROP"}, @@ -372,16 +408,22 @@ func (FilterOutputInvertUIDAndGIDOwner) ContainerAction(ip net.IP, ipv6 bool) er } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -390,7 +432,7 @@ func (FilterOutputInvertUIDAndGIDOwner) LocalAction(ip net.IP, ipv6 bool) error // FilterOutputDestination tests that we can selectively allow packets to // certain destinations. -type FilterOutputDestination struct{} +type FilterOutputDestination struct{ localCase } // Name implements TestCase.Name. func (FilterOutputDestination) Name() string { @@ -398,7 +440,7 @@ func (FilterOutputDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputDestination) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "-d", ip.String(), "-j", "ACCEPT"}, {"-P", "OUTPUT", "DROP"}, @@ -407,17 +449,17 @@ func (FilterOutputDestination) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputDestination) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInvertDestination tests that we can selectively allow packets // not headed for a particular destination. -type FilterOutputInvertDestination struct{} +type FilterOutputInvertDestination struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInvertDestination) Name() string { @@ -425,7 +467,7 @@ func (FilterOutputInvertDestination) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInvertDestination) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { rules := [][]string{ {"-A", "OUTPUT", "!", "-d", localIP(ipv6), "-j", "ACCEPT"}, {"-P", "OUTPUT", "DROP"}, @@ -434,17 +476,17 @@ func (FilterOutputInvertDestination) ContainerAction(ip net.IP, ipv6 bool) error return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInvertDestination) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInvertDestination) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceAccept tests that packets are sent via interface // matching the iptables rule. -type FilterOutputInterfaceAccept struct{} +type FilterOutputInterfaceAccept struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceAccept) Name() string { @@ -452,7 +494,7 @@ func (FilterOutputInterfaceAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { ifname, ok := getInterfaceName() if !ok { return fmt.Errorf("no interface is present, except loopback") @@ -461,17 +503,17 @@ func (FilterOutputInterfaceAccept) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceAccept) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInterfaceAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceDrop tests that packets are not sent via interface // matching the iptables rule. -type FilterOutputInterfaceDrop struct{} +type FilterOutputInterfaceDrop struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceDrop) Name() string { @@ -479,7 +521,7 @@ func (FilterOutputInterfaceDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { ifname, ok := getInterfaceName() if !ok { return fmt.Errorf("no interface is present, except loopback") @@ -488,13 +530,17 @@ func (FilterOutputInterfaceDrop) ContainerAction(ip net.IP, ipv6 bool) error { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceDrop) LocalAction(ip net.IP, ipv6 bool) error { - if err := listenUDP(acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -502,7 +548,7 @@ func (FilterOutputInterfaceDrop) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputInterface tests that packets are sent via interface which is // not matching the interface name in the iptables rule. -type FilterOutputInterface struct{} +type FilterOutputInterface struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterface) Name() string { @@ -510,22 +556,22 @@ func (FilterOutputInterface) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterface) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterface) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "lo", "-j", "DROP"); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterface) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (FilterOutputInterface) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // FilterOutputInterfaceBeginsWith tests that packets are not sent via an // interface which begins with the given interface name. -type FilterOutputInterfaceBeginsWith struct{} +type FilterOutputInterfaceBeginsWith struct{ localCase } // Name implements TestCase.Name. func (FilterOutputInterfaceBeginsWith) Name() string { @@ -533,18 +579,22 @@ func (FilterOutputInterfaceBeginsWith) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceBeginsWith) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceBeginsWith) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "udp", "-o", "e+", "-j", "DROP"); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP, ipv6 bool) error { - if err := listenUDP(acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceBeginsWith) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets should not be received on port %v, but are received", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil @@ -552,7 +602,7 @@ func (FilterOutputInterfaceBeginsWith) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputInterfaceInvertDrop tests that we selectively do not send // packets via interface not matching the interface name. -type FilterOutputInterfaceInvertDrop struct{} +type FilterOutputInterfaceInvertDrop struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInterfaceInvertDrop) Name() string { @@ -560,22 +610,28 @@ func (FilterOutputInterfaceInvertDrop) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceInvertDrop) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceInvertDrop) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "DROP"); err != nil { return err } // Listen for TCP packets on accept port. - if err := listenTCP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenTCP(timedCtx, acceptPort); err == nil { return fmt.Errorf("connection on port %d should not be accepted, but got accepted", acceptPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP, ipv6 bool) error { - if err := connectTCP(ip, acceptPort, sendloopDuration); err == nil { +func (FilterOutputInterfaceInvertDrop) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := connectTCP(timedCtx, ip, acceptPort); err == nil { return fmt.Errorf("connection destined to port %d should not be accepted, but got accepted", acceptPort) } @@ -584,7 +640,7 @@ func (FilterOutputInterfaceInvertDrop) LocalAction(ip net.IP, ipv6 bool) error { // FilterOutputInterfaceInvertAccept tests that we can selectively send packets // not matching the specific outgoing interface. -type FilterOutputInterfaceInvertAccept struct{} +type FilterOutputInterfaceInvertAccept struct{ baseCase } // Name implements TestCase.Name. func (FilterOutputInterfaceInvertAccept) Name() string { @@ -592,16 +648,16 @@ func (FilterOutputInterfaceInvertAccept) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (FilterOutputInterfaceInvertAccept) ContainerAction(ip net.IP, ipv6 bool) error { +func (FilterOutputInterfaceInvertAccept) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := filterTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "!", "-o", "lo", "-j", "ACCEPT"); err != nil { return err } // Listen for TCP packets on accept port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (FilterOutputInterfaceInvertAccept) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (FilterOutputInterfaceInvertAccept) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } diff --git a/test/iptables/iptables.go b/test/iptables/iptables.go index dfbd80cd1..c2a03f54c 100644 --- a/test/iptables/iptables.go +++ b/test/iptables/iptables.go @@ -16,6 +16,7 @@ package iptables import ( + "context" "fmt" "net" "time" @@ -29,7 +30,11 @@ const IPExchangePort = 2349 const TerminalStatement = "Finished!" // TestTimeout is the timeout used for all tests. -const TestTimeout = 10 * time.Minute +const TestTimeout = 10 * time.Second + +// NegativeTimeout is the time tests should wait to establish the negative +// case, i.e. that connections are not made. +const NegativeTimeout = 2 * time.Second // A TestCase contains one action to run in the container and one to run // locally. The actions run concurrently and each must succeed for the test @@ -40,10 +45,60 @@ type TestCase interface { // ContainerAction runs inside the container. It receives the IP of the // local process. - ContainerAction(ip net.IP, ipv6 bool) error + ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error // LocalAction runs locally. It receives the IP of the container. - LocalAction(ip net.IP, ipv6 bool) error + LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error + + // ContainerSufficient indicates whether ContainerAction's return value + // alone indicates whether the test succeeded. + ContainerSufficient() bool + + // LocalSufficient indicates whether LocalAction's return value alone + // indicates whether the test succeeded. + LocalSufficient() bool +} + +// baseCase provides defaults for ContainerSufficient and LocalSufficient when +// both actions are required to finish. +type baseCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (baseCase) ContainerSufficient() bool { + return false +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (baseCase) LocalSufficient() bool { + return false +} + +// localCase provides defaults for ContainerSufficient and LocalSufficient when +// only the local action is required to finish. +type localCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (localCase) ContainerSufficient() bool { + return false +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (localCase) LocalSufficient() bool { + return true +} + +// containerCase provides defaults for ContainerSufficient and LocalSufficient +// when only the container action is required to finish. +type containerCase struct{} + +// ContainerSufficient implements TestCase.ContainerSufficient. +func (containerCase) ContainerSufficient() bool { + return true +} + +// LocalSufficient implements TestCase.LocalSufficient. +func (containerCase) LocalSufficient() bool { + return false } // Tests maps test names to TestCase. diff --git a/test/iptables/iptables_test.go b/test/iptables/iptables_test.go index fda5f694f..398f70ecd 100644 --- a/test/iptables/iptables_test.go +++ b/test/iptables/iptables_test.go @@ -16,9 +16,11 @@ package iptables import ( "context" + "errors" "fmt" "net" "reflect" + "sync" "testing" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -46,19 +48,36 @@ func singleTest(t *testing.T, test TestCase) { } } +// TODO(gvisor.dev/issue/3549): IPv6 NAT support. +func ipv4Test(t *testing.T, test TestCase) { + t.Run("IPv4", func(t *testing.T) { + iptablesTest(t, test, false) + }) +} + func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { if _, ok := Tests[test.Name()]; !ok { t.Fatalf("no test found with name %q. Has it been registered?", test.Name()) } - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) + // Wait for the local and container goroutines to finish. + var wg sync.WaitGroup + defer wg.Wait() - // TODO(gvisor.dev/issue/170): Skipping IPv6 gVisor tests. - if ipv6 && dockerutil.Runtime() != "runc" { - t.Skip("gVisor ip6tables not yet implemented") - } + ctx, cancel := context.WithTimeout(context.Background(), TestTimeout) + defer cancel() + + d := dockerutil.MakeContainer(ctx, t) + defer func() { + if logs, err := d.Logs(context.Background()); err != nil { + t.Logf("Failed to retrieve container logs.") + } else { + t.Logf("=== Container logs: ===\n%s", logs) + } + // Use a new context, as cleanup should run even when we + // timeout. + d.CleanUp(context.Background()) + }() // Create and start the container. opts := dockerutil.RunOpts{ @@ -86,15 +105,44 @@ func iptablesTest(t *testing.T, test TestCase, ipv6 bool) { } // Run our side of the test. - if err := test.LocalAction(ip, ipv6); err != nil { - t.Fatalf("LocalAction failed: %v", err) - } - - // Wait for the final statement. This structure has the side effect - // that all container logs will appear within the individual test - // context. - if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil { - t.Fatalf("test failed: %v", err) + errCh := make(chan error, 2) + wg.Add(1) + go func() { + defer wg.Done() + if err := test.LocalAction(ctx, ip, ipv6); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("LocalAction failed: %v", err) + } else { + errCh <- nil + } + if test.LocalSufficient() { + errCh <- nil + } + }() + + // Run the container side. + wg.Add(1) + go func() { + defer wg.Done() + // Wait for the final statement. This structure has the side + // effect that all container logs will appear within the + // individual test context. + if _, err := d.WaitForOutput(ctx, TerminalStatement, TestTimeout); err != nil && !errors.Is(err, context.Canceled) { + errCh <- fmt.Errorf("ContainerAction failed: %v", err) + } else { + errCh <- nil + } + if test.ContainerSufficient() { + errCh <- nil + } + }() + + for i := 0; i < 2; i++ { + select { + case err := <-errCh: + if err != nil { + t.Fatal(err) + } + } } } @@ -268,75 +316,75 @@ func TestInputInvertDestination(t *testing.T) { singleTest(t, FilterInputInvertDestination{}) } -func TestOutputDestination(t *testing.T) { +func TestFilterOutputDestination(t *testing.T) { singleTest(t, FilterOutputDestination{}) } -func TestOutputInvertDestination(t *testing.T) { +func TestFilterOutputInvertDestination(t *testing.T) { singleTest(t, FilterOutputInvertDestination{}) } func TestNATPreRedirectUDPPort(t *testing.T) { - singleTest(t, NATPreRedirectUDPPort{}) + ipv4Test(t, NATPreRedirectUDPPort{}) } func TestNATPreRedirectTCPPort(t *testing.T) { - singleTest(t, NATPreRedirectTCPPort{}) + ipv4Test(t, NATPreRedirectTCPPort{}) } func TestNATPreRedirectTCPOutgoing(t *testing.T) { - singleTest(t, NATPreRedirectTCPOutgoing{}) + ipv4Test(t, NATPreRedirectTCPOutgoing{}) } func TestNATOutRedirectTCPIncoming(t *testing.T) { - singleTest(t, NATOutRedirectTCPIncoming{}) + ipv4Test(t, NATOutRedirectTCPIncoming{}) } func TestNATOutRedirectUDPPort(t *testing.T) { - singleTest(t, NATOutRedirectUDPPort{}) + ipv4Test(t, NATOutRedirectUDPPort{}) } func TestNATOutRedirectTCPPort(t *testing.T) { - singleTest(t, NATOutRedirectTCPPort{}) + ipv4Test(t, NATOutRedirectTCPPort{}) } func TestNATDropUDP(t *testing.T) { - singleTest(t, NATDropUDP{}) + ipv4Test(t, NATDropUDP{}) } func TestNATAcceptAll(t *testing.T) { - singleTest(t, NATAcceptAll{}) + ipv4Test(t, NATAcceptAll{}) } func TestNATOutRedirectIP(t *testing.T) { - singleTest(t, NATOutRedirectIP{}) + ipv4Test(t, NATOutRedirectIP{}) } func TestNATOutDontRedirectIP(t *testing.T) { - singleTest(t, NATOutDontRedirectIP{}) + ipv4Test(t, NATOutDontRedirectIP{}) } func TestNATOutRedirectInvert(t *testing.T) { - singleTest(t, NATOutRedirectInvert{}) + ipv4Test(t, NATOutRedirectInvert{}) } func TestNATPreRedirectIP(t *testing.T) { - singleTest(t, NATPreRedirectIP{}) + ipv4Test(t, NATPreRedirectIP{}) } func TestNATPreDontRedirectIP(t *testing.T) { - singleTest(t, NATPreDontRedirectIP{}) + ipv4Test(t, NATPreDontRedirectIP{}) } func TestNATPreRedirectInvert(t *testing.T) { - singleTest(t, NATPreRedirectInvert{}) + ipv4Test(t, NATPreRedirectInvert{}) } func TestNATRedirectRequiresProtocol(t *testing.T) { - singleTest(t, NATRedirectRequiresProtocol{}) + ipv4Test(t, NATRedirectRequiresProtocol{}) } func TestNATLoopbackSkipsPrerouting(t *testing.T) { - singleTest(t, NATLoopbackSkipsPrerouting{}) + ipv4Test(t, NATLoopbackSkipsPrerouting{}) } func TestInputSource(t *testing.T) { @@ -373,9 +421,9 @@ func TestFilterAddrs(t *testing.T) { } func TestNATPreOriginalDst(t *testing.T) { - singleTest(t, NATPreOriginalDst{}) + ipv4Test(t, NATPreOriginalDst{}) } func TestNATOutOriginalDst(t *testing.T) { - singleTest(t, NATOutOriginalDst{}) + ipv4Test(t, NATOutOriginalDst{}) } diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go index 5125fe47b..a6ec5cca3 100644 --- a/test/iptables/iptables_util.go +++ b/test/iptables/iptables_util.go @@ -15,6 +15,7 @@ package iptables import ( + "context" "encoding/binary" "errors" "fmt" @@ -70,7 +71,7 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error { // listenUDP listens on a UDP port and returns the value of net.Conn.Read() for // the first read on that port. -func listenUDP(port int, timeout time.Duration) error { +func listenUDP(ctx context.Context, port int) error { localAddr := net.UDPAddr{ Port: port, } @@ -79,68 +80,53 @@ func listenUDP(port int, timeout time.Duration) error { return err } defer conn.Close() - conn.SetDeadline(time.Now().Add(timeout)) - _, err = conn.Read([]byte{0}) - return err -} -// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified -// over a duration. -func sendUDPLoop(ip net.IP, port int, duration time.Duration) error { - conn, err := connectUDP(ip, port) - if err != nil { - return err - } - defer conn.Close() - loopUDP(conn, duration) - return nil -} + ch := make(chan error) + go func() { + _, err = conn.Read([]byte{0}) + ch <- err + }() -// spawnUDPLoop works like sendUDPLoop, but returns immediately and sends -// packets in another goroutine. -func spawnUDPLoop(ip net.IP, port int, duration time.Duration) error { - conn, err := connectUDP(ip, port) - if err != nil { + select { + case err := <-ch: return err + case <-ctx.Done(): + return ctx.Err() } - go func() { - defer conn.Close() - loopUDP(conn, duration) - }() - return nil } -func connectUDP(ip net.IP, port int) (net.Conn, error) { +// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified +// over a duration. +func sendUDPLoop(ctx context.Context, ip net.IP, port int) error { remote := net.UDPAddr{ IP: ip, Port: port, } conn, err := net.DialUDP("udp", nil, &remote) if err != nil { - return nil, err + return err } - return conn, nil -} + defer conn.Close() -func loopUDP(conn net.Conn, duration time.Duration) { - to := time.After(duration) - for timedOut := false; !timedOut; { + for { // This may return an error (connection refused) if the remote // hasn't started listening yet or they're dropping our // packets. So we ignore Write errors and depend on the remote // to report a failure if it doesn't get a packet it needs. conn.Write([]byte{0}) select { - case <-to: - timedOut = true - default: - time.Sleep(200 * time.Millisecond) + case <-ctx.Done(): + // Being cancelled or timing out isn't an error, as we + // cannot tell with UDP whether we succeeded. + return nil + // Continue looping. + case <-time.After(200 * time.Millisecond): } } } // listenTCP listens for connections on a TCP port. -func listenTCP(port int, timeout time.Duration) error { +func listenTCP(ctx context.Context, port int) error { localAddr := net.TCPAddr{ Port: port, } @@ -153,17 +139,23 @@ func listenTCP(port int, timeout time.Duration) error { defer lConn.Close() // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - conn, err := lConn.AcceptTCP() - if err != nil { + ch := make(chan error) + go func() { + conn, err := lConn.AcceptTCP() + ch <- err + conn.Close() + }() + + select { + case err := <-ch: return err + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err()) } - conn.Close() - return nil } // connectTCP connects to the given IP and port from an ephemeral local address. -func connectTCP(ip net.IP, port int, timeout time.Duration) error { +func connectTCP(ctx context.Context, ip net.IP, port int) error { contAddr := net.TCPAddr{ IP: ip, Port: port, @@ -171,13 +163,14 @@ func connectTCP(ip net.IP, port int, timeout time.Duration) error { // The container may not be listening when we first connect, so retry // upon error. callback := func() error { - conn, err := net.DialTimeout("tcp", contAddr.String(), timeout) + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", contAddr.String()) if conn != nil { conn.Close() } return err } - if err := testutil.Poll(callback, timeout); err != nil { + if err := testutil.PollContext(ctx, callback); err != nil { return fmt.Errorf("timed out waiting to connect IP on port %v, most recent error: %v", port, err) } diff --git a/test/iptables/nat.go b/test/iptables/nat.go index b7fea2527..dd9a18339 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -15,11 +15,11 @@ package iptables import ( + "context" "errors" "fmt" "net" "syscall" - "time" ) const redirectPort = 42 @@ -46,7 +46,7 @@ func init() { } // NATPreRedirectUDPPort tests that packets are redirected to different port. -type NATPreRedirectUDPPort struct{} +type NATPreRedirectUDPPort struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectUDPPort) Name() string { @@ -54,12 +54,12 @@ func (NATPreRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } - if err := listenUDP(redirectPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, redirectPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err) } @@ -67,12 +67,12 @@ func (NATPreRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectUDPPort) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATPreRedirectTCPPort tests that connections are redirected on specified ports. -type NATPreRedirectTCPPort struct{} +type NATPreRedirectTCPPort struct{ baseCase } // Name implements TestCase.Name. func (NATPreRedirectTCPPort) Name() string { @@ -80,23 +80,23 @@ func (NATPreRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } // Listen for TCP packets on redirect port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectTCPPort) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, dropPort, sendloopDuration) +func (NATPreRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) } // NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't // affected by PREROUTING connection tracking. -type NATPreRedirectTCPOutgoing struct{} +type NATPreRedirectTCPOutgoing struct{ baseCase } // Name implements TestCase.Name. func (NATPreRedirectTCPOutgoing) Name() string { @@ -104,24 +104,24 @@ func (NATPreRedirectTCPOutgoing) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectTCPOutgoing) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectTCPOutgoing) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect all incoming TCP traffic to a closed port. if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } // Establish a connection to the host process. - return connectTCP(ip, acceptPort, sendloopDuration) + return connectTCP(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectTCPOutgoing) LocalAction(ip net.IP, ipv6 bool) error { - return listenTCP(acceptPort, sendloopDuration) +func (NATPreRedirectTCPOutgoing) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenTCP(ctx, acceptPort) } // NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't // affected by OUTPUT connection tracking. -type NATOutRedirectTCPIncoming struct{} +type NATOutRedirectTCPIncoming struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectTCPIncoming) Name() string { @@ -129,23 +129,23 @@ func (NATOutRedirectTCPIncoming) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectTCPIncoming) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPIncoming) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect all outgoing TCP traffic to a closed port. if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } // Establish a connection to the host process. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectTCPIncoming) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (NATOutRedirectTCPIncoming) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // NATOutRedirectUDPPort tests that packets are redirected to different port. -type NATOutRedirectUDPPort struct{} +type NATOutRedirectUDPPort struct{ containerCase } // Name implements TestCase.Name. func (NATOutRedirectUDPPort) Name() string { @@ -153,19 +153,19 @@ func (NATOutRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) +func (NATOutRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectUDPPort) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATDropUDP tests that packets are not received in ports other than redirect // port. -type NATDropUDP struct{} +type NATDropUDP struct{ containerCase } // Name implements TestCase.Name. func (NATDropUDP) Name() string { @@ -173,25 +173,29 @@ func (NATDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } - if err := listenUDP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (NATDropUDP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATAcceptAll tests that all UDP packets are accepted. -type NATAcceptAll struct{} +type NATAcceptAll struct{ containerCase } // Name implements TestCase.Name. func (NATAcceptAll) Name() string { @@ -199,12 +203,12 @@ func (NATAcceptAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATAcceptAll) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATAcceptAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { return err } - if err := listenUDP(acceptPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, acceptPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) } @@ -212,13 +216,13 @@ func (NATAcceptAll) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATAcceptAll) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATAcceptAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATOutRedirectIP uses iptables to select packets based on destination IP and // redirects them. -type NATOutRedirectIP struct{} +type NATOutRedirectIP struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectIP) Name() string { @@ -226,9 +230,9 @@ func (NATOutRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-d", nowhereIP(ipv6), "-p", "udp", @@ -236,14 +240,14 @@ func (NATOutRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATOutDontRedirectIP tests that iptables matching with "-d" does not match // packets it shouldn't. -type NATOutDontRedirectIP struct{} +type NATOutDontRedirectIP struct{ localCase } // Name implements TestCase.Name. func (NATOutDontRedirectIP) Name() string { @@ -251,20 +255,20 @@ func (NATOutDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutDontRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "OUTPUT", "-d", localIP(ipv6), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATOutDontRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (NATOutDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // NATOutRedirectInvert tests that iptables can match with "! -d". -type NATOutRedirectInvert struct{} +type NATOutRedirectInvert struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectInvert) Name() string { @@ -272,13 +276,13 @@ func (NATOutRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. dest := "192.0.2.2" if ipv6 { dest = "2001:db8::2" } - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "!", "-d", dest, "-p", "udp", @@ -286,14 +290,14 @@ func (NATOutRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectInvert) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATPreRedirectIP tests that we can use iptables to select packets based on // destination IP and redirect them. -type NATPreRedirectIP struct{} +type NATPreRedirectIP struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectIP) Name() string { @@ -301,7 +305,7 @@ func (NATPreRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { addrs, err := localAddrs(ipv6) if err != nil { return err @@ -314,17 +318,17 @@ func (NATPreRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { if err := natTableRules(ipv6, rules); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // NATPreDontRedirectIP tests that iptables matching with "-d" does not match // packets it shouldn't. -type NATPreDontRedirectIP struct{} +type NATPreDontRedirectIP struct{ containerCase } // Name implements TestCase.Name. func (NATPreDontRedirectIP) Name() string { @@ -332,20 +336,20 @@ func (NATPreDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreDontRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreDontRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATPreRedirectInvert tests that iptables can match with "! -d". -type NATPreRedirectInvert struct{} +type NATPreRedirectInvert struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectInvert) Name() string { @@ -353,21 +357,21 @@ func (NATPreRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "!", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectInvert) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a // protocol to be specified with -p. -type NATRedirectRequiresProtocol struct{} +type NATRedirectRequiresProtocol struct{ baseCase } // Name implements TestCase.Name. func (NATRedirectRequiresProtocol) Name() string { @@ -375,7 +379,7 @@ func (NATRedirectRequiresProtocol) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATRedirectRequiresProtocol) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { return errors.New("expected an error using REDIRECT --to-ports without a protocol") } @@ -383,13 +387,13 @@ func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATRedirectRequiresProtocol) LocalAction(ip net.IP, ipv6 bool) error { +func (NATRedirectRequiresProtocol) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATOutRedirectTCPPort tests that connections are redirected on specified ports. -type NATOutRedirectTCPPort struct{} +type NATOutRedirectTCPPort struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectTCPPort) Name() string { @@ -397,12 +401,11 @@ func (NATOutRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - timeout := 20 * time.Second localAddr := net.TCPAddr{ IP: net.ParseIP(localIP(ipv6)), Port: acceptPort, @@ -416,9 +419,7 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { defer lConn.Close() // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - err = connectTCP(ip, dropPort, timeout) - if err != nil { + if err := connectTCP(ctx, ip, dropPort); err != nil { return err } @@ -432,13 +433,13 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectTCPPort) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { return nil } // NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't // affected by PREROUTING rules. -type NATLoopbackSkipsPrerouting struct{} +type NATLoopbackSkipsPrerouting struct{ baseCase } // Name implements TestCase.Name. func (NATLoopbackSkipsPrerouting) Name() string { @@ -446,7 +447,7 @@ func (NATLoopbackSkipsPrerouting) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect anything sent to localhost to an unused port. dest := []byte{127, 0, 0, 1} if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { @@ -457,24 +458,24 @@ func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP, ipv6 bool) error { // loopback traffic, the connection would fail. sendCh := make(chan error) go func() { - sendCh <- connectTCP(dest, acceptPort, sendloopDuration) + sendCh <- connectTCP(ctx, dest, acceptPort) }() - if err := listenTCP(acceptPort, sendloopDuration); err != nil { + if err := listenTCP(ctx, acceptPort); err != nil { return err } return <-sendCh } // LocalAction implements TestCase.LocalAction. -func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP, ipv6 bool) error { +func (NATLoopbackSkipsPrerouting) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination // of PREROUTING NATted packets. -type NATPreOriginalDst struct{} +type NATPreOriginalDst struct{ baseCase } // Name implements TestCase.Name. func (NATPreOriginalDst) Name() string { @@ -482,7 +483,7 @@ func (NATPreOriginalDst) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect incoming TCP connections to acceptPort. if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", @@ -495,17 +496,17 @@ func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { if err != nil { return err } - return listenForRedirectedConn(ipv6, addrs) + return listenForRedirectedConn(ctx, ipv6, addrs) } // LocalAction implements TestCase.LocalAction. -func (NATPreOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, dropPort, sendloopDuration) +func (NATPreOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) } // NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination // of OUTBOUND NATted packets. -type NATOutOriginalDst struct{} +type NATOutOriginalDst struct{ baseCase } // Name implements TestCase.Name. func (NATOutOriginalDst) Name() string { @@ -513,7 +514,7 @@ func (NATOutOriginalDst) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect incoming TCP connections to acceptPort. if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { return err @@ -521,22 +522,22 @@ func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { connCh := make(chan error) go func() { - connCh <- connectTCP(ip, dropPort, sendloopDuration) + connCh <- connectTCP(ctx, ip, dropPort) }() - if err := listenForRedirectedConn(ipv6, []net.IP{ip}); err != nil { + if err := listenForRedirectedConn(ctx, ipv6, []net.IP{ip}); err != nil { return err } return <-connCh } // LocalAction implements TestCase.LocalAction. -func (NATOutOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } -func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { +func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error { // The net package doesn't give guarantee access to the connection's // underlying FD, and thus we cannot call getsockopt. We have to use // traditional syscalls for SO_ORIGINAL_DST. @@ -572,16 +573,32 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { return err } - connfd, _, err := syscall.Accept(sockfd) - if err != nil { + // Block on accept() in another goroutine. + connCh := make(chan int) + errCh := make(chan error) + go func() { + connFD, _, err := syscall.Accept(sockfd) + if err != nil { + errCh <- err + } + connCh <- connFD + }() + + // Wait for accept() to return or for the context to finish. + var connFD int + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: return err + case connFD = <-connCh: } - defer syscall.Close(connfd) + defer syscall.Close(connFD) // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST // indicates the packet was sent to originalDst:dropPort. if ipv6 { - got, err := originalDestination6(connfd) + got, err := originalDestination6(connFD) if err != nil { return err } @@ -598,7 +615,7 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { } return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) } else { - got, err := originalDestination4(connfd) + got, err := originalDestination4(connFD) if err != nil { return err } @@ -619,26 +636,22 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { // loopbackTests runs an iptables rule and ensures that packets sent to // dest:dropPort are received by localhost:acceptPort. -func loopbackTest(ipv6 bool, dest net.IP, args ...string) error { +func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) error { if err := natTable(ipv6, args...); err != nil { return err } - sendCh := make(chan error) - listenCh := make(chan error) + sendCh := make(chan error, 1) + listenCh := make(chan error, 1) go func() { - sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration) + sendCh <- sendUDPLoop(ctx, dest, dropPort) }() go func() { - listenCh <- listenUDP(acceptPort, sendloopDuration) + listenCh <- listenUDP(ctx, acceptPort) }() select { case err := <-listenCh: - if err != nil { - return err - } - case <-time.After(sendloopDuration): - return errors.New("timed out") + return err + case err := <-sendCh: + return err } - // sendCh will always take the full sendloop time. - return <-sendCh } diff --git a/test/iptables/runner/main.go b/test/iptables/runner/main.go index 69d3ef121..9ae2d1b4d 100644 --- a/test/iptables/runner/main.go +++ b/test/iptables/runner/main.go @@ -16,6 +16,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -46,7 +47,9 @@ func main() { } // Run the test. - if err := test.ContainerAction(ip, *ipv6); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := test.ContainerAction(ctx, ip, *ipv6); err != nil { log.Fatalf("Failed running test %q: %v", *name, err) } diff --git a/test/packetdrill/BUILD b/test/packetdrill/BUILD index dfcd55f60..49642f282 100644 --- a/test/packetdrill/BUILD +++ b/test/packetdrill/BUILD @@ -1,4 +1,5 @@ -load("defs.bzl", "packetdrill_test") +load("//tools:defs.bzl", "bzl_library") +load("//test/packetdrill:defs.bzl", "packetdrill_test") package(licenses = ["notice"]) @@ -36,3 +37,9 @@ packetdrill_test( name = "tcp_defer_accept_timeout_test", scripts = ["tcp_defer_accept_timeout.pkt"], ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/test/packetimpact/README.md b/test/packetimpact/README.md index f46c67a0c..ffa96ba98 100644 --- a/test/packetimpact/README.md +++ b/test/packetimpact/README.md @@ -30,7 +30,7 @@ $ make load-packetimpact Run a test, e.g. `fin_wait2_timeout`, against Linux: ```bash -$ bazel test //test/packetimpact/tests:fin_wait2_timeout_linux_test +$ bazel test //test/packetimpact/tests:fin_wait2_timeout_native_test ``` Run the same test, but against gVisor: diff --git a/test/packetimpact/dut/BUILD b/test/packetimpact/dut/BUILD index 3ce63c2c6..ccf1c735f 100644 --- a/test/packetimpact/dut/BUILD +++ b/test/packetimpact/dut/BUILD @@ -16,3 +16,13 @@ cc_binary( "//test/packetimpact/proto:posix_server_cc_proto", ], ) + +cc_binary( + name = "posix_server_dynamic", + srcs = ["posix_server.cc"], + deps = [ + grpcpp, + "//test/packetimpact/proto:posix_server_cc_grpc_proto", + "//test/packetimpact/proto:posix_server_cc_proto", + ], +) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index 29d4cc6fe..4de8540f6 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -21,6 +21,7 @@ #include <string.h> #include <sys/socket.h> #include <sys/types.h> +#include <time.h> #include <unistd.h> #include <iostream> @@ -28,6 +29,7 @@ #include "include/grpcpp/security/server_credentials.h" #include "include/grpcpp/server_builder.h" +#include "include/grpcpp/server_context.h" #include "test/packetimpact/proto/posix_server.grpc.pb.h" #include "test/packetimpact/proto/posix_server.pb.h" @@ -108,18 +110,20 @@ } class PosixImpl final : public posix_server::Posix::Service { - ::grpc::Status Accept(grpc_impl::ServerContext *context, + ::grpc::Status Accept(grpc::ServerContext *context, const ::posix_server::AcceptRequest *request, ::posix_server::AcceptResponse *response) override { sockaddr_storage addr; socklen_t addrlen = sizeof(addr); response->set_fd(accept(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); - response->set_errno_(errno); + if (response->fd() < 0) { + response->set_errno_(errno); + } return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); } - ::grpc::Status Bind(grpc_impl::ServerContext *context, + ::grpc::Status Bind(grpc::ServerContext *context, const ::posix_server::BindRequest *request, ::posix_server::BindResponse *response) override { if (!request->has_addr()) { @@ -136,19 +140,23 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret( bind(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Close(grpc_impl::ServerContext *context, + ::grpc::Status Close(grpc::ServerContext *context, const ::posix_server::CloseRequest *request, ::posix_server::CloseResponse *response) override { response->set_ret(close(request->fd())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Connect(grpc_impl::ServerContext *context, + ::grpc::Status Connect(grpc::ServerContext *context, const ::posix_server::ConnectRequest *request, ::posix_server::ConnectResponse *response) override { if (!request->has_addr()) { @@ -164,32 +172,38 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret(connect(request->sockfd(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Fcntl(grpc_impl::ServerContext *context, + ::grpc::Status Fcntl(grpc::ServerContext *context, const ::posix_server::FcntlRequest *request, ::posix_server::FcntlResponse *response) override { response->set_ret(::fcntl(request->fd(), request->cmd(), request->arg())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } ::grpc::Status GetSockName( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::GetSockNameRequest *request, ::posix_server::GetSockNameResponse *response) override { sockaddr_storage addr; socklen_t addrlen = sizeof(addr); response->set_ret(getsockname( request->sockfd(), reinterpret_cast<sockaddr *>(&addr), &addrlen)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return sockaddr_to_proto(addr, addrlen, response->mutable_addr()); } ::grpc::Status GetSockOpt( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::GetSockOptRequest *request, ::posix_server::GetSockOptResponse *response) override { switch (request->type()) { @@ -226,15 +240,19 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown SockOpt Type"); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Listen(grpc_impl::ServerContext *context, + ::grpc::Status Listen(grpc::ServerContext *context, const ::posix_server::ListenRequest *request, ::posix_server::ListenResponse *response) override { response->set_ret(listen(request->sockfd(), request->backlog())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -243,7 +261,9 @@ class PosixImpl final : public posix_server::Posix::Service { ::posix_server::SendResponse *response) override { response->set_ret(::send(request->sockfd(), request->buf().data(), request->buf().size(), request->flags())); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -264,12 +284,14 @@ class PosixImpl final : public posix_server::Posix::Service { response->set_ret(::sendto(request->sockfd(), request->buf().data(), request->buf().size(), request->flags(), reinterpret_cast<sockaddr *>(&addr), addr_len)); - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } ::grpc::Status SetSockOpt( - grpc_impl::ServerContext *context, + grpc::ServerContext *context, const ::posix_server::SetSockOptRequest *request, ::posix_server::SetSockOptResponse *response) override { switch (request->optval().val_case()) { @@ -286,9 +308,9 @@ class PosixImpl final : public posix_server::Posix::Service { break; } case ::posix_server::SockOptVal::kTimeval: { - timeval tv = {.tv_sec = static_cast<__time_t>( + timeval tv = {.tv_sec = static_cast<time_t>( request->optval().timeval().seconds()), - .tv_usec = static_cast<__suseconds_t>( + .tv_usec = static_cast<suseconds_t>( request->optval().timeval().microseconds())}; response->set_ret(setsockopt(request->sockfd(), request->level(), request->optname(), &tv, sizeof(tv))); @@ -298,16 +320,29 @@ class PosixImpl final : public posix_server::Posix::Service { return ::grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Unknown SockOpt Type"); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } - ::grpc::Status Socket(grpc_impl::ServerContext *context, + ::grpc::Status Socket(grpc::ServerContext *context, const ::posix_server::SocketRequest *request, ::posix_server::SocketResponse *response) override { response->set_fd( socket(request->domain(), request->type(), request->protocol())); - response->set_errno_(errno); + if (response->fd() < 0) { + response->set_errno_(errno); + } + return ::grpc::Status::OK; + } + + ::grpc::Status Shutdown(grpc::ServerContext *context, + const ::posix_server::ShutdownRequest *request, + ::posix_server::ShutdownResponse *response) override { + if (shutdown(request->fd(), request->how()) < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } @@ -320,7 +355,9 @@ class PosixImpl final : public posix_server::Posix::Service { if (response->ret() >= 0) { response->set_buf(buf.data(), response->ret()); } - response->set_errno_(errno); + if (response->ret() < 0) { + response->set_errno_(errno); + } return ::grpc::Status::OK; } }; diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto index ccd20b10d..f32ed54ef 100644 --- a/test/packetimpact/proto/posix_server.proto +++ b/test/packetimpact/proto/posix_server.proto @@ -188,6 +188,15 @@ message SocketResponse { int32 errno_ = 2; // "errno" may fail to compile in c++. } +message ShutdownRequest { + int32 fd = 1; + int32 how = 2; +} + +message ShutdownResponse { + int32 errno_ = 1; // "errno" may fail to compile in c++. +} + message RecvRequest { int32 sockfd = 1; int32 len = 2; @@ -225,6 +234,8 @@ service Posix { rpc SetSockOpt(SetSockOptRequest) returns (SetSockOptResponse); // Call socket() on the DUT. rpc Socket(SocketRequest) returns (SocketResponse); + // Call shutdown() on the DUT. + rpc Shutdown(ShutdownRequest) returns (ShutdownResponse); // Call recv() on the DUT. rpc Recv(RecvRequest) returns (RecvResponse); } diff --git a/test/packetimpact/runner/BUILD b/test/packetimpact/runner/BUILD index bad4f0183..605dd4972 100644 --- a/test/packetimpact/runner/BUILD +++ b/test/packetimpact/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_test") +load("//tools:defs.bzl", "bzl_library", "go_library", "go_test") package( default_visibility = ["//test/packetimpact:__subpackages__"], @@ -7,12 +7,28 @@ package( go_test( name = "packetimpact_test", - srcs = ["packetimpact_test.go"], + srcs = [ + "packetimpact_test.go", + ], tags = [ # Not intended to be run directly. "local", "manual", ], + deps = [":runner"], +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//test/packetimpact:__subpackages__"], +) + +go_library( + name = "runner", + testonly = True, + srcs = ["dut.go"], + visibility = ["//test/packetimpact:__subpackages__"], deps = [ "//pkg/test/dockerutil", "//test/packetimpact/netdevs", diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 79b3c9162..f56d3c42e 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -23,8 +23,9 @@ def _packetimpact_test_impl(ctx): transitive_files = [] if hasattr(ctx.attr._test_runner, "data_runfiles"): transitive_files.append(ctx.attr._test_runner.data_runfiles.files) + files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server runfiles = ctx.runfiles( - files = [test_runner] + ctx.files.testbench_binary + ctx.files._posix_server_binary, + files = files, transitive_files = depset(transitive = transitive_files), collect_default = True, collect_data = True, @@ -38,7 +39,7 @@ _packetimpact_test = rule( cfg = "target", default = ":packetimpact_test", ), - "_posix_server_binary": attr.label( + "_posix_server": attr.label( cfg = "target", default = "//test/packetimpact/dut:posix_server", ), @@ -61,12 +62,12 @@ PACKETIMPACT_TAGS = [ "packetimpact", ] -def packetimpact_linux_test( +def packetimpact_native_test( name, testbench_binary, expect_failure = False, **kwargs): - """Add a packetimpact test on linux. + """Add a native packetimpact test. Args: name: name of the test @@ -76,9 +77,9 @@ def packetimpact_linux_test( """ expect_failure_flag = ["--expect_failure"] if expect_failure else [] _packetimpact_test( - name = name + "_linux_test", + name = name + "_native_test", testbench_binary = testbench_binary, - flags = ["--dut_platform", "linux"] + expect_failure_flag, + flags = ["--native"] + expect_failure_flag, tags = PACKETIMPACT_TAGS, **kwargs ) @@ -102,21 +103,21 @@ def packetimpact_netstack_test( _packetimpact_test( name = name + "_netstack_test", testbench_binary = testbench_binary, - # This is the default runtime unless - # "--test_arg=--runtime=OTHER_RUNTIME" is used to override the value. - flags = ["--dut_platform", "netstack", "--runtime=runsc-d"] + expect_failure_flag, + # Note that a distinct runtime must be provided in the form + # --test_arg=--runtime=other when invoking bazel. + flags = expect_failure_flag, tags = PACKETIMPACT_TAGS, **kwargs ) -def packetimpact_go_test(name, size = "small", pure = True, expect_linux_failure = False, expect_netstack_failure = False, **kwargs): +def packetimpact_go_test(name, size = "small", pure = True, expect_native_failure = False, expect_netstack_failure = False, **kwargs): """Add packetimpact tests written in go. Args: name: name of the test size: size of the test pure: make a static go binary - expect_linux_failure: the test must fail for Linux + expect_native_failure: the test must fail natively expect_netstack_failure: the test must fail for Netstack **kwargs: all the other args, forwarded to go_test """ @@ -125,15 +126,16 @@ def packetimpact_go_test(name, size = "small", pure = True, expect_linux_failure name = testbench_binary, size = size, pure = pure, + nogo = False, # FIXME(gvisor.dev/issue/3374): Not working with all build systems. tags = [ "local", "manual", ], **kwargs ) - packetimpact_linux_test( + packetimpact_native_test( name = name, - expect_failure = expect_linux_failure, + expect_failure = expect_native_failure, testbench_binary = testbench_binary, ) packetimpact_netstack_test( diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go new file mode 100644 index 000000000..59bb68eb1 --- /dev/null +++ b/test/packetimpact/runner/dut.go @@ -0,0 +1,442 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package runner starts docker containers and networking for a packetimpact test. +package runner + +import ( + "context" + "flag" + "fmt" + "io/ioutil" + "log" + "math/rand" + "net" + "os" + "os/exec" + "path" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/docker/docker/api/types/mount" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/test/packetimpact/netdevs" +) + +// stringList implements flag.Value. +type stringList []string + +// String implements flag.Value.String. +func (l *stringList) String() string { + return strings.Join(*l, ",") +} + +// Set implements flag.Value.Set. +func (l *stringList) Set(value string) error { + *l = append(*l, value) + return nil +} + +var ( + native = false + testbenchBinary = "" + tshark = false + extraTestArgs = stringList{} + expectFailure = false + + // DutAddr is the IP addres for DUT. + DutAddr = net.IPv4(0, 0, 0, 10) + testbenchAddr = net.IPv4(0, 0, 0, 20) +) + +// RegisterFlags defines flags and associates them with the package-level +// exported variables above. It should be called by tests in their init +// functions. +func RegisterFlags(fs *flag.FlagSet) { + fs.BoolVar(&native, "native", false, "whether the test should be run natively") + fs.StringVar(&testbenchBinary, "testbench_binary", "", "path to the testbench binary") + fs.BoolVar(&tshark, "tshark", false, "use more verbose tshark in logs instead of tcpdump") + fs.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench") + fs.BoolVar(&expectFailure, "expect_failure", false, "expect that the test will fail when run") +} + +// CtrlPort is the port that posix_server listens on. +const CtrlPort = "40000" + +// logger implements testutil.Logger. +// +// Labels logs based on their source and formats multi-line logs. +type logger string + +// Name implements testutil.Logger.Name. +func (l logger) Name() string { + return string(l) +} + +// Logf implements testutil.Logger.Logf. +func (l logger) Logf(format string, args ...interface{}) { + lines := strings.Split(fmt.Sprintf(format, args...), "\n") + log.Printf("%s: %s", l, lines[0]) + for _, line := range lines[1:] { + log.Printf("%*s %s", len(l), "", line) + } +} + +// TestWithDUT runs a packetimpact test with the given information. +func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Container) DUT, containerAddr net.IP) { + if testbenchBinary == "" { + t.Fatal("--testbench_binary is missing") + } + dockerutil.EnsureSupportedDockerVersion() + + // Create the networks needed for the test. One control network is needed for + // the gRPC control packets and one test network on which to transmit the test + // packets. + ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet")) + testNet := dockerutil.NewNetwork(ctx, logger("testNet")) + for _, dn := range []*dockerutil.Network{ctrlNet, testNet} { + for { + if err := createDockerNetwork(ctx, dn); err != nil { + t.Log("creating docker network:", err) + const wait = 100 * time.Millisecond + t.Logf("sleeping %s and will try creating docker network again", wait) + // This can fail if another docker network claimed the same IP so we'll + // just try again. + time.Sleep(wait) + continue + } + break + } + dn := dn + t.Cleanup(func() { + if err := dn.Cleanup(ctx); err != nil { + t.Errorf("unable to cleanup container %s: %s", dn.Name, err) + } + }) + // Sanity check. + if inspect, err := dn.Inspect(ctx); err != nil { + t.Fatalf("failed to inspect network %s: %v", dn.Name, err) + } else if inspect.Name != dn.Name { + t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name) + } + } + + tmpDir, err := ioutil.TempDir("", "container-output") + if err != nil { + t.Fatal("creating temp dir:", err) + } + t.Cleanup(func() { + if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil { + t.Errorf("unable to copy container output files: %s", err) + } + if err := os.RemoveAll(tmpDir); err != nil { + t.Errorf("failed to remove tmpDir %s: %s", tmpDir, err) + } + }) + + const testOutputDir = "/tmp/testoutput" + + // Create the Docker container for the DUT. + var dut *dockerutil.Container + if native { + dut = dockerutil.MakeNativeContainer(ctx, logger("dut")) + } else { + dut = dockerutil.MakeContainer(ctx, logger("dut")) + } + t.Cleanup(func() { + dut.CleanUp(ctx) + }) + + runOpts := dockerutil.RunOpts{ + Image: "packetimpact", + CapAdd: []string{"NET_ADMIN"}, + Mounts: []mount.Mount{{ + Type: mount.TypeBind, + Source: tmpDir, + Target: testOutputDir, + ReadOnly: false, + }}, + } + + device := mkDevice(dut) + remoteIPv6, remoteMAC, dutDeviceID, dutTestNetDev := device.Prepare(ctx, t, runOpts, ctrlNet, testNet, containerAddr) + + // Create the Docker container for the testbench. + testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) + + tbb := path.Base(testbenchBinary) + containerTestbenchBinary := filepath.Join("/packetimpact", tbb) + testbench.CopyFiles(&runOpts, "/packetimpact", filepath.Join("test/packetimpact/tests", tbb)) + + // snifferNetDev is a network device on the test orchestrator that we will + // run sniffer (tcpdump or tshark) on and inject traffic to, not to be + // confused with the device on the DUT. + const snifferNetDev = "eth2" + // Run tcpdump in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs := []string{ + "tcpdump", + "-S", "-vvv", "-U", "-n", + "-i", snifferNetDev, + "-w", testOutputDir + "/dump.pcap", + } + snifferRegex := "tcpdump: listening.*\n" + if tshark { + // Run tshark in the test bench unbuffered, without DNS resolution, just on + // the interface with the test packets. + snifferArgs = []string{ + "tshark", "-V", "-l", "-n", "-i", snifferNetDev, + "-o", "tcp.check_checksum:TRUE", + "-o", "udp.check_checksum:TRUE", + } + snifferRegex = "Capturing on.*\n" + } + + if err := StartContainer( + ctx, + runOpts, + testbench, + testbenchAddr, + []*dockerutil.Network{ctrlNet, testNet}, + snifferArgs..., + ); err != nil { + t.Fatalf("failed to start docker container for testbench sniffer: %s", err) + } + // Kill so that it will flush output. + t.Cleanup(func() { + time.Sleep(1 * time.Second) + testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0]) + }) + + if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil { + t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) + } + + // When the Linux kernel receives a SYN-ACK for a SYN it didn't send, it + // will respond with an RST. In most packetimpact tests, the SYN is sent + // by the raw socket and the kernel knows nothing about the connection, this + // behavior will break lots of TCP related packetimpact tests. To prevent + // this, we can install the following iptables rules. The raw socket that + // packetimpact tests use will still be able to see everything. + for _, bin := range []string{"iptables", "ip6tables"} { + if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", snifferNetDev, "-p", "tcp", "-j", "DROP"); err != nil { + t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) + } + } + + // FIXME(b/156449515): Some piece of the system has a race. The old + // bash script version had a sleep, so we have one too. The race should + // be fixed and this sleep removed. + time.Sleep(time.Second) + + // Start a packetimpact test on the test bench. The packetimpact test sends + // and receives packets and also sends POSIX socket commands to the + // posix_server to be executed on the DUT. + testArgs := []string{containerTestbenchBinary} + testArgs = append(testArgs, extraTestArgs...) + testArgs = append(testArgs, + "--posix_server_ip", AddressInSubnet(DutAddr, *ctrlNet.Subnet).String(), + "--posix_server_port", CtrlPort, + "--remote_ipv4", AddressInSubnet(DutAddr, *testNet.Subnet).String(), + "--local_ipv4", AddressInSubnet(testbenchAddr, *testNet.Subnet).String(), + "--remote_ipv6", remoteIPv6.String(), + "--remote_mac", remoteMAC.String(), + "--remote_interface_id", fmt.Sprintf("%d", dutDeviceID), + "--local_device", snifferNetDev, + "--remote_device", dutTestNetDev, + fmt.Sprintf("--native=%t", native), + ) + testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) + if (err != nil) != expectFailure { + var dutLogs string + if logs, err := device.Logs(ctx); err != nil { + dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err) + } else { + dutLogs = logs + } + + t.Errorf(`test error: %v, expect failure: %t + +%s + +====== Begin of Testbench Logs ====== + +%s + +====== End of Testbench Logs ======`, + err, expectFailure, dutLogs, testbenchLogs) + } +} + +// DUT describes how to setup/teardown the dut for packetimpact tests. +type DUT interface { + // Prepare prepares the dut, starts posix_server and returns the IPv6, MAC + // address, the interface ID, and the interface name for the testNet on DUT. + Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string) + // Logs retrieves the logs from the dut. + Logs(ctx context.Context) (string, error) +} + +// DockerDUT describes a docker based DUT. +type DockerDUT struct { + c *dockerutil.Container +} + +// NewDockerDUT creates a docker based DUT. +func NewDockerDUT(c *dockerutil.Container) DUT { + return &DockerDUT{ + c: c, + } +} + +// Prepare implements DUT.Prepare. +func (dut *DockerDUT) Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network, containerAddr net.IP) (net.IP, net.HardwareAddr, uint32, string) { + const containerPosixServerBinary = "/packetimpact/posix_server" + dut.c.CopyFiles(&runOpts, "/packetimpact", "test/packetimpact/dut/posix_server") + + if err := StartContainer( + ctx, + runOpts, + dut.c, + containerAddr, + []*dockerutil.Network{ctrlNet, testNet}, + containerPosixServerBinary, + "--ip=0.0.0.0", + "--port="+CtrlPort, + ); err != nil { + t.Fatalf("failed to start docker container for DUT: %s", err) + } + + if _, err := dut.c.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil { + t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.c.Name, err) + } + + dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + + remoteMAC := dutDeviceInfo.MAC + remoteIPv6 := dutDeviceInfo.IPv6Addr + // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if + // needed. + if remoteIPv6 == nil { + if _, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { + t.Fatalf("unable to ip addr add on container %s: %s", dut.c.Name, err) + } + // Now try again, to make sure that it worked. + _, dutDeviceInfo, err = deviceByIP(ctx, dut.c, AddressInSubnet(containerAddr, *testNet.Subnet)) + if err != nil { + t.Fatal(err) + } + remoteIPv6 = dutDeviceInfo.IPv6Addr + if remoteIPv6 == nil { + t.Fatalf("unable to set IPv6 address on container %s", dut.c.Name) + } + } + const testNetDev = "eth2" + + return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev +} + +// Logs implements DUT.Logs. +func (dut *DockerDUT) Logs(ctx context.Context) (string, error) { + logs, err := dut.c.Logs(ctx) + if err != nil { + return "", err + } + return fmt.Sprintf(`====== Begin of DUT Logs ====== + +%s + +====== End of DUT Logs ======`, logs), nil +} + +// AddNetworks connects docker network with the container and assigns the specific IP. +func AddNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error { + for _, dn := range networks { + ip := AddressInSubnet(addr, *dn.Subnet) + // Connect to the network with the specified IP address. + if err := dn.Connect(ctx, d, ip.String(), ""); err != nil { + return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) + } + } + return nil +} + +// AddressInSubnet combines the subnet provided with the address and returns a +// new address. The return address bits come from the subnet where the mask is 1 +// and from the ip address where the mask is 0. +func AddressInSubnet(addr net.IP, subnet net.IPNet) net.IP { + var octets []byte + for i := 0; i < 4; i++ { + octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i]))) + } + return net.IP(octets) +} + +// deviceByIP finds a deviceInfo and device name from an IP address. +func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) { + out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show") + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w\n%s", d.Name, err, out) + } + devs, err := netdevs.ParseDevices(out) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w\n%s", d.Name, err, out) + } + testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs) + if err != nil { + return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err) + } + return testDevice, deviceInfo, nil +} + +// createDockerNetwork makes a randomly-named network that will start with the +// namePrefix. The network will be a random /24 subnet. +func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { + randSource := rand.NewSource(time.Now().UnixNano()) + r1 := rand.New(randSource) + // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. + ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0) + n.Subnet = &net.IPNet{ + IP: ip, + Mask: ip.DefaultMask(), + } + return n.Create(ctx) +} + +// StartContainer will create a container instance from runOpts, connect it +// with the specified docker networks and start executing the specified cmd. +func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, cmd ...string) error { + conf, hostconf, netconf := c.ConfigsFrom(runOpts, cmd...) + _ = netconf + hostconf.AutoRemove = true + hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + + if err := c.CreateFrom(ctx, conf, hostconf, nil); err != nil { + return fmt.Errorf("unable to create container %s: %w", c.Name, err) + } + + if err := AddNetworks(ctx, c, containerAddr, ns); err != nil { + return fmt.Errorf("unable to connect the container with the networks: %w", err) + } + + if err := c.Start(ctx); err != nil { + return fmt.Errorf("unable to start container %s: %w", c.Name, err) + } + return nil +} diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go index 74e1e6def..c598bfc29 100644 --- a/test/packetimpact/runner/packetimpact_test.go +++ b/test/packetimpact/runner/packetimpact_test.go @@ -18,372 +18,15 @@ package packetimpact_test import ( "context" "flag" - "fmt" - "io/ioutil" - "log" - "math/rand" - "net" - "os" - "os/exec" - "path" - "strings" "testing" - "time" - "github.com/docker/docker/api/types/mount" - "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/test/packetimpact/netdevs" + "gvisor.dev/gvisor/test/packetimpact/runner" ) -// stringList implements flag.Value. -type stringList []string - -// String implements flag.Value.String. -func (l *stringList) String() string { - return strings.Join(*l, ",") -} - -// Set implements flag.Value.Set. -func (l *stringList) Set(value string) error { - *l = append(*l, value) - return nil -} - -var ( - dutPlatform = flag.String("dut_platform", "", "either \"linux\" or \"netstack\"") - testbenchBinary = flag.String("testbench_binary", "", "path to the testbench binary") - tshark = flag.Bool("tshark", false, "use more verbose tshark in logs instead of tcpdump") - extraTestArgs = stringList{} - expectFailure = flag.Bool("expect_failure", false, "expect that the test will fail when run") - - dutAddr = net.IPv4(0, 0, 0, 10) - testbenchAddr = net.IPv4(0, 0, 0, 20) -) - -const ctrlPort = "40000" - -// logger implements testutil.Logger. -// -// Labels logs based on their source and formats multi-line logs. -type logger string - -// Name implements testutil.Logger.Name. -func (l logger) Name() string { - return string(l) -} - -// Logf implements testutil.Logger.Logf. -func (l logger) Logf(format string, args ...interface{}) { - lines := strings.Split(fmt.Sprintf(format, args...), "\n") - log.Printf("%s: %s", l, lines[0]) - for _, line := range lines[1:] { - log.Printf("%*s %s", len(l), "", line) - } +func init() { + runner.RegisterFlags(flag.CommandLine) } func TestOne(t *testing.T) { - flag.Var(&extraTestArgs, "extra_test_arg", "extra arguments to pass to the testbench") - flag.Parse() - if *dutPlatform != "linux" && *dutPlatform != "netstack" { - t.Fatal("--dut_platform should be either linux or netstack") - } - if *testbenchBinary == "" { - t.Fatal("--testbench_binary is missing") - } - if *dutPlatform == "netstack" { - if _, err := dockerutil.RuntimePath(); err != nil { - t.Fatal("--runtime is missing or invalid with --dut_platform=netstack:", err) - } - } - dockerutil.EnsureSupportedDockerVersion() - ctx := context.Background() - - // Create the networks needed for the test. One control network is needed for - // the gRPC control packets and one test network on which to transmit the test - // packets. - ctrlNet := dockerutil.NewNetwork(ctx, logger("ctrlNet")) - testNet := dockerutil.NewNetwork(ctx, logger("testNet")) - for _, dn := range []*dockerutil.Network{ctrlNet, testNet} { - for { - if err := createDockerNetwork(ctx, dn); err != nil { - t.Log("creating docker network:", err) - const wait = 100 * time.Millisecond - t.Logf("sleeping %s and will try creating docker network again", wait) - // This can fail if another docker network claimed the same IP so we'll - // just try again. - time.Sleep(wait) - continue - } - break - } - defer func(dn *dockerutil.Network) { - if err := dn.Cleanup(ctx); err != nil { - t.Errorf("unable to cleanup container %s: %s", dn.Name, err) - } - }(dn) - // Sanity check. - inspect, err := dn.Inspect(ctx) - if err != nil { - t.Fatalf("failed to inspect network %s: %v", dn.Name, err) - } else if inspect.Name != dn.Name { - t.Fatalf("name mismatch for network want: %s got: %s", dn.Name, inspect.Name) - } - - } - - tmpDir, err := ioutil.TempDir("", "container-output") - if err != nil { - t.Fatal("creating temp dir:", err) - } - defer os.RemoveAll(tmpDir) - - const testOutputDir = "/tmp/testoutput" - - // Create the Docker container for the DUT. - dut := dockerutil.MakeContainer(ctx, logger("dut")) - if *dutPlatform == "linux" { - dut = dockerutil.MakeNativeContainer(ctx, logger("dut")) - } - - runOpts := dockerutil.RunOpts{ - Image: "packetimpact", - CapAdd: []string{"NET_ADMIN"}, - Mounts: []mount.Mount{mount.Mount{ - Type: mount.TypeBind, - Source: tmpDir, - Target: testOutputDir, - ReadOnly: false, - }}, - } - - const containerPosixServerBinary = "/packetimpact/posix_server" - dut.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/dut/posix_server") - - conf, hostconf, _ := dut.ConfigsFrom(runOpts, containerPosixServerBinary, "--ip=0.0.0.0", "--port="+ctrlPort) - hostconf.AutoRemove = true - hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} - - if err := dut.CreateFrom(ctx, conf, hostconf, nil); err != nil { - t.Fatalf("unable to create container %s: %v", dut.Name, err) - } - - defer dut.CleanUp(ctx) - - // Add ctrlNet as eth1 and testNet as eth2. - const testNetDev = "eth2" - if err := addNetworks(ctx, dut, dutAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { - t.Fatal(err) - } - - if err := dut.Start(ctx); err != nil { - t.Fatalf("unable to start container %s: %s", dut.Name, err) - } - - if _, err := dut.WaitForOutput(ctx, "Server listening.*\n", 60*time.Second); err != nil { - t.Fatalf("%s on container %s never listened: %s", containerPosixServerBinary, dut.Name, err) - } - - dutTestDevice, dutDeviceInfo, err := deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) - if err != nil { - t.Fatal(err) - } - - remoteMAC := dutDeviceInfo.MAC - remoteIPv6 := dutDeviceInfo.IPv6Addr - // Netstack as DUT doesn't assign IPv6 addresses automatically so do it if - // needed. - if remoteIPv6 == nil { - if _, err := dut.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "add", netdevs.MACToIP(remoteMAC).String(), "scope", "link", "dev", dutTestDevice); err != nil { - t.Fatalf("unable to ip addr add on container %s: %s", dut.Name, err) - } - // Now try again, to make sure that it worked. - _, dutDeviceInfo, err = deviceByIP(ctx, dut, addressInSubnet(dutAddr, *testNet.Subnet)) - if err != nil { - t.Fatal(err) - } - remoteIPv6 = dutDeviceInfo.IPv6Addr - if remoteIPv6 == nil { - t.Fatal("unable to set IPv6 address on container", dut.Name) - } - } - - // Create the Docker container for the testbench. - testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench")) - - tbb := path.Base(*testbenchBinary) - containerTestbenchBinary := "/packetimpact/" + tbb - runOpts = dockerutil.RunOpts{ - Image: "packetimpact", - CapAdd: []string{"NET_ADMIN"}, - Mounts: []mount.Mount{mount.Mount{ - Type: mount.TypeBind, - Source: tmpDir, - Target: testOutputDir, - ReadOnly: false, - }}, - } - testbench.CopyFiles(&runOpts, "/packetimpact", "/test/packetimpact/tests/"+tbb) - - // Run tcpdump in the test bench unbuffered, without DNS resolution, just on - // the interface with the test packets. - snifferArgs := []string{ - "tcpdump", - "-S", "-vvv", "-U", "-n", - "-i", testNetDev, - "-w", testOutputDir + "/dump.pcap", - } - snifferRegex := "tcpdump: listening.*\n" - if *tshark { - // Run tshark in the test bench unbuffered, without DNS resolution, just on - // the interface with the test packets. - snifferArgs = []string{ - "tshark", "-V", "-l", "-n", "-i", testNetDev, - "-o", "tcp.check_checksum:TRUE", - "-o", "udp.check_checksum:TRUE", - } - snifferRegex = "Capturing on.*\n" - } - - defer func() { - if err := exec.Command("/bin/cp", "-r", tmpDir, os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")).Run(); err != nil { - t.Error("unable to copy container output files:", err) - } - }() - - conf, hostconf, _ = testbench.ConfigsFrom(runOpts, snifferArgs...) - hostconf.AutoRemove = true - hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} - - if err := testbench.CreateFrom(ctx, conf, hostconf, nil); err != nil { - t.Fatalf("unable to create container %s: %s", testbench.Name, err) - } - defer testbench.CleanUp(ctx) - - // Add ctrlNet as eth1 and testNet as eth2. - if err := addNetworks(ctx, testbench, testbenchAddr, []*dockerutil.Network{ctrlNet, testNet}); err != nil { - t.Fatal(err) - } - - if err := testbench.Start(ctx); err != nil { - t.Fatalf("unable to start container %s: %s", testbench.Name, err) - } - - // Kill so that it will flush output. - defer func() { - time.Sleep(1 * time.Second) - testbench.Exec(ctx, dockerutil.ExecOpts{}, "killall", snifferArgs[0]) - }() - - if _, err := testbench.WaitForOutput(ctx, snifferRegex, 60*time.Second); err != nil { - t.Fatalf("sniffer on %s never listened: %s", dut.Name, err) - } - - // Because the Linux kernel receives the SYN-ACK but didn't send the SYN it - // will issue an RST. To prevent this IPtables can be used to filter out all - // incoming packets. The raw socket that packetimpact tests use will still see - // everything. - for _, bin := range []string{"iptables", "ip6tables"} { - if logs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, bin, "-A", "INPUT", "-i", testNetDev, "-p", "tcp", "-j", "DROP"); err != nil { - t.Fatalf("unable to Exec %s on container %s: %s, logs from testbench:\n%s", bin, testbench.Name, err, logs) - } - } - - // FIXME(b/156449515): Some piece of the system has a race. The old - // bash script version had a sleep, so we have one too. The race should - // be fixed and this sleep removed. - time.Sleep(time.Second) - - // Start a packetimpact test on the test bench. The packetimpact test sends - // and receives packets and also sends POSIX socket commands to the - // posix_server to be executed on the DUT. - testArgs := []string{containerTestbenchBinary} - testArgs = append(testArgs, extraTestArgs...) - testArgs = append(testArgs, - "--posix_server_ip", addressInSubnet(dutAddr, *ctrlNet.Subnet).String(), - "--posix_server_port", ctrlPort, - "--remote_ipv4", addressInSubnet(dutAddr, *testNet.Subnet).String(), - "--local_ipv4", addressInSubnet(testbenchAddr, *testNet.Subnet).String(), - "--remote_ipv6", remoteIPv6.String(), - "--remote_mac", remoteMAC.String(), - "--remote_interface_id", fmt.Sprintf("%d", dutDeviceInfo.ID), - "--device", testNetDev, - "--dut_type", *dutPlatform, - ) - testbenchLogs, err := testbench.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) - if (err != nil) != *expectFailure { - var dutLogs string - if logs, err := dut.Logs(ctx); err != nil { - dutLogs = fmt.Sprintf("failed to fetch DUT logs: %s", err) - } else { - dutLogs = logs - } - - t.Errorf(`test error: %v, expect failure: %t - -====== Begin of DUT Logs ====== - -%s - -====== End of DUT Logs ====== - -====== Begin of Testbench Logs ====== - -%s - -====== End of Testbench Logs ======`, - err, *expectFailure, dutLogs, testbenchLogs) - } -} - -func addNetworks(ctx context.Context, d *dockerutil.Container, addr net.IP, networks []*dockerutil.Network) error { - for _, dn := range networks { - ip := addressInSubnet(addr, *dn.Subnet) - // Connect to the network with the specified IP address. - if err := dn.Connect(ctx, d, ip.String(), ""); err != nil { - return fmt.Errorf("unable to connect container %s to network %s: %w", d.Name, dn.Name, err) - } - } - return nil -} - -// addressInSubnet combines the subnet provided with the address and returns a -// new address. The return address bits come from the subnet where the mask is 1 -// and from the ip address where the mask is 0. -func addressInSubnet(addr net.IP, subnet net.IPNet) net.IP { - var octets []byte - for i := 0; i < 4; i++ { - octets = append(octets, (subnet.IP.To4()[i]&subnet.Mask[i])+(addr.To4()[i]&(^subnet.Mask[i]))) - } - return net.IP(octets) -} - -// createDockerNetwork makes a randomly-named network that will start with the -// namePrefix. The network will be a random /24 subnet. -func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { - randSource := rand.NewSource(time.Now().UnixNano()) - r1 := rand.New(randSource) - // Class C, 192.0.0.0 to 223.255.255.255, transitionally has mask 24. - ip := net.IPv4(byte(r1.Intn(224-192)+192), byte(r1.Intn(256)), byte(r1.Intn(256)), 0) - n.Subnet = &net.IPNet{ - IP: ip, - Mask: ip.DefaultMask(), - } - return n.Create(ctx) -} - -// deviceByIP finds a deviceInfo and device name from an IP address. -func deviceByIP(ctx context.Context, d *dockerutil.Container, ip net.IP) (string, netdevs.DeviceInfo, error) { - out, err := d.Exec(ctx, dockerutil.ExecOpts{}, "ip", "addr", "show") - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("listing devices on %s container: %w", d.Name, err) - } - devs, err := netdevs.ParseDevices(out) - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("parsing devices from %s container: %w", d.Name, err) - } - testDevice, deviceInfo, err := netdevs.FindDeviceByIP(ip, devs) - if err != nil { - return "", netdevs.DeviceInfo{}, fmt.Errorf("can't find deviceInfo for container %s: %w", d.Name, err) - } - return testDevice, deviceInfo, nil + runner.TestWithDUT(context.Background(), t, runner.NewDockerDUT, runner.DutAddr) } diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 3af5f83fd..a90046f69 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -615,7 +615,7 @@ func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Du if errs == nil { return nil, fmt.Errorf("got no frames matching %v during %s", layers, timeout) } - return nil, fmt.Errorf("got no frames matching %v during %s: got %w", layers, timeout, errs) + return nil, fmt.Errorf("got frames %w want %v during %s", errs, layers, timeout) } if conn.match(layers, gotLayers) { for i, s := range conn.layerStates { diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index 73c532e75..6165ab293 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -16,11 +16,13 @@ package testbench import ( "context" + "encoding/binary" "flag" "net" "strconv" "syscall" "testing" + "time" pb "gvisor.dev/gvisor/test/packetimpact/proto/posix_server_go_proto" @@ -700,3 +702,43 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, fl } return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) } + +// SetSockLingerOption sets SO_LINGER socket option on the DUT. +func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Duration, enable bool) { + var linger unix.Linger + if enable { + linger.Onoff = 1 + } + linger.Linger = int32(timeout / time.Second) + + buf := make([]byte, 8) + binary.LittleEndian.PutUint32(buf, uint32(linger.Onoff)) + binary.LittleEndian.PutUint32(buf[4:], uint32(linger.Linger)) + dut.SetSockOpt(t, sockfd, unix.SOL_SOCKET, unix.SO_LINGER, buf) +} + +// Shutdown calls shutdown on the DUT and causes a fatal test failure if it doesn't +// succeed. If more control over the timeout or error handling is needed, use +// ShutdownWithErrno. +func (dut *DUT) Shutdown(t *testing.T, fd, how int32) error { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) + defer cancel() + return dut.ShutdownWithErrno(ctx, t, fd, how) +} + +// ShutdownWithErrno calls shutdown on the DUT. +func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int32) error { + t.Helper() + + req := pb.ShutdownRequest{ + Fd: fd, + How: how, + } + resp, err := dut.posixServer.Shutdown(ctx, &req) + if err != nil { + t.Fatalf("failed to call Shutdown: %s", err) + } + return syscall.Errno(resp.GetErrno_()) +} diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 24aa46cce..a35562ca8 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -775,7 +775,7 @@ func (l *IPv6FragmentExtHdr) String() string { type ICMPv6 struct { LayerBase Type *header.ICMPv6Type - Code *byte + Code *header.ICMPv6Code Checksum *uint16 Payload []byte } @@ -823,6 +823,12 @@ func ICMPv6Type(v header.ICMPv6Type) *header.ICMPv6Type { return &v } +// ICMPv6Code is a helper routine that allocates a new ICMPv6Type value to store +// v and returns a pointer to it. +func ICMPv6Code(v header.ICMPv6Code) *header.ICMPv6Code { + return &v +} + // Byte is a helper routine that allocates a new byte value to store // v and returns a pointer to it. func Byte(v byte) *byte { @@ -834,7 +840,7 @@ func parseICMPv6(b []byte) (Layer, layerParser) { h := header.ICMPv6(b) icmpv6 := ICMPv6{ Type: ICMPv6Type(h.Type()), - Code: Byte(h.Code()), + Code: ICMPv6Code(h.Code()), Checksum: Uint16(h.Checksum()), Payload: h.NDPPayload(), } @@ -861,11 +867,17 @@ func ICMPv4Type(t header.ICMPv4Type) *header.ICMPv4Type { return &t } +// ICMPv4Code is a helper routine that allocates a new header.ICMPv4Code value +// to store t and returns a pointer to it. +func ICMPv4Code(t header.ICMPv4Code) *header.ICMPv4Code { + return &t +} + // ICMPv4 can construct and match an ICMPv4 encapsulation. type ICMPv4 struct { LayerBase Type *header.ICMPv4Type - Code *uint8 + Code *header.ICMPv4Code Checksum *uint16 } @@ -881,7 +893,7 @@ func (l *ICMPv4) ToBytes() ([]byte, error) { h.SetType(*l.Type) } if l.Code != nil { - h.SetCode(byte(*l.Code)) + h.SetCode(*l.Code) } if l.Checksum != nil { h.SetChecksum(*l.Checksum) @@ -901,7 +913,7 @@ func parseICMPv4(b []byte) (Layer, layerParser) { h := header.ICMPv4(b) icmpv4 := ICMPv4{ Type: ICMPv4Type(h.Type()), - Code: Uint8(h.Code()), + Code: ICMPv4Code(h.Code()), Checksum: Uint16(h.Checksum()), } return &icmpv4, parsePayload diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index a2a763034..eca0780b5 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -594,7 +594,7 @@ func TestIPv6ExtHdrOptions(t *testing.T) { }, &ICMPv6{ Type: ICMPv6Type(header.ICMPv6ParamProblem), - Code: Byte(0), + Code: ICMPv6Code(header.ICMPv6ErroneousHeader), Checksum: Uint16(0x5f98), Payload: []byte{0x00, 0x00, 0x00, 0x06}, }, diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go index 57e822725..193bb2dc8 100644 --- a/test/packetimpact/testbench/rawsockets.go +++ b/test/packetimpact/testbench/rawsockets.go @@ -139,7 +139,7 @@ type Injector struct { func NewInjector(t *testing.T) (Injector, error) { t.Helper() - ifInfo, err := net.InterfaceByName(Device) + ifInfo, err := net.InterfaceByName(LocalDevice) if err != nil { return Injector{}, err } diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index 242464e3a..0073a1361 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -27,10 +27,13 @@ import ( ) var ( - // DUTType is the type of device under test. - DUTType = "" - // Device is the local device on the test network. - Device = "" + // Native indicates that the test is being run natively. + Native = false + // LocalDevice is the device that testbench uses to inject traffic. + LocalDevice = "" + // RemoteDevice is the device name on the DUT, individual tests can + // use the name to construct tests. + RemoteDevice = "" // LocalIPv4 is the local IPv4 address on the test network. LocalIPv4 = "" @@ -80,8 +83,9 @@ func RegisterFlags(fs *flag.FlagSet) { fs.StringVar(&RemoteIPv4, "remote_ipv4", RemoteIPv4, "remote IPv4 address for test packets") fs.StringVar(&RemoteIPv6, "remote_ipv6", RemoteIPv6, "remote IPv6 address for test packets") fs.StringVar(&RemoteMAC, "remote_mac", RemoteMAC, "remote mac address for test packets") - fs.StringVar(&Device, "device", Device, "local device for test packets") - fs.StringVar(&DUTType, "dut_type", DUTType, "type of device under test") + fs.StringVar(&LocalDevice, "local_device", LocalDevice, "local device to inject traffic") + fs.StringVar(&RemoteDevice, "remote_device", RemoteDevice, "remote device on the DUT") + fs.BoolVar(&Native, "native", Native, "whether the test is running natively") fs.Uint64Var(&RemoteInterfaceID, "remote_interface_id", RemoteInterfaceID, "remote interface ID for test packets") } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 0c2a05380..94731c64b 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -40,8 +40,6 @@ packetimpact_go_test( packetimpact_go_test( name = "udp_recv_mcast_bcast", srcs = ["udp_recv_mcast_bcast_test.go"], - # TODO(b/152813495): Fix netstack then remove the line below. - expect_netstack_failure = True, deps = [ "//pkg/tcpip", "//pkg/tcpip/header", @@ -168,8 +166,8 @@ packetimpact_go_test( ) packetimpact_go_test( - name = "tcp_close_wait_ack", - srcs = ["tcp_close_wait_ack_test.go"], + name = "tcp_unacc_seq_ack", + srcs = ["tcp_unacc_seq_ack_test.go"], deps = [ "//pkg/tcpip/header", "//pkg/tcpip/seqnum", @@ -262,6 +260,28 @@ packetimpact_go_test( ) packetimpact_go_test( + name = "tcp_timewait_reset", + srcs = ["tcp_timewait_reset_test.go"], + # TODO(b/168523247): Fix netstack then remove the line below. + expect_netstack_failure = True, + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_queue_send_in_syn_sent", + srcs = ["tcp_queue_send_in_syn_sent_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( name = "icmpv6_param_problem", srcs = ["icmpv6_param_problem_test.go"], # TODO(b/153485026): Fix netstack then remove the line below. @@ -310,3 +330,23 @@ packetimpact_go_test( "@org_golang_x_sys//unix:go_default_library", ], ) + +packetimpact_go_test( + name = "tcp_linger", + srcs = ["tcp_linger_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_go_test( + name = "tcp_rcv_buf_space", + srcs = ["tcp_rcv_buf_space_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go index b5f94ad4b..a24c85566 100644 --- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go +++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go @@ -67,7 +67,7 @@ func TestIPv6FragmentReassembly(t *testing.T) { rIP := tcpip.Address(net.ParseIP(testbench.RemoteIPv6).To16()) icmpv6 := testbench.ICMPv6{ Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), - Code: testbench.Byte(0), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), Payload: icmpv6EchoPayload, } icmpv6Bytes, err := icmpv6.ToBytes() @@ -89,7 +89,7 @@ func TestIPv6FragmentReassembly(t *testing.T) { }, &testbench.ICMPv6{ Type: testbench.ICMPv6Type(header.ICMPv6EchoRequest), - Code: testbench.Byte(0), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), Payload: icmpv6EchoPayload, Checksum: &cksum, }) @@ -116,7 +116,7 @@ func TestIPv6FragmentReassembly(t *testing.T) { }, &testbench.ICMPv6{ Type: testbench.ICMPv6Type(header.ICMPv6EchoReply), - Code: testbench.Byte(0), + Code: testbench.ICMPv6Code(header.ICMPv6UnusedCode), }, }, time.Second) if err != nil { diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go index d7d63cbd2..e79d74476 100644 --- a/test/packetimpact/tests/ipv6_unknown_options_action_test.go +++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go @@ -172,7 +172,7 @@ func TestIPv6UnknownOptionAction(t *testing.T) { &testbench.IPv6{}, &testbench.ICMPv6{ Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem), - Code: testbench.Byte(2), + Code: testbench.ICMPv6Code(header.ICMPv6UnknownOption), Payload: icmpv6Payload, }, }, time.Second) diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go deleted file mode 100644 index e6a96f214..000000000 --- a/test/packetimpact/tests/tcp_close_wait_ack_test.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package tcp_close_wait_ack_test - -import ( - "flag" - "fmt" - "testing" - "time" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/seqnum" - "gvisor.dev/gvisor/test/packetimpact/testbench" -) - -func init() { - testbench.RegisterFlags(flag.CommandLine) -} - -func TestCloseWaitAck(t *testing.T) { - for _, tt := range []struct { - description string - makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP - seqNumOffset seqnum.Size - expectAck bool - }{ - {"OTW", generateOTWSeqSegment, 0, false}, - {"OTW", generateOTWSeqSegment, 1, true}, - {"OTW", generateOTWSeqSegment, 2, true}, - {"ACK", generateUnaccACKSegment, 0, false}, - {"ACK", generateUnaccACKSegment, 1, true}, - {"ACK", generateUnaccACKSegment, 2, true}, - } { - t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) { - dut := testbench.NewDUT(t) - defer dut.TearDown() - listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) - defer dut.Close(t, listenFd) - conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) - defer conn.Close(t) - - conn.Connect(t) - acceptFd, _ := dut.Accept(t, listenFd) - - // Send a FIN to DUT to intiate the active close - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) - if err != nil { - t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) - } - windowSize := seqnum.Size(*gotTCP.WindowSize) - - // Send a segment with OTW Seq / unacc ACK and expect an ACK back - conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")}) - gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) - if tt.expectAck && err != nil { - t.Fatalf("expected an ack but got none: %s", err) - } - if !tt.expectAck && gotAck != nil { - t.Fatalf("expected no ack but got one: %s", gotAck) - } - - // Now let's verify DUT is indeed in CLOSE_WAIT - dut.Close(t, acceptFd) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { - t.Fatalf("expected DUT to send a FIN: %s", err) - } - // Ack the FIN from DUT - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - // Send some extra data to DUT - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")}) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { - t.Fatalf("expected DUT to send an RST: %s", err) - } - }) - } -} - -// generateOTWSeqSegment generates an segment with -// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only -// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the -// receiver. -func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) - otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) - return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} -} - -// generateUnaccACKSegment generates an segment with -// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable -// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. -func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.RemoteSeqNum(t) - unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) - return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} -} diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go new file mode 100644 index 000000000..b9a0409aa --- /dev/null +++ b/test/packetimpact/tests/tcp_linger_test.go @@ -0,0 +1,270 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_linger_test + +import ( + "context" + "flag" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func createSocket(t *testing.T, dut testbench.DUT) (int32, int32, testbench.TCPIPv4) { + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + return acceptFD, listenFD, conn +} + +func closeAll(t *testing.T, dut testbench.DUT, listenFD int32, conn testbench.TCPIPv4) { + conn.Close(t) + dut.Close(t, listenFD) + dut.TearDown() +} + +// lingerDuration is the timeout value used with SO_LINGER socket option. +const lingerDuration = 3 * time.Second + +// TestTCPLingerZeroTimeout tests when SO_LINGER is set with zero timeout. DUT +// should send RST-ACK when socket is closed. +func TestTCPLingerZeroTimeout(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, 0, true) + dut.Close(t, acceptFD) + + // If the linger timeout is set to zero, the DUT should send a RST. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected RST-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerOff tests when SO_LINGER is not set. DUT should send FIN-ACK +// when socket is closed. +func TestTCPLingerOff(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.Close(t, acceptFD) + + // If SO_LINGER is not set, DUT should send a FIN-ACK. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerNonZeroTimeout tests when SO_LINGER is set with non-zero timeout. +// DUT should close the socket after timeout. +func TestTCPLingerNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"WithNonZeroLinger", true}, + {"WithoutLinger", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +// TestTCPLingerSendNonZeroTimeout tests when SO_LINGER is set with non-zero +// timeout and send a packet. DUT should close the socket after timeout. +func TestTCPLingerSendNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"WithSendNonZeroLinger", true}, + {"WithoutLinger", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Send data. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + samplePayload := &testbench.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +// TestTCPLingerShutdownZeroTimeout tests SO_LINGER with shutdown() and zero +// timeout. DUT should send RST-ACK when socket is closed. +func TestTCPLingerShutdownZeroTimeout(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, 0, true) + dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + dut.Close(t, acceptFD) + + // Shutdown will send FIN-ACK with read/write option. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + + // If the linger timeout is set to zero, the DUT should send a RST. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected RST-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) +} + +// TestTCPLingerShutdownSendNonZeroTimeout tests SO_LINGER with shutdown() and +// non-zero timeout. DUT should close the socket after timeout. +func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) { + for _, tt := range []struct { + description string + lingerOn bool + }{ + {"shutdownRDWR", true}, + {"shutdownRDWR", false}, + } { + t.Run(tt.description, func(t *testing.T) { + // Create a socket, listen, TCP connect, and accept. + dut := testbench.NewDUT(t) + acceptFD, listenFD, conn := createSocket(t, dut) + defer closeAll(t, dut, listenFD, conn) + + dut.SetSockLingerOption(t, acceptFD, lingerDuration, tt.lingerOn) + + // Send data. + sampleData := []byte("Sample Data") + dut.Send(t, acceptFD, sampleData, 0) + + dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + + // Increase timeout as Close will take longer time to + // return when SO_LINGER is set with non-zero timeout. + timeout := lingerDuration + 1*time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + start := time.Now() + dut.CloseWithErrno(ctx, t, acceptFD) + end := time.Now() + diff := end.Sub(start) + + if tt.lingerOn && diff < lingerDuration { + t.Errorf("expected close to return after %v seconds, but returned sooner", lingerDuration) + } else if !tt.lingerOn && diff > 1*time.Second { + t.Errorf("expected close to return within a second, but returned later") + } + + samplePayload := &testbench.Payload{Bytes: sampleData} + if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { + t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) + } + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + }) + } +} + +func TestTCPLingerNonEstablished(t *testing.T) { + dut := testbench.NewDUT(t) + newFD := dut.Socket(t, unix.AF_INET, unix.SOCK_STREAM, unix.IPPROTO_TCP) + dut.SetSockLingerOption(t, newFD, lingerDuration, true) + + // As the socket is in the initial state, Close() should not linger + // and return immediately. + start := time.Now() + dut.CloseWithErrno(context.Background(), t, newFD) + diff := time.Since(start) + + if diff > lingerDuration { + t.Errorf("expected close to return within %s, but returned after %s", lingerDuration, diff) + } + dut.TearDown() +} diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go index 900352fa1..2f57dff19 100644 --- a/test/packetimpact/tests/tcp_network_unreachable_test.go +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -72,7 +72,9 @@ func TestTCPSynSentUnreachable(t *testing.T) { if !ok { t.Fatalf("expected %s to be TCP", tcpLayers[tcpLayer]) } - var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4HostUnreachable)} + var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), + Code: testbench.ICMPv4Code(header.ICMPv4HostUnreachable)} layers = append(layers, &icmpv4, ip, tcp) rawConn.SendFrameStateless(t, layers) @@ -126,7 +128,7 @@ func TestTCPSynSentUnreachable6(t *testing.T) { } var icmpv6 testbench.ICMPv6 = testbench.ICMPv6{ Type: testbench.ICMPv6Type(header.ICMPv6DstUnreachable), - Code: testbench.Uint8(header.ICMPv6NetworkUnreachable), + Code: testbench.ICMPv6Code(header.ICMPv6NetworkUnreachable), // Per RFC 4443 3.1, the payload contains 4 zeroed bytes. Payload: []byte{0, 0, 0, 0}, } diff --git a/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go new file mode 100644 index 000000000..0ec8fd748 --- /dev/null +++ b/test/packetimpact/tests/tcp_queue_send_in_syn_sent_test.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_queue_send_in_syn_sent_test + +import ( + "context" + "errors" + "flag" + "net" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestQueueSendInSynSent tests send behavior when the TCP state +// is SYN-SENT. +// It tests for 2 variants when in SYN_SENT state and: +// (1) DUT blocks on send and complete handshake +// (2) DUT blocks on send and receive a TCP RST. +func TestQueueSendInSynSent(t *testing.T) { + for _, tt := range []struct { + description string + reset bool + }{ + {description: "Complete handshake", reset: false}, + {description: "Send RST", reset: true}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + + socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4)) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + dut.SetNonBlocking(t, socket, true) + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) + } + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + t.Fatalf("expected a SYN from DUT, but got none: %s", err) + } + if _, err := dut.SendWithErrno(context.Background(), t, socket, sampleData, 0); err != syscall.Errno(unix.EWOULDBLOCK) { + t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) + } + + // Test blocking write. + dut.SetNonBlocking(t, socket, false) + + var wg sync.WaitGroup + defer wg.Wait() + wg.Add(1) + var block sync.WaitGroup + block.Add(1) + go func() { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + + block.Done() + // Issue SEND call in SYN-SENT, this should be queued for + // process until the connection is established. + n, err := dut.SendWithErrno(ctx, t, socket, sampleData, 0) + if tt.reset { + if err != syscall.Errno(unix.ECONNREFUSED) { + t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) + } + if n != -1 { + t.Errorf("expected return value %d, got %d", -1, n) + } + return + } + if n != int32(len(sampleData)) { + t.Errorf("failed to send on DUT: %s", err) + } + }() + + // Wait for the goroutine to be scheduled and before it + // blocks on endpoint send. + block.Wait() + // The following sleep is used to prevent the connection + // from being established before we are blocked on send. + time.Sleep(100 * time.Millisecond) + + if tt.reset { + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + return + } + + // Bring the connection to Established. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + + // Expect the data from the DUT's enqueued send request. + // + // On Linux, this can be piggybacked with the ACK completing the + // handshake. On gVisor, getting such a piggyback is a bit more + // complicated because the actual data enqueuing occurs in the + // callers of endpoint Write. + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, samplePayload, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + + // Send sample payload and expect an ACK to ensure connection is still ESTABLISHED. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK from DUT, but got none: %s", err) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_rcv_buf_space_test.go b/test/packetimpact/tests/tcp_rcv_buf_space_test.go new file mode 100644 index 000000000..cfbba1e8e --- /dev/null +++ b/test/packetimpact/tests/tcp_rcv_buf_space_test.go @@ -0,0 +1,80 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_rcv_buf_space_test + +import ( + "context" + "flag" + "syscall" + "testing" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestReduceRecvBuf tests that a packet within window is still dropped +// if the available buffer space drops below the size of the incoming +// segment. +func TestReduceRecvBuf(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFd) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFd, _ := dut.Accept(t, listenFd) + defer dut.Close(t, acceptFd) + + // Set a small receive buffer for the test. + const rcvBufSz = 4096 + dut.SetSockOptInt(t, acceptFd, unix.SOL_SOCKET, unix.SO_RCVBUF, rcvBufSz) + + // Retrieve the actual buffer. + bufSz := dut.GetSockOptInt(t, acceptFd, unix.SOL_SOCKET, unix.SO_RCVBUF) + + // Generate a payload of 1 more than the actual buffer size used by the + // DUT. + sampleData := testbench.GenerateRandomPayload(t, int(bufSz)+1) + // Send and receive sample data to the dut. + const pktSize = 1400 + for payload := sampleData; len(payload) != 0; { + payloadBytes := pktSize + if l := len(payload); l < payloadBytes { + payloadBytes = l + } + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, []testbench.Layer{&testbench.Payload{Bytes: payload[:payloadBytes]}}...) + payload = payload[payloadBytes:] + } + + // First read should read < len(sampleData) + if ret, _, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), 0); ret == -1 || int(ret) == len(sampleData) { + t.Fatalf("dut.RecvWithErrno(ctx, t, %d, %d, 0) = %d,_, %s", acceptFd, int32(len(sampleData)), ret, err) + } + + // Second read should return EAGAIN as the last segment should have been + // dropped due to it exceeding the receive buffer space available in the + // socket. + if ret, got, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), syscall.MSG_DONTWAIT); got != nil || ret != -1 || err != syscall.EAGAIN { + t.Fatalf("expected no packets but got: %s", got) + } +} diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go index 8742819ca..b4aeaab57 100644 --- a/test/packetimpact/tests/tcp_reordering_test.go +++ b/test/packetimpact/tests/tcp_reordering_test.go @@ -54,13 +54,13 @@ func TestReorderingWindow(t *testing.T) { acceptFd, _ := dut.Accept(t, listenFd) defer dut.Close(t, acceptFd) - if tb.DUTType == "linux" { + if tb.Native { // Linux has changed its handling of reordering, force the old behavior. dut.SetSockOpt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno")) } pls := dut.GetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG) - if tb.DUTType == "netstack" { + if !tb.Native { // netstack does not impliment TCP_MAXSEG correctly. Fake it // here. Netstack uses the max SACK size which is 32. The MSS // option is 8 bytes, making the total 36 bytes. @@ -141,7 +141,7 @@ func TestReorderingWindow(t *testing.T) { } } - if tb.DUTType == "netstack" { + if !tb.Native { // The window should now be halved, so we should receive any // more, even if we send them. dut.Send(t, acceptFd, payload, 0) diff --git a/test/packetimpact/tests/tcp_timewait_reset_test.go b/test/packetimpact/tests/tcp_timewait_reset_test.go new file mode 100644 index 000000000..2f76a6531 --- /dev/null +++ b/test/packetimpact/tests/tcp_timewait_reset_test.go @@ -0,0 +1,68 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_timewait_reset_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +// TestTimeWaitReset tests handling of RST when in TIME_WAIT state. +func TestTimeWaitReset(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Trigger active close. + dut.Close(t, acceptFD) + + _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected a FIN: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Send a FIN, DUT should transition to TIME_WAIT from FIN_WAIT2. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our FIN: %s", err) + } + + // Send a RST, the DUT should transition to CLOSED from TIME_WAIT. + // This is the default Linux behavior, it can be changed to ignore RSTs via + // sysctl net.ipv4.tcp_rfc1337. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // The DUT should reply with RST to our ACK as the state should have + // transitioned to CLOSED. + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected a RST: %s", err) + } +} diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go new file mode 100644 index 000000000..d078bbf15 --- /dev/null +++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go @@ -0,0 +1,234 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_unacc_seq_ack_test + +import ( + "flag" + "fmt" + "syscall" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.RegisterFlags(flag.CommandLine) +} + +func TestEstablishedUnaccSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + restoreSeq bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: true, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: true, restoreSeq: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: false, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: false, restoreSeq: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + dut.Accept(t, listenFD) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected ack %s", err) + } + windowSize := seqnum.Size(*gotTCP.WindowSize) + + origSeq := *conn.LocalSeqNum(t) + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload) + if tt.restoreSeq { + // Restore the local sequence number to ensure that the incoming + // ACK matches the TCP layer state. + *conn.LocalSeqNum(t) = origSeq + } + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Fatalf("expected an ack but got none: %s", err) + } + if err == nil && !tt.expectAck && gotAck != nil { + t.Fatalf("expected no ack but got one: %s", gotAck) + } + }) + } +} + +func TestPassiveCloseUnaccSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: false}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Send a FIN to DUT to intiate the passive close. + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) + } + windowSize := seqnum.Size(*gotTCP.WindowSize) + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload) + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Errorf("expected an ack but got none: %s", err) + } + if err == nil && !tt.expectAck && gotAck != nil { + t.Errorf("expected no ack but got one: %s", gotAck) + } + + // Now let's verify DUT is indeed in CLOSE_WAIT + dut.Close(t, acceptFD) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + t.Fatalf("expected DUT to send a FIN: %s", err) + } + // Ack the FIN from DUT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + // Send some extra data to DUT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, samplePayload) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + t.Fatalf("expected DUT to send an RST: %s", err) + } + }) + } +} + +func TestActiveCloseUnaccpSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + restoreSeq bool + }{ + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, restoreSeq: false}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, restoreSeq: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + defer dut.TearDown() + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Trigger active close. + dut.Shutdown(t, acceptFD, syscall.SHUT_WR) + + // Get to FIN_WAIT2 + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected a FIN: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + + sendUnaccSeqAck := func(state string) { + t.Helper() + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + origSeq := *conn.LocalSeqNum(t) + // Send a segment with OTW Seq / unacc ACK. + conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize)), samplePayload) + if tt.restoreSeq { + // Restore the local sequence number to ensure that the + // incoming ACK matches the TCP layer state. + *conn.LocalSeqNum(t) = origSeq + } + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected an ack in %s state, but got none: %s", state, err) + } + } + + sendUnaccSeqAck("FIN_WAIT2") + + // Send a FIN to DUT to get to TIME_WAIT + conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected an ACK for our fin and DUT should enter TIME_WAIT: %s", err) + } + + sendUnaccSeqAck("TIME_WAIT") + }) + } +} + +// generateOTWSeqSegment generates an segment with +// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only +// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the +// receiver. +func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} +} + +// generateUnaccACKSegment generates an segment with +// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable +// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. +func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { + lastAcceptable := conn.RemoteSeqNum(t) + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} +} diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go index d30177e64..3d2791a6e 100644 --- a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go +++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go @@ -53,6 +53,7 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { t, testbench.IPv4{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To4()))}, testbench.UDP{}, + &testbench.Payload{Bytes: []byte("test payload")}, ) ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) @@ -76,14 +77,15 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) { net.IPv6interfacelocalallnodes, net.IPv6linklocalallnodes, net.IPv6linklocalallrouters, - net.ParseIP("fe01::42"), - net.ParseIP("fe02::4242"), + net.ParseIP("ff01::42"), + net.ParseIP("ff02::4242"), } { t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) { conn.SendIPv6( t, testbench.IPv6{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To16()))}, testbench.UDP{}, + &testbench.Payload{Bytes: []byte("test payload")}, ) ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go index b47ddb6c3..df35d16c8 100644 --- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -62,9 +62,13 @@ func (e icmpError) String() string { func (e icmpError) ToICMPv4() *testbench.ICMPv4 { switch e { case portUnreachable: - return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4PortUnreachable)} + return &testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), + Code: testbench.ICMPv4Code(header.ICMPv4PortUnreachable)} case timeToLiveExceeded: - return &testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded), Code: testbench.Uint8(header.ICMPv4TTLExceeded)} + return &testbench.ICMPv4{ + Type: testbench.ICMPv4Type(header.ICMPv4TimeExceeded), + Code: testbench.ICMPv4Code(header.ICMPv4TTLExceeded)} } return nil } diff --git a/test/perf/BUILD b/test/perf/BUILD index 471d8c2ab..b763be50e 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -3,33 +3,40 @@ load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) syscall_test( + debug = False, test = "//test/perf/linux:clock_getres_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:clock_gettime_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:death_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:epoll_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:fork_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:futex_benchmark", ) syscall_test( size = "enormous", + debug = False, shard_count = 10, tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", @@ -37,81 +44,96 @@ syscall_test( syscall_test( size = "large", + debug = False, test = "//test/perf/linux:getpid_benchmark", ) syscall_test( size = "enormous", + debug = False, tags = ["nogotsan"], test = "//test/perf/linux:gettid_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:mapping_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:open_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:pipe_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:randread_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:read_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:sched_yield_benchmark", ) syscall_test( size = "large", + debug = False, test = "//test/perf/linux:send_recv_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:seqwrite_benchmark", ) syscall_test( size = "enormous", + debug = False, test = "//test/perf/linux:signal_benchmark", ) syscall_test( + debug = False, test = "//test/perf/linux:sleep_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:stat_benchmark", ) syscall_test( size = "enormous", add_overlay = True, + debug = False, test = "//test/perf/linux:unlink_benchmark", ) syscall_test( size = "large", add_overlay = True, + debug = False, test = "//test/perf/linux:write_benchmark", ) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD index b4e907826..dd1d2438c 100644 --- a/test/perf/linux/BUILD +++ b/test/perf/linux/BUILD @@ -354,3 +354,19 @@ cc_binary( "//test/util:test_util", ], ) + +cc_binary( + name = "open_read_close_benchmark", + testonly = 1, + srcs = [ + "open_read_close_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + ], +) diff --git a/test/perf/linux/getdents_benchmark.cc b/test/perf/linux/getdents_benchmark.cc index d8e81fa8c..9030eb356 100644 --- a/test/perf/linux/getdents_benchmark.cc +++ b/test/perf/linux/getdents_benchmark.cc @@ -105,7 +105,7 @@ void BM_GetdentsSameFD(benchmark::State& state) { state.SetItemsProcessed(state.iterations()); } -BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 16)->UseRealTime(); +BENCHMARK(BM_GetdentsSameFD)->Range(1, 1 << 12)->UseRealTime(); // Creates a directory containing `files` files, and reads all the directory // entries from the directory using a new FD each time. diff --git a/test/perf/linux/open_read_close_benchmark.cc b/test/perf/linux/open_read_close_benchmark.cc new file mode 100644 index 000000000..8b023a3d8 --- /dev/null +++ b/test/perf/linux/open_read_close_benchmark.cc @@ -0,0 +1,61 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_OpenReadClose(benchmark::State& state) { + const int size = state.range(0); + std::vector<TempPath> cache; + for (int i = 0; i < size; i++) { + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), "some content", 0644)); + cache.emplace_back(std::move(path)); + } + + char buf[1]; + unsigned int seed = 1; + for (auto _ : state) { + const int chosen = rand_r(&seed) % size; + int fd = open(cache[chosen].path().c_str(), O_RDONLY); + TEST_CHECK(fd != -1); + TEST_CHECK(read(fd, buf, 1) == 1); + close(fd); + } +} + +// Gofer dentry cache is 1000 by default. Go over it to force files to be closed +// for real. +BENCHMARK(BM_OpenReadClose)->Range(1000, 16384)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/crictl_test.go b/test/root/crictl_test.go index df91fa0fe..11ac5cb52 100644 --- a/test/root/crictl_test.go +++ b/test/root/crictl_test.go @@ -418,7 +418,7 @@ func setup(t *testing.T, version string) (*criutil.Crictl, func(), error) { // care about the docker runtime name. config = v2Template default: - t.Fatalf("unknown version: %d", version) + t.Fatalf("unknown version: %s", version) } t.Logf("Using config: %s", config) diff --git a/test/root/root.go b/test/root/root.go index 0f1d29faf..441fa5e2e 100644 --- a/test/root/root.go +++ b/test/root/root.go @@ -17,5 +17,5 @@ // docker, containerd, and crictl installed. To run these tests from the // project root directory: // -// ./scripts/root_tests.sh +// make root-tests package root diff --git a/test/runner/BUILD b/test/runner/BUILD index 63c7ec83a..582d2946d 100644 --- a/test/runner/BUILD +++ b/test/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary") +load("//tools:defs.bzl", "bzl_library", "go_binary") package(licenses = ["notice"]) @@ -21,3 +21,9 @@ go_binary( "@org_golang_x_sys//unix:go_default_library", ], ) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], +) diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index c92392b35..032ebd04e 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -62,7 +62,8 @@ def _syscall_test( overlay = False, add_uds_tree = False, vfs2 = False, - fuse = False): + fuse = False, + debug = True): # Prepend "runsc" to non-native platform names. full_platform = platform if platform == "native" else "runsc_" + platform @@ -111,6 +112,8 @@ def _syscall_test( "--add-uds-tree=" + str(add_uds_tree), "--vfs2=" + str(vfs2), "--fuse=" + str(fuse), + "--strace=" + str(debug), + "--debug=" + str(debug), ] # Call the rule above. @@ -132,8 +135,9 @@ def syscall_test( add_overlay = False, add_uds_tree = False, add_hostinet = False, - vfs2 = False, + vfs2 = True, fuse = False, + debug = True, tags = None): """syscall_test is a macro that will create targets for all platforms. @@ -150,31 +154,12 @@ def syscall_test( if not tags: tags = [] - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = "native", - use_tmpfs = False, - add_uds_tree = add_uds_tree, - tags = list(tags), - ) - - for (platform, platform_tags) in platforms.items(): - _syscall_test( - test = test, - shard_count = shard_count, - size = size, - platform = platform, - use_tmpfs = use_tmpfs, - add_uds_tree = add_uds_tree, - tags = platform_tags + tags, - ) - vfs2_tags = list(tags) if vfs2: # Add tag to easily run VFS2 tests with --test_tag_filters=vfs2 vfs2_tags.append("vfs2") + if fuse: + vfs2_tags.append("fuse") else: # Don't automatically run tests tests not yet passing. @@ -191,19 +176,31 @@ def syscall_test( add_uds_tree = add_uds_tree, tags = platforms[default_platform] + vfs2_tags, vfs2 = True, + fuse = fuse, + ) + if fuse: + # Only generate *_vfs2_fuse target if fuse parameter is enabled. + return + + _syscall_test( + test = test, + shard_count = shard_count, + size = size, + platform = "native", + use_tmpfs = False, + add_uds_tree = add_uds_tree, + tags = list(tags), ) - if vfs2 and fuse: + for (platform, platform_tags) in platforms.items(): _syscall_test( test = test, shard_count = shard_count, size = size, - platform = default_platform, + platform = platform, use_tmpfs = use_tmpfs, add_uds_tree = add_uds_tree, - tags = platforms[default_platform] + vfs2_tags, - vfs2 = True, - fuse = True, + tags = platform_tags + tags, ) # TODO(gvisor.dev/issue/1487): Enable VFS2 overlay tests. diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go index 869169ad5..e4445e01b 100644 --- a/test/runner/gtest/gtest.go +++ b/test/runner/gtest/gtest.go @@ -146,10 +146,13 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) } - out = []byte(strings.Trim(string(out), "\n")) + benches := strings.Trim(string(out), "\n") + if len(benches) == 0 { + return t, nil + } // Parse benchmark output. - for _, line := range strings.Split(string(out), "\n") { + for _, line := range strings.Split(benches, "\n") { // Strip comments. line = strings.Split(line, "#")[0] @@ -163,6 +166,5 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes benchmark: true, }) } - return t, nil } diff --git a/test/runner/runner.go b/test/runner/runner.go index bc4b39cbb..22d535f8d 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -106,11 +106,14 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + cmd.SysProcAttr = &syscall.SysProcAttr{} + + if specutils.HasCapabilities(capability.CAP_SYS_ADMIN) { + cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWUTS + } if specutils.HasCapabilities(capability.CAP_NET_ADMIN) { - cmd.SysProcAttr = &syscall.SysProcAttr{ - Cloneflags: syscall.CLONE_NEWNET, - } + cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWNET } if err := cmd.Run(); err != nil { @@ -172,13 +175,14 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { args = append(args, "-fsgofer-host-uds") } - undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR") - if ok { - tdir := filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) - if err := os.MkdirAll(tdir, 0755); err != nil { + testLogDir := "" + if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { + // Create log directory dedicated for this test. + testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) + if err := os.MkdirAll(testLogDir, 0755); err != nil { return fmt.Errorf("could not create test dir: %v", err) } - debugLogDir, err := ioutil.TempDir(tdir, "runsc") + debugLogDir, err := ioutil.TempDir(testLogDir, "runsc") if err != nil { return fmt.Errorf("could not create temp dir: %v", err) } @@ -227,10 +231,10 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { dArgs := append([]string{}, args...) dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) go func(dArgs []string) { - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() + debug := exec.Command(*runscPath, dArgs...) + debug.Stdout = os.Stdout + debug.Stderr = os.Stderr + debug.Run() done <- true }(dArgs) @@ -245,17 +249,17 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { dArgs = append(args, "debug", fmt.Sprintf("--signal=%d", syscall.SIGTERM), id) - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() + signal := exec.Command(*runscPath, dArgs...) + signal.Stdout = os.Stdout + signal.Stderr = os.Stderr + signal.Run() }() err = cmd.Run() - if err == nil { + if err == nil && len(testLogDir) > 0 { // If the test passed, then we erase the log directory. This speeds up // uploading logs in continuous integration & saves on disk space. - os.RemoveAll(undeclaredOutputsDir) + os.RemoveAll(testLogDir) } return err diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD index 3be123d94..22b526f59 100644 --- a/test/runtimes/BUILD +++ b/test/runtimes/BUILD @@ -1,39 +1,46 @@ +load("//tools:defs.bzl", "bzl_library") load("//test/runtimes:defs.bzl", "runtime_test") package(licenses = ["notice"]) runtime_test( name = "go1.12", - exclude_file = "exclude_go1.12.csv", + exclude_file = "exclude/go1.12.csv", lang = "go", - shard_count = 5, + shard_count = 8, ) runtime_test( name = "java11", batch = 100, - exclude_file = "exclude_java11.csv", + exclude_file = "exclude/java11.csv", lang = "java", - shard_count = 20, + shard_count = 16, ) runtime_test( name = "nodejs12.4.0", - exclude_file = "exclude_nodejs12.4.0.csv", + exclude_file = "exclude/nodejs12.4.0.csv", lang = "nodejs", - shard_count = 10, + shard_count = 8, ) runtime_test( name = "php7.3.6", - exclude_file = "exclude_php7.3.6.csv", + exclude_file = "exclude/php7.3.6.csv", lang = "php", - shard_count = 5, + shard_count = 8, ) runtime_test( name = "python3.7.3", - exclude_file = "exclude_python3.7.3.csv", + exclude_file = "exclude/python3.7.3.csv", lang = "python", - shard_count = 5, + shard_count = 8, +) + +bzl_library( + name = "defs_bzl", + srcs = ["defs.bzl"], + visibility = ["//visibility:private"], ) diff --git a/test/runtimes/README.md b/test/runtimes/README.md new file mode 100644 index 000000000..9dda1a728 --- /dev/null +++ b/test/runtimes/README.md @@ -0,0 +1,62 @@ +# gVisor Runtime Tests + +App Engine uses gvisor to sandbox application containers. The runtime tests aim +to test `runsc` compatibility with these +[standard runtimes](https://cloud.google.com/appengine/docs/standard/runtimes). +The test itself runs the language-defined tests inside the sandboxed standard +runtime container. + +Note: [Ruby runtime](https://cloud.google.com/appengine/docs/standard/ruby) is +currently in beta mode and so we do not run tests for it yet. + +### Testing Locally + +To run runtime tests individually from a given runtime, use the following table. + +Language | Version | Download Image | Run Test(s) +-------- | ------- | ------------------------------------------- | ----------- +Go | 1.12 | `make -C images load-runtimes_go1.12` | If the test name ends with `.go`, it is an on-disk test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 ( cd /usr/local/go/test ; go run run.go -v -- <TEST_NAME>... )` <br> Otherwise it is a tool test: <br> `docker run --runtime=runsc -it gvisor.dev/images/runtimes/go1.12 go tool dist test -v -no-rebuild ^TEST1$\|^TEST2$...` +Java | 11 | `make -C images load-runtimes_java11` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/java11 jtreg -agentvm -dir:/root/test/jdk -noreport -timeoutFactor:20 -verbose:summary <TEST_NAME>...` +NodeJS | 12.4.0 | `make -C images load-runtimes_nodejs12.4.0` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/nodejs12.4.0 python tools/test.py --timeout=180 <TEST_NAME>...` +Php | 7.3.6 | `make -C images load-runtimes_php7.3.6` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/php7.3.6 make test "TESTS=<TEST_NAME>..."` +Python | 3.7.3 | `make -C images load-runtimes_python3.7.3` | `docker run --runtime=runsc -it gvisor.dev/images/runtimes/python3.7.3 ./python -m test <TEST_NAME>...` + +To run an entire runtime test locally, use the following table. + +Note: java runtime test take 1+ hours with 16 cores. + +Language | Version | Running the test suite +-------- | ------- | ---------------------------------------- +Go | 1.12 | `make go1.12-runtime-tests{_vfs2}` +Java | 11 | `make java11-runtime-tests{_vfs2}` +NodeJS | 12.4.0 | `make nodejs12.4.0-runtime-tests{_vfs2}` +Php | 7.3.6 | `make php7.3.6-runtime-tests{_vfs2}` +Python | 3.7.3 | `make python3.7.3-runtime-tests{_vfs2}` + +#### Clean Up + +Sometimes when runtime tests fail or when the testing container itself crashes +unexpectedly, the containers are not removed or sometimes do not even exit. This +can cause some docker commands like `docker system prune` to hang forever. + +Here are some helpful commands (should be executed in order): + +```bash +docker ps -a # Lists all docker processes; useful when investigating hanging containers. +docker kill $(docker ps -a -q) # Kills all running containers. +docker rm $(docker ps -a -q) # Removes all exited containers. +docker system prune # Remove unused data. +``` + +### Testing Infrastructure + +There are 3 components to this tests infrastructure: + +- [`runner`](runner) - This is the test entrypoint. This is the binary is + invoked by `bazel test`. The runner spawns the target runtime container + using `runsc` and then copies over the `proctor` binary into the container. +- [`proctor`](proctor) - This binary acts as our agent inside the container + which communicates with the runner and actually executes tests. +- [`exclude`](exclude) - Holds a CSV file for each language runtime containing + the full path of tests that should be excluded from running along with a + reason for exclusion. diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl index db22029a8..702522d86 100644 --- a/test/runtimes/defs.bzl +++ b/test/runtimes/defs.bzl @@ -55,9 +55,13 @@ _runtime_test = rule( ), "_runner": attr.label( default = "//test/runtimes/runner:runner", + executable = True, + cfg = "target", ), "_proctor": attr.label( default = "//test/runtimes/proctor:proctor", + executable = True, + cfg = "target", ), }, test = True, diff --git a/test/runtimes/exclude_go1.12.csv b/test/runtimes/exclude/go1.12.csv index 81e02cf64..81e02cf64 100644 --- a/test/runtimes/exclude_go1.12.csv +++ b/test/runtimes/exclude/go1.12.csv diff --git a/test/runtimes/exclude_java11.csv b/test/runtimes/exclude/java11.csv index 4d62f7d3a..f779df8d5 100644 --- a/test/runtimes/exclude_java11.csv +++ b/test/runtimes/exclude/java11.csv @@ -1,9 +1,11 @@ test name,bug id,comment com/sun/crypto/provider/Cipher/PBE/PKCS12Cipher.java,,Fails in Docker +com/sun/jdi/InvokeHangTest.java,https://bugs.openjdk.java.net/browse/JDK-8218463, com/sun/jdi/NashornPopFrameTest.java,, com/sun/jdi/ProcessAttachTest.java,, com/sun/management/HotSpotDiagnosticMXBean/CheckOrigin.java,,Fails in Docker com/sun/management/OperatingSystemMXBean/GetCommittedVirtualMemorySize.java,, +com/sun/management/ThreadMXBean/ThreadCpuTimeArray.java,,Test assumes high CPU clock precision com/sun/management/UnixOperatingSystemMXBean/GetMaxFileDescriptorCount.sh,, com/sun/tools/attach/AttachSelf.java,, com/sun/tools/attach/BasicTests.java,, @@ -18,8 +20,16 @@ java/lang/ClassLoader/nativeLibrary/NativeLibraryTest.java,,Fails in Docker java/lang/module/ModuleDescriptorTest.java,, java/lang/String/nativeEncoding/StringPlatformChars.java,, java/net/CookieHandler/B6791927.java,,java.lang.RuntimeException: Expiration date shouldn't be 0 +java/net/ipv6tests/TcpTest.java,,java.net.ConnectException: Connection timed out (Connection timed out) +java/net/ipv6tests/UdpTest.java,,Times out +java/net/Inet6Address/B6558853.java,,Times out +java/net/InetAddress/CheckJNI.java,,java.net.ConnectException: Connection timed out (Connection timed out) java/net/InterfaceAddress/NetworkPrefixLength.java,b/78507103, +java/net/MulticastSocket/B6425815.java,,java.net.SocketException: Protocol not available (Error getting socket option) +java/net/MulticastSocket/B6427403.java,,java.net.SocketException: Protocol not available java/net/MulticastSocket/MulticastTTL.java,, +java/net/MulticastSocket/NetworkInterfaceEmptyGetInetAddressesTest.java,,java.net.SocketException: Protocol not available (Error getting socket option) +java/net/MulticastSocket/NoLoopbackPackets.java,,java.net.SocketException: Protocol not available java/net/MulticastSocket/Promiscuous.java,, java/net/MulticastSocket/SetLoopbackMode.java,, java/net/MulticastSocket/SetTTLAndGetTTL.java,, @@ -27,6 +37,7 @@ java/net/MulticastSocket/Test.java,, java/net/MulticastSocket/TestDefaults.java,, java/net/MulticastSocket/TimeToLive.java,, java/net/NetworkInterface/NetworkInterfaceStreamTest.java,, +java/net/Socket/LinkLocal.java,,java.net.SocketTimeoutException: Receive timed out java/net/Socket/SetSoLinger.java,b/78527327,SO_LINGER is not yet supported java/net/Socket/UrgentDataTest.java,b/111515323, java/net/SocketOption/OptionsTest.java,,Fails in Docker @@ -167,6 +178,10 @@ sun/management/jmxremote/bootstrap/RmiSslBootstrapTest.sh,, sun/management/jmxremote/startstop/JMXStartStopTest.java,, sun/management/jmxremote/startstop/JMXStatusPerfCountersTest.java,, sun/management/jmxremote/startstop/JMXStatusTest.java,, +sun/management/jdp/JdpDefaultsTest.java,, +sun/management/jdp/JdpJmxRemoteDynamicPortTest.java,, +sun/management/jdp/JdpOffTest.java,, +sun/management/jdp/JdpSpecificAddressTest.java,, sun/text/resources/LocaleDataTest.java,, sun/tools/jcmd/TestJcmdSanity.java,, sun/tools/jhsdb/AlternateHashingTest.java,, diff --git a/test/runtimes/exclude/nodejs12.4.0.csv b/test/runtimes/exclude/nodejs12.4.0.csv new file mode 100644 index 000000000..749fb9482 --- /dev/null +++ b/test/runtimes/exclude/nodejs12.4.0.csv @@ -0,0 +1,58 @@ +test name,bug id,comment +async-hooks/test-statwatcher.js,https://github.com/nodejs/node/issues/21425,Check for fix inclusion in nodejs releases after 2020-03-29 +benchmark/test-benchmark-fs.js,, +benchmark/test-benchmark-napi.js,, +doctool/test-make-doc.js,b/68848110,Expected to fail. +internet/test-dgram-multicast-set-interface-lo.js,b/162798882, +internet/test-doctool-versions.js,, +internet/test-uv-threadpool-schedule.js,, +parallel/test-cluster-dgram-reuse.js,b/64024294, +parallel/test-dgram-bind-fd.js,b/132447356, +parallel/test-dgram-socket-buffer-size.js,b/68847921, +parallel/test-dns-channel-timeout.js,b/161893056, +parallel/test-fs-access.js,, +parallel/test-fs-watchfile.js,,Flaky - File already exists error +parallel/test-fs-write-stream.js,b/166819807,Flaky +parallel/test-fs-write-stream-double-close,b/166819807,Flaky +parallel/test-fs-write-stream-throw-type-error.js,b/166819807,Flaky +parallel/test-http-writable-true-after-close.js,,Flaky - Mismatched <anonymous> function calls. Expected exactly 1 actual 2 +parallel/test-os.js,b/63997097, +parallel/test-net-server-listen-options.js,,Flaky - EADDRINUSE +parallel/test-process-uid-gid.js,, +parallel/test-tls-cli-min-version-1.0.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.1.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.2.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-min-version-1.3.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-max-version-1.2.js,,Flaky - EADDRINUSE +parallel/test-tls-cli-max-version-1.3.js,,Flaky - EADDRINUSE +parallel/test-tls-min-max-version.js,,Flaky - EADDRINUSE +pseudo-tty/test-assert-colors.js,b/162801321, +pseudo-tty/test-assert-no-color.js,b/162801321, +pseudo-tty/test-assert-position-indicator.js,b/162801321, +pseudo-tty/test-async-wrap-getasyncid-tty.js,b/162801321, +pseudo-tty/test-fatal-error.js,b/162801321, +pseudo-tty/test-handle-wrap-isrefed-tty.js,b/162801321, +pseudo-tty/test-readable-tty-keepalive.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset-process-exit.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset-signal.js,b/162801321, +pseudo-tty/test-set-raw-mode-reset.js,b/162801321, +pseudo-tty/test-stderr-stdout-handle-sigwinch.js,b/162801321, +pseudo-tty/test-stdout-read.js,b/162801321, +pseudo-tty/test-tty-color-support.js,b/162801321, +pseudo-tty/test-tty-isatty.js,b/162801321, +pseudo-tty/test-tty-stdin-call-end.js,b/162801321, +pseudo-tty/test-tty-stdin-end.js,b/162801321, +pseudo-tty/test-stdin-write.js,b/162801321, +pseudo-tty/test-tty-stdout-end.js,b/162801321, +pseudo-tty/test-tty-stdout-resize.js,b/162801321, +pseudo-tty/test-tty-stream-constructors.js,b/162801321, +pseudo-tty/test-tty-window-size.js,b/162801321, +pseudo-tty/test-tty-wrap.js,b/162801321, +pummel/test-heapdump-http2.js,,Flaky +pummel/test-net-pingpong.js,, +pummel/test-vm-memleak.js,b/162799436, +pummel/test-watch-file.js,,Flaky - Timeout +sequential/test-child-process-pass-fd.js,b/63926391,Flaky +sequential/test-https-connect-localport.js,,Flaky - EADDRINUSE +sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout +tick-processor/test-tick-processor-builtin.js,, diff --git a/test/runtimes/exclude_php7.3.6.csv b/test/runtimes/exclude/php7.3.6.csv index 0bef786c0..a73f3bcfb 100644 --- a/test/runtimes/exclude_php7.3.6.csv +++ b/test/runtimes/exclude/php7.3.6.csv @@ -8,23 +8,32 @@ ext/mbstring/tests/bug77165.phpt,, ext/mbstring/tests/bug77454.phpt,, ext/mbstring/tests/mb_convert_encoding_leak.phpt,, ext/mbstring/tests/mb_strrpos_encoding_3rd_param.phpt,, +ext/session/tests/session_module_name_variation4.phpt,,Flaky ext/session/tests/session_set_save_handler_class_018.phpt,, ext/session/tests/session_set_save_handler_iface_003.phpt,, +ext/session/tests/session_set_save_handler_sid_001.phpt,, ext/session/tests/session_set_save_handler_variation4.phpt,, -ext/session/tests/session_set_save_handler_variation5.phpt,, -ext/standard/tests/file/filetype_variation.phpt,, -ext/standard/tests/file/fopen_variation19.phpt,, +ext/standard/tests/file/disk.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_basic.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_error.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_free_space_variation.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_basic.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_error.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/disk_total_space_variation.phpt,https://bugs.php.net/bug.php?id=80018, +ext/standard/tests/file/fopen_variation19.phpt,b/162894964, +ext/standard/tests/file/lstat_stat_variation14.phpt,,Flaky ext/standard/tests/file/php_fd_wrapper_01.phpt,, ext/standard/tests/file/php_fd_wrapper_02.phpt,, ext/standard/tests/file/php_fd_wrapper_03.phpt,, ext/standard/tests/file/php_fd_wrapper_04.phpt,, -ext/standard/tests/file/realpath_bug77484.phpt,, +ext/standard/tests/file/realpath_bug77484.phpt,b/162894969, ext/standard/tests/file/rename_variation.phpt,b/68717309, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,, -ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation4.phpt,b/162895341, +ext/standard/tests/file/symlink_link_linkinfo_is_link_variation8.phpt,b/162896223, ext/standard/tests/general_functions/escapeshellarg_bug71270.phpt,, ext/standard/tests/general_functions/escapeshellcmd_bug71270.phpt,, -ext/standard/tests/network/bug20134.phpt,, +ext/standard/tests/streams/proc_open_bug60120.phpt,,Flaky until php-src 3852a35fdbcb +ext/standard/tests/streams/proc_open_bug69900.phpt,,Flaky ext/standard/tests/streams/stream_socket_sendto.phpt,, ext/standard/tests/strings/007.phpt,, sapi/cli/tests/upload_2G.phpt,, @@ -34,4 +43,4 @@ tests/output/stream_isatty_in-out-err.phpt,, tests/output/stream_isatty_in-out.phpt,b/68720299, tests/output/stream_isatty_out-err.phpt,b/68720311, tests/output/stream_isatty_out.phpt,b/68720325, -Zend/tests/concat_003.phpt,, +Zend/tests/concat_003.phpt,b/162896021, diff --git a/test/runtimes/exclude/python3.7.3.csv b/test/runtimes/exclude/python3.7.3.csv new file mode 100644 index 000000000..8760f8951 --- /dev/null +++ b/test/runtimes/exclude/python3.7.3.csv @@ -0,0 +1,21 @@ +test name,bug id,comment +test_asyncio,,Fails on Docker. +test_asyncore,b/162973328, +test_epoll,b/162983393, +test_fcntl,b/162978767,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. +test_httplib,b/163000009,OSError: [Errno 98] Address already in use +test_imaplib,b/162979661, +test_logging,b/162980079, +test_multiprocessing_fork,,Flaky. Sometimes times out. +test_multiprocessing_forkserver,,Flaky. Sometimes times out. +test_multiprocessing_main_handling,,Flaky. Sometimes times out. +test_multiprocessing_spawn,,Flaky. Sometimes times out. +test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted +test_pty,b/162979921, +test_readline,b/162980389,TestReadline hangs forever +test_resource,b/76174079, +test_selectors,b/76116849,OSError not raised with epoll +test_smtplib,b/162980434,unclosed sockets +test_signal,,Flaky - signal: alarm clock +test_socket,b/75983380, +test_subprocess,b/162980831, diff --git a/test/runtimes/exclude_nodejs12.4.0.csv b/test/runtimes/exclude_nodejs12.4.0.csv deleted file mode 100644 index 4eb0a4807..000000000 --- a/test/runtimes/exclude_nodejs12.4.0.csv +++ /dev/null @@ -1,51 +0,0 @@ -test name,bug id,comment -benchmark/test-benchmark-fs.js,, -benchmark/test-benchmark-module.js,, -benchmark/test-benchmark-napi.js,, -doctool/test-make-doc.js,b/68848110,Expected to fail. -fixtures/test-error-first-line-offset.js,, -fixtures/test-fs-readfile-error.js,, -fixtures/test-fs-stat-sync-overflow.js,, -internet/test-dgram-broadcast-multi-process.js,, -internet/test-dgram-multicast-multi-process.js,, -internet/test-dgram-multicast-set-interface-lo.js,, -internet/test-doctool-versions.js,, -internet/test-uv-threadpool-schedule.js,, -parallel/test-cluster-dgram-reuse.js,b/64024294, -parallel/test-dgram-bind-fd.js,b/132447356, -parallel/test-dgram-create-socket-handle-fd.js,b/132447238, -parallel/test-dgram-createSocket-type.js,b/68847739, -parallel/test-dgram-socket-buffer-size.js,b/68847921, -parallel/test-fs-access.js,, -parallel/test-fs-write-stream-double-close.js,, -parallel/test-fs-write-stream-throw-type-error.js,b/110226209, -parallel/test-fs-write-stream.js,, -parallel/test-http2-respond-file-error-pipe-offset.js,, -parallel/test-os.js,, -parallel/test-process-uid-gid.js,, -pseudo-tty/test-assert-colors.js,, -pseudo-tty/test-assert-no-color.js,, -pseudo-tty/test-assert-position-indicator.js,, -pseudo-tty/test-async-wrap-getasyncid-tty.js,, -pseudo-tty/test-fatal-error.js,, -pseudo-tty/test-handle-wrap-isrefed-tty.js,, -pseudo-tty/test-readable-tty-keepalive.js,, -pseudo-tty/test-set-raw-mode-reset-process-exit.js,, -pseudo-tty/test-set-raw-mode-reset-signal.js,, -pseudo-tty/test-set-raw-mode-reset.js,, -pseudo-tty/test-stderr-stdout-handle-sigwinch.js,, -pseudo-tty/test-stdout-read.js,, -pseudo-tty/test-tty-color-support.js,, -pseudo-tty/test-tty-isatty.js,, -pseudo-tty/test-tty-stdin-call-end.js,, -pseudo-tty/test-tty-stdin-end.js,, -pseudo-tty/test-stdin-write.js,, -pseudo-tty/test-tty-stdout-end.js,, -pseudo-tty/test-tty-stdout-resize.js,, -pseudo-tty/test-tty-stream-constructors.js,, -pseudo-tty/test-tty-window-size.js,, -pseudo-tty/test-tty-wrap.js,, -pummel/test-net-pingpong.js,, -pummel/test-vm-memleak.js,, -sequential/test-net-bytes-per-incoming-chunk-overhead.js,,flaky - timeout -tick-processor/test-tick-processor-builtin.js,, diff --git a/test/runtimes/exclude_python3.7.3.csv b/test/runtimes/exclude_python3.7.3.csv deleted file mode 100644 index 2b9947212..000000000 --- a/test/runtimes/exclude_python3.7.3.csv +++ /dev/null @@ -1,27 +0,0 @@ -test name,bug id,comment -test_asynchat,b/76031995,SO_REUSEADDR -test_asyncio,,Fails on Docker. -test_asyncore,b/76031995,SO_REUSEADDR -test_epoll,, -test_fcntl,,fcntl invalid argument -- artificial test to make sure something works in 64 bit mode. -test_ftplib,,Fails in Docker -test_httplib,b/76031995,SO_REUSEADDR -test_imaplib,, -test_logging,, -test_multiprocessing_fork,,Flaky. Sometimes times out. -test_multiprocessing_forkserver,,Flaky. Sometimes times out. -test_multiprocessing_main_handling,,Flaky. Sometimes times out. -test_multiprocessing_spawn,,Flaky. Sometimes times out. -test_nntplib,b/76031995,tests should not set SO_REUSEADDR -test_poplib,,Fails on Docker -test_posix,b/76174079,posix.sched_get_priority_min not implemented + posix.sched_rr_get_interval not permitted -test_pty,b/76157709,out of pty devices -test_readline,b/76157709,out of pty devices -test_resource,b/76174079, -test_selectors,b/76116849,OSError not raised with epoll -test_smtplib,b/76031995,SO_REUSEADDR and unclosed sockets -test_socket,b/75983380, -test_ssl,b/76031995,SO_REUSEADDR -test_subprocess,, -test_support,b/76031995,SO_REUSEADDR -test_telnetlib,b/76031995,SO_REUSEADDR diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD index f76e2ddc0..f66371265 100644 --- a/test/runtimes/proctor/BUILD +++ b/test/runtimes/proctor/BUILD @@ -1,28 +1,10 @@ -load("//tools:defs.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) go_binary( name = "proctor", - srcs = [ - "go.go", - "java.go", - "nodejs.go", - "php.go", - "proctor.go", - "python.go", - ], - pure = True, + srcs = ["main.go"], visibility = ["//test/runtimes:__pkg__"], -) - -go_test( - name = "proctor_test", - size = "small", - srcs = ["proctor_test.go"], - library = ":proctor", - pure = True, - deps = [ - "//pkg/test/testutil", - ], + deps = ["//test/runtimes/proctor/lib"], ) diff --git a/test/runtimes/proctor/lib/BUILD b/test/runtimes/proctor/lib/BUILD new file mode 100644 index 000000000..0c8367dfe --- /dev/null +++ b/test/runtimes/proctor/lib/BUILD @@ -0,0 +1,24 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "lib", + srcs = [ + "go.go", + "java.go", + "lib.go", + "nodejs.go", + "php.go", + "python.go", + ], + visibility = ["//test/runtimes/proctor:__pkg__"], +) + +go_test( + name = "lib_test", + size = "small", + srcs = ["lib_test.go"], + library = ":lib", + deps = ["//pkg/test/testutil"], +) diff --git a/test/runtimes/proctor/go.go b/test/runtimes/proctor/lib/go.go index d0ae844e6..5c48fb60b 100644 --- a/test/runtimes/proctor/go.go +++ b/test/runtimes/proctor/lib/go.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" @@ -59,7 +59,7 @@ func (goRunner) ListTests() ([]string, error) { } // Go tests on disk. - diskSlice, err := search(goTestDir, goTestRegEx) + diskSlice, err := Search(goTestDir, goTestRegEx) if err != nil { return nil, err } diff --git a/test/runtimes/proctor/java.go b/test/runtimes/proctor/lib/java.go index d456fa681..3105011ff 100644 --- a/test/runtimes/proctor/java.go +++ b/test/runtimes/proctor/lib/java.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" diff --git a/test/runtimes/proctor/proctor.go b/test/runtimes/proctor/lib/lib.go index 9e0642424..f2ba82498 100644 --- a/test/runtimes/proctor/proctor.go +++ b/test/runtimes/proctor/lib/lib.go @@ -12,20 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary proctor runs the test for a particular runtime. It is meant to be -// included in Docker images for all runtime tests. -package main +// Package lib contains proctor functions. +package lib import ( - "flag" "fmt" - "log" "os" "os/exec" "os/signal" "path/filepath" "regexp" - "strings" "syscall" ) @@ -42,66 +38,8 @@ type TestRunner interface { TestCmds(tests []string) []*exec.Cmd } -var ( - runtime = flag.String("runtime", "", "name of runtime") - list = flag.Bool("list", false, "list all available tests") - testNames = flag.String("tests", "", "run a subset of the available tests") - pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") -) - -func main() { - flag.Parse() - - if *pause { - pauseAndReap() - panic("pauseAndReap should never return") - } - - if *runtime == "" { - log.Fatalf("runtime flag must be provided") - } - - tr, err := testRunnerForRuntime(*runtime) - if err != nil { - log.Fatalf("%v", err) - } - - // List tests. - if *list { - tests, err := tr.ListTests() - if err != nil { - log.Fatalf("failed to list tests: %v", err) - } - for _, test := range tests { - fmt.Println(test) - } - return - } - - var tests []string - if *testNames == "" { - // Run every test. - tests, err = tr.ListTests() - if err != nil { - log.Fatalf("failed to get all tests: %v", err) - } - } else { - // Run subset of test. - tests = strings.Split(*testNames, ",") - } - - // Run tests. - cmds := tr.TestCmds(tests) - for _, cmd := range cmds { - cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr - if err := cmd.Run(); err != nil { - log.Fatalf("FAIL: %v", err) - } - } -} - -// testRunnerForRuntime returns a new TestRunner for the given runtime. -func testRunnerForRuntime(runtime string) (TestRunner, error) { +// TestRunnerForRuntime returns a new TestRunner for the given runtime. +func TestRunnerForRuntime(runtime string) (TestRunner, error) { switch runtime { case "go": return goRunner{}, nil @@ -117,8 +55,8 @@ func testRunnerForRuntime(runtime string) (TestRunner, error) { return nil, fmt.Errorf("invalid runtime %q", runtime) } -// pauseAndReap is like init. It runs forever and reaps any children. -func pauseAndReap() { +// PauseAndReap is like init. It runs forever and reaps any children. +func PauseAndReap() { // Get notified of any new children. ch := make(chan os.Signal, 1) signal.Notify(ch, syscall.SIGCHLD) @@ -138,9 +76,9 @@ func pauseAndReap() { } } -// search is a helper function to find tests in the given directory that match +// Search is a helper function to find tests in the given directory that match // the regex. -func search(root string, testFilter *regexp.Regexp) ([]string, error) { +func Search(root string, testFilter *regexp.Regexp) ([]string, error) { var testSlice []string err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { diff --git a/test/runtimes/proctor/proctor_test.go b/test/runtimes/proctor/lib/lib_test.go index 6ef2de085..1193d2e28 100644 --- a/test/runtimes/proctor/proctor_test.go +++ b/test/runtimes/proctor/lib/lib_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "io/ioutil" @@ -47,7 +47,7 @@ func TestSearchEmptyDir(t *testing.T) { var want []string testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := search(td, testFilter) + got, err := Search(td, testFilter) if err != nil { t.Errorf("search error: %v", err) } @@ -116,7 +116,7 @@ func TestSearch(t *testing.T) { } testFilter := regexp.MustCompile(`^test-[^-].+\.tc$`) - got, err := search(td, testFilter) + got, err := Search(td, testFilter) if err != nil { t.Errorf("search error: %v", err) } diff --git a/test/runtimes/proctor/nodejs.go b/test/runtimes/proctor/lib/nodejs.go index 23d6edc72..320597aa5 100644 --- a/test/runtimes/proctor/nodejs.go +++ b/test/runtimes/proctor/lib/nodejs.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "os/exec" @@ -32,7 +32,7 @@ var _ TestRunner = nodejsRunner{} // ListTests implements TestRunner.ListTests. func (nodejsRunner) ListTests() ([]string, error) { - testSlice, err := search(nodejsTestDir, nodejsTestRegEx) + testSlice, err := Search(nodejsTestDir, nodejsTestRegEx) if err != nil { return nil, err } @@ -41,6 +41,6 @@ func (nodejsRunner) ListTests() ([]string, error) { // TestCmds implements TestRunner.TestCmds. func (nodejsRunner) TestCmds(tests []string) []*exec.Cmd { - args := append([]string{filepath.Join("tools", "test.py")}, tests...) + args := append([]string{filepath.Join("tools", "test.py"), "--timeout=180"}, tests...) return []*exec.Cmd{exec.Command("/usr/bin/python", args...)} } diff --git a/test/runtimes/proctor/php.go b/test/runtimes/proctor/lib/php.go index 6a83d64e3..b67a60a97 100644 --- a/test/runtimes/proctor/php.go +++ b/test/runtimes/proctor/lib/php.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "os/exec" @@ -29,7 +29,7 @@ var _ TestRunner = phpRunner{} // ListTests implements TestRunner.ListTests. func (phpRunner) ListTests() ([]string, error) { - testSlice, err := search(".", phpTestRegEx) + testSlice, err := Search(".", phpTestRegEx) if err != nil { return nil, err } diff --git a/test/runtimes/proctor/python.go b/test/runtimes/proctor/lib/python.go index 7c598801b..429bfd850 100644 --- a/test/runtimes/proctor/python.go +++ b/test/runtimes/proctor/lib/python.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "fmt" diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go new file mode 100644 index 000000000..e5607ac92 --- /dev/null +++ b/test/runtimes/proctor/main.go @@ -0,0 +1,85 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary proctor runs the test for a particular runtime. It is meant to be +// included in Docker images for all runtime tests. +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strings" + + "gvisor.dev/gvisor/test/runtimes/proctor/lib" +) + +var ( + runtime = flag.String("runtime", "", "name of runtime") + list = flag.Bool("list", false, "list all available tests") + testNames = flag.String("tests", "", "run a subset of the available tests") + pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children") +) + +func main() { + flag.Parse() + + if *pause { + lib.PauseAndReap() + panic("pauseAndReap should never return") + } + + if *runtime == "" { + log.Fatalf("runtime flag must be provided") + } + + tr, err := lib.TestRunnerForRuntime(*runtime) + if err != nil { + log.Fatalf("%v", err) + } + + // List tests. + if *list { + tests, err := tr.ListTests() + if err != nil { + log.Fatalf("failed to list tests: %v", err) + } + for _, test := range tests { + fmt.Println(test) + } + return + } + + var tests []string + if *testNames == "" { + // Run every test. + tests, err = tr.ListTests() + if err != nil { + log.Fatalf("failed to get all tests: %v", err) + } + } else { + // Run subset of test. + tests = strings.Split(*testNames, ",") + } + + // Run tests. + cmds := tr.TestCmds(tests) + for _, cmd := range cmds { + cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr + if err := cmd.Run(); err != nil { + log.Fatalf("FAIL: %v", err) + } + } +} diff --git a/test/runtimes/runner/BUILD b/test/runtimes/runner/BUILD index dc0d5d5b4..70cc01594 100644 --- a/test/runtimes/runner/BUILD +++ b/test/runtimes/runner/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_binary", "go_test") +load("//tools:defs.bzl", "go_binary") package(licenses = ["notice"]) @@ -7,16 +7,5 @@ go_binary( testonly = 1, srcs = ["main.go"], visibility = ["//test/runtimes:__pkg__"], - deps = [ - "//pkg/log", - "//pkg/test/dockerutil", - "//pkg/test/testutil", - ], -) - -go_test( - name = "exclude_test", - size = "small", - srcs = ["exclude_test.go"], - library = ":runner", + deps = ["//test/runtimes/runner/lib"], ) diff --git a/test/runtimes/runner/lib/BUILD b/test/runtimes/runner/lib/BUILD new file mode 100644 index 000000000..d308f41b0 --- /dev/null +++ b/test/runtimes/runner/lib/BUILD @@ -0,0 +1,22 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "lib", + testonly = 1, + srcs = ["lib.go"], + visibility = ["//test/runtimes/runner:__pkg__"], + deps = [ + "//pkg/log", + "//pkg/test/dockerutil", + "//pkg/test/testutil", + ], +) + +go_test( + name = "lib_test", + size = "small", + srcs = ["exclude_test.go"], + library = ":lib", +) diff --git a/test/runtimes/runner/exclude_test.go b/test/runtimes/runner/lib/exclude_test.go index 67c2170c8..f996e895b 100644 --- a/test/runtimes/runner/exclude_test.go +++ b/test/runtimes/runner/lib/exclude_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package main +package lib import ( "flag" @@ -20,6 +20,8 @@ import ( "testing" ) +var excludeFile = flag.String("exclude_file", "", "file to test (standard format)") + func TestMain(m *testing.M) { flag.Parse() os.Exit(m.Run()) @@ -27,7 +29,7 @@ func TestMain(m *testing.M) { // Test that the exclude file parses without error. func TestExcludelist(t *testing.T) { - ex, err := getExcludes() + ex, err := getExcludes(*excludeFile) if err != nil { t.Fatalf("error parsing exclude file: %v", err) } diff --git a/test/runtimes/runner/lib/lib.go b/test/runtimes/runner/lib/lib.go new file mode 100644 index 000000000..78285cb0e --- /dev/null +++ b/test/runtimes/runner/lib/lib.go @@ -0,0 +1,185 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package lib provides utilities for runner. +package lib + +import ( + "context" + "encoding/csv" + "fmt" + "io" + "os" + "sort" + "strings" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/test/dockerutil" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// RunTests is a helper that is called by main. It exists so that we can run +// defered functions before exiting. It returns an exit code that should be +// passed to os.Exit. +func RunTests(lang, image, excludeFile string, batchSize int, timeout time.Duration) int { + // Get tests to exclude.. + excludes, err := getExcludes(excludeFile) + if err != nil { + fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) + return 1 + } + + // Construct the shared docker instance. + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(lang)) + defer d.CleanUp(ctx) + + if err := testutil.TouchShardStatusFile(); err != nil { + fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err) + return 1 + } + + // Get a slice of tests to run. This will also start a single Docker + // container that will be used to run each test. The final test will + // stop the Docker container. + tests, err := getTests(ctx, d, lang, image, batchSize, timeout, excludes) + if err != nil { + fmt.Fprintf(os.Stderr, "%s\n", err.Error()) + return 1 + } + + m := testing.MainStart(testDeps{}, tests, nil, nil) + return m.Run() +} + +// getTests executes all tests as table tests. +func getTests(ctx context.Context, d *dockerutil.Container, lang, image string, batchSize int, timeout time.Duration, excludes map[string]struct{}) ([]testing.InternalTest, error) { + // Start the container. + opts := dockerutil.RunOpts{ + Image: fmt.Sprintf("runtimes/%s", image), + } + d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor") + if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil { + return nil, fmt.Errorf("docker run failed: %v", err) + } + + // Get a list of all tests in the image. + list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--list") + if err != nil { + return nil, fmt.Errorf("docker exec failed: %v", err) + } + + // Calculate a subset of tests to run corresponding to the current + // shard. + tests := strings.Fields(list) + sort.Strings(tests) + indices, err := testutil.TestIndicesForShard(len(tests)) + if err != nil { + return nil, fmt.Errorf("TestsForShard() failed: %v", err) + } + + var itests []testing.InternalTest + for i := 0; i < len(indices); i += batchSize { + var tcs []string + end := i + batchSize + if end > len(indices) { + end = len(indices) + } + for _, tc := range indices[i:end] { + // Add test if not excluded. + if _, ok := excludes[tests[tc]]; ok { + log.Infof("Skipping test case %s\n", tests[tc]) + continue + } + tcs = append(tcs, tests[tc]) + } + itests = append(itests, testing.InternalTest{ + Name: strings.Join(tcs, ", "), + F: func(t *testing.T) { + var ( + now = time.Now() + done = make(chan struct{}) + output string + err error + ) + + go func() { + fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) + output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", lang, "--tests", strings.Join(tcs, ",")) + close(done) + }() + + select { + case <-done: + if err == nil { + fmt.Printf("PASS: (%v)\n\n", time.Since(now)) + return + } + t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) + case <-time.After(timeout): + t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) + } + }, + }) + } + + return itests, nil +} + +// getBlacklist reads the exclude file and returns a set of test names to +// exclude. +func getExcludes(excludeFile string) (map[string]struct{}, error) { + excludes := make(map[string]struct{}) + if excludeFile == "" { + return excludes, nil + } + f, err := os.Open(excludeFile) + if err != nil { + return nil, err + } + defer f.Close() + + r := csv.NewReader(f) + + // First line is header. Skip it. + if _, err := r.Read(); err != nil { + return nil, err + } + + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + excludes[record[0]] = struct{}{} + } + return excludes, nil +} + +// testDeps implements testing.testDeps (an unexported interface), and is +// required to use testing.MainStart. +type testDeps struct{} + +func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil } +func (f testDeps) StartCPUProfile(io.Writer) error { return nil } +func (f testDeps) StopCPUProfile() {} +func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } +func (f testDeps) ImportPath() string { return "" } +func (f testDeps) StartTestLog(io.Writer) {} +func (f testDeps) StopTestLog() error { return nil } diff --git a/test/runtimes/runner/main.go b/test/runtimes/runner/main.go index 948e7cf9c..ec79a22c2 100644 --- a/test/runtimes/runner/main.go +++ b/test/runtimes/runner/main.go @@ -16,20 +16,12 @@ package main import ( - "context" - "encoding/csv" "flag" "fmt" - "io" "os" - "sort" - "strings" - "testing" "time" - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/test/dockerutil" - "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/test/runtimes/runner/lib" ) var ( @@ -37,169 +29,14 @@ var ( image = flag.String("image", "", "docker image with runtime tests") excludeFile = flag.String("exclude_file", "", "file containing list of tests to exclude, in CSV format with fields: test name, bug id, comment") batchSize = flag.Int("batch", 50, "number of test cases run in one command") + timeout = flag.Duration("timeout", 90*time.Minute, "batch timeout") ) -// Wait time for each test to run. -const timeout = 90 * time.Minute - func main() { flag.Parse() if *lang == "" || *image == "" { fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n") os.Exit(1) } - os.Exit(runTests()) -} - -// runTests is a helper that is called by main. It exists so that we can run -// defered functions before exiting. It returns an exit code that should be -// passed to os.Exit. -func runTests() int { - // Get tests to exclude.. - excludes, err := getExcludes() - if err != nil { - fmt.Fprintf(os.Stderr, "Error getting exclude list: %s\n", err.Error()) - return 1 - } - - // Construct the shared docker instance. - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, testutil.DefaultLogger(*lang)) - defer d.CleanUp(ctx) - - if err := testutil.TouchShardStatusFile(); err != nil { - fmt.Fprintf(os.Stderr, "error touching status shard file: %v\n", err) - return 1 - } - - // Get a slice of tests to run. This will also start a single Docker - // container that will be used to run each test. The final test will - // stop the Docker container. - tests, err := getTests(ctx, d, excludes) - if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", err.Error()) - return 1 - } - - m := testing.MainStart(testDeps{}, tests, nil, nil) - return m.Run() -} - -// getTests executes all tests as table tests. -func getTests(ctx context.Context, d *dockerutil.Container, excludes map[string]struct{}) ([]testing.InternalTest, error) { - // Start the container. - opts := dockerutil.RunOpts{ - Image: fmt.Sprintf("runtimes/%s", *image), - } - d.CopyFiles(&opts, "/proctor", "test/runtimes/proctor/proctor") - if err := d.Spawn(ctx, opts, "/proctor/proctor", "--pause"); err != nil { - return nil, fmt.Errorf("docker run failed: %v", err) - } - - // Get a list of all tests in the image. - list, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--list") - if err != nil { - return nil, fmt.Errorf("docker exec failed: %v", err) - } - - // Calculate a subset of tests to run corresponding to the current - // shard. - tests := strings.Fields(list) - sort.Strings(tests) - indices, err := testutil.TestIndicesForShard(len(tests)) - if err != nil { - return nil, fmt.Errorf("TestsForShard() failed: %v", err) - } - - var itests []testing.InternalTest - for i := 0; i < len(indices); i += *batchSize { - var tcs []string - end := i + *batchSize - if end > len(indices) { - end = len(indices) - } - for _, tc := range indices[i:end] { - // Add test if not excluded. - if _, ok := excludes[tests[tc]]; ok { - log.Infof("Skipping test case %s\n", tests[tc]) - continue - } - tcs = append(tcs, tests[tc]) - } - itests = append(itests, testing.InternalTest{ - Name: strings.Join(tcs, ", "), - F: func(t *testing.T) { - var ( - now = time.Now() - done = make(chan struct{}) - output string - err error - ) - - go func() { - fmt.Printf("RUNNING the following in a batch\n%s\n", strings.Join(tcs, "\n")) - output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/proctor/proctor", "--runtime", *lang, "--tests", strings.Join(tcs, ",")) - close(done) - }() - - select { - case <-done: - if err == nil { - fmt.Printf("PASS: (%v)\n\n", time.Since(now)) - return - } - t.Errorf("FAIL: (%v):\n%s\n", time.Since(now), output) - case <-time.After(timeout): - t.Errorf("TIMEOUT: (%v):\n%s\n", time.Since(now), output) - } - }, - }) - } - - return itests, nil + os.Exit(lib.RunTests(*lang, *image, *excludeFile, *batchSize, *timeout)) } - -// getBlacklist reads the exclude file and returns a set of test names to -// exclude. -func getExcludes() (map[string]struct{}, error) { - excludes := make(map[string]struct{}) - if *excludeFile == "" { - return excludes, nil - } - f, err := os.Open(*excludeFile) - if err != nil { - return nil, err - } - defer f.Close() - - r := csv.NewReader(f) - - // First line is header. Skip it. - if _, err := r.Read(); err != nil { - return nil, err - } - - for { - record, err := r.Read() - if err == io.EOF { - break - } - if err != nil { - return nil, err - } - excludes[record[0]] = struct{}{} - } - return excludes, nil -} - -// testDeps implements testing.testDeps (an unexported interface), and is -// required to use testing.MainStart. -type testDeps struct{} - -func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil } -func (f testDeps) StartCPUProfile(io.Writer) error { return nil } -func (f testDeps) StopCPUProfile() {} -func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil } -func (f testDeps) ImportPath() string { return "" } -func (f testDeps) StartTestLog(io.Writer) {} -func (f testDeps) StopTestLog() error { return nil } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index c19b30b4a..96a775456 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -4,97 +4,84 @@ package(licenses = ["notice"]) syscall_test( test = "//test/syscalls/linux:32bit_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:accept_bind_stream_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:access_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:affinity_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:aio_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:alarm_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:arch_prctl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:bad_test", - vfs2 = "True", ) syscall_test( size = "large", add_overlay = True, test = "//test/syscalls/linux:bind_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:brk_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_capability_test", - vfs2 = "True", ) syscall_test( size = "large", + # Produce too many logs in the debug mode. + debug = False, shard_count = 50, # Takes too long for TSAN. Since this is kind of a stress test that doesn't # involve much concurrency, TSAN's usefulness here is limited anyway. tags = ["nogotsan"], test = "//test/syscalls/linux:socket_stress_test", - vfs2 = "True", + vfs2 = False, ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chdir_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chmod_test", - vfs2 = "True", ) syscall_test( @@ -102,116 +89,96 @@ syscall_test( add_overlay = True, test = "//test/syscalls/linux:chown_test", use_tmpfs = True, # chwon tests require gofer to be running as root. - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chroot_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:clock_getres_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:clock_gettime_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:clock_nanosleep_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:concurrency_test", - vfs2 = "True", ) syscall_test( add_uds_tree = True, test = "//test/syscalls/linux:connect_external_test", use_tmpfs = True, - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:creat_test", - vfs2 = "True", ) syscall_test( fuse = "True", test = "//test/syscalls/linux:dev_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:dup_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:epoll_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:eventfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:exceptions_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:exec_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:exec_binary_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:exit_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fadvise64_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fallocate_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fault_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fchdir_test", - vfs2 = "True", ) syscall_test( @@ -223,204 +190,176 @@ syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:flock_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fork_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fpsig_fork_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fpsig_nested_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fsync_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:futex_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getcpu_host_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getcpu_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:getdents_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getrandom_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getrusage_test", - vfs2 = "True", ) syscall_test( size = "medium", - add_overlay = False, # TODO(gvisor.dev/issue/317): enable when fixed. + add_overlay = True, test = "//test/syscalls/linux:inotify_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:ioctl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:iptables_test", - vfs2 = "True", +) + +syscall_test( + test = "//test/syscalls/linux:ip6tables_test", ) syscall_test( size = "large", shard_count = 5, test = "//test/syscalls/linux:itimer_test", - vfs2 = "True", +) + +syscall_test( + test = "//test/syscalls/linux:kcov_test", ) syscall_test( test = "//test/syscalls/linux:kill_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:link_test", use_tmpfs = True, # gofer needs CAP_DAC_READ_SEARCH to use AT_EMPTY_PATH with linkat(2) - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:lseek_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:madvise_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:memory_accounting_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:mempolicy_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:mincore_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:mkdir_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:mknod_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:mmap_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:mount_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:mremap_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:msync_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:munmap_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:network_namespace_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_create_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:packet_socket_raw_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:packet_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:partial_bad_buffer_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:pause_test", - vfs2 = "True", ) syscall_test( @@ -428,7 +367,6 @@ syscall_test( # Takes too long under gotsan to run. tags = ["nogotsan"], test = "//test/syscalls/linux:ping_socket_test", - vfs2 = "True", ) syscall_test( @@ -436,229 +374,188 @@ syscall_test( add_overlay = True, shard_count = 5, test = "//test/syscalls/linux:pipe_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:poll_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:ppoll_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:prctl_setuid_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:prctl_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pread64_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:preadv_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:preadv2_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:priority_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:proc_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_oomscore_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_smaps_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_uid_gid_map_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:pselect_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:ptrace_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:pty_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:pty_root_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwritev2_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwrite64_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_hdrincl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_icmp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:read_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:readahead_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:readv_socket_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:readv_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:rename_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rlimits_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rseq_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rtsignal_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:signalfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sched_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sched_yield_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:seccomp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:select_test", - vfs2 = "True", ) syscall_test( shard_count = 20, test = "//test/syscalls/linux:semaphore_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sendfile_socket_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sendfile_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:splice_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigaction_test", - vfs2 = "True", ) # TODO(b/119826902): Enable once the test passes in runsc. @@ -666,62 +563,52 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:sigiret_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigprocmask_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:sigstop_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigtimedwait_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:shm_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_abstract_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_domain_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:socket_filesystem_non_blocking_test", - vfs2 = "True", ) syscall_test( @@ -729,14 +616,12 @@ syscall_test( add_overlay = True, shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", - vfs2 = "True", ) syscall_test( @@ -745,122 +630,116 @@ syscall_test( # Takes too long for TSAN. Creates a lot of TCP sockets. tags = ["nogotsan"], test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ip_tcp_loopback_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ip_udp_loopback_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test", - vfs2 = "True", +) + +syscall_test( + size = "medium", + # Takes too long under gotsan to run. + tags = ["nogotsan"], + test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_nogotsan_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_netlink_test", +) + +syscall_test( + test = "//test/syscalls/linux:socket_ipv6_udp_unbound_loopback_netlink_test", ) syscall_test( test = "//test/syscalls/linux:socket_ip_unbound_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netdevice_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_route_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_uevent_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_blocking_local_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_blocking_ip_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_non_stream_blocking_local_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test", - vfs2 = "True", ) syscall_test( size = "large", test = "//test/syscalls/linux:socket_stream_blocking_local_test", - vfs2 = "True", ) syscall_test( size = "large", test = "//test/syscalls/linux:socket_stream_blocking_tcp_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_stream_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_stream_nonblock_local_test", - vfs2 = "True", ) syscall_test( @@ -868,13 +747,11 @@ syscall_test( size = "enormous", shard_count = 5, test = "//test/syscalls/linux:socket_unix_dgram_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_dgram_non_blocking_test", - vfs2 = "True", ) syscall_test( @@ -882,7 +759,6 @@ syscall_test( add_overlay = True, shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", - vfs2 = "True", ) syscall_test( @@ -890,134 +766,112 @@ syscall_test( size = "enormous", shard_count = 5, test = "//test/syscalls/linux:socket_unix_seqpacket_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_stream_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_abstract_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_dgram_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_filesystem_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 10, test = "//test/syscalls/linux:socket_unix_unbound_seqpacket_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:statfs_test", - vfs2 = "True", + use_tmpfs = True, # Test specifically relies on TEST_TMPDIR to be tmpfs. ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:stat_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:stat_times_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sticky_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:symlink_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sync_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sync_file_range_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sysinfo_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:syslog_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sysret_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 10, test = "//test/syscalls/linux:tcp_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:tgkill_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:timerfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:timers_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:time_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:tkill_test", - vfs2 = "True", ) syscall_test( @@ -1027,18 +881,15 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:tuntap_test", - vfs2 = "True", ) syscall_test( add_hostinet = True, test = "//test/syscalls/linux:tuntap_hostinet_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:udp_bind_test", - vfs2 = "True", ) syscall_test( @@ -1046,80 +897,65 @@ syscall_test( add_hostinet = True, shard_count = 10, test = "//test/syscalls/linux:udp_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:uidgid_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:uname_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:unlink_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:unshare_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:utimes_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:vdso_clock_gettime_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vdso_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vsyscall_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vfork_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:wait_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:write_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_unix_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_tcp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_udp_test", - vfs2 = "True", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 66a31cd28..d9dbe2267 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -22,6 +22,7 @@ exports_files( "socket_ipv4_tcp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_external_networking_test.cc", "socket_ipv4_udp_unbound_loopback.cc", + "socket_ipv4_udp_unbound_loopback_nogotsan.cc", "tcp_socket.cc", "udp_bind.cc", "udp_socket.cc", @@ -1030,6 +1031,24 @@ cc_binary( ) cc_binary( + name = "ip6tables_test", + testonly = 1, + srcs = [ + "ip6tables.cc", + ], + linkstatic = 1, + deps = [ + ":iptables_types", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "itimer_test", testonly = 1, srcs = ["itimer.cc"], @@ -1050,6 +1069,20 @@ cc_binary( ) cc_binary( + name = "kcov_test", + testonly = 1, + srcs = ["kcov.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:file_descriptor", + gtest, + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "kill_test", testonly = 1, srcs = ["kill.cc"], @@ -1634,12 +1667,14 @@ cc_binary( "//test/util:cleanup", "//test/util:file_descriptor", "//test/util:fs_util", + "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", gtest, "//test/util:memory_util", "//test/util:posix_error", + "//test/util:proc_util", "//test/util:temp_path", "//test/util:test_util", "//test/util:thread_util", @@ -1862,6 +1897,7 @@ cc_binary( srcs = ["readahead.cc"], linkstatic = 1, deps = [ + ":socket_test_util", "//test/util:file_descriptor", gtest, "//test/util:temp_path", @@ -1953,6 +1989,7 @@ cc_binary( gtest, "//test/util:logging", "//test/util:multiprocess_util", + "//test/util:posix_error", "//test/util:test_main", "//test/util:test_util", ], @@ -2377,12 +2414,50 @@ cc_library( ":socket_test_util", "@com_google_absl//absl/memory", gtest, + "//test/util:posix_error", "//test/util:test_util", ], alwayslink = 1, ) cc_library( + name = "socket_ipv4_udp_unbound_netlink_test_cases", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_netlink.cc", + ], + hdrs = [ + "socket_ipv4_udp_unbound_netlink.h", + ], + deps = [ + ":socket_netlink_route_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:cleanup", + gtest, + ], + alwayslink = 1, +) + +cc_library( + name = "socket_ipv6_udp_unbound_netlink_test_cases", + testonly = 1, + srcs = [ + "socket_ipv6_udp_unbound_netlink.cc", + ], + hdrs = [ + "socket_ipv6_udp_unbound_netlink.h", + ], + deps = [ + ":socket_netlink_route_util", + ":socket_test_util", + "//test/util:capability_util", + gtest, + ], + alwayslink = 1, +) + +cc_library( name = "socket_ipv4_udp_unbound_external_networking_test_cases", testonly = 1, srcs = [ @@ -2719,6 +2794,55 @@ cc_binary( ) cc_binary( + name = "socket_ipv4_udp_unbound_loopback_nogotsan_test", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_loopback_nogotsan.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_test_util", + gtest, + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/memory", + ], +) + +cc_binary( + name = "socket_ipv4_udp_unbound_loopback_netlink_test", + testonly = 1, + srcs = [ + "socket_ipv4_udp_unbound_loopback_netlink.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_ipv4_udp_unbound_netlink_test_cases", + ":socket_test_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( + name = "socket_ipv6_udp_unbound_loopback_netlink_test", + testonly = 1, + srcs = [ + "socket_ipv6_udp_unbound_loopback_netlink.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_ipv6_udp_unbound_netlink_test_cases", + ":socket_test_util", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "socket_ip_unbound_test", testonly = 1, srcs = [ @@ -3547,15 +3671,12 @@ cc_binary( ], ) -cc_library( - name = "udp_socket_test_cases", +cc_binary( + name = "udp_socket_test", testonly = 1, - srcs = [ - "udp_socket_errqueue_test_case.cc", - "udp_socket_test_cases.cc", - ], - hdrs = ["udp_socket_test_cases.h"], + srcs = ["udp_socket.cc"], defines = select_system(), + linkstatic = 1, deps = [ ":ip_socket_test_util", ":socket_test_util", @@ -3570,17 +3691,6 @@ cc_library( "//test/util:test_util", "//test/util:thread_util", ], - alwayslink = 1, -) - -cc_binary( - name = "udp_socket_test", - testonly = 1, - srcs = ["udp_socket.cc"], - linkstatic = 1, - deps = [ - ":udp_socket_test_cases", - ], ) cc_binary( diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc index 18d2f22c1..3797fd4c8 100644 --- a/test/syscalls/linux/exec_binary.cc +++ b/test/syscalls/linux/exec_binary.cc @@ -1042,6 +1042,13 @@ class ElfInterpreterStaticTest // Statically linked ELF with a statically linked ELF interpreter. TEST_P(ElfInterpreterStaticTest, Test) { + // TODO(gvisor.dev/issue/3721): Test has been observed to segfault on 5.X + // kernels. + if (!IsRunningOnGvisor()) { + auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion()); + SKIP_IF(version.major > 4); + } + const std::vector<char> segment_suffix = std::get<0>(GetParam()); const int expected_errno = std::get<1>(GetParam()); diff --git a/test/syscalls/linux/fallocate.cc b/test/syscalls/linux/fallocate.cc index cabc2b751..edd23e063 100644 --- a/test/syscalls/linux/fallocate.cc +++ b/test/syscalls/linux/fallocate.cc @@ -179,6 +179,12 @@ TEST_F(AllocateTest, FallocateOtherFDs) { auto sock0 = FileDescriptor(socks[0]); auto sock1 = FileDescriptor(socks[1]); EXPECT_THAT(fallocate(sock0.get(), 0, 0, 10), SyscallFailsWithErrno(ENODEV)); + + int pipefds[2]; + ASSERT_THAT(pipe(pipefds), SyscallSucceeds()); + EXPECT_THAT(fallocate(pipefds[1], 0, 0, 10), SyscallFailsWithErrno(ESPIPE)); + close(pipefds[0]); + close(pipefds[1]); } } // namespace diff --git a/test/syscalls/linux/flock.cc b/test/syscalls/linux/flock.cc index 638a93979..549141cbb 100644 --- a/test/syscalls/linux/flock.cc +++ b/test/syscalls/linux/flock.cc @@ -185,7 +185,7 @@ TEST_F(FlockTest, TestMultipleHolderSharedExclusive) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } -TEST_F(FlockTest, TestSharedLockFailExclusiveHolder) { +TEST_F(FlockTest, TestSharedLockFailExclusiveHolderNonblocking) { // This test will verify that a shared lock is denied while // someone holds an exclusive lock. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), @@ -203,7 +203,33 @@ TEST_F(FlockTest, TestSharedLockFailExclusiveHolder) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } -TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolder) { +void trivial_handler(int signum) {} + +TEST_F(FlockTest, TestSharedLockFailExclusiveHolderBlocking_NoRandomSave) { + const DisableSave ds; // Timing-related. + + // This test will verify that a shared lock is denied while + // someone holds an exclusive lock. + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), + SyscallSucceedsWithValue(0)); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); + + // Register a signal handler for SIGALRM and set an alarm that will go off + // while blocking in the subsequent flock() call. This will interrupt flock() + // and cause it to return EINTR. + struct sigaction act = {}; + act.sa_handler = trivial_handler; + ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds()); + ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds()); + ASSERT_THAT(flock(fd.get(), LOCK_SH), SyscallFailsWithErrno(EINTR)); + + // Unlock + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); +} + +TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderNonblocking) { // This test will verify that an exclusive lock is denied while // someone already holds an exclsuive lock. ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), @@ -221,6 +247,30 @@ TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolder) { ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); } +TEST_F(FlockTest, TestExclusiveLockFailExclusiveHolderBlocking_NoRandomSave) { + const DisableSave ds; // Timing-related. + + // This test will verify that an exclusive lock is denied while + // someone already holds an exclsuive lock. + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_EX | LOCK_NB), + SyscallSucceedsWithValue(0)); + + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(test_file_name_, O_RDWR)); + + // Register a signal handler for SIGALRM and set an alarm that will go off + // while blocking in the subsequent flock() call. This will interrupt flock() + // and cause it to return EINTR. + struct sigaction act = {}; + act.sa_handler = trivial_handler; + ASSERT_THAT(sigaction(SIGALRM, &act, NULL), SyscallSucceeds()); + ASSERT_THAT(ualarm(10000, 0), SyscallSucceeds()); + ASSERT_THAT(flock(fd.get(), LOCK_EX), SyscallFailsWithErrno(EINTR)); + + // Unlock + ASSERT_THAT(flock(test_file_fd_.get(), LOCK_UN), SyscallSucceedsWithValue(0)); +} + TEST_F(FlockTest, TestMultipleHolderSharedExclusiveUpgrade) { // This test will verify that we cannot obtain an exclusive lock while // a shared lock is held by another descriptor, then verify that an upgrade diff --git a/test/syscalls/linux/inotify.cc b/test/syscalls/linux/inotify.cc index 220874aeb..e4392a450 100644 --- a/test/syscalls/linux/inotify.cc +++ b/test/syscalls/linux/inotify.cc @@ -18,6 +18,7 @@ #include <sys/epoll.h> #include <sys/inotify.h> #include <sys/ioctl.h> +#include <sys/sendfile.h> #include <sys/time.h> #include <sys/xattr.h> @@ -464,7 +465,9 @@ TEST(Inotify, ConcurrentFileDeletionAndWatchRemoval) { for (int i = 0; i < 100; ++i) { FileDescriptor file_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(filename, O_CREAT, S_IRUSR | S_IWUSR)); - file_fd.reset(); // Close before unlinking (although save is disabled). + // Close before unlinking (although S/R is disabled). Some filesystems + // cannot restore an open fd on an unlinked file. + file_fd.reset(); EXPECT_THAT(unlink(filename.c_str()), SyscallSucceeds()); } }; @@ -1255,10 +1258,7 @@ TEST(Inotify, MknodGeneratesCreateEvent) { InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); const TempPath file1(root.path() + "/file1"); - const int rc = mknod(file1.path().c_str(), S_IFREG, 0); - // mknod(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0); - ASSERT_THAT(rc, SyscallSucceeds()); + ASSERT_THAT(mknod(file1.path().c_str(), S_IFREG, 0), SyscallSucceeds()); const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); @@ -1288,6 +1288,10 @@ TEST(Inotify, SymlinkGeneratesCreateEvent) { } TEST(Inotify, LinkGeneratesAttribAndCreateEvents) { + // Inotify does not work properly with hard links in gofer and overlay fs. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath file1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); @@ -1300,11 +1304,8 @@ TEST(Inotify, LinkGeneratesAttribAndCreateEvents) { const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - const int rc = link(file1.path().c_str(), link1.path().c_str()); - // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); + ASSERT_THAT(link(file1.path().c_str(), link1.path().c_str()), + SyscallSucceeds()); const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); @@ -1333,66 +1334,70 @@ TEST(Inotify, UtimesGeneratesAttribEvent) { } TEST(Inotify, HardlinksReuseSameWatch) { + // Inotify does not work properly with hard links in gofer and overlay fs. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); - TempPath file1 = + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(root.path())); - TempPath link1(root.path() + "/link1"); - const int rc = link(file1.path().c_str(), link1.path().c_str()); - // link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); + + TempPath file2(root.path() + "/file2"); + ASSERT_THAT(link(file.path().c_str(), file2.path().c_str()), + SyscallSucceeds()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); const int root_wd = ASSERT_NO_ERRNO_AND_VALUE( InotifyAddWatch(fd.get(), root.path(), IN_ALL_EVENTS)); - const int file1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), file1.path(), IN_ALL_EVENTS)); - const int link1_wd = ASSERT_NO_ERRNO_AND_VALUE( - InotifyAddWatch(fd.get(), link1.path(), IN_ALL_EVENTS)); + const int file_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), file.path(), IN_ALL_EVENTS)); + const int file2_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(fd.get(), file2.path(), IN_ALL_EVENTS)); // The watch descriptors for watches on different links to the same file // should be identical. - EXPECT_NE(root_wd, file1_wd); - EXPECT_EQ(file1_wd, link1_wd); + EXPECT_NE(root_wd, file_wd); + EXPECT_EQ(file_wd, file2_wd); - FileDescriptor file1_fd = - ASSERT_NO_ERRNO_AND_VALUE(Open(file1.path(), O_WRONLY)); + FileDescriptor file_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); ASSERT_THAT(events, - AreUnordered({Event(IN_OPEN, root_wd, Basename(file1.path())), - Event(IN_OPEN, file1_wd)})); + AreUnordered({Event(IN_OPEN, root_wd, Basename(file.path())), + Event(IN_OPEN, file_wd)})); // For the next step, we want to ensure all fds to the file are closed. Do // that now and drain the resulting events. - file1_fd.reset(); + file_fd.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, - Are({Event(IN_CLOSE_WRITE, root_wd, Basename(file1.path())), - Event(IN_CLOSE_WRITE, file1_wd)})); + ASSERT_THAT( + events, + AreUnordered({Event(IN_CLOSE_WRITE, root_wd, Basename(file.path())), + Event(IN_CLOSE_WRITE, file_wd)})); // Try removing the link and let's see what events show up. Note that after // this, we still have a link to the file so the watch shouldn't be // automatically removed. - const std::string link1_path = link1.reset(); + const std::string file2_path = file2.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); - ASSERT_THAT(events, Are({Event(IN_ATTRIB, link1_wd), - Event(IN_DELETE, root_wd, Basename(link1_path))})); + ASSERT_THAT(events, + AreUnordered({Event(IN_ATTRIB, file2_wd), + Event(IN_DELETE, root_wd, Basename(file2_path))})); // Now remove the other link. Since this is the last link to the file, the // watch should be automatically removed. - const std::string file1_path = file1.reset(); + const std::string file_path = file.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(fd.get())); ASSERT_THAT( events, - AreUnordered({Event(IN_ATTRIB, file1_wd), Event(IN_DELETE_SELF, file1_wd), - Event(IN_IGNORED, file1_wd), - Event(IN_DELETE, root_wd, Basename(file1_path))})); + AreUnordered({Event(IN_ATTRIB, file_wd), Event(IN_DELETE_SELF, file_wd), + Event(IN_IGNORED, file_wd), + Event(IN_DELETE, root_wd, Basename(file_path))})); } // Calling mkdir within "parent/child" should generate an event for child, but @@ -1681,6 +1686,60 @@ TEST(Inotify, EpollNoDeadlock) { } } +TEST(Inotify, Fallocate) { + const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR)); + + const FileDescriptor inotify_fd = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(inotify_fd.get(), file.path(), IN_ALL_EVENTS)); + + // Do an arbitrary modification with fallocate. + ASSERT_THAT(RetryEINTR(fallocate)(fd.get(), 0, 0, 123), SyscallSucceeds()); + std::vector<Event> events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); + EXPECT_THAT(events, Are({Event(IN_MODIFY, wd)})); +} + +TEST(Inotify, Sendfile) { + SKIP_IF(IsRunningWithVFS1()); + + const TempPath root = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(root.path(), "x", 0644)); + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor in = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + const FileDescriptor out = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY)); + + // Create separate inotify instances for the in and out fds. If both watches + // were on the same instance, we would have discrepancies between Linux and + // gVisor (order of events, duplicate events), which is not that important + // since inotify is asynchronous anyway. + const FileDescriptor in_inotify = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const FileDescriptor out_inotify = + ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); + const int in_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(in_inotify.get(), in_file.path(), IN_ALL_EVENTS)); + const int out_wd = ASSERT_NO_ERRNO_AND_VALUE( + InotifyAddWatch(out_inotify.get(), out_file.path(), IN_ALL_EVENTS)); + + ASSERT_THAT(sendfile(out.get(), in.get(), /*offset=*/nullptr, 1), + SyscallSucceeds()); + + // Expect a single access event and a single modify event. + std::vector<Event> in_events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(in_inotify.get())); + std::vector<Event> out_events = + ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(out_inotify.get())); + EXPECT_THAT(in_events, Are({Event(IN_ACCESS, in_wd)})); + EXPECT_THAT(out_events, Are({Event(IN_MODIFY, out_wd)})); +} + // On Linux, inotify behavior is not very consistent with splice(2). We try our // best to emulate Linux for very basic calls to splice. TEST(Inotify, SpliceOnWatchTarget) { @@ -1749,17 +1808,17 @@ TEST(Inotify, SpliceOnInotifyFD) { // Watches on a parent should not be triggered by actions on a hard link to one // of its children that has a different parent. TEST(Inotify, LinkOnOtherParent) { + // Inotify does not work properly with hard links in gofer and overlay fs. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); + const TempPath dir1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath dir2 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir1.path())); std::string link_path = NewTempAbsPathInDir(dir2.path()); - const int rc = link(file.path().c_str(), link_path.c_str()); - // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); + ASSERT_THAT(link(file.path().c_str(), link_path.c_str()), SyscallSucceeds()); const FileDescriptor inotify_fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); @@ -1768,13 +1827,18 @@ TEST(Inotify, LinkOnOtherParent) { // Perform various actions on the link outside of dir1, which should trigger // no inotify events. - const FileDescriptor fd = + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(link_path.c_str(), O_RDWR)); int val = 0; ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds()); + + // Close before unlinking; some filesystems cannot restore an open fd on an + // unlinked file. + fd.reset(); ASSERT_THAT(unlink(link_path.c_str()), SyscallSucceeds()); + const std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); EXPECT_THAT(events, Are({})); @@ -1879,14 +1943,22 @@ TEST(Inotify, IncludeUnlinkedFile_NoRandomSave) { ASSERT_THAT(write(fd.get(), &val, sizeof(val)), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ - Event(IN_ATTRIB, file_wd), - Event(IN_DELETE, dir_wd, Basename(file.path())), - Event(IN_ACCESS, dir_wd, Basename(file.path())), - Event(IN_ACCESS, file_wd), - Event(IN_MODIFY, dir_wd, Basename(file.path())), - Event(IN_MODIFY, file_wd), - })); + EXPECT_THAT(events, AnyOf(Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ACCESS, dir_wd, Basename(file.path())), + Event(IN_ACCESS, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }), + Are({ + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + Event(IN_ACCESS, dir_wd, Basename(file.path())), + Event(IN_ACCESS, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }))); fd.reset(); events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); @@ -1929,7 +2001,7 @@ TEST(Inotify, ExcludeUnlink_NoRandomSave) { ASSERT_THAT(read(fd.get(), &val, sizeof(val)), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ + EXPECT_THAT(events, AreUnordered({ Event(IN_ATTRIB, file_wd), Event(IN_DELETE, dir_wd, Basename(file.path())), })); @@ -1990,21 +2062,21 @@ TEST(Inotify, ExcludeUnlinkDirectory_NoRandomSave) { // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) { - const DisableSave ds; + // Inotify does not work properly with hard links in gofer and overlay fs. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); // TODO(gvisor.dev/issue/1624): This test fails on VFS1. SKIP_IF(IsRunningWithVFS1()); + const DisableSave ds; + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); std::string path1 = file.path(); std::string path2 = NewTempAbsPathInDir(dir.path()); + ASSERT_THAT(link(path1.c_str(), path2.c_str()), SyscallSucceeds()); - const int rc = link(path1.c_str(), path2.c_str()); - // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(IsRunningOnGvisor() && rc != 0 && - (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); const FileDescriptor fd1 = ASSERT_NO_ERRNO_AND_VALUE(Open(path1.c_str(), O_RDWR)); const FileDescriptor fd2 = @@ -2036,6 +2108,15 @@ TEST(Inotify, ExcludeUnlinkMultipleChildren_NoRandomSave) { // We need to disable S/R because there are filesystems where we cannot re-open // fds to an unlinked file across S/R, e.g. gofer-backed filesytems. TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { + // TODO(gvisor.dev/issue/1624): Fails on VFS1. + SKIP_IF(IsRunningWithVFS1()); + + // NOTE(gvisor.dev/issue/3654): In the gofer filesystem, we do not allow + // setting attributes through an fd if the file at the open path has been + // deleted. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(GetAbsoluteTestTmpdir()))); + const DisableSave ds; const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); @@ -2045,18 +2126,6 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path().c_str(), O_RDWR)); - // NOTE(b/157163751): Create another link before unlinking. This is needed for - // the gofer filesystem in gVisor, where open fds will not work once the link - // count hits zero. In VFS2, we end up skipping the gofer test anyway, because - // hard links are not supported for gofer fs. - if (IsRunningOnGvisor()) { - std::string link_path = NewTempAbsPath(); - const int rc = link(file.path().c_str(), link_path.c_str()); - // NOTE(b/34861058): link(2) is only supported on tmpfs in the sandbox. - SKIP_IF(rc != 0 && (errno == EPERM || errno == ENOENT)); - ASSERT_THAT(rc, SyscallSucceeds()); - } - const FileDescriptor inotify_fd = ASSERT_NO_ERRNO_AND_VALUE(InotifyInit1(IN_NONBLOCK)); const int dir_wd = ASSERT_NO_ERRNO_AND_VALUE(InotifyAddWatch( @@ -2072,12 +2141,18 @@ TEST(Inotify, ExcludeUnlinkInodeEvents_NoRandomSave) { ASSERT_THAT(ftruncate(fd.get(), 12345), SyscallSucceeds()); std::vector<Event> events = ASSERT_NO_ERRNO_AND_VALUE(DrainEvents(inotify_fd.get())); - EXPECT_THAT(events, Are({ - Event(IN_ATTRIB, file_wd), - Event(IN_DELETE, dir_wd, Basename(file.path())), - Event(IN_MODIFY, dir_wd, Basename(file.path())), - Event(IN_MODIFY, file_wd), - })); + EXPECT_THAT(events, AnyOf(Are({ + Event(IN_ATTRIB, file_wd), + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }), + Are({ + Event(IN_DELETE, dir_wd, Basename(file.path())), + Event(IN_ATTRIB, file_wd), + Event(IN_MODIFY, dir_wd, Basename(file.path())), + Event(IN_MODIFY, file_wd), + }))); const struct timeval times[2] = {{1, 0}, {2, 0}}; ASSERT_THAT(futimes(fd.get(), times), SyscallSucceeds()); diff --git a/test/syscalls/linux/ip6tables.cc b/test/syscalls/linux/ip6tables.cc new file mode 100644 index 000000000..f08f2dc55 --- /dev/null +++ b/test/syscalls/linux/ip6tables.cc @@ -0,0 +1,239 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/capability.h> +#include <sys/socket.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/iptables.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +constexpr char kNatTablename[] = "nat"; +constexpr char kErrorTarget[] = "ERROR"; +constexpr size_t kEmptyStandardEntrySize = + sizeof(struct ip6t_entry) + sizeof(struct xt_standard_target); +constexpr size_t kEmptyErrorEntrySize = + sizeof(struct ip6t_entry) + sizeof(struct xt_error_target); + +TEST(IP6TablesBasic, FailSockoptNonRaw) { + // Even if the user has CAP_NET_RAW, they shouldn't be able to use the + // ip6tables sockopts with a non-raw socket. + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + EXPECT_THAT(getsockopt(sock, SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(ENOPROTOOPT)); + + EXPECT_THAT(close(sock), SyscallSucceeds()); +} + +TEST(IP6TablesBasic, GetInfoErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info) - 1; + EXPECT_THAT(getsockopt(sock, SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IP6TablesBasic, GetEntriesErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ip6t_get_entries entries = {}; + socklen_t entries_size = sizeof(struct ip6t_get_entries) - 1; + snprintf(entries.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_ENTRIES, &entries, &entries_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IP6TablesBasic, GetRevision) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_RAW, IPPROTO_RAW), + SyscallSucceeds()); + + struct xt_get_revision rev = { + .name = "REDIRECT", + .revision = 0, + }; + socklen_t rev_len = sizeof(rev); + + // TODO(gvisor.dev/issue/3549): IPv6 redirect support. + const int retval = + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len); + if (IsRunningOnGvisor()) { + EXPECT_THAT(retval, SyscallFailsWithErrno(ENOPROTOOPT)); + return; + } + + // Revision 0 exists. + EXPECT_THAT(retval, SyscallSucceeds()); + EXPECT_EQ(rev.revision, 0); + + // Revisions > 0 don't exist. + rev.revision = 1; + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallFailsWithErrno(EPROTONOSUPPORT)); +} + +// This tests the initial state of a machine with empty ip6tables via +// getsockopt(IP6T_SO_GET_INFO). We don't have a guarantee that the iptables are +// empty when running in native, but we can test that gVisor has the same +// initial state that a newly-booted Linux machine would have. +TEST(IP6TablesTest, InitialInfo) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)); + + // Get info via sockopt. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT( + getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // The nat table supports PREROUTING, and OUTPUT. + unsigned int valid_hooks = + (1 << NF_IP6_PRE_ROUTING) | (1 << NF_IP6_LOCAL_OUT) | + (1 << NF_IP6_POST_ROUTING) | (1 << NF_IP6_LOCAL_IN); + EXPECT_EQ(info.valid_hooks, valid_hooks); + + // Each chain consists of an empty entry with a standard target.. + EXPECT_EQ(info.hook_entry[NF_IP6_PRE_ROUTING], 0); + EXPECT_EQ(info.hook_entry[NF_IP6_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.hook_entry[NF_IP6_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.hook_entry[NF_IP6_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // The underflow points are the same as the entry points. + EXPECT_EQ(info.underflow[NF_IP6_PRE_ROUTING], 0); + EXPECT_EQ(info.underflow[NF_IP6_LOCAL_IN], kEmptyStandardEntrySize); + EXPECT_EQ(info.underflow[NF_IP6_LOCAL_OUT], kEmptyStandardEntrySize * 2); + EXPECT_EQ(info.underflow[NF_IP6_POST_ROUTING], kEmptyStandardEntrySize * 3); + + // One entry for each chain, plus an error entry at the end. + EXPECT_EQ(info.num_entries, 5); + + EXPECT_EQ(info.size, 4 * kEmptyStandardEntrySize + kEmptyErrorEntrySize); + EXPECT_EQ(strcmp(info.name, kNatTablename), 0); +} + +// This tests the initial state of a machine with empty ip6tables via +// getsockopt(IP6T_SO_GET_ENTRIES). We don't have a guarantee that the iptables +// are empty when running in native, but we can test that gVisor has the same +// initial state that a newly-booted Linux machine would have. +TEST(IP6TablesTest, InitialEntries) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET6, SOCK_RAW, IPPROTO_RAW)); + + // Get info via sockopt. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + ASSERT_THAT( + getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallSucceeds()); + + // Use info to get entries. + socklen_t entries_size = sizeof(struct ip6t_get_entries) + info.size; + struct ip6t_get_entries* entries = + static_cast<struct ip6t_get_entries*>(malloc(entries_size)); + snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + entries->size = info.size; + ASSERT_THAT(getsockopt(sock.get(), SOL_IPV6, IP6T_SO_GET_ENTRIES, entries, + &entries_size), + SyscallSucceeds()); + + // Verify the name and size. + ASSERT_EQ(info.size, entries->size); + ASSERT_EQ(strcmp(entries->name, kNatTablename), 0); + + // Verify that the entrytable is 4 entries with accept targets and no matches + // followed by a single error target. + size_t entry_offset = 0; + while (entry_offset < entries->size) { + struct ip6t_entry* entry = reinterpret_cast<struct ip6t_entry*>( + reinterpret_cast<char*>(entries->entrytable) + entry_offset); + + // ipv6 should be zeroed. + struct ip6t_ip6 zeroed = {}; + ASSERT_EQ(memcmp(static_cast<void*>(&zeroed), + static_cast<void*>(&entry->ipv6), sizeof(zeroed)), + 0); + + // target_offset should be zero. + EXPECT_EQ(entry->target_offset, sizeof(ip6t_entry)); + + if (entry_offset < kEmptyStandardEntrySize * 4) { + // The first 4 entries are standard targets + struct xt_standard_target* target = + reinterpret_cast<struct xt_standard_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyStandardEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, ""), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + // This is what's returned for an accept verdict. I don't know why. + EXPECT_EQ(target->verdict, -NF_ACCEPT - 1); + } else { + // The last entry is an error target + struct xt_error_target* target = + reinterpret_cast<struct xt_error_target*>(entry->elems); + EXPECT_EQ(entry->next_offset, kEmptyErrorEntrySize); + EXPECT_EQ(target->target.u.user.target_size, sizeof(*target)); + EXPECT_EQ(strcmp(target->target.u.user.name, kErrorTarget), 0); + EXPECT_EQ(target->target.u.user.revision, 0); + EXPECT_EQ(strcmp(target->errorname, kErrorTarget), 0); + } + + entry_offset += entry->next_offset; + break; + } + + free(entries); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/iptables.cc b/test/syscalls/linux/iptables.cc index b8e4ece64..7ee10bbde 100644 --- a/test/syscalls/linux/iptables.cc +++ b/test/syscalls/linux/iptables.cc @@ -67,12 +67,82 @@ TEST(IPTablesBasic, FailSockoptNonRaw) { struct ipt_getinfo info = {}; snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); socklen_t info_size = sizeof(info); - EXPECT_THAT(getsockopt(sock, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + EXPECT_THAT(getsockopt(sock, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), SyscallFailsWithErrno(ENOPROTOOPT)); ASSERT_THAT(close(sock), SyscallSucceeds()); } +TEST(IPTablesBasic, GetInfoErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info) - 1; + ASSERT_THAT(getsockopt(sock, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IPTablesBasic, GetEntriesErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_get_entries entries = {}; + socklen_t entries_size = sizeof(struct ipt_get_entries) - 1; + snprintf(entries.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + ASSERT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_ENTRIES, &entries, &entries_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IPTablesBasic, OriginalDstErrors) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_STREAM, 0), SyscallSucceeds()); + + // Sockets not affected by NAT should fail to find an original destination. + struct sockaddr_in addr = {}; + socklen_t addr_len = sizeof(addr); + EXPECT_THAT(getsockopt(sock, SOL_IP, SO_ORIGINAL_DST, &addr, &addr_len), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST(IPTablesBasic, GetRevision) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), + SyscallSucceeds()); + + struct xt_get_revision rev = { + .name = "REDIRECT", + .revision = 0, + }; + socklen_t rev_len = sizeof(rev); + + // Revision 0 exists. + EXPECT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallSucceeds()); + EXPECT_EQ(rev.revision, 0); + + // Revisions > 0 don't exist. + rev.revision = 1; + EXPECT_THAT( + getsockopt(sock, SOL_IP, IPT_SO_GET_REVISION_TARGET, &rev, &rev_len), + SyscallFailsWithErrno(EPROTONOSUPPORT)); +} + // Fixture for iptables tests. class IPTablesTest : public ::testing::Test { protected: @@ -112,7 +182,7 @@ TEST_F(IPTablesTest, InitialState) { struct ipt_getinfo info = {}; snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); socklen_t info_size = sizeof(info); - ASSERT_THAT(getsockopt(s_, IPPROTO_IP, SO_GET_INFO, &info, &info_size), + ASSERT_THAT(getsockopt(s_, SOL_IP, IPT_SO_GET_INFO, &info, &info_size), SyscallSucceeds()); // The nat table supports PREROUTING, and OUTPUT. @@ -148,7 +218,7 @@ TEST_F(IPTablesTest, InitialState) { snprintf(entries->name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); entries->size = info.size; ASSERT_THAT( - getsockopt(s_, IPPROTO_IP, SO_GET_ENTRIES, entries, &entries_size), + getsockopt(s_, SOL_IP, IPT_SO_GET_ENTRIES, entries, &entries_size), SyscallSucceeds()); // Verify the name and size. diff --git a/test/syscalls/linux/iptables.h b/test/syscalls/linux/iptables.h index 0719c60a4..d0fc10fea 100644 --- a/test/syscalls/linux/iptables.h +++ b/test/syscalls/linux/iptables.h @@ -27,27 +27,32 @@ #include <linux/netfilter/x_tables.h> #include <linux/netfilter_ipv4.h> +#include <linux/netfilter_ipv6.h> #include <net/if.h> #include <netinet/ip.h> #include <stdint.h> +// +// IPv4 ABI. +// + #define ipt_standard_target xt_standard_target #define ipt_entry_target xt_entry_target #define ipt_error_target xt_error_target enum SockOpts { // For setsockopt. - BASE_CTL = 64, - SO_SET_REPLACE = BASE_CTL, - SO_SET_ADD_COUNTERS, - SO_SET_MAX = SO_SET_ADD_COUNTERS, + IPT_BASE_CTL = 64, + IPT_SO_SET_REPLACE = IPT_BASE_CTL, + IPT_SO_SET_ADD_COUNTERS = IPT_BASE_CTL + 1, + IPT_SO_SET_MAX = IPT_SO_SET_ADD_COUNTERS, // For getsockopt. - SO_GET_INFO = BASE_CTL, - SO_GET_ENTRIES, - SO_GET_REVISION_MATCH, - SO_GET_REVISION_TARGET, - SO_GET_MAX = SO_GET_REVISION_TARGET + IPT_SO_GET_INFO = IPT_BASE_CTL, + IPT_SO_GET_ENTRIES = IPT_BASE_CTL + 1, + IPT_SO_GET_REVISION_MATCH = IPT_BASE_CTL + 2, + IPT_SO_GET_REVISION_TARGET = IPT_BASE_CTL + 3, + IPT_SO_GET_MAX = IPT_SO_GET_REVISION_TARGET }; // ipt_ip specifies basic matching criteria that can be applied by examining @@ -115,7 +120,7 @@ struct ipt_entry { unsigned char elems[0]; }; -// Passed to getsockopt(SO_GET_INFO). +// Passed to getsockopt(IPT_SO_GET_INFO). struct ipt_getinfo { // The name of the table. The user only fills this in, the rest is filled in // when returning from getsockopt. Currently "nat" and "mangle" are supported. @@ -127,7 +132,7 @@ struct ipt_getinfo { unsigned int valid_hooks; // The offset into the entry table for each valid hook. The entry table is - // returned by getsockopt(SO_GET_ENTRIES). + // returned by getsockopt(IPT_SO_GET_ENTRIES). unsigned int hook_entry[NF_IP_NUMHOOKS]; // For each valid hook, the underflow is the offset into the entry table to @@ -142,14 +147,14 @@ struct ipt_getinfo { unsigned int underflow[NF_IP_NUMHOOKS]; // The number of entries in the entry table returned by - // getsockopt(SO_GET_ENTRIES). + // getsockopt(IPT_SO_GET_ENTRIES). unsigned int num_entries; - // The size of the entry table returned by getsockopt(SO_GET_ENTRIES). + // The size of the entry table returned by getsockopt(IPT_SO_GET_ENTRIES). unsigned int size; }; -// Passed to getsockopt(SO_GET_ENTRIES). +// Passed to getsockopt(IPT_SO_GET_ENTRIES). struct ipt_get_entries { // The name of the table. The user fills this in. Currently "nat" and "mangle" // are supported. @@ -195,4 +200,103 @@ struct ipt_replace { struct ipt_entry entries[0]; }; +// +// IPv6 ABI. +// + +enum SockOpts6 { + // For setsockopt. + IP6T_BASE_CTL = 64, + IP6T_SO_SET_REPLACE = IP6T_BASE_CTL, + IP6T_SO_SET_ADD_COUNTERS = IP6T_BASE_CTL + 1, + IP6T_SO_SET_MAX = IP6T_SO_SET_ADD_COUNTERS, + + // For getsockopt. + IP6T_SO_GET_INFO = IP6T_BASE_CTL, + IP6T_SO_GET_ENTRIES = IP6T_BASE_CTL + 1, + IP6T_SO_GET_REVISION_MATCH = IP6T_BASE_CTL + 4, + IP6T_SO_GET_REVISION_TARGET = IP6T_BASE_CTL + 5, + IP6T_SO_GET_MAX = IP6T_SO_GET_REVISION_TARGET +}; + +// ip6t_ip6 specifies basic matching criteria that can be applied by examining +// only the IP header of a packet. +struct ip6t_ip6 { + // Source IP address. + struct in6_addr src; + + // Destination IP address. + struct in6_addr dst; + + // Source IP address mask. + struct in6_addr smsk; + + // Destination IP address mask. + struct in6_addr dmsk; + + // Input interface. + char iniface[IFNAMSIZ]; + + // Output interface. + char outiface[IFNAMSIZ]; + + // Input interface mask. + unsigned char iniface_mask[IFNAMSIZ]; + + // Output interface mask. + unsigned char outiface_mask[IFNAMSIZ]; + + // Transport protocol. + uint16_t proto; + + // TOS. + uint8_t tos; + + // Flags. + uint8_t flags; + + // Inverse flags. + uint8_t invflags; +}; + +// ip6t_entry is an ip6tables rule. +struct ip6t_entry { + // Basic matching information used to match a packet's IP header. + struct ip6t_ip6 ipv6; + + // A caching field that isn't used by userspace. + unsigned int nfcache; + + // The number of bytes between the start of this entry and the rule's target. + uint16_t target_offset; + + // The total size of this rule, from the beginning of the entry to the end of + // the target. + uint16_t next_offset; + + // A return pointer not used by userspace. + unsigned int comefrom; + + // Counters for packets and bytes, which we don't yet implement. + struct xt_counters counters; + + // The data for all this rules matches followed by the target. This runs + // beyond the value of sizeof(struct ip6t_entry). + unsigned char elems[0]; +}; + +// Passed to getsockopt(IP6T_SO_GET_ENTRIES). +struct ip6t_get_entries { + // The name of the table. + char name[XT_TABLE_MAXNAMELEN]; + + // The size of the entry table in bytes. The user fills this in with the value + // from struct ipt_getinfo.size. + unsigned int size; + + // The entries for the given table. This will run past the size defined by + // sizeof(struct ip6t_get_entries). + struct ip6t_entry entrytable[0]; +}; + #endif // GVISOR_TEST_SYSCALLS_IPTABLES_TYPES_H_ diff --git a/test/syscalls/linux/kcov.cc b/test/syscalls/linux/kcov.cc new file mode 100644 index 000000000..6afcb4e75 --- /dev/null +++ b/test/syscalls/linux/kcov.cc @@ -0,0 +1,73 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/errno.h> +#include <sys/ioctl.h> +#include <sys/mman.h> + +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// For this test to work properly, it must be run with coverage enabled. On +// native Linux, this involves compiling the kernel with kcov enabled. For +// gVisor, we need to enable the Go coverage tool, e.g. +// bazel test --collect_coverage_data --instrumentation_filter=//pkg/... <test>. +TEST(KcovTest, Kcov) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + constexpr int kSize = 4096; + constexpr int KCOV_INIT_TRACE = 0x80086301; + constexpr int KCOV_ENABLE = 0x6364; + constexpr int KCOV_DISABLE = 0x6365; + + int fd; + ASSERT_THAT(fd = open("/sys/kernel/debug/kcov", O_RDWR), + AnyOf(SyscallSucceeds(), SyscallFailsWithErrno(ENOENT))); + + // Kcov not available. + SKIP_IF(errno == ENOENT); + + ASSERT_THAT(ioctl(fd, KCOV_INIT_TRACE, kSize), SyscallSucceeds()); + uint64_t* area = (uint64_t*)mmap(nullptr, kSize * sizeof(uint64_t), + PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + + ASSERT_TRUE(area != MAP_FAILED); + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallSucceeds()); + + for (int i = 0; i < 10; i++) { + // Make some syscalls to generate coverage data. + ASSERT_THAT(ioctl(fd, KCOV_ENABLE, 0), SyscallFailsWithErrno(EINVAL)); + } + + uint64_t num_pcs = *(uint64_t*)(area); + EXPECT_GT(num_pcs, 0); + for (uint64_t i = 1; i <= num_pcs; i++) { + // Verify that PCs are in the standard kernel range. + EXPECT_GT(area[i], 0xffffffff7fffffffL); + } + + ASSERT_THAT(ioctl(fd, KCOV_DISABLE, 0), SyscallSucceeds()); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc index f8b7f7938..4a450742b 100644 --- a/test/syscalls/linux/memfd.cc +++ b/test/syscalls/linux/memfd.cc @@ -14,12 +14,10 @@ #include <errno.h> #include <fcntl.h> -#include <linux/magic.h> #include <linux/memfd.h> #include <linux/unistd.h> #include <string.h> #include <sys/mman.h> -#include <sys/statfs.h> #include <sys/syscall.h> #include <vector> @@ -53,6 +51,7 @@ namespace { #define F_SEAL_GROW 0x0004 #define F_SEAL_WRITE 0x0008 +using ::gvisor::testing::IsTmpfs; using ::testing::StartsWith; const std::string kMemfdName = "some-memfd"; @@ -444,20 +443,6 @@ TEST(MemfdTest, SealsAreInodeLevelProperties) { EXPECT_THAT(ftruncate(memfd3.get(), kPageSize), SyscallFailsWithErrno(EPERM)); } -PosixErrorOr<bool> IsTmpfs(const std::string& path) { - struct statfs stat; - if (statfs(path.c_str(), &stat)) { - if (errno == ENOENT) { - // Nothing at path, don't raise this as an error. Instead, just report no - // tmpfs at path. - return false; - } - return PosixError(errno, - absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); - } - return stat.f_type == TMPFS_MAGIC; -} - // Tmpfs files also support seals, but are created with F_SEAL_SEAL. TEST(MemfdTest, TmpfsFilesHaveSealSeal) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs("/tmp"))); diff --git a/test/syscalls/linux/mkdir.cc b/test/syscalls/linux/mkdir.cc index 4036a9275..27758203d 100644 --- a/test/syscalls/linux/mkdir.cc +++ b/test/syscalls/linux/mkdir.cc @@ -82,6 +82,13 @@ TEST_F(MkdirTest, FailsOnDirWithoutWritePerms) { SyscallFailsWithErrno(EACCES)); } +TEST_F(MkdirTest, MkdirAtEmptyPath) { + ASSERT_THAT(mkdir(dirname_.c_str(), 0777), SyscallSucceeds()); + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dirname_, O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(mkdirat(fd.get(), "", 0777), SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mknod.cc b/test/syscalls/linux/mknod.cc index 05dfb375a..ae65d366b 100644 --- a/test/syscalls/linux/mknod.cc +++ b/test/syscalls/linux/mknod.cc @@ -14,6 +14,7 @@ #include <errno.h> #include <fcntl.h> +#include <sys/socket.h> #include <sys/stat.h> #include <sys/types.h> #include <sys/un.h> @@ -103,6 +104,27 @@ TEST(MknodTest, UnimplementedTypesReturnError) { ASSERT_THAT(mknod(path.c_str(), S_IFBLK, 0), SyscallFailsWithErrno(EPERM)); } +TEST(MknodTest, Socket) { + SKIP_IF(IsRunningOnGvisor() && IsRunningWithVFS1()); + + ASSERT_THAT(chdir(GetAbsoluteTestTmpdir().c_str()), SyscallSucceeds()); + + auto filename = NewTempRelPath(); + + ASSERT_THAT(mknod(filename.c_str(), S_IFSOCK | S_IRUSR | S_IWUSR, 0), + SyscallSucceeds()); + + int sk; + ASSERT_THAT(sk = socket(AF_UNIX, SOCK_SEQPACKET, 0), SyscallSucceeds()); + FileDescriptor fd(sk); + + struct sockaddr_un addr = {.sun_family = AF_UNIX}; + absl::SNPrintF(addr.sun_path, sizeof(addr.sun_path), "%s", filename.c_str()); + ASSERT_THAT(connect(sk, (struct sockaddr *)&addr, sizeof(addr)), + SyscallFailsWithErrno(ECONNREFUSED)); + ASSERT_THAT(unlink(filename.c_str()), SyscallSucceeds()); +} + TEST(MknodTest, Fifo) { const std::string fifo = NewTempAbsPath(); ASSERT_THAT(mknod(fifo.c_str(), S_IFIFO | S_IRUSR | S_IWUSR, 0), @@ -184,6 +206,14 @@ TEST(MknodTest, FifoTruncNoOp) { EXPECT_THAT(ftruncate(wfd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +TEST(MknodTest, MknodAtEmptyPath) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(mknodat(fd.get(), "", S_IFREG | 0777, 0), + SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/mmap.cc b/test/syscalls/linux/mmap.cc index 6d3227ab6..e52c9cbcb 100644 --- a/test/syscalls/linux/mmap.cc +++ b/test/syscalls/linux/mmap.cc @@ -43,6 +43,8 @@ #include "test/util/temp_path.h" #include "test/util/test_util.h" +using ::testing::AnyOf; +using ::testing::Eq; using ::testing::Gt; namespace gvisor { @@ -296,7 +298,8 @@ TEST_F(MMapTest, MapDevZeroSegfaultAfterUnmap) { }; EXPECT_THAT(InForkedProcess(rest), - IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV))); + IsPosixErrorOkAndHolds(AnyOf(Eq(W_EXITCODE(0, SIGSEGV)), + Eq(W_EXITCODE(0, 128 + SIGSEGV))))); } TEST_F(MMapTest, MapDevZeroUnaligned) { diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc index 97e8d0f7e..3aab25b23 100644 --- a/test/syscalls/linux/mount.cc +++ b/test/syscalls/linux/mount.cc @@ -147,8 +147,15 @@ TEST(MountTest, UmountDetach) { // Unmount the tmpfs. mount.Release()(); - const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(before.st_ino, after2.st_ino); + // Only check for inode number equality if the directory is not in overlayfs. + // If xino option is not enabled and if all overlayfs layers do not belong to + // the same filesystem then "the value of st_ino for directory objects may not + // be persistent and could change even while the overlay filesystem is + // mounted." -- Documentation/filesystems/overlayfs.txt + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + const struct stat after2 = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); + EXPECT_EQ(before.st_ino, after2.st_ino); + } // Can still read file after unmounting. std::vector<char> buf(sizeof(kContents)); @@ -213,8 +220,15 @@ TEST(MountTest, MountTmpfs) { } // Now that dir is unmounted again, we should have the old inode back. - const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); - EXPECT_EQ(before.st_ino, after.st_ino); + // Only check for inode number equality if the directory is not in overlayfs. + // If xino option is not enabled and if all overlayfs layers do not belong to + // the same filesystem then "the value of st_ino for directory objects may not + // be persistent and could change even while the overlay filesystem is + // mounted." -- Documentation/filesystems/overlayfs.txt + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(dir.path()))) { + const struct stat after = ASSERT_NO_ERRNO_AND_VALUE(Stat(dir.path())); + EXPECT_EQ(before.st_ino, after.st_ino); + } } TEST(MountTest, MountTmpfsMagicValIgnored) { @@ -326,6 +340,14 @@ TEST(MountTest, MountFuseFilesystemNoDevice) { SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled()); auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + // Before kernel version 4.16-rc6, FUSE mount is protected by + // capable(CAP_SYS_ADMIN). After this version, it uses + // ns_capable(CAP_SYS_ADMIN) to protect. Before the 4.16 kernel, it was not + // allowed to mount fuse file systems without the global CAP_SYS_ADMIN. + int res = mount("", dir.path().c_str(), "fuse", 0, ""); + SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM); + EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""), SyscallFailsWithErrno(EINVAL)); } @@ -339,6 +361,12 @@ TEST(MountTest, MountFuseFilesystem) { std::string mopts = "fd=" + std::to_string(fd.get()); auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + // See comments in MountFuseFilesystemNoDevice for the reason why we skip + // EPERM when running on Linux. + int res = mount("", dir.path().c_str(), "fuse", 0, ""); + SKIP_IF(!IsRunningOnGvisor() && res == -1 && errno == EPERM); + auto const mount = ASSERT_NO_ERRNO_AND_VALUE(Mount("", dir.path(), "fuse", 0, mopts, 0)); } diff --git a/test/syscalls/linux/open.cc b/test/syscalls/linux/open.cc index bf350946b..77f390f3c 100644 --- a/test/syscalls/linux/open.cc +++ b/test/syscalls/linux/open.cc @@ -27,6 +27,7 @@ #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -95,6 +96,38 @@ TEST_F(OpenTest, OTruncAndReadOnlyFile) { Open(dirpath.c_str(), O_TRUNC | O_RDONLY, 0666)); } +TEST_F(OpenTest, OCreateDirectory) { + SKIP_IF(IsRunningWithVFS1()); + auto dirpath = GetAbsoluteTestTmpdir(); + + // Normal case: existing directory. + ASSERT_THAT(open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // Trailing separator on existing directory. + ASSERT_THAT(open(dirpath.append("/").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // Trailing separator on non-existing directory. + ASSERT_THAT(open(JoinPath(dirpath, "non-existent").append("/").c_str(), + O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); + // "." special case. + ASSERT_THAT(open(JoinPath(dirpath, ".").c_str(), O_RDWR | O_CREAT, 0666), + SyscallFailsWithErrno(EISDIR)); +} + +TEST_F(OpenTest, MustCreateExisting) { + auto dirPath = GetAbsoluteTestTmpdir(); + + // Existing directory. + ASSERT_THAT(open(dirPath.c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), + SyscallFailsWithErrno(EEXIST)); + + // Existing file. + auto newFile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dirPath)); + ASSERT_THAT(open(newFile.path().c_str(), O_RDWR | O_CREAT | O_EXCL, 0666), + SyscallFailsWithErrno(EEXIST)); +} + TEST_F(OpenTest, ReadOnly) { char buf; const FileDescriptor ro_file = @@ -115,6 +148,26 @@ TEST_F(OpenTest, WriteOnly) { EXPECT_THAT(write(wo_file.get(), &buf, 1), SyscallSucceedsWithValue(1)); } +TEST_F(OpenTest, CreateWithAppend) { + std::string data = "text"; + std::string new_file = NewTempAbsPath(); + const FileDescriptor file = ASSERT_NO_ERRNO_AND_VALUE( + Open(new_file, O_WRONLY | O_APPEND | O_CREAT, 0666)); + EXPECT_THAT(write(file.get(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(file.get(), 0, SEEK_SET), SyscallSucceeds()); + EXPECT_THAT(write(file.get(), data.c_str(), data.size()), + SyscallSucceedsWithValue(data.size())); + + // Check that the size of the file is correct and that the offset has been + // incremented to that size. + struct stat s0; + EXPECT_THAT(fstat(file.get(), &s0), SyscallSucceeds()); + EXPECT_EQ(s0.st_size, 2 * data.size()); + EXPECT_THAT(lseek(file.get(), 0, SEEK_CUR), + SyscallSucceedsWithValue(2 * data.size())); +} + TEST_F(OpenTest, ReadWrite) { char buf; const FileDescriptor rw_file = @@ -356,6 +409,13 @@ TEST_F(OpenTest, FileNotDirectory) { SyscallFailsWithErrno(ENOTDIR)); } +TEST_F(OpenTest, SymlinkDirectory) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string link = NewTempAbsPath(); + ASSERT_THAT(symlink(dir.path().c_str(), link.c_str()), SyscallSucceeds()); + ASSERT_NO_ERRNO(Open(link, O_RDONLY | O_DIRECTORY)); +} + TEST_F(OpenTest, Null) { char c = '\0'; ASSERT_THAT(open(&c, O_RDONLY), SyscallFailsWithErrno(ENOENT)); diff --git a/test/syscalls/linux/open_create.cc b/test/syscalls/linux/open_create.cc index 51eacf3f2..78c36f98f 100644 --- a/test/syscalls/linux/open_create.cc +++ b/test/syscalls/linux/open_create.cc @@ -88,21 +88,21 @@ TEST(CreateTest, CreateExclusively) { SyscallFailsWithErrno(EEXIST)); } -TEST(CreateTeast, CreatWithOTrunc) { +TEST(CreateTest, CreatWithOTrunc) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC, 0666), SyscallFailsWithErrno(EISDIR)); } -TEST(CreateTeast, CreatDirWithOTruncAndReadOnly) { +TEST(CreateTest, CreatDirWithOTruncAndReadOnly) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncd"); ASSERT_THAT(mkdir(dirpath.c_str(), 0777), SyscallSucceeds()); ASSERT_THAT(open(dirpath.c_str(), O_CREAT | O_TRUNC | O_RDONLY, 0666), SyscallFailsWithErrno(EISDIR)); } -TEST(CreateTeast, CreatFileWithOTruncAndReadOnly) { +TEST(CreateTest, CreatFileWithOTruncAndReadOnly) { std::string dirpath = JoinPath(GetAbsoluteTestTmpdir(), "truncfile"); int dirfd; ASSERT_THAT(dirfd = open(dirpath.c_str(), O_RDWR | O_CREAT, 0666), @@ -149,6 +149,116 @@ TEST(CreateTest, OpenCreateROThenRW) { EXPECT_THAT(WriteFd(fd2.get(), &c, 1), SyscallSucceedsWithValue(1)); } +TEST(CreateTest, ChmodReadToWriteBetweenOpens_NoRandomSave) { + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be + // cleared for the same reason. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0400)); + + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + + // Cannot restore after making permissions more restrictive. + const DisableSave ds; + ASSERT_THAT(fchmod(rfd.get(), 0200), SyscallSucceeds()); + + EXPECT_THAT(open(file.path().c_str(), O_RDONLY), + SyscallFailsWithErrno(EACCES)); + + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, ChmodWriteToReadBetweenOpens_NoRandomSave) { + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + const TempPath file = + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileMode(0200)); + + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_WRONLY)); + + // Cannot restore after making permissions more restrictive. + const DisableSave ds; + ASSERT_THAT(fchmod(wfd.get(), 0400), SyscallSucceeds()); + + EXPECT_THAT(open(file.path().c_str(), O_WRONLY), + SyscallFailsWithErrno(EACCES)); + + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, CreateWithReadFlagNotAllowedByMode_NoRandomSave) { + // The only time we can open a file with flags forbidden by its permissions + // is when we are creating the file. We cannot re-open with the same flags, + // so we cannot restore an fd obtained from such an operation. + const DisableSave ds; + + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. CAP_DAC_READ_SEARCH needs to be + // cleared for the same reason. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_READ_SEARCH, false)); + + // Create and open a file with read flag but without read permissions. + const std::string path = NewTempAbsPath(); + const FileDescriptor rfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CREAT | O_RDONLY, 0222)); + + EXPECT_THAT(open(path.c_str(), O_RDONLY), SyscallFailsWithErrno(EACCES)); + const FileDescriptor wfd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_WRONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + +TEST(CreateTest, CreateWithWriteFlagNotAllowedByMode_NoRandomSave) { + // The only time we can open a file with flags forbidden by its permissions + // is when we are creating the file. We cannot re-open with the same flags, + // so we cannot restore an fd obtained from such an operation. + const DisableSave ds; + + // Make sure we don't have CAP_DAC_OVERRIDE, since that allows the user to + // override file read/write permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + // Create and open a file with write flag but without write permissions. + const std::string path = NewTempAbsPath(); + const FileDescriptor wfd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_CREAT | O_WRONLY, 0444)); + + EXPECT_THAT(open(path.c_str(), O_WRONLY), SyscallFailsWithErrno(EACCES)); + const FileDescriptor rfd = ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDONLY)); + + char c = 'x'; + EXPECT_THAT(write(wfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + c = 0; + EXPECT_THAT(read(rfd.get(), &c, 1), SyscallSucceedsWithValue(1)); + EXPECT_EQ(c, 'x'); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc index a11a03415..b558e3a01 100644 --- a/test/syscalls/linux/packet_socket_raw.cc +++ b/test/syscalls/linux/packet_socket_raw.cc @@ -14,9 +14,7 @@ #include <arpa/inet.h> #include <linux/capability.h> -#ifndef __fuchsia__ #include <linux/filter.h> -#endif // __fuchsia__ #include <linux/if_arp.h> #include <linux/if_packet.h> #include <net/ethernet.h> @@ -618,8 +616,6 @@ TEST_P(RawPacketTest, GetSocketErrorBind) { } } -#ifndef __fuchsia__ - TEST_P(RawPacketTest, SetSocketDetachFilterNoInstalledFilter) { // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. // @@ -647,7 +643,26 @@ TEST_P(RawPacketTest, GetSocketDetachFilter) { SyscallFailsWithErrno(ENOPROTOOPT)); } -#endif // __fuchsia__ +TEST_P(RawPacketTest, SetAndGetSocketLinger) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(s_, level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT(getsockopt(s_, level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} INSTANTIATE_TEST_SUITE_P(AllInetTests, RawPacketTest, ::testing::Values(ETH_P_IP, ETH_P_ALL)); diff --git a/test/syscalls/linux/pipe.cc b/test/syscalls/linux/pipe.cc index 34291850d..c097c9187 100644 --- a/test/syscalls/linux/pipe.cc +++ b/test/syscalls/linux/pipe.cc @@ -13,7 +13,9 @@ // limitations under the License. #include <fcntl.h> /* Obtain O_* constant definitions */ +#include <linux/magic.h> #include <sys/ioctl.h> +#include <sys/statfs.h> #include <sys/uio.h> #include <unistd.h> @@ -198,6 +200,16 @@ TEST_P(PipeTest, NonBlocking) { SyscallFailsWithErrno(EWOULDBLOCK)); } +TEST(PipeTest, StatFS) { + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + struct statfs st; + EXPECT_THAT(fstatfs(fds[0], &st), SyscallSucceeds()); + EXPECT_EQ(st.f_type, PIPEFS_MAGIC); + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + TEST(Pipe2Test, CloExec) { int fds[2]; ASSERT_THAT(pipe2(fds, O_CLOEXEC), SyscallSucceeds()); diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index 04c5161f5..f675dc430 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -153,7 +153,7 @@ TEST(PrctlTest, PDeathSig) { // Enable tracing, then raise SIGSTOP and expect our parent to suppress // it. TEST_CHECK(ptrace(PTRACE_TRACEME, 0, 0, 0) >= 0); - raise(SIGSTOP); + TEST_CHECK(raise(SIGSTOP) == 0); // Sleep until killed by our parent death signal. sleep(3) is // async-signal-safe, absl::SleepFor isn't. while (true) { diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index d6b875dbf..e8fcc4439 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -16,6 +16,7 @@ #include <errno.h> #include <fcntl.h> #include <limits.h> +#include <linux/magic.h> #include <sched.h> #include <signal.h> #include <stddef.h> @@ -26,6 +27,7 @@ #include <sys/mman.h> #include <sys/prctl.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/utsname.h> #include <syscall.h> #include <unistd.h> @@ -45,6 +47,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/container/node_hash_set.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" @@ -61,6 +64,7 @@ #include "test/util/fs_util.h" #include "test/util/memory_util.h" #include "test/util/posix_error.h" +#include "test/util/proc_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -670,6 +674,23 @@ TEST(ProcSelfMaps, Mprotect) { 3 * kPageSize, PROT_READ))); } +TEST(ProcSelfMaps, SharedAnon) { + const Mapping m = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ, MAP_SHARED | MAP_ANONYMOUS)); + + const auto proc_self_maps = + ASSERT_NO_ERRNO_AND_VALUE(GetContents("/proc/self/maps")); + for (const auto& line : absl::StrSplit(proc_self_maps, '\n')) { + const auto entry = ASSERT_NO_ERRNO_AND_VALUE(ParseProcMapsLine(line)); + if (entry.start <= m.addr() && m.addr() < entry.end) { + // cf. proc(5), "/proc/[pid]/map_files/" + EXPECT_EQ(entry.filename, "/dev/zero (deleted)"); + return; + } + } + FAIL() << "no maps entry containing mapping at " << m.ptr(); +} + TEST(ProcSelfFd, OpenFd) { int pipe_fds[2]; ASSERT_THAT(pipe2(pipe_fds, O_CLOEXEC), SyscallSucceeds()); @@ -692,6 +713,30 @@ TEST(ProcSelfFd, OpenFd) { ASSERT_THAT(close(pipe_fds[1]), SyscallSucceeds()); } +static void CheckFdDirGetdentsDuplicates(const std::string& path) { + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path.c_str(), O_RDONLY | O_DIRECTORY)); + // Open a FD whose value is supposed to be much larger than + // the number of FDs opened by current process. + auto newfd = fcntl(fd.get(), F_DUPFD, 1024); + EXPECT_GE(newfd, 1024); + auto fd_closer = Cleanup([newfd]() { close(newfd); }); + auto fd_files = ASSERT_NO_ERRNO_AND_VALUE(ListDir(path.c_str(), false)); + absl::node_hash_set<std::string> fd_files_dedup(fd_files.begin(), + fd_files.end()); + EXPECT_EQ(fd_files.size(), fd_files_dedup.size()); +} + +// This is a regression test for gvisor.dev/issues/3894 +TEST(ProcSelfFd, GetdentsDuplicates) { + CheckFdDirGetdentsDuplicates("/proc/self/fd"); +} + +// This is a regression test for gvisor.dev/issues/3894 +TEST(ProcSelfFdInfo, GetdentsDuplicates) { + CheckFdDirGetdentsDuplicates("/proc/self/fdinfo"); +} + TEST(ProcSelfFdInfo, CorrectFds) { // Make sure there is at least one open file. auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -735,8 +780,12 @@ TEST(ProcSelfFdInfo, Flags) { } TEST(ProcSelfExe, Absolute) { - auto exe = ASSERT_NO_ERRNO_AND_VALUE( - ReadLink(absl::StrCat("/proc/", getpid(), "/exe"))); + auto exe = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/exe")); + EXPECT_EQ(exe[0], '/'); +} + +TEST(ProcSelfCwd, Absolute) { + auto exe = ASSERT_NO_ERRNO_AND_VALUE(ReadLink("/proc/self/cwd")); EXPECT_EQ(exe[0], '/'); } @@ -771,17 +820,12 @@ TEST(ProcCpuinfo, DeniesWriteNonRoot) { constexpr int kNobody = 65534; EXPECT_THAT(syscall(SYS_setuid, kNobody), SyscallSucceeds()); EXPECT_THAT(open("/proc/cpuinfo", O_WRONLY), SyscallFailsWithErrno(EACCES)); - // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in - // kernfs. - if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { - EXPECT_THAT(truncate("/proc/cpuinfo", 123), - SyscallFailsWithErrno(EACCES)); - } + EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EACCES)); }); } // With root privileges, it is possible to open /proc/cpuinfo with write mode, -// but all write operations will return EIO. +// but all write operations should fail. TEST(ProcCpuinfo, DeniesWriteRoot) { // VFS1 does not behave differently for root/non-root. SKIP_IF(IsRunningWithVFS1()); @@ -790,16 +834,10 @@ TEST(ProcCpuinfo, DeniesWriteRoot) { int fd; EXPECT_THAT(fd = open("/proc/cpuinfo", O_WRONLY), SyscallSucceeds()); if (fd > 0) { - EXPECT_THAT(write(fd, "x", 1), SyscallFailsWithErrno(EIO)); - EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFailsWithErrno(EIO)); - } - // TODO(gvisor.dev/issue/1193): Properly support setting size attributes in - // kernfs. - if (!IsRunningOnGvisor() || IsRunningWithVFS1()) { - if (fd > 0) { - EXPECT_THAT(ftruncate(fd, 123), SyscallFailsWithErrno(EIO)); - } - EXPECT_THAT(truncate("/proc/cpuinfo", 123), SyscallFailsWithErrno(EIO)); + // Truncate is not tested--it may succeed on some kernels without doing + // anything. + EXPECT_THAT(write(fd, "x", 1), SyscallFails()); + EXPECT_THAT(pwrite(fd, "x", 1, 123), SyscallFails()); } } @@ -1439,6 +1477,16 @@ TEST(ProcPidExe, Subprocess) { EXPECT_EQ(actual, expected_absolute_path); } +// /proc/PID/cwd points to the correct directory. +TEST(ProcPidCwd, Subprocess) { + auto want = ASSERT_NO_ERRNO_AND_VALUE(GetCWD()); + + char got[PATH_MAX + 1] = {}; + ASSERT_THAT(ReadlinkWhileRunning("cwd", got, sizeof(got)), + SyscallSucceedsWithValue(Gt(0))); + EXPECT_EQ(got, want); +} + // Test whether /proc/PID/ files can be read for a running process. TEST(ProcPidFile, SubprocessRunning) { char buf[1]; @@ -2159,6 +2207,18 @@ TEST(Proc, PidTidIOAccounting) { noop.Join(); } +TEST(Proc, Statfs) { + struct statfs st; + EXPECT_THAT(statfs("/proc", &st), SyscallSucceeds()); + if (IsRunningWithVFS1()) { + EXPECT_EQ(st.f_type, ANON_INODE_FS_MAGIC); + } else { + EXPECT_EQ(st.f_type, PROC_SUPER_MAGIC); + } + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 3377b65cf..23677e296 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -39,6 +39,7 @@ namespace testing { namespace { constexpr const char kProcNet[] = "/proc/net"; +constexpr const char kIpForward[] = "/proc/sys/net/ipv4/ip_forward"; TEST(ProcNetSymlinkTarget, FileMode) { struct stat s; @@ -477,6 +478,84 @@ TEST(ProcNetSnmp, CheckSnmp) { EXPECT_EQ(value_count, 1); } +TEST(ProcSysNetIpv4Recovery, Exists) { + EXPECT_THAT(open("/proc/sys/net/ipv4/tcp_recovery", O_RDONLY), + SyscallSucceeds()); +} + +TEST(ProcSysNetIpv4Recovery, CanReadAndWrite) { + // TODO(b/162988252): Enable save/restore for this test after the bug is + // fixed. + DisableSave ds; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open("/proc/sys/net/ipv4/tcp_recovery", O_RDWR)); + + char buf[10] = {'\0'}; + char to_write = '2'; + + // Check initial value is set to 1. + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(to_write) + 1)); + EXPECT_EQ(strcmp(buf, "1\n"), 0); + + // Set tcp_recovery to one of the allowed constants. + EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), + SyscallSucceedsWithValue(sizeof(to_write))); + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(to_write) + 1)); + EXPECT_EQ(strcmp(buf, "2\n"), 0); + + // Set tcp_recovery to any random value. + char kMessage[] = "100"; + EXPECT_THAT(PwriteFd(fd.get(), kMessage, strlen(kMessage), 0), + SyscallSucceedsWithValue(strlen(kMessage))); + EXPECT_THAT(PreadFd(fd.get(), buf, sizeof(kMessage), 0), + SyscallSucceedsWithValue(sizeof(kMessage))); + EXPECT_EQ(strcmp(buf, "100\n"), 0); +} + +TEST(ProcSysNetIpv4IpForward, Exists) { + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDONLY)); +} + +TEST(ProcSysNetIpv4IpForward, DefaultValueEqZero) { + // Test is only valid in sandbox. Not hermetic in native tests + // running on a arbitrary machine. + SKIP_IF(!IsRunningOnGvisor()); + auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDONLY)); + + char buf = 101; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_EQ(buf, '0') << "unexpected ip_forward: " << buf; +} + +TEST(ProcSysNetIpv4IpForward, CanReadAndWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability((CAP_DAC_OVERRIDE)))); + + auto const fd = ASSERT_NO_ERRNO_AND_VALUE(Open(kIpForward, O_RDWR)); + + char buf; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_TRUE(buf == '0' || buf == '1') << "unexpected ip_forward: " << buf; + + // constexpr char to_write = '1'; + char to_write = (buf == '1') ? '0' : '1'; + EXPECT_THAT(PwriteFd(fd.get(), &to_write, sizeof(to_write), 0), + SyscallSucceedsWithValue(sizeof(to_write))); + + buf = 0; + EXPECT_THAT(PreadFd(fd.get(), &buf, sizeof(buf), 0), + SyscallSucceedsWithValue(sizeof(buf))); + EXPECT_EQ(buf, to_write); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index f9392b9e0..0b174e2be 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -51,6 +51,7 @@ using ::testing::AnyOf; using ::testing::Contains; using ::testing::Eq; using ::testing::Not; +using SubprocessCallback = std::function<void()>; // Tests Unix98 pseudoterminals. // @@ -389,15 +390,15 @@ TEST(PtyTrunc, Truncate) { // (f)truncate should. FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC)); - int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master)); + int n = ASSERT_NO_ERRNO_AND_VALUE(ReplicaID(master)); std::string spath = absl::StrCat("/dev/pts/", n); - FileDescriptor slave = + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(replica.get(), 0), SyscallFailsWithErrno(EINVAL)); } TEST(BasicPtyTest, StatUnopenedMaster) { @@ -453,16 +454,16 @@ void ExpectReadable(const FileDescriptor& fd, int expected, char* buf) { EXPECT_EQ(expected, n); } -TEST(BasicPtyTest, OpenMasterSlave) { +TEST(BasicPtyTest, OpenMasterReplica) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); } -// The slave entry in /dev/pts/ disappears when the master is closed, even if -// the slave is still open. -TEST(BasicPtyTest, SlaveEntryGoneAfterMasterClose) { +// The replica entry in /dev/pts/ disappears when the master is closed, even if +// the replica is still open. +TEST(BasicPtyTest, ReplicaEntryGoneAfterMasterClose) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); // Get pty index. int index = -1; @@ -482,12 +483,12 @@ TEST(BasicPtyTest, Getdents) { FileDescriptor master1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); int index1 = -1; ASSERT_THAT(ioctl(master1.get(), TIOCGPTN, &index1), SyscallSucceeds()); - FileDescriptor slave1 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master1)); + FileDescriptor replica1 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master1)); FileDescriptor master2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); int index2 = -1; ASSERT_THAT(ioctl(master2.get(), TIOCGPTN, &index2), SyscallSucceeds()); - FileDescriptor slave2 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master2)); + FileDescriptor replica2 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master2)); // The directory contains ptmx, index1, and index2. (Plus any additional PTYs // unrelated to this test.) @@ -519,59 +520,60 @@ class PtyTest : public ::testing::Test { protected: void SetUp() override { master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master_)); } void DisableCanonical() { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); t.c_lflag &= ~ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); } void EnableCanonical() { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); t.c_lflag |= ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); } - // Master and slave ends of the PTY. Non-blocking. + // Master and replica ends of the PTY. Non-blocking. FileDescriptor master_; - FileDescriptor slave_; + FileDescriptor replica_; }; -// Master to slave sanity test. -TEST_F(PtyTest, WriteMasterToSlave) { - // N.B. by default, the slave reads nothing until the master writes a newline. +// Master to replica sanity test. +TEST_F(PtyTest, WriteMasterToReplica) { + // N.B. by default, the replica reads nothing until the master writes a + // newline. constexpr char kBuf[] = "hello\n"; EXPECT_THAT(WriteFd(master_.get(), kBuf, sizeof(kBuf) - 1), SyscallSucceedsWithValue(sizeof(kBuf) - 1)); - // Linux moves data from the master to the slave via async work scheduled via - // tty_flip_buffer_push. Since it is asynchronous, the data may not be + // Linux moves data from the master to the replica via async work scheduled + // via tty_flip_buffer_push. Since it is asynchronous, the data may not be // available for reading immediately. Instead we must poll and assert that it // becomes available "soon". char buf[sizeof(kBuf)] = {}; - ExpectReadable(slave_, sizeof(buf) - 1, buf); + ExpectReadable(replica_, sizeof(buf) - 1, buf); EXPECT_EQ(memcmp(buf, kBuf, sizeof(kBuf)), 0); } -// Slave to master sanity test. -TEST_F(PtyTest, WriteSlaveToMaster) { - // N.B. by default, the master reads nothing until the slave writes a newline, - // and the master gets a carriage return. +// Replica to master sanity test. +TEST_F(PtyTest, WriteReplicaToMaster) { + // N.B. by default, the master reads nothing until the replica writes a + // newline, and the master gets a carriage return. constexpr char kInput[] = "hello\n"; constexpr char kExpected[] = "hello\r\n"; - EXPECT_THAT(WriteFd(slave_.get(), kInput, sizeof(kInput) - 1), + EXPECT_THAT(WriteFd(replica_.get(), kInput, sizeof(kInput) - 1), SyscallSucceedsWithValue(sizeof(kInput) - 1)); - // Linux moves data from the master to the slave via async work scheduled via - // tty_flip_buffer_push. Since it is asynchronous, the data may not be + // Linux moves data from the master to the replica via async work scheduled + // via tty_flip_buffer_push. Since it is asynchronous, the data may not be // available for reading immediately. Instead we must poll and assert that it // becomes available "soon". @@ -587,32 +589,33 @@ TEST_F(PtyTest, WriteInvalidUTF8) { SyscallSucceedsWithValue(sizeof(c))); } -// Both the master and slave report the standard default termios settings. +// Both the master and replica report the standard default termios settings. // -// Note that TCGETS on the master actually redirects to the slave (see comment +// Note that TCGETS on the master actually redirects to the replica (see comment // on MasterTermiosUnchangable). TEST_F(PtyTest, DefaultTermios) { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); EXPECT_EQ(t, DefaultTermios()); EXPECT_THAT(ioctl(master_.get(), TCGETS, &t), SyscallSucceeds()); EXPECT_EQ(t, DefaultTermios()); } -// Changing termios from the master actually affects the slave. +// Changing termios from the master actually affects the replica. // -// TCSETS on the master actually redirects to the slave (see comment on +// TCSETS on the master actually redirects to the replica (see comment on // MasterTermiosUnchangable). -TEST_F(PtyTest, TermiosAffectsSlave) { +TEST_F(PtyTest, TermiosAffectsReplica) { struct kernel_termios master_termios = {}; EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds()); master_termios.c_lflag ^= ICANON; EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); - struct kernel_termios slave_termios = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &slave_termios), SyscallSucceeds()); - EXPECT_EQ(master_termios, slave_termios); + struct kernel_termios replica_termios = {}; + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &replica_termios), + SyscallSucceeds()); + EXPECT_EQ(master_termios, replica_termios); } // The master end of the pty has termios: @@ -627,7 +630,7 @@ TEST_F(PtyTest, TermiosAffectsSlave) { // // (From drivers/tty/pty.c:unix98_pty_init) // -// All termios control ioctls on the master actually redirect to the slave +// All termios control ioctls on the master actually redirect to the replica // (drivers/tty/tty_ioctl.c:tty_mode_ioctl), making it impossible to change the // master termios. // @@ -640,7 +643,7 @@ TEST_F(PtyTest, MasterTermiosUnchangable) { EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); ExpectReadable(master_, 1, &c); EXPECT_EQ(c, '\r'); // ICRNL had no effect! @@ -653,15 +656,15 @@ TEST_F(PtyTest, TermiosICRNL) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= ICRNL; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\r'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); EXPECT_EQ(c, '\n'); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ONLCR rewrites output \n to \r\n. @@ -669,42 +672,42 @@ TEST_F(PtyTest, TermiosONLCR) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= ONLCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\n'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Extra byte for NUL for EXPECT_STREQ. char buf[3] = {}; ExpectReadable(master_, 2, buf); EXPECT_STREQ(buf, "\r\n"); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, TermiosIGNCR) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\r'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. - ASSERT_THAT(PollAndReadFd(slave_.get(), &c, 1, kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), &c, 1, kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); } -// Test that we can successfully poll for readable data from the slave. -TEST_F(PtyTest, TermiosPollSlave) { +// Test that we can successfully poll for readable data from the replica. +TEST_F(PtyTest, TermiosPollReplica) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); absl::Notification notify; - int sfd = slave_.get(); + int sfd = replica_.get(); ScopedThread th([sfd, ¬ify]() { notify.Notify(); @@ -753,33 +756,33 @@ TEST_F(PtyTest, TermiosPollMaster) { absl::SleepFor(absl::Seconds(1)); char s[] = "foo\n"; - ASSERT_THAT(WriteFd(slave_.get(), s, strlen(s) + 1), SyscallSucceeds()); + ASSERT_THAT(WriteFd(replica_.get(), s, strlen(s) + 1), SyscallSucceeds()); } TEST_F(PtyTest, TermiosINLCR) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= INLCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\n'; ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); EXPECT_EQ(c, '\r'); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, TermiosONOCR) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= ONOCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); // The terminal is at column 0, so there should be no CR to read. char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), @@ -789,7 +792,7 @@ TEST_F(PtyTest, TermiosONOCR) { // out of the other end. constexpr char kInput[] = "foo\r"; constexpr int kInputSize = sizeof(kInput) - 1; - ASSERT_THAT(WriteFd(slave_.get(), kInput, kInputSize), + ASSERT_THAT(WriteFd(replica_.get(), kInput, kInputSize), SyscallSucceedsWithValue(kInputSize)); char buf[kInputSize] = {}; @@ -800,7 +803,7 @@ TEST_F(PtyTest, TermiosONOCR) { ExpectFinished(master_); // Terminal should be at column 0 again, so no CR can be read. - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), @@ -811,11 +814,11 @@ TEST_F(PtyTest, TermiosOCRNL) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= OCRNL; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); // The terminal is at column 0, so there should be no CR to read. char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); ExpectReadable(master_, 1, &c); EXPECT_EQ(c, '\n'); @@ -831,24 +834,24 @@ TEST_F(PtyTest, VEOLTermination) { ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), SyscallSucceedsWithValue(sizeof(kInput))); char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); // Set the EOL character to '=' and write it. constexpr char delim = '='; struct kernel_termios t = DefaultTermios(); t.c_cc[VEOL] = delim; - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now we can read, as sending EOL caused the line to become available. - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_EQ(memcmp(buf, kInput, sizeof(kInput)), 0); - ExpectReadable(slave_, 1, buf); + ExpectReadable(replica_, 1, buf); EXPECT_EQ(buf[0], '='); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that we can write more than the 4096 character limit, then a @@ -864,9 +867,9 @@ TEST_F(PtyTest, CanonBigWrite) { // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); + ExpectReadable(replica_, kMaxLineSize, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that data written in canonical mode can be read immediately once @@ -880,15 +883,15 @@ TEST_F(PtyTest, SwitchCanonToNoncanon) { // Nothing available yet. char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); DisableCanonical(); - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchCanonToNonCanonNewline) { @@ -901,10 +904,10 @@ TEST_F(PtyTest, SwitchCanonToNonCanonNewline) { // We can read the line. char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { @@ -917,7 +920,7 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); // Wait for the input queue to fill. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); constexpr char delim = '\n'; ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); @@ -925,12 +928,12 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); + ExpectReadable(replica_, kMaxLineSize - 1, buf); // We can also read the remaining characters. - ExpectReadable(slave_, 6, buf); + ExpectReadable(replica_, 6, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { @@ -942,15 +945,15 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput) - 1), SyscallSucceedsWithValue(sizeof(kInput) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput) - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput) - 1)); EnableCanonical(); // We can read the line. char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput) - 1, buf); + ExpectReadable(replica_, sizeof(kInput) - 1, buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { @@ -964,14 +967,14 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); EnableCanonical(); // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); + ExpectReadable(replica_, kMaxLineSize - 1, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that we can write over the 4095 noncanonical limit, then read out @@ -990,17 +993,17 @@ TEST_F(PtyTest, NoncanonBigWrite) { } // We should be able to read out everything. Sleep a bit so that Linux has a - // chance to move data from the master to the slave. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + // chance to move data from the master to the replica. + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); for (int i = 0; i < kInputSize; i++) { // This makes too many syscalls for save/restore. const DisableSave ds; char c; - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); ASSERT_EQ(c, kInput); } - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON doesn't make input available until a line delimiter is typed. @@ -1015,18 +1018,18 @@ TEST_F(PtyTest, TermiosICANONNewline) { char buf[5] = {}; // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); char delim = '\n'; ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now it is available. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(input) + 1)); - ExpectReadable(slave_, sizeof(input) + 1, buf); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(input) + 1)); + ExpectReadable(replica_, sizeof(input) + 1, buf); EXPECT_STREQ(buf, "abc\n"); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON doesn't make input available until a line delimiter is typed. @@ -1041,16 +1044,16 @@ TEST_F(PtyTest, TermiosICANONEOF) { char buf[4] = {}; // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); char delim = ControlCharacter('D'); ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now it is available. Note that ^D is not included. - ExpectReadable(slave_, sizeof(input), buf); + ExpectReadable(replica_, sizeof(input), buf); EXPECT_STREQ(buf, "abc"); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON limits us to 4096 bytes including a terminating character. Anything @@ -1076,12 +1079,12 @@ TEST_F(PtyTest, CanonDiscard) { // There should be multiple truncated lines available to read. for (int i = 0; i < kIter; i++) { char buf[kInputSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); + ExpectReadable(replica_, kMaxLineSize, buf); EXPECT_EQ(buf[kMaxLineSize - 1], delim); EXPECT_EQ(buf[kMaxLineSize - 2], kInput); } - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, CanonMultiline) { @@ -1096,15 +1099,15 @@ TEST_F(PtyTest, CanonMultiline) { // Get the first line. char line1[8] = {}; - ExpectReadable(slave_, sizeof(kInput1) - 1, line1); + ExpectReadable(replica_, sizeof(kInput1) - 1, line1); EXPECT_STREQ(line1, kInput1); // Get the second line. char line2[8] = {}; - ExpectReadable(slave_, sizeof(kInput2) - 1, line2); + ExpectReadable(replica_, sizeof(kInput2) - 1, line2); EXPECT_STREQ(line2, kInput2); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { @@ -1121,15 +1124,15 @@ TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { SyscallSucceedsWithValue(sizeof(kInput2) - 1)); ASSERT_NO_ERRNO( - WaitUntilReceived(slave_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); + WaitUntilReceived(replica_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); EnableCanonical(); // Get all together as one line. char line[9] = {}; - ExpectReadable(slave_, 8, line); + ExpectReadable(replica_, 8, line); EXPECT_STREQ(line, kExpected); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchTwiceMultiline) { @@ -1146,15 +1149,15 @@ TEST_F(PtyTest, SwitchTwiceMultiline) { // All written characters have to make it into the input queue before // canonical mode is re-enabled. If the final '!' character hasn't been // enqueued before canonical mode is re-enabled, it won't be readable. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kExpected.size())); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kExpected.size())); EnableCanonical(); // Get all together as one line. char line[10] = {}; - ExpectReadable(slave_, 9, line); + ExpectReadable(replica_, 9, line); EXPECT_STREQ(line, kExpected.c_str()); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, QueueSize) { @@ -1162,7 +1165,7 @@ TEST_F(PtyTest, QueueSize) { constexpr char kInput1[] = "GO\n"; ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1)); // Ensure that writing more (beyond what is readable) does not impact the // readable size. @@ -1171,7 +1174,7 @@ TEST_F(PtyTest, QueueSize) { ASSERT_THAT(WriteFd(master_.get(), input, kMaxLineSize), SyscallSucceedsWithValue(kMaxLineSize)); int inputBufSize = ASSERT_NO_ERRNO_AND_VALUE( - WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); + WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1)); EXPECT_EQ(inputBufSize, sizeof(kInput1) - 1); } @@ -1196,9 +1199,9 @@ TEST_F(PtyTest, PartialBadBuffer) { EXPECT_THAT(WriteFd(master_.get(), kBuf, size), SyscallSucceedsWithValue(size)); - // Read from the slave into bad_buffer. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), size)); - EXPECT_THAT(ReadFd(slave_.get(), bad_buffer, size), + // Read from the replica into bad_buffer. + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), size)); + EXPECT_THAT(ReadFd(replica_.get(), bad_buffer, size), SyscallFailsWithErrno(EFAULT)); EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()) << addr; @@ -1218,16 +1221,16 @@ TEST_F(PtyTest, SimpleEcho) { TEST_F(PtyTest, GetWindowSize) { struct winsize ws; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); EXPECT_EQ(ws.ws_row, 0); EXPECT_EQ(ws.ws_col, 0); } -TEST_F(PtyTest, SetSlaveWindowSize) { +TEST_F(PtyTest, SetReplicaWindowSize) { constexpr uint16_t kRows = 343; constexpr uint16_t kCols = 2401; struct winsize ws = {.ws_row = kRows, .ws_col = kCols}; - ASSERT_THAT(ioctl(slave_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); struct winsize retrieved_ws = {}; ASSERT_THAT(ioctl(master_.get(), TIOCGWINSZ, &retrieved_ws), @@ -1243,7 +1246,7 @@ TEST_F(PtyTest, SetMasterWindowSize) { ASSERT_THAT(ioctl(master_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); struct winsize retrieved_ws = {}; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &retrieved_ws), + ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &retrieved_ws), SyscallSucceeds()); EXPECT_EQ(retrieved_ws.ws_row, kRows); EXPECT_EQ(retrieved_ws.ws_col, kCols); @@ -1253,7 +1256,7 @@ class JobControlTest : public ::testing::Test { protected: void SetUp() override { master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master_)); // Make this a session leader, which also drops the controlling terminal. // In the gVisor test environment, this test will be run as the session @@ -1263,61 +1266,82 @@ class JobControlTest : public ::testing::Test { } } - // Master and slave ends of the PTY. Non-blocking. + PosixError RunInChild(SubprocessCallback childFunc) { + pid_t child = fork(); + if (!child) { + childFunc(); + _exit(0); + } + int wstatus; + if (waitpid(child, &wstatus, 0) != child) { + return PosixError( + errno, absl::StrCat("child failed with wait status: ", wstatus)); + } + return PosixError(wstatus, "process returned"); + } + + // Master and replica ends of the PTY. Non-blocking. FileDescriptor master_; - FileDescriptor slave_; + FileDescriptor replica_; }; TEST_F(JobControlTest, SetTTYMaster) { - ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(master_.get(), TIOCSCTTY, 0)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(!replica_.get(), TIOCSCTTY, 0)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYNonLeader) { // Fork a process that won't be the session leader. - pid_t child = fork(); - if (!child) { - // We shouldn't be able to set the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + auto res = + RunInChild([=]() { TEST_PCHECK(ioctl(replica_.get(), TIOCSCTTY, 0)); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYBadArg) { - // Despite the man page saying arg should be 0 here, Linux doesn't actually - // check. - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYDifferentSession) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Fork, join a new session, and try to steal the parent's controlling - // terminal, which should fail. - pid_t child = fork(); - if (!child) { + auto res = RunInChild([=]() { TEST_PCHECK(setsid() >= 0); - // We shouldn't be able to steal the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1)); - _exit(0); - } + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should fail. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal. + TEST_PCHECK(ioctl(replica_.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int gcwstatus; + TEST_PCHECK(waitpid(grandchild, &gcwstatus, 0) == grandchild); + TEST_PCHECK(gcwstatus == 0); + }); } TEST_F(JobControlTest, ReleaseTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Make sure we're ignoring SIGHUP, which will be sent to this process once we // disconnect they TTY. @@ -1327,48 +1351,60 @@ TEST_F(JobControlTest, ReleaseTTY) { sigemptyset(&sa.sa_mask); struct sigaction old_sa; EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds()); - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); } TEST_F(JobControlTest, ReleaseUnsetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); + ASSERT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); } TEST_F(JobControlTest, ReleaseWrongTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + TEST_PCHECK(ioctl(master_.get(), TIOCNOTTY) < 0 && errno == ENOTTY); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, ReleaseTTYNonLeader) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!ioctl(replica_.get(), TIOCNOTTY)); + _exit(0); + } - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(wstatus == 0); + }); + ASSERT_NO_ERRNO(ret); } TEST_F(JobControlTest, ReleaseTTYDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - pid_t child = fork(); - if (!child) { - // Join a new session, then try to disconnect. + auto ret = RunInChild([=]() { TEST_PCHECK(setsid() >= 0); - TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + pid_t grandchild = fork(); + if (!grandchild) { + // Join a new session, then try to disconnect. + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(replica_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(wstatus == 0); + }); + ASSERT_NO_ERRNO(ret); } // Used by the child process spawned in ReleaseTTYSignals to track received @@ -1387,7 +1423,7 @@ void sig_handler(int signum) { received |= signum; } // - Checks that thread 1 got both signals // - Checks that thread 2 didn't get any signals. TEST_F(JobControlTest, ReleaseTTYSignals) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); received = 0; struct sigaction sa = {}; @@ -1439,7 +1475,7 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { // Release the controlling terminal, sending SIGHUP and SIGCONT to all other // processes in this process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); @@ -1456,20 +1492,21 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { } TEST_F(JobControlTest, GetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - pid_t foreground_pgid; - pid_t pid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), - SyscallSucceeds()); - ASSERT_THAT(pid = getpid(), SyscallSucceeds()); - - ASSERT_EQ(foreground_pgid, pid); + auto res = RunInChild([=]() { + pid_t pid, foreground_pgid; + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); + TEST_PCHECK(!ioctl(replica_.get(), TIOCGPGRP, &foreground_pgid)); + TEST_PCHECK((pid = getpid()) >= 0); + TEST_PCHECK(pid == foreground_pgid); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // At this point there's no controlling terminal, so TIOCGPGRP should fail. pid_t foreground_pgid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + ASSERT_THAT(ioctl(replica_.get(), TIOCGPGRP, &foreground_pgid), SyscallFailsWithErrno(ENOTTY)); } @@ -1479,113 +1516,125 @@ TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // - sets that child as the foreground process group // - kills its child and sets itself as the foreground process group. TEST_F(JobControlTest, SetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - sigemptyset(&sa.sa_mask); - sigaction(SIGTTOU, &sa, NULL); - - // Set ourself as the foreground process group. - ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds()); - - // Create a new process that just waits to be signaled. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!pause()); - // We should never reach this. - _exit(1); - } - - // Make the child its own process group, then make it the controlling process - // group of the terminal. - ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + sa.sa_flags = 0; + sigemptyset(&sa.sa_mask); + sigaction(SIGTTOU, &sa, NULL); + + // Set ourself as the foreground process group. + TEST_PCHECK(!tcsetpgrp(replica_.get(), getpgid(0))); + + // Create a new process that just waits to be signaled. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!pause()); + // We should never reach this. + _exit(1); + } - // Sanity check - we're still the controlling session. - ASSERT_EQ(getsid(0), getsid(child)); + // Make the child its own process group, then make it the controlling + // process group of the terminal. + TEST_PCHECK(!setpgid(grandchild, grandchild)); + TEST_PCHECK(!tcsetpgrp(replica_.get(), grandchild)); - // Signal the child, wait for it to exit, then retake the terminal. - ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds()); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_TRUE(WIFSIGNALED(wstatus)); - ASSERT_EQ(WTERMSIG(wstatus), SIGTERM); + // Sanity check - we're still the controlling session. + TEST_PCHECK(getsid(0) == getsid(grandchild)); - // Set ourself as the foreground process. - pid_t pgid; - ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds()); + // Signal the child, wait for it to exit, then retake the terminal. + TEST_PCHECK(!kill(grandchild, SIGTERM)); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(WIFSIGNALED(wstatus)); + TEST_PCHECK(WTERMSIG(wstatus) == SIGTERM); + + // Set ourself as the foreground process. + pid_t pgid; + TEST_PCHECK(pgid = getpgid(0) == 0); + TEST_PCHECK(!tcsetpgrp(replica_.get(), pgid)); + }); } TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { pid_t pid = getpid(); - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + ASSERT_THAT(ioctl(replica_.get(), TIOCSPGRP, &pid), SyscallFailsWithErrno(ENOTTY)); } TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - pid_t pid = -1; - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), - SyscallFailsWithErrno(EINVAL)); + pid_t pid = -1; + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &pid) && errno == EINVAL); + }); + ASSERT_NO_ERRNO(ret); } TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Create a new process, put it in a new process group, make that group the - // foreground process group, then have the process wait. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!setpgid(0, 0)); - _exit(0); - } + auto ret = RunInChild([=]() { + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + // Create a new process, put it in a new process group, make that group the + // foreground process group, then have the process wait. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!setpgid(0, 0)); + _exit(0); + } - // Wait for the child to exit. - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - // The child's process group doesn't exist anymore - this should fail. - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(ESRCH)); + // Wait for the child to exit. + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + // The child's process group doesn't exist anymore - this should fail. + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) != 0 && + errno == ESRCH); + }); } TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - int sync_setsid[2]; - int sync_exit[2]; - ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds()); - ASSERT_THAT(pipe(sync_exit), SyscallSucceeds()); + int sync_setsid[2]; + int sync_exit[2]; + TEST_PCHECK(pipe(sync_setsid) >= 0); + TEST_PCHECK(pipe(sync_exit) >= 0); - // Create a new process and put it in a new session. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(setsid() >= 0); - // Tell the parent we're in a new session. - char c = 'c'; - TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); - TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); - _exit(0); - } + // Create a new process and put it in a new session. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(setsid() >= 0); + // Tell the parent we're in a new session. + char c = 'c'; + TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); + TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); + _exit(0); + } - // Wait for the child to tell us it's in a new session. - char c = 'c'; - ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1)); + // Wait for the child to tell us it's in a new session. + char c = 'c'; + TEST_PCHECK(ReadFd(sync_setsid[0], &c, 1) == 1); - // Child is in a new session, so we can't make it the foregroup process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(EPERM)); + // Child is in a new session, so we can't make it the foregroup process + // group. + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) && + errno == EPERM); - EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1)); + TEST_PCHECK(WriteFd(sync_exit[1], &c, 1) == 1); - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFEXITED(wstatus)); - EXPECT_EQ(WEXITSTATUS(wstatus), 0); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(WIFEXITED(wstatus)); + TEST_PCHECK(!WEXITSTATUS(wstatus)); + }); + ASSERT_NO_ERRNO(ret); } // Verify that we don't hang when creating a new session from an orphaned diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc index 1d7dbefdb..4ac648729 100644 --- a/test/syscalls/linux/pty_root.cc +++ b/test/syscalls/linux/pty_root.cc @@ -50,10 +50,10 @@ TEST(JobControlRootTest, StealTTY) { FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(master)); - // Make slave the controlling terminal. - ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds()); + // Make replica the controlling terminal. + ASSERT_THAT(ioctl(replica.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Fork, join a new session, and try to steal the parent's controlling // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg @@ -62,9 +62,9 @@ TEST(JobControlRootTest, StealTTY) { if (!child) { ASSERT_THAT(setsid(), SyscallSucceeds()); // We shouldn't be able to steal the terminal with the wrong arg value. - TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); + TEST_PCHECK(ioctl(replica.get(), TIOCSCTTY, 0)); // We should be able to steal it if we are true root. - TEST_PCHECK(true_root == !ioctl(slave.get(), TIOCSCTTY, 1)); + TEST_PCHECK(true_root == !ioctl(replica.get(), TIOCSCTTY, 1)); _exit(0); } diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 8d6e5c913..54709371c 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -13,9 +13,7 @@ // limitations under the License. #include <linux/capability.h> -#ifndef __fuchsia__ #include <linux/filter.h> -#endif // __fuchsia__ #include <netinet/in.h> #include <netinet/ip.h> #include <netinet/ip6.h> @@ -815,8 +813,6 @@ void RawSocketTest::ReceiveBufFrom(int sock, char* recv_buf, ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(sock, recv_buf, recv_buf_len)); } -#ifndef __fuchsia__ - TEST_P(RawSocketTest, SetSocketDetachFilterNoInstalledFilter) { // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. if (IsRunningOnGvisor()) { @@ -838,8 +834,6 @@ TEST_P(RawSocketTest, GetSocketDetachFilter) { SyscallFailsWithErrno(ENOPROTOOPT)); } -#endif // __fuchsia__ - // AF_INET6+SOCK_RAW+IPPROTO_RAW sockets can be created, but not written to. TEST(RawSocketTest, IPv6ProtoRaw) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 3de898df7..1b9dbc584 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -416,6 +416,28 @@ TEST_F(RawSocketICMPTest, BindConnectSendAndReceive) { ASSERT_NO_FATAL_FAILURE(ExpectICMPSuccess(icmp)); } +// Set and get SO_LINGER. +TEST_F(RawSocketICMPTest, SetAndGetSocketLinger) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(s_, level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT(getsockopt(s_, level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} + void RawSocketICMPTest::ExpectICMPSuccess(const struct icmphdr& icmp) { // We're going to receive both the echo request and reply, but the order is // indeterminate. diff --git a/test/syscalls/linux/readahead.cc b/test/syscalls/linux/readahead.cc index 09703b5c1..71073bb3c 100644 --- a/test/syscalls/linux/readahead.cc +++ b/test/syscalls/linux/readahead.cc @@ -16,6 +16,7 @@ #include <fcntl.h> #include "gtest/gtest.h" +#include "test/syscalls/linux/socket_test_util.h" #include "test/util/file_descriptor.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -29,7 +30,15 @@ TEST(ReadaheadTest, InvalidFD) { EXPECT_THAT(readahead(-1, 1, 1), SyscallFailsWithErrno(EBADF)); } +TEST(ReadaheadTest, UnsupportedFile) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0)); + ASSERT_THAT(readahead(sock.get(), 1, 1), SyscallFailsWithErrno(EINVAL)); +} + TEST(ReadaheadTest, InvalidOffset) { + // This test is not valid for some Linux Kernels. + SKIP_IF(!IsRunningOnGvisor()); const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); @@ -79,6 +88,8 @@ TEST(ReadaheadTest, WriteOnly) { } TEST(ReadaheadTest, InvalidSize) { + // This test is not valid on some Linux kernels. + SKIP_IF(!IsRunningOnGvisor()); const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR)); diff --git a/test/syscalls/linux/rename.cc b/test/syscalls/linux/rename.cc index 833c0dc4f..5458f54ad 100644 --- a/test/syscalls/linux/rename.cc +++ b/test/syscalls/linux/rename.cc @@ -170,6 +170,9 @@ TEST(RenameTest, FileOverwritesFile) { } TEST(RenameTest, DirectoryOverwritesDirectoryLinkCount) { + // Directory link counts are synthetic on overlay filesystems. + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + auto parent1 = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); EXPECT_THAT(Links(parent1.path()), IsPosixErrorOkAndHolds(2)); diff --git a/test/syscalls/linux/rseq.cc b/test/syscalls/linux/rseq.cc index 4bfb1ff56..94f9154a0 100644 --- a/test/syscalls/linux/rseq.cc +++ b/test/syscalls/linux/rseq.cc @@ -24,6 +24,7 @@ #include "test/syscalls/linux/rseq/uapi.h" #include "test/util/logging.h" #include "test/util/multiprocess_util.h" +#include "test/util/posix_error.h" #include "test/util/test_util.h" namespace gvisor { @@ -31,6 +32,9 @@ namespace testing { namespace { +using ::testing::AnyOf; +using ::testing::Eq; + // Syscall test for rseq (restartable sequences). // // We must be very careful about how these tests are written. Each thread may @@ -98,7 +102,7 @@ void RunChildTest(std::string test_case, int want_status) { int status = 0; ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), SyscallSucceeds()); - ASSERT_EQ(status, want_status); + ASSERT_THAT(status, AnyOf(Eq(want_status), Eq(128 + want_status))); } // Test that rseq must be aligned. diff --git a/test/syscalls/linux/rseq/rseq.cc b/test/syscalls/linux/rseq/rseq.cc index f036db26d..6f5d38bba 100644 --- a/test/syscalls/linux/rseq/rseq.cc +++ b/test/syscalls/linux/rseq/rseq.cc @@ -74,84 +74,95 @@ int TestUnaligned() { // Sanity test that registration works. int TestRegister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } return 0; -}; +} // Registration can't be done twice. int TestDoubleRegister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != EBUSY) { + ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != EBUSY) { return 1; } return 0; -}; +} // Registration can be done again after unregister. int TestRegisterUnregister() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); - sys_errno(ret) != 0) { + ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, 0); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } return 0; -}; +} // The pointer to rseq must match on register/unregister. int TestUnregisterDifferentPtr() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } struct rseq r2 = {}; - if (int ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); - sys_errno(ret) != EINVAL) { + + ret = sys_rseq(&r2, sizeof(r2), kRseqFlagUnregister, 0); + if (sys_errno(ret) != EINVAL) { return 1; } return 0; -}; +} // The signature must match on register/unregister. int TestUnregisterDifferentSignature() { constexpr int kSignature = 0; struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kSignature); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kSignature); + if (sys_errno(ret) != 0) { return 1; } - if (int ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); - sys_errno(ret) != EPERM) { + ret = sys_rseq(&r, sizeof(r), kRseqFlagUnregister, kSignature + 1); + if (sys_errno(ret) != EPERM) { return 1; } return 0; -}; +} // The CPU ID is initialized. int TestCPU() { struct rseq r = {}; r.cpu_id = kRseqCPUIDUninitialized; - if (int ret = sys_rseq(&r, sizeof(r), 0, 0); sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, 0); + if (sys_errno(ret) != 0) { return 1; } @@ -163,13 +174,13 @@ int TestCPU() { } return 0; -}; +} // Critical section is eventually aborted. int TestAbort() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -185,13 +196,13 @@ int TestAbort() { rseq_loop(&r, &cs); return 0; -}; +} // Abort may be before the critical section. int TestAbortBefore() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -207,13 +218,13 @@ int TestAbortBefore() { rseq_loop(&r, &cs); return 0; -}; +} // Signature must match. int TestAbortSignature() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + if (sys_errno(ret) != 0) { return 1; } @@ -229,13 +240,13 @@ int TestAbortSignature() { rseq_loop(&r, &cs); return 1; -}; +} // Abort must not be in the critical section. int TestAbortPreCommit() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature + 1); + if (sys_errno(ret) != 0) { return 1; } @@ -251,13 +262,13 @@ int TestAbortPreCommit() { rseq_loop(&r, &cs); return 1; -}; +} // rseq.rseq_cs is cleared on abort. int TestAbortClearsCS() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -277,13 +288,13 @@ int TestAbortClearsCS() { } return 0; -}; +} // rseq.rseq_cs is cleared on abort outside of critical section. int TestInvalidAbortClearsCS() { struct rseq r = {}; - if (int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); - sys_errno(ret) != 0) { + int ret = sys_rseq(&r, sizeof(r), 0, kRseqSignature); + if (sys_errno(ret) != 0) { return 1; } @@ -306,7 +317,7 @@ int TestInvalidAbortClearsCS() { } return 0; -}; +} // Exit codes: // 0 - Pass diff --git a/test/syscalls/linux/rseq/test.h b/test/syscalls/linux/rseq/test.h index 3b7bb74b1..ff0dd6e48 100644 --- a/test/syscalls/linux/rseq/test.h +++ b/test/syscalls/linux/rseq/test.h @@ -20,22 +20,20 @@ namespace testing { // Test cases supported by rseq binary. -inline constexpr char kRseqTestUnaligned[] = "unaligned"; -inline constexpr char kRseqTestRegister[] = "register"; -inline constexpr char kRseqTestDoubleRegister[] = "double-register"; -inline constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; -inline constexpr char kRseqTestUnregisterDifferentPtr[] = - "unregister-different-ptr"; -inline constexpr char kRseqTestUnregisterDifferentSignature[] = +constexpr char kRseqTestUnaligned[] = "unaligned"; +constexpr char kRseqTestRegister[] = "register"; +constexpr char kRseqTestDoubleRegister[] = "double-register"; +constexpr char kRseqTestRegisterUnregister[] = "register-unregister"; +constexpr char kRseqTestUnregisterDifferentPtr[] = "unregister-different-ptr"; +constexpr char kRseqTestUnregisterDifferentSignature[] = "unregister-different-signature"; -inline constexpr char kRseqTestCPU[] = "cpu"; -inline constexpr char kRseqTestAbort[] = "abort"; -inline constexpr char kRseqTestAbortBefore[] = "abort-before"; -inline constexpr char kRseqTestAbortSignature[] = "abort-signature"; -inline constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; -inline constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; -inline constexpr char kRseqTestInvalidAbortClearsCS[] = - "invalid-abort-clears-cs"; +constexpr char kRseqTestCPU[] = "cpu"; +constexpr char kRseqTestAbort[] = "abort"; +constexpr char kRseqTestAbortBefore[] = "abort-before"; +constexpr char kRseqTestAbortSignature[] = "abort-signature"; +constexpr char kRseqTestAbortPreCommit[] = "abort-precommit"; +constexpr char kRseqTestAbortClearsCS[] = "abort-clears-cs"; +constexpr char kRseqTestInvalidAbortClearsCS[] = "invalid-abort-clears-cs"; } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/sendfile.cc b/test/syscalls/linux/sendfile.cc index 64123e904..a8bfb01f1 100644 --- a/test/syscalls/linux/sendfile.cc +++ b/test/syscalls/linux/sendfile.cc @@ -198,7 +198,39 @@ TEST(SendFileTest, SendAndUpdateFileOffset) { EXPECT_EQ(absl::string_view(kData, kHalfDataSize), absl::string_view(actual, bytes_sent)); - // Verify that the input file offset has been updated + // Verify that the input file offset has been updated. + ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent), + SyscallSucceedsWithValue(kHalfDataSize)); + EXPECT_EQ( + absl::string_view(kData + kDataSize - bytes_sent, kDataSize - bytes_sent), + absl::string_view(actual, kHalfDataSize)); +} + +TEST(SendFileTest, SendToDevZeroAndUpdateFileOffset) { + // Create temp files. + // Test input string length must be > 2 AND even. + constexpr char kData[] = "The slings and arrows of outrageous fortune,"; + constexpr int kDataSize = sizeof(kData) - 1; + constexpr int kHalfDataSize = kDataSize / 2; + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode)); + + // Open the input file as read only. + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Open /dev/zero as write only. + const FileDescriptor outf = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); + + // Send data and verify that sendfile returns the correct value. + int bytes_sent; + EXPECT_THAT( + bytes_sent = sendfile(outf.get(), inf.get(), nullptr, kHalfDataSize), + SyscallSucceedsWithValue(kHalfDataSize)); + + char actual[kHalfDataSize]; + // Verify that the input file offset has been updated. ASSERT_THAT(read(inf.get(), &actual, kDataSize - bytes_sent), SyscallSucceedsWithValue(kHalfDataSize)); EXPECT_EQ( @@ -250,7 +282,7 @@ TEST(SendFileTest, SendAndUpdateFileOffsetFromNonzeroStartingPoint) { EXPECT_EQ(absl::string_view(kData + kQuarterDataSize, kHalfDataSize), absl::string_view(actual, bytes_sent)); - // Verify that the input file offset has been updated + // Verify that the input file offset has been updated. ASSERT_THAT(read(inf.get(), &actual, kQuarterDataSize), SyscallSucceedsWithValue(kQuarterDataSize)); @@ -501,6 +533,22 @@ TEST(SendFileTest, SendPipeWouldBlock) { SyscallFailsWithErrno(EWOULDBLOCK)); } +TEST(SendFileTest, SendPipeEOF) { + // Create and open an empty input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor inf = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Setup the output named pipe. + int fds[2]; + ASSERT_THAT(pipe2(fds, O_NONBLOCK), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + EXPECT_THAT(sendfile(wfd.get(), inf.get(), nullptr, 123), + SyscallSucceedsWithValue(0)); +} + TEST(SendFileTest, SendPipeBlocks) { // Create temp file. constexpr char kData[] = diff --git a/test/syscalls/linux/shm.cc b/test/syscalls/linux/shm.cc index c7fdbb924..d6e8b3e59 100644 --- a/test/syscalls/linux/shm.cc +++ b/test/syscalls/linux/shm.cc @@ -29,6 +29,8 @@ namespace testing { namespace { using ::testing::_; +using ::testing::AnyOf; +using ::testing::Eq; const uint64_t kAllocSize = kPageSize * 128ULL; @@ -394,7 +396,8 @@ TEST(ShmDeathTest, SegmentNotAccessibleAfterDetach) { }; EXPECT_THAT(InForkedProcess(rest), - IsPosixErrorOkAndHolds(W_EXITCODE(0, SIGSEGV))); + IsPosixErrorOkAndHolds(AnyOf(Eq(W_EXITCODE(0, SIGSEGV)), + Eq(W_EXITCODE(0, 128 + SIGSEGV))))); } TEST(ShmTest, RequestingSegmentSmallerThanSHMMINFails) { diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc index c20cd3fcc..e680d3dd7 100644 --- a/test/syscalls/linux/socket.cc +++ b/test/syscalls/linux/socket.cc @@ -14,6 +14,7 @@ #include <sys/socket.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -26,6 +27,9 @@ namespace gvisor { namespace testing { +// From linux/magic.h, but we can't depend on linux headers here. +#define SOCKFS_MAGIC 0x534F434B + TEST(SocketTest, UnixSocketPairProtocol) { int socks[2]; ASSERT_THAT(socketpair(AF_UNIX, SOCK_STREAM, PF_UNIX, socks), @@ -94,6 +98,19 @@ TEST(SocketTest, UnixSocketStat) { } } +TEST(SocketTest, UnixSocketStatFS) { + SKIP_IF(IsRunningWithVFS1()); + + FileDescriptor bound = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX)); + + struct statfs st; + EXPECT_THAT(fstatfs(bound.get(), &st), SyscallSucceeds()); + EXPECT_EQ(st.f_type, SOCKFS_MAGIC); + EXPECT_EQ(st.f_bsize, getpagesize()); + EXPECT_EQ(st.f_namelen, NAME_MAX); +} + using SocketOpenTest = ::testing::TestWithParam<int>; // UDS cannot be opened. diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index f7d6139f1..5d39e6fbd 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -462,7 +462,8 @@ TEST_P(AllSocketPairTest, SendTimeoutDefault) { TEST_P(AllSocketPairTest, SetGetSendTimeout) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - timeval tv = {.tv_sec = 89, .tv_usec = 42000}; + // tv_usec should be a multiple of 4000 to work on most systems. + timeval tv = {.tv_sec = 89, .tv_usec = 44000}; EXPECT_THAT( setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)), SyscallSucceeds()); @@ -472,8 +473,8 @@ TEST_P(AllSocketPairTest, SetGetSendTimeout) { EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &actual_tv, &len), SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv_sec, 89); - EXPECT_EQ(actual_tv.tv_usec, 42000); + EXPECT_EQ(actual_tv.tv_sec, tv.tv_sec); + EXPECT_EQ(actual_tv.tv_usec, tv.tv_usec); } TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { @@ -484,8 +485,9 @@ TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { int64_t extra_data; } ABSL_ATTRIBUTE_PACKED; + // tv_usec should be a multiple of 4000 to work on most systems. timeval_with_extra tv_extra = { - .tv = {.tv_sec = 0, .tv_usec = 123000}, + .tv = {.tv_sec = 0, .tv_usec = 124000}, }; EXPECT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, @@ -497,8 +499,8 @@ TEST_P(AllSocketPairTest, SetGetSendTimeoutLargerArg) { EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_SNDTIMEO, &actual_tv, &len), SyscallSucceeds()); - EXPECT_EQ(actual_tv.tv.tv_sec, 0); - EXPECT_EQ(actual_tv.tv.tv_usec, 123000); + EXPECT_EQ(actual_tv.tv.tv_sec, tv_extra.tv.tv_sec); + EXPECT_EQ(actual_tv.tv.tv_usec, tv_extra.tv.tv_usec); } TEST_P(AllSocketPairTest, SendTimeoutAllowsWrite) { diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc index 6a232238d..6cd67123d 100644 --- a/test/syscalls/linux/socket_generic_stress.cc +++ b/test/syscalls/linux/socket_generic_stress.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <poll.h> #include <stdio.h> #include <sys/ioctl.h> #include <sys/socket.h> @@ -29,6 +30,9 @@ namespace testing { using ConnectStressTest = SocketPairTest; TEST_P(ConnectStressTest, Reset65kTimes) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + for (int i = 0; i < 1 << 16; ++i) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -37,6 +41,14 @@ TEST_P(ConnectStressTest, Reset65kTimes) { char sent_data[100] = {}; ASSERT_THAT(write(sockets->first_fd(), sent_data, sizeof(sent_data)), SyscallSucceedsWithValue(sizeof(sent_data))); + // Poll the other FD to make sure that the data is in the receive buffer + // before closing it to ensure a RST is triggered. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->second_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); } } @@ -58,7 +70,54 @@ INSTANTIATE_TEST_SUITE_P( // a persistent listener (if applicable). using PersistentListenerConnectStressTest = SocketPairTest; -TEST_P(PersistentListenerConnectStressTest, 65kTimes) { +TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); + if (GetParam().type == SOCK_STREAM) { + // Poll the other FD to make sure that we see the FIN from the other + // side before closing the second_fd. This ensures that the first_fd + // enters TIME-WAIT and not second_fd. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->second_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + } + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); + } +} + +TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + + for (int i = 0; i < 1 << 16; ++i) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); + if (GetParam().type == SOCK_STREAM) { + // Poll the other FD to make sure that we see the FIN from the other + // side before closing the first_fd. This ensures that the second_fd + // enters TIME-WAIT and not first_fd. + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = sockets->first_fd(), + .events = POLL_IN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + } + ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); + } +} + +TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) { + // TODO(b/165912341): These are too slow on KVM platform with nested virt. + SKIP_IF(GvisorPlatform() == Platform::kKVM); + for (int i = 0; i < 1 << 16; ++i) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 18b9e4b70..11fcec443 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -97,11 +97,9 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { ASSERT_THAT(socketpair(AF_INET6, 0, 0, fd), SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - // Invalid AF will return ENOAFSUPPORT. - ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), - SyscallFailsWithErrno(EAFNOSUPPORT)); - ASSERT_THAT(socketpair(8675309, 0, 0, fd), - SyscallFailsWithErrno(EAFNOSUPPORT)); + // Invalid AF will fail. + ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), SyscallFails()); + ASSERT_THAT(socketpair(8675309, 0, 0, fd), SyscallFails()); } enum class Operation { @@ -116,7 +114,8 @@ std::string OperationToString(Operation operation) { return "Bind"; case Operation::Connect: return "Connect"; - case Operation::SendTo: + // Operation::SendTo is the default. + default: return "SendTo"; } } @@ -861,36 +860,38 @@ TEST_P(SocketInetLoopbackTest, TCPResetAfterClose) { SyscallSucceedsWithValue(0)); } -// This test is disabled under random save as the the restore run -// results in the stack.Seed() being different which can cause -// sequence number of final connect to be one that is considered -// old and can cause the test to be flaky. -TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { - auto const& param = GetParam(); - TestAddress const& listener = param.listener; - TestAddress const& connector = param.connector; - +// setupTimeWaitClose sets up a socket endpoint in TIME_WAIT state. +// Callers can choose to perform active close on either ends of the connection +// and also specify if they want to enabled SO_REUSEADDR. +void setupTimeWaitClose(const TestAddress* listener, + const TestAddress* connector, bool reuse, + bool accept_close, sockaddr_storage* listen_addr, + sockaddr_storage* conn_bound_addr) { // Create the listening socket. - const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener->family(), SOCK_STREAM, IPPROTO_TCP)); + if (reuse) { + ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + } + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(listen_addr), + listener->addr_len), SyscallSucceeds()); ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); // Get the port bound by the listening socket. - socklen_t addrlen = listener.addr_len; + socklen_t addrlen = listener->addr_len; ASSERT_THAT(getsockname(listen_fd.get(), - reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + reinterpret_cast<sockaddr*>(listen_addr), &addrlen), SyscallSucceeds()); uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener->family(), *listen_addr)); // Connect to the listening socket. FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + Socket(connector->family(), SOCK_STREAM, IPPROTO_TCP)); // We disable saves after this point as a S/R causes the netstack seed // to be regenerated which changes what ports/ISN is picked for a given @@ -901,11 +902,12 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { // // TODO(gvisor.dev/issue/940): S/R portSeed/portHint DisableSave ds; - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + + sockaddr_storage conn_addr = connector->addr; + ASSERT_NO_ERRNO(SetAddrPort(connector->family(), &conn_addr, port)); ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len), + connector->addr_len), SyscallSucceeds()); // Accept the connection. @@ -913,33 +915,146 @@ TEST_P(SocketInetLoopbackTest, TCPTimeWaitTest_NoRandomSave) { ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); // Get the address/port bound by the connecting socket. - sockaddr_storage conn_bound_addr; - socklen_t conn_addrlen = connector.addr_len; + socklen_t conn_addrlen = connector->addr_len; ASSERT_THAT( - getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + getsockname(conn_fd.get(), reinterpret_cast<sockaddr*>(conn_bound_addr), &conn_addrlen), SyscallSucceeds()); - // close the accept FD to trigger TIME_WAIT on the accepted socket which - // should cause the conn_fd to follow CLOSE_WAIT->LAST_ACK->CLOSED instead of - // TIME_WAIT. - accepted.reset(); - absl::SleepFor(absl::Seconds(1)); - conn_fd.reset(); + FileDescriptor active_closefd, passive_closefd; + if (accept_close) { + active_closefd = std::move(accepted); + passive_closefd = std::move(conn_fd); + } else { + active_closefd = std::move(conn_fd); + passive_closefd = std::move(accepted); + } + + // shutdown to trigger TIME_WAIT. + ASSERT_THAT(shutdown(active_closefd.get(), SHUT_RDWR), SyscallSucceeds()); + { + const int kTimeout = 10000; + struct pollfd pfd = { + .fd = passive_closefd.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + } + ScopedThread t([&]() { + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = active_closefd.get(), + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + }); + + passive_closefd.reset(); + t.Join(); + active_closefd.reset(); + // This sleep is needed to reduce flake to ensure that the passive-close + // ensures the state transitions to CLOSE from LAST_ACK. absl::SleepFor(absl::Seconds(1)); +} - // Now bind and connect a new socket and verify that we can immediately - // rebind the address bound by the conn_fd as it never entered TIME_WAIT. - const FileDescriptor conn_fd2 = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); +// These tests are disabled under random save as the the restore run +// results in the stack.Seed() being different which can cause +// sequence number of final connect to be one that is considered +// old and can cause the test to be flaky. +// +// Test re-binding of client and server bound addresses when the older +// connection is in TIME_WAIT. +TEST_P(SocketInetLoopbackTest, TCPPassiveCloseNoTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, + true /*accept_close*/, &listen_addr, &conn_bound_addr); - ASSERT_THAT(bind(conn_fd2.get(), - reinterpret_cast<sockaddr*>(&conn_bound_addr), conn_addrlen), + // Now bind a new socket and verify that we can immediately rebind the address + // bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), + + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + param.listener.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackTest, + TCPPassiveCloseNoTimeWaitReuseTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, + true /*accept_close*/, &listen_addr, &conn_bound_addr); + + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.listener.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(listen_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + param.listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Now bind and connect new socket and verify that we can immediately rebind + // the address bound by the conn_fd as it never entered TIME_WAIT. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallSucceeds()); + + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(param.listener.family(), listen_addr)); + sockaddr_storage conn_addr = param.connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(param.connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), - conn_addrlen), + param.connector.addr_len), + SyscallSucceeds()); +} + +TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, false /*reuse*/, + false /*accept_close*/, &listen_addr, &conn_bound_addr); + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(SocketInetLoopbackTest, TCPActiveCloseTimeWaitReuseTest_NoRandomSave) { + auto const& param = GetParam(); + sockaddr_storage listen_addr, conn_bound_addr; + listen_addr = param.listener.addr; + setupTimeWaitClose(¶m.listener, ¶m.connector, true /*reuse*/, + false /*accept_close*/, &listen_addr, &conn_bound_addr); + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(param.connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(setsockopt(conn_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), SyscallSucceeds()); + ASSERT_THAT(bind(conn_fd.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), + param.connector.addr_len), + SyscallFailsWithErrno(EADDRINUSE)); } TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { @@ -996,6 +1111,86 @@ TEST_P(SocketInetLoopbackTest, AcceptedInheritsTCPUserTimeout) { EXPECT_EQ(get, kUserTimeout); } +TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(listen_fd.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the port bound by the listening socket. + { + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + &addrlen), + SyscallSucceeds()); + } + + const uint16_t port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + + // Connect to the listening socket. + FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT(RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Trigger a RST by turning linger off and closing the socket. + struct linger opt = { + .l_onoff = 1, + .l_linger = 0, + }; + ASSERT_THAT( + setsockopt(conn_fd.get(), SOL_SOCKET, SO_LINGER, &opt, sizeof(opt)), + SyscallSucceeds()); + ASSERT_THAT(close(conn_fd.release()), SyscallSucceeds()); + + if (IsRunningOnGvisor()) { + // Gvisor packet procssing is asynchronous and can take a bit of time in + // some cases so we give it a bit of time to process the RST packet before + // calling accept. + // + // There is nothing to poll() on so we have no choice but to use a sleep + // here. + absl::SleepFor(absl::Milliseconds(100)); + } + + sockaddr_storage accept_addr; + socklen_t addrlen = sizeof(accept_addr); + + auto accept_fd = ASSERT_NO_ERRNO_AND_VALUE(Accept( + listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen)); + ASSERT_EQ(addrlen, listener.addr_len); + + // TODO(gvisor.dev/issue/3812): Remove after SO_ERROR is fixed. + if (IsRunningOnGvisor()) { + char buf[10]; + ASSERT_THAT(ReadFd(accept_fd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(ECONNRESET)); + } else { + int err; + socklen_t optlen = sizeof(err); + ASSERT_THAT( + getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen), + SyscallSucceeds()); + ASSERT_EQ(err, ECONNRESET); + ASSERT_EQ(optlen, sizeof(err)); + } +} + // TODO(gvisor.dev/issue/1688): Partially completed passive endpoints are not // saved. Enable S/R once issue is fixed. TEST_P(SocketInetLoopbackTest, TCPDeferAccept_NoRandomSave) { @@ -2469,6 +2664,44 @@ TEST_P(SocketMultiProtocolInetLoopbackTest, V4EphemeralPortReservedReuseAddr) { SyscallSucceeds()); } +TEST_P(SocketMultiProtocolInetLoopbackTest, + MultipleBindsAllowedNoListeningReuseAddr) { + const auto& param = GetParam(); + // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP + // this is only permitted if there is no other listening socket. + SKIP_IF(param.type != SOCK_STREAM); + // Bind the v4 loopback on a v4 socket. + const TestAddress& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + + // Now create a socket and bind it to the same port, this should + // succeed since there is no listening socket for the same port. + FileDescriptor second_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(second_fd.get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(bind(second_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len), + SyscallSucceeds()); +} + TEST_P(SocketMultiProtocolInetLoopbackTest, PortReuseTwoSockets) { auto const& param = GetParam(); TestAddress const& test_addr = V4Loopback(); diff --git a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc index 791e2bd51..1a0b53394 100644 --- a/test/syscalls/linux/socket_inet_loopback_nogotsan.cc +++ b/test/syscalls/linux/socket_inet_loopback_nogotsan.cc @@ -168,6 +168,71 @@ INSTANTIATE_TEST_SUITE_P( TestParam{V6Loopback(), V6Loopback()}), DescribeTestParam); +struct ProtocolTestParam { + std::string description; + int type; +}; + +std::string DescribeProtocolTestParam( + ::testing::TestParamInfo<ProtocolTestParam> const& info) { + return info.param.description; +} + +using SocketMultiProtocolInetLoopbackTest = + ::testing::TestWithParam<ProtocolTestParam>; + +TEST_P(SocketMultiProtocolInetLoopbackTest, + BindAvoidsListeningPortsReuseAddr_NoRandomSave) { + const auto& param = GetParam(); + // UDP sockets are allowed to bind/listen on the port w/ SO_REUSEADDR, for TCP + // this is only permitted if there is no other listening socket. + SKIP_IF(param.type != SOCK_STREAM); + + DisableSave ds; // Too many syscalls. + + // A map of port to file descriptor binding the port. + std::map<uint16_t, FileDescriptor> listen_sockets; + + // Exhaust all ephemeral ports. + while (true) { + // Bind the v4 loopback on a v4 socket. + TestAddress const& test_addr = V4Loopback(); + sockaddr_storage bound_addr = test_addr.addr; + FileDescriptor bound_fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(test_addr.family(), param.type, 0)); + + ASSERT_THAT(setsockopt(bound_fd.get(), SOL_SOCKET, SO_REUSEADDR, + &kSockOptOn, sizeof(kSockOptOn)), + SyscallSucceeds()); + + int ret = bind(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + test_addr.addr_len); + if (ret != 0) { + ASSERT_EQ(errno, EADDRINUSE); + break; + } + // Get the port that we bound. + socklen_t bound_addr_len = test_addr.addr_len; + ASSERT_THAT( + getsockname(bound_fd.get(), reinterpret_cast<sockaddr*>(&bound_addr), + &bound_addr_len), + SyscallSucceeds()); + uint16_t port = reinterpret_cast<sockaddr_in*>(&bound_addr)->sin_port; + + // Newly bound port should not already be in use by a listening socket. + ASSERT_EQ(listen_sockets.find(port), listen_sockets.end()); + auto fd = bound_fd.get(); + listen_sockets.insert(std::make_pair(port, std::move(bound_fd))); + ASSERT_THAT(listen(fd, SOMAXCONN), SyscallSucceeds()); + } +} + +INSTANTIATE_TEST_SUITE_P( + AllFamilies, SocketMultiProtocolInetLoopbackTest, + ::testing::Values(ProtocolTestParam{"TCP", SOCK_STREAM}, + ProtocolTestParam{"UDP", SOCK_DGRAM}), + DescribeProtocolTestParam); + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index c2ecb639f..f4b69c46c 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -34,6 +34,9 @@ namespace gvisor { namespace testing { +using ::testing::AnyOf; +using ::testing::Eq; + TEST_P(TCPSocketPairTest, TcpInfoSucceeds) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -800,6 +803,9 @@ TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { // Linux and Netstack both default to a 60s TCP_LINGER2 timeout. constexpr int kDefaultTCPLingerTimeout = 60; +// On Linux, the maximum linger2 timeout was changed from 60sec to 120sec. +constexpr int kMaxTCPLingerTimeout = 120; +constexpr int kOldMaxTCPLingerTimeout = 60; TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -813,26 +819,45 @@ TEST_P(TCPSocketPairTest, TCPLingerTimeoutDefault) { EXPECT_EQ(get, kDefaultTCPLingerTimeout); } -TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZeroOrLess) { +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutLessThanZero) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); - constexpr int kZero = 0; - EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, - sizeof(kZero)), - SyscallSucceedsWithValue(0)); - constexpr int kNegative = -1234; EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kNegative, sizeof(kNegative)), SyscallSucceedsWithValue(0)); + int get = INT_MAX; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_EQ(get, -1); +} + +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutZero) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + constexpr int kZero = 0; + EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kZero, + sizeof(kZero)), + SyscallSucceedsWithValue(0)); + int get = -1; + socklen_t get_len = sizeof(get); + EXPECT_THAT( + getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(get_len, sizeof(get)); + EXPECT_THAT(get, + AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout))); } -TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { +TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveMax) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); // Values above the net.ipv4.tcp_fin_timeout are capped to tcp_fin_timeout // on linux (defaults to 60 seconds on linux). - constexpr int kAboveDefault = kDefaultTCPLingerTimeout + 1; + constexpr int kAboveDefault = kMaxTCPLingerTimeout + 1; EXPECT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &kAboveDefault, sizeof(kAboveDefault)), SyscallSucceedsWithValue(0)); @@ -843,7 +868,12 @@ TEST_P(TCPSocketPairTest, SetTCPLingerTimeoutAboveDefault) { getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_LINGER2, &get, &get_len), SyscallSucceedsWithValue(0)); EXPECT_EQ(get_len, sizeof(get)); - EXPECT_EQ(get, kDefaultTCPLingerTimeout); + if (IsRunningOnGvisor()) { + EXPECT_EQ(get, kMaxTCPLingerTimeout); + } else { + EXPECT_THAT(get, + AnyOf(Eq(kMaxTCPLingerTimeout), Eq(kOldMaxTCPLingerTimeout))); + } } TEST_P(TCPSocketPairTest, SetTCPLingerTimeout) { @@ -1050,5 +1080,124 @@ TEST_P(TCPSocketPairTest, TCPResetDuringClose_NoRandomSave) { } } +// Test setsockopt and getsockopt for a socket with SO_LINGER option. +TEST_P(TCPSocketPairTest, SetAndGetLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Check getsockopt before SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_THAT(got_len, sizeof(got_linger)); + struct linger want_linger = {}; + EXPECT_EQ(0, memcmp(&want_linger, &got_linger, got_len)); + + // Set and get SO_LINGER with negative values. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = -3; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(sl.l_onoff, got_linger.l_onoff); + // Linux returns a different value as it uses HZ to convert the seconds to + // jiffies which overflows for negative values. We want to be compatible with + // linux for getsockopt return value. + if (IsRunningOnGvisor()) { + EXPECT_EQ(sl.l_linger, got_linger.l_linger); + } + + // Set and get SO_LINGER option with positive values. + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); +} + +// Test socket to disable SO_LINGER option. +TEST_P(TCPSocketPairTest, SetOffLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set the SO_LINGER option. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); + + sl.l_onoff = 0; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set to zero. + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); +} + +// Test close on dup'd socket with SO_LINGER option set. +TEST_P(TCPSocketPairTest, CloseWithLingerOption) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + // Set the SO_LINGER option. + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + // Check getsockopt after SO_LINGER option is set. + struct linger got_linger = {-1, -1}; + socklen_t got_len = sizeof(got_linger); + ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &got_len), + SyscallSucceeds()); + ASSERT_EQ(got_len, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, got_len)); + + FileDescriptor dupFd = FileDescriptor(dup(sockets->first_fd())); + ASSERT_THAT(close(sockets->release_first_fd()), SyscallSucceeds()); + char buf[10] = {}; + // Write on dupFd should succeed as socket will not be closed until + // all references are removed. + ASSERT_THAT(RetryEINTR(write)(dupFd.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sockets->first_fd(), buf, sizeof(buf)), + SyscallFailsWithErrno(EBADF)); + + // Close the socket. + dupFd.reset(); + // Write on dupFd should fail as all references for socket are removed. + ASSERT_THAT(RetryEINTR(write)(dupFd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EBADF)); +} } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc index edb86aded..3f2c0fdf2 100644 --- a/test/syscalls/linux/socket_ip_udp_generic.cc +++ b/test/syscalls/linux/socket_ip_udp_generic.cc @@ -435,8 +435,10 @@ TEST_P(UDPSocketPairTest, TOSRecvMismatch) { // Test that an IPv4 socket does not support the IPv6 TClass option. TEST_P(UDPSocketPairTest, TClassRecvMismatch) { - // This should only test AF_INET sockets for the mismatch behavior. - SKIP_IF(GetParam().domain != AF_INET); + // This should only test AF_INET6 sockets for the mismatch behavior. + SKIP_IF(GetParam().domain != AF_INET6); + // IPV6_RECVTCLASS is only valid for SOCK_DGRAM and SOCK_RAW. + SKIP_IF(GetParam().type != SOCK_DGRAM | GetParam().type != SOCK_RAW); auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); @@ -448,5 +450,27 @@ TEST_P(UDPSocketPairTest, TClassRecvMismatch) { SyscallFailsWithErrno(EOPNOTSUPP)); } +// Test the SO_LINGER option can be set/get on udp socket. +TEST_P(UDPSocketPairTest, SetAndGetSocketLinger) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + int level = SOL_SOCKET; + int type = SO_LINGER; + + struct linger sl; + sl.l_onoff = 1; + sl.l_linger = 5; + ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + ASSERT_THAT( + getsockopt(sockets->first_fd(), level, type, &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&sl, &got_linger, length)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ip_unbound.cc b/test/syscalls/linux/socket_ip_unbound.cc index 1c7b0cf90..8f7ccc868 100644 --- a/test/syscalls/linux/socket_ip_unbound.cc +++ b/test/syscalls/linux/socket_ip_unbound.cc @@ -217,6 +217,8 @@ TEST_P(IPUnboundSocketTest, InvalidLargeTOS) { } TEST_P(IPUnboundSocketTest, CheckSkipECN) { + // Test is inconsistant on different kernels. + SKIP_IF(!IsRunningOnGvisor()); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); int set = 0xFF; socklen_t set_sz = sizeof(set); diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index de0f5f01b..a72c76c97 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -27,6 +27,7 @@ #include "absl/memory/memory.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/posix_error.h" #include "test/util/test_util.h" namespace gvisor { @@ -73,9 +74,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNoGroup) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that not setting a default send interface prevents multicast packets @@ -207,8 +208,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackAddr) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -262,8 +264,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackNic) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -317,8 +320,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddr) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -372,8 +376,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNic) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -431,8 +436,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrConnect) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -490,8 +496,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicConnect) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -545,8 +552,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelf) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -600,8 +608,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelf) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -659,9 +668,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfConnect) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that multicast works when the default send interface is configured by @@ -717,9 +726,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfConnect) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that multicast works when the default send interface is configured by @@ -775,8 +784,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfAddrSelfNoLoop) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -834,8 +844,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastLoopbackIfNicSelfNoLoop) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket1->get(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket1->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -907,9 +918,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropAddr) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that dropping a group membership prevents multicast packets from being @@ -965,9 +976,9 @@ TEST_P(IPv4UDPUnboundSocketTest, IpMulticastDropNic) { // Check that we did not receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIfZero) { @@ -1319,9 +1330,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionOnTwoSockets) { // Check that we received the multicast packet on both sockets. for (auto& sockets : socket_pairs) { char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, + sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } } @@ -1398,9 +1409,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { // Check that we received the multicast packet on both sockets. for (auto& sockets : socket_pairs) { char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT( - RetryEINTR(recv)(sockets->second_fd(), recv_buf, sizeof(recv_buf), 0), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, + sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } } @@ -1421,9 +1432,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestMcastReceptionWhenDroppingMemberships) { char recv_buf[sizeof(send_buf)] = {}; for (auto& sockets : socket_pairs) { - ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), recv_buf, - sizeof(recv_buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + ASSERT_THAT(RecvMsgTimeout(sockets->second_fd(), recv_buf, + sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } } } @@ -1474,9 +1485,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenJoinThenReceive) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1518,9 +1529,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenNoJoinThenNoReceive) { // Check that we don't receive the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } // Check that a socket can bind to a multicast address and still send out @@ -1568,9 +1579,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToMcastThenSend) { // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1615,9 +1626,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenReceive) { // Check that we received the multicast packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1666,9 +1677,9 @@ TEST_P(IPv4UDPUnboundSocketTest, TestBindToBcastThenSend) { // Check that we received the packet. char recv_buf[sizeof(send_buf)] = {}; - ASSERT_THAT(RetryEINTR(recv)(socket2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(recv_buf))); + ASSERT_THAT( + RecvMsgTimeout(socket2->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(recv_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } @@ -1726,17 +1737,17 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrDistribution_NoRandomSave) { // of the other sockets to have received it, but we will check that later. char recv_buf[sizeof(send_buf)] = {}; EXPECT_THAT( - RetryEINTR(recv)(last->get(), recv_buf, sizeof(recv_buf), MSG_DONTWAIT), - SyscallSucceedsWithValue(sizeof(send_buf))); + RecvMsgTimeout(last->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(send_buf))); EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); } // Verify that no other messages were received. for (auto& socket : sockets) { char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RetryEINTR(recv)(socket->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT(RecvMsgTimeout(socket->get(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), + PosixErrorIs(EAGAIN, ::testing::_)); } } @@ -2113,45 +2124,12 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) { // balancing (REUSEPORT) instead of the most recently bound socket // (REUSEADDR). char recv_buf[kMessageSize] = {}; - EXPECT_THAT(RetryEINTR(recv)(receiver1->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(kMessageSize)); - EXPECT_THAT(RetryEINTR(recv)(receiver2->get(), recv_buf, sizeof(recv_buf), - MSG_DONTWAIT), - SyscallSucceedsWithValue(kMessageSize)); -} - -// Check that connect returns EADDRNOTAVAIL when out of local ephemeral ports. -// We disable S/R because this test creates a large number of sockets. -TEST_P(IPv4UDPUnboundSocketTest, UDPConnectPortExhaustion_NoRandomSave) { - auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - constexpr int kClients = 65536; - // Bind the first socket to the loopback and take note of the selected port. - auto addr = V4Loopback(); - ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len), - SyscallSucceeds()); - socklen_t addr_len = addr.addr_len; - ASSERT_THAT(getsockname(receiver1->get(), - reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), - SyscallSucceeds()); - EXPECT_EQ(addr_len, addr.addr_len); - - // Disable cooperative S/R as we are making too many syscalls. - DisableSave ds; - std::vector<std::unique_ptr<FileDescriptor>> sockets; - for (int i = 0; i < kClients; i++) { - auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); - - int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), - addr.addr_len); - if (ret == 0) { - sockets.push_back(std::move(s)); - continue; - } - ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); - break; - } + EXPECT_THAT(RecvMsgTimeout(receiver1->get(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(kMessageSize)); + EXPECT_THAT(RecvMsgTimeout(receiver2->get(), recv_buf, sizeof(recv_buf), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(kMessageSize)); } // Test that socket will receive packet info control message. @@ -2452,5 +2430,105 @@ TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) { ASSERT_EQ(quarter_sz, val); } + +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) { + auto sender_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind the first FD to the loopback. This is an alternative to + // IP_MULTICAST_IF for setting the default send interface. + auto sender_addr = V4Loopback(); + ASSERT_THAT( + bind(sender_socket->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + + // Bind the second FD to the v4 any address to ensure that we can receive the + // multicast packet. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(receiver_socket->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver_socket->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Register to receive IP packet info. + const int one = 1; + ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_PKTINFO, &one, + sizeof(one)), + SyscallSucceeds()); + + // Send a multicast packet. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sender_socket->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + msghdr recv_msg = {}; + iovec recv_iov = {}; + char recv_buf[sizeof(send_buf)]; + char recv_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; + size_t cmsg_data_len = sizeof(in_pktinfo); + recv_iov.iov_base = recv_buf; + recv_iov.iov_len = sizeof(recv_buf); + recv_msg.msg_iov = &recv_iov; + recv_msg.msg_iovlen = 1; + recv_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + recv_msg.msg_control = recv_cmsg_buf; + ASSERT_THAT(RetryEINTR(recvmsg)(receiver_socket->get(), &recv_msg, 0), + SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + + // Check the IP_PKTINFO control message. + cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP); + EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO); + + // Get loopback index. + ifreq ifr = {}; + absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); + ASSERT_THAT(ioctl(receiver_socket->get(), SIOCGIFINDEX, &ifr), + SyscallSucceeds()); + ASSERT_NE(ifr.ifr_ifindex, 0); + + in_pktinfo received_pktinfo = {}; + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); + if (IsRunningOnGvisor()) { + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, group.imr_multiaddr.s_addr); + } else { + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); + } + EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, group.imr_multiaddr.s_addr); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc index d690d9564..b206137eb 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.cc @@ -42,7 +42,9 @@ TestAddress V4EmptyAddress() { } void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { - got_if_infos_ = false; + // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its + // IPv4 address on eth0. + found_net_interfaces_ = false; // Get interface list. ASSERT_NO_ERRNO(if_helper_.Load()); @@ -71,7 +73,7 @@ void IPv4UDPUnboundExternalNetworkingSocketTest::SetUp() { } eth_if_addr_ = *reinterpret_cast<const sockaddr_in*>(eth_if_addr); - got_if_infos_ = true; + found_net_interfaces_ = true; } // Verifies that a newly instantiated UDP socket does not have the @@ -110,6 +112,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, SetUDPBroadcast) { // the destination port number. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastReceivedOnExpectedPort) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr2 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -185,9 +188,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // not a unicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastReceivedOnExpectedAddresses) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto rcvr1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -272,6 +273,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // (UDPBroadcastSendRecvOnSocketBoundToAny). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastSendRecvOnSocketBoundToBroadcast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -313,6 +315,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // (UDPBroadcastSendRecvOnSocketBoundToBroadcast). TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, UDPBroadcastSendRecvOnSocketBoundToAny) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Enable SO_BROADCAST. @@ -351,6 +354,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Verifies that a UDP broadcast fails to send on a socket with SO_BROADCAST // disabled. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendBroadcast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); // Broadcast a test message without having enabled SO_BROADCAST on the sending @@ -401,6 +405,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // multicast on gVisor. SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(!found_net_interfaces_); + auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -435,6 +441,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // Check that multicast packets will be delivered to the sending socket without // setting an interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { + SKIP_IF(!found_net_interfaces_); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -478,6 +485,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelf) { // set interface and IP_MULTICAST_LOOP disabled. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSelfLoopOff) { + SKIP_IF(!found_net_interfaces_); auto socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto bind_addr = V4Any(); @@ -528,6 +536,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { // multicast on gVisor. SKIP_IF(IsRunningOnGvisor()); + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -566,6 +576,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastNoGroup) { // Check that multicast packets will be delivered to another socket without // setting an interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -613,6 +624,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticast) { // set interface and IP_MULTICAST_LOOP disabled on the sending socket. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastSenderNoLoop) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -664,6 +676,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // setting an interface and IP_MULTICAST_LOOP disabled on the receiving socket. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastReceiverNoLoop) { + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -716,6 +730,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // and both will receive data on it when bound to the ANY address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToAny) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -782,6 +797,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // and both will receive data on it when bound to the multicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToMulticastAddress) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -851,6 +867,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // multicast address, both will receive data. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, TestSendMulticastToTwoBoundToAnyAndMulticastAddress) { + SKIP_IF(!found_net_interfaces_); auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); std::unique_ptr<FileDescriptor> receivers[2] = { ASSERT_NO_ERRNO_AND_VALUE(NewSocket()), @@ -924,6 +941,8 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // is not a multicast address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackFromAddr) { + SKIP_IF(!found_net_interfaces_); + auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -991,9 +1010,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // interface, a multicast packet sent out uses the latter as its source address. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackIfNicAndAddr) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); // Create receiver, bind to ANY and join the multicast group. auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); @@ -1059,9 +1076,7 @@ TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, // another interface. TEST_P(IPv4UDPUnboundExternalNetworkingSocketTest, IpMulticastLoopbackBindToOneIfSetMcastIfToAnother) { - // FIXME(b/137899561): Linux instance for syscall tests sometimes misses its - // IPv4 address on eth0. - SKIP_IF(!got_if_infos_); + SKIP_IF(!found_net_interfaces_); // FIXME (b/137790511): When bound to one interface it is not possible to set // IP_MULTICAST_IF to a different interface. diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h index 10b90b1e0..0e9e70e8e 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_external_networking.h @@ -29,9 +29,9 @@ class IPv4UDPUnboundExternalNetworkingSocketTest : public SimpleSocketTest { IfAddrHelper if_helper_; - // got_if_infos_ is set to false if SetUp() could not obtain all interface - // infos that we need. - bool got_if_infos_; + // found_net_interfaces_ is set to false if SetUp() could not obtain + // all interface infos that we need. + bool found_net_interfaces_; // Interface infos. int lo_if_idx_; diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc new file mode 100644 index 000000000..8052bf404 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_netlink.cc @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <vector> + +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +INSTANTIATE_TEST_SUITE_P( + IPv4UDPSockets, IPv4UDPUnboundSocketNetlinkTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc new file mode 100644 index 000000000..bcbd2feac --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_loopback_nogotsan.cc @@ -0,0 +1,94 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/socket.h> +#include <sys/types.h> + +#include "gtest/gtest.h" +#include "absl/memory/memory.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv4 UDP sockets. +using IPv4UDPUnboundSocketNogotsanTest = SimpleSocketTest; + +// Check that connect returns EAGAIN when out of local ephemeral ports. +// We disable S/R because this test creates a large number of sockets. +TEST_P(IPv4UDPUnboundSocketNogotsanTest, + UDPConnectPortExhaustion_NoRandomSave) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + constexpr int kClients = 65536; + // Bind the first socket to the loopback and take note of the selected port. + auto addr = V4Loopback(); + ASSERT_THAT(bind(receiver1->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len), + SyscallSucceeds()); + socklen_t addr_len = addr.addr_len; + ASSERT_THAT(getsockname(receiver1->get(), + reinterpret_cast<sockaddr*>(&addr.addr), &addr_len), + SyscallSucceeds()); + EXPECT_EQ(addr_len, addr.addr_len); + + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + std::vector<std::unique_ptr<FileDescriptor>> sockets; + for (int i = 0; i < kClients; i++) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int ret = connect(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), + addr.addr_len); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EAGAIN)); + break; + } +} + +// Check that bind returns EADDRINUSE when out of local ephemeral ports. +// We disable S/R because this test creates a large number of sockets. +TEST_P(IPv4UDPUnboundSocketNogotsanTest, UDPBindPortExhaustion_NoRandomSave) { + auto receiver1 = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + constexpr int kClients = 65536; + auto addr = V4Loopback(); + // Disable cooperative S/R as we are making too many syscalls. + DisableSave ds; + std::vector<std::unique_ptr<FileDescriptor>> sockets; + for (int i = 0; i < kClients; i++) { + auto s = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + int ret = + bind(s->get(), reinterpret_cast<sockaddr*>(&addr.addr), addr.addr_len); + if (ret == 0) { + sockets.push_back(std::move(s)); + continue; + } + ASSERT_THAT(ret, SyscallFailsWithErrno(EADDRINUSE)); + break; + } +} + +INSTANTIATE_TEST_SUITE_P( + IPv4UDPSockets, IPv4UDPUnboundSocketNogotsanTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv4UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc new file mode 100644 index 000000000..49a0f06d9 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.cc @@ -0,0 +1,225 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h" + +#include <arpa/inet.h> +#include <poll.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/util/capability_util.h" +#include "test/util/cleanup.h" + +namespace gvisor { +namespace testing { + +constexpr size_t kSendBufSize = 200; + +// Checks that the loopback interface considers itself bound to all IPs in an +// associated subnet. +TEST_P(IPv4UDPUnboundSocketNetlinkTest, JoinSubnet) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr addr; + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses + // via netlink is supported. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); + } else { + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + } + }); + + auto snd_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto rcv_sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Send from an unassigned address but an address that is in the subnet + // associated with the loopback interface. + TestAddress sender_addr("V4NotAssignd1"); + sender_addr.addr.ss_family = AF_INET; + sender_addr.addr_len = sizeof(sockaddr_in); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.2", + &(reinterpret_cast<sockaddr_in*>(&sender_addr.addr) + ->sin_addr.s_addr))); + ASSERT_THAT( + bind(snd_sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + + // Send the packet to an unassigned address but an address that is in the + // subnet associated with the loopback interface. + TestAddress receiver_addr("V4NotAssigned2"); + receiver_addr.addr.ss_family = AF_INET; + receiver_addr.addr_len = sizeof(sockaddr_in); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.254", + &(reinterpret_cast<sockaddr_in*>(&receiver_addr.addr) + ->sin_addr.s_addr))); + ASSERT_THAT( + bind(rcv_sock->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(rcv_sock->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + ASSERT_EQ(receiver_addr_len, receiver_addr.addr_len); + char send_buf[kSendBufSize]; + RandomizeBuffer(send_buf, kSendBufSize); + ASSERT_THAT( + RetryEINTR(sendto)(snd_sock->get(), send_buf, kSendBufSize, 0, + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceedsWithValue(kSendBufSize)); + + // Check that we received the packet. + char recv_buf[kSendBufSize] = {}; + ASSERT_THAT(RetryEINTR(recv)(rcv_sock->get(), recv_buf, kSendBufSize, 0), + SyscallSucceedsWithValue(kSendBufSize)); + ASSERT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)); +} + +// Tests that broadcast packets are delivered to all interested sockets +// (wildcard and broadcast address specified sockets). +// +// Note, we cannot test the IPv4 Broadcast (255.255.255.255) because we do +// not have a route to it. +TEST_P(IPv4UDPUnboundSocketNetlinkTest, ReuseAddrSubnetDirectedBroadcast) { + constexpr uint16_t kPort = 9876; + // Wait up to 20 seconds for the data. + constexpr int kPollTimeoutMs = 20000; + // Number of sockets per socket type. + constexpr int kNumSocketsPerType = 2; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in_addr addr; + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.1", &addr)); + ASSERT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET, + 24 /* prefixlen */, &addr, sizeof(addr))); + Cleanup defer_addr_removal = Cleanup( + [loopback_link = std::move(loopback_link), addr = std::move(addr)] { + if (IsRunningOnGvisor()) { + // TODO(gvisor.dev/issue/3921): Remove this once deleting addresses + // via netlink is supported. + EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, sizeof(addr)), + PosixErrorIs(EOPNOTSUPP, ::testing::_)); + } else { + EXPECT_NO_ERRNO(LinkDelLocalAddr(loopback_link.index, AF_INET, + /*prefixlen=*/24, &addr, + sizeof(addr))); + } + }); + + TestAddress broadcast_address("SubnetBroadcastAddress"); + broadcast_address.addr.ss_family = AF_INET; + broadcast_address.addr_len = sizeof(sockaddr_in); + auto broadcast_address_in = + reinterpret_cast<sockaddr_in*>(&broadcast_address.addr); + ASSERT_EQ(1, inet_pton(AF_INET, "192.0.2.255", + &broadcast_address_in->sin_addr.s_addr)); + broadcast_address_in->sin_port = htons(kPort); + + TestAddress any_address = V4Any(); + reinterpret_cast<sockaddr_in*>(&any_address.addr)->sin_port = htons(kPort); + + // We create sockets bound to both the wildcard address and the broadcast + // address to make sure both of these types of "broadcast interested" sockets + // receive broadcast packets. + std::vector<std::unique_ptr<FileDescriptor>> socks; + for (bool bind_wildcard : {false, true}) { + // Create multiple sockets for each type of "broadcast interested" + // socket so we can test that all sockets receive the broadcast packet. + for (int i = 0; i < kNumSocketsPerType; i++) { + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto idx = socks.size(); + + ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_REUSEADDR, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)) + << "socks[" << idx << "]"; + + ASSERT_THAT(setsockopt(sock->get(), SOL_SOCKET, SO_BROADCAST, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)) + << "socks[" << idx << "]"; + + if (bind_wildcard) { + ASSERT_THAT( + bind(sock->get(), reinterpret_cast<sockaddr*>(&any_address.addr), + any_address.addr_len), + SyscallSucceeds()) + << "socks[" << idx << "]"; + } else { + ASSERT_THAT(bind(sock->get(), + reinterpret_cast<sockaddr*>(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceeds()) + << "socks[" << idx << "]"; + } + + socks.push_back(std::move(sock)); + } + } + + char send_buf[kSendBufSize]; + RandomizeBuffer(send_buf, kSendBufSize); + + // Broadcasts from each socket should be received by every socket (including + // the sending socket). + for (int w = 0; w < socks.size(); w++) { + auto& w_sock = socks[w]; + ASSERT_THAT( + RetryEINTR(sendto)(w_sock->get(), send_buf, kSendBufSize, 0, + reinterpret_cast<sockaddr*>(&broadcast_address.addr), + broadcast_address.addr_len), + SyscallSucceedsWithValue(kSendBufSize)) + << "write socks[" << w << "]"; + + // Check that we received the packet on all sockets. + for (int r = 0; r < socks.size(); r++) { + auto& r_sock = socks[r]; + + struct pollfd poll_fd = {r_sock->get(), POLLIN, 0}; + EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)) + << "write socks[" << w << "] & read socks[" << r << "]"; + + char recv_buf[kSendBufSize] = {}; + EXPECT_THAT(RetryEINTR(recv)(r_sock->get(), recv_buf, kSendBufSize, 0), + SyscallSucceedsWithValue(kSendBufSize)) + << "write socks[" << w << "] & read socks[" << r << "]"; + EXPECT_EQ(0, memcmp(send_buf, recv_buf, kSendBufSize)) + << "write socks[" << w << "] & read socks[" << r << "]"; + } + } +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h new file mode 100644 index 000000000..73e7836d5 --- /dev/null +++ b/test/syscalls/linux/socket_ipv4_udp_unbound_netlink.h @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv4 UDP sockets. +using IPv4UDPUnboundSocketNetlinkTest = SimpleSocketTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_NETLINK_UTIL_H_ diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc new file mode 100644 index 000000000..17021ff82 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_loopback_netlink.cc @@ -0,0 +1,32 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <vector> + +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +INSTANTIATE_TEST_SUITE_P( + IPv6UDPSockets, IPv6UDPUnboundSocketNetlinkTest, + ::testing::ValuesIn(ApplyVec<SocketKind>(IPv6UDPUnboundSocket, + AllBitwiseCombinations(List<int>{ + 0, SOCK_NONBLOCK})))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc new file mode 100644 index 000000000..2ee218231 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc @@ -0,0 +1,53 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h" + +#include <arpa/inet.h> + +#include "gtest/gtest.h" +#include "test/syscalls/linux/socket_netlink_route_util.h" +#include "test/util/capability_util.h" + +namespace gvisor { +namespace testing { + +// Checks that the loopback interface does not consider itself bound to all IPs +// in an associated subnet. +TEST_P(IPv6UDPUnboundSocketNetlinkTest, JoinSubnet) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + + // Add an IP address to the loopback interface. + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + struct in6_addr addr; + EXPECT_EQ(1, inet_pton(AF_INET6, "2001:db8::1", &addr)); + EXPECT_NO_ERRNO(LinkAddLocalAddr(loopback_link.index, AF_INET6, + /*prefixlen=*/64, &addr, sizeof(addr))); + + // Binding to an unassigned address but an address that is in the subnet + // associated with the loopback interface should fail. + TestAddress sender_addr("V6NotAssignd1"); + sender_addr.addr.ss_family = AF_INET6; + sender_addr.addr_len = sizeof(sockaddr_in6); + EXPECT_EQ(1, inet_pton(AF_INET6, "2001:db8::2", + reinterpret_cast<sockaddr_in6*>(&sender_addr.addr) + ->sin6_addr.s6_addr)); + auto sock = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + EXPECT_THAT(bind(sock->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h new file mode 100644 index 000000000..88098be82 --- /dev/null +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.h @@ -0,0 +1,29 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ + +#include "test/syscalls/linux/socket_test_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to IPv6 UDP sockets. +using IPv6UDPUnboundSocketNetlinkTest = SimpleSocketTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_NETLINK_UTIL_H_ diff --git a/test/syscalls/linux/socket_netlink_route_util.cc b/test/syscalls/linux/socket_netlink_route_util.cc index bde1dbb4d..7a0bad4cb 100644 --- a/test/syscalls/linux/socket_netlink_route_util.cc +++ b/test/syscalls/linux/socket_netlink_route_util.cc @@ -26,6 +26,62 @@ namespace { constexpr uint32_t kSeq = 12345; +// Types of address modifications that may be performed on an interface. +enum class LinkAddrModification { + kAdd, + kDelete, +}; + +// Populates |hdr| with appripriate values for the modification type. +PosixError PopulateNlmsghdr(LinkAddrModification modification, + struct nlmsghdr* hdr) { + switch (modification) { + case LinkAddrModification::kAdd: + hdr->nlmsg_type = RTM_NEWADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + return NoError(); + case LinkAddrModification::kDelete: + hdr->nlmsg_type = RTM_DELADDR; + hdr->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; + return NoError(); + } + + return PosixError(EINVAL); +} + +// Adds or removes the specified address from the specified interface. +PosixError LinkModifyLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen, + LinkAddrModification modification) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifaddrmsg ifaddr; + char attrbuf[512]; + }; + + struct request req = {}; + PosixError err = PopulateNlmsghdr(modification, &req.hdr); + if (!err.ok()) { + return err; + } + req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); + req.hdr.nlmsg_seq = kSeq; + req.ifaddr.ifa_index = index; + req.ifaddr.ifa_family = family; + req.ifaddr.ifa_prefixlen = prefixlen; + + struct rtattr* rta = reinterpret_cast<struct rtattr*>( + reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); + rta->rta_type = IFA_LOCAL; + rta->rta_len = RTA_LENGTH(addrlen); + req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); + memcpy(RTA_DATA(rta), addr, addrlen); + + return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +} + } // namespace PosixError DumpLinks( @@ -84,31 +140,14 @@ PosixErrorOr<Link> LoopbackLink() { PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen) { - ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, NetlinkBoundSocket(NETLINK_ROUTE)); - - struct request { - struct nlmsghdr hdr; - struct ifaddrmsg ifaddr; - char attrbuf[512]; - }; - - struct request req = {}; - req.hdr.nlmsg_len = NLMSG_LENGTH(sizeof(req.ifaddr)); - req.hdr.nlmsg_type = RTM_NEWADDR; - req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK; - req.hdr.nlmsg_seq = kSeq; - req.ifaddr.ifa_index = index; - req.ifaddr.ifa_family = family; - req.ifaddr.ifa_prefixlen = prefixlen; - - struct rtattr* rta = reinterpret_cast<struct rtattr*>( - reinterpret_cast<int8_t*>(&req) + NLMSG_ALIGN(req.hdr.nlmsg_len)); - rta->rta_type = IFA_LOCAL; - rta->rta_len = RTA_LENGTH(addrlen); - req.hdr.nlmsg_len = NLMSG_ALIGN(req.hdr.nlmsg_len) + RTA_LENGTH(addrlen); - memcpy(RTA_DATA(rta), addr, addrlen); + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kAdd); +} - return NetlinkRequestAckOrError(fd, kSeq, &req, req.hdr.nlmsg_len); +PosixError LinkDelLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen) { + return LinkModifyLocalAddr(index, family, prefixlen, addr, addrlen, + LinkAddrModification::kDelete); } PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change) { diff --git a/test/syscalls/linux/socket_netlink_route_util.h b/test/syscalls/linux/socket_netlink_route_util.h index 149c4a7f6..e5badca70 100644 --- a/test/syscalls/linux/socket_netlink_route_util.h +++ b/test/syscalls/linux/socket_netlink_route_util.h @@ -43,6 +43,10 @@ PosixErrorOr<Link> LoopbackLink(); PosixError LinkAddLocalAddr(int index, int family, int prefixlen, const void* addr, int addrlen); +// LinkDelLocalAddr removes IFA_LOCAL attribute on the interface. +PosixError LinkDelLocalAddr(int index, int family, int prefixlen, + const void* addr, int addrlen); + // LinkChangeFlags changes interface flags. E.g. IFF_UP. PosixError LinkChangeFlags(int index, unsigned int flags, unsigned int change); diff --git a/test/syscalls/linux/socket_test_util.cc b/test/syscalls/linux/socket_test_util.cc index 53b678e94..e11792309 100644 --- a/test/syscalls/linux/socket_test_util.cc +++ b/test/syscalls/linux/socket_test_util.cc @@ -753,6 +753,20 @@ PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size) { return ret; } +PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, + int timeout) { + fd_set rfd; + struct timeval to = {.tv_sec = timeout, .tv_usec = 0}; + FD_ZERO(&rfd); + FD_SET(sock, &rfd); + + int ret; + RETURN_ERROR_IF_SYSCALL_FAIL(ret = select(1, &rfd, NULL, NULL, &to)); + RETURN_ERROR_IF_SYSCALL_FAIL( + ret = RetryEINTR(recv)(sock, buf, buf_size, MSG_DONTWAIT)); + return ret; +} + void RecvNoData(int sock) { char data = 0; struct iovec iov; diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h index 734b48b96..468bc96e0 100644 --- a/test/syscalls/linux/socket_test_util.h +++ b/test/syscalls/linux/socket_test_util.h @@ -467,6 +467,10 @@ PosixError FreeAvailablePort(int port); // SendMsg converts a buffer to an iovec and adds it to msg before sending it. PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size); +// RecvMsgTimeout calls select on sock with timeout and then calls recv on sock. +PosixErrorOr<int> RecvMsgTimeout(int sock, char buf[], int buf_size, + int timeout); + // RecvNoData checks that no data is receivable on sock. void RecvNoData(int sock); diff --git a/test/syscalls/linux/socket_unix_stream.cc b/test/syscalls/linux/socket_unix_stream.cc index 99e77b89e..1edcb15a7 100644 --- a/test/syscalls/linux/socket_unix_stream.cc +++ b/test/syscalls/linux/socket_unix_stream.cc @@ -103,6 +103,24 @@ TEST_P(StreamUnixSocketPairTest, Sendto) { SyscallFailsWithErrno(EISCONN)); } +TEST_P(StreamUnixSocketPairTest, SetAndGetSocketLinger) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct linger sl = {1, 5}; + EXPECT_THAT( + setsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceedsWithValue(0)); + + struct linger got_linger = {}; + socklen_t length = sizeof(sl); + EXPECT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, SO_LINGER, + &got_linger, &length), + SyscallSucceedsWithValue(0)); + + ASSERT_EQ(length, sizeof(got_linger)); + EXPECT_EQ(0, memcmp(&got_linger, &sl, length)); +} + INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, StreamUnixSocketPairTest, ::testing::ValuesIn(IncludeReversals(VecCat<SocketPairKind>( diff --git a/test/syscalls/linux/splice.cc b/test/syscalls/linux/splice.cc index 08fc4b1b7..a1d2b9b11 100644 --- a/test/syscalls/linux/splice.cc +++ b/test/syscalls/linux/splice.cc @@ -298,6 +298,23 @@ TEST(SpliceTest, ToPipe) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, ToPipeEOF) { + // Create and open an empty input file. + const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor in_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY)); + + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + // Splice from the empty file to the pipe. + EXPECT_THAT(splice(in_fd.get(), nullptr, wfd.get(), nullptr, 123, 0), + SyscallSucceedsWithValue(0)); +} + TEST(SpliceTest, ToPipeOffset) { // Open the input file. const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -342,7 +359,7 @@ TEST(SpliceTest, FromPipe) { ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), SyscallSucceedsWithValue(kPageSize)); - // Open the input file. + // Open the output file. const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const FileDescriptor out_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); @@ -364,6 +381,40 @@ TEST(SpliceTest, FromPipe) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, FromPipeMultiple) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + const FileDescriptor wfd(fds[1]); + + std::string buf = "abcABC123"; + ASSERT_THAT(write(wfd.get(), buf.c_str(), buf.size()), + SyscallSucceedsWithValue(buf.size())); + + // Open the output file. + const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + const FileDescriptor out_fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDWR)); + + // Splice from the pipe to the output file over several calls. + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + EXPECT_THAT(splice(rfd.get(), nullptr, out_fd.get(), nullptr, 3, 0), + SyscallSucceedsWithValue(3)); + + // Reset cursor to zero so that we can check the contents. + ASSERT_THAT(lseek(out_fd.get(), 0, SEEK_SET), SyscallSucceedsWithValue(0)); + + // Contents should be equal. + std::vector<char> rbuf(buf.size()); + ASSERT_THAT(read(out_fd.get(), rbuf.data(), rbuf.size()), + SyscallSucceedsWithValue(rbuf.size())); + EXPECT_EQ(memcmp(rbuf.data(), buf.c_str(), buf.size()), 0); +} + TEST(SpliceTest, FromPipeOffset) { // Create a new pipe. int fds[2]; @@ -693,6 +744,34 @@ TEST(SpliceTest, FromPipeMaxFileSize) { EXPECT_EQ(memcmp(rbuf.data(), buf.data(), buf.size()), 0); } +TEST(SpliceTest, FromPipeToDevZero) { + // Create a new pipe. + int fds[2]; + ASSERT_THAT(pipe(fds), SyscallSucceeds()); + const FileDescriptor rfd(fds[0]); + FileDescriptor wfd(fds[1]); + + // Fill with some random data. + std::vector<char> buf(kPageSize); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT(write(wfd.get(), buf.data(), buf.size()), + SyscallSucceedsWithValue(kPageSize)); + + const FileDescriptor zero = + ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/zero", O_WRONLY)); + + // Close the write end to prevent blocking below. + wfd.reset(); + + // Splice to /dev/zero. The first call should empty the pipe, and the return + // value should not exceed the number of bytes available for reading. + EXPECT_THAT( + splice(rfd.get(), nullptr, zero.get(), nullptr, kPageSize + 123, 0), + SyscallSucceedsWithValue(kPageSize)); + EXPECT_THAT(splice(rfd.get(), nullptr, zero.get(), nullptr, 1, 0), + SyscallSucceedsWithValue(0)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 2503960f3..92260b1e1 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -97,6 +97,11 @@ TEST_F(StatTest, FstatatSymlink) { } TEST_F(StatTest, Nlinks) { + // Skip this test if we are testing overlayfs because overlayfs does not + // (intentionally) return the correct nlink value for directories. + // See fs/overlayfs/inode.c:ovl_getattr(). + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + TempPath basedir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Directory is initially empty, it should contain 2 links (one from itself, @@ -328,20 +333,23 @@ TEST_F(StatTest, LeadingDoubleSlash) { // Test that a rename doesn't change the underlying file. TEST_F(StatTest, StatDoesntChangeAfterRename) { - const TempPath old_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + const TempPath old_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); const TempPath new_path(NewTempAbsPath()); struct stat st_old = {}; struct stat st_new = {}; - ASSERT_THAT(stat(old_dir.path().c_str(), &st_old), SyscallSucceeds()); - ASSERT_THAT(rename(old_dir.path().c_str(), new_path.path().c_str()), + ASSERT_THAT(stat(old_file.path().c_str(), &st_old), SyscallSucceeds()); + ASSERT_THAT(rename(old_file.path().c_str(), new_path.path().c_str()), SyscallSucceeds()); ASSERT_THAT(stat(new_path.path().c_str(), &st_new), SyscallSucceeds()); EXPECT_EQ(st_old.st_nlink, st_new.st_nlink); EXPECT_EQ(st_old.st_dev, st_new.st_dev); - EXPECT_EQ(st_old.st_ino, st_new.st_ino); + // Overlay filesystems may synthesize directory inode numbers on the fly. + if (!ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))) { + EXPECT_EQ(st_old.st_ino, st_new.st_ino); + } EXPECT_EQ(st_old.st_mode, st_new.st_mode); EXPECT_EQ(st_old.st_uid, st_new.st_uid); EXPECT_EQ(st_old.st_gid, st_new.st_gid); @@ -378,7 +386,9 @@ TEST_F(StatTest, LinkCountsWithRegularFileChild) { // This test verifies that inodes remain around when there is an open fd // after link count hits 0. -TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { +// +// It is marked NoSave because we don't support saving unlinked files. +TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoSave) { // Setting the enviornment variable GVISOR_GOFER_UNCACHED to any value // will prevent this test from running, see the tmpfs lifecycle. // @@ -387,9 +397,6 @@ TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { const char* uncached_gofer = getenv("GVISOR_GOFER_UNCACHED"); SKIP_IF(uncached_gofer != nullptr); - // We don't support saving unlinked files. - const DisableSave ds; - const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); const TempPath child = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( dir.path(), "hello", TempPath::kDefaultFileMode)); @@ -432,6 +439,11 @@ TEST_F(StatTest, ZeroLinksOpenFdRegularFileChild_NoRandomSave) { // Test link counts with a directory as the child. TEST_F(StatTest, LinkCountsWithDirChild) { + // Skip this test if we are testing overlayfs because overlayfs does not + // (intentionally) return the correct nlink value for directories. + // See fs/overlayfs/inode.c:ovl_getattr(). + SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(IsOverlayfs(GetAbsoluteTestTmpdir()))); + const TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); // Before a child is added the two links are "." and the link from the parent. diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc index aca51d30f..f0fb166bd 100644 --- a/test/syscalls/linux/statfs.cc +++ b/test/syscalls/linux/statfs.cc @@ -13,6 +13,7 @@ // limitations under the License. #include <fcntl.h> +#include <linux/magic.h> #include <sys/statfs.h> #include <unistd.h> @@ -43,14 +44,10 @@ TEST(StatfsTest, InternalTmpfs) { TEST(StatfsTest, InternalDevShm) { struct statfs st; EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds()); -} - -TEST(StatfsTest, NameLen) { - struct statfs st; - EXPECT_THAT(statfs("/dev/shm", &st), SyscallSucceeds()); // This assumes that /dev/shm is tmpfs. - EXPECT_EQ(st.f_namelen, NAME_MAX); + // Note: We could be an overlay on some configurations. + EXPECT_TRUE(st.f_type == TMPFS_MAGIC || st.f_type == OVERLAYFS_SUPER_MAGIC); } TEST(FstatfsTest, CannotStatBadFd) { diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index a17ff62e9..4d9eba7f0 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -218,6 +218,36 @@ TEST(SymlinkTest, PreadFromSymlink) { EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds()); } +TEST(SymlinkTest, PwriteToSymlink) { + std::string name = NewTempAbsPath(); + int fd; + ASSERT_THAT(fd = open(name.c_str(), O_CREAT, 0644), SyscallSucceeds()); + ASSERT_THAT(close(fd), SyscallSucceeds()); + + std::string linkname = NewTempAbsPath(); + ASSERT_THAT(symlink(name.c_str(), linkname.c_str()), SyscallSucceeds()); + + ASSERT_THAT(fd = open(linkname.c_str(), O_WRONLY), SyscallSucceeds()); + + const int data_size = 10; + const std::string data = std::string(data_size, 'a'); + EXPECT_THAT(pwrite64(fd, data.c_str(), data.size(), 0), + SyscallSucceedsWithValue(data.size())); + + ASSERT_THAT(close(fd), SyscallSucceeds()); + ASSERT_THAT(fd = open(name.c_str(), O_RDONLY), SyscallSucceeds()); + + char buf[data_size + 1]; + EXPECT_THAT(pread64(fd, buf, data.size(), 0), SyscallSucceeds()); + buf[data.size()] = '\0'; + EXPECT_STREQ(buf, data.c_str()); + + ASSERT_THAT(close(fd), SyscallSucceeds()); + + EXPECT_THAT(unlink(name.c_str()), SyscallSucceeds()); + EXPECT_THAT(unlink(linkname.c_str()), SyscallSucceeds()); +} + TEST(SymlinkTest, SymlinkAtDegradedPermissions_NoRandomSave) { // Drop capabilities that allow us to override file and directory permissions. ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); @@ -297,6 +327,16 @@ TEST(SymlinkTest, FollowUpdatesATime) { EXPECT_LT(st_before_follow.st_atime, st_after_follow.st_atime); } +TEST(SymlinkTest, SymlinkAtEmptyPath) { + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + auto fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(dir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(symlinkat(file.path().c_str(), fd.get(), ""), + SyscallFailsWithErrno(ENOENT)); +} + class ParamSymlinkTest : public ::testing::TestWithParam<std::string> {}; // Test that creating an existing symlink with creat will create the target. diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 0cea7d11f..e0981e28a 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -13,9 +13,9 @@ // limitations under the License. #include <fcntl.h> -#ifndef __fuchsia__ +#ifdef __linux__ #include <linux/filter.h> -#endif // __fuchsia__ +#endif // __linux__ #include <netinet/in.h> #include <netinet/tcp.h> #include <poll.h> @@ -720,6 +720,30 @@ TEST_P(TcpSocketTest, TcpSCMPriority) { ASSERT_EQ(cmsg, nullptr); } +TEST_P(TcpSocketTest, TimeWaitPollHUP) { + shutdown(s_, SHUT_RDWR); + ScopedThread t([&]() { + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = s_, + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + }); + shutdown(t_, SHUT_RDWR); + t.Join(); + // At this point s_ should be in TIME-WAIT and polling for POLLHUP should + // return with 1 FD. + constexpr int kTimeout = 10000; + constexpr int16_t want_events = POLLHUP; + struct pollfd pfd = { + .fd = s_, + .events = want_events, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, TcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); @@ -1562,7 +1586,7 @@ TEST_P(SimpleTcpSocketTest, SetTCPWindowClampAboveHalfMinRcvBuf) { } } -#ifndef __fuchsia__ +#ifdef __linux__ // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. // gVisor currently silently ignores attaching a filter. @@ -1596,6 +1620,8 @@ TEST_P(SimpleTcpSocketTest, SetSocketAttachDetachFilter) { SyscallSucceeds()); } +#endif // __linux__ + TEST_P(SimpleTcpSocketTest, SetSocketDetachFilterNoInstalledFilter) { // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. SKIP_IF(IsRunningOnGvisor()); @@ -1617,7 +1643,35 @@ TEST_P(SimpleTcpSocketTest, GetSocketDetachFilter) { SyscallFailsWithErrno(ENOPROTOOPT)); } -#endif // __fuchsia__ +TEST_P(SimpleTcpSocketTest, CloseNonConnectedLingerOption) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + constexpr int kLingerTimeout = 10; // Seconds. + + // Set the SO_LINGER option. + struct linger sl = { + .l_onoff = 1, + .l_linger = kLingerTimeout, + }; + ASSERT_THAT(setsockopt(s.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)), + SyscallSucceeds()); + + struct pollfd poll_fd = { + .fd = s.get(), + .events = POLLHUP, + }; + constexpr int kPollTimeoutMs = 0; + ASSERT_THAT(RetryEINTR(poll)(&poll_fd, 1, kPollTimeoutMs), + SyscallSucceedsWithValue(1)); + + auto const start_time = absl::Now(); + EXPECT_THAT(close(s.release()), SyscallSucceeds()); + auto const end_time = absl::Now(); + + // Close() should not linger and return immediately. + ASSERT_LT((end_time - start_time), absl::Seconds(kLingerTimeout)); +} INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); diff --git a/test/syscalls/linux/truncate.cc b/test/syscalls/linux/truncate.cc index c988c6380..bfc95ed38 100644 --- a/test/syscalls/linux/truncate.cc +++ b/test/syscalls/linux/truncate.cc @@ -196,6 +196,26 @@ TEST(TruncateTest, FtruncateNonWriteable) { EXPECT_THAT(ftruncate(fd.get(), 0), SyscallFailsWithErrno(EINVAL)); } +// ftruncate(2) should succeed as long as the file descriptor is writeable, +// regardless of whether the file permissions allow writing. +TEST(TruncateTest, FtruncateWithoutWritePermission_NoRandomSave) { + // Drop capabilities that allow us to override file permissions. + ASSERT_NO_ERRNO(SetCapability(CAP_DAC_OVERRIDE, false)); + + // The only time we can open a file with flags forbidden by its permissions + // is when we are creating the file. We cannot re-open with the same flags, + // so we cannot restore an fd obtained from such an operation. + const DisableSave ds; + auto path = NewTempAbsPath(); + const FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(path, O_RDWR | O_CREAT, 0444)); + + // In goferfs, ftruncate may be converted to a remote truncate operation that + // unavoidably requires write permission. + SKIP_IF(IsRunningOnGvisor() && !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(path))); + ASSERT_THAT(ftruncate(fd.get(), 100), SyscallSucceeds()); +} + TEST(TruncateTest, TruncateNonExist) { EXPECT_THAT(truncate("/foo/bar", 0), SyscallFailsWithErrno(ENOENT)); } diff --git a/test/syscalls/linux/udp_socket.cc b/test/syscalls/linux/udp_socket.cc index 7a8ac30a4..1a7673317 100644 --- a/test/syscalls/linux/udp_socket.cc +++ b/test/syscalls/linux/udp_socket.cc @@ -12,13 +12,1845 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "test/syscalls/linux/udp_socket_test_cases.h" +#include <arpa/inet.h> +#include <fcntl.h> + +#include <ctime> + +#ifdef __linux__ +#include <linux/errqueue.h> +#include <linux/filter.h> +#endif // __linux__ +#include <netinet/in.h> +#include <poll.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> + +#include "absl/strings/str_format.h" +#ifndef SIOCGSTAMP +#include <linux/sockios.h> +#endif + +#include "gtest/gtest.h" +#include "absl/base/macros.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" namespace gvisor { namespace testing { namespace { +// Fixture for tests parameterized by the address family to use (AF_INET and +// AF_INET6) when creating sockets. +class UdpSocketTest + : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { + protected: + // Creates two sockets that will be used by test cases. + void SetUp() override; + + // Binds the socket bind_ to the loopback and updates bind_addr_. + PosixError BindLoopback(); + + // Binds the socket bind_ to Any and updates bind_addr_. + PosixError BindAny(); + + // Binds given socket to address addr and updates. + PosixError BindSocket(int socket, struct sockaddr* addr); + + // Return initialized Any address to port 0. + struct sockaddr_storage InetAnyAddr(); + + // Return initialized Loopback address to port 0. + struct sockaddr_storage InetLoopbackAddr(); + + // Disconnects socket sockfd. + void Disconnect(int sockfd); + + // Get family for the test. + int GetFamily(); + + // Socket used by Bind methods + FileDescriptor bind_; + + // Second socket used for tests. + FileDescriptor sock_; + + // Address for bind_ socket. + struct sockaddr* bind_addr_; + + // Initialized to the length based on GetFamily(). + socklen_t addrlen_; + + // Storage for bind_addr_. + struct sockaddr_storage bind_addr_storage_; + + private: + // Helper to initialize addrlen_ for the test case. + socklen_t GetAddrLength(); +}; + +// Gets a pointer to the port component of the given address. +uint16_t* Port(struct sockaddr_storage* addr) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + return &sin->sin_port; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + return &sin6->sin6_port; + } + } + + return nullptr; +} + +// Sets addr port to "port". +void SetPort(struct sockaddr_storage* addr, uint16_t port) { + switch (addr->ss_family) { + case AF_INET: { + auto sin = reinterpret_cast<struct sockaddr_in*>(addr); + sin->sin_port = port; + break; + } + case AF_INET6: { + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); + sin6->sin6_port = port; + break; + } + } +} + +void UdpSocketTest::SetUp() { + addrlen_ = GetAddrLength(); + + bind_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); + bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + + sock_ = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); +} + +int UdpSocketTest::GetFamily() { + if (GetParam() == AddressFamily::kIpv4) { + return AF_INET; + } + return AF_INET6; +} + +PosixError UdpSocketTest::BindLoopback() { + bind_addr_storage_ = InetLoopbackAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} + +PosixError UdpSocketTest::BindAny() { + bind_addr_storage_ = InetAnyAddr(); + struct sockaddr* bind_addr_ = + reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); + return BindSocket(bind_.get(), bind_addr_); +} + +PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) { + socklen_t len = sizeof(bind_addr_storage_); + + // Bind, then check that we get the right address. + RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_)); + + RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len)); + + if (addrlen_ != len) { + return PosixError( + EINVAL, + absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_)); + } + return PosixError(0); +} + +socklen_t UdpSocketTest::GetAddrLength() { + struct sockaddr_storage addr; + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + return sizeof(*sin); + } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + return sizeof(*sin6); +} + +sockaddr_storage UdpSocketTest::InetAnyAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_ANY); + sin->sin_port = htons(0); + return addr; + } + + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = IN6ADDR_ANY_INIT; + sin6->sin6_port = htons(0); + return addr; +} + +sockaddr_storage UdpSocketTest::InetLoopbackAddr() { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); + + if (GetFamily() == AF_INET) { + auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + sin->sin_port = htons(0); + return addr; + } + auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); + sin6->sin6_addr = in6addr_loopback; + sin6->sin6_port = htons(0); + return addr; +} + +void UdpSocketTest::Disconnect(int sockfd) { + sockaddr_storage addr_storage = InetAnyAddr(); + sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + socklen_t addrlen = sizeof(addr_storage); + + addr->sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds()); + + // Check that after disconnect the socket is bound to the ANY address. + EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds()); + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + struct in6_addr loopback = IN6ADDR_ANY_INIT; + + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, Creation) { + FileDescriptor sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0)); + EXPECT_THAT(close(sock.release()), SyscallSucceeds()); + + ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails()); +} + +TEST_P(UdpSocketTest, Getsockname) { + // Check that we're not bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), + 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, Getpeername) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that we're not connected. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Connect, then check that we get the right address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, SendNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Do send & write, they must fail. + char buf[512]; + EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), + SyscallFailsWithErrno(EDESTADDRREQ)); + + EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(EDESTADDRREQ)); + + // Use sendto. + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ConnectBinds) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're bound now. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_NE(*Port(&addr), 0); +} + +TEST_P(UdpSocketTest, ReceiveNotBound) { + char buf[512]; + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, Bind) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Try to bind again. + EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EINVAL)); + + // Check that we're still bound to the original address. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); +} + +TEST_P(UdpSocketTest, BindInUse) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Try to bind again. + EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_), + SyscallFailsWithErrno(EADDRINUSE)); +} + +TEST_P(UdpSocketTest, ReceiveAfterConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send from sock_ to bind_ + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + for (int i = 0; i < 2; i++) { + // Connet sock_ to bound address. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + + // Send from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Disconnect sock_. + struct sockaddr unspec = {}; + unspec.sa_family = AF_UNSPEC; + ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)), + SyscallSucceeds()); + } +} + +TEST_P(UdpSocketTest, Connect) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); + + // Try to bind after connect. + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallFailsWithErrno(EINVAL)); + + struct sockaddr_storage bind2_storage = InetLoopbackAddr(); + struct sockaddr* bind2_addr = + reinterpret_cast<struct sockaddr*>(&bind2_storage); + FileDescriptor bind2 = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); + ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); + + // Try to connect again. + EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds()); + + // Check that peer name changed. + peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); +} + +TEST_P(UdpSocketTest, ConnectAnyZero) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, ConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { + // TODO(138658473): Enable when we can connect to port 0 with gVisor. + SKIP_IF(IsRunningOnGvisor()); + struct sockaddr_storage any = InetAnyAddr(); + EXPECT_THAT( + connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), + SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); + + Disconnect(sock_.get()); +} + +TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { + ASSERT_NO_ERRNO(BindAny()); + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); + + Disconnect(sock_.get()); +} + +TEST_P(UdpSocketTest, DisconnectAfterBind) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Bind to the next port above bind_. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct sockaddr_storage unspec = {}; + unspec.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), + sizeof(unspec.ss_family)), + SyscallSucceeds()); + + // Check that we're still bound. + socklen_t addrlen = sizeof(unspec); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); + + addrlen = sizeof(addr); + EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { + ASSERT_NO_ERRNO(BindAny()); + + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + socklen_t addrlen = sizeof(addr); + + // Connect the socket. + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds()); + + // If the socket is bound to ANY and connected to a loopback address, + // getsockname() has to return the loopback address. + if (GetParam() == AddressFamily::kIpv4) { + auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); + } else { + auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); + struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; + EXPECT_EQ(addrlen, sizeof(*addr_out)); + EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); + } +} + +TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + + // Connect the socket. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + Disconnect(sock_.get()); + + // Check that we're still bound. + struct sockaddr_storage addr; + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(memcmp(&addr, any, addrlen), 0); + + addrlen = sizeof(addr); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, Disconnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage any_storage = InetAnyAddr(); + struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); + SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); + + for (int i = 0; i < 2; i++) { + // Try to connect again. + EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Check that we're connected to the right peer. + struct sockaddr_storage peer; + socklen_t peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallSucceeds()); + EXPECT_EQ(peerlen, addrlen_); + EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); + + // Try to disconnect. + struct sockaddr_storage addr = {}; + addr.ss_family = AF_UNSPEC; + EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), + sizeof(addr.ss_family)), + SyscallSucceeds()); + + peerlen = sizeof(peer); + EXPECT_THAT( + getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), + SyscallFailsWithErrno(ENOTCONN)); + + // Check that we're still bound. + socklen_t addrlen = sizeof(addr); + EXPECT_THAT( + getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), + SyscallSucceeds()); + EXPECT_EQ(addrlen, addrlen_); + EXPECT_EQ(*Port(&addr), *Port(&any_storage)); + } +} + +TEST_P(UdpSocketTest, ConnectBadAddress) { + struct sockaddr addr = {}; + addr.sa_family = GetFamily(); + ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)), + SyscallFailsWithErrno(EINVAL)); +} + +TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + struct sockaddr_storage addr_storage = InetAnyAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send to a different destination than we're connected to. + char buf[512]; + EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + // Connect to loopback:bind_addr_+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000), + SyscallSucceedsWithValue(1)); + + // Receive the packet. + char received[3]; + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { + // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Set sock to non-blocking. + int opts = 0; + ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from bind_ to sock_. + ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {sock_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // Receive the packet. + char received[3]; + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallSucceedsWithValue(0)); + EXPECT_THAT(read(sock_.get(), received, sizeof(received)), + SyscallFailsWithErrno(EAGAIN)); +} + +TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send some data to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, SendAndReceiveConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data. + char received[sizeof(buf)]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, ReceiveFromNotConnected) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Check that the data isn't received because it was sent from a different + // address than we're connected. + EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveBeforeConnect) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Bind sock to loopback:bind_addr_port+2. + struct sockaddr_storage addr2_storage = InetLoopbackAddr(); + struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); + SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); + ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Receive the data. It works because it was sent before the connect. + char received[sizeof(buf)]; + EXPECT_THAT( + RecvMsgTimeout(bind_.get(), received, sizeof(received), 1 /*timeout*/), + IsPosixErrorOkAndHolds(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + + // Send again. This time it should not be received. + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReceiveFrom) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind sock to loopback:bind_addr_port+1. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + + // Send some data from sock to bind_. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + + // Receive the data and sender address. + char received[sizeof(buf)]; + struct sockaddr_storage addr2; + socklen_t addr2len = sizeof(addr2); + EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, + reinterpret_cast<sockaddr*>(&addr2), &addr2len), + SyscallSucceedsWithValue(sizeof(received))); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); + EXPECT_EQ(addr2len, addrlen_); + EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0); +} + +TEST_P(UdpSocketTest, Listen) { + ASSERT_THAT(listen(sock_.get(), SOMAXCONN), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +TEST_P(UdpSocketTest, Accept) { + ASSERT_THAT(accept(sock_.get(), nullptr, nullptr), + SyscallFailsWithErrno(EOPNOTSUPP)); +} + +// This test validates that a read shutdown with pending data allows the read +// to proceed with the data before returning EAGAIN. +TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + // Bind to loopback:bind_addr_port+1 and connect to bind_addr_. + ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Verify that we get EWOULDBLOCK when there is nothing to read. + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + const char* buf = "abc"; + EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3)); + + int opts = 0; + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK), + SyscallSucceeds()); + ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); + ASSERT_NE(opts & O_NONBLOCK, 0); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // We should get the data even though read has been shutdown. + EXPECT_THAT( + RecvMsgTimeout(bind_.get(), received, 2 /*buf_size*/, 1 /*timeout*/), + IsPosixErrorOkAndHolds(2)); + + // Because we read less than the entire packet length, since it's a packet + // based socket any subsequent reads should return EWOULDBLOCK. + EXPECT_THAT(recv(bind_.get(), received, 1, 0), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +// This test is validating that even after a socket is shutdown if it's +// reconnected it will reset the shutdown state. +TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { + char received[512]; + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_NO_ERRNO(BindLoopback()); + + // Connect to loopback:bind_addr_port+1. + struct sockaddr_storage addr_storage = InetLoopbackAddr(); + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); + SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); + ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); +} + +TEST_P(UdpSocketTest, ReadShutdown) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + + char received[512]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then try to shutdown again. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { + // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without + // MSG_DONTWAIT blocks indefinitely. + SKIP_IF(IsRunningWithHostinet()); + ASSERT_NO_ERRNO(BindLoopback()); + + char received[512]; + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Connect the socket, then shutdown from another thread. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + }); + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); + t.Join(); + + EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(0)); +} + +TEST_P(UdpSocketTest, WriteShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds()); +} + +TEST_P(UdpSocketTest, SynchronousReceive) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send some data to bind_ from another thread. + char buf[512]; + RandomizeBuffer(buf, sizeof(buf)); + + // Receive the data prior to actually starting the other thread. + char received[512]; + EXPECT_THAT( + RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT), + SyscallFailsWithErrno(EWOULDBLOCK)); + + // Start the thread. + ScopedThread t([&] { + absl::SleepFor(absl::Milliseconds(200)); + ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_, + this->addrlen_), + SyscallSucceedsWithValue(sizeof(buf))); + }); + + EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0), + SyscallSucceedsWithValue(512)); + EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); + } + + // Receive the data as 3 separate packets. + char received[6 * psize]; + for (int i = 0; i < 3; ++i) { + EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0), + SyscallSucceedsWithValue(psize)); + } + EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Direct writes from sock to bind_. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(writev(sock_.get(), iov, 2), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + ASSERT_THAT(readv(bind_.get(), iov, 3), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Send 2 packets from sock to bind_, where each packet's data consists of + // 2 discontiguous iovecs. + constexpr size_t kPieceSize = 100; + char buf[4 * kPieceSize]; + RandomizeBuffer(buf, sizeof(buf)); + + for (int i = 0; i < 2; i++) { + struct iovec iov[2]; + for (int j = 0; j < 2; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_name = bind_addr_; + msg.msg_namelen = addrlen_; + msg.msg_iov = iov; + msg.msg_iovlen = 2; + ASSERT_THAT(sendmsg(sock_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + + // Receive the data as 2 separate packets. + char received[6 * kPieceSize]; + for (int i = 0; i < 2; i++) { + struct iovec iov[3]; + for (int j = 0; j < 3; j++) { + iov[j].iov_base = reinterpret_cast<void*>( + reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); + iov[j].iov_len = kPieceSize; + } + struct msghdr msg = {}; + msg.msg_iov = iov; + msg.msg_iovlen = 3; + ASSERT_THAT(recvmsg(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(2 * kPieceSize)); + } + EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); +} + +TEST_P(UdpSocketTest, FIONREADShutdown) { + ASSERT_NO_ERRNO(BindLoopback()); + + int n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, FIONREADWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0), + SyscallSucceedsWithValue(sizeof(str))); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, sizeof(str)); +} + +// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the +// corresponding macro and become `0x541B`. +TEST_P(UdpSocketTest, Fionread) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(psize)); + + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, psize); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { + ASSERT_NO_ERRNO(BindLoopback()); + + // Check that the bound socket with an empty buffer reports an empty first + // packet. + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + // Send 3 packets from sock to bind_. + constexpr int psize = 100; + char buf[3 * psize]; + RandomizeBuffer(buf, sizeof(buf)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + for (int i = 0; i < 3; ++i) { + ASSERT_THAT( + sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(0)); + + // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet + // socket does not cause a poll event to be triggered. + if (!IsRunningWithHostinet()) { + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + } + + // Check that regardless of how many packets are in the queue, the size + // reported is that of a single packet. + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + } +} + +TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { + int n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + ASSERT_NO_ERRNO(BindLoopback()); + + // A UDP socket must be connected before it can be shutdown. + ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + const char str[] = "abc"; + ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0)); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); + + EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); + + n = -1; + EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); + EXPECT_EQ(n, 0); +} + +TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoNoCheck) { + // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = kSockOptOn; + socklen_t optlen = sizeof(v); + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOn); + ASSERT_EQ(optlen, sizeof(v)); + + v = kSockOptOff; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), + SyscallSucceeds()); + v = -1; + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +#ifdef __linux__ +TEST_P(UdpSocketTest, ErrorQueue) { + char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE), + SyscallFailsWithErrno(EAGAIN)); +} +#endif // __linux__ + +TEST_P(UdpSocketTest, SoTimestampOffByDefault) { + // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by + // hostinet. + SKIP_IF(IsRunningWithHostinet()); + + int v = -1; + socklen_t optlen = sizeof(v); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), + SyscallSucceeds()); + ASSERT_EQ(v, kSockOptOff); + ASSERT_EQ(optlen, sizeof(v)); +} + +TEST_P(UdpSocketTest, SoTimestamp) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + int v = 1; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + + char buf[3]; + // Send zero length packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg; + memset(&msg, 0, sizeof(msg)); + iovec iov; + memset(&iov, 0, sizeof(iov)); + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); + ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); + ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); + + struct timeval tv = {}; + memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); + + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // There should be nothing to get via ioctl. + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, WriteShutdownNotConnected) { + EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); +} + +TEST_P(UdpSocketTest, TimestampIoctl) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); +} + +TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { + // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + struct timeval tv = {}; + ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv), + SyscallFailsWithErrno(ENOENT)); +} + +// Test that the timestamp accessed via SIOCGSTAMP is still accessible after +// SO_TIMESTAMP is enabled and used to retrieve a timestamp. +TEST_P(UdpSocketTest, TimestampIoctlPersistence) { + // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not + // supported by hostinet. + SKIP_IF(IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + char buf[3]; + // Send packet from sock to bind_. + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), + SyscallSucceedsWithValue(sizeof(buf))); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + struct pollfd pfd = {bind_.get(), POLLIN, 0}; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be no control messages. + char recv_buf[sizeof(buf)]; + ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); + + // A nonzero timeval should be available via ioctl. + struct timeval tv = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); + ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); + + // Enable SO_TIMESTAMP and send a message. + int v = 1; + EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), + SyscallSucceeds()); + ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), + SyscallSucceedsWithValue(0)); + + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), + SyscallSucceedsWithValue(1)); + + // There should be a message for SO_TIMESTAMP. + char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; + msghdr msg = {}; + iovec iov = {}; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsgbuf; + msg.msg_controllen = sizeof(cmsgbuf); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), + SyscallSucceedsWithValue(0)); + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + ASSERT_NE(cmsg, nullptr); + + // The ioctl should return the exact same values as before. + struct timeval tv2 = {}; + ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds()); + ASSERT_EQ(tv.tv_sec, tv2.tv_sec); + ASSERT_EQ(tv.tv_usec, tv2.tv_usec); +} + +// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on +// outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SetAndReceiveTOS) { + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + + // Set socket TOS. + int sent_level = recv_level; + int sent_type = IP_TOS; + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + } + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos, + sizeof(sent_tos)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_type == IPV6_TCLASS) { + cmsg_data_len = sizeof(int); + } + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = received_cmsgbuf.size(); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the +// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or +// IPV6_RECVTCLASS will create the corresponding control message. +TEST_P(UdpSocketTest, SendAndReceiveTOS) { + // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. + SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); + + ASSERT_NO_ERRNO(BindLoopback()); + ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); + + // Allow socket to receive control message. + int recv_level = SOL_IP; + int recv_type = IP_RECVTOS; + if (GetParam() != AddressFamily::kIpv4) { + recv_level = SOL_IPV6; + recv_type = IPV6_RECVTCLASS; + } + int recv_opt = kSockOptOn; + ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt, + sizeof(recv_opt)), + SyscallSucceeds()); + + // Prepare message to send. + constexpr size_t kDataLength = 1024; + int sent_level = recv_level; + int sent_type = IP_TOS; + int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. + + struct msghdr sent_msg = {}; + struct iovec sent_iov = {}; + char sent_data[kDataLength]; + sent_iov.iov_base = &sent_data[0]; + sent_iov.iov_len = kDataLength; + sent_msg.msg_iov = &sent_iov; + sent_msg.msg_iovlen = 1; + size_t cmsg_data_len = sizeof(int8_t); + if (sent_level == SOL_IPV6) { + sent_type = IPV6_TCLASS; + cmsg_data_len = sizeof(int); + } + std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + sent_msg.msg_control = &sent_cmsgbuf[0]; + sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + + // Manually add control message. + struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); + sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); + sent_cmsg->cmsg_level = sent_level; + sent_cmsg->cmsg_type = sent_type; + *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; + + ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + // Receive message. + struct msghdr received_msg = {}; + struct iovec received_iov = {}; + char received_data[kDataLength]; + received_iov.iov_base = &received_data[0]; + received_iov.iov_len = kDataLength; + received_msg.msg_iov = &received_iov; + received_msg.msg_iovlen = 1; + std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); + received_msg.msg_control = &received_cmsgbuf[0]; + received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), + SyscallSucceedsWithValue(kDataLength)); + + struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, sent_level); + EXPECT_EQ(cmsg->cmsg_type, sent_type); + int8_t received_tos = 0; + memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); + EXPECT_EQ(received_tos, sent_tos); +} + +TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { + // Discover minimum buffer size by setting it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + int min = 0; + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + + // Bind bind_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); + + { + // Send data of size min and verify that it's received. + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + std::vector<char> received(buf.size()); + EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(received.size())); + } + + { + // Send data of size min + 1 and verify that its received. Both linux and + // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer + // is currently empty. + std::vector<char> buf(min + 1); + RandomizeBuffer(buf.data(), buf.size()); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + + std::vector<char> received(buf.size()); + ASSERT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(received.size())); + } +} + +// Test that receive buffer limits are enforced. +TEST_P(UdpSocketTest, RecvBufLimits) { + // Bind s_ to loopback. + ASSERT_NO_ERRNO(BindLoopback()); + + int min = 0; + { + // Discover minimum buffer size by trying to set it to zero. + constexpr int kRcvBufSz = 0; + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, + sizeof(kRcvBufSz)), + SyscallSucceeds()); + + socklen_t min_len = sizeof(min); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), + SyscallSucceeds()); + } + + // Now set the limit to min * 4. + int new_rcv_buf_sz = min * 4; + if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { + // Linux doubles the value specified so just set to min * 2. + new_rcv_buf_sz = min * 2; + } + + ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, + sizeof(new_rcv_buf_sz)), + SyscallSucceeds()); + int rcv_buf_sz = 0; + { + socklen_t rcv_buf_len = sizeof(rcv_buf_sz); + ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, + &rcv_buf_len), + SyscallSucceeds()); + } + + { + std::vector<char> buf(min); + RandomizeBuffer(buf.data(), buf.size()); + + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + int sent = 4; + if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { + // Linux seems to drop the 4th packet even though technically it should + // fit in the receive buffer. + ASSERT_THAT( + sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), + SyscallSucceedsWithValue(buf.size())); + sent++; + } + + for (int i = 0; i < sent - 1; i++) { + // Receive the data. + std::vector<char> received(buf.size()); + EXPECT_THAT(RecvMsgTimeout(bind_.get(), received.data(), received.size(), + 1 /*timeout*/), + IsPosixErrorOkAndHolds(received.size())); + EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); + } + + // The last receive should fail with EAGAIN as the last packet should have + // been dropped due to lack of space in the receive buffer. + std::vector<char> received(buf.size()); + EXPECT_THAT( + recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + } +} + +#ifdef __linux__ + +// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. +// gVisor currently silently ignores attaching a filter. +TEST_P(UdpSocketTest, SetSocketDetachFilter) { + // Program generated using sudo tcpdump -i lo udp and port 1234 -dd + struct sock_filter code[] = { + {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, + {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011}, + {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, + {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, + {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, + {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014}, + {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, + {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, + {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, + {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, + }; + struct sock_fprog bpf = { + .len = ABSL_ARRAYSIZE(code), + .filter = code, + }; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), + SyscallSucceeds()); + + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallSucceeds()); +} + +#endif // __linux__ + +TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) { + // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. + SKIP_IF(IsRunningOnGvisor()); + constexpr int val = 0; + ASSERT_THAT( + setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), + SyscallFailsWithErrno(ENOENT)); +} + +TEST_P(UdpSocketTest, GetSocketDetachFilter) { + int val = 0; + socklen_t val_len = sizeof(val); + ASSERT_THAT( + getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), + SyscallFailsWithErrno(ENOPROTOOPT)); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, UdpSocketTest, ::testing::Values(AddressFamily::kIpv4, AddressFamily::kIpv6, diff --git a/test/syscalls/linux/udp_socket_errqueue_test_case.cc b/test/syscalls/linux/udp_socket_errqueue_test_case.cc deleted file mode 100644 index 54a0594f7..000000000 --- a/test/syscalls/linux/udp_socket_errqueue_test_case.cc +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef __fuchsia__ - -#include <arpa/inet.h> -#include <fcntl.h> -#include <linux/errqueue.h> -#include <netinet/in.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/udp_socket_test_cases.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -TEST_P(UdpSocketTest, ErrorQueue) { - char cmsgbuf[CMSG_SPACE(sizeof(sock_extended_err))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. - EXPECT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, MSG_ERRQUEUE), - SyscallFailsWithErrno(EAGAIN)); -} - -} // namespace testing -} // namespace gvisor - -#endif // __fuchsia__ diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc deleted file mode 100644 index 60c48ed6e..000000000 --- a/test/syscalls/linux/udp_socket_test_cases.cc +++ /dev/null @@ -1,1781 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "test/syscalls/linux/udp_socket_test_cases.h" - -#include <arpa/inet.h> -#include <fcntl.h> -#ifndef __fuchsia__ -#include <linux/filter.h> -#endif // __fuchsia__ -#include <netinet/in.h> -#include <poll.h> -#include <sys/ioctl.h> -#include <sys/socket.h> -#include <sys/types.h> - -#include "absl/strings/str_format.h" -#ifndef SIOCGSTAMP -#include <linux/sockios.h> -#endif - -#include "gtest/gtest.h" -#include "absl/base/macros.h" -#include "absl/time/clock.h" -#include "absl/time/time.h" -#include "test/syscalls/linux/ip_socket_test_util.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/syscalls/linux/unix_domain_socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" -#include "test/util/test_util.h" -#include "test/util/thread_util.h" - -namespace gvisor { -namespace testing { - -// Gets a pointer to the port component of the given address. -uint16_t* Port(struct sockaddr_storage* addr) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - return &sin->sin_port; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - return &sin6->sin6_port; - } - } - - return nullptr; -} - -// Sets addr port to "port". -void SetPort(struct sockaddr_storage* addr, uint16_t port) { - switch (addr->ss_family) { - case AF_INET: { - auto sin = reinterpret_cast<struct sockaddr_in*>(addr); - sin->sin_port = port; - break; - } - case AF_INET6: { - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(addr); - sin6->sin6_port = port; - break; - } - } -} - -void UdpSocketTest::SetUp() { - addrlen_ = GetAddrLength(); - - bind_ = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - memset(&bind_addr_storage_, 0, sizeof(bind_addr_storage_)); - bind_addr_ = reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - - sock_ = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); -} - -int UdpSocketTest::GetFamily() { - if (GetParam() == AddressFamily::kIpv4) { - return AF_INET; - } - return AF_INET6; -} - -PosixError UdpSocketTest::BindLoopback() { - bind_addr_storage_ = InetLoopbackAddr(); - struct sockaddr* bind_addr_ = - reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - return BindSocket(bind_.get(), bind_addr_); -} - -PosixError UdpSocketTest::BindAny() { - bind_addr_storage_ = InetAnyAddr(); - struct sockaddr* bind_addr_ = - reinterpret_cast<struct sockaddr*>(&bind_addr_storage_); - return BindSocket(bind_.get(), bind_addr_); -} - -PosixError UdpSocketTest::BindSocket(int socket, struct sockaddr* addr) { - socklen_t len = sizeof(bind_addr_storage_); - - // Bind, then check that we get the right address. - RETURN_ERROR_IF_SYSCALL_FAIL(bind(socket, addr, addrlen_)); - - RETURN_ERROR_IF_SYSCALL_FAIL(getsockname(socket, addr, &len)); - - if (addrlen_ != len) { - return PosixError( - EINVAL, - absl::StrFormat("getsockname len: %u expected: %u", len, addrlen_)); - } - return PosixError(0); -} - -socklen_t UdpSocketTest::GetAddrLength() { - struct sockaddr_storage addr; - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - return sizeof(*sin); - } - - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - return sizeof(*sin6); -} - -sockaddr_storage UdpSocketTest::InetAnyAddr() { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); - - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - sin->sin_addr.s_addr = htonl(INADDR_ANY); - sin->sin_port = htons(0); - return addr; - } - - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - sin6->sin6_addr = IN6ADDR_ANY_INIT; - sin6->sin6_port = htons(0); - return addr; -} - -sockaddr_storage UdpSocketTest::InetLoopbackAddr() { - struct sockaddr_storage addr; - memset(&addr, 0, sizeof(addr)); - reinterpret_cast<struct sockaddr*>(&addr)->sa_family = GetFamily(); - - if (GetFamily() == AF_INET) { - auto sin = reinterpret_cast<struct sockaddr_in*>(&addr); - sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); - sin->sin_port = htons(0); - return addr; - } - auto sin6 = reinterpret_cast<struct sockaddr_in6*>(&addr); - sin6->sin6_addr = in6addr_loopback; - sin6->sin6_port = htons(0); - return addr; -} - -void UdpSocketTest::Disconnect(int sockfd) { - sockaddr_storage addr_storage = InetAnyAddr(); - sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - socklen_t addrlen = sizeof(addr_storage); - - addr->sa_family = AF_UNSPEC; - ASSERT_THAT(connect(sockfd, addr, addrlen), SyscallSucceeds()); - - // Check that after disconnect the socket is bound to the ANY address. - EXPECT_THAT(getsockname(sockfd, addr, &addrlen), SyscallSucceeds()); - if (GetParam() == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_ANY)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - struct in6_addr loopback = IN6ADDR_ANY_INIT; - - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, Creation) { - FileDescriptor sock = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - EXPECT_THAT(close(sock.release()), SyscallSucceeds()); - - sock = ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, 0)); - EXPECT_THAT(close(sock.release()), SyscallSucceeds()); - - ASSERT_THAT(socket(GetFamily(), SOCK_STREAM, IPPROTO_UDP), SyscallFails()); -} - -TEST_P(UdpSocketTest, Getsockname) { - // Check that we're not bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_EQ(memcmp(&addr, reinterpret_cast<struct sockaddr*>(&any), addrlen_), - 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, Getpeername) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that we're not connected. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Connect, then check that we get the right address. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, SendNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Do send & write, they must fail. - char buf[512]; - EXPECT_THAT(send(sock_.get(), buf, sizeof(buf), 0), - SyscallFailsWithErrno(EDESTADDRREQ)); - - EXPECT_THAT(write(sock_.get(), buf, sizeof(buf)), - SyscallFailsWithErrno(EDESTADDRREQ)); - - // Use sendto. - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ConnectBinds) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're bound now. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_NE(*Port(&addr), 0); -} - -TEST_P(UdpSocketTest, ReceiveNotBound) { - char buf[512]; - EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, Bind) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Try to bind again. - EXPECT_THAT(bind(bind_.get(), bind_addr_, addrlen_), - SyscallFailsWithErrno(EINVAL)); - - // Check that we're still bound to the original address. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(bind_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, bind_addr_, addrlen_), 0); -} - -TEST_P(UdpSocketTest, BindInUse) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Try to bind again. - EXPECT_THAT(bind(sock_.get(), bind_addr_, addrlen_), - SyscallFailsWithErrno(EADDRINUSE)); -} - -TEST_P(UdpSocketTest, ReceiveAfterConnect) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send from sock_ to bind_ - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveAfterDisconnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - for (int i = 0; i < 2; i++) { - // Connet sock_ to bound address. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - - // Send from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(bind_.get(), buf, sizeof(buf), 0, - reinterpret_cast<sockaddr*>(&addr), addrlen), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Disconnect sock_. - struct sockaddr unspec = {}; - unspec.sa_family = AF_UNSPEC; - ASSERT_THAT(connect(sock_.get(), &unspec, sizeof(unspec.sa_family)), - SyscallSucceeds()); - } -} - -TEST_P(UdpSocketTest, Connect) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); - - // Try to bind after connect. - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - bind(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallFailsWithErrno(EINVAL)); - - struct sockaddr_storage bind2_storage = InetLoopbackAddr(); - struct sockaddr* bind2_addr = - reinterpret_cast<struct sockaddr*>(&bind2_storage); - FileDescriptor bind2 = - ASSERT_NO_ERRNO_AND_VALUE(Socket(GetFamily(), SOCK_DGRAM, IPPROTO_UDP)); - ASSERT_NO_ERRNO(BindSocket(bind2.get(), bind2_addr)); - - // Try to connect again. - EXPECT_THAT(connect(sock_.get(), bind2_addr, addrlen_), SyscallSucceeds()); - - // Check that peer name changed. - peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind2_addr, addrlen_), 0); -} - -TEST_P(UdpSocketTest, ConnectAnyZero) { - // TODO(138658473): Enable when we can connect to port 0 with gVisor. - SKIP_IF(IsRunningOnGvisor()); - - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, ConnectAnyWithPort) { - ASSERT_NO_ERRNO(BindAny()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAny) { - // TODO(138658473): Enable when we can connect to port 0 with gVisor. - SKIP_IF(IsRunningOnGvisor()); - struct sockaddr_storage any = InetAnyAddr(); - EXPECT_THAT( - connect(sock_.get(), reinterpret_cast<struct sockaddr*>(&any), addrlen_), - SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); - - Disconnect(sock_.get()); -} - -TEST_P(UdpSocketTest, DisconnectAfterConnectAnyWithPort) { - ASSERT_NO_ERRNO(BindAny()); - EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&bind_addr_storage_), *Port(&addr)); - - Disconnect(sock_.get()); -} - -TEST_P(UdpSocketTest, DisconnectAfterBind) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Bind to the next port above bind_. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_NO_ERRNO(BindSocket(sock_.get(), addr)); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct sockaddr_storage unspec = {}; - unspec.ss_family = AF_UNSPEC; - EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), - sizeof(unspec.ss_family)), - SyscallSucceeds()); - - // Check that we're still bound. - socklen_t addrlen = sizeof(unspec); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&unspec), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(addr, &unspec, addrlen_), 0); - - addrlen = sizeof(addr); - EXPECT_THAT(getpeername(sock_.get(), addr, &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, BindToAnyConnnectToLocalhost) { - ASSERT_NO_ERRNO(BindAny()); - - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - socklen_t addrlen = sizeof(addr); - - // Connect the socket. - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(getsockname(bind_.get(), addr, &addrlen), SyscallSucceeds()); - - // If the socket is bound to ANY and connected to a loopback address, - // getsockname() has to return the loopback address. - if (GetParam() == AddressFamily::kIpv4) { - auto addr_out = reinterpret_cast<struct sockaddr_in*>(addr); - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(addr_out->sin_addr.s_addr, htonl(INADDR_LOOPBACK)); - } else { - auto addr_out = reinterpret_cast<struct sockaddr_in6*>(addr); - struct in6_addr loopback = IN6ADDR_LOOPBACK_INIT; - EXPECT_EQ(addrlen, sizeof(*addr_out)); - EXPECT_EQ(memcmp(&addr_out->sin6_addr, &loopback, sizeof(in6_addr)), 0); - } -} - -TEST_P(UdpSocketTest, DisconnectAfterBindToAny) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage any_storage = InetAnyAddr(); - struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); - SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); - - ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); - - // Connect the socket. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - Disconnect(sock_.get()); - - // Check that we're still bound. - struct sockaddr_storage addr; - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(memcmp(&addr, any, addrlen), 0); - - addrlen = sizeof(addr); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, Disconnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage any_storage = InetAnyAddr(); - struct sockaddr* any = reinterpret_cast<struct sockaddr*>(&any_storage); - SetPort(&any_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_NO_ERRNO(BindSocket(sock_.get(), any)); - - for (int i = 0; i < 2; i++) { - // Try to connect again. - EXPECT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Check that we're connected to the right peer. - struct sockaddr_storage peer; - socklen_t peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallSucceeds()); - EXPECT_EQ(peerlen, addrlen_); - EXPECT_EQ(memcmp(&peer, bind_addr_, addrlen_), 0); - - // Try to disconnect. - struct sockaddr_storage addr = {}; - addr.ss_family = AF_UNSPEC; - EXPECT_THAT(connect(sock_.get(), reinterpret_cast<sockaddr*>(&addr), - sizeof(addr.ss_family)), - SyscallSucceeds()); - - peerlen = sizeof(peer); - EXPECT_THAT( - getpeername(sock_.get(), reinterpret_cast<sockaddr*>(&peer), &peerlen), - SyscallFailsWithErrno(ENOTCONN)); - - // Check that we're still bound. - socklen_t addrlen = sizeof(addr); - EXPECT_THAT( - getsockname(sock_.get(), reinterpret_cast<sockaddr*>(&addr), &addrlen), - SyscallSucceeds()); - EXPECT_EQ(addrlen, addrlen_); - EXPECT_EQ(*Port(&addr), *Port(&any_storage)); - } -} - -TEST_P(UdpSocketTest, ConnectBadAddress) { - struct sockaddr addr = {}; - addr.sa_family = GetFamily(); - ASSERT_THAT(connect(sock_.get(), &addr, sizeof(addr.sa_family)), - SyscallFailsWithErrno(EINVAL)); -} - -TEST_P(UdpSocketTest, SendToAddressOtherThanConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - struct sockaddr_storage addr_storage = InetAnyAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send to a different destination than we're connected to. - char buf[512]; - EXPECT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, addr, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowed) { - // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - // Connect to loopback:bind_addr_+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from bind_ to sock_. - ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {sock_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout*/ 1000), - SyscallSucceedsWithValue(1)); - - // Receive the packet. - char received[3]; - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ZerolengthWriteAllowedNonBlockRead) { - // TODO(gvisor.dev/issue/1202): Hostinet does not support zero length writes. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_port+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Set sock to non-blocking. - int opts = 0; - ASSERT_THAT(opts = fcntl(sock_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(sock_.get(), F_SETFL, opts | O_NONBLOCK), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from bind_ to sock_. - ASSERT_THAT(write(bind_.get(), buf, 0), SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {sock_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // Receive the packet. - char received[3]; - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallSucceedsWithValue(0)); - EXPECT_THAT(read(sock_.get(), received, sizeof(received)), - SyscallFailsWithErrno(EAGAIN)); -} - -TEST_P(UdpSocketTest, SendAndReceiveNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send some data to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, SendAndReceiveConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:TestPort+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, ReceiveFromNotConnected) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:bind_addr_port+2. - struct sockaddr_storage addr2_storage = InetLoopbackAddr(); - struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); - SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); - ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Check that the data isn't received because it was sent from a different - // address than we're connected. - EXPECT_THAT(recv(sock_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveBeforeConnect) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Bind sock to loopback:bind_addr_port+2. - struct sockaddr_storage addr2_storage = InetLoopbackAddr(); - struct sockaddr* addr2 = reinterpret_cast<struct sockaddr*>(&addr2_storage); - SetPort(&addr2_storage, *Port(&bind_addr_storage_) + 2); - ASSERT_THAT(bind(sock_.get(), addr2, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Connect to loopback:TestPort+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Receive the data. It works because it was sent before the connect. - char received[sizeof(buf)]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - - // Send again. This time it should not be received. - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - EXPECT_THAT(recv(bind_.get(), buf, sizeof(buf), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReceiveFrom) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind sock to loopback:TestPort+1. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - - // Send some data from sock to bind_. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - - // Receive the data and sender address. - char received[sizeof(buf)]; - struct sockaddr_storage addr2; - socklen_t addr2len = sizeof(addr2); - EXPECT_THAT(recvfrom(bind_.get(), received, sizeof(received), 0, - reinterpret_cast<sockaddr*>(&addr2), &addr2len), - SyscallSucceedsWithValue(sizeof(received))); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); - EXPECT_EQ(addr2len, addrlen_); - EXPECT_EQ(memcmp(addr, &addr2, addrlen_), 0); -} - -TEST_P(UdpSocketTest, Listen) { - ASSERT_THAT(listen(sock_.get(), SOMAXCONN), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -TEST_P(UdpSocketTest, Accept) { - ASSERT_THAT(accept(sock_.get(), nullptr, nullptr), - SyscallFailsWithErrno(EOPNOTSUPP)); -} - -// This test validates that a read shutdown with pending data allows the read -// to proceed with the data before returning EAGAIN. -TEST_P(UdpSocketTest, ReadShutdownNonblockPendingData) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - // Bind to loopback:bind_addr_port+1 and connect to bind_addr_. - ASSERT_THAT(bind(sock_.get(), addr, addrlen_), SyscallSucceeds()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Verify that we get EWOULDBLOCK when there is nothing to read. - char received[512]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - const char* buf = "abc"; - EXPECT_THAT(write(sock_.get(), buf, 3), SyscallSucceedsWithValue(3)); - - int opts = 0; - ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_THAT(fcntl(bind_.get(), F_SETFL, opts | O_NONBLOCK), - SyscallSucceeds()); - ASSERT_THAT(opts = fcntl(bind_.get(), F_GETFL), SyscallSucceeds()); - ASSERT_NE(opts & O_NONBLOCK, 0); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // We should get the data even though read has been shutdown. - EXPECT_THAT(recv(bind_.get(), received, 2, 0), SyscallSucceedsWithValue(2)); - - // Because we read less than the entire packet length, since it's a packet - // based socket any subsequent reads should return EWOULDBLOCK. - EXPECT_THAT(recv(bind_.get(), received, 1, 0), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -// This test is validating that even after a socket is shutdown if it's -// reconnected it will reset the shutdown state. -TEST_P(UdpSocketTest, ReadShutdownSameSocketResetsShutdownState) { - char received[512]; - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_NO_ERRNO(BindLoopback()); - - // Connect to loopback:bind_addr_port+1. - struct sockaddr_storage addr_storage = InetLoopbackAddr(); - struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_storage); - SetPort(&addr_storage, *Port(&bind_addr_storage_) + 1); - ASSERT_THAT(connect(bind_.get(), addr, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); -} - -TEST_P(UdpSocketTest, ReadShutdown) { - // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without - // MSG_DONTWAIT blocks indefinitely. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - - char received[512]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then try to shutdown again. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, ReadShutdownDifferentThread) { - // TODO(gvisor.dev/issue/1202): Calling recv() after shutdown without - // MSG_DONTWAIT blocks indefinitely. - SKIP_IF(IsRunningWithHostinet()); - ASSERT_NO_ERRNO(BindLoopback()); - - char received[512]; - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Connect the socket, then shutdown from another thread. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - EXPECT_THAT(recv(sock_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - }); - EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); - t.Join(); - - EXPECT_THAT(RetryEINTR(recv)(sock_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(0)); -} - -TEST_P(UdpSocketTest, WriteShutdown) { - ASSERT_NO_ERRNO(BindLoopback()); - EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - EXPECT_THAT(shutdown(sock_.get(), SHUT_WR), SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, SynchronousReceive) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send some data to bind_ from another thread. - char buf[512]; - RandomizeBuffer(buf, sizeof(buf)); - - // Receive the data prior to actually starting the other thread. - char received[512]; - EXPECT_THAT( - RetryEINTR(recv)(bind_.get(), received, sizeof(received), MSG_DONTWAIT), - SyscallFailsWithErrno(EWOULDBLOCK)); - - // Start the thread. - ScopedThread t([&] { - absl::SleepFor(absl::Milliseconds(200)); - ASSERT_THAT(sendto(sock_.get(), buf, sizeof(buf), 0, this->bind_addr_, - this->addrlen_), - SyscallSucceedsWithValue(sizeof(buf))); - }); - - EXPECT_THAT(RetryEINTR(recv)(bind_.get(), received, sizeof(received), 0), - SyscallSucceedsWithValue(512)); - EXPECT_EQ(memcmp(buf, received, sizeof(buf)), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendRecv) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(psize)); - } - - // Receive the data as 3 separate packets. - char received[6 * psize]; - for (int i = 0; i < 3; ++i) { - EXPECT_THAT(recv(bind_.get(), received + i * psize, 3 * psize, 0), - SyscallSucceedsWithValue(psize)); - } - EXPECT_EQ(memcmp(buf, received, 3 * psize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_WritevReadv) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Direct writes from sock to bind_. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Send 2 packets from sock to bind_, where each packet's data consists of - // 2 discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(writev(sock_.get(), iov, 2), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - ASSERT_THAT(readv(bind_.get(), iov, 3), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, BoundaryPreserved_SendMsgRecvMsg) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Send 2 packets from sock to bind_, where each packet's data consists of - // 2 discontiguous iovecs. - constexpr size_t kPieceSize = 100; - char buf[4 * kPieceSize]; - RandomizeBuffer(buf, sizeof(buf)); - - for (int i = 0; i < 2; i++) { - struct iovec iov[2]; - for (int j = 0; j < 2; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(buf) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_name = bind_addr_; - msg.msg_namelen = addrlen_; - msg.msg_iov = iov; - msg.msg_iovlen = 2; - ASSERT_THAT(sendmsg(sock_.get(), &msg, 0), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - - // Receive the data as 2 separate packets. - char received[6 * kPieceSize]; - for (int i = 0; i < 2; i++) { - struct iovec iov[3]; - for (int j = 0; j < 3; j++) { - iov[j].iov_base = reinterpret_cast<void*>( - reinterpret_cast<uintptr_t>(received) + (i + 2 * j) * kPieceSize); - iov[j].iov_len = kPieceSize; - } - struct msghdr msg = {}; - msg.msg_iov = iov; - msg.msg_iovlen = 3; - ASSERT_THAT(recvmsg(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(2 * kPieceSize)); - } - EXPECT_EQ(memcmp(buf, received, 4 * kPieceSize), 0); -} - -TEST_P(UdpSocketTest, FIONREADShutdown) { - ASSERT_NO_ERRNO(BindLoopback()); - - int n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(sock_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(sock_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, FIONREADWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(bind_.get(), str, sizeof(str), 0), - SyscallSucceedsWithValue(sizeof(str))); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, sizeof(str)); -} - -// NOTE: Do not use `FIONREAD` as test name because it will be replaced by the -// corresponding macro and become `0x541B`. -TEST_P(UdpSocketTest, Fionread) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, psize, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(psize)); - - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, psize); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthPacket) { - ASSERT_NO_ERRNO(BindLoopback()); - - // Check that the bound socket with an empty buffer reports an empty first - // packet. - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - // Send 3 packets from sock to bind_. - constexpr int psize = 100; - char buf[3 * psize]; - RandomizeBuffer(buf, sizeof(buf)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - for (int i = 0; i < 3; ++i) { - ASSERT_THAT( - sendto(sock_.get(), buf + i * psize, 0, 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(0)); - - // TODO(gvisor.dev/issue/2726): sending a zero-length message to a hostinet - // socket does not cause a poll event to be triggered. - if (!IsRunningWithHostinet()) { - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - } - - // Check that regardless of how many packets are in the queue, the size - // reported is that of a single packet. - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - } -} - -TEST_P(UdpSocketTest, FIONREADZeroLengthWriteShutdown) { - int n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - ASSERT_NO_ERRNO(BindLoopback()); - - // A UDP socket must be connected before it can be shutdown. - ASSERT_THAT(connect(bind_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - const char str[] = "abc"; - ASSERT_THAT(send(bind_.get(), str, 0, 0), SyscallSucceedsWithValue(0)); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); - - EXPECT_THAT(shutdown(bind_.get(), SHUT_RD), SyscallSucceeds()); - - n = -1; - EXPECT_THAT(ioctl(bind_.get(), FIONREAD, &n), SyscallSucceedsWithValue(0)); - EXPECT_EQ(n, 0); -} - -TEST_P(UdpSocketTest, SoNoCheckOffByDefault) { - // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoNoCheck) { - // TODO(gvisor.dev/issue/1202): SO_NO_CHECK socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = kSockOptOn; - socklen_t optlen = sizeof(v); - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), - SyscallSucceeds()); - v = -1; - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOn); - ASSERT_EQ(optlen, sizeof(v)); - - v = kSockOptOff; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, optlen), - SyscallSucceeds()); - v = -1; - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_NO_CHECK, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestampOffByDefault) { - // TODO(gvisor.dev/issue/1202): SO_TIMESTAMP socket option not supported by - // hostinet. - SKIP_IF(IsRunningWithHostinet()); - - int v = -1; - socklen_t optlen = sizeof(v); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, &optlen), - SyscallSucceeds()); - ASSERT_EQ(v, kSockOptOff); - ASSERT_EQ(optlen, sizeof(v)); -} - -TEST_P(UdpSocketTest, SoTimestamp) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - int v = 1; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - - char buf[3]; - // Send zero length packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg; - memset(&msg, 0, sizeof(msg)); - iovec iov; - memset(&iov, 0, sizeof(iov)); - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(0)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET); - ASSERT_EQ(cmsg->cmsg_type, SO_TIMESTAMP); - ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(struct timeval))); - - struct timeval tv = {}; - memcpy(&tv, CMSG_DATA(cmsg), sizeof(struct timeval)); - - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // There should be nothing to get via ioctl. - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), - SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UdpSocketTest, WriteShutdownNotConnected) { - EXPECT_THAT(shutdown(bind_.get(), SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); -} - -TEST_P(UdpSocketTest, TimestampIoctl) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); -} - -TEST_P(UdpSocketTest, TimestampIoctlNothingRead) { - // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - struct timeval tv = {}; - ASSERT_THAT(ioctl(sock_.get(), SIOCGSTAMP, &tv), - SyscallFailsWithErrno(ENOENT)); -} - -// Test that the timestamp accessed via SIOCGSTAMP is still accessible after -// SO_TIMESTAMP is enabled and used to retrieve a timestamp. -TEST_P(UdpSocketTest, TimestampIoctlPersistence) { - // TODO(gvisor.dev/issue/1202): ioctl() and SO_TIMESTAMP socket option are not - // supported by hostinet. - SKIP_IF(IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - char buf[3]; - // Send packet from sock to bind_. - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, sizeof(buf)), - SyscallSucceedsWithValue(sizeof(buf))); - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - struct pollfd pfd = {bind_.get(), POLLIN, 0}; - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be no control messages. - char recv_buf[sizeof(buf)]; - ASSERT_NO_FATAL_FAILURE(RecvNoCmsg(bind_.get(), recv_buf, sizeof(recv_buf))); - - // A nonzero timeval should be available via ioctl. - struct timeval tv = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv), SyscallSucceeds()); - ASSERT_TRUE(tv.tv_sec != 0 || tv.tv_usec != 0); - - // Enable SO_TIMESTAMP and send a message. - int v = 1; - EXPECT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_TIMESTAMP, &v, sizeof(v)), - SyscallSucceeds()); - ASSERT_THAT(RetryEINTR(write)(sock_.get(), buf, 0), - SyscallSucceedsWithValue(0)); - - ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, /*timeout=*/1000), - SyscallSucceedsWithValue(1)); - - // There should be a message for SO_TIMESTAMP. - char cmsgbuf[CMSG_SPACE(sizeof(struct timeval))]; - msghdr msg = {}; - iovec iov = {}; - msg.msg_iov = &iov; - msg.msg_iovlen = 1; - msg.msg_control = cmsgbuf; - msg.msg_controllen = sizeof(cmsgbuf); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &msg, 0), - SyscallSucceedsWithValue(0)); - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); - ASSERT_NE(cmsg, nullptr); - - // The ioctl should return the exact same values as before. - struct timeval tv2 = {}; - ASSERT_THAT(ioctl(bind_.get(), SIOCGSTAMP, &tv2), SyscallSucceeds()); - ASSERT_EQ(tv.tv_sec, tv2.tv_sec); - ASSERT_EQ(tv.tv_usec, tv2.tv_usec); -} - -// Test that a socket with IP_TOS or IPV6_TCLASS set will set the TOS byte on -// outgoing packets, and that a receiving socket with IP_RECVTOS or -// IPV6_RECVTCLASS will create the corresponding control message. -TEST_P(UdpSocketTest, SetAndReceiveTOS) { - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Allow socket to receive control message. - int recv_level = SOL_IP; - int recv_type = IP_RECVTOS; - if (GetParam() != AddressFamily::kIpv4) { - recv_level = SOL_IPV6; - recv_type = IPV6_RECVTCLASS; - } - ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &kSockOptOn, - sizeof(kSockOptOn)), - SyscallSucceeds()); - - // Set socket TOS. - int sent_level = recv_level; - int sent_type = IP_TOS; - if (sent_level == SOL_IPV6) { - sent_type = IPV6_TCLASS; - } - int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - ASSERT_THAT(setsockopt(sock_.get(), sent_level, sent_type, &sent_tos, - sizeof(sent_tos)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - struct msghdr sent_msg = {}; - struct iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = &sent_data[0]; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - - ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - // Receive message. - struct msghdr received_msg = {}; - struct iovec received_iov = {}; - char received_data[kDataLength]; - received_iov.iov_base = &received_data[0]; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - size_t cmsg_data_len = sizeof(int8_t); - if (sent_type == IPV6_TCLASS) { - cmsg_data_len = sizeof(int); - } - std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - received_msg.msg_control = &received_cmsgbuf[0]; - received_msg.msg_controllen = received_cmsgbuf.size(); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, sent_level); - EXPECT_EQ(cmsg->cmsg_type, sent_type); - int8_t received_tos = 0; - memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); - EXPECT_EQ(received_tos, sent_tos); -} - -// Test that sendmsg with IP_TOS and IPV6_TCLASS control messages will set the -// TOS byte on outgoing packets, and that a receiving socket with IP_RECVTOS or -// IPV6_RECVTCLASS will create the corresponding control message. -TEST_P(UdpSocketTest, SendAndReceiveTOS) { - // TODO(b/146661005): Setting TOS via cmsg not supported for netstack. - SKIP_IF(IsRunningOnGvisor() && !IsRunningWithHostinet()); - - ASSERT_NO_ERRNO(BindLoopback()); - ASSERT_THAT(connect(sock_.get(), bind_addr_, addrlen_), SyscallSucceeds()); - - // Allow socket to receive control message. - int recv_level = SOL_IP; - int recv_type = IP_RECVTOS; - if (GetParam() != AddressFamily::kIpv4) { - recv_level = SOL_IPV6; - recv_type = IPV6_RECVTCLASS; - } - int recv_opt = kSockOptOn; - ASSERT_THAT(setsockopt(bind_.get(), recv_level, recv_type, &recv_opt, - sizeof(recv_opt)), - SyscallSucceeds()); - - // Prepare message to send. - constexpr size_t kDataLength = 1024; - int sent_level = recv_level; - int sent_type = IP_TOS; - int sent_tos = IPTOS_LOWDELAY; // Choose some TOS value. - - struct msghdr sent_msg = {}; - struct iovec sent_iov = {}; - char sent_data[kDataLength]; - sent_iov.iov_base = &sent_data[0]; - sent_iov.iov_len = kDataLength; - sent_msg.msg_iov = &sent_iov; - sent_msg.msg_iovlen = 1; - size_t cmsg_data_len = sizeof(int8_t); - if (sent_level == SOL_IPV6) { - sent_type = IPV6_TCLASS; - cmsg_data_len = sizeof(int); - } - std::vector<char> sent_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - sent_msg.msg_control = &sent_cmsgbuf[0]; - sent_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - - // Manually add control message. - struct cmsghdr* sent_cmsg = CMSG_FIRSTHDR(&sent_msg); - sent_cmsg->cmsg_len = CMSG_LEN(cmsg_data_len); - sent_cmsg->cmsg_level = sent_level; - sent_cmsg->cmsg_type = sent_type; - *(int8_t*)CMSG_DATA(sent_cmsg) = sent_tos; - - ASSERT_THAT(RetryEINTR(sendmsg)(sock_.get(), &sent_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - // Receive message. - struct msghdr received_msg = {}; - struct iovec received_iov = {}; - char received_data[kDataLength]; - received_iov.iov_base = &received_data[0]; - received_iov.iov_len = kDataLength; - received_msg.msg_iov = &received_iov; - received_msg.msg_iovlen = 1; - std::vector<char> received_cmsgbuf(CMSG_SPACE(cmsg_data_len)); - received_msg.msg_control = &received_cmsgbuf[0]; - received_msg.msg_controllen = CMSG_LEN(cmsg_data_len); - ASSERT_THAT(RetryEINTR(recvmsg)(bind_.get(), &received_msg, 0), - SyscallSucceedsWithValue(kDataLength)); - - struct cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg); - ASSERT_NE(cmsg, nullptr); - EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); - EXPECT_EQ(cmsg->cmsg_level, sent_level); - EXPECT_EQ(cmsg->cmsg_type, sent_type); - int8_t received_tos = 0; - memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos)); - EXPECT_EQ(received_tos, sent_tos); -} - -TEST_P(UdpSocketTest, RecvBufLimitsEmptyRcvBuf) { - // Discover minimum buffer size by setting it to zero. - constexpr int kRcvBufSz = 0; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, - sizeof(kRcvBufSz)), - SyscallSucceeds()); - - int min = 0; - socklen_t min_len = sizeof(min); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), - SyscallSucceeds()); - - // Bind bind_ to loopback. - ASSERT_NO_ERRNO(BindLoopback()); - - { - // Send data of size min and verify that it's received. - std::vector<char> buf(min); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - } - - { - // Send data of size min + 1 and verify that its received. Both linux and - // Netstack accept a dgram that exceeds rcvBuf limits if the receive buffer - // is currently empty. - std::vector<char> buf(min + 1); - RandomizeBuffer(buf.data(), buf.size()); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - } -} - -// Test that receive buffer limits are enforced. -TEST_P(UdpSocketTest, RecvBufLimits) { - // Bind s_ to loopback. - ASSERT_NO_ERRNO(BindLoopback()); - - int min = 0; - { - // Discover minimum buffer size by trying to set it to zero. - constexpr int kRcvBufSz = 0; - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &kRcvBufSz, - sizeof(kRcvBufSz)), - SyscallSucceeds()); - - socklen_t min_len = sizeof(min); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &min, &min_len), - SyscallSucceeds()); - } - - // Now set the limit to min * 4. - int new_rcv_buf_sz = min * 4; - if (!IsRunningOnGvisor() || IsRunningWithHostinet()) { - // Linux doubles the value specified so just set to min * 2. - new_rcv_buf_sz = min * 2; - } - - ASSERT_THAT(setsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &new_rcv_buf_sz, - sizeof(new_rcv_buf_sz)), - SyscallSucceeds()); - int rcv_buf_sz = 0; - { - socklen_t rcv_buf_len = sizeof(rcv_buf_sz); - ASSERT_THAT(getsockopt(bind_.get(), SOL_SOCKET, SO_RCVBUF, &rcv_buf_sz, - &rcv_buf_len), - SyscallSucceeds()); - } - - { - std::vector<char> buf(min); - RandomizeBuffer(buf.data(), buf.size()); - - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - int sent = 4; - if (IsRunningOnGvisor() && !IsRunningWithHostinet()) { - // Linux seems to drop the 4th packet even though technically it should - // fit in the receive buffer. - ASSERT_THAT( - sendto(sock_.get(), buf.data(), buf.size(), 0, bind_addr_, addrlen_), - SyscallSucceedsWithValue(buf.size())); - sent++; - } - - for (int i = 0; i < sent - 1; i++) { - // Receive the data. - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallSucceedsWithValue(received.size())); - EXPECT_EQ(memcmp(buf.data(), received.data(), buf.size()), 0); - } - - // The last receive should fail with EAGAIN as the last packet should have - // been dropped due to lack of space in the receive buffer. - std::vector<char> received(buf.size()); - EXPECT_THAT( - recv(bind_.get(), received.data(), received.size(), MSG_DONTWAIT), - SyscallFailsWithErrno(EAGAIN)); - } -} - -#ifndef __fuchsia__ - -// TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. -// gVisor currently silently ignores attaching a filter. -TEST_P(UdpSocketTest, SetSocketDetachFilter) { - // Program generated using sudo tcpdump -i lo udp and port 1234 -dd - struct sock_filter code[] = { - {0x28, 0, 0, 0x0000000c}, {0x15, 0, 6, 0x000086dd}, - {0x30, 0, 0, 0x00000014}, {0x15, 0, 15, 0x00000011}, - {0x28, 0, 0, 0x00000036}, {0x15, 12, 0, 0x000004d2}, - {0x28, 0, 0, 0x00000038}, {0x15, 10, 11, 0x000004d2}, - {0x15, 0, 10, 0x00000800}, {0x30, 0, 0, 0x00000017}, - {0x15, 0, 8, 0x00000011}, {0x28, 0, 0, 0x00000014}, - {0x45, 6, 0, 0x00001fff}, {0xb1, 0, 0, 0x0000000e}, - {0x48, 0, 0, 0x0000000e}, {0x15, 2, 0, 0x000004d2}, - {0x48, 0, 0, 0x00000010}, {0x15, 0, 1, 0x000004d2}, - {0x6, 0, 0, 0x00040000}, {0x6, 0, 0, 0x00000000}, - }; - struct sock_fprog bpf = { - .len = ABSL_ARRAYSIZE(code), - .filter = code, - }; - ASSERT_THAT( - setsockopt(sock_.get(), SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)), - SyscallSucceeds()); - - constexpr int val = 0; - ASSERT_THAT( - setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), - SyscallSucceeds()); -} - -TEST_P(UdpSocketTest, SetSocketDetachFilterNoInstalledFilter) { - // TODO(gvisor.dev/2746): Support SO_ATTACH_FILTER/SO_DETACH_FILTER. - SKIP_IF(IsRunningOnGvisor()); - constexpr int val = 0; - ASSERT_THAT( - setsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, sizeof(val)), - SyscallFailsWithErrno(ENOENT)); -} - -TEST_P(UdpSocketTest, GetSocketDetachFilter) { - int val = 0; - socklen_t val_len = sizeof(val); - ASSERT_THAT( - getsockopt(sock_.get(), SOL_SOCKET, SO_DETACH_FILTER, &val, &val_len), - SyscallFailsWithErrno(ENOPROTOOPT)); -} - -#endif // __fuchsia__ - -} // namespace testing -} // namespace gvisor diff --git a/test/syscalls/linux/udp_socket_test_cases.h b/test/syscalls/linux/udp_socket_test_cases.h deleted file mode 100644 index f7e25c805..000000000 --- a/test/syscalls/linux/udp_socket_test_cases.h +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ -#define THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ - -#include <sys/socket.h> - -#include "gtest/gtest.h" -#include "test/syscalls/linux/socket_test_util.h" -#include "test/util/file_descriptor.h" -#include "test/util/posix_error.h" - -namespace gvisor { -namespace testing { - -// The initial port to be be used on gvisor. -constexpr int TestPort = 40000; - -// Fixture for tests parameterized by the address family to use (AF_INET and -// AF_INET6) when creating sockets. -class UdpSocketTest - : public ::testing::TestWithParam<gvisor::testing::AddressFamily> { - protected: - // Creates two sockets that will be used by test cases. - void SetUp() override; - - // Binds the socket bind_ to the loopback and updates bind_addr_. - PosixError BindLoopback(); - - // Binds the socket bind_ to Any and updates bind_addr_. - PosixError BindAny(); - - // Binds given socket to address addr and updates. - PosixError BindSocket(int socket, struct sockaddr* addr); - - // Return initialized Any address to port 0. - struct sockaddr_storage InetAnyAddr(); - - // Return initialized Loopback address to port 0. - struct sockaddr_storage InetLoopbackAddr(); - - // Disconnects socket sockfd. - void Disconnect(int sockfd); - - // Get family for the test. - int GetFamily(); - - // Socket used by Bind methods - FileDescriptor bind_; - - // Second socket used for tests. - FileDescriptor sock_; - - // Address for bind_ socket. - struct sockaddr* bind_addr_; - - // Initialized to the length based on GetFamily(). - socklen_t addrlen_; - - // Storage for bind_addr_. - struct sockaddr_storage bind_addr_storage_; - - private: - // Helper to initialize addrlen_ for the test case. - socklen_t GetAddrLength(); -}; -} // namespace testing -} // namespace gvisor - -#endif // THIRD_PARTY_GOLANG_GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV4_UDP_UNBOUND_H_ diff --git a/test/syscalls/linux/unlink.cc b/test/syscalls/linux/unlink.cc index 2040375c9..061e2e0f1 100644 --- a/test/syscalls/linux/unlink.cc +++ b/test/syscalls/linux/unlink.cc @@ -208,6 +208,20 @@ TEST(RmdirTest, CanRemoveWithTrailingSlashes) { ASSERT_THAT(rmdir(slashslash.c_str()), SyscallSucceeds()); } +TEST(UnlinkTest, UnlinkAtEmptyPath) { + auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + auto file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + auto fd = ASSERT_NO_ERRNO_AND_VALUE(Open(file.path(), O_RDWR, 0666)); + EXPECT_THAT(unlinkat(fd.get(), "", 0), SyscallFailsWithErrno(ENOENT)); + + auto dirInDir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDirIn(dir.path())); + auto dirFD = ASSERT_NO_ERRNO_AND_VALUE( + Open(dirInDir.path(), O_RDONLY | O_DIRECTORY, 0666)); + EXPECT_THAT(unlinkat(dirFD.get(), "", AT_REMOVEDIR), + SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/vdso_clock_gettime.cc b/test/syscalls/linux/vdso_clock_gettime.cc index ce1899f45..2a8699a7b 100644 --- a/test/syscalls/linux/vdso_clock_gettime.cc +++ b/test/syscalls/linux/vdso_clock_gettime.cc @@ -38,8 +38,6 @@ std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) { switch (info.param) { case CLOCK_MONOTONIC: return "CLOCK_MONOTONIC"; - case CLOCK_REALTIME: - return "CLOCK_REALTIME"; case CLOCK_BOOTTIME: return "CLOCK_BOOTTIME"; default: @@ -47,59 +45,36 @@ std::string PrintClockId(::testing::TestParamInfo<clockid_t> info) { } } -class CorrectVDSOClockTest : public ::testing::TestWithParam<clockid_t> {}; +class MonotonicVDSOClockTest : public ::testing::TestWithParam<clockid_t> {}; -TEST_P(CorrectVDSOClockTest, IsCorrect) { +TEST_P(MonotonicVDSOClockTest, IsCorrect) { + // The VDSO implementation of clock_gettime() uses the TSC. On KVM, sentry and + // application TSCs can be very desynchronized; see + // sentry/platform/kvm/kvm.vCPU.setSystemTime(). + SKIP_IF(GvisorPlatform() == Platform::kKVM); + + // Check that when we alternate readings from the clock_gettime syscall and + // the VDSO's implementation, we observe the combined sequence as being + // monotonic. struct timespec tvdso, tsys; absl::Time vdso_time, sys_time; - uint64_t total_calls = 0; - - // It is expected that 82.5% of clock_gettime calls will be less than 100us - // skewed from the system time. - // Unfortunately this is not only influenced by the VDSO clock skew, but also - // by arbitrary scheduling delays and the like. The test is therefore - // regularly disabled. - std::map<absl::Duration, std::tuple<double, uint64_t, uint64_t>> confidence = - { - {absl::Microseconds(100), std::make_tuple(0.825, 0, 0)}, - {absl::Microseconds(250), std::make_tuple(0.94, 0, 0)}, - {absl::Milliseconds(1), std::make_tuple(0.999, 0, 0)}, - }; - - absl::Time start = absl::Now(); - while (absl::Now() < start + absl::Seconds(30)) { - EXPECT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds()); - EXPECT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), - SyscallSucceeds()); - + ASSERT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), + SyscallSucceeds()); + sys_time = absl::TimeFromTimespec(tsys); + auto end = absl::Now() + absl::Seconds(10); + while (absl::Now() < end) { + ASSERT_THAT(clock_gettime(GetParam(), &tvdso), SyscallSucceeds()); vdso_time = absl::TimeFromTimespec(tvdso); - - for (auto const& conf : confidence) { - std::get<1>(confidence[conf.first]) += - (sys_time - vdso_time) < conf.first; - } - + EXPECT_LE(sys_time, vdso_time); + ASSERT_THAT(syscall(__NR_clock_gettime, GetParam(), &tsys), + SyscallSucceeds()); sys_time = absl::TimeFromTimespec(tsys); - - for (auto const& conf : confidence) { - std::get<2>(confidence[conf.first]) += - (vdso_time - sys_time) < conf.first; - } - - ++total_calls; - } - - for (auto const& conf : confidence) { - EXPECT_GE(std::get<1>(conf.second) / static_cast<double>(total_calls), - std::get<0>(conf.second)); - EXPECT_GE(std::get<2>(conf.second) / static_cast<double>(total_calls), - std::get<0>(conf.second)); + EXPECT_LE(vdso_time, sys_time); } } -INSTANTIATE_TEST_SUITE_P(ClockGettime, CorrectVDSOClockTest, - ::testing::Values(CLOCK_MONOTONIC, CLOCK_REALTIME, - CLOCK_BOOTTIME), +INSTANTIATE_TEST_SUITE_P(ClockGettime, MonotonicVDSOClockTest, + ::testing::Values(CLOCK_MONOTONIC, CLOCK_BOOTTIME), PrintClockId); } // namespace diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index 39b5b2f56..77bcfbb8a 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -133,6 +133,91 @@ TEST_F(WriteTest, WriteExceedsRLimit) { EXPECT_THAT(close(fd), SyscallSucceeds()); } +TEST_F(WriteTest, WriteIncrementOffset) { + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + EXPECT_THAT(WriteBytes(fd, 0), SyscallSucceedsWithValue(0)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0)); + + const int bytes_total = 1024; + + EXPECT_THAT(WriteBytes(fd, bytes_total), + SyscallSucceedsWithValue(bytes_total)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); +} + +TEST_F(WriteTest, WriteIncrementOffsetSeek) { + const std::string data = "hello world\n"; + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + const int seek_offset = data.size() / 2; + ASSERT_THAT(lseek(fd, seek_offset, SEEK_SET), + SyscallSucceedsWithValue(seek_offset)); + + const int write_bytes = 512; + EXPECT_THAT(WriteBytes(fd, write_bytes), + SyscallSucceedsWithValue(write_bytes)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), + SyscallSucceedsWithValue(seek_offset + write_bytes)); +} + +TEST_F(WriteTest, WriteIncrementOffsetAppend) { + const std::string data = "hello world\n"; + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmpfile.path().c_str(), O_WRONLY | O_APPEND)); + int fd = f.get(); + + EXPECT_THAT(WriteBytes(fd, 1024), SyscallSucceedsWithValue(1024)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), + SyscallSucceedsWithValue(data.size() + 1024)); +} + +TEST_F(WriteTest, WriteIncrementOffsetEOF) { + const std::string data = "hello world\n"; + const TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + EXPECT_THAT(lseek(fd, 0, SEEK_END), SyscallSucceedsWithValue(data.size())); + + EXPECT_THAT(WriteBytes(fd, 1024), SyscallSucceedsWithValue(1024)); + EXPECT_THAT(lseek(fd, 0, SEEK_END), + SyscallSucceedsWithValue(data.size() + 1024)); +} + +TEST_F(WriteTest, PwriteNoChangeOffset) { + TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + int fd = f.get(); + + const std::string data = "hello world\n"; + + EXPECT_THAT(pwrite(fd, data.data(), data.size(), 0), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0)); + + const int bytes_total = 1024; + ASSERT_THAT(WriteBytes(fd, bytes_total), + SyscallSucceedsWithValue(bytes_total)); + ASSERT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); + + EXPECT_THAT(pwrite(fd, data.data(), data.size(), bytes_total), + SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/xattr.cc b/test/syscalls/linux/xattr.cc index cbcf08451..bd3f829c4 100644 --- a/test/syscalls/linux/xattr.cc +++ b/test/syscalls/linux/xattr.cc @@ -28,6 +28,7 @@ #include "test/syscalls/linux/file_base.h" #include "test/util/capability_util.h" #include "test/util/file_descriptor.h" +#include "test/util/fs_util.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -37,6 +38,8 @@ namespace testing { namespace { +using ::gvisor::testing::IsTmpfs; + class XattrTest : public FileTest {}; TEST_F(XattrTest, XattrNonexistentFile) { @@ -229,7 +232,7 @@ TEST_F(XattrTest, XattrOnInvalidFileTypes) { EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); } -TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { +TEST_F(XattrTest, SetXattrSizeSmallerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -244,7 +247,7 @@ TEST_F(XattrTest, SetxattrSizeSmallerThanValue) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrZeroSize) { +TEST_F(XattrTest, SetXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -256,7 +259,7 @@ TEST_F(XattrTest, SetxattrZeroSize) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, SetxattrSizeTooLarge) { +TEST_F(XattrTest, SetXattrSizeTooLarge) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; @@ -271,7 +274,7 @@ TEST_F(XattrTest, SetxattrSizeTooLarge) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { +TEST_F(XattrTest, SetXattrNullValueAndNonzeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 1, /*flags=*/0), @@ -280,7 +283,7 @@ TEST_F(XattrTest, SetxattrNullValueAndNonzeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { +TEST_F(XattrTest, SetXattrNullValueAndZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -288,7 +291,7 @@ TEST_F(XattrTest, SetxattrNullValueAndZeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { +TEST_F(XattrTest, SetXattrValueTooLargeButOKSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val(XATTR_SIZE_MAX + 1); @@ -304,7 +307,7 @@ TEST_F(XattrTest, SetxattrValueTooLargeButOKSize) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrReplaceWithSmaller) { +TEST_F(XattrTest, SetXattrReplaceWithSmaller) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -319,7 +322,7 @@ TEST_F(XattrTest, SetxattrReplaceWithSmaller) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, SetxattrReplaceWithLarger) { +TEST_F(XattrTest, SetXattrReplaceWithLarger) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -333,7 +336,7 @@ TEST_F(XattrTest, SetxattrReplaceWithLarger) { EXPECT_EQ(buf, val); } -TEST_F(XattrTest, SetxattrCreateFlag) { +TEST_F(XattrTest, SetXattrCreateFlag) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_CREATE), @@ -344,7 +347,7 @@ TEST_F(XattrTest, SetxattrCreateFlag) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrReplaceFlag) { +TEST_F(XattrTest, SetXattrReplaceFlag) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, XATTR_REPLACE), @@ -356,14 +359,14 @@ TEST_F(XattrTest, SetxattrReplaceFlag) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, SetxattrInvalidFlags) { +TEST_F(XattrTest, SetXattrInvalidFlags) { const char* path = test_file_name_.c_str(); int invalid_flags = 0xff; EXPECT_THAT(setxattr(path, nullptr, nullptr, 0, invalid_flags), SyscallFailsWithErrno(EINVAL)); } -TEST_F(XattrTest, Getxattr) { +TEST_F(XattrTest, GetXattr) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; int val = 1234; @@ -375,7 +378,7 @@ TEST_F(XattrTest, Getxattr) { EXPECT_EQ(buf, val); } -TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { +TEST_F(XattrTest, GetXattrSizeSmallerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; std::vector<char> val = {'a', 'a'}; @@ -387,7 +390,7 @@ TEST_F(XattrTest, GetxattrSizeSmallerThanValue) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, GetxattrSizeLargerThanValue) { +TEST_F(XattrTest, GetXattrSizeLargerThanValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -402,7 +405,7 @@ TEST_F(XattrTest, GetxattrSizeLargerThanValue) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, GetxattrZeroSize) { +TEST_F(XattrTest, GetXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -415,7 +418,7 @@ TEST_F(XattrTest, GetxattrZeroSize) { EXPECT_EQ(buf, '-'); } -TEST_F(XattrTest, GetxattrSizeTooLarge) { +TEST_F(XattrTest, GetXattrSizeTooLarge) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -431,7 +434,7 @@ TEST_F(XattrTest, GetxattrSizeTooLarge) { EXPECT_EQ(buf, expected_buf); } -TEST_F(XattrTest, GetxattrNullValue) { +TEST_F(XattrTest, GetXattrNullValue) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -442,7 +445,7 @@ TEST_F(XattrTest, GetxattrNullValue) { SyscallFailsWithErrno(EFAULT)); } -TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { +TEST_F(XattrTest, GetXattrNullValueAndZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; char val = 'a'; @@ -458,13 +461,13 @@ TEST_F(XattrTest, GetxattrNullValueAndZeroSize) { EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallSucceedsWithValue(size)); } -TEST_F(XattrTest, GetxattrNonexistentName) { +TEST_F(XattrTest, GetXattrNonexistentName) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(getxattr(path, name, nullptr, 0), SyscallFailsWithErrno(ENODATA)); } -TEST_F(XattrTest, Listxattr) { +TEST_F(XattrTest, ListXattr) { const char* path = test_file_name_.c_str(); const std::string name = "user.test"; const std::string name2 = "user.test2"; @@ -490,7 +493,7 @@ TEST_F(XattrTest, Listxattr) { EXPECT_EQ(got, expected); } -TEST_F(XattrTest, ListxattrNoXattrs) { +TEST_F(XattrTest, ListXattrNoXattrs) { const char* path = test_file_name_.c_str(); std::vector<char> list, expected; @@ -498,13 +501,13 @@ TEST_F(XattrTest, ListxattrNoXattrs) { SyscallSucceedsWithValue(0)); EXPECT_EQ(list, expected); - // Listxattr should succeed if there are no attributes, even if the buffer + // ListXattr should succeed if there are no attributes, even if the buffer // passed in is a nullptr. EXPECT_THAT(listxattr(path, nullptr, sizeof(list)), SyscallSucceedsWithValue(0)); } -TEST_F(XattrTest, ListxattrNullBuffer) { +TEST_F(XattrTest, ListXattrNullBuffer) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -513,7 +516,7 @@ TEST_F(XattrTest, ListxattrNullBuffer) { SyscallFailsWithErrno(EFAULT)); } -TEST_F(XattrTest, ListxattrSizeTooSmall) { +TEST_F(XattrTest, ListXattrSizeTooSmall) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -523,7 +526,7 @@ TEST_F(XattrTest, ListxattrSizeTooSmall) { SyscallFailsWithErrno(ERANGE)); } -TEST_F(XattrTest, ListxattrZeroSize) { +TEST_F(XattrTest, ListXattrZeroSize) { const char* path = test_file_name_.c_str(); const char name[] = "user.test"; EXPECT_THAT(setxattr(path, name, nullptr, 0, /*flags=*/0), SyscallSucceeds()); @@ -604,6 +607,83 @@ TEST_F(XattrTest, XattrWithFD) { EXPECT_THAT(fremovexattr(fd.get(), name), SyscallSucceeds()); } +TEST_F(XattrTest, TrustedNamespaceWithCapSysAdmin) { + // Trusted namespace not supported in VFS1. + SKIP_IF(IsRunningWithVFS1()); + + // TODO(b/66162845): Only gVisor tmpfs currently supports trusted namespace. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); + + const char* path = test_file_name_.c_str(); + const char name[] = "trusted.test"; + + // Writing to the trusted.* xattr namespace requires CAP_SYS_ADMIN in the root + // user namespace. There's no easy way to check that, other than trying the + // operation and seeing what happens. We'll call removexattr because it's + // simplest. + if (removexattr(path, name) < 0) { + SKIP_IF(errno == EPERM); + FAIL() << "unexpected errno from removexattr: " << errno; + } + + // Set. + char val = 'a'; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), SyscallSucceeds()); + + // Get. + char got = '\0'; + EXPECT_THAT(getxattr(path, name, &got, size), SyscallSucceedsWithValue(size)); + EXPECT_EQ(val, got); + + // List. + char list[sizeof(name)]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), + SyscallSucceedsWithValue(sizeof(name))); + EXPECT_STREQ(list, name); + + // Remove. + EXPECT_THAT(removexattr(path, name), SyscallSucceeds()); + + // Get should now return ENODATA. + EXPECT_THAT(getxattr(path, name, &got, size), SyscallFailsWithErrno(ENODATA)); +} + +TEST_F(XattrTest, TrustedNamespaceWithoutCapSysAdmin) { + // Trusted namespace not supported in VFS1. + SKIP_IF(IsRunningWithVFS1()); + + // TODO(b/66162845): Only gVisor tmpfs currently supports trusted namespace. + SKIP_IF(IsRunningOnGvisor() && + !ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(test_file_name_))); + + // Drop CAP_SYS_ADMIN if we have it. + if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))) { + EXPECT_NO_ERRNO(SetCapability(CAP_SYS_ADMIN, false)); + } + + const char* path = test_file_name_.c_str(); + const char name[] = "trusted.test"; + + // Set fails. + char val = 'a'; + size_t size = sizeof(val); + EXPECT_THAT(setxattr(path, name, &val, size, /*flags=*/0), + SyscallFailsWithErrno(EPERM)); + + // Get fails. + char got = '\0'; + EXPECT_THAT(getxattr(path, name, &got, size), SyscallFailsWithErrno(ENODATA)); + + // List still works, but returns no items. + char list[sizeof(name)]; + EXPECT_THAT(listxattr(path, list, sizeof(list)), SyscallSucceedsWithValue(0)); + + // Remove fails. + EXPECT_THAT(removexattr(path, name), SyscallFailsWithErrno(EPERM)); +} + } // namespace } // namespace testing diff --git a/test/util/BUILD b/test/util/BUILD index 2a17c33ee..fc5fb3a8d 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -46,6 +46,13 @@ cc_library( ) cc_library( + name = "fuse_util", + testonly = 1, + srcs = ["fuse_util.cc"], + hdrs = ["fuse_util.h"], +) + +cc_library( name = "proc_util", testonly = 1, srcs = ["proc_util.cc"], diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 5418948fe..b16055dd8 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -15,7 +15,11 @@ #include "test/util/fs_util.h" #include <dirent.h> +#ifdef __linux__ +#include <linux/magic.h> +#endif // __linux__ #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -629,5 +633,35 @@ PosixErrorOr<std::string> ProcessExePath(int pid) { return ReadLink(absl::StrCat("/proc/", pid, "/exe")); } +#ifdef __linux__ +PosixErrorOr<bool> IsTmpfs(const std::string& path) { + struct statfs stat; + if (statfs(path.c_str(), &stat)) { + if (errno == ENOENT) { + // Nothing at path, don't raise this as an error. Instead, just report no + // tmpfs at path. + return false; + } + return PosixError(errno, + absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); + } + return stat.f_type == TMPFS_MAGIC; +} +#endif // __linux__ + +PosixErrorOr<bool> IsOverlayfs(const std::string& path) { + struct statfs stat; + if (statfs(path.c_str(), &stat)) { + if (errno == ENOENT) { + // Nothing at path, don't raise this as an error. Instead, just report no + // overlayfs at path. + return false; + } + return PosixError(errno, + absl::StrFormat("statfs(\"%s\", %#p)", path, &stat)); + } + return stat.f_type == OVERLAYFS_SUPER_MAGIC; +} + } // namespace testing } // namespace gvisor diff --git a/test/util/fs_util.h b/test/util/fs_util.h index 8cdac23a1..c99cf5eb7 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -17,6 +17,7 @@ #include <dirent.h> #include <sys/stat.h> +#include <sys/statfs.h> #include <sys/types.h> #include <unistd.h> @@ -37,6 +38,10 @@ constexpr int kOLargeFile = 00400000; #error "Unknown architecture" #endif +// From linux/magic.h. For some reason, not defined in the headers for some +// build environments. +#define OVERLAYFS_SUPER_MAGIC 0x794c7630 + // Returns a status or the current working directory. PosixErrorOr<std::string> GetCWD(); @@ -178,6 +183,14 @@ std::string CleanPath(absl::string_view path); // Returns the full path to the executable of the given pid or a PosixError. PosixErrorOr<std::string> ProcessExePath(int pid); +#ifdef __linux__ +// IsTmpfs returns true if the file at path is backed by tmpfs. +PosixErrorOr<bool> IsTmpfs(const std::string& path); +#endif // __linux__ + +// IsOverlayfs returns true if the file at path is backed by overlayfs. +PosixErrorOr<bool> IsOverlayfs(const std::string& path); + namespace internal { // Not part of the public API. std::string JoinPathImpl(std::initializer_list<absl::string_view> paths); diff --git a/test/util/fuse_util.cc b/test/util/fuse_util.cc new file mode 100644 index 000000000..027f8386c --- /dev/null +++ b/test/util/fuse_util.cc @@ -0,0 +1,63 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/fuse_util.h" + +#include <sys/stat.h> +#include <sys/types.h> + +#include <string> + +namespace gvisor { +namespace testing { + +// Create a default FuseAttr struct with specified mode, inode, and size. +fuse_attr DefaultFuseAttr(mode_t mode, uint64_t inode, uint64_t size) { + const int time_sec = 1595436289; + const int time_nsec = 134150844; + return (struct fuse_attr){ + .ino = inode, + .size = size, + .blocks = 4, + .atime = time_sec, + .mtime = time_sec, + .ctime = time_sec, + .atimensec = time_nsec, + .mtimensec = time_nsec, + .ctimensec = time_nsec, + .mode = mode, + .nlink = 2, + .uid = 1234, + .gid = 4321, + .rdev = 12, + .blksize = 4096, + }; +} + +// Create response body with specified mode, nodeID, and size. +fuse_entry_out DefaultEntryOut(mode_t mode, uint64_t node_id, uint64_t size) { + struct fuse_entry_out default_entry_out = { + .nodeid = node_id, + .generation = 0, + .entry_valid = 0, + .attr_valid = 0, + .entry_valid_nsec = 0, + .attr_valid_nsec = 0, + .attr = DefaultFuseAttr(mode, node_id, size), + }; + return default_entry_out; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/fuse_util.h b/test/util/fuse_util.h new file mode 100644 index 000000000..544fe1b38 --- /dev/null +++ b/test/util/fuse_util.h @@ -0,0 +1,75 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_FUSE_UTIL_H_ +#define GVISOR_TEST_UTIL_FUSE_UTIL_H_ + +#include <linux/fuse.h> +#include <sys/uio.h> + +#include <string> +#include <vector> + +namespace gvisor { +namespace testing { + +// The fundamental generation function with a single argument. If passed by +// std::string or std::vector<char>, it will call specialized versions as +// implemented below. +template <typename T> +std::vector<struct iovec> FuseGenerateIovecs(T &first) { + return {(struct iovec){.iov_base = &first, .iov_len = sizeof(first)}}; +} + +// If an argument is of type std::string, it must be used in read-only scenario. +// Because we are setting up iovec, which contains the original address of a +// data structure, we have to drop const qualification. Usually used with +// variable-length payload data. +template <typename T = std::string> +std::vector<struct iovec> FuseGenerateIovecs(std::string &first) { + // Pad one byte for null-terminate c-string. + return {(struct iovec){.iov_base = const_cast<char *>(first.c_str()), + .iov_len = first.size() + 1}}; +} + +// If an argument is of type std::vector<char>, it must be used in write-only +// scenario and the size of the variable must be greater than or equal to the +// size of the expected data. Usually used with variable-length payload data. +template <typename T = std::vector<char>> +std::vector<struct iovec> FuseGenerateIovecs(std::vector<char> &first) { + return {(struct iovec){.iov_base = first.data(), .iov_len = first.size()}}; +} + +// A helper function to set up an array of iovec struct for testing purpose. +// Use variadic class template to generalize different numbers and different +// types of FUSE structs. +template <typename T, typename... Types> +std::vector<struct iovec> FuseGenerateIovecs(T &first, Types &...args) { + auto first_iovec = FuseGenerateIovecs(first); + auto iovecs = FuseGenerateIovecs(args...); + first_iovec.insert(std::end(first_iovec), std::begin(iovecs), + std::end(iovecs)); + return first_iovec; +} + +// Create a fuse_attr filled with the specified mode and inode. +fuse_attr DefaultFuseAttr(mode_t mode, uint64_t inode, uint64_t size = 512); + +// Return a fuse_entry_out FUSE server response body. +fuse_entry_out DefaultEntryOut(mode_t mode, uint64_t node_id, + uint64_t size = 512); + +} // namespace testing +} // namespace gvisor +#endif // GVISOR_TEST_UTIL_FUSE_UTIL_H_ diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc index c01f916aa..2cf0bea74 100644 --- a/test/util/pty_util.cc +++ b/test/util/pty_util.cc @@ -23,15 +23,15 @@ namespace gvisor { namespace testing { -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { - PosixErrorOr<int> n = SlaveID(master); +PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& master) { + PosixErrorOr<int> n = ReplicaID(master); if (!n.ok()) { return PosixErrorOr<FileDescriptor>(n.error()); } return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK); } -PosixErrorOr<int> SlaveID(const FileDescriptor& master) { +PosixErrorOr<int> ReplicaID(const FileDescriptor& master) { // Get pty index. int n; int ret = ioctl(master.get(), TIOCGPTN, &n); diff --git a/test/util/pty_util.h b/test/util/pty_util.h index 0722da379..ed7658868 100644 --- a/test/util/pty_util.h +++ b/test/util/pty_util.h @@ -21,11 +21,11 @@ namespace gvisor { namespace testing { -// Opens the slave end of the passed master as R/W and nonblocking. -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); +// Opens the replica end of the passed master as R/W and nonblocking. +PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& master); -// Get the number of the slave end of the master. -PosixErrorOr<int> SlaveID(const FileDescriptor& master); +// Get the number of the replica end of the master. +PosixErrorOr<int> ReplicaID(const FileDescriptor& master); } // namespace testing } // namespace gvisor diff --git a/test/util/test_util_runfiles.cc b/test/util/test_util_runfiles.cc index 694d21692..7210094eb 100644 --- a/test/util/test_util_runfiles.cc +++ b/test/util/test_util_runfiles.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef __fuchsia__ - #include <iostream> #include <string> @@ -46,5 +44,3 @@ std::string RunfilePath(std::string path) { } // namespace testing } // namespace gvisor - -#endif // __fuchsia__ |