summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.buildkite/pipeline.yaml6
-rw-r--r--.github/workflows/build.yml4
-rw-r--r--.github/workflows/go.yml4
-rw-r--r--Makefile7
-rw-r--r--WORKSPACE4
-rw-r--r--go.mod2
-rw-r--r--go.sum2
-rw-r--r--images/syzkaller/README.md6
-rw-r--r--nogo.yaml2
-rw-r--r--pkg/abi/linux/ptrace_amd64.go11
-rw-r--r--pkg/abi/linux/ptrace_arm64.go11
-rw-r--r--pkg/sentry/fs/proc/sys_net.go121
-rw-r--r--pkg/sentry/fsimpl/devpts/devpts.go5
-rw-r--r--pkg/sentry/fsimpl/ext/filesystem.go5
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go128
-rw-r--r--pkg/sentry/fsimpl/gofer/filesystem.go57
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go110
-rw-r--r--pkg/sentry/fsimpl/host/host.go5
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs_test.go5
-rw-r--r--pkg/sentry/fsimpl/overlay/filesystem.go12
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go5
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go5
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go78
-rw-r--r--pkg/sentry/fsimpl/sockfs/sockfs.go5
-rw-r--r--pkg/sentry/fsimpl/sys/sys.go5
-rw-r--r--pkg/sentry/fsimpl/tmpfs/filesystem.go5
-rw-r--r--pkg/sentry/fsimpl/tmpfs/tmpfs.go5
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go17
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go68
-rw-r--r--pkg/sentry/inet/inet.go8
-rw-r--r--pkg/sentry/inet/test_stack.go12
-rw-r--r--pkg/sentry/socket/hostinet/stack.go11
-rw-r--r--pkg/sentry/socket/netstack/stack.go10
-rw-r--r--pkg/sentry/vfs/anonfs.go5
-rw-r--r--pkg/sentry/vfs/filesystem.go9
-rw-r--r--pkg/sentry/vfs/mount.go8
-rw-r--r--pkg/syserr/netstack.go3
-rw-r--r--pkg/tcpip/checker/checker.go12
-rw-r--r--pkg/tcpip/errors.go13
-rw-r--r--pkg/tcpip/header/tcp.go28
-rw-r--r--pkg/tcpip/header/tcp_test.go20
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go8
-rw-r--r--pkg/tcpip/network/arp/arp.go2
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection.go17
-rw-r--r--pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go22
-rw-r--r--pkg/tcpip/network/ipv4/BUILD2
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go3
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go119
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go36
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go49
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go56
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go6
-rw-r--r--pkg/tcpip/ports/BUILD5
-rw-r--r--pkg/tcpip/ports/flags.go150
-rw-r--r--pkg/tcpip/ports/ports.go619
-rw-r--r--pkg/tcpip/ports/ports_test.go59
-rw-r--r--pkg/tcpip/stack/ndp_test.go105
-rw-r--r--pkg/tcpip/stack/registration.go44
-rw-r--r--pkg/tcpip/stack/stack.go12
-rw-r--r--pkg/tcpip/stack/stack_test.go4
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go6
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go74
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go2
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go28
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go20
-rw-r--r--pkg/tcpip/transport/tcp/connect.go8
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go72
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go12
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/tcp/segment.go6
-rw-r--r--pkg/tcpip/transport/tcp/snd.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go70
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go4
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go12
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go2
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go44
-rw-r--r--runsc/boot/compat.go12
-rw-r--r--runsc/boot/compat_amd64.go4
-rw-r--r--runsc/boot/controller.go4
-rw-r--r--runsc/boot/filter/config.go379
-rw-r--r--runsc/boot/filter/config_amd64.go33
-rw-r--r--runsc/boot/filter/config_arm64.go17
-rw-r--r--runsc/boot/filter/config_profile.go7
-rw-r--r--runsc/boot/filter/extra_filters_msan.go11
-rw-r--r--runsc/boot/filter/extra_filters_race.go25
-rw-r--r--runsc/boot/fs.go12
-rw-r--r--runsc/boot/limits.go8
-rw-r--r--runsc/boot/loader_test.go7
-rw-r--r--runsc/boot/network.go4
-rw-r--r--runsc/cgroup/cgroup.go9
-rw-r--r--runsc/cli/BUILD1
-rw-r--r--runsc/cli/main.go8
-rw-r--r--runsc/cmd/BUILD3
-rw-r--r--runsc/cmd/boot.go5
-rw-r--r--runsc/cmd/checkpoint.go4
-rw-r--r--runsc/cmd/chroot.go16
-rw-r--r--runsc/cmd/cmd.go10
-rw-r--r--runsc/cmd/debug.go6
-rw-r--r--runsc/cmd/do.go6
-rw-r--r--runsc/cmd/exec.go12
-rw-r--r--runsc/cmd/gofer.go43
-rw-r--r--runsc/cmd/kill.go7
-rw-r--r--runsc/cmd/mitigate.go122
-rw-r--r--runsc/cmd/mitigate_test.go169
-rw-r--r--runsc/cmd/restore.go4
-rw-r--r--runsc/cmd/run.go4
-rw-r--r--runsc/cmd/wait.go6
-rw-r--r--runsc/config/config.go3
-rw-r--r--runsc/config/flags.go1
-rw-r--r--runsc/container/BUILD1
-rw-r--r--runsc/container/console_test.go15
-rw-r--r--runsc/container/container.go35
-rw-r--r--runsc/container/container_test.go50
-rw-r--r--runsc/container/multi_container_test.go34
-rw-r--r--runsc/container/state_file.go6
-rw-r--r--runsc/fsgofer/filter/config.go153
-rw-r--r--runsc/fsgofer/filter/config_amd64.go35
-rw-r--r--runsc/fsgofer/filter/config_arm64.go19
-rw-r--r--runsc/fsgofer/filter/extra_filters_msan.go7
-rw-r--r--runsc/fsgofer/filter/extra_filters_race.go27
-rw-r--r--runsc/fsgofer/fsgofer.go21
-rw-r--r--runsc/fsgofer/fsgofer_test.go32
-rw-r--r--runsc/mitigate/BUILD22
-rw-r--r--runsc/mitigate/cpu.go423
-rw-r--r--runsc/mitigate/cpu_test.go605
-rw-r--r--runsc/mitigate/mitigate.go467
-rw-r--r--runsc/mitigate/mitigate_test.go579
-rw-r--r--runsc/mitigate/mock/BUILD11
-rw-r--r--runsc/mitigate/mock/mock.go141
-rw-r--r--runsc/sandbox/network.go25
-rw-r--r--runsc/sandbox/network_unsafe.go3
-rw-r--r--runsc/sandbox/sandbox.go35
-rw-r--r--runsc/specutils/fs.go78
-rw-r--r--runsc/specutils/namespace.go12
-rw-r--r--runsc/specutils/seccomp/BUILD2
-rw-r--r--runsc/specutils/seccomp/seccomp.go6
-rw-r--r--runsc/specutils/seccomp/seccomp_test.go40
-rw-r--r--runsc/specutils/specutils.go14
-rw-r--r--test/benchmarks/fs/BUILD1
-rw-r--r--test/benchmarks/fs/bazel_test.go61
-rw-r--r--test/benchmarks/fs/fio_test.go59
-rw-r--r--test/benchmarks/harness/BUILD1
-rw-r--r--test/benchmarks/harness/util.go52
-rw-r--r--test/benchmarks/tcp/tcp_proxy.go17
-rw-r--r--test/fuse/linux/mount_test.cc5
-rw-r--r--test/iptables/BUILD1
-rw-r--r--test/iptables/iptables_unsafe.go31
-rw-r--r--test/iptables/nat.go90
-rw-r--r--test/packetimpact/dut/posix_server.cc3
-rw-r--r--test/packetimpact/proto/posix_server.proto3
-rw-r--r--test/packetimpact/runner/defs.bzl15
-rw-r--r--test/packetimpact/runner/dut.go81
-rw-r--r--test/packetimpact/testbench/BUILD2
-rw-r--r--test/packetimpact/testbench/connections.go33
-rw-r--r--test/packetimpact/testbench/dut.go56
-rw-r--r--test/packetimpact/testbench/layers.go12
-rw-r--r--test/packetimpact/testbench/layers_test.go8
-rw-r--r--test/packetimpact/testbench/testbench.go62
-rw-r--r--test/packetimpact/tests/BUILD32
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go10
-rw-r--r--test/packetimpact/tests/ipv4_id_uniqueness_test.go2
-rw-r--r--test/packetimpact/tests/tcp_cork_mss_test.go12
-rw-r--r--test/packetimpact/tests/tcp_fin_retransmission_test.go87
-rw-r--r--test/packetimpact/tests/tcp_handshake_window_size_test.go8
-rw-r--r--test/packetimpact/tests/tcp_info_test.go2
-rw-r--r--test/packetimpact/tests/tcp_linger_test.go31
-rw-r--r--test/packetimpact/tests/tcp_network_unreachable_test.go11
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go2
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_closing_test.go86
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go12
-rw-r--r--test/packetimpact/tests/tcp_paws_mechanism_test.go14
-rw-r--r--test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go53
-rw-r--r--test/packetimpact/tests/tcp_rack_test.go20
-rw-r--r--test/packetimpact/tests/tcp_rcv_buf_space_test.go5
-rw-r--r--test/packetimpact/tests/tcp_retransmits_test.go36
-rw-r--r--test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go4
-rw-r--r--test/packetimpact/tests/tcp_synrcvd_reset_test.go10
-rw-r--r--test/packetimpact/tests/tcp_synsent_reset_test.go18
-rw-r--r--test/packetimpact/tests/tcp_timewait_reset_test.go14
-rw-r--r--test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go94
-rw-r--r--test/packetimpact/tests/tcp_unacc_seq_ack_test.go88
-rw-r--r--test/packetimpact/tests/tcp_user_timeout_test.go6
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go4
-rw-r--r--test/packetimpact/tests/tcp_zero_receive_window_test.go8
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go22
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_test.go12
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go12
-rw-r--r--test/packetimpact/tests/udp_discard_mcast_source_addr_test.go5
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go21
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go3
-rw-r--r--test/perf/BUILD2
-rw-r--r--test/runner/gtest/gtest.go50
-rw-r--r--test/runner/runner.go226
-rw-r--r--test/runtimes/proctor/BUILD5
-rw-r--r--test/runtimes/proctor/lib/BUILD1
-rw-r--r--test/runtimes/proctor/lib/lib.go7
-rw-r--r--test/runtimes/proctor/main.go8
-rw-r--r--test/syscalls/BUILD6
-rw-r--r--test/syscalls/linux/BUILD6
-rw-r--r--test/syscalls/linux/proc.cc23
-rw-r--r--test/syscalls/linux/proc_net.cc37
-rw-r--r--test/syscalls/linux/pty.cc7
-rw-r--r--test/syscalls/linux/socket_generic_stress.cc136
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc6
-rw-r--r--test/uds/BUILD1
-rw-r--r--test/uds/uds.go24
-rw-r--r--tools/go_marshal/gomarshal/generator.go8
-rw-r--r--tools/go_marshal/gomarshal/generator_tests.go12
-rwxr-xr-xtools/make_apt.sh2
-rw-r--r--tools/verity/BUILD15
-rw-r--r--tools/verity/measure_tool.go87
-rw-r--r--tools/verity/measure_tool_unsafe.go (renamed from runsc/mitigate/mitigate_conf.go)40
213 files changed, 5133 insertions, 3295 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml
index cb272aef6..aa2fd1f47 100644
--- a/.buildkite/pipeline.yaml
+++ b/.buildkite/pipeline.yaml
@@ -183,9 +183,13 @@ steps:
- <<: *benchmarks
label: ":metal: FFMPEG benchmarks"
command: make benchmark-platforms BENCHMARKS_SUITE=ffmpeg BENCHMARKS_TARGETS=test/benchmarks/media:ffmpeg_test
+ # For fio, running with --test.benchtime=Xs scales the written/read
+ # bytes to several GB. This is not a problem for root/bind/volume mounts,
+ # but for tmpfs mounts, the size can grow to more memory than the machine
+ # has availabe. Fix the runs to 10GB written/read for the benchmark.
- <<: *benchmarks
label: ":floppy_disk: FIO benchmarks"
- command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test
+ command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test BENCHMARKS_OPTIONS=--test.benchtime=10000x
- <<: *benchmarks
label: ":globe_with_meridians: HTTPD benchmarks"
command: make benchmark-platforms BENCHMARKS_FILTER="Continuous" BENCHMARKS_SUITE=httpd BENCHMARKS_TARGETS=test/benchmarks/network:httpd_test
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index b0381a563..b572dc94f 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -16,6 +16,10 @@ jobs:
default:
runs-on: ubuntu-latest
steps:
+ - name: Cancel previous
+ uses: styfle/cancel-workflow-action@0.7.0
+ with:
+ access_token: ${{ github.token }}
- uses: actions/checkout@v2
- run: make
- run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..."
diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml
index 594dc7ffc..4c8b8ea5c 100644
--- a/.github/workflows/go.yml
+++ b/.github/workflows/go.yml
@@ -16,6 +16,10 @@ jobs:
generate:
runs-on: ubuntu-latest
steps:
+ - name: Cancel previous
+ uses: styfle/cancel-workflow-action@0.7.0
+ with:
+ access_token: ${{ github.token }}
- id: setup
run: |
if ! [[ -z "${{ secrets.GO_TOKEN }}" ]]; then
diff --git a/Makefile b/Makefile
index de22509cd..0f79b6a18 100644
--- a/Makefile
+++ b/Makefile
@@ -143,6 +143,7 @@ dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo.
@$(call configure_noreload,$(RUNTIME)-d,--net-raw --debug --strace --log-packets)
@$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile)
@$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2)
+ @$(call configure_noreload,$(RUNTIME)-vfs2-fuse-d,--net-raw --debug --strace --log-packets --vfs2 --fuse)
@$(call reload_docker)
.PHONY: dev
@@ -325,6 +326,7 @@ containerd-tests: containerd-test-1.4.3
## BENCHMARKS_FILTER - filter to be applied to the test suite.
## BENCHMARKS_OPTIONS - options to be passed to the test.
## BENCHMARKS_PROFILE - profile options to be passed to the test.
+## BENCH_RUNTIME_ARGS - args to configure the runtime which runs the benchmarks.
##
BENCHMARKS_PROJECT ?= gvisor-benchmarks
BENCHMARKS_DATASET ?= kokoro
@@ -338,6 +340,7 @@ BENCHMARKS_FILTER := .
BENCHMARKS_OPTIONS := -test.benchtime=30s
BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) $(BENCHMARKS_OPTIONS)
BENCHMARKS_PROFILE := -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex
+BENCH_RUNTIME_ARGS ?= --vfs2
init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema.
@$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE))
@@ -358,13 +361,13 @@ run_benchmark = \
benchmark-platforms: load-benchmarks $(RUNTIME_BIN) ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS.
@$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \
- $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) --vfs2) && \
+ $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \
) true
@$(call run_benchmark,runc)
.PHONY: benchmark-platforms
run-benchmark: load-benchmarks $(RUNTIME_BIN) ## Runs single benchmark and optionally sends data to BigQuery.
- @$(call run_benchmark,$(RUNTIME),)
+ @$(call run_benchmark,$(RUNTIME),$(BENCH_RUNTIME_ARGS))
.PHONY: run-benchmark
##
diff --git a/WORKSPACE b/WORKSPACE
index 4ee93a670..779ee6ae6 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -303,8 +303,8 @@ go_repository(
go_repository(
name = "com_github_gofrs_flock",
importpath = "github.com/gofrs/flock",
- sum = "h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=",
- version = "v0.6.1-0.20180915234121-886344bea079",
+ sum = "h1:MSdYClljsF3PbENUUEx85nkWfJSGfzYI9yEBZOJz6CY=",
+ version = "v0.8.0",
)
go_repository(
diff --git a/go.mod b/go.mod
index 0774d2930..870fb0b83 100644
--- a/go.mod
+++ b/go.mod
@@ -20,7 +20,7 @@ require (
github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 // indirect
github.com/docker/go-connections v0.3.0 // indirect
github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect
- github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect
+ github.com/gofrs/flock v0.8.0 // indirect
github.com/gogo/googleapis v1.4.0 // indirect
github.com/gogo/protobuf v1.3.1 // indirect
github.com/golang/mock v1.4.4 // indirect
diff --git a/go.sum b/go.sum
index 9d7ef2243..bd3d0b0f7 100644
--- a/go.sum
+++ b/go.sum
@@ -135,6 +135,8 @@ github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME=
github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=
github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
+github.com/gofrs/flock v0.8.0 h1:MSdYClljsF3PbENUUEx85nkWfJSGfzYI9yEBZOJz6CY=
+github.com/gofrs/flock v0.8.0/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI=
github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c=
github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
diff --git a/images/syzkaller/README.md b/images/syzkaller/README.md
index 47e309422..7e500cab3 100644
--- a/images/syzkaller/README.md
+++ b/images/syzkaller/README.md
@@ -51,8 +51,8 @@ syzkaller repro in /tmp/syzkaller/repro.
Now we can run syz-repro to reproduce a crash:
```bash
-docker run --privileged -it --rm -v
- /tmp/syzkaller:/tmp/syzkaller --entrypoint=""
- gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config
+docker run --privileged -it --rm -v \
+ /tmp/syzkaller:/tmp/syzkaller --entrypoint="" \
+ gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config \
/tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro
```
diff --git a/nogo.yaml b/nogo.yaml
index 96e3aeccc..c0445a837 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -46,8 +46,6 @@ global:
- "(field|method|struct|type) .* should be .*"
# Generated proto code sometimes duplicates imports with aliases.
- "duplicate import"
- # TODO(b/179817829): Upgrade to flock to v0.8.0.
- - "flock.NewFlock is deprecated: Use New instead"
internal:
suppress:
# We use ALL_CAPS for system definitions,
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go
index ed3881e27..50e22fe7e 100644
--- a/pkg/abi/linux/ptrace_amd64.go
+++ b/pkg/abi/linux/ptrace_amd64.go
@@ -50,3 +50,14 @@ type PtraceRegs struct {
Fs uint64
Gs uint64
}
+
+// InstructionPointer returns the address of the next instruction to
+// be executed.
+func (p *PtraceRegs) InstructionPointer() uint64 {
+ return p.Rip
+}
+
+// StackPointer returns the address of the Stack pointer.
+func (p *PtraceRegs) StackPointer() uint64 {
+ return p.Rsp
+}
diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go
index 6147738b3..da36811d2 100644
--- a/pkg/abi/linux/ptrace_arm64.go
+++ b/pkg/abi/linux/ptrace_arm64.go
@@ -27,3 +27,14 @@ type PtraceRegs struct {
Pc uint64
Pstate uint64
}
+
+// InstructionPointer returns the address of the next instruction to be
+// executed.
+func (p *PtraceRegs) InstructionPointer() uint64 {
+ return p.Pc
+}
+
+// StackPointer returns the address of the Stack pointer.
+func (p *PtraceRegs) StackPointer() uint64 {
+ return p.Sp
+}
diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go
index 52061175f..bbe282c03 100644
--- a/pkg/sentry/fs/proc/sys_net.go
+++ b/pkg/sentry/fs/proc/sys_net.go
@@ -17,6 +17,7 @@ package proc
import (
"fmt"
"io"
+ "math"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -498,6 +500,120 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO
return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled)
}
+// portRangeInode implements fs.InodeOperations. It provides and allows
+// modification of the range of ephemeral ports that IPv4 and IPv6 sockets
+// choose from.
+//
+// +stateify savable
+type portRangeInode struct {
+ fsutil.SimpleFileInode
+
+ stack inet.Stack `state:"wait"`
+
+ // start and end store the port range. We must save/restore this here,
+ // since a netstack instance is created on restore.
+ start *uint16
+ end *uint16
+}
+
+func newPortRangeInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
+ ipf := &portRangeInode{
+ SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC),
+ stack: s,
+ }
+ sattr := fs.StableAttr{
+ DeviceID: device.ProcDevice.DeviceID(),
+ InodeID: device.ProcDevice.NextIno(),
+ BlockSize: usermem.PageSize,
+ Type: fs.SpecialFile,
+ }
+ return fs.NewInode(ctx, ipf, msrc, sattr)
+}
+
+// Truncate implements fs.InodeOperations.Truncate. Truncate is called when
+// O_TRUNC is specified for any kind of existing Dirent but is not called via
+// (f)truncate for proc files.
+func (*portRangeInode) Truncate(context.Context, *fs.Inode, int64) error {
+ return nil
+}
+
+// +stateify savable
+type portRangeFile struct {
+ fsutil.FileGenericSeek `state:"nosave"`
+ fsutil.FileNoIoctl `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoopFsync `state:"nosave"`
+ fsutil.FileNoopRelease `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ waiter.AlwaysReady `state:"nosave"`
+
+ inode *portRangeInode
+}
+
+// GetFile implements fs.InodeOperations.GetFile.
+func (in *portRangeInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ flags.Pread = true
+ flags.Pwrite = true
+ return fs.NewFile(ctx, dirent, flags, &portRangeFile{
+ inode: in,
+ }), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (pf *portRangeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ return 0, io.EOF
+ }
+
+ if pf.inode.start == nil {
+ start, end := pf.inode.stack.PortRange()
+ pf.inode.start = &start
+ pf.inode.end = &end
+ }
+
+ contents := fmt.Sprintf("%d %d\n", *pf.inode.start, *pf.inode.end)
+ n, err := dst.CopyOut(ctx, []byte(contents))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+//
+// Offset is ignored, multiple writes are not supported.
+func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) {
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Only consider size of one memory page for input for performance
+ // reasons.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ ports := make([]int32, 2)
+ n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ // Port numbers must be uint16s.
+ if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 {
+ return 0, syserror.EINVAL
+ }
+
+ if err := pf.inode.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil {
+ return 0, err
+ }
+ if pf.inode.start == nil {
+ pf.inode.start = new(uint16)
+ pf.inode.end = new(uint16)
+ }
+ *pf.inode.start = uint16(ports[0])
+ *pf.inode.end = uint16(ports[1])
+ return n, nil
+}
+
func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode {
contents := map[string]*fs.Inode{
// Add tcp_sack.
@@ -506,12 +622,15 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine
// Add ip_forward.
"ip_forward": newIPForwardingInode(ctx, msrc, s),
+ // Allow for configurable ephemeral port ranges. Note that this
+ // controls ports for both IPv4 and IPv6 sockets.
+ "ip_local_port_range": newPortRangeInode(ctx, msrc, s),
+
// The following files are simple stubs until they are
// implemented in netstack, most of these files are
// configuration related. We use the value closest to the
// actual netstack behavior or any empty file, all of these
// files will have mode 0444 (read-only for all users).
- "ip_local_port_range": newStaticProcInode(ctx, msrc, []byte("16000 65535")),
"ip_local_reserved_ports": newStaticProcInode(ctx, msrc, []byte("")),
"ipfrag_time": newStaticProcInode(ctx, msrc, []byte("30")),
"ip_nonlocal_bind": newStaticProcInode(ctx, msrc, []byte("0")),
diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go
index d8c237753..e75954105 100644
--- a/pkg/sentry/fsimpl/devpts/devpts.go
+++ b/pkg/sentry/fsimpl/devpts/devpts.go
@@ -137,6 +137,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.Release(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// rootInode is the root directory inode for the devpts mounts.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go
index 917f1873d..d4fc484a2 100644
--- a/pkg/sentry/fsimpl/ext/filesystem.go
+++ b/pkg/sentry/fsimpl/ext/filesystem.go
@@ -548,3 +548,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.mu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 204d8d143..fef857afb 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -47,19 +47,14 @@ type FilesystemType struct{}
// +stateify savable
type filesystemOptions struct {
- // userID specifies the numeric uid of the mount owner.
- // This option should not be specified by the filesystem owner.
- // It is set by libfuse (or, if libfuse is not used, must be set
- // by the filesystem itself). For more information, see man page
- // for fuse(8)
- userID uint32
-
- // groupID specifies the numeric gid of the mount owner.
- // This option should not be specified by the filesystem owner.
- // It is set by libfuse (or, if libfuse is not used, must be set
- // by the filesystem itself). For more information, see man page
- // for fuse(8)
- groupID uint32
+ // mopts contains the raw, unparsed mount options passed to this filesystem.
+ mopts string
+
+ // uid of the mount owner.
+ uid auth.KUID
+
+ // gid of the mount owner.
+ gid auth.KGID
// rootMode specifies the the file mode of the filesystem's root.
rootMode linux.FileMode
@@ -73,6 +68,19 @@ type filesystemOptions struct {
// specified as "max_read" in fs parameters.
// If not specified by user, use math.MaxUint32 as default value.
maxRead uint32
+
+ // defaultPermissions is the default_permissions mount option. It instructs
+ // the kernel to perform a standard unix permission checks based on
+ // ownership and mode bits, instead of deferring the check to the server.
+ //
+ // Immutable after mount.
+ defaultPermissions bool
+
+ // allowOther is the allow_other mount option. It allows processes that
+ // don't own the FUSE mount to call into it.
+ //
+ // Immutable after mount.
+ allowOther bool
}
// filesystem implements vfs.FilesystemImpl.
@@ -108,18 +116,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, err
}
- var fsopts filesystemOptions
+ fsopts := filesystemOptions{mopts: opts.Data}
mopts := vfs.GenericParseMountOptions(opts.Data)
deviceDescriptorStr, ok := mopts["fd"]
if !ok {
- log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name())
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option fd missing")
return nil, nil, syserror.EINVAL
}
delete(mopts, "fd")
deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
if err != nil {
- log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err)
+ ctx.Debugf("fusefs.FilesystemType.GetFilesystem: invalid fd: %q (%v)", deviceDescriptorStr, err)
return nil, nil, syserror.EINVAL
}
@@ -141,38 +149,54 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Parse and set all the other supported FUSE mount options.
// TODO(gVisor.dev/issue/3229): Expand the supported mount options.
- if userIDStr, ok := mopts["user_id"]; ok {
+ if uidStr, ok := mopts["user_id"]; ok {
delete(mopts, "user_id")
- userID, err := strconv.ParseUint(userIDStr, 10, 32)
+ uid, err := strconv.ParseUint(uidStr, 10, 32)
if err != nil {
- log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr)
+ log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), uidStr)
return nil, nil, syserror.EINVAL
}
- fsopts.userID = uint32(userID)
+ kuid := creds.UserNamespace.MapToKUID(auth.UID(uid))
+ if !kuid.Ok() {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped uid: %d", uid)
+ return nil, nil, syserror.EINVAL
+ }
+ fsopts.uid = kuid
+ } else {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option user_id missing")
+ return nil, nil, syserror.EINVAL
}
- if groupIDStr, ok := mopts["group_id"]; ok {
+ if gidStr, ok := mopts["group_id"]; ok {
delete(mopts, "group_id")
- groupID, err := strconv.ParseUint(groupIDStr, 10, 32)
+ gid, err := strconv.ParseUint(gidStr, 10, 32)
if err != nil {
- log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr)
+ log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), gidStr)
+ return nil, nil, syserror.EINVAL
+ }
+ kgid := creds.UserNamespace.MapToKGID(auth.GID(gid))
+ if !kgid.Ok() {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped gid: %d", gid)
return nil, nil, syserror.EINVAL
}
- fsopts.groupID = uint32(groupID)
+ fsopts.gid = kgid
+ } else {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option group_id missing")
+ return nil, nil, syserror.EINVAL
}
- rootMode := linux.FileMode(0777)
- modeStr, ok := mopts["rootmode"]
- if ok {
+ if modeStr, ok := mopts["rootmode"]; ok {
delete(mopts, "rootmode")
mode, err := strconv.ParseUint(modeStr, 8, 32)
if err != nil {
log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr)
return nil, nil, syserror.EINVAL
}
- rootMode = linux.FileMode(mode)
+ fsopts.rootMode = linux.FileMode(mode)
+ } else {
+ ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option rootmode missing")
+ return nil, nil, syserror.EINVAL
}
- fsopts.rootMode = rootMode
// Set the maxInFlightRequests option.
fsopts.maxActiveRequests = maxActiveRequestsDefault
@@ -192,6 +216,16 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts.maxRead = math.MaxUint32
}
+ if _, ok := mopts["default_permissions"]; ok {
+ delete(mopts, "default_permissions")
+ fsopts.defaultPermissions = true
+ }
+
+ if _, ok := mopts["allow_other"]; ok {
+ delete(mopts, "allow_other")
+ fsopts.allowOther = true
+ }
+
// Check for unparsed options.
if len(mopts) != 0 {
log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts)
@@ -260,6 +294,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.Release(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return fs.opts.mopts
+}
+
// inode implements kernfs.Inode.
//
// +stateify savable
@@ -318,6 +357,37 @@ func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FU
return i
}
+// CheckPermissions implements kernfs.Inode.CheckPermissions.
+func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
+ // Since FUSE operations are ultimately backed by a userspace process (the
+ // fuse daemon), allowing a process to call into fusefs grants the daemon
+ // ptrace-like capabilities over the calling process. Because of this, by
+ // default FUSE only allows the mount owner to interact with the
+ // filesystem. This explicitly excludes setuid/setgid processes.
+ //
+ // This behaviour can be overriden with the 'allow_other' mount option.
+ //
+ // See fs/fuse/dir.c:fuse_allow_current_process() in Linux.
+ if !i.fs.opts.allowOther {
+ if creds.RealKUID != i.fs.opts.uid ||
+ creds.EffectiveKUID != i.fs.opts.uid ||
+ creds.SavedKUID != i.fs.opts.uid ||
+ creds.RealKGID != i.fs.opts.gid ||
+ creds.EffectiveKGID != i.fs.opts.gid ||
+ creds.SavedKGID != i.fs.opts.gid {
+ return syserror.EACCES
+ }
+ }
+
+ // By default, fusefs delegates all permission checks to the server.
+ // However, standard unix permission checks can be enabled with the
+ // default_permissions mount option.
+ if i.fs.opts.defaultPermissions {
+ return i.InodeAttrs.CheckPermissions(ctx, creds, ats)
+ }
+ return nil
+}
+
// Open implements kernfs.Inode.Open.
func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
isDir := i.InodeAttrs.Mode().IsDir()
diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go
index 8f95473b6..c34451269 100644
--- a/pkg/sentry/fsimpl/gofer/filesystem.go
+++ b/pkg/sentry/fsimpl/gofer/filesystem.go
@@ -15,7 +15,9 @@
package gofer
import (
+ "fmt"
"math"
+ "strings"
"sync"
"sync/atomic"
@@ -1608,3 +1610,58 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
+
+type mopt struct {
+ key string
+ value interface{}
+}
+
+func (m mopt) String() string {
+ if m.value == nil {
+ return fmt.Sprintf("%s", m.key)
+ }
+ return fmt.Sprintf("%s=%v", m.key, m.value)
+}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ optsKV := []mopt{
+ {moptTransport, transportModeFD}, // Only valid value, currently.
+ {moptReadFD, fs.opts.fd}, // Currently, read and write FD are the same.
+ {moptWriteFD, fs.opts.fd}, // Currently, read and write FD are the same.
+ {moptAname, fs.opts.aname},
+ {moptDfltUID, fs.opts.dfltuid},
+ {moptDfltGID, fs.opts.dfltgid},
+ {moptMsize, fs.opts.msize},
+ {moptVersion, fs.opts.version},
+ {moptDentryCacheLimit, fs.opts.maxCachedDentries},
+ }
+
+ switch fs.opts.interop {
+ case InteropModeExclusive:
+ optsKV = append(optsKV, mopt{moptCache, cacheFSCache})
+ case InteropModeWritethrough:
+ optsKV = append(optsKV, mopt{moptCache, cacheFSCacheWritethrough})
+ case InteropModeShared:
+ if fs.opts.regularFilesUseSpecialFileFD {
+ optsKV = append(optsKV, mopt{moptCache, cacheNone})
+ } else {
+ optsKV = append(optsKV, mopt{moptCache, cacheRemoteRevalidating})
+ }
+ }
+ if fs.opts.forcePageCache {
+ optsKV = append(optsKV, mopt{moptForcePageCache, nil})
+ }
+ if fs.opts.limitHostFDTranslation {
+ optsKV = append(optsKV, mopt{moptLimitHostFDTranslation, nil})
+ }
+ if fs.opts.overlayfsStaleRead {
+ optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil})
+ }
+
+ opts := make([]string, 0, len(optsKV))
+ for _, opt := range optsKV {
+ opts = append(opts, opt.String())
+ }
+ return strings.Join(opts, ",")
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index 1508cbdf1..71569dc65 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -66,6 +66,34 @@ import (
// Name is the default filesystem name.
const Name = "9p"
+// Mount option names for goferfs.
+const (
+ moptTransport = "trans"
+ moptReadFD = "rfdno"
+ moptWriteFD = "wfdno"
+ moptAname = "aname"
+ moptDfltUID = "dfltuid"
+ moptDfltGID = "dfltgid"
+ moptMsize = "msize"
+ moptVersion = "version"
+ moptDentryCacheLimit = "dentry_cache_limit"
+ moptCache = "cache"
+ moptForcePageCache = "force_page_cache"
+ moptLimitHostFDTranslation = "limit_host_fd_translation"
+ moptOverlayfsStaleRead = "overlayfs_stale_read"
+)
+
+// Valid values for the "cache" mount option.
+const (
+ cacheNone = "none"
+ cacheFSCache = "fscache"
+ cacheFSCacheWritethrough = "fscache_writethrough"
+ cacheRemoteRevalidating = "remote_revalidating"
+)
+
+// Valid values for "trans" mount option.
+const transportModeFD = "fd"
+
// FilesystemType implements vfs.FilesystemType.
//
// +stateify savable
@@ -301,39 +329,39 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Get the attach name.
fsopts.aname = "/"
- if aname, ok := mopts["aname"]; ok {
- delete(mopts, "aname")
+ if aname, ok := mopts[moptAname]; ok {
+ delete(mopts, moptAname)
fsopts.aname = aname
}
// Parse the cache policy. For historical reasons, this defaults to the
// least generally-applicable option, InteropModeExclusive.
fsopts.interop = InteropModeExclusive
- if cache, ok := mopts["cache"]; ok {
- delete(mopts, "cache")
+ if cache, ok := mopts[moptCache]; ok {
+ delete(mopts, moptCache)
switch cache {
- case "fscache":
+ case cacheFSCache:
fsopts.interop = InteropModeExclusive
- case "fscache_writethrough":
+ case cacheFSCacheWritethrough:
fsopts.interop = InteropModeWritethrough
- case "none":
+ case cacheNone:
fsopts.regularFilesUseSpecialFileFD = true
fallthrough
- case "remote_revalidating":
+ case cacheRemoteRevalidating:
fsopts.interop = InteropModeShared
default:
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: cache=%s", cache)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: %s=%s", moptCache, cache)
return nil, nil, syserror.EINVAL
}
}
// Parse the default UID and GID.
fsopts.dfltuid = _V9FS_DEFUID
- if dfltuidstr, ok := mopts["dfltuid"]; ok {
- delete(mopts, "dfltuid")
+ if dfltuidstr, ok := mopts[moptDfltUID]; ok {
+ delete(mopts, moptDfltUID)
dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltUID, dfltuidstr)
return nil, nil, syserror.EINVAL
}
// In Linux, dfltuid is interpreted as a UID and is converted to a KUID
@@ -342,11 +370,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fsopts.dfltuid = auth.KUID(dfltuid)
}
fsopts.dfltgid = _V9FS_DEFGID
- if dfltgidstr, ok := mopts["dfltgid"]; ok {
- delete(mopts, "dfltgid")
+ if dfltgidstr, ok := mopts[moptDfltGID]; ok {
+ delete(mopts, moptDfltGID)
dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltGID, dfltgidstr)
return nil, nil, syserror.EINVAL
}
fsopts.dfltgid = auth.KGID(dfltgid)
@@ -354,11 +382,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Parse the 9P message size.
fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M
- if msizestr, ok := mopts["msize"]; ok {
- delete(mopts, "msize")
+ if msizestr, ok := mopts[moptMsize]; ok {
+ delete(mopts, moptMsize)
msize, err := strconv.ParseUint(msizestr, 10, 32)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: msize=%s", msizestr)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: %s=%s", moptMsize, msizestr)
return nil, nil, syserror.EINVAL
}
fsopts.msize = uint32(msize)
@@ -366,34 +394,34 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Parse the 9P protocol version.
fsopts.version = p9.HighestVersionString()
- if version, ok := mopts["version"]; ok {
- delete(mopts, "version")
+ if version, ok := mopts[moptVersion]; ok {
+ delete(mopts, moptVersion)
fsopts.version = version
}
// Parse the dentry cache limit.
fsopts.maxCachedDentries = 1000
- if str, ok := mopts["dentry_cache_limit"]; ok {
- delete(mopts, "dentry_cache_limit")
+ if str, ok := mopts[moptDentryCacheLimit]; ok {
+ delete(mopts, moptDentryCacheLimit)
maxCachedDentries, err := strconv.ParseUint(str, 10, 64)
if err != nil {
- ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str)
return nil, nil, syserror.EINVAL
}
fsopts.maxCachedDentries = maxCachedDentries
}
// Handle simple flags.
- if _, ok := mopts["force_page_cache"]; ok {
- delete(mopts, "force_page_cache")
+ if _, ok := mopts[moptForcePageCache]; ok {
+ delete(mopts, moptForcePageCache)
fsopts.forcePageCache = true
}
- if _, ok := mopts["limit_host_fd_translation"]; ok {
- delete(mopts, "limit_host_fd_translation")
+ if _, ok := mopts[moptLimitHostFDTranslation]; ok {
+ delete(mopts, moptLimitHostFDTranslation)
fsopts.limitHostFDTranslation = true
}
- if _, ok := mopts["overlayfs_stale_read"]; ok {
- delete(mopts, "overlayfs_stale_read")
+ if _, ok := mopts[moptOverlayfsStaleRead]; ok {
+ delete(mopts, moptOverlayfsStaleRead)
fsopts.overlayfsStaleRead = true
}
// fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying
@@ -469,34 +497,34 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) {
// Check that the transport is "fd".
- trans, ok := mopts["trans"]
- if !ok || trans != "fd" {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'")
+ trans, ok := mopts[moptTransport]
+ if !ok || trans != transportModeFD {
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as '%s=%s'", moptTransport, transportModeFD)
return -1, syserror.EINVAL
}
- delete(mopts, "trans")
+ delete(mopts, moptTransport)
// Check that read and write FDs are provided and identical.
- rfdstr, ok := mopts["rfdno"]
+ rfdstr, ok := mopts[moptReadFD]
if !ok {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'")
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as '%s=<file descriptor>'", moptReadFD)
return -1, syserror.EINVAL
}
- delete(mopts, "rfdno")
+ delete(mopts, moptReadFD)
rfd, err := strconv.Atoi(rfdstr)
if err != nil {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr)
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: %s=%s", moptReadFD, rfdstr)
return -1, syserror.EINVAL
}
- wfdstr, ok := mopts["wfdno"]
+ wfdstr, ok := mopts[moptWriteFD]
if !ok {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'")
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as '%s=<file descriptor>'", moptWriteFD)
return -1, syserror.EINVAL
}
- delete(mopts, "wfdno")
+ delete(mopts, moptWriteFD)
wfd, err := strconv.Atoi(wfdstr)
if err != nil {
- ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr)
+ ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: %s=%s", moptWriteFD, wfdstr)
return -1, syserror.EINVAL
}
if rfd != wfd {
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index ad5de80dc..b9cce4181 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -260,6 +260,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// CheckPermissions implements kernfs.Inode.CheckPermissions.
func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error {
var s unix.Stat_t
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
index e63588e33..1cd3137e6 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go
@@ -67,6 +67,11 @@ type filesystem struct {
kernfs.Filesystem
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
type file struct {
kernfs.DynamicBytesFile
content string
diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go
index 917709d75..84e37f793 100644
--- a/pkg/sentry/fsimpl/overlay/filesystem.go
+++ b/pkg/sentry/fsimpl/overlay/filesystem.go
@@ -1764,3 +1764,15 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
defer fs.renameMu.RUnlock()
return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b)
}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ // Return the mount options from the topmost layer.
+ var vd vfs.VirtualDentry
+ if fs.opts.UpperRoot.Ok() {
+ vd = fs.opts.UpperRoot
+ } else {
+ vd = fs.opts.LowerRoots[0]
+ }
+ return vd.Mount().Filesystem().Impl().MountOptions()
+}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 429733c10..3f05e444e 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -80,6 +80,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// inode implements kernfs.Inode.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 8716d0a3c..254a8b062 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -104,6 +104,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.Release(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return fmt.Sprintf("dentry_cache_limit=%d", fs.MaxCachedDentries)
+}
+
// dynamicInode is an overfitted interface for common Inodes with
// dynamicByteSource types used in procfs.
//
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index fd7823daa..fb274b78e 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -17,6 +17,7 @@ package proc
import (
"bytes"
"fmt"
+ "math"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
@@ -69,17 +70,17 @@ func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials,
if stack := k.RootNetworkNamespace().Stack(); stack != nil {
contents = map[string]kernfs.Inode{
"ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
- "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}),
- "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
- "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}),
- "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
- "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}),
+ "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}),
+ "ip_local_port_range": fs.newInode(ctx, root, 0644, &portRange{stack: stack}),
+ "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}),
+ "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}),
+ "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}),
+ "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}),
// The following files are simple stubs until they are implemented in
// netstack, most of these files are configuration related. We use the
// value closest to the actual netstack behavior or any empty file, all
// of these files will have mode 0444 (read-only for all users).
- "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")),
"ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")),
"ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")),
"ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")),
@@ -421,3 +422,68 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs
}
return n, nil
}
+
+// portRange implements vfs.WritableDynamicBytesSource for
+// /proc/sys/net/ipv4/ip_local_port_range.
+//
+// +stateify savable
+type portRange struct {
+ kernfs.DynamicBytesFile
+
+ stack inet.Stack `state:"wait"`
+
+ // start and end store the port range. We must save/restore this here,
+ // since a netstack instance is created on restore.
+ start *uint16
+ end *uint16
+}
+
+var _ vfs.WritableDynamicBytesSource = (*portRange)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if pr.start == nil {
+ start, end := pr.stack.PortRange()
+ pr.start = &start
+ pr.end = &end
+ }
+ _, err := fmt.Fprintf(buf, "%d %d\n", *pr.start, *pr.end)
+ return err
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ if offset != 0 {
+ // No need to handle partial writes thus far.
+ return 0, syserror.EINVAL
+ }
+ if src.NumBytes() == 0 {
+ return 0, nil
+ }
+
+ // Limit input size so as not to impact performance if input size is
+ // large.
+ src = src.TakeFirst(usermem.PageSize - 1)
+
+ ports := make([]int32, 2)
+ n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts)
+ if err != nil {
+ return 0, err
+ }
+
+ // Port numbers must be uint16s.
+ if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 {
+ return 0, syserror.EINVAL
+ }
+
+ if err := pr.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil {
+ return 0, err
+ }
+ if pr.start == nil {
+ pr.start = new(uint16)
+ pr.end = new(uint16)
+ }
+ *pr.start = uint16(ports[0])
+ *pr.end = uint16(ports[1])
+ return n, nil
+}
diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go
index fda1fa942..735756280 100644
--- a/pkg/sentry/fsimpl/sockfs/sockfs.go
+++ b/pkg/sentry/fsimpl/sockfs/sockfs.go
@@ -85,6 +85,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
return vfs.PrependPathSyntheticError{}
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// inode implements kernfs.Inode.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go
index dbd9ebdda..1d9280dae 100644
--- a/pkg/sentry/fsimpl/sys/sys.go
+++ b/pkg/sentry/fsimpl/sys/sys.go
@@ -143,6 +143,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.Filesystem.Release(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return fmt.Sprintf("dentry_cache_limit=%d", fs.MaxCachedDentries)
+}
+
// dir implements kernfs.Inode.
//
// +stateify savable
diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go
index 4f675c21e..5fdca1d46 100644
--- a/pkg/sentry/fsimpl/tmpfs/filesystem.go
+++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go
@@ -898,3 +898,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe
d = d.parent
}
}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return fs.mopts
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
index a01e413e0..8df81f589 100644
--- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go
+++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go
@@ -70,6 +70,10 @@ type filesystem struct {
// devMinor is the filesystem's minor device number. devMinor is immutable.
devMinor uint32
+ // mopts contains the tmpfs-specific mount options passed to this
+ // filesystem. Immutable.
+ mopts string
+
// mu serializes changes to the Dentry tree.
mu sync.RWMutex `state:"nosave"`
@@ -184,6 +188,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
mfp: mfp,
clock: clock,
devMinor: devMinor,
+ mopts: opts.Data,
}
fs.vfsfs.Init(vfsObj, newFSType, &fs)
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 9057d2b4e..6cb1a23e0 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -590,6 +590,23 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
return nil, err
}
+ // Clear the Merkle tree file if they are to be generated at runtime.
+ // TODO(b/182315468): Optimize the Merkle tree generate process to
+ // allow only updating certain files/directories.
+ if fs.allowRuntimeEnable {
+ childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: childMerkleVD,
+ Start: childMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_TRUNC,
+ Mode: 0644,
+ })
+ if err != nil {
+ return nil, err
+ }
+ childMerkleFD.DecRef(ctx)
+ }
+
// The dentry needs to be cleaned up if any error occurs. IncRef will be
// called if a verity child dentry is successfully created.
defer childMerkleVD.DecRef(ctx)
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index 374f71568..0d9b0ee2c 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -38,6 +38,7 @@ import (
"fmt"
"math"
"strconv"
+ "strings"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -310,6 +311,24 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
d.DecRef(ctx)
return nil, nil, alertIntegrityViolation("Failed to find root Merkle file")
}
+
+ // Clear the Merkle tree file if they are to be generated at runtime.
+ // TODO(b/182315468): Optimize the Merkle tree generate process to
+ // allow only updating certain files/directories.
+ if fs.allowRuntimeEnable {
+ lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{
+ Root: lowerMerkleVD,
+ Start: lowerMerkleVD,
+ }, &vfs.OpenOptions{
+ Flags: linux.O_RDWR | linux.O_TRUNC,
+ Mode: 0644,
+ })
+ if err != nil {
+ return nil, nil, err
+ }
+ lowerMerkleFD.DecRef(ctx)
+ }
+
d.lowerMerkleVD = lowerMerkleVD
// Get metadata from the underlying file system.
@@ -418,6 +437,11 @@ func (fs *filesystem) Release(ctx context.Context) {
fs.lowerMount.DecRef(ctx)
}
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ return ""
+}
+
// dentry implements vfs.DentryImpl.
//
// +stateify savable
@@ -750,6 +774,50 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions)
return syserror.EPERM
}
+// IterDirents implements vfs.FileDescriptionImpl.IterDirents.
+func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error {
+ if !fd.d.isDir() {
+ return syserror.ENOTDIR
+ }
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+
+ var ds []vfs.Dirent
+ err := fd.lowerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error {
+ // Do not include the Merkle tree files.
+ if strings.Contains(dirent.Name, merklePrefix) || strings.Contains(dirent.Name, merkleRootPrefix) {
+ return nil
+ }
+ if fd.d.verityEnabled() {
+ // Verify that the child is expected.
+ if dirent.Name != "." && dirent.Name != ".." {
+ if _, ok := fd.d.childrenNames[dirent.Name]; !ok {
+ return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name))
+ }
+ }
+ }
+ ds = append(ds, dirent)
+ return nil
+ }))
+
+ if err != nil {
+ return err
+ }
+
+ // The result should contain all children plus "." and "..".
+ if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 {
+ return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds)))
+ }
+
+ for fd.off < int64(len(ds)) {
+ if err := cb.Handle(ds[fd.off]); err != nil {
+ return err
+ }
+ fd.off++
+ }
+ return nil
+}
+
// Seek implements vfs.FileDescriptionImpl.Seek.
func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) {
fd.mu.Lock()
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index f31277d30..6b71bd3a9 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -93,6 +93,14 @@ type Stack interface {
// SetForwarding enables or disables packet forwarding between NICs.
SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error
+
+ // PortRange returns the UDP and TCP inclusive range of ephemeral ports
+ // used in both IPv4 and IPv6.
+ PortRange() (uint16, uint16)
+
+ // SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
+ // (inclusive).
+ SetPortRange(start uint16, end uint16) error
}
// Interface contains information about a network interface.
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 9ebeba8a3..03e2608c2 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -164,3 +164,15 @@ func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable b
s.IPForwarding = enable
return nil
}
+
+// PortRange implements inet.Stack.PortRange.
+func (*TestStack) PortRange() (uint16, uint16) {
+ // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net().
+ return 32768, 28232
+}
+
+// SetPortRange implements inet.Stack.SetPortRange.
+func (*TestStack) SetPortRange(start uint16, end uint16) error {
+ // No-op.
+ return nil
+}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index e6323244c..5bcf92e14 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -504,3 +504,14 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return syserror.EACCES
}
+
+// PortRange implements inet.Stack.PortRange.
+func (*Stack) PortRange() (uint16, uint16) {
+ // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net().
+ return 32768, 28232
+}
+
+// SetPortRange implements inet.Stack.SetPortRange.
+func (*Stack) SetPortRange(start uint16, end uint16) error {
+ return syserror.EACCES
+}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 71c3bc034..b215067cf 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -478,3 +478,13 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool)
}
return nil
}
+
+// PortRange implements inet.Stack.PortRange.
+func (s *Stack) PortRange() (uint16, uint16) {
+ return s.Stack.PortRange()
+}
+
+// SetPortRange implements inet.Stack.SetPortRange.
+func (s *Stack) SetPortRange(start uint16, end uint16) error {
+ return syserr.TranslateNetstackError(s.Stack.SetPortRange(start, end)).ToError()
+}
diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go
index 7ad0eaf86..3caf417ca 100644
--- a/pkg/sentry/vfs/anonfs.go
+++ b/pkg/sentry/vfs/anonfs.go
@@ -291,6 +291,11 @@ func (fs *anonFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDe
return PrependPathSyntheticError{}
}
+// MountOptions implements FilesystemImpl.MountOptions.
+func (fs *anonFilesystem) MountOptions() string {
+ return ""
+}
+
// IncRef implements DentryImpl.IncRef.
func (d *anonDentry) IncRef() {
// no-op
diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go
index 2c4b81e78..059939010 100644
--- a/pkg/sentry/vfs/filesystem.go
+++ b/pkg/sentry/vfs/filesystem.go
@@ -502,6 +502,15 @@ type FilesystemImpl interface {
//
// Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl.
PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error
+
+ // MountOptions returns mount options for the current filesystem. This
+ // should only return options specific to the filesystem (i.e. don't return
+ // "ro", "rw", etc). Options should be returned as a comma-separated string,
+ // similar to the input to the 5th argument to mount.
+ //
+ // If the implementation has no filesystem-specific options, it should
+ // return the empty string.
+ MountOptions() string
}
// PrependPathAtVFSRootError is returned by implementations of
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index bac9eb905..922f9e697 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -959,13 +959,17 @@ func manglePath(p string) string {
// superBlockOpts returns the super block options string for the the mount at
// the given path.
func superBlockOpts(mountPath string, mnt *Mount) string {
- // gVisor doesn't (yet) have a concept of super block options, so we
- // use the ro/rw bit from the mount flag.
+ // Compose super block options by combining global mount flags with
+ // FS-specific mount options.
opts := "rw"
if mnt.ReadOnly() {
opts = "ro"
}
+ if mopts := mnt.fs.Impl().MountOptions(); mopts != "" {
+ opts += "," + mopts
+ }
+
// NOTE(b/147673608): If the mount is a cgroup, we also need to include
// the cgroup name in the options. For now we just read that from the
// path.
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 0b9139570..79e564de6 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -51,6 +51,7 @@ var (
ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM)
ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT)
ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), linux.EINVAL)
+ ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), linux.EINVAL)
)
// TranslateNetstackError converts an error from the tcpip package to a sentry
@@ -135,6 +136,8 @@ func TranslateNetstackError(err tcpip.Error) *Error {
return ErrBadBuffer
case *tcpip.ErrMalformedHeader:
return ErrMalformedHeader
+ case *tcpip.ErrInvalidPortRange:
+ return ErrInvalidPortRange
default:
panic(fmt.Sprintf("unknown error %T", err))
}
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
index 75d8e1f03..fc622b246 100644
--- a/pkg/tcpip/checker/checker.go
+++ b/pkg/tcpip/checker/checker.go
@@ -567,7 +567,7 @@ func TCPWindowLessThanEq(window uint16) TransportChecker {
}
// TCPFlags creates a checker that checks the tcp flags.
-func TCPFlags(flags uint8) TransportChecker {
+func TCPFlags(flags header.TCPFlags) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
@@ -576,15 +576,15 @@ func TCPFlags(flags uint8) TransportChecker {
t.Fatalf("TCP header not found in h: %T", h)
}
- if f := tcp.Flags(); f != flags {
- t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
+ if got := tcp.Flags(); got != flags {
+ t.Errorf("got tcp.Flags() = %s, want %s", got, flags)
}
}
}
// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
// given mask, match the supplied flags.
-func TCPFlagsMatch(flags, mask uint8) TransportChecker {
+func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
@@ -593,8 +593,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker {
t.Fatalf("TCP header not found in h: %T", h)
}
- if f := tcp.Flags(); (f & mask) != (flags & mask) {
- t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
+ if got := tcp.Flags(); (got & mask) != (flags & mask) {
+ t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask)
}
}
}
diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go
index 3b7cc52f3..5d478ac32 100644
--- a/pkg/tcpip/errors.go
+++ b/pkg/tcpip/errors.go
@@ -300,6 +300,19 @@ func (*ErrInvalidOptionValue) IgnoreStats() bool {
}
func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" }
+// ErrInvalidPortRange indicates an attempt to set an invalid port range.
+//
+// +stateify savable
+type ErrInvalidPortRange struct{}
+
+func (*ErrInvalidPortRange) isError() {}
+
+// IgnoreStats implements Error.
+func (*ErrInvalidPortRange) IgnoreStats() bool {
+ return true
+}
+func (*ErrInvalidPortRange) String() string { return "invalid port range" }
+
// ErrMalformedHeader indicates the operation encountered a malformed header.
//
// +stateify savable
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
index 4c6f808e5..adc835d30 100644
--- a/pkg/tcpip/header/tcp.go
+++ b/pkg/tcpip/header/tcp.go
@@ -45,9 +45,23 @@ const (
TCPMaxSACKBlocks = 4
)
+// TCPFlags is the dedicated type for TCP flags.
+type TCPFlags uint8
+
+// String implements Stringer.String.
+func (f TCPFlags) String() string {
+ flagsStr := []byte("FSRPAU")
+ for i := range flagsStr {
+ if f&(1<<uint(i)) == 0 {
+ flagsStr[i] = ' '
+ }
+ }
+ return string(flagsStr)
+}
+
// Flags that may be set in a TCP segment.
const (
- TCPFlagFin = 1 << iota
+ TCPFlagFin TCPFlags = 1 << iota
TCPFlagSyn
TCPFlagRst
TCPFlagPsh
@@ -94,7 +108,7 @@ type TCPFields struct {
DataOffset uint8
// Flags is the "flags" field of a TCP packet.
- Flags uint8
+ Flags TCPFlags
// WindowSize is the "window size" field of a TCP packet.
WindowSize uint16
@@ -234,8 +248,8 @@ func (b TCP) Payload() []byte {
}
// Flags returns the flags field of the tcp header.
-func (b TCP) Flags() uint8 {
- return b[TCPFlagsOffset]
+func (b TCP) Flags() TCPFlags {
+ return TCPFlags(b[TCPFlagsOffset])
}
// WindowSize returns the "window size" field of the tcp header.
@@ -319,10 +333,10 @@ func (b TCP) ParsedOptions() TCPOptions {
return ParseTCPOptions(b.Options())
}
-func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) {
+func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) {
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq)
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack)
- b[TCPFlagsOffset] = flags
+ b[TCPFlagsOffset] = uint8(flags)
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
@@ -338,7 +352,7 @@ func (b TCP) Encode(t *TCPFields) {
// EncodePartial updates a subset of the fields of the tcp header. It is useful
// in cases when similar segments are produced.
-func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) {
+func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) {
// Add the total length and "flags" field contributions to the checksum.
// We don't use the flags field directly from the header because it's a
// one-byte field with an odd offset, so it would be accounted for
diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go
index 72563837b..96db8460f 100644
--- a/pkg/tcpip/header/tcp_test.go
+++ b/pkg/tcpip/header/tcp_test.go
@@ -146,3 +146,23 @@ func TestTCPParseOptions(t *testing.T) {
}
}
}
+
+func TestTCPFlags(t *testing.T) {
+ for _, tt := range []struct {
+ flags header.TCPFlags
+ want string
+ }{
+ {header.TCPFlagFin, "F "},
+ {header.TCPFlagSyn, " S "},
+ {header.TCPFlagRst, " R "},
+ {header.TCPFlagPsh, " P "},
+ {header.TCPFlagAck, " A "},
+ {header.TCPFlagUrg, " U"},
+ {header.TCPFlagSyn | header.TCPFlagAck, " S A "},
+ {header.TCPFlagFin | header.TCPFlagAck, "F A "},
+ } {
+ if got := tt.flags.String(); got != tt.want {
+ t.Errorf("got TCPFlags(%#b).String() = %s, want = %s", tt.flags, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 84189bba5..7aaee3d13 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -398,13 +398,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe
// Initialize the TCP flags.
flags := tcp.Flags()
- flagsStr := []byte("FSRPAU")
- for i := range flagsStr {
- if flags&(1<<uint(i)) == 0 {
- flagsStr[i] = ' '
- }
- }
- details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ details = fmt.Sprintf("flags: %s seqnum: %d ack: %d win: %d xsum:0x%x", flags, tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
if flags&header.TCPFlagSyn != 0 {
details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
} else {
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index 3fcdea119..ae0461a6d 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -232,7 +232,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.mu.Lock()
- e.mu.dad.StopLocked(addr, false /* aborted */)
+ e.mu.dad.StopLocked(addr, &stack.DADDupAddrDetected{HolderLinkAddress: linkAddr})
e.mu.Unlock()
// The solicited, override, and isRouter flags are not available for ARP;
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
index 6f89a6a16..0053646ee 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go
@@ -126,9 +126,12 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet
s.timer.Stop()
delete(d.addresses, addr)
- r := stack.DADResult{Resolved: dadDone, Err: err}
+ var res stack.DADResult = &stack.DADSucceeded{}
+ if err != nil {
+ res = &stack.DADError{Err: err}
+ }
for _, h := range s.completionHandlers {
- h(r)
+ h(res)
}
}),
}
@@ -142,7 +145,7 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet
// StopLocked stops a currently running DAD process.
//
// Precondition: d.protocolMU must be locked.
-func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) {
+func (d *DAD) StopLocked(addr tcpip.Address, reason stack.DADResult) {
s, ok := d.addresses[addr]
if !ok {
return
@@ -152,14 +155,8 @@ func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) {
s.timer.Stop()
delete(d.addresses, addr)
- var err tcpip.Error
- if aborted {
- err = &tcpip.ErrAborted{}
- }
-
- r := stack.DADResult{Resolved: false, Err: err}
for _, h := range s.completionHandlers {
- h(r)
+ h(reason)
}
}
diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
index 18c357b56..e00aa4678 100644
--- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
+++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go
@@ -78,10 +78,10 @@ func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADC
return m.mu.dad.CheckDuplicateAddressLocked(addr, h)
}
-func (m *mockDADProtocol) stop(addr tcpip.Address, aborted bool) {
+func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) {
m.mu.Lock()
defer m.mu.Unlock()
- m.mu.dad.StopLocked(addr, aborted)
+ m.mu.dad.StopLocked(addr, reason)
}
func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) {
@@ -175,7 +175,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
}
clock.Advance(delta)
for i := 0; i < 2; i++ {
- if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff)
}
}
@@ -189,7 +189,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
default:
}
clock.Advance(delta)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
@@ -202,7 +202,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
clock.Advance(dadConfig2Duration)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
@@ -241,19 +241,19 @@ func TestDADStop(t *testing.T) {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
- dad.stop(addr1, true /* aborted */)
- if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: false, Err: &tcpip.ErrAborted{}}}, <-ch); diff != "" {
+ dad.stop(addr1, &stack.DADAborted{})
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
- dad.stop(addr2, false /* aborted */)
- if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: false, Err: nil}}, <-ch); diff != "" {
+ dad.stop(addr2, &stack.DADDupAddrDetected{})
+ if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer
clock.Advance(dadResolutionDuration)
- if diff := cmp.Diff(dadResult{Addr: addr3, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
@@ -266,7 +266,7 @@ func TestDADStop(t *testing.T) {
t.Errorf("dad check mismatch (-want +got):\n%s", diff)
}
clock.Advance(dadResolutionDuration)
- if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" {
+ if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" {
t.Errorf("dad result mismatch (-want +got):\n%s", diff)
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index 4b21ee79c..5e7f10f4b 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -32,12 +32,14 @@ go_test(
"ipv4_test.go",
],
deps = [
+ "//pkg/sync",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/internal/testutil",
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index cabe274d6..8a2140ebe 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -899,10 +899,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {
e.mu.Lock()
- defer e.mu.Unlock()
-
e.disableLocked()
e.mu.addressableEndpointState.Cleanup()
+ e.mu.Unlock()
e.protocol.forgetEndpoint(e.nic.ID())
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 26d9696d7..cfed241bf 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -26,12 +26,14 @@ import (
"time"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
@@ -2985,3 +2987,120 @@ func TestPacketQueing(t *testing.T) {
})
}
}
+
+// TestCloseLocking test that lock ordering is followed when closing an
+// endpoint.
+func TestCloseLocking(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ src = tcpip.Address("\x10\x00\x00\x01")
+ dst = tcpip.Address("\x10\x00\x00\x02")
+
+ iterations = 1000
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ // Perform NAT so that the endoint tries to search for a sibling endpoint
+ // which ends up taking the protocol and endpoint lock (in that order).
+ table := stack.Table{
+ Rules: []stack.Rule{
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}},
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 1,
+ stack.Forward: stack.HookUnset,
+ stack.Output: 2,
+ stack.Postrouting: 3,
+ },
+ Underflows: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 1,
+ stack.Forward: stack.HookUnset,
+ stack.Output: 2,
+ stack.Postrouting: 3,
+ },
+ }
+ if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil {
+ t.Fatalf("s.IPTables().ReplaceTable(...): %s", err)
+ }
+
+ e := channel.New(0, defaultMTU, "")
+ if err := s.CreateNIC(nicID1, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID1,
+ }})
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ep.Close()
+
+ addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53}
+ if err := ep.Connect(addr); err != nil {
+ t.Errorf("ep.Connect(%#v): %s", addr, err)
+ }
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ // Writing packets should trigger NAT which requires the stack to search the
+ // protocol for network endpoints with the destination address.
+ //
+ // Creating and removing interfaces should modify the protocol and endpoint
+ // which requires taking the locks of each.
+ //
+ // We expect the protocol > endpoint lock ordering to be followed here.
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ data := []byte{1, 2, 3, 4}
+
+ for i := 0; i < iterations; i++ {
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
+ t.Errorf("ep.Write(_, _): %s", err)
+ return
+ } else if want := int64(len(data)); n != want {
+ t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
+ return
+ }
+ }
+ }()
+ go func() {
+ defer wg.Done()
+
+ for i := 0; i < iterations; i++ {
+ if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ t.Errorf("CreateNIC(%d, _): %s", nicID2, err)
+ return
+ }
+ if err := s.RemoveNIC(nicID2); err != nil {
+ t.Errorf("RemoveNIC(%d): %s", nicID2, err)
+ return
+ }
+ }
+ }()
+}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index e80e681da..6344a3e09 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -382,6 +382,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
// stack know so it can handle such a scenario and do nothing further with
// the NS.
if srcAddr == header.IPv6Any {
+ // Since this is a DAD message we know the sender does not actually hold
+ // the target address so there is no "holder".
+ var holderLinkAddress tcpip.LinkAddress
+
// We would get an error if the address no longer exists or the address
// is no longer tentative (DAD resolved between the call to
// hasTentativeAddr and this point). Both of these are valid scenarios:
@@ -393,7 +397,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) {
case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
default:
panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err))
@@ -561,10 +565,24 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
// 5, NDP messages cannot be fragmented. Also note that in the common case
// NDP datagrams are very small and AsView() will not incur allocations.
na := header.NDPNeighborAdvert(payload.AsView())
+
+ it, err := na.Options().Iter(false /* check */)
+ if err != nil {
+ // If we have a malformed NDP NA option, drop the packet.
+ received.invalid.Increment()
+ return
+ }
+
+ targetLinkAddr, ok := getTargetLinkAddr(it)
+ if !ok {
+ received.invalid.Increment()
+ return
+ }
+
targetAddr := na.TargetAddress()
e.dad.mu.Lock()
- e.dad.mu.dad.StopLocked(targetAddr, false /* aborted */)
+ e.dad.mu.dad.StopLocked(targetAddr, &stack.DADDupAddrDetected{HolderLinkAddress: targetLinkAddr})
e.dad.mu.Unlock()
if e.hasTentativeAddr(targetAddr) {
@@ -584,7 +602,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
//
// TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate
// address is detected for an assigned address.
- switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) {
+ switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) {
case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState:
return
default:
@@ -592,13 +610,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
}
}
- it, err := na.Options().Iter(false /* check */)
- if err != nil {
- // If we have a malformed NDP NA option, drop the packet.
- received.invalid.Increment()
- return
- }
-
// At this point we know that the target address is not tentative on the
// NIC. However, the target address may still be assigned to the NIC but not
// tentative (it could be permanent). Such a scenario is beyond the scope of
@@ -608,11 +619,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
// TODO(b/143147598): Handle the scenario described above. Also inform the
// netstack integration that a duplicate address was detected outside of
// DAD.
- targetLinkAddr, ok := getTargetLinkAddr(it)
- if !ok {
- received.invalid.Increment()
- return
- }
// As per RFC 4861 section 7.1.2:
// A node MUST silently discard any received Neighbor Advertisement
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 544717678..46b6cc41a 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool {
// dupTentativeAddrDetected removes the tentative address if it exists. If the
// address was generated via SLAAC, an attempt is made to generate a new
// address.
-func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error {
+func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
@@ -363,7 +363,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error {
// If the address is a SLAAC address, do not invalidate its SLAAC prefix as an
// attempt will be made to generate a new address for it.
- if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, true /* dadFailure */); err != nil {
+ if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil {
return err
}
@@ -536,8 +536,20 @@ func (e *endpoint) disableLocked() {
}
e.mu.ndp.stopSolicitingRouters()
+ // Stop DAD for all the tentative unicast addresses.
+ e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
+ if addressEndpoint.GetKind() != stack.PermanentTentative {
+ return true
+ }
+
+ addr := addressEndpoint.AddressWithPrefix().Address
+ if header.IsV6UnicastAddress(addr) {
+ e.mu.ndp.stopDuplicateAddressDetection(addr, &stack.DADAborted{})
+ }
+
+ return true
+ })
e.mu.ndp.cleanupState(false /* hostOnly */)
- e.stopDADForPermanentAddressesLocked()
// The endpoint may have already left the multicast group.
switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) {
@@ -555,25 +567,6 @@ func (e *endpoint) disableLocked() {
}
}
-// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses.
-//
-// Precondition: e.mu must be write locked.
-func (e *endpoint) stopDADForPermanentAddressesLocked() {
- // Stop DAD for all the tentative unicast addresses.
- e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
- if addressEndpoint.GetKind() != stack.PermanentTentative {
- return true
- }
-
- addr := addressEndpoint.AddressWithPrefix().Address
- if header.IsV6UnicastAddress(addr) {
- e.mu.ndp.stopDuplicateAddressDetection(addr, false /* failed */)
- }
-
- return true
- })
-}
-
// DefaultTTL is the default hop limit for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
return e.protocol.DefaultTTL()
@@ -1384,8 +1377,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) {
func (e *endpoint) Close() {
e.mu.Lock()
e.disableLocked()
- e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */)
- e.stopDADForPermanentAddressesLocked()
e.mu.addressableEndpointState.Cleanup()
e.mu.Unlock()
@@ -1451,14 +1442,14 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error {
return &tcpip.ErrBadLocalAddress{}
}
- return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, false /* dadFailure */)
+ return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, &stack.DADAborted{})
}
// removePermanentEndpointLocked is like removePermanentAddressLocked except
// it works with a stack.AddressEndpoint.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation, dadFailure bool) tcpip.Error {
+func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool, dadResult stack.DADResult) tcpip.Error {
addr := addressEndpoint.AddressWithPrefix()
// If we are removing an address generated via SLAAC, cleanup
// its SLAAC resources and notify the integrator.
@@ -1469,16 +1460,16 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn
e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr)
}
- return e.removePermanentEndpointInnerLocked(addressEndpoint, dadFailure)
+ return e.removePermanentEndpointInnerLocked(addressEndpoint, dadResult)
}
// removePermanentEndpointInnerLocked is like removePermanentEndpointLocked
// except it does not cleanup SLAAC address state.
//
// Precondition: e.mu must be write locked.
-func (e *endpoint) removePermanentEndpointInnerLocked(addressEndpoint stack.AddressEndpoint, dadFailure bool) tcpip.Error {
+func (e *endpoint) removePermanentEndpointInnerLocked(addressEndpoint stack.AddressEndpoint, dadResult stack.DADResult) tcpip.Error {
addr := addressEndpoint.AddressWithPrefix()
- e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadFailure)
+ e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadResult)
if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil {
return err
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index c22f60709..d9b728878 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -208,16 +208,12 @@ const (
// NDPDispatcher is the interface integrators of netstack must implement to
// receive and handle NDP related events.
type NDPDispatcher interface {
- // OnDuplicateAddressDetectionStatus is called when the DAD process for an
- // address (addr) on a NIC (with ID nicID) completes. resolved is set to true
- // if DAD completed successfully (no duplicate addr detected); false otherwise
- // (addr was detected to be a duplicate on the link the NIC is a part of, or
- // it was stopped for some other reason, such as the address being removed).
- // If an error occured during DAD, err is set and resolved must be ignored.
+ // OnDuplicateAddressDetectionResult is called when the DAD process for an
+ // address on a NIC completes.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error)
+ OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult)
// OnDefaultRouterDiscovered is called when a new default router is
// discovered. Implementations must return true if the newly discovered
@@ -225,14 +221,14 @@ type NDPDispatcher interface {
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool
+ OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool
// OnDefaultRouterInvalidated is called when a discovered default router that
// was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address)
+ OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address)
// OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered.
// Implementations must return true if the newly discovered on-link prefix
@@ -240,14 +236,14 @@ type NDPDispatcher interface {
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool
+ OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool
// OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that
// was remembered is invalidated.
//
// This function is not permitted to block indefinitely. This function
// is also not permitted to call into the stack.
- OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet)
+ OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet)
// OnAutoGenAddress is called when a new prefix with its autonomous address-
// configuration flag set is received and SLAAC was performed. Implementations
@@ -280,12 +276,12 @@ type NDPDispatcher interface {
// It is up to the caller to use the DNS Servers only for their valid
// lifetime. OnRecursiveDNSServerOption may be called for new or
// already known DNS servers. If called with known DNS servers, their
- // valid lifetimes must be refreshed to lifetime (it may be increased,
- // decreased, or completely invalidated when lifetime = 0).
+ // valid lifetimes must be refreshed to the lifetime (it may be increased,
+ // decreased, or completely invalidated when the lifetime = 0).
//
// This function is not permitted to block indefinitely. It must not
// call functions on the stack itself.
- OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration)
+ OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration)
// OnDNSSearchListOption is called when the stack learns of DNS search lists
// through NDP.
@@ -293,9 +289,9 @@ type NDPDispatcher interface {
// It is up to the caller to use the domain names in the search list
// for only their valid lifetime. OnDNSSearchListOption may be called
// with new or already known domain names. If called with known domain
- // names, their valid lifetimes must be refreshed to lifetime (it may
- // be increased, decreased or completely invalidated when lifetime = 0.
- OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration)
+ // names, their valid lifetimes must be refreshed to the lifetime (it may
+ // be increased, decreased or completely invalidated when the lifetime = 0.
+ OnDNSSearchListOption(tcpip.NICID, []string, time.Duration)
// OnDHCPv6Configuration is called with an updated configuration that is
// available via DHCPv6 for the passed NIC.
@@ -587,15 +583,25 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID()))
}
- if r.Resolved {
+ var dadSucceeded bool
+ switch r.(type) {
+ case *stack.DADAborted, *stack.DADError, *stack.DADDupAddrDetected:
+ dadSucceeded = false
+ case *stack.DADSucceeded:
+ dadSucceeded = true
+ default:
+ panic(fmt.Sprintf("unrecognized DAD result = %T", r))
+ }
+
+ if dadSucceeded {
addressEndpoint.SetKind(stack.Permanent)
}
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, r.Resolved, r.Err)
+ ndpDisp.OnDuplicateAddressDetectionResult(ndp.ep.nic.ID(), addr, r)
}
- if r.Resolved {
+ if dadSucceeded {
if addressEndpoint.ConfigType() == stack.AddressConfigSlaac {
// Reset the generation attempts counter as we are starting the
// generation of a new address for the SLAAC prefix.
@@ -616,7 +622,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// Consider DAD to have resolved even if no DAD messages were actually
// transmitted.
if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil {
- ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil)
+ ndpDisp.OnDuplicateAddressDetectionResult(ndp.ep.nic.ID(), addr, &stack.DADSucceeded{})
}
ndp.ep.onAddressAssignedLocked(addr)
@@ -633,8 +639,8 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE
// of this function to handle such a scenario.
//
// The IPv6 endpoint that ndp belongs to MUST be locked.
-func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, failed bool) {
- ndp.dad.StopLocked(addr, !failed)
+func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, reason stack.DADResult) {
+ ndp.dad.StopLocked(addr, reason)
}
// handleRA handles a Router Advertisement message that arrived on the NIC
@@ -1501,7 +1507,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi
ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix())
}
- if err := ndp.ep.removePermanentEndpointInnerLocked(addressEndpoint, false /* dadFailure */); err != nil {
+ if err := ndp.ep.removePermanentEndpointInnerLocked(addressEndpoint, &stack.DADAborted{}); err != nil {
panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err))
}
}
@@ -1560,7 +1566,7 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa
func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) {
ndp.cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs, tempAddr, tempAddrState)
- if err := ndp.ep.removePermanentEndpointInnerLocked(tempAddrState.addressEndpoint, false /* dadFailure */); err != nil {
+ if err := ndp.ep.removePermanentEndpointInnerLocked(tempAddrState.addressEndpoint, &stack.DADAborted{}); err != nil {
panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err))
}
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 0d53c260d..6e850fd46 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -90,7 +90,7 @@ type testNDPDispatcher struct {
addr tcpip.Address
}
-func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
+func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool {
@@ -1314,10 +1314,10 @@ func TestCheckDuplicateAddress(t *testing.T) {
t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning)
}
- // Wait for DAD to resolve.
+ // Wait for DAD to complete.
clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer)
for i := 0; i < dadRequestsMade; i++ {
- if diff := cmp.Diff(stack.DADResult{Resolved: true}, <-ch); diff != "" {
+ if diff := cmp.Diff(&stack.DADSucceeded{}, <-ch); diff != "" {
t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
}
}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
index 57abec5c9..210262703 100644
--- a/pkg/tcpip/ports/BUILD
+++ b/pkg/tcpip/ports/BUILD
@@ -4,7 +4,10 @@ package(licenses = ["notice"])
go_library(
name = "ports",
- srcs = ["ports.go"],
+ srcs = [
+ "flags.go",
+ "ports.go",
+ ],
visibility = ["//visibility:public"],
deps = [
"//pkg/sync",
diff --git a/pkg/tcpip/ports/flags.go b/pkg/tcpip/ports/flags.go
new file mode 100644
index 000000000..a8d7bff25
--- /dev/null
+++ b/pkg/tcpip/ports/flags.go
@@ -0,0 +1,150 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ports
+
+// Flags represents the type of port reservation.
+//
+// +stateify savable
+type Flags struct {
+ // MostRecent represents UDP SO_REUSEADDR.
+ MostRecent bool
+
+ // LoadBalanced indicates SO_REUSEPORT.
+ //
+ // LoadBalanced takes precidence over MostRecent.
+ LoadBalanced bool
+
+ // TupleOnly represents TCP SO_REUSEADDR.
+ TupleOnly bool
+}
+
+// Bits converts the Flags to their bitset form.
+func (f Flags) Bits() BitFlags {
+ var rf BitFlags
+ if f.MostRecent {
+ rf |= MostRecentFlag
+ }
+ if f.LoadBalanced {
+ rf |= LoadBalancedFlag
+ }
+ if f.TupleOnly {
+ rf |= TupleOnlyFlag
+ }
+ return rf
+}
+
+// Effective returns the effective behavior of a flag config.
+func (f Flags) Effective() Flags {
+ e := f
+ if e.LoadBalanced && e.MostRecent {
+ e.MostRecent = false
+ }
+ return e
+}
+
+// BitFlags is a bitset representation of Flags.
+type BitFlags uint32
+
+const (
+ // MostRecentFlag represents Flags.MostRecent.
+ MostRecentFlag BitFlags = 1 << iota
+
+ // LoadBalancedFlag represents Flags.LoadBalanced.
+ LoadBalancedFlag
+
+ // TupleOnlyFlag represents Flags.TupleOnly.
+ TupleOnlyFlag
+
+ // nextFlag is the value that the next added flag will have.
+ //
+ // It is used to calculate FlagMask below. It is also the number of
+ // valid flag states.
+ nextFlag
+
+ // FlagMask is a bit mask for BitFlags.
+ FlagMask = nextFlag - 1
+
+ // MultiBindFlagMask contains the flags that allow binding the same
+ // tuple multiple times.
+ MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
+)
+
+// ToFlags converts the bitset into a Flags struct.
+func (f BitFlags) ToFlags() Flags {
+ return Flags{
+ MostRecent: f&MostRecentFlag != 0,
+ LoadBalanced: f&LoadBalancedFlag != 0,
+ TupleOnly: f&TupleOnlyFlag != 0,
+ }
+}
+
+// FlagCounter counts how many references each flag combination has.
+type FlagCounter struct {
+ // refs stores the count for each possible flag combination, (0 though
+ // FlagMask).
+ refs [nextFlag]int
+}
+
+// AddRef increases the reference count for a specific flag combination.
+func (c *FlagCounter) AddRef(flags BitFlags) {
+ c.refs[flags]++
+}
+
+// DropRef decreases the reference count for a specific flag combination.
+func (c *FlagCounter) DropRef(flags BitFlags) {
+ c.refs[flags]--
+}
+
+// TotalRefs calculates the total number of references for all flag
+// combinations.
+func (c FlagCounter) TotalRefs() int {
+ var total int
+ for _, r := range c.refs {
+ total += r
+ }
+ return total
+}
+
+// FlagRefs returns the number of references with all specified flags.
+func (c FlagCounter) FlagRefs(flags BitFlags) int {
+ var total int
+ for i, r := range c.refs {
+ if BitFlags(i)&flags == flags {
+ total += r
+ }
+ }
+ return total
+}
+
+// AllRefsHave returns if all references have all specified flags.
+func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
+ for i, r := range c.refs {
+ if BitFlags(i)&flags != flags && r > 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// SharedFlags returns the set of flags shared by all references.
+func (c FlagCounter) SharedFlags() BitFlags {
+ intersection := FlagMask
+ for i, r := range c.refs {
+ if r > 0 {
+ intersection &= BitFlags(i)
+ }
+ }
+ return intersection
+}
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 11dbdbbcf..678199371 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package ports provides PortManager that manages allocating, reserving and releasing ports.
+// Package ports provides PortManager that manages allocating, reserving and
+// releasing ports.
package ports
import (
- "math"
"math/rand"
"sync/atomic"
@@ -24,169 +24,44 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
)
-const (
- // FirstEphemeral is the first ephemeral port.
- FirstEphemeral = 16000
+const anyIPAddress tcpip.Address = ""
- // numEphemeralPorts it the mnumber of available ephemeral ports to
- // Netstack.
- numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1
+// Reservation describes a port reservation.
+type Reservation struct {
+ // Networks is a list of network protocols to which the reservation
+ // applies. Can be IPv4, IPv6, or both.
+ Networks []tcpip.NetworkProtocolNumber
- anyIPAddress tcpip.Address = ""
-)
-
-type portDescriptor struct {
- network tcpip.NetworkProtocolNumber
- transport tcpip.TransportProtocolNumber
- port uint16
-}
-
-// Flags represents the type of port reservation.
-//
-// +stateify savable
-type Flags struct {
- // MostRecent represents UDP SO_REUSEADDR.
- MostRecent bool
-
- // LoadBalanced indicates SO_REUSEPORT.
- //
- // LoadBalanced takes precidence over MostRecent.
- LoadBalanced bool
-
- // TupleOnly represents TCP SO_REUSEADDR.
- TupleOnly bool
-}
-
-// Bits converts the Flags to their bitset form.
-func (f Flags) Bits() BitFlags {
- var rf BitFlags
- if f.MostRecent {
- rf |= MostRecentFlag
- }
- if f.LoadBalanced {
- rf |= LoadBalancedFlag
- }
- if f.TupleOnly {
- rf |= TupleOnlyFlag
- }
- return rf
-}
-
-// Effective returns the effective behavior of a flag config.
-func (f Flags) Effective() Flags {
- e := f
- if e.LoadBalanced && e.MostRecent {
- e.MostRecent = false
- }
- return e
-}
-
-// PortManager manages allocating, reserving and releasing ports.
-type PortManager struct {
- mu sync.RWMutex
- allocatedPorts map[portDescriptor]bindAddresses
-
- // hint is used to pick ports ephemeral ports in a stable order for
- // a given port offset.
- //
- // hint must be accessed using the portHint/incPortHint helpers.
- // TODO(gvisor.dev/issue/940): S/R this field.
- hint uint32
-}
-
-// BitFlags is a bitset representation of Flags.
-type BitFlags uint32
-
-const (
- // MostRecentFlag represents Flags.MostRecent.
- MostRecentFlag BitFlags = 1 << iota
-
- // LoadBalancedFlag represents Flags.LoadBalanced.
- LoadBalancedFlag
-
- // TupleOnlyFlag represents Flags.TupleOnly.
- TupleOnlyFlag
-
- // nextFlag is the value that the next added flag will have.
- //
- // It is used to calculate FlagMask below. It is also the number of
- // valid flag states.
- nextFlag
-
- // FlagMask is a bit mask for BitFlags.
- FlagMask = nextFlag - 1
+ // Transport is the transport protocol to which the reservation applies.
+ Transport tcpip.TransportProtocolNumber
- // MultiBindFlagMask contains the flags that allow binding the same
- // tuple multiple times.
- MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag
-)
-
-// ToFlags converts the bitset into a Flags struct.
-func (f BitFlags) ToFlags() Flags {
- return Flags{
- MostRecent: f&MostRecentFlag != 0,
- LoadBalanced: f&LoadBalancedFlag != 0,
- TupleOnly: f&TupleOnlyFlag != 0,
- }
-}
+ // Addr is the address of the local endpoint.
+ Addr tcpip.Address
-// FlagCounter counts how many references each flag combination has.
-type FlagCounter struct {
- // refs stores the count for each possible flag combination, (0 though
- // FlagMask).
- refs [nextFlag]int
-}
+ // Port is the local port number.
+ Port uint16
-// AddRef increases the reference count for a specific flag combination.
-func (c *FlagCounter) AddRef(flags BitFlags) {
- c.refs[flags]++
-}
+ // Flags describe features of the reservation.
+ Flags Flags
-// DropRef decreases the reference count for a specific flag combination.
-func (c *FlagCounter) DropRef(flags BitFlags) {
- c.refs[flags]--
-}
+ // BindToDevice is the NIC to which the reservation applies.
+ BindToDevice tcpip.NICID
-// TotalRefs calculates the total number of references for all flag
-// combinations.
-func (c FlagCounter) TotalRefs() int {
- var total int
- for _, r := range c.refs {
- total += r
- }
- return total
+ // Dest is the destination address.
+ Dest tcpip.FullAddress
}
-// FlagRefs returns the number of references with all specified flags.
-func (c FlagCounter) FlagRefs(flags BitFlags) int {
- var total int
- for i, r := range c.refs {
- if BitFlags(i)&flags == flags {
- total += r
- }
- }
- return total
-}
-
-// AllRefsHave returns if all references have all specified flags.
-func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
- for i, r := range c.refs {
- if BitFlags(i)&flags != flags && r > 0 {
- return false
- }
+func (rs Reservation) dst() destination {
+ return destination{
+ rs.Dest.Addr,
+ rs.Dest.Port,
}
- return true
}
-// IntersectionRefs returns the set of flags shared by all references.
-func (c FlagCounter) IntersectionRefs() BitFlags {
- intersection := FlagMask
- for i, r := range c.refs {
- if r > 0 {
- intersection &= BitFlags(i)
- }
- }
- return intersection
+type portDescriptor struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+ port uint16
}
type destination struct {
@@ -194,18 +69,14 @@ type destination struct {
port uint16
}
-func makeDestination(a tcpip.FullAddress) destination {
- return destination{
- a.Addr,
- a.Port,
- }
-}
-
-// portNode is never empty. When it has no elements, it is removed from the
-// map that references it.
-type portNode map[destination]FlagCounter
+// destToCounter maps each destination to the FlagCounter that represents
+// endpoints to that destination.
+//
+// destToCounter is never empty. When it has no elements, it is removed from
+// the map that references it.
+type destToCounter map[destination]FlagCounter
-// intersectionRefs calculates the intersection of flag bit values which affect
+// intersectionFlags calculates the intersection of flag bit values which affect
// the specified destination.
//
// If no destinations are present, all flag values are returned as there are no
@@ -213,20 +84,20 @@ type portNode map[destination]FlagCounter
//
// In addition to the intersection, the number of intersecting refs is
// returned.
-func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
+func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) {
intersection := FlagMask
var count int
- for d, f := range p {
- if d == dst {
- intersection &= f.IntersectionRefs()
+ for dest, counter := range dc {
+ if dest == res.dst() {
+ intersection &= counter.SharedFlags()
count++
continue
}
// Wildcard destinations affect all destinations for TupleOnly.
- if d.addr == anyIPAddress || dst.addr == anyIPAddress {
+ if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress {
// Only bitwise and the TupleOnlyFlag.
- intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs())
+ intersection &= ((^TupleOnlyFlag) | counter.SharedFlags())
count++
}
}
@@ -234,27 +105,29 @@ func (p portNode) intersectionRefs(dst destination) (BitFlags, int) {
return intersection, count
}
-// deviceNode is never empty. When it has no elements, it is removed from the
+// deviceToDest maps NICs to destinations for which there are port reservations.
+//
+// deviceToDest is never empty. When it has no elements, it is removed from the
// map that references it.
-type deviceNode map[tcpip.NICID]portNode
+type deviceToDest map[tcpip.NICID]destToCounter
-// isAvailable checks whether binding is possible by device. If not binding to a
-// device, check against all FlagCounters. If binding to a specific device, check
-// against the unspecified device and the provided device.
+// isAvailable checks whether binding is possible by device. If not binding to
+// a device, check against all FlagCounters. If binding to a specific device,
+// check against the unspecified device and the provided device.
//
// If either of the port reuse flags is enabled on any of the nodes, all nodes
// sharing a port must share at least one reuse flag. This matches Linux's
// behavior.
-func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- flagBits := flags.Bits()
- if bindToDevice == 0 {
+func (dd deviceToDest) isAvailable(res Reservation) bool {
+ flagBits := res.Flags.Bits()
+ if res.BindToDevice == 0 {
intersection := FlagMask
- for _, p := range d {
- i, c := p.intersectionRefs(dst)
- if c == 0 {
+ for _, dest := range dd {
+ flags, count := dest.intersectionFlags(res)
+ if count == 0 {
continue
}
- intersection &= i
+ intersection &= flags
if intersection&flagBits == 0 {
// Can't bind because the (addr,port) was
// previously bound without reuse.
@@ -266,18 +139,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti
intersection := FlagMask
- if p, ok := d[0]; ok {
- var c int
- intersection, c = p.intersectionRefs(dst)
- if c > 0 && intersection&flagBits == 0 {
+ if dests, ok := dd[0]; ok {
+ var count int
+ intersection, count = dests.intersectionFlags(res)
+ if count > 0 && intersection&flagBits == 0 {
return false
}
}
- if p, ok := d[bindToDevice]; ok {
- i, c := p.intersectionRefs(dst)
- intersection &= i
- if c > 0 && intersection&flagBits == 0 {
+ if dests, ok := dd[res.BindToDevice]; ok {
+ flags, count := dests.intersectionFlags(res)
+ intersection &= flags
+ if count > 0 && intersection&flagBits == 0 {
return false
}
}
@@ -285,18 +158,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti
return true
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]deviceNode
+// addrToDevice maps IP addresses to NICs that have port reservations.
+type addrToDevice map[tcpip.Address]deviceToDest
// isAvailable checks whether an IP address is available to bind to. If the
// address is the "any" address, check all other addresses. Otherwise, just
// check against the "any" address and the provided address.
-func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- if addr == anyIPAddress {
- // If binding to the "any" address then check that there are no conflicts
- // with all addresses.
- for _, d := range b {
- if !d.isAvailable(flags, bindToDevice, dst) {
+func (ad addrToDevice) isAvailable(res Reservation) bool {
+ if res.Addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no
+ // conflicts with all addresses.
+ for _, devices := range ad {
+ if !devices.isAvailable(res) {
return false
}
}
@@ -304,15 +177,15 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice
}
// Check that there is no conflict with the "any" address.
- if d, ok := b[anyIPAddress]; ok {
- if !d.isAvailable(flags, bindToDevice, dst) {
+ if devices, ok := ad[anyIPAddress]; ok {
+ if !devices.isAvailable(res) {
return false
}
}
// Check that this is no conflict with the provided address.
- if d, ok := b[addr]; ok {
- if !d.isAvailable(flags, bindToDevice, dst) {
+ if devices, ok := ad[res.Addr]; ok {
+ if !devices.isAvailable(res) {
return false
}
}
@@ -320,50 +193,93 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice
return true
}
+// PortManager manages allocating, reserving and releasing ports.
+type PortManager struct {
+ // mu protects allocatedPorts.
+ // LOCK ORDERING: mu > ephemeralMu.
+ mu sync.RWMutex
+ // allocatedPorts is a nesting of maps that ultimately map Reservations
+ // to FlagCounters describing whether the Reservation is valid and can
+ // be reused.
+ allocatedPorts map[portDescriptor]addrToDevice
+
+ // ephemeralMu protects firstEphemeral and numEphemeral.
+ ephemeralMu sync.RWMutex
+ firstEphemeral uint16
+ numEphemeral uint16
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
+}
+
// NewPortManager creates new PortManager.
func NewPortManager() *PortManager {
- return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)}
+ return &PortManager{
+ allocatedPorts: make(map[portDescriptor]addrToDevice),
+ // Match Linux's default ephemeral range. See:
+ // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842
+ firstEphemeral: 32768,
+ numEphemeral: 28232,
+ }
}
+// PortTester indicates whether the passed in port is suitable. Returning an
+// error causes the function to which the PortTester is passed to return that
+// error.
+type PortTester func(port uint16) (good bool, err tcpip.Error)
+
// PickEphemeralPort randomly chooses a starting point and iterates over all
// possible ephemeral ports, allowing the caller to decide whether a given port
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
-func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
- offset := uint32(rand.Int31n(numEphemeralPorts))
- return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
+func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) {
+ pm.ephemeralMu.RLock()
+ firstEphemeral := pm.firstEphemeral
+ numEphemeral := pm.numEphemeral
+ pm.ephemeralMu.RUnlock()
+
+ offset := uint16(rand.Int31n(int32(numEphemeral)))
+ return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort)
}
-// portHint atomically reads and returns the s.hint value.
-func (s *PortManager) portHint() uint32 {
- return atomic.LoadUint32(&s.hint)
+// portHint atomically reads and returns the pm.hint value.
+func (pm *PortManager) portHint() uint16 {
+ return uint16(atomic.LoadUint32(&pm.hint))
}
-// incPortHint atomically increments s.hint by 1.
-func (s *PortManager) incPortHint() {
- atomic.AddUint32(&s.hint, 1)
+// incPortHint atomically increments pm.hint by 1.
+func (pm *PortManager) incPortHint() {
+ atomic.AddUint32(&pm.hint, 1)
}
-// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// PickEphemeralPortStable starts at the specified offset + pm.portHint and
// iterates over all ephemeral ports, allowing the caller to decide whether a
// given port is suitable for its needs and stopping when a port is found or an
// error occurs.
-func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
- p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
+func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+ pm.ephemeralMu.RLock()
+ firstEphemeral := pm.firstEphemeral
+ numEphemeral := pm.numEphemeral
+ pm.ephemeralMu.RUnlock()
+
+ p, err := pickEphemeralPort(pm.portHint()+offset, firstEphemeral, numEphemeral, testPort)
if err == nil {
- s.incPortHint()
+ pm.incPortHint()
}
return p, err
-
}
// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
// and iterates over the number of ports specified by count and allows the
// caller to decide whether a given port is suitable for its needs, and stopping
// when a port is found or an error occurs.
-func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) {
- for i := uint32(0); i < count; i++ {
- port = uint16(FirstEphemeral + (offset+i)%count)
+func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) {
+ for i := uint16(0); i < count; i++ {
+ port = first + (offset+i)%count
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -377,144 +293,145 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui
return 0, &tcpip.ErrNoPortAvailable{}
}
-// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
- s.mu.Lock()
- defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest))
-}
-
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, flags, bindToDevice, dst) {
- return false
- }
- }
- }
- return true
-}
-
// ReservePort marks a port/IP combination as reserved so that it cannot be
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
//
-// An optional testPort closure can be passed in which if provided will be used
-// to test if the picked port can be used. The function should return true if
-// the port is safe to use, false otherwise.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- dst := makeDestination(dest)
+// An optional PortTester can be passed in which if provided will be used to
+// test if the picked port can be used. The function should return true if the
+// port is safe to use, false otherwise.
+func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
- if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) {
+ if res.Port != 0 {
+ if !pm.reserveSpecificPortLocked(res) {
return 0, &tcpip.ErrPortInUse{}
}
- if testPort != nil && !testPort(port) {
- s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst)
- return 0, &tcpip.ErrPortInUse{}
+ if testPort != nil {
+ ok, err := testPort(res.Port)
+ if err != nil {
+ pm.releasePortLocked(res)
+ return 0, err
+ }
+ if !ok {
+ pm.releasePortLocked(res)
+ return 0, &tcpip.ErrPortInUse{}
+ }
}
- return port, nil
+ return res.Port, nil
}
// A port wasn't specified, so try to find one.
- return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
- if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) {
+ return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) {
+ res.Port = p
+ if !pm.reserveSpecificPortLocked(res) {
return false, nil
}
- if testPort != nil && !testPort(p) {
- s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst)
- return false, nil
+ if testPort != nil {
+ ok, err := testPort(p)
+ if err != nil {
+ pm.releasePortLocked(res)
+ return false, err
+ }
+ if !ok {
+ pm.releasePortLocked(res)
+ return false, nil
+ }
}
return true, nil
})
}
-// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) {
- return false
+// reserveSpecificPortLocked tries to reserve the given port on all given
+// protocols.
+func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool {
+ // Make sure the port is available.
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ if addrs, ok := pm.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(res) {
+ return false
+ }
+ }
}
- flagBits := flags.Bits()
-
// Reserve port on all network protocols.
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- m, ok := s.allocatedPorts[desc]
+ flagBits := res.Flags.Bits()
+ dst := res.dst()
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
- m = make(bindAddresses)
- s.allocatedPorts[desc] = m
+ addrToDev = make(addrToDevice)
+ pm.allocatedPorts[desc] = addrToDev
}
- d, ok := m[addr]
+ devToDest, ok := addrToDev[res.Addr]
if !ok {
- d = make(deviceNode)
- m[addr] = d
+ devToDest = make(deviceToDest)
+ addrToDev[res.Addr] = devToDest
}
- p := d[bindToDevice]
- if p == nil {
- p = make(portNode)
+ destToCntr := devToDest[res.BindToDevice]
+ if destToCntr == nil {
+ destToCntr = make(destToCounter)
}
- n := p[dst]
- n.AddRef(flagBits)
- p[dst] = n
- d[bindToDevice] = p
+ counter := destToCntr[dst]
+ counter.AddRef(flagBits)
+ destToCntr[dst] = counter
+ devToDest[res.BindToDevice] = destToCntr
}
return true
}
// ReserveTuple adds a port reservation for the tuple on all given protocol.
-func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool {
- flagBits := flags.Bits()
- dst := makeDestination(dest)
+func (pm *PortManager) ReserveTuple(res Reservation) bool {
+ flagBits := res.Flags.Bits()
+ dst := res.dst()
- s.mu.Lock()
- defer s.mu.Unlock()
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
// It is easier to undo the entire reservation, so if we find that the
// tuple can't be fully added, finish and undo the whole thing.
undo := false
// Reserve port on all network protocols.
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- m, ok := s.allocatedPorts[desc]
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
if !ok {
- m = make(bindAddresses)
- s.allocatedPorts[desc] = m
+ addrToDev = make(addrToDevice)
+ pm.allocatedPorts[desc] = addrToDev
}
- d, ok := m[addr]
+ devToDest, ok := addrToDev[res.Addr]
if !ok {
- d = make(deviceNode)
- m[addr] = d
+ devToDest = make(deviceToDest)
+ addrToDev[res.Addr] = devToDest
}
- p := d[bindToDevice]
- if p == nil {
- p = make(portNode)
+ destToCntr := devToDest[res.BindToDevice]
+ if destToCntr == nil {
+ destToCntr = make(destToCounter)
}
- n := p[dst]
- if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 {
+ counter := destToCntr[dst]
+ if counter.TotalRefs() != 0 && counter.SharedFlags()&flagBits == 0 {
// Tuple already exists.
undo = true
}
- n.AddRef(flagBits)
- p[dst] = n
- d[bindToDevice] = p
+ counter.AddRef(flagBits)
+ destToCntr[dst] = counter
+ devToDest[res.BindToDevice] = destToCntr
}
if undo {
// releasePortLocked decrements the counts (rather than setting
// them to zero), so it will undo the incorrect incrementing
// above.
- s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst)
+ pm.releasePortLocked(res)
return false
}
@@ -523,47 +440,71 @@ func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, trans
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) {
- s.mu.Lock()
- defer s.mu.Unlock()
+func (pm *PortManager) ReleasePort(res Reservation) {
+ pm.mu.Lock()
+ defer pm.mu.Unlock()
- s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest))
+ pm.releasePortLocked(res)
}
-func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) {
- for _, network := range networks {
- desc := portDescriptor{network, transport, port}
- if m, ok := s.allocatedPorts[desc]; ok {
- d, ok := m[addr]
- if !ok {
- continue
- }
- p, ok := d[bindToDevice]
- if !ok {
- continue
- }
- n, ok := p[dst]
- if !ok {
- continue
- }
- n.DropRef(flags)
- if n.TotalRefs() > 0 {
- p[dst] = n
- continue
- }
- delete(p, dst)
- if len(p) > 0 {
- continue
- }
- delete(d, bindToDevice)
- if len(d) > 0 {
- continue
- }
- delete(m, addr)
- if len(m) > 0 {
- continue
- }
- delete(s.allocatedPorts, desc)
+func (pm *PortManager) releasePortLocked(res Reservation) {
+ dst := res.dst()
+ for _, network := range res.Networks {
+ desc := portDescriptor{network, res.Transport, res.Port}
+ addrToDev, ok := pm.allocatedPorts[desc]
+ if !ok {
+ continue
}
+ devToDest, ok := addrToDev[res.Addr]
+ if !ok {
+ continue
+ }
+ destToCounter, ok := devToDest[res.BindToDevice]
+ if !ok {
+ continue
+ }
+ counter, ok := destToCounter[dst]
+ if !ok {
+ continue
+ }
+ counter.DropRef(res.Flags.Bits())
+ if counter.TotalRefs() > 0 {
+ destToCounter[dst] = counter
+ continue
+ }
+ delete(destToCounter, dst)
+ if len(destToCounter) > 0 {
+ continue
+ }
+ delete(devToDest, res.BindToDevice)
+ if len(devToDest) > 0 {
+ continue
+ }
+ delete(addrToDev, res.Addr)
+ if len(addrToDev) > 0 {
+ continue
+ }
+ delete(pm.allocatedPorts, desc)
+ }
+}
+
+// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
+// both IPv4 and IPv6.
+func (pm *PortManager) PortRange() (uint16, uint16) {
+ pm.ephemeralMu.RLock()
+ defer pm.ephemeralMu.RUnlock()
+ return pm.firstEphemeral, pm.firstEphemeral + pm.numEphemeral - 1
+}
+
+// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
+// (inclusive).
+func (pm *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error {
+ if start > end {
+ return &tcpip.ErrInvalidPortRange{}
}
+ pm.ephemeralMu.Lock()
+ defer pm.ephemeralMu.Unlock()
+ pm.firstEphemeral = start
+ pm.numEphemeral = end - start + 1
+ return nil
}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index e70fbb72b..0f43dc8f8 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -329,16 +329,35 @@ func TestPortReservation(t *testing.T) {
net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
for _, test := range test.actions {
+ first, _ := pm.PortRange()
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest)
+ portRes := Reservation{
+ Networks: net,
+ Transport: fakeTransNumber,
+ Addr: test.ip,
+ Port: test.port,
+ Flags: test.flags,
+ BindToDevice: test.device,
+ Dest: test.dest,
+ }
+ pm.ReleasePort(portRes)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */)
+ portRes := Reservation{
+ Networks: net,
+ Transport: fakeTransNumber,
+ Addr: test.ip,
+ Port: test.port,
+ Flags: test.flags,
+ BindToDevice: test.device,
+ Dest: test.dest,
+ }
+ gotPort, err := pm.ReservePort(portRes, nil /* testPort */)
if diff := cmp.Diff(test.want, err); diff != "" {
- t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff)
+ t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff)
}
- if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
- t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
+ if test.port == 0 && (gotPort == 0 || gotPort < first) {
+ t.Fatalf("ReservePort(%+v, _) = %d, want port number >= %d to be picked", portRes, gotPort, first)
}
}
})
@@ -346,6 +365,11 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
+ const (
+ firstEphemeral = 32000
+ numEphemeralPorts = 1000
+ )
+
for _, test := range []struct {
name string
f func(port uint16) (bool, tcpip.Error)
@@ -369,17 +393,17 @@ func TestPickEphemeralPort(t *testing.T) {
{
name: "only-port-16042-available",
f: func(port uint16) (bool, tcpip.Error) {
- if port == FirstEphemeral+42 {
+ if port == firstEphemeral+42 {
return true, nil
}
return false, nil
},
- wantPort: FirstEphemeral + 42,
+ wantPort: firstEphemeral + 42,
},
{
name: "only-port-under-16000-available",
f: func(port uint16) (bool, tcpip.Error) {
- if port < FirstEphemeral {
+ if port < firstEphemeral {
return true, nil
}
return false, nil
@@ -389,6 +413,9 @@ func TestPickEphemeralPort(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
+ if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
+ t.Fatalf("failed to set ephemeral port range: %s", err)
+ }
port, err := pm.PickEphemeralPort(test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
@@ -401,6 +428,11 @@ func TestPickEphemeralPort(t *testing.T) {
}
func TestPickEphemeralPortStable(t *testing.T) {
+ const (
+ firstEphemeral = 32000
+ numEphemeralPorts = 1000
+ )
+
for _, test := range []struct {
name string
f func(port uint16) (bool, tcpip.Error)
@@ -424,17 +456,17 @@ func TestPickEphemeralPortStable(t *testing.T) {
{
name: "only-port-16042-available",
f: func(port uint16) (bool, tcpip.Error) {
- if port == FirstEphemeral+42 {
+ if port == firstEphemeral+42 {
return true, nil
}
return false, nil
},
- wantPort: FirstEphemeral + 42,
+ wantPort: firstEphemeral + 42,
},
{
name: "only-port-under-16000-available",
f: func(port uint16) (bool, tcpip.Error) {
- if port < FirstEphemeral {
+ if port < firstEphemeral {
return true, nil
}
return false, nil
@@ -444,7 +476,10 @@ func TestPickEphemeralPortStable(t *testing.T) {
} {
t.Run(test.name, func(t *testing.T) {
pm := NewPortManager()
- portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
+ if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil {
+ t.Fatalf("failed to set ephemeral port range: %s", err)
+ }
+ portOffset := uint16(rand.Int31n(int32(numEphemeralPorts)))
port, err := pm.PickEphemeralPortStable(portOffset, test.f)
if diff := cmp.Diff(test.wantErr, err); diff != "" {
t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff)
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 78a4cb072..47796a6ba 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -99,12 +99,11 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi
}
// ndpDADEvent is a set of parameters that was passed to
-// ndpDispatcher.OnDuplicateAddressDetectionStatus.
+// ndpDispatcher.OnDuplicateAddressDetectionResult.
type ndpDADEvent struct {
- nicID tcpip.NICID
- addr tcpip.Address
- resolved bool
- err tcpip.Error
+ nicID tcpip.NICID
+ addr tcpip.Address
+ res stack.DADResult
}
type ndpRouterEvent struct {
@@ -173,14 +172,13 @@ type ndpDispatcher struct {
dhcpv6ConfigurationC chan ndpDHCPv6Event
}
-// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus.
-func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) {
+// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionResult.
+func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) {
if n.dadC != nil {
n.dadC <- ndpDADEvent{
nicID,
addr,
- resolved,
- err,
+ res,
}
}
}
@@ -311,8 +309,8 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 {
// Check e to make sure that the event is for addr on nic with ID 1, and the
// resolved flag set to resolved with the specified err.
-func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string {
- return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e))
+func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) string {
+ return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, res: res}, e, cmp.AllowUnexported(e))
}
// TestDADDisabled tests that an address successfully resolves immediately
@@ -344,8 +342,8 @@ func TestDADDisabled(t *testing.T) {
// DAD on it.
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
@@ -491,8 +489,8 @@ func TestDADResolve(t *testing.T) {
case <-time.After(defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
@@ -598,9 +596,10 @@ func TestDADFail(t *testing.T) {
const nicID = 1
tests := []struct {
- name string
- rxPkt func(e *channel.Endpoint, tgt tcpip.Address)
- getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ name string
+ rxPkt func(e *channel.Endpoint, tgt tcpip.Address)
+ getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter
+ expectedHolderLinkAddress tcpip.LinkAddress
}{
{
name: "RxSolicit",
@@ -608,6 +607,7 @@ func TestDADFail(t *testing.T) {
getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborSolicit
},
+ expectedHolderLinkAddress: "",
},
{
name: "RxAdvert",
@@ -642,6 +642,7 @@ func TestDADFail(t *testing.T) {
getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter {
return s.NeighborAdvert
},
+ expectedHolderLinkAddress: linkAddr1,
},
}
@@ -691,8 +692,8 @@ func TestDADFail(t *testing.T) {
// something is wrong.
t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
@@ -790,8 +791,8 @@ func TestDADStop(t *testing.T) {
// time + extra 1s buffer, something is wrong.
t.Fatal("timed out waiting for DAD failure")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
@@ -852,8 +853,8 @@ func TestSetNDPConfigurations(t *testing.T) {
expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) {
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatalf("expected DAD event for %s", addr)
@@ -944,8 +945,8 @@ func TestSetNDPConfigurations(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
@@ -1963,8 +1964,8 @@ func TestAutoGenTempAddr(t *testing.T) {
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
@@ -2169,8 +2170,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) {
}
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
@@ -2257,8 +2258,8 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) {
// address to be generated.
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
@@ -2723,8 +2724,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
t.Helper()
clock.Advance(dupAddrTransmits * retransmitTimer)
- if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
}
@@ -2754,8 +2755,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) {
rxNDPSolicit(e, addr.Address)
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
@@ -3853,26 +3854,26 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
}
}
- expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) {
+ expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
}
}
- expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) {
+ expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) {
t.Helper()
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr, res); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
@@ -3929,7 +3930,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// generated.
e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100))
expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr)
- expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true)
+ expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{})
// The stable address will be assigned throughout the test.
return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest}
@@ -4004,7 +4005,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
// Simulate a DAD conflict.
rxNDPSolicit(e, addr.Address)
expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr)
- expectDADEvent(t, &ndpDisp, addr.Address, false, nil)
+ expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{})
// Attempting to add the address manually should not fail if the
// address's state was cleaned up when DAD failed.
@@ -4014,7 +4015,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if err := s.RemoveAddress(nicID, addr.Address); err != nil {
t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err)
}
- expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{})
+ expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{})
}
// Should not have any new addresses assigned to the NIC.
@@ -4027,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) {
if maxRetries+1 > numFailures {
addr := addrType.addrGenFn(numFailures, tempIIDHistory[:])
expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr)
- expectDADEventAsync(t, &ndpDisp, addr.Address, true)
+ expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{})
if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" {
t.Fatal(mismatch)
}
@@ -4144,8 +4145,8 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) {
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
@@ -4243,8 +4244,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
expectAutoGenAddrEvent(addr, invalidatedAddr)
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
default:
t.Fatal("expected DAD event")
@@ -4255,8 +4256,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) {
expectAutoGenAddrEvent(addr, newAddr)
select {
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" {
- t.Errorf("dad event mismatch (-want +got):\n%s", diff)
+ if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" {
+ t.Errorf("DAD event mismatch (-want +got):\n%s", diff)
}
case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD event")
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 43e9e4beb..85f0f471a 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -852,18 +852,46 @@ type InjectableLinkEndpoint interface {
InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error
}
-// DADResult is the result of a duplicate address detection process.
-type DADResult struct {
- // Resolved is true when DAD completed without detecting a duplicate address
- // on the link.
- //
- // Ignored when Err is non-nil.
- Resolved bool
+// DADResult is a marker interface for the result of a duplicate address
+// detection process.
+type DADResult interface {
+ isDADResult()
+}
+
+var _ DADResult = (*DADSucceeded)(nil)
+
+// DADSucceeded indicates DAD completed without finding any duplicate addresses.
+type DADSucceeded struct{}
- // Err is an error encountered while performing DAD.
+func (*DADSucceeded) isDADResult() {}
+
+var _ DADResult = (*DADError)(nil)
+
+// DADError indicates DAD hit an error.
+type DADError struct {
Err tcpip.Error
}
+func (*DADError) isDADResult() {}
+
+var _ DADResult = (*DADAborted)(nil)
+
+// DADAborted indicates DAD was aborted.
+type DADAborted struct{}
+
+func (*DADAborted) isDADResult() {}
+
+var _ DADResult = (*DADDupAddrDetected)(nil)
+
+// DADDupAddrDetected indicates DAD detected a duplicate address.
+type DADDupAddrDetected struct {
+ // HolderLinkAddress is the link address of the node that holds the duplicate
+ // address.
+ HolderLinkAddress tcpip.LinkAddress
+}
+
+func (*DADDupAddrDetected) isDADResult() {}
+
// DADCompletionHandler is a handler for DAD completion.
type DADCompletionHandler func(DADResult)
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index de94ddfda..53370c354 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -813,6 +813,18 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool {
return forwardingProtocol.Forwarding()
}
+// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in
+// both IPv4 and IPv6.
+func (s *Stack) PortRange() (uint16, uint16) {
+ return s.PortManager.PortRange()
+}
+
+// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range
+// (inclusive).
+func (s *Stack) SetPortRange(start uint16, end uint16) tcpip.Error {
+ return s.PortManager.SetPortRange(start, end)
+}
+
// SetRouteTable assigns the route table to be used by this stack. It
// specifies which NIC to use for given destination address ranges.
//
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index f45cf5fdf..880219007 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -2605,7 +2605,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// means something is wrong.
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" {
+ if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
@@ -3289,7 +3289,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout):
t.Fatal("timed out waiting for DAD resolution")
case e := <-ndpDisp.dadC:
- if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" {
+ if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index e799f9290..e188efccb 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -359,7 +359,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32
return mpep.endpoints[0]
}
- if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent {
+ if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent {
return mpep.endpoints[len(mpep.endpoints)-1]
}
@@ -410,7 +410,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
- if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
return &tcpip.ErrPortInUse{}
}
}
@@ -429,7 +429,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error
if len(ep.endpoints) != 0 {
// If it was previously bound, we need to check if we can bind again.
- if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 {
+ if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 {
return &tcpip.ErrPortInUse{}
}
}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 405c74c65..095623789 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -1167,53 +1167,53 @@ func TestDAD(t *testing.T) {
}
tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- dadNetProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedResolved bool
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ dadNetProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedResult stack.DADResult
}{
{
- name: "IPv4 own address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
- expectedResolved: true,
+ name: "IPv4 own address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address,
+ expectedResult: &stack.DADSucceeded{},
},
{
- name: "IPv6 own address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
- expectedResolved: true,
+ name: "IPv6 own address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address,
+ expectedResult: &stack.DADSucceeded{},
},
{
- name: "IPv4 duplicate address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
- expectedResolved: false,
+ name: "IPv4 duplicate address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
},
{
- name: "IPv6 duplicate address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
- expectedResolved: false,
+ name: "IPv6 duplicate address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2},
},
{
- name: "IPv4 no duplicate address",
- netProto: ipv4.ProtocolNumber,
- dadNetProto: arp.ProtocolNumber,
- remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
- expectedResolved: true,
+ name: "IPv4 no duplicate address",
+ netProto: ipv4.ProtocolNumber,
+ dadNetProto: arp.ProtocolNumber,
+ remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address,
+ expectedResult: &stack.DADSucceeded{},
},
{
- name: "IPv6 no duplicate address",
- netProto: ipv6.ProtocolNumber,
- dadNetProto: ipv6.ProtocolNumber,
- remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
- expectedResolved: true,
+ name: "IPv6 no duplicate address",
+ netProto: ipv6.ProtocolNumber,
+ dadNetProto: ipv6.ProtocolNumber,
+ remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address,
+ expectedResult: &stack.DADSucceeded{},
},
}
@@ -1260,7 +1260,7 @@ func TestDAD(t *testing.T) {
}
expectResults := 1
- if test.expectedResolved {
+ if _, ok := test.expectedResult.(*stack.DADSucceeded); ok {
const delta = time.Nanosecond
clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta)
select {
@@ -1285,7 +1285,7 @@ func TestDAD(t *testing.T) {
}
for i := 0; i < expectResults; i++ {
- if diff := cmp.Diff(stack.DADResult{Resolved: test.expectedResolved}, <-ch); diff != "" {
+ if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" {
t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff)
}
}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index c56155ea2..80afc2825 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -38,7 +38,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil)
type ndpDispatcher struct{}
-func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) {
+func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) {
}
func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool {
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 09e9d027d..06c63e74a 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -26,6 +26,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// TODO(https://gvisor.dev/issues/5623): Unit test this package.
+
// +stateify savable
type icmpPacket struct {
icmpPacketEntry
@@ -414,6 +416,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
return &tcpip.ErrInvalidEndpointState{}
}
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest
+
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
@@ -422,7 +429,14 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
+
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
+ r.Stats().ICMP.V4.PacketsSent.Dropped.Increment()
+ return err
+ }
+
+ sentStat.Increment()
+ return nil
}
func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error {
@@ -444,6 +458,10 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return &tcpip.ErrInvalidEndpointState{}
}
+ // Because this icmp endpoint is implemented in the transport layer, we can
+ // only increment the 'stack-wide' stats but we can't increment the
+ // 'per-NetworkEndpoint' stats.
+ sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest
pkt.Data().AppendView(data)
dataRange := pkt.Data().AsRange()
@@ -458,7 +476,13 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro
if ttl == 0 {
ttl = r.DefaultTTL()
}
- return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt)
+
+ if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil {
+ r.Stats().ICMP.V6.PacketsSent.Dropped.Increment()
+ }
+
+ sentStat.Increment()
+ return nil
}
// checkV4MappedLocked determines the effective network protocol and converts
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index fcdd032c5..a69d6624d 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -105,7 +105,6 @@ go_test(
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
- "//pkg/tcpip/ports",
"//pkg/tcpip/seqnum",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/tcp/testing/context",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 842c1622b..3b574837c 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
@@ -432,15 +433,16 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// * e.mu is held.
func (e *endpoint) reserveTupleLocked() bool {
dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort}
- if !e.stack.ReserveTuple(
- e.effectiveNetProtos,
- ProtocolNumber,
- e.ID.LocalAddress,
- e.ID.LocalPort,
- e.boundPortFlags,
- e.boundBindToDevice,
- dest,
- ) {
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: dest,
+ }
+ if !e.stack.ReserveTuple(portRes) {
return false
}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index d1e452421..3404af6bb 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -68,7 +68,7 @@ type handshake struct {
ep *endpoint
state handshakeState
active bool
- flags uint8
+ flags header.TCPFlags
ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
@@ -606,7 +606,7 @@ func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer
func (bt *backoffTimer) reset() tcpip.Error {
bt.timeout *= 2
- if bt.timeout > MaxRTO {
+ if bt.timeout > bt.maxTimeout {
return &tcpip.ErrTimeout{}
}
bt.t.Reset(bt.timeout)
@@ -700,7 +700,7 @@ type tcpFields struct {
id stack.TransportEndpointID
ttl uint8
tos uint8
- flags byte
+ flags header.TCPFlags
seq seqnum.Value
ack seqnum.Value
rcvWnd seqnum.Size
@@ -877,7 +877,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
}
// sendRaw sends a TCP segment to the endpoint's peer.
-func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error {
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error {
var sackBlocks []header.SACKBlock
if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) {
sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 687b9f459..129f36d11 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1097,7 +1097,16 @@ func (e *endpoint) closeNoShutdownLocked() {
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: e.boundDest,
+ }
+ e.stack.ReleasePort(portRes)
e.isPortReserved = false
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
@@ -1172,7 +1181,16 @@ func (e *endpoint) cleanupLocked() {
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest)
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: e.boundDest,
+ }
+ e.stack.ReleasePort(portRes)
e.isPortReserved = false
}
e.boundBindToDevice = 0
@@ -2220,7 +2238,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
portBuf := make([]byte, 2)
binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort)
h.Write(portBuf)
- portOffset := h.Sum32()
+ portOffset := uint16(h.Sum32())
var twReuse tcpip.TCPTimeWaitReuseOption
if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil {
@@ -2242,7 +2260,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
if sameAddr && p == e.ID.RemotePort {
return false, nil
}
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: p,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: addr,
+ }
+ if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse {
return false, nil
}
@@ -2280,7 +2307,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
tcpEP.notifyProtocolGoroutine(notifyAbort)
tcpEP.UnlockUser()
// Now try and Reserve again if it fails then we skip.
- if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil {
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: p,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: addr,
+ }
+ if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil {
return false, nil
}
}
@@ -2288,7 +2324,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp
id := e.ID
id.LocalPort = p
if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr)
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: p,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: addr,
+ }
+ e.stack.ReleasePort(portRes)
if _, ok := err.(*tcpip.ErrPortInUse); ok {
return false, nil
}
@@ -2604,7 +2649,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
}
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool {
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: addr.Addr,
+ Port: addr.Port,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) {
id := e.ID
id.LocalPort = p
// CheckRegisterTransportEndpoint should only return an error if there is a
@@ -2616,9 +2670,9 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) {
// address/port. Hence this will only return an error if there is a matching
// listening endpoint.
if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil {
- return false
+ return false, nil
}
- return true
+ return true, nil
})
if err != nil {
return err
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 7c15690a3..a53d76917 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -22,6 +22,7 @@ import (
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -208,7 +209,16 @@ func (e *endpoint) Resume(s *stack.Stack) {
if err != nil {
panic("unable to parse BindAddr: " + err.String())
}
- if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok {
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: addr.Addr,
+ Port: addr.Port,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: e.boundDest,
+ }
+ if ok := e.stack.ReserveTuple(portRes); !ok {
panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
e.isPortReserved = true
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 04012cd40..2a4667906 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -226,7 +226,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
ack := seqnum.Value(0)
- flags := byte(header.TCPFlagRst)
+ flags := header.TCPFlagRst
// As per RFC 793 page 35 (Reset Generation)
// 1. If the connection does not exist (CLOSED) then a reset is sent
// in response to any incoming segment except another reset. In
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 744382100..8edd6775b 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -62,7 +62,7 @@ type segment struct {
views [8]buffer.View `state:"nosave"`
sequenceNumber seqnum.Value
ackNumber seqnum.Value
- flags uint8
+ flags header.TCPFlags
window seqnum.Size
// csum is only populated for received segments.
csum uint16
@@ -141,12 +141,12 @@ func (s *segment) clone() *segment {
}
// flagIsSet checks if at least one flag in flags is set in s.flags.
-func (s *segment) flagIsSet(flags uint8) bool {
+func (s *segment) flagIsSet(flags header.TCPFlags) bool {
return s.flags&flags != 0
}
// flagsAreSet checks if all flags in flags are set in s.flags.
-func (s *segment) flagsAreSet(flags uint8) bool {
+func (s *segment) flagsAreSet(flags header.TCPFlags) bool {
return s.flags&flags == flags
}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
index 83c8deb0e..18817029d 100644
--- a/pkg/tcpip/transport/tcp/snd.go
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -1613,7 +1613,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error {
// sendSegmentFromView sends a new segment containing the given payload, flags
// and sequence number.
-func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error {
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error {
s.lastSendTime = time.Now()
if seq == s.rttMeasureSeqNum {
s.rttMeasureTime = s.lastSendTime
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 0128c1f7e..fd499a47b 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -33,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
- "gvisor.dev/gvisor/pkg/tcpip/ports"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -1373,7 +1372,7 @@ func TestTOSV4(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
checker.TOS(tos, 0),
)
@@ -1421,7 +1420,7 @@ func TestTrafficClassV6(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
checker.TOS(tos, 0),
)
@@ -2202,7 +2201,7 @@ func TestSimpleSend(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2242,7 +2241,7 @@ func TestZeroWindowSend(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2264,7 +2263,7 @@ func TestZeroWindowSend(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2311,7 +2310,7 @@ func TestScaledWindowConnect(t *testing.T) {
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(0x5fff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2342,7 +2341,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2415,7 +2414,7 @@ func TestScaledWindowAccept(t *testing.T) {
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(0x5fff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2488,7 +2487,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(0xffff),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2666,7 +2665,7 @@ func TestSegmentMerging(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
}
@@ -2689,7 +2688,7 @@ func TestSegmentMerging(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+11),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2738,7 +2737,7 @@ func TestDelay(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2786,7 +2785,7 @@ func TestUndelay(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2809,7 +2808,7 @@ func TestUndelay(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2872,7 +2871,7 @@ func TestMSSNotDelayed(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(seq)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -2923,7 +2922,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3438,7 +3437,7 @@ func TestMaxRTO(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
const numRetransmits = 2
@@ -3447,7 +3446,7 @@ func TestMaxRTO(t *testing.T) {
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() {
@@ -3490,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
checker.FragmentFlags(0),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}}
@@ -3502,7 +3501,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
checker.FragmentFlags(0),
checker.TCP(
checker.DstPort(context.TestPort),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
id := header.IPv4(pkt).ID()
@@ -3633,7 +3632,7 @@ func TestFinWithNoPendingData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3710,7 +3709,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3729,7 +3728,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3796,7 +3795,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3822,7 +3821,7 @@ func TestFinWithPendingData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3886,7 +3885,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -3907,7 +3906,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(791),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -3923,7 +3922,7 @@ func TestFinWithPartialAck(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(791),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
next += uint32(len(view))
@@ -4033,7 +4032,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -4783,7 +4782,8 @@ func TestConnectAvoidsBoundPorts(t *testing.T) {
t.Fatalf("unknown address type: '%s'", candidateAddressType)
}
- for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ {
+ start, end := s.PortRange()
+ for i := start; i <= end; i++ {
if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil {
t.Fatalf("Bind(%d) failed: %s", i, err)
}
@@ -4844,7 +4844,7 @@ func TestPathMTUDiscovery(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(seqNum),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
seqNum += uint32(size)
@@ -5129,7 +5129,7 @@ func TestKeepalive(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -7174,7 +7174,7 @@ func TestTCPCloseWithData(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -7274,7 +7274,7 @@ func TestTCPUserTimeout(t *testing.T) {
checker.DstPort(context.TestPort),
checker.TCPSeqNum(next),
checker.TCPAckNum(790),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 5a9745ad7..cb4f82903 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -170,7 +170,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
checker.TCPTimestampChecker(true, 0, tsVal+1),
),
)
@@ -231,7 +231,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
checker.TCPSeqNum(uint32(c.IRS)+1),
checker.TCPAckNum(790),
checker.TCPWindow(wndSize),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
checker.TCPTimestampChecker(false, 0, 0),
),
)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index b1cb9a324..2f1c1011d 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -101,7 +101,7 @@ type Headers struct {
AckNum seqnum.Value
// Flags are the TCP flags in the TCP header.
- Flags int
+ Flags header.TCPFlags
// RcvWnd is the window to be advertised in the ReceiveWindow field of
// the TCP header.
@@ -452,7 +452,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp
SeqNum: uint32(h.SeqNum),
AckNum: uint32(h.AckNum),
DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
- Flags: uint8(h.Flags),
+ Flags: h.Flags,
WindowSize: uint16(h.RcvWnd),
})
@@ -544,7 +544,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -571,7 +571,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int
checker.DstPort(TestPort),
checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))),
- checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh),
),
)
@@ -650,7 +650,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp
SeqNum: uint32(h.SeqNum),
AckNum: uint32(h.AckNum),
DataOffset: header.TCPMinimumSize,
- Flags: uint8(h.Flags),
+ Flags: h.Flags,
WindowSize: uint16(h.RcvWnd),
})
@@ -780,7 +780,7 @@ type RawEndpoint struct {
C *Context
SrcPort uint16
DstPort uint16
- Flags int
+ Flags header.TCPFlags
NextSeqNum seqnum.Value
AckNum seqnum.Value
WndSize seqnum.Size
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
index 5e271b7ca..6c5ddc3c7 100644
--- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
@@ -465,7 +465,7 @@ func TestIgnoreBadResetOnSynSent(t *testing.T) {
// Receive a RST with a bad ACK, it should not cause the connection to
// be reset.
acks := []uint32{1234, 1236, 1000, 5000}
- flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
+ flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
for _, a := range acks {
for _, f := range flags {
tcp.Encode(&header.TCPFields{
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index b519afed1..c0f566459 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -245,7 +245,16 @@ func (e *endpoint) Close() {
switch e.EndpointState() {
case StateBound, StateConnected:
e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ e.stack.ReleasePort(portRes)
e.boundBindToDevice = 0
e.boundPortFlags = ports.Flags{}
}
@@ -920,7 +929,16 @@ func (e *endpoint) Disconnect() tcpip.Error {
} else {
if e.ID.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{})
+ portRes := ports.Reservation{
+ Networks: e.effectiveNetProtos,
+ Transport: ProtocolNumber,
+ Addr: e.ID.LocalAddress,
+ Port: e.ID.LocalPort,
+ Flags: boundPortFlags,
+ BindToDevice: e.boundBindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
e.setEndpointState(StateInitial)
@@ -1072,7 +1090,16 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if e.ID.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */)
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ Flags: e.portFlags,
+ BindToDevice: bindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ port, err := e.stack.ReservePort(portRes, nil /* testPort */)
if err != nil {
return id, bindToDevice, err
}
@@ -1082,7 +1109,16 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id
err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{})
+ portRes := ports.Reservation{
+ Networks: netProtos,
+ Transport: ProtocolNumber,
+ Addr: id.LocalAddress,
+ Port: id.LocalPort,
+ Flags: e.boundPortFlags,
+ BindToDevice: bindToDevice,
+ Dest: tcpip.FullAddress{},
+ }
+ e.stack.ReleasePort(portRes)
e.boundPortFlags = ports.Flags{}
}
return id, bindToDevice, err
diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go
index a3a76b609..28e82e117 100644
--- a/runsc/boot/compat.go
+++ b/runsc/boot/compat.go
@@ -17,8 +17,8 @@ package boot
import (
"fmt"
"os"
- "syscall"
+ "golang.org/x/sys/unix"
"google.golang.org/protobuf/proto"
"gvisor.dev/gvisor/pkg/eventchannel"
"gvisor.dev/gvisor/pkg/log"
@@ -93,19 +93,19 @@ func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) {
tr := c.trackers[sysnr]
if tr == nil {
switch sysnr {
- case syscall.SYS_PRCTL:
+ case unix.SYS_PRCTL:
// args: cmd, ...
tr = newArgsTracker(0)
- case syscall.SYS_IOCTL, syscall.SYS_EPOLL_CTL, syscall.SYS_SHMCTL, syscall.SYS_FUTEX, syscall.SYS_FALLOCATE:
+ case unix.SYS_IOCTL, unix.SYS_EPOLL_CTL, unix.SYS_SHMCTL, unix.SYS_FUTEX, unix.SYS_FALLOCATE:
// args: fd/addr, cmd, ...
tr = newArgsTracker(1)
- case syscall.SYS_GETSOCKOPT, syscall.SYS_SETSOCKOPT:
+ case unix.SYS_GETSOCKOPT, unix.SYS_SETSOCKOPT:
// args: fd, level, name, ...
tr = newArgsTracker(1, 2)
- case syscall.SYS_SEMCTL:
+ case unix.SYS_SEMCTL:
// args: semid, semnum, cmd, ...
tr = newArgsTracker(2)
@@ -131,7 +131,7 @@ func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) {
}
func (c *compatEmitter) emitUncaughtSignal(msg *ucspb.UncaughtSignal) {
- sig := syscall.Signal(msg.SignalNumber)
+ sig := unix.Signal(msg.SignalNumber)
c.sink.Infof(
"Uncaught signal: %q (%d), PID: %d, TID: %d, fault addr: %#x",
sig, msg.SignalNumber, msg.Pid, msg.Tid, msg.FaultAddr)
diff --git a/runsc/boot/compat_amd64.go b/runsc/boot/compat_amd64.go
index 8eb76b2ba..7e13ff87c 100644
--- a/runsc/boot/compat_amd64.go
+++ b/runsc/boot/compat_amd64.go
@@ -16,8 +16,8 @@ package boot
import (
"fmt"
- "syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi"
"gvisor.dev/gvisor/pkg/sentry/arch"
rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto"
@@ -92,7 +92,7 @@ func syscallNum(regs *rpb.Registers) uint64 {
func newArchArgsTracker(sysnr uint64) syscallTracker {
switch sysnr {
- case syscall.SYS_ARCH_PRCTL:
+ case unix.SYS_ARCH_PRCTL:
// args: cmd, ...
return newArgsTracker(0)
}
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 5e849cb37..1cd5fba5c 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -18,9 +18,9 @@ import (
"errors"
"fmt"
"os"
- "syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/control/server"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/log"
@@ -366,7 +366,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
case 2:
// The device file is donated to the platform.
// Can't take ownership away from os.File. dup them to get a new FD.
- fd, err := syscall.Dup(int(o.Files[1].Fd()))
+ fd, err := unix.Dup(int(o.Files[1].Fd()))
if err != nil {
return fmt.Errorf("failed to dup file: %v", err)
}
diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go
index 2a8c916d5..49b503f99 100644
--- a/runsc/boot/filter/config.go
+++ b/runsc/boot/filter/config.go
@@ -16,7 +16,6 @@ package filter
import (
"os"
- "syscall"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -26,19 +25,19 @@ import (
// allowedSyscalls is the set of syscalls executed by the Sentry to the host OS.
var allowedSyscalls = seccomp.SyscallRules{
- syscall.SYS_CLOCK_GETTIME: {},
- syscall.SYS_CLOSE: {},
- syscall.SYS_DUP: {},
- syscall.SYS_DUP3: []seccomp.Rule{
+ unix.SYS_CLOCK_GETTIME: {},
+ unix.SYS_CLOSE: {},
+ unix.SYS_DUP: {},
+ unix.SYS_DUP3: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.O_CLOEXEC),
+ seccomp.EqualTo(unix.O_CLOEXEC),
},
},
- syscall.SYS_EPOLL_CREATE1: {},
- syscall.SYS_EPOLL_CTL: {},
- syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{
+ unix.SYS_EPOLL_CREATE1: {},
+ unix.SYS_EPOLL_CTL: {},
+ unix.SYS_EPOLL_PWAIT: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
@@ -47,34 +46,34 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.EqualTo(0),
},
},
- syscall.SYS_EVENTFD2: []seccomp.Rule{
+ unix.SYS_EVENTFD2: []seccomp.Rule{
{
seccomp.EqualTo(0),
seccomp.EqualTo(0),
},
},
- syscall.SYS_EXIT: {},
- syscall.SYS_EXIT_GROUP: {},
- syscall.SYS_FALLOCATE: {},
- syscall.SYS_FCHMOD: {},
- syscall.SYS_FCNTL: []seccomp.Rule{
+ unix.SYS_EXIT: {},
+ unix.SYS_EXIT_GROUP: {},
+ unix.SYS_FALLOCATE: {},
+ unix.SYS_FCHMOD: {},
+ unix.SYS_FCNTL: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.F_GETFL),
+ seccomp.EqualTo(unix.F_GETFL),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.F_SETFL),
+ seccomp.EqualTo(unix.F_SETFL),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.F_GETFD),
+ seccomp.EqualTo(unix.F_GETFD),
},
},
- syscall.SYS_FSTAT: {},
- syscall.SYS_FSYNC: {},
- syscall.SYS_FTRUNCATE: {},
- syscall.SYS_FUTEX: []seccomp.Rule{
+ unix.SYS_FSTAT: {},
+ unix.SYS_FSYNC: {},
+ unix.SYS_FTRUNCATE: {},
+ unix.SYS_FUTEX: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
@@ -109,35 +108,35 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.EqualTo(0),
},
},
- syscall.SYS_GETPID: {},
+ unix.SYS_GETPID: {},
unix.SYS_GETRANDOM: {},
- syscall.SYS_GETSOCKOPT: []seccomp.Rule{
+ unix.SYS_GETSOCKOPT: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_DOMAIN),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_DOMAIN),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_TYPE),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_TYPE),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_ERROR),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_ERROR),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_SNDBUF),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_SNDBUF),
},
},
- syscall.SYS_GETTID: {},
- syscall.SYS_GETTIMEOFDAY: {},
+ unix.SYS_GETTID: {},
+ unix.SYS_GETTIMEOFDAY: {},
// SYS_IOCTL is needed for terminal support, but we only allow
// setting/getting termios and winsize.
- syscall.SYS_IOCTL: []seccomp.Rule{
+ unix.SYS_IOCTL: []seccomp.Rule{
{
seccomp.MatchAny{}, /* fd */
seccomp.EqualTo(linux.TCGETS),
@@ -169,94 +168,94 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.MatchAny{}, /* winsize struct */
},
},
- syscall.SYS_LSEEK: {},
- syscall.SYS_MADVISE: {},
+ unix.SYS_LSEEK: {},
+ unix.SYS_MADVISE: {},
unix.SYS_MEMBARRIER: []seccomp.Rule{
{
seccomp.EqualTo(linux.MEMBARRIER_CMD_GLOBAL),
seccomp.EqualTo(0),
},
},
- syscall.SYS_MINCORE: {},
+ unix.SYS_MINCORE: {},
// Used by the Go runtime as a temporarily workaround for a Linux
// 5.2-5.4 bug.
//
// See src/runtime/os_linux_x86.go.
//
// TODO(b/148688965): Remove once this is gone from Go.
- syscall.SYS_MLOCK: []seccomp.Rule{
+ unix.SYS_MLOCK: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.EqualTo(4096),
},
},
- syscall.SYS_MMAP: []seccomp.Rule{
+ unix.SYS_MMAP: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MAP_SHARED),
+ seccomp.EqualTo(unix.MAP_SHARED),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MAP_PRIVATE),
+ seccomp.EqualTo(unix.MAP_PRIVATE),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS),
+ seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK),
+ seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS | unix.MAP_STACK),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE),
+ seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS | unix.MAP_NORESERVE),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.PROT_WRITE | syscall.PROT_READ),
- seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED),
+ seccomp.EqualTo(unix.PROT_WRITE | unix.PROT_READ),
+ seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS | unix.MAP_FIXED),
},
},
- syscall.SYS_MPROTECT: {},
- syscall.SYS_MUNMAP: {},
- syscall.SYS_NANOSLEEP: {},
- syscall.SYS_PPOLL: {},
- syscall.SYS_PREAD64: {},
- syscall.SYS_PREADV: {},
- unix.SYS_PREADV2: {},
- syscall.SYS_PWRITE64: {},
- syscall.SYS_PWRITEV: {},
- unix.SYS_PWRITEV2: {},
- syscall.SYS_READ: {},
- syscall.SYS_RECVMSG: []seccomp.Rule{
+ unix.SYS_MPROTECT: {},
+ unix.SYS_MUNMAP: {},
+ unix.SYS_NANOSLEEP: {},
+ unix.SYS_PPOLL: {},
+ unix.SYS_PREAD64: {},
+ unix.SYS_PREADV: {},
+ unix.SYS_PREADV2: {},
+ unix.SYS_PWRITE64: {},
+ unix.SYS_PWRITEV: {},
+ unix.SYS_PWRITEV2: {},
+ unix.SYS_READ: {},
+ unix.SYS_RECVMSG: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC),
+ seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_TRUNC),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK),
+ seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_TRUNC | unix.MSG_PEEK),
},
},
- syscall.SYS_RECVMMSG: []seccomp.Rule{
+ unix.SYS_RECVMMSG: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.EqualTo(fdbased.MaxMsgsPerRecv),
- seccomp.EqualTo(syscall.MSG_DONTWAIT),
+ seccomp.EqualTo(unix.MSG_DONTWAIT),
seccomp.EqualTo(0),
},
},
@@ -265,34 +264,34 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MSG_DONTWAIT),
+ seccomp.EqualTo(unix.MSG_DONTWAIT),
seccomp.EqualTo(0),
},
},
- syscall.SYS_RESTART_SYSCALL: {},
- syscall.SYS_RT_SIGACTION: {},
- syscall.SYS_RT_SIGPROCMASK: {},
- syscall.SYS_RT_SIGRETURN: {},
- syscall.SYS_SCHED_YIELD: {},
- syscall.SYS_SENDMSG: []seccomp.Rule{
+ unix.SYS_RESTART_SYSCALL: {},
+ unix.SYS_RT_SIGACTION: {},
+ unix.SYS_RT_SIGPROCMASK: {},
+ unix.SYS_RT_SIGRETURN: {},
+ unix.SYS_SCHED_YIELD: {},
+ unix.SYS_SENDMSG: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL),
+ seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_NOSIGNAL),
},
},
- syscall.SYS_SETITIMER: {},
- syscall.SYS_SHUTDOWN: []seccomp.Rule{
+ unix.SYS_SETITIMER: {},
+ unix.SYS_SHUTDOWN: []seccomp.Rule{
// Used by fs/host to shutdown host sockets.
- {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RD)},
- {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_WR)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(unix.SHUT_RD)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(unix.SHUT_WR)},
// Used by unet to shutdown connections.
- {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)},
+ {seccomp.MatchAny{}, seccomp.EqualTo(unix.SHUT_RDWR)},
},
- syscall.SYS_SIGALTSTACK: {},
- unix.SYS_STATX: {},
- syscall.SYS_SYNC_FILE_RANGE: {},
- syscall.SYS_TEE: []seccomp.Rule{
+ unix.SYS_SIGALTSTACK: {},
+ unix.SYS_STATX: {},
+ unix.SYS_SYNC_FILE_RANGE: {},
+ unix.SYS_TEE: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
@@ -300,12 +299,12 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.EqualTo(unix.SPLICE_F_NONBLOCK), /* flags */
},
},
- syscall.SYS_TGKILL: []seccomp.Rule{
+ unix.SYS_TGKILL: []seccomp.Rule{
{
seccomp.EqualTo(uint64(os.Getpid())),
},
},
- syscall.SYS_UTIMENSAT: []seccomp.Rule{
+ unix.SYS_UTIMENSAT: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.EqualTo(0), /* null pathname */
@@ -313,9 +312,9 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.EqualTo(0), /* flags */
},
},
- syscall.SYS_WRITE: {},
+ unix.SYS_WRITE: {},
// For rawfile.NonBlockingWriteIovec.
- syscall.SYS_WRITEV: []seccomp.Rule{
+ unix.SYS_WRITEV: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
@@ -327,313 +326,313 @@ var allowedSyscalls = seccomp.SyscallRules{
// hostInetFilters contains syscalls that are needed by sentry/socket/hostinet.
func hostInetFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_ACCEPT4: []seccomp.Rule{
+ unix.SYS_ACCEPT4: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC),
},
},
- syscall.SYS_BIND: {},
- syscall.SYS_CONNECT: {},
- syscall.SYS_GETPEERNAME: {},
- syscall.SYS_GETSOCKNAME: {},
- syscall.SYS_GETSOCKOPT: []seccomp.Rule{
+ unix.SYS_BIND: {},
+ unix.SYS_CONNECT: {},
+ unix.SYS_GETPEERNAME: {},
+ unix.SYS_GETSOCKNAME: {},
+ unix.SYS_GETSOCKOPT: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_TOS),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_TOS),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_RECVTOS),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_RECVTOS),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_PKTINFO),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_PKTINFO),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_RECVORIGDSTADDR),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_RECVERR),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_RECVERR),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_TCLASS),
+ seccomp.EqualTo(unix.SOL_IPV6),
+ seccomp.EqualTo(unix.IPV6_TCLASS),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_RECVTCLASS),
+ seccomp.EqualTo(unix.SOL_IPV6),
+ seccomp.EqualTo(unix.IPV6_RECVTCLASS),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_RECVERR),
+ seccomp.EqualTo(unix.SOL_IPV6),
+ seccomp.EqualTo(unix.IPV6_RECVERR),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_V6ONLY),
+ seccomp.EqualTo(unix.SOL_IPV6),
+ seccomp.EqualTo(unix.IPV6_V6ONLY),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(unix.SOL_IPV6),
seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_ERROR),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_ERROR),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_KEEPALIVE),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_KEEPALIVE),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_SNDBUF),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_SNDBUF),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_RCVBUF),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_RCVBUF),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_REUSEADDR),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_REUSEADDR),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_TYPE),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_TYPE),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_LINGER),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_LINGER),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_TIMESTAMP),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_TIMESTAMP),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_TCP),
- seccomp.EqualTo(syscall.TCP_NODELAY),
+ seccomp.EqualTo(unix.SOL_TCP),
+ seccomp.EqualTo(unix.TCP_NODELAY),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_TCP),
- seccomp.EqualTo(syscall.TCP_INFO),
+ seccomp.EqualTo(unix.SOL_TCP),
+ seccomp.EqualTo(unix.TCP_INFO),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(unix.SOL_TCP),
seccomp.EqualTo(linux.TCP_INQ),
},
},
- syscall.SYS_IOCTL: []seccomp.Rule{
+ unix.SYS_IOCTL: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.TIOCOUTQ),
+ seccomp.EqualTo(unix.TIOCOUTQ),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.TIOCINQ),
+ seccomp.EqualTo(unix.TIOCINQ),
},
},
- syscall.SYS_LISTEN: {},
- syscall.SYS_READV: {},
- syscall.SYS_RECVFROM: {},
- syscall.SYS_RECVMSG: {},
- syscall.SYS_SENDMSG: {},
- syscall.SYS_SENDTO: {},
- syscall.SYS_SETSOCKOPT: []seccomp.Rule{
+ unix.SYS_LISTEN: {},
+ unix.SYS_READV: {},
+ unix.SYS_RECVFROM: {},
+ unix.SYS_RECVMSG: {},
+ unix.SYS_SENDMSG: {},
+ unix.SYS_SENDTO: {},
+ unix.SYS_SETSOCKOPT: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_SNDBUF),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_SNDBUF),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_RCVBUF),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_RCVBUF),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_REUSEADDR),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_REUSEADDR),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_TIMESTAMP),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_TIMESTAMP),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_TCP),
- seccomp.EqualTo(syscall.TCP_NODELAY),
+ seccomp.EqualTo(unix.SOL_TCP),
+ seccomp.EqualTo(unix.TCP_NODELAY),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_TCP),
+ seccomp.EqualTo(unix.SOL_TCP),
seccomp.EqualTo(linux.TCP_INQ),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_TOS),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_TOS),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_RECVTOS),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_RECVTOS),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_PKTINFO),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_PKTINFO),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_RECVORIGDSTADDR),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IP),
- seccomp.EqualTo(syscall.IP_RECVERR),
+ seccomp.EqualTo(unix.SOL_IP),
+ seccomp.EqualTo(unix.IP_RECVERR),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_TCLASS),
+ seccomp.EqualTo(unix.SOL_IPV6),
+ seccomp.EqualTo(unix.IPV6_TCLASS),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_RECVTCLASS),
+ seccomp.EqualTo(unix.SOL_IPV6),
+ seccomp.EqualTo(unix.IPV6_RECVTCLASS),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
+ seccomp.EqualTo(unix.SOL_IPV6),
seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_RECVERR),
+ seccomp.EqualTo(unix.SOL_IPV6),
+ seccomp.EqualTo(unix.IPV6_RECVERR),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_IPV6),
- seccomp.EqualTo(syscall.IPV6_V6ONLY),
+ seccomp.EqualTo(unix.SOL_IPV6),
+ seccomp.EqualTo(unix.IPV6_V6ONLY),
seccomp.MatchAny{},
seccomp.EqualTo(4),
},
},
- syscall.SYS_SHUTDOWN: []seccomp.Rule{
+ unix.SYS_SHUTDOWN: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SHUT_RD),
+ seccomp.EqualTo(unix.SHUT_RD),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SHUT_WR),
+ seccomp.EqualTo(unix.SHUT_WR),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SHUT_RDWR),
+ seccomp.EqualTo(unix.SHUT_RDWR),
},
},
- syscall.SYS_SOCKET: []seccomp.Rule{
+ unix.SYS_SOCKET: []seccomp.Rule{
{
- seccomp.EqualTo(syscall.AF_INET),
- seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(unix.AF_INET),
+ seccomp.EqualTo(unix.SOCK_STREAM | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC),
seccomp.EqualTo(0),
},
{
- seccomp.EqualTo(syscall.AF_INET),
- seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(unix.AF_INET),
+ seccomp.EqualTo(unix.SOCK_DGRAM | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC),
seccomp.EqualTo(0),
},
{
- seccomp.EqualTo(syscall.AF_INET6),
- seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(unix.AF_INET6),
+ seccomp.EqualTo(unix.SOCK_STREAM | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC),
seccomp.EqualTo(0),
},
{
- seccomp.EqualTo(syscall.AF_INET6),
- seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(unix.AF_INET6),
+ seccomp.EqualTo(unix.SOCK_DGRAM | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC),
seccomp.EqualTo(0),
},
},
- syscall.SYS_WRITEV: {},
+ unix.SYS_WRITEV: {},
}
}
func controlServerFilters(fd int) seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_ACCEPT: []seccomp.Rule{
+ unix.SYS_ACCEPT: []seccomp.Rule{
{
seccomp.EqualTo(fd),
},
},
- syscall.SYS_LISTEN: []seccomp.Rule{
+ unix.SYS_LISTEN: []seccomp.Rule{
{
seccomp.EqualTo(fd),
seccomp.EqualTo(16 /* unet.backlog */),
},
},
- syscall.SYS_GETSOCKOPT: []seccomp.Rule{
+ unix.SYS_GETSOCKOPT: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.SOL_SOCKET),
- seccomp.EqualTo(syscall.SO_PEERCRED),
+ seccomp.EqualTo(unix.SOL_SOCKET),
+ seccomp.EqualTo(unix.SO_PEERCRED),
},
},
}
diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go
index cea5613b8..42cb8ed3a 100644
--- a/runsc/boot/filter/config_amd64.go
+++ b/runsc/boot/filter/config_amd64.go
@@ -17,30 +17,29 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
)
func init() {
- allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{
+ allowedSyscalls[unix.SYS_ARCH_PRCTL] = []seccomp.Rule{
// TODO(b/168828518): No longer used in Go 1.16+.
{seccomp.EqualTo(linux.ARCH_SET_FS)},
}
- allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{
+ allowedSyscalls[unix.SYS_CLONE] = []seccomp.Rule{
// parent_tidptr and child_tidptr are always 0 because neither
// CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used.
{
seccomp.EqualTo(
- syscall.CLONE_VM |
- syscall.CLONE_FS |
- syscall.CLONE_FILES |
- syscall.CLONE_SETTLS |
- syscall.CLONE_SIGHAND |
- syscall.CLONE_SYSVSEM |
- syscall.CLONE_THREAD),
+ unix.CLONE_VM |
+ unix.CLONE_FS |
+ unix.CLONE_FILES |
+ unix.CLONE_SETTLS |
+ unix.CLONE_SIGHAND |
+ unix.CLONE_SYSVSEM |
+ unix.CLONE_THREAD),
seccomp.MatchAny{}, // newsp
seccomp.EqualTo(0), // parent_tidptr
seccomp.EqualTo(0), // child_tidptr
@@ -49,12 +48,12 @@ func init() {
{
// TODO(b/168828518): No longer used in Go 1.16+ (on amd64).
seccomp.EqualTo(
- syscall.CLONE_VM |
- syscall.CLONE_FS |
- syscall.CLONE_FILES |
- syscall.CLONE_SIGHAND |
- syscall.CLONE_SYSVSEM |
- syscall.CLONE_THREAD),
+ unix.CLONE_VM |
+ unix.CLONE_FS |
+ unix.CLONE_FILES |
+ unix.CLONE_SIGHAND |
+ unix.CLONE_SYSVSEM |
+ unix.CLONE_THREAD),
seccomp.MatchAny{}, // newsp
seccomp.EqualTo(0), // parent_tidptr
seccomp.EqualTo(0), // child_tidptr
diff --git a/runsc/boot/filter/config_arm64.go b/runsc/boot/filter/config_arm64.go
index 37313f97f..f162f87ff 100644
--- a/runsc/boot/filter/config_arm64.go
+++ b/runsc/boot/filter/config_arm64.go
@@ -17,21 +17,20 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/seccomp"
)
func init() {
- allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{
+ allowedSyscalls[unix.SYS_CLONE] = []seccomp.Rule{
{
seccomp.EqualTo(
- syscall.CLONE_VM |
- syscall.CLONE_FS |
- syscall.CLONE_FILES |
- syscall.CLONE_SIGHAND |
- syscall.CLONE_SYSVSEM |
- syscall.CLONE_THREAD),
+ unix.CLONE_VM |
+ unix.CLONE_FS |
+ unix.CLONE_FILES |
+ unix.CLONE_SIGHAND |
+ unix.CLONE_SYSVSEM |
+ unix.CLONE_THREAD),
seccomp.MatchAny{}, // newsp
// These arguments are left uninitialized by the Go
// runtime, so they may be anything (and are unused by
diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go
index 7b8669595..89b66a6da 100644
--- a/runsc/boot/filter/config_profile.go
+++ b/runsc/boot/filter/config_profile.go
@@ -15,19 +15,18 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/seccomp"
)
// profileFilters returns extra syscalls made by runtime/pprof package.
func profileFilters() seccomp.SyscallRules {
return seccomp.SyscallRules{
- syscall.SYS_OPENAT: []seccomp.Rule{
+ unix.SYS_OPENAT: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC),
+ seccomp.EqualTo(unix.O_RDONLY | unix.O_LARGEFILE | unix.O_CLOEXEC),
},
},
}
diff --git a/runsc/boot/filter/extra_filters_msan.go b/runsc/boot/filter/extra_filters_msan.go
index 209e646a7..41baa78cd 100644
--- a/runsc/boot/filter/extra_filters_msan.go
+++ b/runsc/boot/filter/extra_filters_msan.go
@@ -17,8 +17,7 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/seccomp"
)
@@ -26,9 +25,9 @@ import (
func instrumentationFilters() seccomp.SyscallRules {
Report("MSAN is enabled: syscall filters less restrictive!")
return seccomp.SyscallRules{
- syscall.SYS_CLONE: {},
- syscall.SYS_MMAP: {},
- syscall.SYS_SCHED_GETAFFINITY: {},
- syscall.SYS_SET_ROBUST_LIST: {},
+ unix.SYS_CLONE: {},
+ unix.SYS_MMAP: {},
+ unix.SYS_SCHED_GETAFFINITY: {},
+ unix.SYS_SET_ROBUST_LIST: {},
}
}
diff --git a/runsc/boot/filter/extra_filters_race.go b/runsc/boot/filter/extra_filters_race.go
index 5b99eb8cd..79b2104f0 100644
--- a/runsc/boot/filter/extra_filters_race.go
+++ b/runsc/boot/filter/extra_filters_race.go
@@ -17,8 +17,7 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/seccomp"
)
@@ -26,17 +25,17 @@ import (
func instrumentationFilters() seccomp.SyscallRules {
Report("TSAN is enabled: syscall filters less restrictive!")
return seccomp.SyscallRules{
- syscall.SYS_BRK: {},
- syscall.SYS_CLOCK_NANOSLEEP: {},
- syscall.SYS_CLONE: {},
- syscall.SYS_FUTEX: {},
- syscall.SYS_MMAP: {},
- syscall.SYS_MUNLOCK: {},
- syscall.SYS_NANOSLEEP: {},
- syscall.SYS_OPEN: {},
- syscall.SYS_OPENAT: {},
- syscall.SYS_SET_ROBUST_LIST: {},
+ unix.SYS_BRK: {},
+ unix.SYS_CLOCK_NANOSLEEP: {},
+ unix.SYS_CLONE: {},
+ unix.SYS_FUTEX: {},
+ unix.SYS_MMAP: {},
+ unix.SYS_MUNLOCK: {},
+ unix.SYS_NANOSLEEP: {},
+ unix.SYS_OPEN: {},
+ unix.SYS_OPENAT: {},
+ unix.SYS_SET_ROBUST_LIST: {},
// Used within glibc's malloc.
- syscall.SYS_TIME: {},
+ unix.SYS_TIME: {},
}
}
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 2b0d2cd51..77f632bb9 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -20,9 +20,9 @@ import (
"sort"
"strconv"
"strings"
- "syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fd"
@@ -312,11 +312,11 @@ func setupContainerFS(ctx context.Context, conf *config.Config, mntr *containerM
}
func adjustDirentCache(k *kernel.Kernel) error {
- var hl syscall.Rlimit
- if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &hl); err != nil {
+ var hl unix.Rlimit
+ if err := unix.Getrlimit(unix.RLIMIT_NOFILE, &hl); err != nil {
return fmt.Errorf("getting RLIMIT_NOFILE: %v", err)
}
- if int64(hl.Cur) != syscall.RLIM_INFINITY {
+ if hl.Cur != unix.RLIM_INFINITY {
newSize := hl.Cur / 2
if newSize < gofer.DefaultDirentCacheSize {
log.Infof("Setting gofer dirent cache size to %d", newSize)
@@ -844,10 +844,10 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Confi
// than simply printed to the logs for the 'runsc boot' command.
//
// We check the error message string rather than type because the
- // actual error types (syscall.EIO, syscall.EPIPE) are lost by file system
+ // actual error types (unix.EIO, unix.EPIPE) are lost by file system
// implementation (e.g. p9).
// TODO(gvisor.dev/issue/1765): Remove message when bug is resolved.
- if strings.Contains(err.Error(), syscall.EIO.Error()) || strings.Contains(err.Error(), syscall.EPIPE.Error()) {
+ if strings.Contains(err.Error(), unix.EIO.Error()) || strings.Contains(err.Error(), unix.EPIPE.Error()) {
return fmt.Errorf("%v: %s", err, specutils.FaqErrorMsg("memlock", "you may be encountering a Linux kernel bug"))
}
return err
diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go
index ce62236e5..3d2b3506d 100644
--- a/runsc/boot/limits.go
+++ b/runsc/boot/limits.go
@@ -16,9 +16,9 @@ package boot
import (
"fmt"
- "syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/limits"
"gvisor.dev/gvisor/pkg/sync"
@@ -104,9 +104,9 @@ func (d *defs) initDefaults() error {
// Read host limits that directly affect the sandbox and adjust the defaults
// based on them.
- for _, res := range []int{syscall.RLIMIT_FSIZE, syscall.RLIMIT_NOFILE} {
- var hl syscall.Rlimit
- if err := syscall.Getrlimit(res, &hl); err != nil {
+ for _, res := range []int{unix.RLIMIT_FSIZE, unix.RLIMIT_NOFILE} {
+ var hl unix.Rlimit
+ if err := unix.Getrlimit(res, &hl); err != nil {
return err
}
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index b77b4762e..3121ca6eb 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -19,7 +19,6 @@ import (
"math/rand"
"os"
"reflect"
- "syscall"
"testing"
"time"
@@ -78,7 +77,7 @@ func testSpec() *specs.Spec {
// sandbox side of the connection, and a function that when called will stop the
// gofer.
func startGofer(root string) (int, func(), error) {
- fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0)
+ fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return 0, nil, err
}
@@ -86,8 +85,8 @@ func startGofer(root string) (int, func(), error) {
socket, err := unet.NewSocket(goferEnd)
if err != nil {
- syscall.Close(sandboxEnd)
- syscall.Close(goferEnd)
+ unix.Close(sandboxEnd)
+ unix.Close(goferEnd)
return 0, nil, fmt.Errorf("error creating server on FD %d: %v", goferEnd, err)
}
at, err := fsgofer.NewAttachPoint(root, fsgofer.Config{ROMount: true})
diff --git a/runsc/boot/network.go b/runsc/boot/network.go
index 3d3a813df..7e627e4c6 100644
--- a/runsc/boot/network.go
+++ b/runsc/boot/network.go
@@ -19,8 +19,8 @@ import (
"net"
"runtime"
"strings"
- "syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
@@ -195,7 +195,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct
for j := 0; j < link.NumChannels; j++ {
// Copy the underlying FD.
oldFD := args.FilePayload.Files[fdOffset].Fd()
- newFD, err := syscall.Dup(int(oldFD))
+ newFD, err := unix.Dup(int(oldFD))
if err != nil {
return fmt.Errorf("failed to dup FD %v: %v", oldFD, err)
}
diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go
index ac9e4e3a8..438b7ef3e 100644
--- a/runsc/cgroup/cgroup.go
+++ b/runsc/cgroup/cgroup.go
@@ -27,7 +27,6 @@ import (
"path/filepath"
"strconv"
"strings"
- "syscall"
"time"
"github.com/cenkalti/backoff"
@@ -111,7 +110,7 @@ func setValue(path, name, data string) error {
err := ioutil.WriteFile(fullpath, []byte(data), 0700)
if err == nil {
return nil
- } else if !errors.Is(err, syscall.EINTR) {
+ } else if !errors.Is(err, unix.EINTR) {
return err
}
}
@@ -161,7 +160,7 @@ func fillFromAncestor(path string) (string, error) {
err := ioutil.WriteFile(path, []byte(val), 0700)
if err == nil {
break
- } else if !errors.Is(err, syscall.EINTR) {
+ } else if !errors.Is(err, unix.EINTR) {
return "", err
}
}
@@ -337,7 +336,7 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error {
c.Own[key] = true
if err := os.MkdirAll(path, 0755); err != nil {
- if cfg.optional && errors.Is(err, syscall.EROFS) {
+ if cfg.optional && errors.Is(err, unix.EROFS) {
log.Infof("Skipping cgroup %q", key)
continue
}
@@ -370,7 +369,7 @@ func (c *Cgroup) Uninstall() error {
defer cancel()
b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
fn := func() error {
- err := syscall.Rmdir(path)
+ err := unix.Rmdir(path)
if os.IsNotExist(err) {
return nil
}
diff --git a/runsc/cli/BUILD b/runsc/cli/BUILD
index 32cce2a18..f1e3cce68 100644
--- a/runsc/cli/BUILD
+++ b/runsc/cli/BUILD
@@ -18,5 +18,6 @@ go_library(
"//runsc/flag",
"//runsc/specutils",
"@com_github_google_subcommands//:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/cli/main.go b/runsc/cli/main.go
index bf6928941..a3c515f4b 100644
--- a/runsc/cli/main.go
+++ b/runsc/cli/main.go
@@ -23,10 +23,10 @@ import (
"os"
"os/signal"
"runtime"
- "syscall"
"time"
"github.com/google/subcommands"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/platform"
@@ -198,7 +198,7 @@ func Main(version string) {
// want with them. Since Docker and Containerd both eat boot's stderr, we
// dup our stderr to the provided log FD so that panics will appear in the
// logs, rather than just disappear.
- if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
+ if err := unix.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil {
cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err)
}
} else if conf.AlsoLogToStderr {
@@ -227,11 +227,11 @@ func Main(version string) {
// SIGTERM is sent to all processes if a test exceeds its
// timeout and this case is handled by syscall_test_runner.
log.Warningf("Block the TERM signal. This is only safe in tests!")
- signal.Ignore(syscall.SIGTERM)
+ signal.Ignore(unix.SIGTERM)
}
// Call the subcommand and pass in the configuration.
- var ws syscall.WaitStatus
+ var ws unix.WaitStatus
subcmdCode := subcommands.Execute(context.Background(), conf, &ws)
if subcmdCode == subcommands.ExitSuccess {
log.Infof("Exiting with status: %v", ws)
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index e3e289da3..2c3b4058b 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -77,6 +77,7 @@ go_test(
"delete_test.go",
"exec_test.go",
"gofer_test.go",
+ "mitigate_test.go",
],
data = [
"//runsc",
@@ -91,6 +92,8 @@ go_test(
"//pkg/urpc",
"//runsc/config",
"//runsc/container",
+ "//runsc/mitigate",
+ "//runsc/mitigate/mock",
"//runsc/specutils",
"@com_github_google_go_cmp//cmp:go_default_library",
"@com_github_google_go_cmp//cmp/cmpopts:go_default_library",
diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go
index 2c92e3067..a14249641 100644
--- a/runsc/cmd/boot.go
+++ b/runsc/cmd/boot.go
@@ -19,7 +19,6 @@ import (
"os"
"runtime/debug"
"strings"
- "syscall"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
@@ -259,8 +258,8 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
ws := l.WaitExit()
log.Infof("application exiting with %+v", ws)
- waitStatus := args[1].(*syscall.WaitStatus)
- *waitStatus = syscall.WaitStatus(ws.Status())
+ waitStatus := args[1].(*unix.WaitStatus)
+ *waitStatus = unix.WaitStatus(ws.Status())
l.Destroy()
return subcommands.ExitSuccess
}
diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go
index 124198239..a9dbe86de 100644
--- a/runsc/cmd/checkpoint.go
+++ b/runsc/cmd/checkpoint.go
@@ -18,9 +18,9 @@ import (
"context"
"os"
"path/filepath"
- "syscall"
"github.com/google/subcommands"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/container"
@@ -73,7 +73,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa
id := f.Arg(0)
conf := args[0].(*config.Config)
- waitStatus := args[1].(*syscall.WaitStatus)
+ waitStatus := args[1].(*unix.WaitStatus)
cont, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
diff --git a/runsc/cmd/chroot.go b/runsc/cmd/chroot.go
index 189244765..e988247da 100644
--- a/runsc/cmd/chroot.go
+++ b/runsc/cmd/chroot.go
@@ -18,8 +18,8 @@ import (
"fmt"
"os"
"path/filepath"
- "syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -49,11 +49,11 @@ func pivotRoot(root string) error {
// will be moved to "/" too. The parent mount of the old_root will be
// new_root, so after umounting the old_root, we will see only
// the new_root in "/".
- if err := syscall.PivotRoot(".", "."); err != nil {
+ if err := unix.PivotRoot(".", "."); err != nil {
return fmt.Errorf("pivot_root failed, make sure that the root mount has a parent: %v", err)
}
- if err := syscall.Unmount(".", syscall.MNT_DETACH); err != nil {
+ if err := unix.Unmount(".", unix.MNT_DETACH); err != nil {
return fmt.Errorf("error umounting the old root file system: %v", err)
}
return nil
@@ -70,26 +70,26 @@ func setUpChroot(pidns bool) error {
// Convert all shared mounts into slave to be sure that nothing will be
// propagated outside of our namespace.
- if err := syscall.Mount("", "/", "", syscall.MS_SLAVE|syscall.MS_REC, ""); err != nil {
+ if err := unix.Mount("", "/", "", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {
return fmt.Errorf("error converting mounts: %v", err)
}
- if err := syscall.Mount("runsc-root", chroot, "tmpfs", syscall.MS_NOSUID|syscall.MS_NODEV|syscall.MS_NOEXEC, ""); err != nil {
+ if err := unix.Mount("runsc-root", chroot, "tmpfs", unix.MS_NOSUID|unix.MS_NODEV|unix.MS_NOEXEC, ""); err != nil {
return fmt.Errorf("error mounting tmpfs in choot: %v", err)
}
if pidns {
- flags := uint32(syscall.MS_NOSUID | syscall.MS_NODEV | syscall.MS_NOEXEC | syscall.MS_RDONLY)
+ flags := uint32(unix.MS_NOSUID | unix.MS_NODEV | unix.MS_NOEXEC | unix.MS_RDONLY)
if err := mountInChroot(chroot, "proc", "/proc", "proc", flags); err != nil {
return fmt.Errorf("error mounting proc in chroot: %v", err)
}
} else {
- if err := mountInChroot(chroot, "/proc", "/proc", "bind", syscall.MS_BIND|syscall.MS_RDONLY|syscall.MS_REC); err != nil {
+ if err := mountInChroot(chroot, "/proc", "/proc", "bind", unix.MS_BIND|unix.MS_RDONLY|unix.MS_REC); err != nil {
return fmt.Errorf("error mounting proc in chroot: %v", err)
}
}
- if err := syscall.Mount("", chroot, "", syscall.MS_REMOUNT|syscall.MS_RDONLY|syscall.MS_BIND, ""); err != nil {
+ if err := unix.Mount("", chroot, "", unix.MS_REMOUNT|unix.MS_RDONLY|unix.MS_BIND, ""); err != nil {
return fmt.Errorf("error remounting chroot in read-only: %v", err)
}
diff --git a/runsc/cmd/cmd.go b/runsc/cmd/cmd.go
index f1a4887ef..4dd55cc33 100644
--- a/runsc/cmd/cmd.go
+++ b/runsc/cmd/cmd.go
@@ -19,9 +19,9 @@ import (
"fmt"
"runtime"
"strconv"
- "syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/specutils"
)
@@ -71,7 +71,7 @@ func setCapsAndCallSelf(args []string, caps *specs.LinuxCapabilities) error {
binPath := specutils.ExePath
log.Infof("Execve %q again, bye!", binPath)
- err := syscall.Exec(binPath, args, []string{})
+ err := unix.Exec(binPath, args, []string{})
return fmt.Errorf("error executing %s: %v", binPath, err)
}
@@ -83,16 +83,16 @@ func callSelfAsNobody(args []string) error {
const nobody = 65534
- if _, _, err := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(nobody), 0, 0); err != 0 {
+ if _, _, err := unix.RawSyscall(unix.SYS_SETGID, uintptr(nobody), 0, 0); err != 0 {
return fmt.Errorf("error setting uid: %v", err)
}
- if _, _, err := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(nobody), 0, 0); err != 0 {
+ if _, _, err := unix.RawSyscall(unix.SYS_SETUID, uintptr(nobody), 0, 0); err != 0 {
return fmt.Errorf("error setting gid: %v", err)
}
binPath := specutils.ExePath
log.Infof("Execve %q again, bye!", binPath)
- err := syscall.Exec(binPath, args, []string{})
+ err := unix.Exec(binPath, args, []string{})
return fmt.Errorf("error executing %s: %v", binPath, err)
}
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index b84142b0d..6212ffb2e 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -21,10 +21,10 @@ import (
"strconv"
"strings"
"sync"
- "syscall"
"time"
"github.com/google/subcommands"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/runsc/config"
@@ -135,7 +135,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Perform synchronous actions.
if d.signal > 0 {
log.Infof("Sending signal %d to process: %d", d.signal, c.Sandbox.Pid)
- if err := syscall.Kill(c.Sandbox.Pid, syscall.Signal(d.signal)); err != nil {
+ if err := unix.Kill(c.Sandbox.Pid, unix.Signal(d.signal)); err != nil {
return Errorf("failed to send signal %d to processs %d", d.signal, c.Sandbox.Pid)
}
}
@@ -317,7 +317,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
wg.Wait()
}()
signals := make(chan os.Signal, 1)
- signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
+ signal.Notify(signals, unix.SIGTERM, unix.SIGINT)
select {
case <-readyChan:
break // Safe to proceed.
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 8a8d9f752..22c1dfeb8 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -26,10 +26,10 @@ import (
"path/filepath"
"strconv"
"strings"
- "syscall"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/container"
@@ -86,7 +86,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
}
conf := args[0].(*config.Config)
- waitStatus := args[1].(*syscall.WaitStatus)
+ waitStatus := args[1].(*unix.WaitStatus)
if conf.Rootless {
if err := specutils.MaybeRunAsRoot(); err != nil {
@@ -225,7 +225,7 @@ func resolvePath(path string) (string, error) {
return "", fmt.Errorf("resolving %q: %v", path, err)
}
path = filepath.Clean(path)
- if err := syscall.Access(path, 0); err != nil {
+ if err := unix.Access(path, 0); err != nil {
return "", fmt.Errorf("unable to access %q: %v", path, err)
}
return path, nil
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index e9726401a..242d474b8 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -24,11 +24,11 @@ import (
"path/filepath"
"strconv"
"strings"
- "syscall"
"time"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -110,7 +110,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
if err != nil {
Fatalf("parsing process spec: %v", err)
}
- waitStatus := args[1].(*syscall.WaitStatus)
+ waitStatus := args[1].(*unix.WaitStatus)
c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{})
if err != nil {
@@ -149,7 +149,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return ex.exec(c, e, waitStatus)
}
-func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *syscall.WaitStatus) subcommands.ExitStatus {
+func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *unix.WaitStatus) subcommands.ExitStatus {
// Start the new process and get its pid.
pid, err := c.Execute(e)
if err != nil {
@@ -189,7 +189,7 @@ func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *sy
return subcommands.ExitSuccess
}
-func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.ExitStatus {
+func (ex *Exec) execChildAndWait(waitStatus *unix.WaitStatus) subcommands.ExitStatus {
var args []string
for _, a := range os.Args[1:] {
if !strings.Contains(a, "detach") {
@@ -233,7 +233,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi
cmd.Stdin = tty
cmd.Stdout = tty
cmd.Stderr = tty
- cmd.SysProcAttr = &syscall.SysProcAttr{
+ cmd.SysProcAttr = &unix.SysProcAttr{
Setsid: true,
Setctty: true,
// The Ctty FD must be the FD in the child process's FD
@@ -263,7 +263,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi
}
return pid == cmd.Process.Pid, nil
}
- if pe, ok := err.(*os.PathError); !ok || pe.Err != syscall.ENOENT {
+ if pe, ok := err.(*os.PathError); !ok || pe.Err != unix.ENOENT {
return false, err
}
// No file yet, continue to wait...
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 371fcc0ae..639b2219c 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -21,7 +21,6 @@ import (
"os"
"path/filepath"
"strings"
- "syscall"
"github.com/google/subcommands"
specs "github.com/opencontainers/runtime-spec/specs-go"
@@ -149,16 +148,16 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// fsgofer should run with a umask of 0, because we want to preserve file
// modes exactly as sent by the sandbox, which will have applied its own umask.
- syscall.Umask(0)
+ unix.Umask(0)
if err := fsgofer.OpenProcSelfFD(); err != nil {
Fatalf("failed to open /proc/self/fd: %v", err)
}
- if err := syscall.Chroot(root); err != nil {
+ if err := unix.Chroot(root); err != nil {
Fatalf("failed to chroot to %q: %v", root, err)
}
- if err := syscall.Chdir("/"); err != nil {
+ if err := unix.Chdir("/"); err != nil {
Fatalf("changing working dir: %v", err)
}
log.Infof("Process chroot'd to %q", root)
@@ -166,7 +165,8 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
// Start with root mount, then add any other additional mount as needed.
ats := make([]p9.Attacher, 0, len(spec.Mounts)+1)
ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{
- ROMount: spec.Root.Readonly || conf.Overlay,
+ ROMount: spec.Root.Readonly || conf.Overlay,
+ EnableXattr: conf.Verity,
})
if err != nil {
Fatalf("creating attach point: %v", err)
@@ -178,8 +178,9 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
for _, m := range spec.Mounts {
if specutils.Is9PMount(m) {
cfg := fsgofer.Config{
- ROMount: isReadonlyMount(m.Options) || conf.Overlay,
- HostUDS: conf.FSGoferHostUDS,
+ ROMount: isReadonlyMount(m.Options) || conf.Overlay,
+ HostUDS: conf.FSGoferHostUDS,
+ EnableXattr: conf.Verity,
}
ap, err := fsgofer.NewAttachPoint(m.Destination, cfg)
if err != nil {
@@ -262,7 +263,7 @@ func isReadonlyMount(opts []string) bool {
func setupRootFS(spec *specs.Spec, conf *config.Config) error {
// Convert all shared mounts into slaves to be sure that nothing will be
// propagated outside of our namespace.
- if err := syscall.Mount("", "/", "", syscall.MS_SLAVE|syscall.MS_REC, ""); err != nil {
+ if err := unix.Mount("", "/", "", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {
Fatalf("error converting mounts: %v", err)
}
@@ -274,30 +275,30 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error {
//
// We need a directory to construct a new root and we know that
// runsc can't start without /proc, so we can use it for this.
- flags := uintptr(syscall.MS_NOSUID | syscall.MS_NODEV | syscall.MS_NOEXEC)
- if err := syscall.Mount("runsc-root", "/proc", "tmpfs", flags, ""); err != nil {
+ flags := uintptr(unix.MS_NOSUID | unix.MS_NODEV | unix.MS_NOEXEC)
+ if err := unix.Mount("runsc-root", "/proc", "tmpfs", flags, ""); err != nil {
Fatalf("error mounting tmpfs: %v", err)
}
// Prepare tree structure for pivot_root(2).
os.Mkdir("/proc/proc", 0755)
os.Mkdir("/proc/root", 0755)
- if err := syscall.Mount("runsc-proc", "/proc/proc", "proc", flags|syscall.MS_RDONLY, ""); err != nil {
+ if err := unix.Mount("runsc-proc", "/proc/proc", "proc", flags|unix.MS_RDONLY, ""); err != nil {
Fatalf("error mounting proc: %v", err)
}
root = "/proc/root"
}
// Mount root path followed by submounts.
- if err := syscall.Mount(spec.Root.Path, root, "bind", syscall.MS_BIND|syscall.MS_REC, ""); err != nil {
+ if err := unix.Mount(spec.Root.Path, root, "bind", unix.MS_BIND|unix.MS_REC, ""); err != nil {
return fmt.Errorf("mounting root on root (%q) err: %v", root, err)
}
- flags := uint32(syscall.MS_SLAVE | syscall.MS_REC)
+ flags := uint32(unix.MS_SLAVE | unix.MS_REC)
if spec.Linux != nil && spec.Linux.RootfsPropagation != "" {
flags = specutils.PropOptionsToFlags([]string{spec.Linux.RootfsPropagation})
}
- if err := syscall.Mount("", root, "", uintptr(flags), ""); err != nil {
+ if err := unix.Mount("", root, "", uintptr(flags), ""); err != nil {
return fmt.Errorf("mounting root (%q) with flags: %#x, err: %v", root, flags, err)
}
@@ -323,8 +324,8 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error {
// If root is a mount point but not read-only, we can change mount options
// to make it read-only for extra safety.
log.Infof("Remounting root as readonly: %q", root)
- flags := uintptr(syscall.MS_BIND | syscall.MS_REMOUNT | syscall.MS_RDONLY | syscall.MS_REC)
- if err := syscall.Mount(root, root, "bind", flags, ""); err != nil {
+ flags := uintptr(unix.MS_BIND | unix.MS_REMOUNT | unix.MS_RDONLY | unix.MS_REC)
+ if err := unix.Mount(root, root, "bind", flags, ""); err != nil {
return fmt.Errorf("remounting root as read-only with source: %q, target: %q, flags: %#x, err: %v", root, root, flags, err)
}
}
@@ -354,10 +355,10 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error {
return fmt.Errorf("resolving symlinks to %q: %v", m.Destination, err)
}
- flags := specutils.OptionsToFlags(m.Options) | syscall.MS_BIND
+ flags := specutils.OptionsToFlags(m.Options) | unix.MS_BIND
if conf.Overlay {
// Force mount read-only if writes are not going to be sent to it.
- flags |= syscall.MS_RDONLY
+ flags |= unix.MS_RDONLY
}
log.Infof("Mounting src: %q, dst: %q, flags: %#x", m.Source, dst, flags)
@@ -368,7 +369,7 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error {
// Set propagation options that cannot be set together with other options.
flags = specutils.PropOptionsToFlags(m.Options)
if flags != 0 {
- if err := syscall.Mount("", dst, "", uintptr(flags), ""); err != nil {
+ if err := unix.Mount("", dst, "", uintptr(flags), ""); err != nil {
return fmt.Errorf("mount dst: %q, flags: %#x, err: %v", dst, flags, err)
}
}
@@ -469,8 +470,8 @@ func adjustMountOptions(conf *config.Config, path string, opts []string) ([]stri
copy(rv, opts)
if conf.OverlayfsStaleRead {
- statfs := syscall.Statfs_t{}
- if err := syscall.Statfs(path, &statfs); err != nil {
+ statfs := unix.Statfs_t{}
+ if err := unix.Statfs(path, &statfs); err != nil {
return nil, err
}
if statfs.Type == unix.OVERLAYFS_SUPER_MAGIC {
diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go
index e0df39266..239fc7ac2 100644
--- a/runsc/cmd/kill.go
+++ b/runsc/cmd/kill.go
@@ -19,7 +19,6 @@ import (
"fmt"
"strconv"
"strings"
- "syscall"
"github.com/google/subcommands"
"golang.org/x/sys/unix"
@@ -99,10 +98,10 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return subcommands.ExitSuccess
}
-func parseSignal(s string) (syscall.Signal, error) {
+func parseSignal(s string) (unix.Signal, error) {
n, err := strconv.Atoi(s)
if err == nil {
- sig := syscall.Signal(n)
+ sig := unix.Signal(n)
for _, msig := range signalMap {
if sig == msig {
return sig, nil
@@ -116,7 +115,7 @@ func parseSignal(s string) (syscall.Signal, error) {
return -1, fmt.Errorf("unknown signal %q", s)
}
-var signalMap = map[string]syscall.Signal{
+var signalMap = map[string]unix.Signal{
"ABRT": unix.SIGABRT,
"ALRM": unix.SIGALRM,
"BUS": unix.SIGBUS,
diff --git a/runsc/cmd/mitigate.go b/runsc/cmd/mitigate.go
index 822af1917..fddf0e0dd 100644
--- a/runsc/cmd/mitigate.go
+++ b/runsc/cmd/mitigate.go
@@ -16,6 +16,8 @@ package cmd
import (
"context"
+ "fmt"
+ "io/ioutil"
"github.com/google/subcommands"
"gvisor.dev/gvisor/pkg/log"
@@ -23,9 +25,23 @@ import (
"gvisor.dev/gvisor/runsc/mitigate"
)
+const (
+ // cpuInfo is the path used to parse CPU info.
+ cpuInfo = "/proc/cpuinfo"
+ // allPossibleCPUs is the path used to enable CPUs.
+ allPossibleCPUs = "/sys/devices/system/cpu/possible"
+)
+
// Mitigate implements subcommands.Command for the "mitigate" command.
type Mitigate struct {
- mitigate mitigate.Mitigate
+ // Run the command without changing the underlying system.
+ dryRun bool
+ // Reverse mitigate by turning on all CPU cores.
+ reverse bool
+ // Path to file to read to create CPUSet.
+ path string
+ // Callback to check if a given thread is vulnerable.
+ vulnerable func(other mitigate.Thread) bool
}
// Name implements subcommands.command.name.
@@ -38,14 +54,19 @@ func (*Mitigate) Synopsis() string {
return "mitigate mitigates the underlying system against side channel attacks"
}
-// Usage implements subcommands.Command.Usage.
-func (m *Mitigate) Usage() string {
- return m.mitigate.Usage()
+// Usage implments Usage for cmd.Mitigate.
+func (m Mitigate) Usage() string {
+ return `mitigate [flags]
+
+mitigate mitigates a system to the "MDS" vulnerability by implementing a manual shutdown of SMT. The command checks /proc/cpuinfo for cpus having the MDS vulnerability, and if found, shutdown all but one CPU per hyperthread pair via /sys/devices/system/cpu/cpu{N}/online. CPUs can be restored by writing "2" to each file in /sys/devices/system/cpu/cpu{N}/online or performing a system reboot.
+
+The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online.`
}
-// SetFlags implements subcommands.Command.SetFlags.
+// SetFlags sets flags for the command Mitigate.
func (m *Mitigate) SetFlags(f *flag.FlagSet) {
- m.mitigate.SetFlags(f)
+ f.BoolVar(&m.dryRun, "dryrun", false, "run the command without changing system")
+ f.BoolVar(&m.reverse, "reverse", false, "reverse mitigate by enabling all CPUs")
}
// Execute implements subcommands.Command.Execute.
@@ -55,10 +76,97 @@ func (m *Mitigate) Execute(_ context.Context, f *flag.FlagSet, args ...interface
return subcommands.ExitUsageError
}
- if err := m.mitigate.Execute(); err != nil {
+ m.path = cpuInfo
+ if m.reverse {
+ m.path = allPossibleCPUs
+ }
+
+ m.vulnerable = func(other mitigate.Thread) bool {
+ return other.IsVulnerable()
+ }
+
+ if _, err := m.doExecute(); err != nil {
log.Warningf("Execute failed: %v", err)
return subcommands.ExitFailure
}
return subcommands.ExitSuccess
}
+
+// Execute executes the Mitigate command.
+func (m *Mitigate) doExecute() (mitigate.CPUSet, error) {
+ if m.dryRun {
+ log.Infof("Running with DryRun. No cpu settings will be changed.")
+ }
+ if m.reverse {
+ data, err := ioutil.ReadFile(m.path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read %s: %v", m.path, err)
+ }
+
+ set, err := m.doReverse(data)
+ if err != nil {
+ return nil, fmt.Errorf("reverse operation failed: %v", err)
+ }
+ return set, nil
+ }
+
+ data, err := ioutil.ReadFile(m.path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read %s: %v", m.path, err)
+ }
+ set, err := m.doMitigate(data)
+ if err != nil {
+ return nil, fmt.Errorf("mitigate operation failed: %v", err)
+ }
+ return set, nil
+}
+
+func (m *Mitigate) doMitigate(data []byte) (mitigate.CPUSet, error) {
+ set, err := mitigate.NewCPUSet(data, m.vulnerable)
+ if err != nil {
+ return nil, err
+ }
+
+ log.Infof("Mitigate found the following CPUs...")
+ log.Infof("%s", set)
+
+ disableList := set.GetShutdownList()
+ log.Infof("Disabling threads on thread pairs.")
+ for _, t := range disableList {
+ log.Infof("Disable thread: %s", t)
+ if m.dryRun {
+ continue
+ }
+ if err := t.Disable(); err != nil {
+ return nil, fmt.Errorf("error disabling thread: %s err: %v", t, err)
+ }
+ }
+ log.Infof("Shutdown successful.")
+ return set, nil
+}
+
+func (m *Mitigate) doReverse(data []byte) (mitigate.CPUSet, error) {
+ set, err := mitigate.NewCPUSetFromPossible(data)
+ if err != nil {
+ return nil, err
+ }
+
+ log.Infof("Reverse mitigate found the following CPUs...")
+ log.Infof("%s", set)
+
+ enableList := set.GetRemainingList()
+
+ log.Infof("Enabling all CPUs...")
+ for _, t := range enableList {
+ log.Infof("Enabling thread: %s", t)
+ if m.dryRun {
+ continue
+ }
+ if err := t.Enable(); err != nil {
+ return nil, fmt.Errorf("error enabling thread: %s err: %v", t, err)
+ }
+ }
+ log.Infof("Enable successful.")
+ return set, nil
+}
diff --git a/runsc/cmd/mitigate_test.go b/runsc/cmd/mitigate_test.go
new file mode 100644
index 000000000..163fece42
--- /dev/null
+++ b/runsc/cmd/mitigate_test.go
@@ -0,0 +1,169 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/runsc/mitigate"
+ "gvisor.dev/gvisor/runsc/mitigate/mock"
+)
+
+type executeTestCase struct {
+ name string
+ mitigateData string
+ mitigateError error
+ mitigateCPU int
+ reverseData string
+ reverseError error
+ reverseCPU int
+}
+
+func TestExecute(t *testing.T) {
+
+ partial := `processor : 1
+vendor_id : AuthenticAMD
+cpu family : 23
+model : 49
+model name : AMD EPYC 7B12
+physical id : 0
+bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass
+power management:
+`
+
+ for _, tc := range []executeTestCase{
+ {
+ name: "CascadeLake4",
+ mitigateData: mock.CascadeLake4.MakeCPUString(),
+ mitigateCPU: 2,
+ reverseData: mock.CascadeLake4.MakeSysPossibleString(),
+ reverseCPU: 4,
+ },
+ {
+ name: "Empty",
+ mitigateData: "",
+ mitigateError: fmt.Errorf(`mitigate operation failed: no cpus found for: ""`),
+ reverseData: "",
+ reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from possible: ""`),
+ },
+ {
+ name: "Partial",
+ mitigateData: `processor : 0
+vendor_id : AuthenticAMD
+cpu family : 23
+model : 49
+model name : AMD EPYC 7B12
+physical id : 0
+core id : 0
+cpu cores : 1
+bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass
+power management::84
+
+` + partial,
+ mitigateError: fmt.Errorf(`mitigate operation failed: failed to match key "core id": %q`, partial),
+ reverseData: "1-",
+ reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from possible: %q`, "1-"),
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ m := &Mitigate{
+ dryRun: true,
+ vulnerable: func(other mitigate.Thread) bool {
+ return other.IsVulnerable()
+ },
+ }
+ m.doExecuteTest(t, "Mitigate", tc.mitigateData, tc.mitigateCPU, tc.mitigateError)
+
+ m.reverse = true
+ m.doExecuteTest(t, "Reverse", tc.reverseData, tc.reverseCPU, tc.reverseError)
+ })
+ }
+}
+
+func TestExecuteSmoke(t *testing.T) {
+ smokeMitigate, err := ioutil.ReadFile(cpuInfo)
+ if err != nil {
+ t.Fatalf("Failed to read %s: %v", cpuInfo, err)
+ }
+
+ m := &Mitigate{
+ dryRun: true,
+ vulnerable: func(other mitigate.Thread) bool {
+ return other.IsVulnerable()
+ },
+ }
+
+ m.doExecuteTest(t, "Mitigate", string(smokeMitigate), 0, nil)
+
+ smokeReverse, err := ioutil.ReadFile(allPossibleCPUs)
+ if err != nil {
+ t.Fatalf("Failed to read %s: %v", allPossibleCPUs, err)
+ }
+
+ m.reverse = true
+ m.doExecuteTest(t, "Reverse", string(smokeReverse), 0, nil)
+}
+
+// doExecuteTest runs Execute with the mitigate operation and reverse operation.
+func (m *Mitigate) doExecuteTest(t *testing.T, name, data string, want int, wantErr error) {
+ t.Run(name, func(t *testing.T) {
+ file, err := ioutil.TempFile("", "outfile.txt")
+ if err != nil {
+ t.Fatalf("Failed to create tmpfile: %v", err)
+ }
+ defer os.Remove(file.Name())
+
+ if _, err := file.WriteString(data); err != nil {
+ t.Fatalf("Failed to write to file: %v", err)
+ }
+
+ // Set fields for mitigate and dryrun to keep test hermetic.
+ m.path = file.Name()
+
+ set, err := m.doExecute()
+ if err = checkErr(wantErr, err); err != nil {
+ t.Fatalf("Mitigate error mismatch: %v", err)
+ }
+
+ // case where test should end in error or we don't care
+ // about how many cpus are returned.
+ if wantErr != nil || want < 1 {
+ return
+ }
+ got := len(set.GetRemainingList())
+ if want != got {
+ t.Fatalf("Failed wrong number of remaining CPUs: want %d, got %d", want, got)
+ }
+
+ })
+}
+
+// checkErr checks error for equality.
+func checkErr(want, got error) error {
+ switch {
+ case want == nil && got == nil:
+ case want != nil && got == nil:
+ fallthrough
+ case want == nil && got != nil:
+ fallthrough
+ case want.Error() != strings.Trim(got.Error(), " "):
+ return fmt.Errorf("got: %v want: %v", got, want)
+ }
+ return nil
+}
diff --git a/runsc/cmd/restore.go b/runsc/cmd/restore.go
index 096ec814c..b21f05921 100644
--- a/runsc/cmd/restore.go
+++ b/runsc/cmd/restore.go
@@ -17,9 +17,9 @@ package cmd
import (
"context"
"path/filepath"
- "syscall"
"github.com/google/subcommands"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/flag"
@@ -78,7 +78,7 @@ func (r *Restore) Execute(_ context.Context, f *flag.FlagSet, args ...interface{
id := f.Arg(0)
conf := args[0].(*config.Config)
- waitStatus := args[1].(*syscall.WaitStatus)
+ waitStatus := args[1].(*unix.WaitStatus)
if conf.Rootless {
return Errorf("Rootless mode not supported with %q", r.Name())
diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go
index c48cbe4cd..722181aff 100644
--- a/runsc/cmd/run.go
+++ b/runsc/cmd/run.go
@@ -16,9 +16,9 @@ package cmd
import (
"context"
- "syscall"
"github.com/google/subcommands"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/flag"
@@ -65,7 +65,7 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s
id := f.Arg(0)
conf := args[0].(*config.Config)
- waitStatus := args[1].(*syscall.WaitStatus)
+ waitStatus := args[1].(*unix.WaitStatus)
if conf.Rootless {
return Errorf("Rootless mode not supported with %q", r.Name())
diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go
index 5d55422c7..d7a783b88 100644
--- a/runsc/cmd/wait.go
+++ b/runsc/cmd/wait.go
@@ -18,9 +18,9 @@ import (
"context"
"encoding/json"
"os"
- "syscall"
"github.com/google/subcommands"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/runsc/config"
"gvisor.dev/gvisor/runsc/container"
"gvisor.dev/gvisor/runsc/flag"
@@ -77,7 +77,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("loading container: %v", err)
}
- var waitStatus syscall.WaitStatus
+ var waitStatus unix.WaitStatus
switch {
// Wait on the whole container.
case wt.rootPID == unsetPID && wt.pid == unsetPID:
@@ -119,7 +119,7 @@ type waitResult struct {
// exitStatus returns the correct exit status for a process based on if it
// was signaled or exited cleanly.
-func exitStatus(status syscall.WaitStatus) int {
+func exitStatus(status unix.WaitStatus) int {
if status.Signaled() {
return 128 + int(status.Signal())
}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index e9fd7708f..34ef48825 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -64,6 +64,9 @@ type Config struct {
// Overlay is whether to wrap the root filesystem in an overlay.
Overlay bool `flag:"overlay"`
+ // Verity is whether there's one or more verity file system to mount.
+ Verity bool `flag:"verity"`
+
// FSGoferHostUDS enables the gofer to mount a host UDS.
FSGoferHostUDS bool `flag:"fsgofer-host-uds"`
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index 7e738dfdf..adbee506c 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -69,6 +69,7 @@ func RegisterFlags() {
// Flags that control sandbox runtime behavior: FS related.
flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
+ flag.Bool("verity", false, "specifies whether a verity file system will be mounted.")
flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem")
flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.")
flag.Bool("vfs2", false, "enables VFSv2. This uses the new VFS layer that is faster than the previous one.")
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index 8793c8916..3620dc8c3 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -30,6 +30,7 @@ go_library(
"@com_github_cenkalti_backoff//:go_default_library",
"@com_github_gofrs_flock//:go_default_library",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go
index 7a3d5a523..79b056fce 100644
--- a/runsc/container/console_test.go
+++ b/runsc/container/console_test.go
@@ -21,7 +21,6 @@ import (
"math/rand"
"os"
"path/filepath"
- "syscall"
"testing"
"time"
@@ -320,7 +319,7 @@ func TestJobControlSignalExec(t *testing.T) {
// Send a SIGTERM to the foreground process for the exec PID. Note that
// although we pass in the PID of "bash", it should actually terminate
// "sleep", since that is the foreground process.
- if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.SIGTERM, true /* fgProcess */); err != nil {
+ if err := c.Sandbox.SignalProcess(c.ID, pid, unix.SIGTERM, true /* fgProcess */); err != nil {
t.Fatalf("error signaling container: %v", err)
}
@@ -340,7 +339,7 @@ func TestJobControlSignalExec(t *testing.T) {
// Send a SIGKILL to the foreground process again. This time "bash"
// should be killed. We use SIGKILL instead of SIGTERM or SIGINT
// because bash ignores those.
- if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.SIGKILL, true /* fgProcess */); err != nil {
+ if err := c.Sandbox.SignalProcess(c.ID, pid, unix.SIGKILL, true /* fgProcess */); err != nil {
t.Fatalf("error signaling container: %v", err)
}
expectedPL = expectedPL[:1]
@@ -356,7 +355,7 @@ func TestJobControlSignalExec(t *testing.T) {
if !ws.Signaled() {
t.Error("ws.Signaled() got false, want true")
}
- if got, want := ws.Signal(), syscall.SIGKILL; got != want {
+ if got, want := ws.Signal(), unix.SIGKILL; got != want {
t.Errorf("ws.Signal() got %v, want %v", got, want)
}
}
@@ -423,7 +422,7 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// very early, otherwise it might exit before we have a chance to call
// Wait.
var (
- ws syscall.WaitStatus
+ ws unix.WaitStatus
wg sync.WaitGroup
)
wg.Add(1)
@@ -459,7 +458,7 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// Send a SIGTERM to the foreground process. We pass PID=0, indicating
// that the root process should be killed. However, by setting
// fgProcess=true, the signal should actually be sent to sleep.
- if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, syscall.SIGTERM, true /* fgProcess */); err != nil {
+ if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, unix.SIGTERM, true /* fgProcess */); err != nil {
t.Fatalf("error signaling container: %v", err)
}
@@ -479,7 +478,7 @@ func TestJobControlSignalRootContainer(t *testing.T) {
// Send a SIGKILL to the foreground process again. This time "bash"
// should be killed. We use SIGKILL instead of SIGTERM or SIGINT
// because bash ignores those.
- if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, syscall.SIGKILL, true /* fgProcess */); err != nil {
+ if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, unix.SIGKILL, true /* fgProcess */); err != nil {
t.Fatalf("error signaling container: %v", err)
}
@@ -488,7 +487,7 @@ func TestJobControlSignalRootContainer(t *testing.T) {
if !ws.Signaled() {
t.Error("ws.Signaled() got false, want true")
}
- if got, want := ws.Signal(), syscall.SIGKILL; got != want {
+ if got, want := ws.Signal(), unix.SIGKILL; got != want {
t.Errorf("ws.Signal() got %v, want %v", got, want)
}
}
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 40812efb8..f9d83c118 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -30,6 +30,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/log"
@@ -244,7 +245,7 @@ func New(conf *config.Config, args Args) (*Container, error) {
// If there is cgroup config, install it before creating sandbox process.
if err := cg.Install(args.Spec.Linux.Resources); err != nil {
switch {
- case errors.Is(err, syscall.EACCES) && conf.Rootless:
+ case errors.Is(err, unix.EACCES) && conf.Rootless:
log.Warningf("Skipping cgroup configuration in rootless mode: %v", err)
cg = nil
default:
@@ -447,7 +448,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile s
}
// Run is a helper that calls Create + Start + Wait.
-func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) {
+func Run(conf *config.Config, args Args) (unix.WaitStatus, error) {
log.Debugf("Run container, cid: %s, rootDir: %q", args.ID, conf.RootDir)
c, err := New(conf, args)
if err != nil {
@@ -517,7 +518,7 @@ func (c *Container) SandboxPid() int {
// Wait waits for the container to exit, and returns its WaitStatus.
// Call to wait on a stopped container is needed to retrieve the exit status
// and wait returns immediately.
-func (c *Container) Wait() (syscall.WaitStatus, error) {
+func (c *Container) Wait() (unix.WaitStatus, error) {
log.Debugf("Wait on container, cid: %s", c.ID)
ws, err := c.Sandbox.Wait(c.ID)
if err == nil {
@@ -529,7 +530,7 @@ func (c *Container) Wait() (syscall.WaitStatus, error) {
// WaitRootPID waits for process 'pid' in the sandbox's PID namespace and
// returns its WaitStatus.
-func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) {
+func (c *Container) WaitRootPID(pid int32) (unix.WaitStatus, error) {
log.Debugf("Wait on process %d in sandbox, cid: %s", pid, c.Sandbox.ID)
if !c.IsSandboxRunning() {
return 0, fmt.Errorf("sandbox is not running")
@@ -539,7 +540,7 @@ func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) {
// WaitPID waits for process 'pid' in the container's PID namespace and returns
// its WaitStatus.
-func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) {
+func (c *Container) WaitPID(pid int32) (unix.WaitStatus, error) {
log.Debugf("Wait on process %d in container, cid: %s", pid, c.ID)
if !c.IsSandboxRunning() {
return 0, fmt.Errorf("sandbox is not running")
@@ -551,7 +552,7 @@ func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) {
// is SIGKILL, then waits for all processes to exit before returning.
// SignalContainer returns an error if the container is already stopped.
// TODO(b/113680494): Distinguish different error types.
-func (c *Container) SignalContainer(sig syscall.Signal, all bool) error {
+func (c *Container) SignalContainer(sig unix.Signal, all bool) error {
log.Debugf("Signal container, cid: %s, signal: %v (%d)", c.ID, sig, sig)
// Signaling container in Stopped state is allowed. When all=false,
// an error will be returned anyway; when all=true, this allows
@@ -568,7 +569,7 @@ func (c *Container) SignalContainer(sig syscall.Signal, all bool) error {
}
// SignalProcess sends sig to a specific process in the container.
-func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error {
+func (c *Container) SignalProcess(sig unix.Signal, pid int32) error {
log.Debugf("Signal process %d in container, cid: %s, signal: %v (%d)", pid, c.ID, sig, sig)
if err := c.requireStatus("signal a process inside", Running); err != nil {
return err
@@ -586,7 +587,7 @@ func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() {
log.Debugf("Forwarding all signals to container, cid: %s, PIDPID: %d, fgProcess: %t", c.ID, pid, fgProcess)
stop := sighandling.StartSignalForwarding(func(sig linux.Signal) {
log.Debugf("Forwarding signal %d to container, cid: %s, PID: %d, fgProcess: %t", sig, c.ID, pid, fgProcess)
- if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.Signal(sig), fgProcess); err != nil {
+ if err := c.Sandbox.SignalProcess(c.ID, pid, unix.Signal(sig), fgProcess); err != nil {
log.Warningf("error forwarding signal %d to container %q: %v", sig, c.ID, err)
}
})
@@ -768,9 +769,9 @@ func (c *Container) stop() error {
// Try killing gofer if it does not exit with container.
if c.GoferPid != 0 {
log.Debugf("Killing gofer for container, cid: %s, PID: %d", c.ID, c.GoferPid)
- if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil {
+ if err := unix.Kill(c.GoferPid, unix.SIGKILL); err != nil {
// The gofer may already be stopped, log the error.
- log.Warningf("Error sending signal %d to gofer %d: %v", syscall.SIGKILL, c.GoferPid, err)
+ log.Warningf("Error sending signal %d to gofer %d: %v", unix.SIGKILL, c.GoferPid, err)
}
}
@@ -793,7 +794,7 @@ func (c *Container) waitForStopped() error {
b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx)
op := func() error {
if c.IsSandboxRunning() {
- if err := c.SignalContainer(syscall.Signal(0), false); err == nil {
+ if err := c.SignalContainer(unix.Signal(0), false); err == nil {
return fmt.Errorf("container is still running")
}
}
@@ -803,7 +804,7 @@ func (c *Container) waitForStopped() error {
if c.goferIsChild {
// The gofer process is a child of the current process,
// so we can wait it and collect its zombie.
- wpid, err := syscall.Wait4(int(c.GoferPid), nil, syscall.WNOHANG, nil)
+ wpid, err := unix.Wait4(int(c.GoferPid), nil, unix.WNOHANG, nil)
if err != nil {
return fmt.Errorf("error waiting the gofer process: %v", err)
}
@@ -811,7 +812,7 @@ func (c *Container) waitForStopped() error {
return fmt.Errorf("gofer is still running")
}
- } else if err := syscall.Kill(c.GoferPid, 0); err == nil {
+ } else if err := unix.Kill(c.GoferPid, 0); err == nil {
return fmt.Errorf("gofer is still running")
}
c.GoferPid = 0
@@ -892,7 +893,7 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu
sandEnds := make([]*os.File, 0, mountCount)
for i := 0; i < mountCount; i++ {
- fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0)
+ fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0)
if err != nil {
return nil, nil, err
}
@@ -914,8 +915,8 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu
if attached {
// The gofer is attached to the lifetime of this process, so it
// should synchronously die when this process dies.
- cmd.SysProcAttr = &syscall.SysProcAttr{
- Pdeathsig: syscall.SIGKILL,
+ cmd.SysProcAttr = &unix.SysProcAttr{
+ Pdeathsig: unix.SIGKILL,
}
}
@@ -1113,7 +1114,7 @@ func setOOMScoreAdj(pid int, scoreAdj int) error {
}
defer f.Close()
if _, err := f.WriteString(strconv.Itoa(scoreAdj)); err != nil {
- if errors.Is(err, syscall.ESRCH) {
+ if errors.Is(err, unix.ESRCH) {
log.Warningf("Process (%d) exited while setting oom_score_adj", pid)
return nil
}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 862d9444d..5a0c468a4 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -27,12 +27,12 @@ import (
"reflect"
"strconv"
"strings"
- "syscall"
"testing"
"time"
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
@@ -103,7 +103,7 @@ func waitForProcessCount(cont *Container, want int) error {
func blockUntilWaitable(pid int) error {
_, _, err := specutils.RetryEintr(func() (uintptr, uintptr, error) {
var err error
- _, _, err1 := syscall.Syscall6(syscall.SYS_WAITID, 1, uintptr(pid), 0, syscall.WEXITED|syscall.WNOWAIT, 0, 0)
+ _, _, err1 := unix.Syscall6(unix.SYS_WAITID, 1, uintptr(pid), 0, unix.WEXITED|unix.WNOWAIT, 0, 0)
if err1 != 0 {
err = err1
}
@@ -468,7 +468,7 @@ func TestLifecycle(t *testing.T) {
if err != nil {
ch <- err
}
- if got, want := ws.Signal(), syscall.SIGTERM; got != want {
+ if got, want := ws.Signal(), unix.SIGTERM; got != want {
ch <- fmt.Errorf("got signal %v, want %v", got, want)
}
ch <- nil
@@ -479,8 +479,8 @@ func TestLifecycle(t *testing.T) {
time.Sleep(time.Second)
// Send the container a SIGTERM which will cause it to stop.
- if err := c.SignalContainer(syscall.SIGTERM, false); err != nil {
- t.Fatalf("error sending signal %v to container: %v", syscall.SIGTERM, err)
+ if err := c.SignalContainer(unix.SIGTERM, false); err != nil {
+ t.Fatalf("error sending signal %v to container: %v", unix.SIGTERM, err)
}
// Wait for it to die.
@@ -815,11 +815,11 @@ func TestExec(t *testing.T) {
t.Run("nonexist", func(t *testing.T) {
// b/179114837 found by Syzkaller that causes nil pointer panic when
// trying to dec-ref an unix socket FD.
- fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0)
+ fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0)
if err != nil {
t.Fatal(err)
}
- defer syscall.Close(fds[0])
+ defer unix.Close(fds[0])
_, err = cont.executeSync(&control.ExecArgs{
Argv: []string{"/nonexist"},
@@ -956,7 +956,7 @@ func TestKillPid(t *testing.T) {
pid = int32(p.PID)
}
}
- if err := cont.SignalProcess(syscall.SIGKILL, pid); err != nil {
+ if err := cont.SignalProcess(unix.SIGKILL, pid); err != nil {
t.Fatalf("failed to signal process %d: %v", pid, err)
}
@@ -1601,12 +1601,12 @@ func TestReadonlyRoot(t *testing.T) {
}
// Read mounts to check that root is readonly.
- out, err := executeCombinedOutput(c, "/bin/sh", "-c", "mount | grep ' / '")
+ out, err := executeCombinedOutput(c, "/bin/sh", "-c", "mount | grep ' / ' | grep -o -e '(.*)'")
if err != nil {
t.Fatalf("exec failed: %v", err)
}
- t.Logf("root mount: %q", out)
- if !strings.Contains(string(out), "(ro)") {
+ t.Logf("root mount options: %q", out)
+ if !strings.Contains(string(out), "ro") {
t.Errorf("root not mounted readonly: %q", out)
}
@@ -1615,7 +1615,7 @@ func TestReadonlyRoot(t *testing.T) {
if err != nil {
t.Fatalf("touch file in ro mount: %v", err)
}
- if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
+ if !ws.Exited() || unix.Errno(ws.ExitStatus()) != unix.EPERM {
t.Fatalf("wrong waitStatus: %v", ws)
}
})
@@ -1659,13 +1659,13 @@ func TestReadonlyMount(t *testing.T) {
}
// Read mounts to check that volume is readonly.
- cmd := fmt.Sprintf("mount | grep ' %s '", dir)
+ cmd := fmt.Sprintf("mount | grep ' %s ' | grep -o -e '(.*)'", dir)
out, err := executeCombinedOutput(c, "/bin/sh", "-c", cmd)
if err != nil {
t.Fatalf("exec failed, err: %v", err)
}
- t.Logf("mount: %q", out)
- if !strings.Contains(string(out), "(ro)") {
+ t.Logf("mount options: %q", out)
+ if !strings.Contains(string(out), "ro") {
t.Errorf("volume not mounted readonly: %q", out)
}
@@ -1674,7 +1674,7 @@ func TestReadonlyMount(t *testing.T) {
if err != nil {
t.Fatalf("touch file in ro mount: %v", err)
}
- if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM {
+ if !ws.Exited() || unix.Errno(ws.ExitStatus()) != unix.EPERM {
t.Fatalf("wrong WaitStatus: %v", ws)
}
})
@@ -1750,8 +1750,8 @@ func TestUIDMap(t *testing.T) {
if !ws.Exited() || ws.ExitStatus() != 0 {
t.Fatalf("container failed, waitStatus: %v", ws)
}
- st := syscall.Stat_t{}
- if err := syscall.Stat(testFile, &st); err != nil {
+ st := unix.Stat_t{}
+ if err := unix.Stat(testFile, &st); err != nil {
t.Fatalf("error stat /testfile: %v", err)
}
@@ -1880,7 +1880,7 @@ func doGoferExitTest(t *testing.T, vfs2 bool) {
}
err = blockUntilWaitable(c.GoferPid)
- if err != nil && err != syscall.ECHILD {
+ if err != nil && err != unix.ECHILD {
t.Errorf("error waiting for gofer to exit: %v", err)
}
}
@@ -1929,7 +1929,7 @@ func TestUserLog(t *testing.T) {
}
// sched_rr_get_interval - not implemented in gvisor.
- num := strconv.Itoa(syscall.SYS_SCHED_RR_GET_INTERVAL)
+ num := strconv.Itoa(unix.SYS_SCHED_RR_GET_INTERVAL)
spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall="+num)
conf := testutil.TestConfig(t)
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
@@ -2159,10 +2159,10 @@ func TestMountPropagation(t *testing.T) {
f.Close()
// Setup src as a shared mount.
- if err := syscall.Mount(src, src, "bind", syscall.MS_BIND, ""); err != nil {
+ if err := unix.Mount(src, src, "bind", unix.MS_BIND, ""); err != nil {
t.Fatalf("mount(%q, %q, MS_BIND): %v", dir, srcMnt, err)
}
- if err := syscall.Mount("", src, "", syscall.MS_SHARED, ""); err != nil {
+ if err := unix.Mount("", src, "", unix.MS_SHARED, ""); err != nil {
t.Fatalf("mount(%q, MS_SHARED): %v", srcMnt, err)
}
@@ -2209,7 +2209,7 @@ func TestMountPropagation(t *testing.T) {
// After the container is started, mount dir inside source and check what
// happens to both destinations.
- if err := syscall.Mount(dir, srcMnt, "bind", syscall.MS_BIND, ""); err != nil {
+ if err := unix.Mount(dir, srcMnt, "bind", unix.MS_BIND, ""); err != nil {
t.Fatalf("mount(%q, %q, MS_BIND): %v", dir, srcMnt, err)
}
@@ -2449,7 +2449,7 @@ func TestCreateWithCorruptedStateFile(t *testing.T) {
}
}
-func execute(cont *Container, name string, arg ...string) (syscall.WaitStatus, error) {
+func execute(cont *Container, name string, arg ...string) (unix.WaitStatus, error) {
args := &control.ExecArgs{
Filename: name,
Argv: append([]string{name}, arg...),
@@ -2483,7 +2483,7 @@ func executeCombinedOutput(cont *Container, name string, arg ...string) ([]byte,
}
// executeSync synchronously executes a new process.
-func (c *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
+func (c *Container) executeSync(args *control.ExecArgs) (unix.WaitStatus, error) {
pid, err := c.Execute(args)
if err != nil {
return 0, fmt.Errorf("error executing: %v", err)
diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go
index b434cdb23..0f0a223ce 100644
--- a/runsc/container/multi_container_test.go
+++ b/runsc/container/multi_container_test.go
@@ -22,11 +22,11 @@ import (
"path"
"path/filepath"
"strings"
- "syscall"
"testing"
"time"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -403,7 +403,7 @@ func TestMultiPIDNSKill(t *testing.T) {
t.Logf("Container %q procs: %s", c.ID, procListToString(procs))
pidToKill := procs[processes-1].PID
t.Logf("PID to kill: %d", pidToKill)
- if err := c.SignalProcess(syscall.SIGKILL, int32(pidToKill)); err != nil {
+ if err := c.SignalProcess(unix.SIGKILL, int32(pidToKill)); err != nil {
t.Errorf("container.SignalProcess: %v", err)
}
// Wait for the process to get killed.
@@ -432,7 +432,7 @@ func TestMultiPIDNSKill(t *testing.T) {
pidToKill = procs[len(procs)-1].PID
t.Logf("PID that should not be killed: %d", pidToKill)
- err = c.SignalProcess(syscall.SIGKILL, int32(pidToKill))
+ err = c.SignalProcess(unix.SIGKILL, int32(pidToKill))
if err == nil {
t.Fatalf("killing another container's process should fail")
}
@@ -640,7 +640,7 @@ func TestMultiContainerSignal(t *testing.T) {
}
// Kill process 2.
- if err := containers[1].SignalContainer(syscall.SIGKILL, false); err != nil {
+ if err := containers[1].SignalContainer(unix.SIGKILL, false); err != nil {
t.Errorf("failed to kill process 2: %v", err)
}
@@ -660,10 +660,10 @@ func TestMultiContainerSignal(t *testing.T) {
t.Errorf("failed to destroy container: %v", err)
}
_, _, err = specutils.RetryEintr(func() (uintptr, uintptr, error) {
- cpid, err := syscall.Wait4(goferPid, nil, 0, nil)
+ cpid, err := unix.Wait4(goferPid, nil, 0, nil)
return uintptr(cpid), 0, err
})
- if err != syscall.ECHILD {
+ if err != unix.ECHILD {
t.Errorf("error waiting for gofer to exit: %v", err)
}
// Make sure process 1 is still running.
@@ -673,28 +673,28 @@ func TestMultiContainerSignal(t *testing.T) {
// Now that process 2 is gone, ensure we get an error trying to
// signal it again.
- if err := containers[1].SignalContainer(syscall.SIGKILL, false); err == nil {
+ if err := containers[1].SignalContainer(unix.SIGKILL, false); err == nil {
t.Errorf("container %q shouldn't exist, but we were able to signal it", containers[1].ID)
}
// Kill process 1.
- if err := containers[0].SignalContainer(syscall.SIGKILL, false); err != nil {
+ if err := containers[0].SignalContainer(unix.SIGKILL, false); err != nil {
t.Errorf("failed to kill process 1: %v", err)
}
// Ensure that container's gofer and sandbox process are no more.
err = blockUntilWaitable(containers[0].GoferPid)
- if err != nil && err != syscall.ECHILD {
+ if err != nil && err != unix.ECHILD {
t.Errorf("error waiting for gofer to exit: %v", err)
}
err = blockUntilWaitable(containers[0].Sandbox.Pid)
- if err != nil && err != syscall.ECHILD {
+ if err != nil && err != unix.ECHILD {
t.Errorf("error waiting for sandbox to exit: %v", err)
}
// The sentry should be gone, so signaling should yield an error.
- if err := containers[0].SignalContainer(syscall.SIGKILL, false); err == nil {
+ if err := containers[0].SignalContainer(unix.SIGKILL, false); err == nil {
t.Errorf("sandbox %q shouldn't exist, but we were able to signal it", containers[0].Sandbox.ID)
}
@@ -893,7 +893,7 @@ func TestMultiContainerKillAll(t *testing.T) {
if tc.killContainer {
// First kill the init process to make the container be stopped with
// processes still running inside.
- containers[1].SignalContainer(syscall.SIGKILL, false)
+ containers[1].SignalContainer(unix.SIGKILL, false)
op := func() error {
c, err := Load(conf.RootDir, FullID{ContainerID: ids[1]}, LoadOpts{})
if err != nil {
@@ -914,7 +914,7 @@ func TestMultiContainerKillAll(t *testing.T) {
t.Fatalf("failed to load child container %q: %v", c.ID, err)
}
// Kill'Em All
- if err := c.SignalContainer(syscall.SIGKILL, true); err != nil {
+ if err := c.SignalContainer(unix.SIGKILL, true); err != nil {
t.Fatalf("failed to send SIGKILL to container %q: %v", c.ID, err)
}
@@ -1640,8 +1640,8 @@ func TestMultiContainerGoferKilled(t *testing.T) {
}
// Kill container's gofer.
- if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil {
- t.Fatalf("syscall.Kill(%d, SIGKILL)=%v", c.GoferPid, err)
+ if err := unix.Kill(c.GoferPid, unix.SIGKILL); err != nil {
+ t.Fatalf("unix.Kill(%d, SIGKILL)=%v", c.GoferPid, err)
}
// Wait until container stops.
@@ -1672,8 +1672,8 @@ func TestMultiContainerGoferKilled(t *testing.T) {
// Kill root container's gofer to bring entire sandbox down.
c = containers[0]
- if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil {
- t.Fatalf("syscall.Kill(%d, SIGKILL)=%v", c.GoferPid, err)
+ if err := unix.Kill(c.GoferPid, unix.SIGKILL); err != nil {
+ t.Fatalf("unix.Kill(%d, SIGKILL)=%v", c.GoferPid, err)
}
// Wait until sandbox stops. waitForProcessList will loop until sandbox exits
diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go
index c46322ba4..0399903a0 100644
--- a/runsc/container/state_file.go
+++ b/runsc/container/state_file.go
@@ -22,9 +22,9 @@ import (
"path/filepath"
"regexp"
"strings"
- "syscall"
"github.com/gofrs/flock"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -89,7 +89,7 @@ func Load(rootDir string, id FullID, opts LoadOpts) (*Container, error) {
c.changeStatus(Stopped)
}
case Running:
- if err := c.SignalContainer(syscall.Signal(0), false); err != nil {
+ if err := c.SignalContainer(unix.Signal(0), false); err != nil {
c.changeStatus(Stopped)
}
}
@@ -245,7 +245,7 @@ type StateFile struct {
// lock globally locks all locking operations for the container.
func (s *StateFile) lock() error {
s.once.Do(func() {
- s.flock = flock.NewFlock(s.lockPath())
+ s.flock = flock.New(s.lockPath())
})
if err := s.flock.Lock(); err != nil {
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index d1af539cb..fd72414ce 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -16,7 +16,6 @@ package filter
import (
"os"
- "syscall"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -25,12 +24,12 @@ import (
// allowedSyscalls is the set of syscalls executed by the gofer.
var allowedSyscalls = seccomp.SyscallRules{
- syscall.SYS_ACCEPT: {},
- syscall.SYS_CLOCK_GETTIME: {},
- syscall.SYS_CLOSE: {},
- syscall.SYS_DUP: {},
- syscall.SYS_EPOLL_CTL: {},
- syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{
+ unix.SYS_ACCEPT: {},
+ unix.SYS_CLOCK_GETTIME: {},
+ unix.SYS_CLOSE: {},
+ unix.SYS_DUP: {},
+ unix.SYS_EPOLL_CTL: {},
+ unix.SYS_EPOLL_PWAIT: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
@@ -39,34 +38,34 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.EqualTo(0),
},
},
- syscall.SYS_EVENTFD2: []seccomp.Rule{
+ unix.SYS_EVENTFD2: []seccomp.Rule{
{
seccomp.EqualTo(0),
seccomp.EqualTo(0),
},
},
- syscall.SYS_EXIT: {},
- syscall.SYS_EXIT_GROUP: {},
- syscall.SYS_FALLOCATE: []seccomp.Rule{
+ unix.SYS_EXIT: {},
+ unix.SYS_EXIT_GROUP: {},
+ unix.SYS_FALLOCATE: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.EqualTo(0),
},
},
- syscall.SYS_FCHMOD: {},
- syscall.SYS_FCHOWNAT: {},
- syscall.SYS_FCNTL: []seccomp.Rule{
+ unix.SYS_FCHMOD: {},
+ unix.SYS_FCHOWNAT: {},
+ unix.SYS_FCNTL: []seccomp.Rule{
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.F_GETFL),
+ seccomp.EqualTo(unix.F_GETFL),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.F_SETFL),
+ seccomp.EqualTo(unix.F_SETFL),
},
{
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.F_GETFD),
+ seccomp.EqualTo(unix.F_GETFD),
},
// Used by flipcall.PacketWindowAllocator.Init().
{
@@ -74,11 +73,11 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.EqualTo(unix.F_ADD_SEALS),
},
},
- syscall.SYS_FSTAT: {},
- syscall.SYS_FSTATFS: {},
- syscall.SYS_FSYNC: {},
- syscall.SYS_FTRUNCATE: {},
- syscall.SYS_FUTEX: {
+ unix.SYS_FSTAT: {},
+ unix.SYS_FSTATFS: {},
+ unix.SYS_FSYNC: {},
+ unix.SYS_FTRUNCATE: {},
+ unix.SYS_FUTEX: {
seccomp.Rule{
seccomp.MatchAny{},
seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
@@ -116,78 +115,78 @@ var allowedSyscalls = seccomp.SyscallRules{
seccomp.EqualTo(0),
},
},
- syscall.SYS_GETDENTS64: {},
- syscall.SYS_GETPID: {},
- unix.SYS_GETRANDOM: {},
- syscall.SYS_GETTID: {},
- syscall.SYS_GETTIMEOFDAY: {},
- syscall.SYS_LINKAT: {},
- syscall.SYS_LSEEK: {},
- syscall.SYS_MADVISE: {},
- unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init().
- syscall.SYS_MKDIRAT: {},
- syscall.SYS_MKNODAT: {},
+ unix.SYS_GETDENTS64: {},
+ unix.SYS_GETPID: {},
+ unix.SYS_GETRANDOM: {},
+ unix.SYS_GETTID: {},
+ unix.SYS_GETTIMEOFDAY: {},
+ unix.SYS_LINKAT: {},
+ unix.SYS_LSEEK: {},
+ unix.SYS_MADVISE: {},
+ unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init().
+ unix.SYS_MKDIRAT: {},
+ unix.SYS_MKNODAT: {},
// Used by the Go runtime as a temporarily workaround for a Linux
// 5.2-5.4 bug.
//
// See src/runtime/os_linux_x86.go.
//
// TODO(b/148688965): Remove once this is gone from Go.
- syscall.SYS_MLOCK: []seccomp.Rule{
+ unix.SYS_MLOCK: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.EqualTo(4096),
},
},
- syscall.SYS_MMAP: []seccomp.Rule{
+ unix.SYS_MMAP: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MAP_SHARED),
+ seccomp.EqualTo(unix.MAP_SHARED),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS),
+ seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED),
+ seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS | unix.MAP_FIXED),
},
},
- syscall.SYS_MPROTECT: {},
- syscall.SYS_MUNMAP: {},
- syscall.SYS_NANOSLEEP: {},
- syscall.SYS_OPENAT: {},
- syscall.SYS_PPOLL: {},
- syscall.SYS_PREAD64: {},
- syscall.SYS_PWRITE64: {},
- syscall.SYS_READ: {},
- syscall.SYS_READLINKAT: {},
- syscall.SYS_RECVMSG: []seccomp.Rule{
+ unix.SYS_MPROTECT: {},
+ unix.SYS_MUNMAP: {},
+ unix.SYS_NANOSLEEP: {},
+ unix.SYS_OPENAT: {},
+ unix.SYS_PPOLL: {},
+ unix.SYS_PREAD64: {},
+ unix.SYS_PWRITE64: {},
+ unix.SYS_READ: {},
+ unix.SYS_READLINKAT: {},
+ unix.SYS_RECVMSG: []seccomp.Rule{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC),
+ seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_TRUNC),
},
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK),
+ seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_TRUNC | unix.MSG_PEEK),
},
},
- syscall.SYS_RENAMEAT: {},
- syscall.SYS_RESTART_SYSCALL: {},
+ unix.SYS_RENAMEAT: {},
+ unix.SYS_RESTART_SYSCALL: {},
// May be used by the runtime during panic().
- syscall.SYS_RT_SIGACTION: {},
- syscall.SYS_RT_SIGPROCMASK: {},
- syscall.SYS_RT_SIGRETURN: {},
- syscall.SYS_SCHED_YIELD: {},
- syscall.SYS_SENDMSG: []seccomp.Rule{
+ unix.SYS_RT_SIGACTION: {},
+ unix.SYS_RT_SIGPROCMASK: {},
+ unix.SYS_RT_SIGRETURN: {},
+ unix.SYS_SCHED_YIELD: {},
+ unix.SYS_SENDMSG: []seccomp.Rule{
// Used by fdchannel.Endpoint.SendFD().
{
seccomp.MatchAny{},
@@ -198,51 +197,51 @@ var allowedSyscalls = seccomp.SyscallRules{
{
seccomp.MatchAny{},
seccomp.MatchAny{},
- seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL),
+ seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_NOSIGNAL),
},
},
- syscall.SYS_SHUTDOWN: []seccomp.Rule{
- {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)},
+ unix.SYS_SHUTDOWN: []seccomp.Rule{
+ {seccomp.MatchAny{}, seccomp.EqualTo(unix.SHUT_RDWR)},
},
- syscall.SYS_SIGALTSTACK: {},
+ unix.SYS_SIGALTSTACK: {},
// Used by fdchannel.NewConnectedSockets().
- syscall.SYS_SOCKETPAIR: {
+ unix.SYS_SOCKETPAIR: {
{
- seccomp.EqualTo(syscall.AF_UNIX),
- seccomp.EqualTo(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC),
+ seccomp.EqualTo(unix.AF_UNIX),
+ seccomp.EqualTo(unix.SOCK_SEQPACKET | unix.SOCK_CLOEXEC),
seccomp.EqualTo(0),
},
},
- syscall.SYS_SYMLINKAT: {},
- syscall.SYS_TGKILL: []seccomp.Rule{
+ unix.SYS_SYMLINKAT: {},
+ unix.SYS_TGKILL: []seccomp.Rule{
{
seccomp.EqualTo(uint64(os.Getpid())),
},
},
- syscall.SYS_UNLINKAT: {},
- syscall.SYS_UTIMENSAT: {},
- syscall.SYS_WRITE: {},
+ unix.SYS_UNLINKAT: {},
+ unix.SYS_UTIMENSAT: {},
+ unix.SYS_WRITE: {},
}
var udsSyscalls = seccomp.SyscallRules{
- syscall.SYS_SOCKET: []seccomp.Rule{
+ unix.SYS_SOCKET: []seccomp.Rule{
{
- seccomp.EqualTo(syscall.AF_UNIX),
- seccomp.EqualTo(syscall.SOCK_STREAM),
+ seccomp.EqualTo(unix.AF_UNIX),
+ seccomp.EqualTo(unix.SOCK_STREAM),
seccomp.EqualTo(0),
},
{
- seccomp.EqualTo(syscall.AF_UNIX),
- seccomp.EqualTo(syscall.SOCK_DGRAM),
+ seccomp.EqualTo(unix.AF_UNIX),
+ seccomp.EqualTo(unix.SOCK_DGRAM),
seccomp.EqualTo(0),
},
{
- seccomp.EqualTo(syscall.AF_UNIX),
- seccomp.EqualTo(syscall.SOCK_SEQPACKET),
+ seccomp.EqualTo(unix.AF_UNIX),
+ seccomp.EqualTo(unix.SOCK_SEQPACKET),
seccomp.EqualTo(0),
},
},
- syscall.SYS_CONNECT: []seccomp.Rule{
+ unix.SYS_CONNECT: []seccomp.Rule{
{
seccomp.MatchAny{},
},
diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go
index 686753d96..2d0151dcc 100644
--- a/runsc/fsgofer/filter/config_amd64.go
+++ b/runsc/fsgofer/filter/config_amd64.go
@@ -17,30 +17,29 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/seccomp"
)
func init() {
- allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{
+ allowedSyscalls[unix.SYS_ARCH_PRCTL] = []seccomp.Rule{
// TODO(b/168828518): No longer used in Go 1.16+.
{seccomp.EqualTo(linux.ARCH_SET_FS)},
}
- allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{
+ allowedSyscalls[unix.SYS_CLONE] = []seccomp.Rule{
// parent_tidptr and child_tidptr are always 0 because neither
// CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used.
{
seccomp.EqualTo(
- syscall.CLONE_VM |
- syscall.CLONE_FS |
- syscall.CLONE_FILES |
- syscall.CLONE_SETTLS |
- syscall.CLONE_SIGHAND |
- syscall.CLONE_SYSVSEM |
- syscall.CLONE_THREAD),
+ unix.CLONE_VM |
+ unix.CLONE_FS |
+ unix.CLONE_FILES |
+ unix.CLONE_SETTLS |
+ unix.CLONE_SIGHAND |
+ unix.CLONE_SYSVSEM |
+ unix.CLONE_THREAD),
seccomp.MatchAny{}, // newsp
seccomp.EqualTo(0), // parent_tidptr
seccomp.EqualTo(0), // child_tidptr
@@ -49,12 +48,12 @@ func init() {
{
// TODO(b/168828518): No longer used in Go 1.16+ (on amd64).
seccomp.EqualTo(
- syscall.CLONE_VM |
- syscall.CLONE_FS |
- syscall.CLONE_FILES |
- syscall.CLONE_SIGHAND |
- syscall.CLONE_SYSVSEM |
- syscall.CLONE_THREAD),
+ unix.CLONE_VM |
+ unix.CLONE_FS |
+ unix.CLONE_FILES |
+ unix.CLONE_SIGHAND |
+ unix.CLONE_SYSVSEM |
+ unix.CLONE_THREAD),
seccomp.MatchAny{}, // newsp
seccomp.EqualTo(0), // parent_tidptr
seccomp.EqualTo(0), // child_tidptr
@@ -62,5 +61,5 @@ func init() {
},
}
- allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{}
+ allowedSyscalls[unix.SYS_NEWFSTATAT] = []seccomp.Rule{}
}
diff --git a/runsc/fsgofer/filter/config_arm64.go b/runsc/fsgofer/filter/config_arm64.go
index ff0cf77a0..7d458c02d 100644
--- a/runsc/fsgofer/filter/config_arm64.go
+++ b/runsc/fsgofer/filter/config_arm64.go
@@ -17,23 +17,22 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/seccomp"
)
func init() {
- allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{
+ allowedSyscalls[unix.SYS_CLONE] = []seccomp.Rule{
// parent_tidptr and child_tidptr are always 0 because neither
// CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used.
{
seccomp.EqualTo(
- syscall.CLONE_VM |
- syscall.CLONE_FS |
- syscall.CLONE_FILES |
- syscall.CLONE_SIGHAND |
- syscall.CLONE_SYSVSEM |
- syscall.CLONE_THREAD),
+ unix.CLONE_VM |
+ unix.CLONE_FS |
+ unix.CLONE_FILES |
+ unix.CLONE_SIGHAND |
+ unix.CLONE_SYSVSEM |
+ unix.CLONE_THREAD),
seccomp.MatchAny{}, // newsp
// These arguments are left uninitialized by the Go
// runtime, so they may be anything (and are unused by
@@ -44,5 +43,5 @@ func init() {
},
}
- allowedSyscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{}
+ allowedSyscalls[unix.SYS_FSTATAT] = []seccomp.Rule{}
}
diff --git a/runsc/fsgofer/filter/extra_filters_msan.go b/runsc/fsgofer/filter/extra_filters_msan.go
index 8c6179c8f..d768ed0bb 100644
--- a/runsc/fsgofer/filter/extra_filters_msan.go
+++ b/runsc/fsgofer/filter/extra_filters_msan.go
@@ -17,8 +17,7 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/seccomp"
)
@@ -27,7 +26,7 @@ import (
func instrumentationFilters() seccomp.SyscallRules {
log.Warningf("*** SECCOMP WARNING: MSAN is enabled: syscall filters less restrictive!")
return seccomp.SyscallRules{
- syscall.SYS_SCHED_GETAFFINITY: {},
- syscall.SYS_SET_ROBUST_LIST: {},
+ unix.SYS_SCHED_GETAFFINITY: {},
+ unix.SYS_SET_ROBUST_LIST: {},
}
}
diff --git a/runsc/fsgofer/filter/extra_filters_race.go b/runsc/fsgofer/filter/extra_filters_race.go
index cbd5c487e..9e75c025d 100644
--- a/runsc/fsgofer/filter/extra_filters_race.go
+++ b/runsc/fsgofer/filter/extra_filters_race.go
@@ -17,8 +17,7 @@
package filter
import (
- "syscall"
-
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/seccomp"
)
@@ -27,18 +26,18 @@ import (
func instrumentationFilters() seccomp.SyscallRules {
log.Warningf("*** SECCOMP WARNING: TSAN is enabled: syscall filters less restrictive!")
return seccomp.SyscallRules{
- syscall.SYS_BRK: {},
- syscall.SYS_CLOCK_NANOSLEEP: {},
- syscall.SYS_CLONE: {},
- syscall.SYS_FUTEX: {},
- syscall.SYS_MADVISE: {},
- syscall.SYS_MMAP: {},
- syscall.SYS_MUNLOCK: {},
- syscall.SYS_NANOSLEEP: {},
- syscall.SYS_OPEN: {},
- syscall.SYS_OPENAT: {},
- syscall.SYS_SET_ROBUST_LIST: {},
+ unix.SYS_BRK: {},
+ unix.SYS_CLOCK_NANOSLEEP: {},
+ unix.SYS_CLONE: {},
+ unix.SYS_FUTEX: {},
+ unix.SYS_MADVISE: {},
+ unix.SYS_MMAP: {},
+ unix.SYS_MUNLOCK: {},
+ unix.SYS_NANOSLEEP: {},
+ unix.SYS_OPEN: {},
+ unix.SYS_OPENAT: {},
+ unix.SYS_SET_ROBUST_LIST: {},
// Used within glibc's malloc.
- syscall.SYS_TIME: {},
+ unix.SYS_TIME: {},
}
}
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index cfa3796b1..1e80a634d 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -66,6 +66,9 @@ type Config struct {
// HostUDS signals whether the gofer can mount a host's UDS.
HostUDS bool
+
+ // enableXattr allows Get/SetXattr for the mounted file systems.
+ EnableXattr bool
}
type attachPoint struct {
@@ -795,12 +798,22 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error {
return err
}
-func (*localFile) GetXattr(string, uint64) (string, error) {
- return "", unix.EOPNOTSUPP
+func (l *localFile) GetXattr(name string, size uint64) (string, error) {
+ if !l.attachPoint.conf.EnableXattr {
+ return "", unix.EOPNOTSUPP
+ }
+ buffer := make([]byte, size)
+ if _, err := unix.Fgetxattr(l.file.FD(), name, buffer); err != nil {
+ return "", err
+ }
+ return string(buffer), nil
}
-func (*localFile) SetXattr(string, string, uint32) error {
- return unix.EOPNOTSUPP
+func (l *localFile) SetXattr(name string, value string, flags uint32) error {
+ if !l.attachPoint.conf.EnableXattr {
+ return unix.EOPNOTSUPP
+ }
+ return unix.Fsetxattr(l.file.FD(), name, []byte(value), int(flags))
}
func (*localFile) ListXattr(uint64) (map[string]struct{}, error) {
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index 99ea9bd32..a5f09f88f 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -565,6 +565,38 @@ func TestSetAttrOwner(t *testing.T) {
})
}
+func SetGetXattr(l *localFile, name string, value string) error {
+ if err := l.SetXattr(name, value, 0 /* flags */); err != nil {
+ return err
+ }
+ ret, err := l.GetXattr(name, uint64(len(value)))
+ if err != nil {
+ return err
+ }
+ if ret != value {
+ return fmt.Errorf("Got value %s, want %s", ret, value)
+ }
+ return nil
+}
+
+func TestSetGetXattr(t *testing.T) {
+ xattrConfs := []Config{{ROMount: false, EnableXattr: false}, {ROMount: false, EnableXattr: true}}
+ runCustom(t, []uint32{unix.S_IFREG}, xattrConfs, func(t *testing.T, s state) {
+ name := "user.test"
+ value := "tmp"
+ err := SetGetXattr(s.file, name, value)
+ if s.conf.EnableXattr {
+ if err != nil {
+ t.Fatalf("%v: SetGetXattr failed, err: %v", s, err)
+ }
+ } else {
+ if err == nil {
+ t.Fatalf("%v: SetGetXattr should have failed", s)
+ }
+ }
+ })
+}
+
func TestLink(t *testing.T) {
if !specutils.HasCapabilities(capability.CAP_DAC_READ_SEARCH) {
t.Skipf("Link test requires CAP_DAC_READ_SEARCH, running as %d", os.Getuid())
diff --git a/runsc/mitigate/BUILD b/runsc/mitigate/BUILD
index 561854e66..1238890fc 100644
--- a/runsc/mitigate/BUILD
+++ b/runsc/mitigate/BUILD
@@ -4,28 +4,20 @@ package(licenses = ["notice"])
go_library(
name = "mitigate",
- srcs = [
- "cpu.go",
- "mitigate.go",
- "mitigate_conf.go",
- ],
+ srcs = ["mitigate.go"],
visibility = [
"//runsc:__subpackages__",
],
- deps = [
- "//pkg/log",
- "//runsc/flag",
- "@in_gopkg_yaml_v2//:go_default_library",
- ],
+ deps = ["@in_gopkg_yaml_v2//:go_default_library"],
)
go_test(
name = "mitigate_test",
size = "small",
- srcs = [
- "cpu_test.go",
- "mitigate_test.go",
- ],
+ srcs = ["mitigate_test.go"],
library = ":mitigate",
- deps = ["@com_github_google_go_cmp//cmp:go_default_library"],
+ deps = [
+ "//runsc/mitigate/mock",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
)
diff --git a/runsc/mitigate/cpu.go b/runsc/mitigate/cpu.go
deleted file mode 100644
index 4b2aa351f..000000000
--- a/runsc/mitigate/cpu.go
+++ /dev/null
@@ -1,423 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package mitigate
-
-import (
- "fmt"
- "io/ioutil"
- "regexp"
- "strconv"
- "strings"
-)
-
-const (
- // mds is the only bug we care about.
- mds = "mds"
-
- // Constants for parsing /proc/cpuinfo.
- processorKey = "processor"
- vendorIDKey = "vendor_id"
- cpuFamilyKey = "cpu family"
- modelKey = "model"
- physicalIDKey = "physical id"
- coreIDKey = "core id"
- bugsKey = "bugs"
-
- // Path to shutdown a CPU.
- cpuOnlineTemplate = "/sys/devices/system/cpu/cpu%d/online"
-)
-
-// cpuSet contains a map of all CPUs on the system, mapped
-// by Physical ID and CoreIDs. threads with the same
-// Core and Physical ID are Hyperthread pairs.
-type cpuSet map[cpuID]*threadGroup
-
-// newCPUSet creates a CPUSet from data read from /proc/cpuinfo.
-func newCPUSet(data []byte, vulnerable func(thread) bool) (cpuSet, error) {
- processors, err := getThreads(string(data))
- if err != nil {
- return nil, err
- }
-
- set := make(cpuSet)
- for _, p := range processors {
- // Each ID is of the form physicalID:coreID. Hyperthread pairs
- // have identical physical and core IDs. We need to match
- // Hyperthread pairs so that we can shutdown all but one per
- // pair.
- core, ok := set[p.id]
- if !ok {
- core = &threadGroup{}
- set[p.id] = core
- }
- core.isVulnerable = core.isVulnerable || vulnerable(p)
- core.threads = append(core.threads, p)
- }
- return set, nil
-}
-
-// newCPUSetFromPossible makes a cpuSet data read from
-// /sys/devices/system/cpu/possible. This is used in enable operations
-// where the caller simply wants to enable all CPUS.
-func newCPUSetFromPossible(data []byte) (cpuSet, error) {
- threads, err := getThreadsFromPossible(data)
- if err != nil {
- return nil, err
- }
-
- // We don't care if a CPU is vulnerable or not, we just
- // want to return a list of all CPUs on the host.
- set := cpuSet{
- threads[0].id: &threadGroup{
- threads: threads,
- isVulnerable: false,
- },
- }
- return set, nil
-}
-
-// String implements the String method for CPUSet.
-func (c cpuSet) String() string {
- ret := ""
- for _, tg := range c {
- ret += fmt.Sprintf("%s\n", tg)
- }
- return ret
-}
-
-// getRemainingList returns the list of threads that will remain active
-// after mitigation.
-func (c cpuSet) getRemainingList() []thread {
- threads := make([]thread, 0, len(c))
- for _, core := range c {
- // If we're vulnerable, take only one thread from the pair.
- if core.isVulnerable {
- threads = append(threads, core.threads[0])
- continue
- }
- // Otherwise don't shutdown anything.
- threads = append(threads, core.threads...)
- }
- return threads
-}
-
-// getShutdownList returns the list of threads that will be shutdown on
-// mitigation.
-func (c cpuSet) getShutdownList() []thread {
- threads := make([]thread, 0)
- for _, core := range c {
- // Only if we're vulnerable do shutdown anything. In this case,
- // shutdown all but the first entry.
- if core.isVulnerable && len(core.threads) > 1 {
- threads = append(threads, core.threads[1:]...)
- }
- }
- return threads
-}
-
-// threadGroup represents Hyperthread pairs on the same physical/core ID.
-type threadGroup struct {
- threads []thread
- isVulnerable bool
-}
-
-// String implements the String method for threadGroup.
-func (c threadGroup) String() string {
- ret := fmt.Sprintf("ThreadGroup:\nIsVulnerable: %t\n", c.isVulnerable)
- for _, processor := range c.threads {
- ret += fmt.Sprintf("%s\n", processor)
- }
- return ret
-}
-
-// getThreads returns threads structs from reading /proc/cpuinfo.
-func getThreads(data string) ([]thread, error) {
- // Each processor entry should start with the
- // processor key. Find the beginings of each.
- r := buildRegex(processorKey, `\d+`)
- indices := r.FindAllStringIndex(data, -1)
- if len(indices) < 1 {
- return nil, fmt.Errorf("no cpus found for: %q", data)
- }
-
- // Add the ending index for last entry.
- indices = append(indices, []int{len(data), -1})
-
- // Valid cpus are now defined by strings in between
- // indexes (e.g. data[index[i], index[i+1]]).
- // There should be len(indicies) - 1 CPUs
- // since the last index is the end of the string.
- cpus := make([]thread, 0, len(indices))
- // Find each string that represents a CPU. These begin "processor".
- for i := 1; i < len(indices); i++ {
- start := indices[i-1][0]
- end := indices[i][0]
- // Parse the CPU entry, which should be between start/end.
- c, err := newThread(data[start:end])
- if err != nil {
- return nil, err
- }
- cpus = append(cpus, c)
- }
- return cpus, nil
-}
-
-// getThreadsFromPossible makes threads from data read from /sys/devices/system/cpu/possible.
-func getThreadsFromPossible(data []byte) ([]thread, error) {
- possibleRegex := regexp.MustCompile(`(?m)^(\d+)(-(\d+))?$`)
- matches := possibleRegex.FindStringSubmatch(string(data))
- if len(matches) != 4 {
- return nil, fmt.Errorf("mismatch regex from %s: %q", allPossibleCPUs, string(data))
- }
-
- // If matches[3] is empty, we only have one cpu entry.
- if matches[3] == "" {
- matches[3] = matches[1]
- }
-
- begin, err := strconv.ParseInt(matches[1], 10, 64)
- if err != nil {
- return nil, fmt.Errorf("failed to parse begin: %v", err)
- }
- end, err := strconv.ParseInt(matches[3], 10, 64)
- if err != nil {
- return nil, fmt.Errorf("failed to parse end: %v", err)
- }
- if begin > end || begin < 0 || end < 0 {
- return nil, fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", begin, end)
- }
-
- ret := make([]thread, 0, end-begin)
- for i := begin; i <= end; i++ {
- ret = append(ret, thread{
- processorNumber: i,
- id: cpuID{
- physicalID: 0, // we don't care about id for enable ops.
- coreID: 0,
- },
- })
- }
-
- return ret, nil
-}
-
-// cpuID for each thread is defined by the physical and
-// core IDs. If equal, two threads are Hyperthread pairs.
-type cpuID struct {
- physicalID int64
- coreID int64
-}
-
-// type cpu represents pertinent info about a cpu.
-type thread struct {
- processorNumber int64 // the processor number of this CPU.
- vendorID string // the vendorID of CPU (e.g. AuthenticAMD).
- cpuFamily int64 // CPU family number (e.g. 6 for CascadeLake/Skylake).
- model int64 // CPU model number (e.g. 85 for CascadeLake/Skylake).
- id cpuID // id for this thread
- bugs map[string]struct{} // map of vulnerabilities parsed from the 'bugs' field.
-}
-
-// newThread parses a CPU from a single cpu entry from /proc/cpuinfo.
-func newThread(data string) (thread, error) {
- empty := thread{}
- processor, err := parseProcessor(data)
- if err != nil {
- return empty, err
- }
-
- vendorID, err := parseVendorID(data)
- if err != nil {
- return empty, err
- }
-
- cpuFamily, err := parseCPUFamily(data)
- if err != nil {
- return empty, err
- }
-
- model, err := parseModel(data)
- if err != nil {
- return empty, err
- }
-
- physicalID, err := parsePhysicalID(data)
- if err != nil {
- return empty, err
- }
-
- coreID, err := parseCoreID(data)
- if err != nil {
- return empty, err
- }
-
- bugs, err := parseBugs(data)
- if err != nil {
- return empty, err
- }
-
- return thread{
- processorNumber: processor,
- vendorID: vendorID,
- cpuFamily: cpuFamily,
- model: model,
- id: cpuID{
- physicalID: physicalID,
- coreID: coreID,
- },
- bugs: bugs,
- }, nil
-}
-
-// String implements the String method for thread.
-func (t thread) String() string {
- template := `CPU: %d
-CPU ID: %+v
-Vendor: %s
-Family/Model: %d/%d
-Bugs: %s
-`
- bugs := make([]string, 0)
- for bug := range t.bugs {
- bugs = append(bugs, bug)
- }
-
- return fmt.Sprintf(template, t.processorNumber, t.id, t.vendorID, t.cpuFamily, t.model, strings.Join(bugs, ","))
-}
-
-// enable turns on the CPU by writing 1 to /sys/devices/cpu/cpu{N}/online.
-func (t thread) enable() error {
- cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
- return ioutil.WriteFile(cpuPath, []byte{'1'}, 0644)
-}
-
-// disable turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online.
-func (t thread) disable() error {
- cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
- return ioutil.WriteFile(cpuPath, []byte{'0'}, 0644)
-}
-
-// isVulnerable checks if a CPU is vulnerable to mds.
-func (t thread) isVulnerable() bool {
- _, ok := t.bugs[mds]
- return ok
-}
-
-// isActive checks if a CPU is active from /sys/devices/system/cpu/cpu{N}/online
-// If the file does not exist (ioutil returns in error), we assume the CPU is on.
-func (t thread) isActive() bool {
- cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
- data, err := ioutil.ReadFile(cpuPath)
- if err != nil {
- return true
- }
- return len(data) > 0 && data[0] != '0'
-}
-
-// similarTo checks family/model/bugs fields for equality of two
-// processors.
-func (t thread) similarTo(other thread) bool {
- if t.vendorID != other.vendorID {
- return false
- }
-
- if other.cpuFamily != t.cpuFamily {
- return false
- }
-
- if other.model != t.model {
- return false
- }
-
- if len(other.bugs) != len(t.bugs) {
- return false
- }
-
- for bug := range t.bugs {
- if _, ok := other.bugs[bug]; !ok {
- return false
- }
- }
- return true
-}
-
-// parseProcessor grabs the processor field from /proc/cpuinfo output.
-func parseProcessor(data string) (int64, error) {
- return parseIntegerResult(data, processorKey)
-}
-
-// parseVendorID grabs the vendor_id field from /proc/cpuinfo output.
-func parseVendorID(data string) (string, error) {
- return parseRegex(data, vendorIDKey, `[\w\d]+`)
-}
-
-// parseCPUFamily grabs the cpu family field from /proc/cpuinfo output.
-func parseCPUFamily(data string) (int64, error) {
- return parseIntegerResult(data, cpuFamilyKey)
-}
-
-// parseModel grabs the model field from /proc/cpuinfo output.
-func parseModel(data string) (int64, error) {
- return parseIntegerResult(data, modelKey)
-}
-
-// parsePhysicalID parses the physical id field.
-func parsePhysicalID(data string) (int64, error) {
- return parseIntegerResult(data, physicalIDKey)
-}
-
-// parseCoreID parses the core id field.
-func parseCoreID(data string) (int64, error) {
- return parseIntegerResult(data, coreIDKey)
-}
-
-// parseBugs grabs the bugs field from /proc/cpuinfo output.
-func parseBugs(data string) (map[string]struct{}, error) {
- result, err := parseRegex(data, bugsKey, `[\d\w\s]*`)
- if err != nil {
- return nil, err
- }
- bugs := strings.Split(result, " ")
- ret := make(map[string]struct{}, len(bugs))
- for _, bug := range bugs {
- ret[bug] = struct{}{}
- }
- return ret, nil
-}
-
-// parseIntegerResult parses fields expecting an integer.
-func parseIntegerResult(data, key string) (int64, error) {
- result, err := parseRegex(data, key, `\d+`)
- if err != nil {
- return 0, err
- }
- return strconv.ParseInt(result, 0, 64)
-}
-
-// buildRegex builds a regex for parsing each CPU field.
-func buildRegex(key, match string) *regexp.Regexp {
- reg := fmt.Sprintf(`(?m)^%s\s*:\s*(.*)$`, key)
- return regexp.MustCompile(reg)
-}
-
-// parseRegex parses data with key inserted into a standard regex template.
-func parseRegex(data, key, match string) (string, error) {
- r := buildRegex(key, match)
- matches := r.FindStringSubmatch(data)
- if len(matches) < 2 {
- return "", fmt.Errorf("failed to match key %q: %q", key, data)
- }
- return matches[1], nil
-}
diff --git a/runsc/mitigate/cpu_test.go b/runsc/mitigate/cpu_test.go
deleted file mode 100644
index 374333465..000000000
--- a/runsc/mitigate/cpu_test.go
+++ /dev/null
@@ -1,605 +0,0 @@
-// Copyright 2021 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package mitigate
-
-import (
- "fmt"
- "io/ioutil"
- "strings"
- "testing"
-)
-
-// mockCPU represents data from CPUs that will be mitigated.
-type mockCPU struct {
- name string
- vendorID string
- family int
- model int
- modelName string
- bugs string
- physicalCores int
- cores int
- threadsPerCore int
-}
-
-var cascadeLake4 = mockCPU{
- name: "CascadeLake",
- vendorID: "GenuineIntel",
- family: 6,
- model: 85,
- modelName: "Intel(R) Xeon(R) CPU",
- bugs: "spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa",
- physicalCores: 1,
- cores: 2,
- threadsPerCore: 2,
-}
-
-var haswell2 = mockCPU{
- name: "Haswell",
- vendorID: "GenuineIntel",
- family: 6,
- model: 63,
- modelName: "Intel(R) Xeon(R) CPU",
- bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs",
- physicalCores: 1,
- cores: 1,
- threadsPerCore: 2,
-}
-
-var haswell2core = mockCPU{
- name: "Haswell2Physical",
- vendorID: "GenuineIntel",
- family: 6,
- model: 63,
- modelName: "Intel(R) Xeon(R) CPU",
- bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs",
- physicalCores: 2,
- cores: 1,
- threadsPerCore: 1,
-}
-
-var amd8 = mockCPU{
- name: "AMD",
- vendorID: "AuthenticAMD",
- family: 23,
- model: 49,
- modelName: "AMD EPYC 7B12",
- bugs: "sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass",
- physicalCores: 4,
- cores: 1,
- threadsPerCore: 2,
-}
-
-// makeCPUString makes a string formated like /proc/cpuinfo for each cpuTestCase
-func (tc mockCPU) makeCPUString() string {
- template := `processor : %d
-vendor_id : %s
-cpu family : %d
-model : %d
-model name : %s
-physical id : %d
-core id : %d
-cpu cores : %d
-bugs : %s
-`
- ret := ``
- for i := 0; i < tc.physicalCores; i++ {
- for j := 0; j < tc.cores; j++ {
- for k := 0; k < tc.threadsPerCore; k++ {
- processorNum := (i*tc.cores+j)*tc.threadsPerCore + k
- ret += fmt.Sprintf(template,
- processorNum, /*processor*/
- tc.vendorID, /*vendor_id*/
- tc.family, /*cpu family*/
- tc.model, /*model*/
- tc.modelName, /*model name*/
- i, /*physical id*/
- j, /*core id*/
- tc.cores*tc.physicalCores, /*cpu cores*/
- tc.bugs /*bugs*/)
- }
- }
- }
- return ret
-}
-
-func (tc mockCPU) makeSysPossibleString() string {
- max := tc.physicalCores * tc.cores * tc.threadsPerCore
- if max == 1 {
- return "0"
- }
- return fmt.Sprintf("0-%d", max-1)
-}
-
-// TestMockCPUSet tests mock cpu test cases against the cpuSet functions.
-func TestMockCPUSet(t *testing.T) {
- for _, tc := range []struct {
- testCase mockCPU
- isVulnerable bool
- }{
- {
- testCase: amd8,
- isVulnerable: false,
- },
- {
- testCase: haswell2,
- isVulnerable: true,
- },
- {
- testCase: haswell2core,
- isVulnerable: true,
- },
-
- {
- testCase: cascadeLake4,
- isVulnerable: true,
- },
- } {
- t.Run(tc.testCase.name, func(t *testing.T) {
- data := tc.testCase.makeCPUString()
- vulnerable := func(t thread) bool {
- return t.isVulnerable()
- }
- set, err := newCPUSet([]byte(data), vulnerable)
- if err != nil {
- t.Fatalf("Failed to ")
- }
- remaining := set.getRemainingList()
- // In the non-vulnerable case, no cores should be shutdown so all should remain.
- want := tc.testCase.physicalCores * tc.testCase.cores * tc.testCase.threadsPerCore
- if tc.isVulnerable {
- want = tc.testCase.physicalCores * tc.testCase.cores
- }
-
- if want != len(remaining) {
- t.Fatalf("Failed to shutdown the correct number of cores: want: %d got: %d", want, len(remaining))
- }
-
- if !tc.isVulnerable {
- return
- }
-
- // If the set is vulnerable, we expect only 1 thread per hyperthread pair.
- for _, r := range remaining {
- if _, ok := set[r.id]; !ok {
- t.Fatalf("Entry %+v not in map, there must be two entries in the same thread group.", r)
- }
- delete(set, r.id)
- }
-
- possible := tc.testCase.makeSysPossibleString()
- set, err = newCPUSetFromPossible([]byte(possible))
- if err != nil {
- t.Fatalf("Failed to make cpuSet: %v", err)
- }
-
- want = tc.testCase.physicalCores * tc.testCase.cores * tc.testCase.threadsPerCore
- got := len(set.getRemainingList())
- if got != want {
- t.Fatalf("Returned the wrong number of CPUs want: %d got: %d", want, got)
- }
- })
- }
-}
-
-// TestGetCPU tests basic parsing of single CPU strings from reading
-// /proc/cpuinfo.
-func TestGetCPU(t *testing.T) {
- data := `processor : 0
-vendor_id : GenuineIntel
-cpu family : 6
-model : 85
-physical id: 0
-core id : 0
-bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa itlb_multihit
-`
- want := thread{
- processorNumber: 0,
- vendorID: "GenuineIntel",
- cpuFamily: 6,
- model: 85,
- id: cpuID{
- physicalID: 0,
- coreID: 0,
- },
- bugs: map[string]struct{}{
- "cpu_meltdown": struct{}{},
- "spectre_v1": struct{}{},
- "spectre_v2": struct{}{},
- "spec_store_bypass": struct{}{},
- "l1tf": struct{}{},
- "mds": struct{}{},
- "swapgs": struct{}{},
- "taa": struct{}{},
- "itlb_multihit": struct{}{},
- },
- }
-
- got, err := newThread(data)
- if err != nil {
- t.Fatalf("getCpu failed with error: %v", err)
- }
-
- if !want.similarTo(got) {
- t.Fatalf("Failed cpus not similar: got: %+v, want: %+v", got, want)
- }
-
- if !got.isVulnerable() {
- t.Fatalf("Failed: cpu should be vulnerable.")
- }
-}
-
-func TestInvalid(t *testing.T) {
- result, err := getThreads(`something not a processor`)
- if err == nil {
- t.Fatalf("getCPU set didn't return an error: %+v", result)
- }
-
- if !strings.Contains(err.Error(), "no cpus") {
- t.Fatalf("Incorrect error returned: %v", err)
- }
-}
-
-// TestCPUSet tests getting the right number of CPUs from
-// parsing full output of /proc/cpuinfo.
-func TestCPUSet(t *testing.T) {
- data := `processor : 0
-vendor_id : GenuineIntel
-cpu family : 6
-model : 63
-model name : Intel(R) Xeon(R) CPU @ 2.30GHz
-stepping : 0
-microcode : 0x1
-cpu MHz : 2299.998
-cache size : 46080 KB
-physical id : 0
-siblings : 2
-core id : 0
-cpu cores : 1
-apicid : 0
-initial apicid : 0
-fpu : yes
-fpu_exception : yes
-cpuid level : 13
-wp : yes
-flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
-bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
-bogomips : 4599.99
-clflush size : 64
-cache_alignment : 64
-address sizes : 46 bits physical, 48 bits virtual
-power management:
-
-processor : 1
-vendor_id : GenuineIntel
-cpu family : 6
-model : 63
-model name : Intel(R) Xeon(R) CPU @ 2.30GHz
-stepping : 0
-microcode : 0x1
-cpu MHz : 2299.998
-cache size : 46080 KB
-physical id : 0
-siblings : 2
-core id : 0
-cpu cores : 1
-apicid : 1
-initial apicid : 1
-fpu : yes
-fpu_exception : yes
-cpuid level : 13
-wp : yes
-flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
-bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
-bogomips : 4599.99
-clflush size : 64
-cache_alignment : 64
-address sizes : 46 bits physical, 48 bits virtual
-power management:
-`
- cpuSet, err := getThreads(data)
- if err != nil {
- t.Fatalf("getCPUSet failed: %v", err)
- }
-
- wantCPULen := 2
- if len(cpuSet) != wantCPULen {
- t.Fatalf("Num CPU mismatch: want: %d, got: %d", wantCPULen, len(cpuSet))
- }
-
- wantCPU := thread{
- vendorID: "GenuineIntel",
- cpuFamily: 6,
- model: 63,
- bugs: map[string]struct{}{
- "cpu_meltdown": struct{}{},
- "spectre_v1": struct{}{},
- "spectre_v2": struct{}{},
- "spec_store_bypass": struct{}{},
- "l1tf": struct{}{},
- "mds": struct{}{},
- "swapgs": struct{}{},
- },
- }
-
- for _, c := range cpuSet {
- if !wantCPU.similarTo(c) {
- t.Fatalf("Failed cpus not equal: got: %+v, want: %+v", c, wantCPU)
- }
- }
-}
-
-// TestReadFile is a smoke test for parsing methods.
-func TestReadFile(t *testing.T) {
- data, err := ioutil.ReadFile("/proc/cpuinfo")
- if err != nil {
- t.Fatalf("Failed to read cpuinfo: %v", err)
- }
-
- vulnerable := func(t thread) bool {
- return t.isVulnerable()
- }
-
- set, err := newCPUSet(data, vulnerable)
- if err != nil {
- t.Fatalf("Failed to parse CPU data %v\n%s", err, data)
- }
-
- if len(set) < 1 {
- t.Fatalf("Failed to parse any CPUs: %d", len(set))
- }
-
- t.Log(set)
-}
-
-// TestVulnerable tests if the isVulnerable method is correct
-// among known CPUs in GCP.
-func TestVulnerable(t *testing.T) {
- const haswell = `processor : 0
-vendor_id : GenuineIntel
-cpu family : 6
-model : 63
-model name : Intel(R) Xeon(R) CPU @ 2.30GHz
-stepping : 0
-microcode : 0x1
-cpu MHz : 2299.998
-cache size : 46080 KB
-physical id : 0
-siblings : 4
-core id : 0
-cpu cores : 2
-apicid : 0
-initial apicid : 0
-fpu : yes
-fpu_exception : yes
-cpuid level : 13
-wp : yes
-flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
-bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
-bogomips : 4599.99
-clflush size : 64
-cache_alignment : 64
-address sizes : 46 bits physical, 48 bits virtual
-power management:`
-
- const skylake = `processor : 0
-vendor_id : GenuineIntel
-cpu family : 6
-model : 85
-model name : Intel(R) Xeon(R) CPU @ 2.00GHz
-stepping : 3
-microcode : 0x1
-cpu MHz : 2000.180
-cache size : 39424 KB
-physical id : 0
-siblings : 2
-core id : 0
-cpu cores : 1
-apicid : 0
-initial apicid : 0
-fpu : yes
-fpu_exception : yes
-cpuid level : 13
-wp : yes
-flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
-bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
-bogomips : 4000.36
-clflush size : 64
-cache_alignment : 64
-address sizes : 46 bits physical, 48 bits virtual
-power management:`
-
- const cascade = `processor : 0
-vendor_id : GenuineIntel
-cpu family : 6
-model : 85
-model name : Intel(R) Xeon(R) CPU
-stepping : 7
-microcode : 0x1
-cpu MHz : 2800.198
-cache size : 33792 KB
-physical id : 0
-siblings : 2
-core id : 0
-cpu cores : 1
-apicid : 0
-initial apicid : 0
-fpu : yes
-fpu_exception : yes
-cpuid level : 13
-wp : yes
-flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2
- ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmu
-lqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowpr
-efetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid r
-tm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves a
-rat avx512_vnni md_clear arch_capabilities
-bugs : spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa
-bogomips : 5600.39
-clflush size : 64
-cache_alignment : 64
-address sizes : 46 bits physical, 48 bits virtual
-power management:`
-
- const amd = `processor : 0
-vendor_id : AuthenticAMD
-cpu family : 23
-model : 49
-model name : AMD EPYC 7B12
-stepping : 0
-microcode : 0x1000065
-cpu MHz : 2250.000
-cache size : 512 KB
-physical id : 0
-siblings : 2
-core id : 0
-cpu cores : 1
-apicid : 0
-initial apicid : 0
-fpu : yes
-fpu_exception : yes
-cpuid level : 13
-wp : yes
-flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid
-bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass
-bogomips : 4500.00
-TLB size : 3072 4K pages
-clflush size : 64
-cache_alignment : 64
-address sizes : 48 bits physical, 48 bits virtual
-power management:`
-
- for _, tc := range []struct {
- name string
- cpuString string
- vulnerable bool
- }{
- {
- name: "haswell",
- cpuString: haswell,
- vulnerable: true,
- }, {
- name: "skylake",
- cpuString: skylake,
- vulnerable: true,
- }, {
- name: "amd",
- cpuString: amd,
- vulnerable: false,
- },
- } {
- t.Run(tc.name, func(t *testing.T) {
- set, err := getThreads(tc.cpuString)
- if err != nil {
- t.Fatalf("Failed to getCPUSet:%v\n %s", err, tc.cpuString)
- }
-
- if len(set) < 1 {
- t.Fatalf("Returned empty cpu set: %v", set)
- }
-
- for _, c := range set {
- got := func() bool {
- return c.isVulnerable()
- }()
-
- if got != tc.vulnerable {
- t.Fatalf("Mismatch vulnerable for cpu %+s: got %t want: %t", tc.name, tc.vulnerable, got)
- }
- }
- })
- }
-}
-
-func TestReverse(t *testing.T) {
- const noParse = "-1-"
- for _, tc := range []struct {
- name string
- output string
- wantErr error
- wantCount int
- }{
- {
- name: "base",
- output: "0-7",
- wantErr: nil,
- wantCount: 8,
- },
- {
- name: "huge",
- output: "0-111",
- wantErr: nil,
- wantCount: 112,
- },
- {
- name: "not zero",
- output: "50-53",
- wantErr: nil,
- wantCount: 4,
- },
- {
- name: "small",
- output: "0",
- wantErr: nil,
- wantCount: 1,
- },
- {
- name: "invalid order",
- output: "10-6",
- wantErr: fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", 10, 6),
- },
- {
- name: "no parse",
- output: noParse,
- wantErr: fmt.Errorf(`mismatch regex from /sys/devices/system/cpu/possible: %q`, noParse),
- },
- } {
- t.Run(tc.name, func(t *testing.T) {
- threads, err := getThreadsFromPossible([]byte(tc.output))
-
- switch {
- case tc.wantErr == nil:
- if err != nil {
- t.Fatalf("Wanted nil err, got: %v", err)
- }
- case err == nil:
- t.Fatalf("Want error: %v got: %v", tc.wantErr, err)
- default:
- if tc.wantErr.Error() != err.Error() {
- t.Fatalf("Want error: %v got error: %v", tc.wantErr, err)
- }
- }
-
- if len(threads) != tc.wantCount {
- t.Fatalf("Want count: %d got: %d", tc.wantCount, len(threads))
- }
- })
- }
-}
-
-func TestReverseSmoke(t *testing.T) {
- data, err := ioutil.ReadFile(allPossibleCPUs)
- if err != nil {
- t.Fatalf("Failed to read from possible: %v", err)
- }
- threads, err := getThreadsFromPossible(data)
- if err != nil {
- t.Fatalf("Could not parse possible output: %v", err)
- }
-
- if len(threads) <= 0 {
- t.Fatalf("Didn't get any CPU cores: %d", len(threads))
- }
-}
diff --git a/runsc/mitigate/mitigate.go b/runsc/mitigate/mitigate.go
index 91de623e3..24f67414c 100644
--- a/runsc/mitigate/mitigate.go
+++ b/runsc/mitigate/mitigate.go
@@ -14,121 +14,440 @@
// Package mitigate provides libraries for the mitigate command. The
// mitigate command mitigates side channel attacks such as MDS. Mitigate
-// shuts down CPUs via /sys/devices/system/cpu/cpu{N}/online. In addition,
-// the mitigate also handles computing available CPU in kubernetes kube_config
-// files.
+// shuts down CPUs via /sys/devices/system/cpu/cpu{N}/online.
package mitigate
import (
"fmt"
"io/ioutil"
-
- "gvisor.dev/gvisor/pkg/log"
- "gvisor.dev/gvisor/runsc/flag"
+ "os"
+ "regexp"
+ "sort"
+ "strconv"
+ "strings"
)
const (
- cpuInfo = "/proc/cpuinfo"
- allPossibleCPUs = "/sys/devices/system/cpu/possible"
+ // mds is the only bug we care about.
+ mds = "mds"
+
+ // Constants for parsing /proc/cpuinfo.
+ processorKey = "processor"
+ vendorIDKey = "vendor_id"
+ cpuFamilyKey = "cpu family"
+ modelKey = "model"
+ physicalIDKey = "physical id"
+ coreIDKey = "core id"
+ bugsKey = "bugs"
+
+ // Path to shutdown a CPU.
+ cpuOnlineTemplate = "/sys/devices/system/cpu/cpu%d/online"
)
-// Mitigate handles high level mitigate operations provided to runsc.
-type Mitigate struct {
- dryRun bool // Run the command without changing the underlying system.
- reverse bool // Reverse mitigate by turning on all CPU cores.
- other mitigate // Struct holds extra mitigate logic.
- path string // path to read for each operation (e.g. /proc/cpuinfo).
+// CPUSet contains a map of all CPUs on the system, mapped
+// by Physical ID and CoreIDs. threads with the same
+// Core and Physical ID are Hyperthread pairs.
+type CPUSet map[threadID]*ThreadGroup
+
+// NewCPUSet creates a CPUSet from data read from /proc/cpuinfo.
+func NewCPUSet(data []byte, vulnerable func(Thread) bool) (CPUSet, error) {
+ processors, err := getThreads(string(data))
+ if err != nil {
+ return nil, err
+ }
+
+ set := make(CPUSet)
+ for _, p := range processors {
+ // Each ID is of the form physicalID:coreID. Hyperthread pairs
+ // have identical physical and core IDs. We need to match
+ // Hyperthread pairs so that we can shutdown all but one per
+ // pair.
+ core, ok := set[p.id]
+ if !ok {
+ core = &ThreadGroup{}
+ set[p.id] = core
+ }
+ core.isVulnerable = core.isVulnerable || vulnerable(p)
+ core.threads = append(core.threads, p)
+ }
+
+ // We need to make sure we shutdown the lowest number processor per
+ // thread group.
+ for _, tg := range set {
+ sort.Slice(tg.threads, func(i, j int) bool {
+ return tg.threads[i].processorNumber < tg.threads[j].processorNumber
+ })
+ }
+ return set, nil
}
-// Usage implments Usage for cmd.Mitigate.
-func (m Mitigate) Usage() string {
- usageString := `mitigate [flags]
+// NewCPUSetFromPossible makes a cpuSet data read from
+// /sys/devices/system/cpu/possible. This is used in enable operations
+// where the caller simply wants to enable all CPUS.
+func NewCPUSetFromPossible(data []byte) (CPUSet, error) {
+ threads, err := GetThreadsFromPossible(data)
+ if err != nil {
+ return nil, err
+ }
+
+ // We don't care if a CPU is vulnerable or not, we just
+ // want to return a list of all CPUs on the host.
+ set := CPUSet{
+ threads[0].id: &ThreadGroup{
+ threads: threads,
+ isVulnerable: false,
+ },
+ }
+ return set, nil
+}
-Mitigate mitigates a system to the "MDS" vulnerability by implementing a manual shutdown of SMT. The command checks /proc/cpuinfo for cpus having the MDS vulnerability, and if found, shutdown all but one CPU per hyperthread pair via /sys/devices/system/cpu/cpu{N}/online. CPUs can be restored by writing "2" to each file in /sys/devices/system/cpu/cpu{N}/online or performing a system reboot.
+// String implements the String method for CPUSet.
+func (c CPUSet) String() string {
+ ret := ""
+ for _, tg := range c {
+ ret += fmt.Sprintf("%s\n", tg)
+ }
+ return ret
+}
-The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online.
-`
- return usageString + m.other.usage()
+// GetRemainingList returns the list of threads that will remain active
+// after mitigation.
+func (c CPUSet) GetRemainingList() []Thread {
+ threads := make([]Thread, 0, len(c))
+ for _, core := range c {
+ // If we're vulnerable, take only one thread from the pair.
+ if core.isVulnerable {
+ threads = append(threads, core.threads[0])
+ continue
+ }
+ // Otherwise don't shutdown anything.
+ threads = append(threads, core.threads...)
+ }
+ return threads
}
-// SetFlags sets flags for the command Mitigate.
-func (m Mitigate) SetFlags(f *flag.FlagSet) {
- f.BoolVar(&m.dryRun, "dryrun", false, "run the command without changing system")
- f.BoolVar(&m.reverse, "reverse", false, "reverse mitigate by enabling all CPUs")
- m.other.setFlags(f)
- m.path = cpuInfo
- if m.reverse {
- m.path = allPossibleCPUs
+// GetShutdownList returns the list of threads that will be shutdown on
+// mitigation.
+func (c CPUSet) GetShutdownList() []Thread {
+ threads := make([]Thread, 0)
+ for _, core := range c {
+ // Only if we're vulnerable do shutdown anything. In this case,
+ // shutdown all but the first entry.
+ if core.isVulnerable && len(core.threads) > 1 {
+ threads = append(threads, core.threads[1:]...)
+ }
}
+ return threads
}
-// Execute executes the Mitigate command.
-func (m Mitigate) Execute() error {
- data, err := ioutil.ReadFile(m.path)
- if err != nil {
- return fmt.Errorf("failed to read %s: %v", m.path, err)
+// ThreadGroup represents Hyperthread pairs on the same physical/core ID.
+type ThreadGroup struct {
+ threads []Thread
+ isVulnerable bool
+}
+
+// String implements the String method for threadGroup.
+func (c ThreadGroup) String() string {
+ ret := fmt.Sprintf("ThreadGroup:\nIsVulnerable: %t\n", c.isVulnerable)
+ for _, processor := range c.threads {
+ ret += fmt.Sprintf("%s\n", processor)
}
+ return ret
+}
- if m.reverse {
- err := m.doReverse(data)
+// getThreads returns threads structs from reading /proc/cpuinfo.
+func getThreads(data string) ([]Thread, error) {
+ // Each processor entry should start with the
+ // processor key. Find the beginings of each.
+ r := buildRegex(processorKey, `\d+`)
+ indices := r.FindAllStringIndex(data, -1)
+ if len(indices) < 1 {
+ return nil, fmt.Errorf("no cpus found for: %q", data)
+ }
+
+ // Add the ending index for last entry.
+ indices = append(indices, []int{len(data), -1})
+
+ // Valid cpus are now defined by strings in between
+ // indexes (e.g. data[index[i], index[i+1]]).
+ // There should be len(indicies) - 1 CPUs
+ // since the last index is the end of the string.
+ cpus := make([]Thread, 0, len(indices))
+ // Find each string that represents a CPU. These begin "processor".
+ for i := 1; i < len(indices); i++ {
+ start := indices[i-1][0]
+ end := indices[i][0]
+ // Parse the CPU entry, which should be between start/end.
+ c, err := newThread(data[start:end])
if err != nil {
- return fmt.Errorf("reverse operation failed: %v", err)
+ return nil, err
}
- return nil
+ cpus = append(cpus, c)
+ }
+ return cpus, nil
+}
+
+// GetThreadsFromPossible makes threads from data read from /sys/devices/system/cpu/possible.
+func GetThreadsFromPossible(data []byte) ([]Thread, error) {
+ possibleRegex := regexp.MustCompile(`(?m)^(\d+)(-(\d+))?$`)
+ matches := possibleRegex.FindStringSubmatch(string(data))
+ if len(matches) != 4 {
+ return nil, fmt.Errorf("mismatch regex from possible: %q", string(data))
+ }
+
+ // If matches[3] is empty, we only have one cpu entry.
+ if matches[3] == "" {
+ matches[3] = matches[1]
}
- set, err := m.doMitigate(data)
+ begin, err := strconv.ParseInt(matches[1], 10, 64)
if err != nil {
- return fmt.Errorf("mitigate operation failed: %v", err)
+ return nil, fmt.Errorf("failed to parse begin: %v", err)
}
- return m.other.execute(set, m.dryRun)
+ end, err := strconv.ParseInt(matches[3], 10, 64)
+ if err != nil {
+ return nil, fmt.Errorf("failed to parse end: %v", err)
+ }
+ if begin > end || begin < 0 || end < 0 {
+ return nil, fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", begin, end)
+ }
+
+ ret := make([]Thread, 0, end-begin)
+ for i := begin; i <= end; i++ {
+ ret = append(ret, Thread{
+ processorNumber: i,
+ id: threadID{
+ physicalID: 0, // we don't care about id for enable ops.
+ coreID: 0,
+ },
+ })
+ }
+
+ return ret, nil
+}
+
+// threadID for each thread is defined by the physical and
+// core IDs. If equal, two threads are Hyperthread pairs.
+type threadID struct {
+ physicalID int64
+ coreID int64
}
-func (m Mitigate) doMitigate(data []byte) (cpuSet, error) {
- set, err := newCPUSet(data, m.other.vulnerable)
+// Thread represents pertinent info about a single hyperthread in a pair.
+type Thread struct {
+ processorNumber int64 // the processor number of this CPU.
+ vendorID string // the vendorID of CPU (e.g. AuthenticAMD).
+ cpuFamily int64 // CPU family number (e.g. 6 for CascadeLake/Skylake).
+ model int64 // CPU model number (e.g. 85 for CascadeLake/Skylake).
+ id threadID // id for this thread
+ bugs map[string]struct{} // map of vulnerabilities parsed from the 'bugs' field.
+}
+
+// newThread parses a CPU from a single cpu entry from /proc/cpuinfo.
+func newThread(data string) (Thread, error) {
+ empty := Thread{}
+ processor, err := parseProcessor(data)
if err != nil {
- return nil, err
+ return empty, err
}
- log.Infof("Mitigate found the following CPUs...")
- log.Infof("%s", set)
+ vendorID, err := parseVendorID(data)
+ if err != nil {
+ return empty, err
+ }
- disableList := set.getShutdownList()
- log.Infof("Disabling threads on thread pairs.")
- for _, t := range disableList {
- log.Infof("Disable thread: %s", t)
- if m.dryRun {
- continue
- }
- if err := t.disable(); err != nil {
- return nil, fmt.Errorf("error disabling thread: %s err: %v", t, err)
- }
+ cpuFamily, err := parseCPUFamily(data)
+ if err != nil {
+ return empty, err
}
- log.Infof("Shutdown successful.")
- return set, nil
+
+ model, err := parseModel(data)
+ if err != nil {
+ return empty, err
+ }
+
+ physicalID, err := parsePhysicalID(data)
+ if err != nil {
+ return empty, err
+ }
+
+ coreID, err := parseCoreID(data)
+ if err != nil {
+ return empty, err
+ }
+
+ bugs, err := parseBugs(data)
+ if err != nil {
+ return empty, err
+ }
+
+ return Thread{
+ processorNumber: processor,
+ vendorID: vendorID,
+ cpuFamily: cpuFamily,
+ model: model,
+ id: threadID{
+ physicalID: physicalID,
+ coreID: coreID,
+ },
+ bugs: bugs,
+ }, nil
+}
+
+// String implements the String method for thread.
+func (t Thread) String() string {
+ template := `CPU: %d
+CPU ID: %+v
+Vendor: %s
+Family/Model: %d/%d
+Bugs: %s
+`
+ bugs := make([]string, 0)
+ for bug := range t.bugs {
+ bugs = append(bugs, bug)
+ }
+
+ return fmt.Sprintf(template, t.processorNumber, t.id, t.vendorID, t.cpuFamily, t.model, strings.Join(bugs, ","))
+}
+
+// Enable turns on the CPU by writing 1 to /sys/devices/cpu/cpu{N}/online.
+func (t Thread) Enable() error {
+ // Linux ensures that "cpu0" is always online.
+ if t.processorNumber == 0 {
+ return nil
+ }
+ cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
+ f, err := os.OpenFile(cpuPath, os.O_WRONLY|os.O_CREATE, 0644)
+ if err != nil {
+ return fmt.Errorf("failed to open file %s: %v", cpuPath, err)
+ }
+ if _, err = f.Write([]byte{'1'}); err != nil {
+ return fmt.Errorf("failed to write '1' to %s: %v", cpuPath, err)
+ }
+ return nil
+}
+
+// Disable turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online.
+func (t Thread) Disable() error {
+ // The core labeled "cpu0" can never be taken offline via this method.
+ // Linux will return EPERM if the user even creates a file at the /sys
+ // path above.
+ if t.processorNumber == 0 {
+ return fmt.Errorf("invalid shutdown operation: cpu0 cannot be disabled")
+ }
+ cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
+ return ioutil.WriteFile(cpuPath, []byte{'0'}, 0644)
}
-func (m Mitigate) doReverse(data []byte) error {
- set, err := newCPUSetFromPossible(data)
+// IsVulnerable checks if a CPU is vulnerable to mds.
+func (t Thread) IsVulnerable() bool {
+ _, ok := t.bugs[mds]
+ return ok
+}
+
+// isActive checks if a CPU is active from /sys/devices/system/cpu/cpu{N}/online
+// If the file does not exist (ioutil returns in error), we assume the CPU is on.
+func (t Thread) isActive() bool {
+ cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber)
+ data, err := ioutil.ReadFile(cpuPath)
if err != nil {
- return err
+ return true
}
+ return len(data) > 0 && data[0] != '0'
+}
- log.Infof("Reverse mitigate found the following CPUs...")
- log.Infof("%s", set)
+// SimilarTo checks family/model/bugs fields for equality of two
+// processors.
+func (t Thread) SimilarTo(other Thread) bool {
+ if t.vendorID != other.vendorID {
+ return false
+ }
- enableList := set.getRemainingList()
+ if other.cpuFamily != t.cpuFamily {
+ return false
+ }
- log.Infof("Enabling all CPUs...")
- for _, t := range enableList {
- log.Infof("Enabling thread: %s", t)
- if m.dryRun {
- continue
- }
- if err := t.enable(); err != nil {
- return fmt.Errorf("error enabling thread: %s err: %v", t, err)
+ if other.model != t.model {
+ return false
+ }
+
+ if len(other.bugs) != len(t.bugs) {
+ return false
+ }
+
+ for bug := range t.bugs {
+ if _, ok := other.bugs[bug]; !ok {
+ return false
}
}
- log.Infof("Enable successful.")
- return nil
+ return true
+}
+
+// parseProcessor grabs the processor field from /proc/cpuinfo output.
+func parseProcessor(data string) (int64, error) {
+ return parseIntegerResult(data, processorKey)
+}
+
+// parseVendorID grabs the vendor_id field from /proc/cpuinfo output.
+func parseVendorID(data string) (string, error) {
+ return parseRegex(data, vendorIDKey, `[\w\d]+`)
+}
+
+// parseCPUFamily grabs the cpu family field from /proc/cpuinfo output.
+func parseCPUFamily(data string) (int64, error) {
+ return parseIntegerResult(data, cpuFamilyKey)
+}
+
+// parseModel grabs the model field from /proc/cpuinfo output.
+func parseModel(data string) (int64, error) {
+ return parseIntegerResult(data, modelKey)
+}
+
+// parsePhysicalID parses the physical id field.
+func parsePhysicalID(data string) (int64, error) {
+ return parseIntegerResult(data, physicalIDKey)
+}
+
+// parseCoreID parses the core id field.
+func parseCoreID(data string) (int64, error) {
+ return parseIntegerResult(data, coreIDKey)
+}
+
+// parseBugs grabs the bugs field from /proc/cpuinfo output.
+func parseBugs(data string) (map[string]struct{}, error) {
+ result, err := parseRegex(data, bugsKey, `[\d\w\s]*`)
+ if err != nil {
+ return nil, err
+ }
+ bugs := strings.Split(result, " ")
+ ret := make(map[string]struct{}, len(bugs))
+ for _, bug := range bugs {
+ ret[bug] = struct{}{}
+ }
+ return ret, nil
+}
+
+// parseIntegerResult parses fields expecting an integer.
+func parseIntegerResult(data, key string) (int64, error) {
+ result, err := parseRegex(data, key, `\d+`)
+ if err != nil {
+ return 0, err
+ }
+ return strconv.ParseInt(result, 0, 64)
+}
+
+// buildRegex builds a regex for parsing each CPU field.
+func buildRegex(key, match string) *regexp.Regexp {
+ reg := fmt.Sprintf(`(?m)^%s\s*:\s*(.*)$`, key)
+ return regexp.MustCompile(reg)
+}
+
+// parseRegex parses data with key inserted into a standard regex template.
+func parseRegex(data, key, match string) (string, error) {
+ r := buildRegex(key, match)
+ matches := r.FindStringSubmatch(data)
+ if len(matches) < 2 {
+ return "", fmt.Errorf("failed to match key %q: %q", key, data)
+ }
+ return matches[1], nil
}
diff --git a/runsc/mitigate/mitigate_test.go b/runsc/mitigate/mitigate_test.go
index b3a9a9b18..fbd8eb886 100644
--- a/runsc/mitigate/mitigate_test.go
+++ b/runsc/mitigate/mitigate_test.go
@@ -17,138 +17,519 @@ package mitigate
import (
"fmt"
"io/ioutil"
- "os"
"strings"
"testing"
+
+ "gvisor.dev/gvisor/runsc/mitigate/mock"
)
-type executeTestCase struct {
- name string
- mitigateData string
- mitigateError error
- reverseData string
- reverseError error
+// TestMockCPUSet tests mock cpu test cases against the cpuSet functions.
+func TestMockCPUSet(t *testing.T) {
+ for _, tc := range []struct {
+ testCase mock.CPU
+ isVulnerable bool
+ }{
+ {
+ testCase: mock.AMD8,
+ isVulnerable: false,
+ },
+ {
+ testCase: mock.Haswell2,
+ isVulnerable: true,
+ },
+ {
+ testCase: mock.Haswell2core,
+ isVulnerable: true,
+ },
+ {
+ testCase: mock.CascadeLake2,
+ isVulnerable: true,
+ },
+ {
+ testCase: mock.CascadeLake4,
+ isVulnerable: true,
+ },
+ } {
+ t.Run(tc.testCase.Name, func(t *testing.T) {
+ data := tc.testCase.MakeCPUString()
+ vulnerable := func(t Thread) bool {
+ return t.IsVulnerable()
+ }
+ set, err := NewCPUSet([]byte(data), vulnerable)
+ if err != nil {
+ t.Fatalf("Failed to create cpuSet: %v", err)
+ }
+
+ for _, tg := range set {
+ if err := checkSorted(tg.threads); err != nil {
+ t.Fatalf("Failed to sort cpuSet: %v", err)
+ }
+ }
+
+ remaining := set.GetRemainingList()
+ // In the non-vulnerable case, no cores should be shutdown so all should remain.
+ want := tc.testCase.PhysicalCores * tc.testCase.Cores * tc.testCase.ThreadsPerCore
+ if tc.isVulnerable {
+ want = tc.testCase.PhysicalCores * tc.testCase.Cores
+ }
+
+ if want != len(remaining) {
+ t.Fatalf("Failed to shutdown the correct number of cores: want: %d got: %d", want, len(remaining))
+ }
+
+ if !tc.isVulnerable {
+ return
+ }
+
+ // If the set is vulnerable, we expect only 1 thread per hyperthread pair.
+ for _, r := range remaining {
+ if _, ok := set[r.id]; !ok {
+ t.Fatalf("Entry %+v not in map, there must be two entries in the same thread group.", r)
+ }
+ delete(set, r.id)
+ }
+
+ possible := tc.testCase.MakeSysPossibleString()
+ set, err = NewCPUSetFromPossible([]byte(possible))
+ if err != nil {
+ t.Fatalf("Failed to make cpuSet: %v", err)
+ }
+
+ want = tc.testCase.PhysicalCores * tc.testCase.Cores * tc.testCase.ThreadsPerCore
+ got := len(set.GetRemainingList())
+ if got != want {
+ t.Fatalf("Returned the wrong number of CPUs want: %d got: %d", want, got)
+ }
+ })
+ }
}
-func TestExecute(t *testing.T) {
+// TestGetCPU tests basic parsing of single CPU strings from reading
+// /proc/cpuinfo.
+func TestGetCPU(t *testing.T) {
+ data := `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 85
+physical id: 0
+core id : 0
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa itlb_multihit
+`
+ want := Thread{
+ processorNumber: 0,
+ vendorID: "GenuineIntel",
+ cpuFamily: 6,
+ model: 85,
+ id: threadID{
+ physicalID: 0,
+ coreID: 0,
+ },
+ bugs: map[string]struct{}{
+ "cpu_meltdown": struct{}{},
+ "spectre_v1": struct{}{},
+ "spectre_v2": struct{}{},
+ "spec_store_bypass": struct{}{},
+ "l1tf": struct{}{},
+ "mds": struct{}{},
+ "swapgs": struct{}{},
+ "taa": struct{}{},
+ "itlb_multihit": struct{}{},
+ },
+ }
- partial := `processor : 1
-vendor_id : AuthenticAMD
-cpu family : 23
-model : 49
-model name : AMD EPYC 7B12
-physical id : 0
-bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass
+ got, err := newThread(data)
+ if err != nil {
+ t.Fatalf("getCpu failed with error: %v", err)
+ }
+
+ if !want.SimilarTo(got) {
+ t.Fatalf("Failed cpus not similar: got: %+v, want: %+v", got, want)
+ }
+
+ if !got.IsVulnerable() {
+ t.Fatalf("Failed: cpu should be vulnerable.")
+ }
+}
+
+func TestInvalid(t *testing.T) {
+ result, err := getThreads(`something not a processor`)
+ if err == nil {
+ t.Fatalf("getCPU set didn't return an error: %+v", result)
+ }
+
+ if !strings.Contains(err.Error(), "no cpus") {
+ t.Fatalf("Incorrect error returned: %v", err)
+ }
+}
+
+// TestCPUSet tests getting the right number of CPUs from
+// parsing full output of /proc/cpuinfo.
+func TestCPUSet(t *testing.T) {
+ data := `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 63
+model name : Intel(R) Xeon(R) CPU @ 2.30GHz
+stepping : 0
+microcode : 0x1
+cpu MHz : 2299.998
+cache size : 46080 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
+bogomips : 4599.99
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:
+
+processor : 1
+vendor_id : GenuineIntel
+cpu family : 6
+model : 63
+model name : Intel(R) Xeon(R) CPU @ 2.30GHz
+stepping : 0
+microcode : 0x1
+cpu MHz : 2299.998
+cache size : 46080 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 1
+initial apicid : 1
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
+bogomips : 4599.99
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
power management:
`
+ cpuSet, err := getThreads(data)
+ if err != nil {
+ t.Fatalf("getCPUSet failed: %v", err)
+ }
- for _, tc := range []executeTestCase{
- {
- name: "CascadeLake4",
- mitigateData: cascadeLake4.makeCPUString(),
- reverseData: cascadeLake4.makeSysPossibleString(),
- },
- {
- name: "Empty",
- mitigateData: "",
- mitigateError: fmt.Errorf(`mitigate operation failed: no cpus found for: ""`),
- reverseData: "",
- reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from %s: ""`, allPossibleCPUs),
+ wantCPULen := 2
+ if len(cpuSet) != wantCPULen {
+ t.Fatalf("Num CPU mismatch: want: %d, got: %d", wantCPULen, len(cpuSet))
+ }
+
+ wantCPU := Thread{
+ vendorID: "GenuineIntel",
+ cpuFamily: 6,
+ model: 63,
+ bugs: map[string]struct{}{
+ "cpu_meltdown": struct{}{},
+ "spectre_v1": struct{}{},
+ "spectre_v2": struct{}{},
+ "spec_store_bypass": struct{}{},
+ "l1tf": struct{}{},
+ "mds": struct{}{},
+ "swapgs": struct{}{},
},
- {
- name: "Partial",
- mitigateData: `processor : 0
+ }
+
+ for _, c := range cpuSet {
+ if !wantCPU.SimilarTo(c) {
+ t.Fatalf("Failed cpus not equal: got: %+v, want: %+v", c, wantCPU)
+ }
+ }
+}
+
+// TestReadFile is a smoke test for parsing methods.
+func TestReadFile(t *testing.T) {
+ data, err := ioutil.ReadFile("/proc/cpuinfo")
+ if err != nil {
+ t.Fatalf("Failed to read cpuinfo: %v", err)
+ }
+
+ vulnerable := func(t Thread) bool {
+ return t.IsVulnerable()
+ }
+
+ set, err := NewCPUSet(data, vulnerable)
+ if err != nil {
+ t.Fatalf("Failed to parse CPU data %v\n%s", err, data)
+ }
+
+ for _, tg := range set {
+ if err := checkSorted(tg.threads); err != nil {
+ t.Fatalf("Failed to sort cpuSet: %v", err)
+ }
+ }
+
+ if len(set) < 1 {
+ t.Fatalf("Failed to parse any CPUs: %d", len(set))
+ }
+
+ t.Log(set)
+}
+
+// TestVulnerable tests if the isVulnerable method is correct
+// among known CPUs in GCP.
+func TestVulnerable(t *testing.T) {
+ const haswell = `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 63
+model name : Intel(R) Xeon(R) CPU @ 2.30GHz
+stepping : 0
+microcode : 0x1
+cpu MHz : 2299.998
+cache size : 46080 KB
+physical id : 0
+siblings : 4
+core id : 0
+cpu cores : 2
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs
+bogomips : 4599.99
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:`
+
+ const skylake = `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 85
+model name : Intel(R) Xeon(R) CPU @ 2.00GHz
+stepping : 3
+microcode : 0x1
+cpu MHz : 2000.180
+cache size : 39424 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
+bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa
+bogomips : 4000.36
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:`
+
+ const cascade = `processor : 0
+vendor_id : GenuineIntel
+cpu family : 6
+model : 85
+model name : Intel(R) Xeon(R) CPU
+stepping : 7
+microcode : 0x1
+cpu MHz : 2800.198
+cache size : 33792 KB
+physical id : 0
+siblings : 2
+core id : 0
+cpu cores : 1
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2
+ ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmu
+lqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowpr
+efetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid r
+tm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves a
+rat avx512_vnni md_clear arch_capabilities
+bugs : spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa
+bogomips : 5600.39
+clflush size : 64
+cache_alignment : 64
+address sizes : 46 bits physical, 48 bits virtual
+power management:`
+
+ const amd = `processor : 0
vendor_id : AuthenticAMD
cpu family : 23
model : 49
model name : AMD EPYC 7B12
+stepping : 0
+microcode : 0x1000065
+cpu MHz : 2250.000
+cache size : 512 KB
physical id : 0
+siblings : 2
core id : 0
cpu cores : 1
+apicid : 0
+initial apicid : 0
+fpu : yes
+fpu_exception : yes
+cpuid level : 13
+wp : yes
+flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid
bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass
-power management:
+bogomips : 4500.00
+TLB size : 3072 4K pages
+clflush size : 64
+cache_alignment : 64
+address sizes : 48 bits physical, 48 bits virtual
+power management:`
-` + partial,
- mitigateError: fmt.Errorf(`mitigate operation failed: failed to match key "core id": %q`, partial),
- reverseData: "1-",
- reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from %s: %q`, allPossibleCPUs, "1-"),
+ for _, tc := range []struct {
+ name string
+ cpuString string
+ vulnerable bool
+ }{
+ {
+ name: "haswell",
+ cpuString: haswell,
+ vulnerable: true,
+ }, {
+ name: "skylake",
+ cpuString: skylake,
+ vulnerable: true,
+ }, {
+ name: "amd",
+ cpuString: amd,
+ vulnerable: false,
},
} {
- doExecuteTest(t, Mitigate{}, tc)
+ t.Run(tc.name, func(t *testing.T) {
+ set, err := getThreads(tc.cpuString)
+ if err != nil {
+ t.Fatalf("Failed to getCPUSet:%v\n %s", err, tc.cpuString)
+ }
+
+ if len(set) < 1 {
+ t.Fatalf("Returned empty cpu set: %v", set)
+ }
+
+ for _, c := range set {
+ got := func() bool {
+ return c.IsVulnerable()
+ }()
+
+ if got != tc.vulnerable {
+ t.Fatalf("Mismatch vulnerable for cpu %+s: got %t want: %t", tc.name, tc.vulnerable, got)
+ }
+ }
+ })
}
}
-func TestExecuteSmoke(t *testing.T) {
- smokeMitigate, err := ioutil.ReadFile(cpuInfo)
+func TestReverse(t *testing.T) {
+ const noParse = "-1-"
+ for _, tc := range []struct {
+ name string
+ output string
+ wantErr error
+ wantCount int
+ }{
+ {
+ name: "base",
+ output: "0-7",
+ wantErr: nil,
+ wantCount: 8,
+ },
+ {
+ name: "huge",
+ output: "0-111",
+ wantErr: nil,
+ wantCount: 112,
+ },
+ {
+ name: "not zero",
+ output: "50-53",
+ wantErr: nil,
+ wantCount: 4,
+ },
+ {
+ name: "small",
+ output: "0",
+ wantErr: nil,
+ wantCount: 1,
+ },
+ {
+ name: "invalid order",
+ output: "10-6",
+ wantErr: fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", 10, 6),
+ },
+ {
+ name: "no parse",
+ output: noParse,
+ wantErr: fmt.Errorf(`mismatch regex from possible: %q`, noParse),
+ },
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ threads, err := GetThreadsFromPossible([]byte(tc.output))
+
+ switch {
+ case tc.wantErr == nil:
+ if err != nil {
+ t.Fatalf("Wanted nil err, got: %v", err)
+ }
+ case err == nil:
+ t.Fatalf("Want error: %v got: %v", tc.wantErr, err)
+ default:
+ if tc.wantErr.Error() != err.Error() {
+ t.Fatalf("Want error: %v got error: %v", tc.wantErr, err)
+ }
+ }
+
+ if len(threads) != tc.wantCount {
+ t.Fatalf("Want count: %d got: %d", tc.wantCount, len(threads))
+ }
+ })
+ }
+}
+
+func TestReverseSmoke(t *testing.T) {
+ data, err := ioutil.ReadFile("/sys/devices/system/cpu/possible")
if err != nil {
- t.Fatalf("Failed to read %s: %v", cpuInfo, err)
+ t.Fatalf("Failed to read from possible: %v", err)
}
- smokeReverse, err := ioutil.ReadFile(allPossibleCPUs)
+ threads, err := GetThreadsFromPossible(data)
if err != nil {
- t.Fatalf("Failed to read %s: %v", allPossibleCPUs, err)
+ t.Fatalf("Could not parse possible output: %v", err)
}
- doExecuteTest(t, Mitigate{}, executeTestCase{
- name: "SmokeTest",
- mitigateData: string(smokeMitigate),
- reverseData: string(smokeReverse),
- })
+ if len(threads) <= 0 {
+ t.Fatalf("Didn't get any CPU cores: %d", len(threads))
+ }
}
-// doExecuteTest runs Execute with the mitigate operation and reverse operation.
-func doExecuteTest(t *testing.T, m Mitigate, tc executeTestCase) {
- t.Run("Mitigate"+tc.name, func(t *testing.T) {
- m.dryRun = true
- file, err := ioutil.TempFile("", "outfile.txt")
- if err != nil {
- t.Fatalf("Failed to create tmpfile: %v", err)
- }
- defer os.Remove(file.Name())
-
- if _, err := file.WriteString(tc.mitigateData); err != nil {
- t.Fatalf("Failed to write to file: %v", err)
- }
-
- m.path = file.Name()
-
- got := m.Execute()
- if err = checkErr(tc.mitigateError, got); err != nil {
- t.Fatalf("Mitigate error mismatch: %v", err)
- }
- })
- t.Run("Reverse"+tc.name, func(t *testing.T) {
- m.dryRun = true
- m.reverse = true
-
- file, err := ioutil.TempFile("", "outfile.txt")
- if err != nil {
- t.Fatalf("Failed to create tmpfile: %v", err)
- }
- defer os.Remove(file.Name())
-
- if _, err := file.WriteString(tc.reverseData); err != nil {
- t.Fatalf("Failed to write to file: %v", err)
- }
-
- m.path = file.Name()
- got := m.Execute()
- if err = checkErr(tc.reverseError, got); err != nil {
- t.Fatalf("Mitigate error mismatch: %v", err)
+func checkSorted(threads []Thread) error {
+ if len(threads) < 2 {
+ return nil
+ }
+ last := threads[0].processorNumber
+ for _, t := range threads[1:] {
+ if last >= t.processorNumber {
+ return fmt.Errorf("threads out of order: thread %d before %d", t.processorNumber, last)
}
- })
-
-}
-
-// checkErr checks error for equality.
-func checkErr(want, got error) error {
- switch {
- case want == nil && got == nil:
- case want != nil && got == nil:
- fallthrough
- case want == nil && got != nil:
- fallthrough
- case want.Error() != strings.Trim(got.Error(), " "):
- return fmt.Errorf("got: %v want: %v", got, want)
+ last = t.processorNumber
}
return nil
}
diff --git a/runsc/mitigate/mock/BUILD b/runsc/mitigate/mock/BUILD
new file mode 100644
index 000000000..5019ff9ee
--- /dev/null
+++ b/runsc/mitigate/mock/BUILD
@@ -0,0 +1,11 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "mock",
+ srcs = ["mock.go"],
+ visibility = [
+ "//runsc:__subpackages__",
+ ],
+)
diff --git a/runsc/mitigate/mock/mock.go b/runsc/mitigate/mock/mock.go
new file mode 100644
index 000000000..2db718cb9
--- /dev/null
+++ b/runsc/mitigate/mock/mock.go
@@ -0,0 +1,141 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package mock contains mock CPUs for mitigate tests.
+package mock
+
+import "fmt"
+
+// CPU represents data from CPUs that will be mitigated.
+type CPU struct {
+ Name string
+ VendorID string
+ Family int
+ Model int
+ ModelName string
+ Bugs string
+ PhysicalCores int
+ Cores int
+ ThreadsPerCore int
+}
+
+// CascadeLake2 is a two core Intel CascadeLake machine.
+var CascadeLake2 = CPU{
+ Name: "CascadeLake",
+ VendorID: "GenuineIntel",
+ Family: 6,
+ Model: 85,
+ ModelName: "Intel(R) Xeon(R) CPU",
+ Bugs: "spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa",
+ PhysicalCores: 1,
+ Cores: 1,
+ ThreadsPerCore: 2,
+}
+
+// CascadeLake4 is a four core Intel CascadeLake machine.
+var CascadeLake4 = CPU{
+ Name: "CascadeLake",
+ VendorID: "GenuineIntel",
+ Family: 6,
+ Model: 85,
+ ModelName: "Intel(R) Xeon(R) CPU",
+ Bugs: "spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa",
+ PhysicalCores: 1,
+ Cores: 2,
+ ThreadsPerCore: 2,
+}
+
+// Haswell2 is a two core Intel Haswell machine.
+var Haswell2 = CPU{
+ Name: "Haswell",
+ VendorID: "GenuineIntel",
+ Family: 6,
+ Model: 63,
+ ModelName: "Intel(R) Xeon(R) CPU",
+ Bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs",
+ PhysicalCores: 1,
+ Cores: 1,
+ ThreadsPerCore: 2,
+}
+
+// Haswell2core is a 2 core Intel Haswell machine with no hyperthread pairs.
+var Haswell2core = CPU{
+ Name: "Haswell2Physical",
+ VendorID: "GenuineIntel",
+ Family: 6,
+ Model: 63,
+ ModelName: "Intel(R) Xeon(R) CPU",
+ Bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs",
+ PhysicalCores: 2,
+ Cores: 1,
+ ThreadsPerCore: 1,
+}
+
+// AMD8 is an eight core AMD machine.
+var AMD8 = CPU{
+ Name: "AMD",
+ VendorID: "AuthenticAMD",
+ Family: 23,
+ Model: 49,
+ ModelName: "AMD EPYC 7B12",
+ Bugs: "sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass",
+ PhysicalCores: 4,
+ Cores: 1,
+ ThreadsPerCore: 2,
+}
+
+// MakeCPUString makes a string formated like /proc/cpuinfo for each cpuTestCase
+func (tc CPU) MakeCPUString() string {
+ template := `processor : %d
+vendor_id : %s
+cpu family : %d
+model : %d
+model name : %s
+physical id : %d
+core id : %d
+cpu cores : %d
+bugs : %s
+
+`
+
+ ret := ``
+ for i := 0; i < tc.PhysicalCores; i++ {
+ for j := 0; j < tc.Cores; j++ {
+ for k := 0; k < tc.ThreadsPerCore; k++ {
+ processorNum := (i*tc.Cores+j)*tc.ThreadsPerCore + k
+ ret += fmt.Sprintf(template,
+ processorNum, /*processor*/
+ tc.VendorID, /*vendor_id*/
+ tc.Family, /*cpu family*/
+ tc.Model, /*model*/
+ tc.ModelName, /*model name*/
+ i, /*physical id*/
+ j, /*core id*/
+ tc.Cores*tc.PhysicalCores, /*cpu cores*/
+ tc.Bugs, /*bugs*/
+ )
+ }
+ }
+ }
+ return ret
+}
+
+// MakeSysPossibleString makes a string representing a the contents of /sys/devices/system/cpu/possible.
+func (tc CPU) MakeSysPossibleString() string {
+ max := tc.PhysicalCores * tc.Cores * tc.ThreadsPerCore
+ if max == 1 {
+ return "0"
+ }
+ return fmt.Sprintf("0-%d", max-1)
+}
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index 9e429f7d5..f69558021 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -21,7 +21,6 @@ import (
"path/filepath"
"runtime"
"strconv"
- "syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/vishvananda/netlink"
@@ -102,11 +101,11 @@ func joinNetNS(nsPath string) (func(), error) {
// isRootNS determines whether we are running in the root net namespace.
// /proc/sys/net/core/rmem_default only exists in root network namespace.
func isRootNS() (bool, error) {
- err := syscall.Access("/proc/sys/net/core/rmem_default", syscall.F_OK)
+ err := unix.Access("/proc/sys/net/core/rmem_default", unix.F_OK)
switch err {
case nil:
return true, nil
- case syscall.ENOENT:
+ case unix.ENOENT:
return false, nil
default:
return false, fmt.Errorf("failed to access /proc/sys/net/core/rmem_default: %v", err)
@@ -270,17 +269,17 @@ type socketEntry struct {
func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (*socketEntry, error) {
// Create the socket.
const protocol = 0x0300 // htons(ETH_P_ALL)
- fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
+ fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, protocol)
if err != nil {
return nil, fmt.Errorf("unable to create raw socket: %v", err)
}
deviceFile := os.NewFile(uintptr(fd), "raw-device-fd")
// Bind to the appropriate device.
- ll := syscall.SockaddrLinklayer{
+ ll := unix.SockaddrLinklayer{
Protocol: protocol,
Ifindex: iface.Index,
}
- if err := syscall.Bind(fd, &ll); err != nil {
+ if err := unix.Bind(fd, &ll); err != nil {
return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
}
@@ -291,7 +290,7 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
return nil, fmt.Errorf("getting GSO for interface %q: %v", iface.Name, err)
}
if gso {
- if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
+ if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
}
gsoMaxSize = ifaceLink.Attrs().GSOMaxSize
@@ -307,18 +306,18 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
// incurring packet drops.
const bufSize = 4 << 20 // 4MB.
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil {
- syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize)
- sz, _ := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF)
+ if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bufSize); err != nil {
+ unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, bufSize)
+ sz, _ := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF)
if sz < bufSize {
log.Warningf("Failed to increase rcv buffer to %d on SOCK_RAW on %s. Current buffer %d: %v", bufSize, iface.Name, sz, err)
}
}
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil {
- syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize)
- sz, _ := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, bufSize); err != nil {
+ unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF, bufSize)
+ sz, _ := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF)
if sz < bufSize {
log.Warningf("Failed to increase snd buffer to %d on SOCK_RAW on %s. Curent buffer %d: %v", bufSize, iface.Name, sz, err)
}
diff --git a/runsc/sandbox/network_unsafe.go b/runsc/sandbox/network_unsafe.go
index 2a2a0fb7e..1b808a8a0 100644
--- a/runsc/sandbox/network_unsafe.go
+++ b/runsc/sandbox/network_unsafe.go
@@ -15,7 +15,6 @@
package sandbox
import (
- "syscall"
"unsafe"
"golang.org/x/sys/unix"
@@ -48,7 +47,7 @@ func isGSOEnabled(fd int, intf string) (bool, error) {
ifrData: &val,
}
- if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), unix.SIOCETHTOOL, uintptr(unsafe.Pointer(&ifr))); err != 0 {
+ if _, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCETHTOOL, uintptr(unsafe.Pointer(&ifr))); err != 0 {
return false, err
}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index 7fe65c7ba..450f92645 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -30,6 +30,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/syndtr/gocapability/capability"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/cleanup"
"gvisor.dev/gvisor/pkg/control/client"
"gvisor.dev/gvisor/pkg/control/server"
@@ -83,7 +84,7 @@ type Sandbox struct {
// child==true and the sandbox was waited on. This field allows for multiple
// threads to wait on sandbox and get the exit code, since Linux will return
// WaitStatus to one of the waiters only.
- status syscall.WaitStatus
+ status unix.WaitStatus
}
// Args is used to configure a new sandbox.
@@ -383,7 +384,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
binPath := specutils.ExePath
cmd := exec.Command(binPath, conf.ToFlags()...)
- cmd.SysProcAttr = &syscall.SysProcAttr{}
+ cmd.SysProcAttr = &unix.SysProcAttr{}
// Open the log files to pass to the sandbox as FDs.
//
@@ -739,7 +740,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
if args.Attached {
// Kill sandbox if parent process exits in attached mode.
- cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL
+ cmd.SysProcAttr.Pdeathsig = unix.SIGKILL
// Tells boot that any process it creates must have pdeathsig set.
cmd.Args = append(cmd.Args, "--attached")
}
@@ -762,7 +763,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
//
// NOTE: The error message is checked because error types are lost over
// rpc calls.
- if strings.Contains(err.Error(), syscall.EACCES.Error()) {
+ if strings.Contains(err.Error(), unix.EACCES.Error()) {
if permsErr := checkBinaryPermissions(conf); permsErr != nil {
return fmt.Errorf("%v: %v", err, permsErr)
}
@@ -782,7 +783,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
}
// Wait waits for the containerized process to exit, and returns its WaitStatus.
-func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) {
+func (s *Sandbox) Wait(cid string) (unix.WaitStatus, error) {
log.Debugf("Waiting for container %q in sandbox %q", cid, s.ID)
if conn, err := s.sandboxConnect(); err != nil {
@@ -790,14 +791,14 @@ func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) {
// There is nothing we can do for subcontainers. For the init container, we
// can try to get the sandbox exit code.
if !s.IsRootContainer(cid) {
- return syscall.WaitStatus(0), err
+ return unix.WaitStatus(0), err
}
log.Warningf("Wait on container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err)
} else {
defer conn.Close()
// Try the Wait RPC to the sandbox.
- var ws syscall.WaitStatus
+ var ws unix.WaitStatus
err = conn.Call(boot.ContainerWait, &cid, &ws)
if err == nil {
// It worked!
@@ -805,7 +806,7 @@ func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) {
}
// See comment above.
if !s.IsRootContainer(cid) {
- return syscall.WaitStatus(0), err
+ return unix.WaitStatus(0), err
}
// The sandbox may have exited after we connected, but before
@@ -817,10 +818,10 @@ func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) {
// The best we can do is ask Linux what the sandbox exit status was, since in
// most cases that will be the same as the container exit status.
if err := s.waitForStopped(); err != nil {
- return syscall.WaitStatus(0), err
+ return unix.WaitStatus(0), err
}
if !s.child {
- return syscall.WaitStatus(0), fmt.Errorf("sandbox no longer running and its exit status is unavailable")
+ return unix.WaitStatus(0), fmt.Errorf("sandbox no longer running and its exit status is unavailable")
}
s.statusMu.Lock()
@@ -830,9 +831,9 @@ func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) {
// WaitPID waits for process 'pid' in the container's sandbox and returns its
// WaitStatus.
-func (s *Sandbox) WaitPID(cid string, pid int32) (syscall.WaitStatus, error) {
+func (s *Sandbox) WaitPID(cid string, pid int32) (unix.WaitStatus, error) {
log.Debugf("Waiting for PID %d in sandbox %q", pid, s.ID)
- var ws syscall.WaitStatus
+ var ws unix.WaitStatus
conn, err := s.sandboxConnect()
if err != nil {
return ws, err
@@ -861,7 +862,7 @@ func (s *Sandbox) destroy() error {
log.Debugf("Destroy sandbox %q", s.ID)
if s.Pid != 0 {
log.Debugf("Killing sandbox %q", s.ID)
- if err := syscall.Kill(s.Pid, syscall.SIGKILL); err != nil && err != syscall.ESRCH {
+ if err := unix.Kill(s.Pid, unix.SIGKILL); err != nil && err != unix.ESRCH {
return fmt.Errorf("killing sandbox %q PID %q: %v", s.ID, s.Pid, err)
}
if err := s.waitForStopped(); err != nil {
@@ -875,7 +876,7 @@ func (s *Sandbox) destroy() error {
// SignalContainer sends the signal to a container in the sandbox. If all is
// true and signal is SIGKILL, then waits for all processes to exit before
// returning.
-func (s *Sandbox) SignalContainer(cid string, sig syscall.Signal, all bool) error {
+func (s *Sandbox) SignalContainer(cid string, sig unix.Signal, all bool) error {
log.Debugf("Signal sandbox %q", s.ID)
conn, err := s.sandboxConnect()
if err != nil {
@@ -903,7 +904,7 @@ func (s *Sandbox) SignalContainer(cid string, sig syscall.Signal, all bool) erro
// fgProcess is true, then the signal is sent to the foreground process group
// in the same session that PID belongs to. This is only valid if the process
// is attached to a host TTY.
-func (s *Sandbox) SignalProcess(cid string, pid int32, sig syscall.Signal, fgProcess bool) error {
+func (s *Sandbox) SignalProcess(cid string, pid int32, sig unix.Signal, fgProcess bool) error {
log.Debugf("Signal sandbox %q", s.ID)
conn, err := s.sandboxConnect()
if err != nil {
@@ -984,7 +985,7 @@ func (s *Sandbox) Resume(cid string) error {
func (s *Sandbox) IsRunning() bool {
if s.Pid != 0 {
// Send a signal 0 to the sandbox process.
- if err := syscall.Kill(s.Pid, 0); err == nil {
+ if err := unix.Kill(s.Pid, 0); err == nil {
// Succeeded, process is running.
return true
}
@@ -1147,7 +1148,7 @@ func (s *Sandbox) waitForStopped() error {
}
// The sandbox process is a child of the current process,
// so we can wait it and collect its zombie.
- wpid, err := syscall.Wait4(int(s.Pid), &s.status, syscall.WNOHANG, nil)
+ wpid, err := unix.Wait4(int(s.Pid), &s.status, unix.WNOHANG, nil)
if err != nil {
return fmt.Errorf("error waiting the sandbox process: %v", err)
}
diff --git a/runsc/specutils/fs.go b/runsc/specutils/fs.go
index 138aa4dd1..b62504a8c 100644
--- a/runsc/specutils/fs.go
+++ b/runsc/specutils/fs.go
@@ -18,9 +18,9 @@ import (
"fmt"
"math/bits"
"path"
- "syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
)
type mapping struct {
@@ -31,48 +31,48 @@ type mapping struct {
// optionsMap maps mount propagation-related OCI filesystem options to mount(2)
// syscall flags.
var optionsMap = map[string]mapping{
- "acl": {set: true, val: syscall.MS_POSIXACL},
- "async": {set: false, val: syscall.MS_SYNCHRONOUS},
- "atime": {set: false, val: syscall.MS_NOATIME},
- "bind": {set: true, val: syscall.MS_BIND},
+ "acl": {set: true, val: unix.MS_POSIXACL},
+ "async": {set: false, val: unix.MS_SYNCHRONOUS},
+ "atime": {set: false, val: unix.MS_NOATIME},
+ "bind": {set: true, val: unix.MS_BIND},
"defaults": {set: true, val: 0},
- "dev": {set: false, val: syscall.MS_NODEV},
- "diratime": {set: false, val: syscall.MS_NODIRATIME},
- "dirsync": {set: true, val: syscall.MS_DIRSYNC},
- "exec": {set: false, val: syscall.MS_NOEXEC},
- "noexec": {set: true, val: syscall.MS_NOEXEC},
- "iversion": {set: true, val: syscall.MS_I_VERSION},
- "loud": {set: false, val: syscall.MS_SILENT},
- "mand": {set: true, val: syscall.MS_MANDLOCK},
- "noacl": {set: false, val: syscall.MS_POSIXACL},
- "noatime": {set: true, val: syscall.MS_NOATIME},
- "nodev": {set: true, val: syscall.MS_NODEV},
- "nodiratime": {set: true, val: syscall.MS_NODIRATIME},
- "noiversion": {set: false, val: syscall.MS_I_VERSION},
- "nomand": {set: false, val: syscall.MS_MANDLOCK},
- "norelatime": {set: false, val: syscall.MS_RELATIME},
- "nostrictatime": {set: false, val: syscall.MS_STRICTATIME},
- "nosuid": {set: true, val: syscall.MS_NOSUID},
- "rbind": {set: true, val: syscall.MS_BIND | syscall.MS_REC},
- "relatime": {set: true, val: syscall.MS_RELATIME},
- "remount": {set: true, val: syscall.MS_REMOUNT},
- "ro": {set: true, val: syscall.MS_RDONLY},
- "rw": {set: false, val: syscall.MS_RDONLY},
- "silent": {set: true, val: syscall.MS_SILENT},
- "strictatime": {set: true, val: syscall.MS_STRICTATIME},
- "suid": {set: false, val: syscall.MS_NOSUID},
- "sync": {set: true, val: syscall.MS_SYNCHRONOUS},
+ "dev": {set: false, val: unix.MS_NODEV},
+ "diratime": {set: false, val: unix.MS_NODIRATIME},
+ "dirsync": {set: true, val: unix.MS_DIRSYNC},
+ "exec": {set: false, val: unix.MS_NOEXEC},
+ "noexec": {set: true, val: unix.MS_NOEXEC},
+ "iversion": {set: true, val: unix.MS_I_VERSION},
+ "loud": {set: false, val: unix.MS_SILENT},
+ "mand": {set: true, val: unix.MS_MANDLOCK},
+ "noacl": {set: false, val: unix.MS_POSIXACL},
+ "noatime": {set: true, val: unix.MS_NOATIME},
+ "nodev": {set: true, val: unix.MS_NODEV},
+ "nodiratime": {set: true, val: unix.MS_NODIRATIME},
+ "noiversion": {set: false, val: unix.MS_I_VERSION},
+ "nomand": {set: false, val: unix.MS_MANDLOCK},
+ "norelatime": {set: false, val: unix.MS_RELATIME},
+ "nostrictatime": {set: false, val: unix.MS_STRICTATIME},
+ "nosuid": {set: true, val: unix.MS_NOSUID},
+ "rbind": {set: true, val: unix.MS_BIND | unix.MS_REC},
+ "relatime": {set: true, val: unix.MS_RELATIME},
+ "remount": {set: true, val: unix.MS_REMOUNT},
+ "ro": {set: true, val: unix.MS_RDONLY},
+ "rw": {set: false, val: unix.MS_RDONLY},
+ "silent": {set: true, val: unix.MS_SILENT},
+ "strictatime": {set: true, val: unix.MS_STRICTATIME},
+ "suid": {set: false, val: unix.MS_NOSUID},
+ "sync": {set: true, val: unix.MS_SYNCHRONOUS},
}
// propOptionsMap is similar to optionsMap, but it lists propagation options
// that cannot be used together with other flags.
var propOptionsMap = map[string]mapping{
- "private": {set: true, val: syscall.MS_PRIVATE},
- "rprivate": {set: true, val: syscall.MS_PRIVATE | syscall.MS_REC},
- "slave": {set: true, val: syscall.MS_SLAVE},
- "rslave": {set: true, val: syscall.MS_SLAVE | syscall.MS_REC},
- "unbindable": {set: true, val: syscall.MS_UNBINDABLE},
- "runbindable": {set: true, val: syscall.MS_UNBINDABLE | syscall.MS_REC},
+ "private": {set: true, val: unix.MS_PRIVATE},
+ "rprivate": {set: true, val: unix.MS_PRIVATE | unix.MS_REC},
+ "slave": {set: true, val: unix.MS_SLAVE},
+ "rslave": {set: true, val: unix.MS_SLAVE | unix.MS_REC},
+ "unbindable": {set: true, val: unix.MS_UNBINDABLE},
+ "runbindable": {set: true, val: unix.MS_UNBINDABLE | unix.MS_REC},
}
// invalidOptions list options not allowed.
@@ -139,7 +139,7 @@ func ValidateMountOptions(opts []string) error {
// correct.
func validateRootfsPropagation(opt string) error {
flags := PropOptionsToFlags([]string{opt})
- if flags&(syscall.MS_SLAVE|syscall.MS_PRIVATE) == 0 {
+ if flags&(unix.MS_SLAVE|unix.MS_PRIVATE) == 0 {
return fmt.Errorf("root mount propagation option must specify private or slave: %q", opt)
}
return validatePropagation(opt)
@@ -147,7 +147,7 @@ func validateRootfsPropagation(opt string) error {
func validatePropagation(opt string) error {
flags := PropOptionsToFlags([]string{opt})
- exclusive := flags & (syscall.MS_SLAVE | syscall.MS_PRIVATE | syscall.MS_SHARED | syscall.MS_UNBINDABLE)
+ exclusive := flags & (unix.MS_SLAVE | unix.MS_PRIVATE | unix.MS_SHARED | unix.MS_UNBINDABLE)
if bits.OnesCount32(exclusive) > 1 {
return fmt.Errorf("mount propagation options are mutually exclusive: %q", opt)
}
diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go
index 23001d67c..69d7ba5c4 100644
--- a/runsc/specutils/namespace.go
+++ b/runsc/specutils/namespace.go
@@ -109,7 +109,7 @@ func FilterNS(filter []specs.LinuxNamespaceType, s *specs.Spec) []specs.LinuxNam
// setNS sets the namespace of the given type. It must be called with
// OSThreadLocked.
func setNS(fd, nsType uintptr) error {
- if _, _, err := syscall.RawSyscall(unix.SYS_SETNS, fd, nsType, 0); err != 0 {
+ if _, _, err := unix.RawSyscall(unix.SYS_SETNS, fd, nsType, 0); err != 0 {
return err
}
return nil
@@ -158,7 +158,7 @@ func StartInNS(cmd *exec.Cmd, nss []specs.LinuxNamespace) error {
defer runtime.UnlockOSThread()
if cmd.SysProcAttr == nil {
- cmd.SysProcAttr = &syscall.SysProcAttr{}
+ cmd.SysProcAttr = &unix.SysProcAttr{}
}
for _, ns := range nss {
@@ -185,7 +185,7 @@ func SetUIDGIDMappings(cmd *exec.Cmd, s *specs.Spec) {
return
}
if cmd.SysProcAttr == nil {
- cmd.SysProcAttr = &syscall.SysProcAttr{}
+ cmd.SysProcAttr = &unix.SysProcAttr{}
}
for _, idMap := range s.Linux.UIDMappings {
log.Infof("Mapping host uid %d to container uid %d (size=%d)", idMap.HostID, idMap.ContainerID, idMap.Size)
@@ -241,8 +241,8 @@ func MaybeRunAsRoot() error {
cmd := exec.Command("/proc/self/exe", os.Args[1:]...)
- cmd.SysProcAttr = &syscall.SysProcAttr{
- Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS,
+ cmd.SysProcAttr = &unix.SysProcAttr{
+ Cloneflags: unix.CLONE_NEWUSER | unix.CLONE_NEWNS,
// Set current user/group as root inside the namespace. Since we may not
// have CAP_SETUID/CAP_SETGID, just map root to the current user/group.
UidMappings: []syscall.SysProcIDMap{
@@ -255,7 +255,7 @@ func MaybeRunAsRoot() error {
GidMappingsEnableSetgroups: false,
// Make sure child is killed when the parent terminates.
- Pdeathsig: syscall.SIGKILL,
+ Pdeathsig: unix.SIGKILL,
}
cmd.Env = os.Environ()
diff --git a/runsc/specutils/seccomp/BUILD b/runsc/specutils/seccomp/BUILD
index 3520f2d6d..e9e647d82 100644
--- a/runsc/specutils/seccomp/BUILD
+++ b/runsc/specutils/seccomp/BUILD
@@ -18,6 +18,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/syscalls/linux",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
@@ -30,5 +31,6 @@ go_test(
"//pkg/binary",
"//pkg/bpf",
"@com_github_opencontainers_runtime_spec//specs-go:go_default_library",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/runsc/specutils/seccomp/seccomp.go b/runsc/specutils/seccomp/seccomp.go
index 5932f7a41..0ef7a4d54 100644
--- a/runsc/specutils/seccomp/seccomp.go
+++ b/runsc/specutils/seccomp/seccomp.go
@@ -18,9 +18,9 @@ package seccomp
import (
"fmt"
- "syscall"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bpf"
"gvisor.dev/gvisor/pkg/log"
@@ -33,9 +33,9 @@ var (
killThreadAction = linux.SECCOMP_RET_KILL_THREAD
trapAction = linux.SECCOMP_RET_TRAP
// runc always returns EPERM as the errorcode for SECCOMP_RET_ERRNO
- errnoAction = linux.SECCOMP_RET_ERRNO.WithReturnCode(uint16(syscall.EPERM))
+ errnoAction = linux.SECCOMP_RET_ERRNO.WithReturnCode(uint16(unix.EPERM))
// runc always returns EPERM as the errorcode for SECCOMP_RET_TRACE
- traceAction = linux.SECCOMP_RET_TRACE.WithReturnCode(uint16(syscall.EPERM))
+ traceAction = linux.SECCOMP_RET_TRACE.WithReturnCode(uint16(unix.EPERM))
allowAction = linux.SECCOMP_RET_ALLOW
)
diff --git a/runsc/specutils/seccomp/seccomp_test.go b/runsc/specutils/seccomp/seccomp_test.go
index 850c237ba..11a6c8daa 100644
--- a/runsc/specutils/seccomp/seccomp_test.go
+++ b/runsc/specutils/seccomp/seccomp_test.go
@@ -16,10 +16,10 @@ package seccomp
import (
"fmt"
- "syscall"
"testing"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/bpf"
)
@@ -184,7 +184,7 @@ var (
Args: []specs.LinuxSeccompArg{
{
Index: 0,
- Value: syscall.CLONE_FS,
+ Value: unix.CLONE_FS,
Op: specs.OpEqualTo,
},
},
@@ -192,7 +192,7 @@ var (
},
},
},
- input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}),
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_FS}),
expected: uint32(errnoAction),
},
{
@@ -207,12 +207,12 @@ var (
Args: []specs.LinuxSeccompArg{
{
Index: 0,
- Value: syscall.CLONE_FS,
+ Value: unix.CLONE_FS,
Op: specs.OpEqualTo,
},
{
Index: 0,
- Value: syscall.CLONE_VM,
+ Value: unix.CLONE_VM,
Op: specs.OpEqualTo,
},
},
@@ -220,7 +220,7 @@ var (
},
},
},
- input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}),
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_FS}),
expected: uint32(errnoAction),
},
{
@@ -235,12 +235,12 @@ var (
Args: []specs.LinuxSeccompArg{
{
Index: 1,
- Value: syscall.SOL_SOCKET,
+ Value: unix.SOL_SOCKET,
Op: specs.OpEqualTo,
},
{
Index: 2,
- Value: syscall.SO_PEERCRED,
+ Value: unix.SO_PEERCRED,
Op: specs.OpEqualTo,
},
},
@@ -248,7 +248,7 @@ var (
},
},
},
- input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET, syscall.SO_PEERCRED}),
+ input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, unix.SOL_SOCKET, unix.SO_PEERCRED}),
expected: uint32(errnoAction),
},
{
@@ -263,12 +263,12 @@ var (
Args: []specs.LinuxSeccompArg{
{
Index: 1,
- Value: syscall.SOL_SOCKET,
+ Value: unix.SOL_SOCKET,
Op: specs.OpEqualTo,
},
{
Index: 2,
- Value: syscall.SO_PEERCRED,
+ Value: unix.SO_PEERCRED,
Op: specs.OpEqualTo,
},
},
@@ -276,7 +276,7 @@ var (
},
},
},
- input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET}),
+ input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, unix.SOL_SOCKET}),
expected: uint32(allowAction),
},
{
@@ -291,7 +291,7 @@ var (
Args: []specs.LinuxSeccompArg{
{
Index: 0,
- Value: syscall.CLONE_FS,
+ Value: unix.CLONE_FS,
Op: specs.OpEqualTo,
},
},
@@ -299,7 +299,7 @@ var (
},
},
},
- input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_VM}),
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_VM}),
expected: uint32(allowAction),
},
{
@@ -314,8 +314,8 @@ var (
Args: []specs.LinuxSeccompArg{
{
Index: 0,
- Value: syscall.CLONE_FS,
- ValueTwo: syscall.CLONE_FS,
+ Value: unix.CLONE_FS,
+ ValueTwo: unix.CLONE_FS,
Op: specs.OpMaskedEqual,
},
},
@@ -323,7 +323,7 @@ var (
},
},
},
- input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS | syscall.CLONE_VM}),
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_FS | unix.CLONE_VM}),
expected: uint32(errnoAction),
},
{
@@ -338,8 +338,8 @@ var (
Args: []specs.LinuxSeccompArg{
{
Index: 0,
- Value: syscall.CLONE_FS | syscall.CLONE_VM,
- ValueTwo: syscall.CLONE_FS | syscall.CLONE_VM,
+ Value: unix.CLONE_FS | unix.CLONE_VM,
+ ValueTwo: unix.CLONE_FS | unix.CLONE_VM,
Op: specs.OpMaskedEqual,
},
},
@@ -347,7 +347,7 @@ var (
},
},
},
- input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}),
+ input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_FS}),
expected: uint32(allowAction),
},
{
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index ea55bbc7d..5ba38bfe4 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -26,12 +26,12 @@ import (
"path/filepath"
"strconv"
"strings"
- "syscall"
"time"
"github.com/cenkalti/backoff"
"github.com/mohae/deepcopy"
specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
@@ -375,9 +375,9 @@ func WaitForReady(pid int, timeout time.Duration, ready func() (bool, error)) er
// Check if the process is still running.
// If the process is alive, child is 0 because of the NOHANG option.
// If the process has terminated, child equals the process id.
- var ws syscall.WaitStatus
- var ru syscall.Rusage
- child, err := syscall.Wait4(pid, &ws, syscall.WNOHANG, &ru)
+ var ws unix.WaitStatus
+ var ru unix.Rusage
+ child, err := unix.Wait4(pid, &ws, unix.WNOHANG, &ru)
if err != nil {
return backoff.Permanent(fmt.Errorf("error waiting for process: %v", err))
} else if child == pid {
@@ -437,7 +437,7 @@ func Mount(src, dst, typ string, flags uint32) error {
return fmt.Errorf("mkdir(%q) failed: %v", parent, err)
}
// Create the destination file if it does not exist.
- f, err := os.OpenFile(dst, syscall.O_CREAT, 0777)
+ f, err := os.OpenFile(dst, unix.O_CREAT, 0777)
if err != nil {
return fmt.Errorf("open(%q) failed: %v", dst, err)
}
@@ -445,7 +445,7 @@ func Mount(src, dst, typ string, flags uint32) error {
}
// Do the mount.
- if err := syscall.Mount(src, dst, typ, uintptr(flags), ""); err != nil {
+ if err := unix.Mount(src, dst, typ, uintptr(flags), ""); err != nil {
return fmt.Errorf("mount(%q, %q, %d) failed: %v", src, dst, flags, err)
}
return nil
@@ -466,7 +466,7 @@ func ContainsStr(strs []string, str string) bool {
func RetryEintr(f func() (uintptr, uintptr, error)) (uintptr, uintptr, error) {
for {
r1, r2, err := f()
- if err != syscall.EINTR {
+ if err != unix.EINTR {
return r1, r2, err
}
}
diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD
index b4f967441..c94caab60 100644
--- a/test/benchmarks/fs/BUILD
+++ b/test/benchmarks/fs/BUILD
@@ -11,6 +11,7 @@ benchmark_test(
"//pkg/test/dockerutil",
"//test/benchmarks/harness",
"//test/benchmarks/tools",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
],
)
diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go
index 8baeff0db..7ced963f6 100644
--- a/test/benchmarks/fs/bazel_test.go
+++ b/test/benchmarks/fs/bazel_test.go
@@ -25,6 +25,13 @@ import (
"gvisor.dev/gvisor/test/benchmarks/tools"
)
+// Dimensions here are clean/dirty cache (do or don't drop caches)
+// and if the mount on which we are compiling is a tmpfs/bind mount.
+type benchmark struct {
+ clearCache bool // clearCache drops caches before running.
+ fstype string // type of filesystem to use.
+}
+
// Note: CleanCache versions of this test require running with root permissions.
func BenchmarkBuildABSL(b *testing.B) {
runBuildBenchmark(b, "benchmarks/absl", "/abseil-cpp", "absl/base/...")
@@ -45,17 +52,18 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
}
defer machine.CleanUp()
- // Dimensions here are clean/dirty cache (do or don't drop caches)
- // and if the mount on which we are compiling is a tmpfs/bind mount.
- benchmarks := []struct {
- clearCache bool // clearCache drops caches before running.
- tmpfs bool // tmpfs will run compilation on a tmpfs.
- }{
- {clearCache: true, tmpfs: false},
- {clearCache: false, tmpfs: false},
- {clearCache: true, tmpfs: true},
- {clearCache: false, tmpfs: true},
+ benchmarks := make([]benchmark, 0, 6)
+ for _, filesys := range []string{harness.BindFS, harness.TmpFS, harness.RootFS} {
+ benchmarks = append(benchmarks, benchmark{
+ clearCache: true,
+ fstype: filesys,
+ })
+ benchmarks = append(benchmarks, benchmark{
+ clearCache: false,
+ fstype: filesys,
+ })
}
+
for _, bm := range benchmarks {
pageCache := tools.Parameter{
Name: "page_cache",
@@ -67,10 +75,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
filesystem := tools.Parameter{
Name: "filesystem",
- Value: "bind",
- }
- if bm.tmpfs {
- filesystem.Value = "tmpfs"
+ Value: bm.fstype,
}
name, err := tools.ParametersToName(pageCache, filesystem)
if err != nil {
@@ -83,21 +88,25 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
container := machine.GetContainer(ctx, b)
defer container.CleanUp(ctx)
+ mts, prefix, cleanup, err := harness.MakeMount(machine, bm.fstype)
+ if err != nil {
+ b.Fatalf("Failed to make mount: %v", err)
+ }
+ defer cleanup()
+
+ runOpts := dockerutil.RunOpts{
+ Image: image,
+ Mounts: mts,
+ }
+
// Start a container and sleep.
- if err := container.Spawn(ctx, dockerutil.RunOpts{
- Image: image,
- }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil {
+ if err := container.Spawn(ctx, runOpts, "sleep", fmt.Sprintf("%d", 1000000)); err != nil {
b.Fatalf("run failed with: %v", err)
}
- // If we are running on a tmpfs, copy to /tmp which is a tmpfs.
- prefix := ""
- if bm.tmpfs {
- if out, err := container.Exec(ctx, dockerutil.ExecOpts{},
- "cp", "-r", workdir, "/tmp/."); err != nil {
- b.Fatalf("failed to copy directory: %v (%s)", err, out)
- }
- prefix = "/tmp"
+ if out, err := container.Exec(ctx, dockerutil.ExecOpts{},
+ "cp", "-rf", workdir, prefix+"/."); err != nil {
+ b.Fatalf("failed to copy directory: %v (%s)", err, out)
}
b.ResetTimer()
@@ -118,7 +127,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) {
WorkDir: prefix + workdir,
}, "bazel", "build", "-c", "opt", target)
if err != nil {
- b.Fatalf("build failed with: %v", err)
+ b.Fatalf("build failed with: %v logs: %s", err, got)
}
b.StopTimer()
diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go
index cc2d1cbbc..f783a2b33 100644
--- a/test/benchmarks/fs/fio_test.go
+++ b/test/benchmarks/fs/fio_test.go
@@ -21,7 +21,6 @@ import (
"strings"
"testing"
- "github.com/docker/docker/api/types/mount"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/test/benchmarks/harness"
"gvisor.dev/gvisor/test/benchmarks/tools"
@@ -70,7 +69,7 @@ func BenchmarkFio(b *testing.B) {
}
defer machine.CleanUp()
- for _, fsType := range []mount.Type{mount.TypeBind, mount.TypeTmpfs} {
+ for _, fsType := range []string{harness.BindFS, harness.TmpFS, harness.RootFS} {
for _, tc := range testCases {
operation := tools.Parameter{
Name: "operation",
@@ -82,7 +81,7 @@ func BenchmarkFio(b *testing.B) {
}
filesystem := tools.Parameter{
Name: "filesystem",
- Value: string(fsType),
+ Value: fsType,
}
name, err := tools.ParametersToName(operation, blockSize, filesystem)
if err != nil {
@@ -95,13 +94,7 @@ func BenchmarkFio(b *testing.B) {
container := machine.GetContainer(ctx, b)
defer container.CleanUp(ctx)
- // Directory and filename inside container where fio will read/write.
- outdir := "/data"
- outfile := filepath.Join(outdir, "test.txt")
-
- // Make the required mount and grab a cleanup for bind mounts
- // as they are backed by a temp directory (mktemp).
- mnt, mountCleanup, err := makeMount(machine, fsType, outdir)
+ mnts, outdir, mountCleanup, err := harness.MakeMount(machine, fsType)
if err != nil {
b.Fatalf("failed to make mount: %v", err)
}
@@ -109,12 +102,9 @@ func BenchmarkFio(b *testing.B) {
// Start the container with the mount.
if err := container.Spawn(
- ctx,
- dockerutil.RunOpts{
- Image: "benchmarks/fio",
- Mounts: []mount.Mount{
- mnt,
- },
+ ctx, dockerutil.RunOpts{
+ Image: "benchmarks/fio",
+ Mounts: mnts,
},
// Sleep on the order of b.N.
"sleep", fmt.Sprintf("%d", 1000*b.N),
@@ -122,6 +112,9 @@ func BenchmarkFio(b *testing.B) {
b.Fatalf("failed to start fio container with: %v", err)
}
+ // Directory and filename inside container where fio will read/write.
+ outfile := filepath.Join(outdir, "test.txt")
+
// For reads, we need a file to read so make one inside the container.
if strings.Contains(tc.Test, "read") {
fallocateCmd := fmt.Sprintf("fallocate -l %dK %s", tc.Size, outfile)
@@ -135,6 +128,7 @@ func BenchmarkFio(b *testing.B) {
if err := harness.DropCaches(machine); err != nil {
b.Skipf("failed to drop caches with %v. You probably need root.", err)
}
+
cmd := tc.MakeCmd(outfile)
if err := harness.DropCaches(machine); err != nil {
@@ -154,39 +148,6 @@ func BenchmarkFio(b *testing.B) {
}
}
-// makeMount makes a mount and cleanup based on the requested type. Bind
-// and volume mounts are backed by a temp directory made with mktemp.
-// tmpfs mounts require no such backing and are just made.
-// It is up to the caller to call the returned cleanup.
-func makeMount(machine harness.Machine, mountType mount.Type, target string) (mount.Mount, func(), error) {
- switch mountType {
- case mount.TypeVolume, mount.TypeBind:
- dir, err := machine.RunCommand("mktemp", "-d")
- if err != nil {
- return mount.Mount{}, func() {}, fmt.Errorf("failed to create tempdir: %v", err)
- }
- dir = strings.TrimSuffix(dir, "\n")
-
- out, err := machine.RunCommand("chmod", "777", dir)
- if err != nil {
- machine.RunCommand("rm", "-rf", dir)
- return mount.Mount{}, func() {}, fmt.Errorf("failed modify directory: %v %s", err, out)
- }
- return mount.Mount{
- Target: target,
- Source: dir,
- Type: mount.TypeBind,
- }, func() { machine.RunCommand("rm", "-rf", dir) }, nil
- case mount.TypeTmpfs:
- return mount.Mount{
- Target: target,
- Type: mount.TypeTmpfs,
- }, func() {}, nil
- default:
- return mount.Mount{}, func() {}, fmt.Errorf("illegal mount time not supported: %v", mountType)
- }
-}
-
// TestMain is the main method for package fs.
func TestMain(m *testing.M) {
harness.Init()
diff --git a/test/benchmarks/harness/BUILD b/test/benchmarks/harness/BUILD
index c2e316709..116610938 100644
--- a/test/benchmarks/harness/BUILD
+++ b/test/benchmarks/harness/BUILD
@@ -14,5 +14,6 @@ go_library(
deps = [
"//pkg/test/dockerutil",
"//pkg/test/testutil",
+ "@com_github_docker_docker//api/types/mount:go_default_library",
],
)
diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go
index aeac7ebff..36abe1069 100644
--- a/test/benchmarks/harness/util.go
+++ b/test/benchmarks/harness/util.go
@@ -18,8 +18,10 @@ import (
"context"
"fmt"
"net"
+ "strings"
"testing"
+ "github.com/docker/docker/api/types/mount"
"gvisor.dev/gvisor/pkg/test/dockerutil"
"gvisor.dev/gvisor/pkg/test/testutil"
)
@@ -55,3 +57,53 @@ func DebugLog(b *testing.B, msg string, args ...interface{}) {
b.Logf(msg, args...)
}
}
+
+const (
+ // BindFS indicates a bind mount should be created.
+ BindFS = "bindfs"
+ // TmpFS indicates a tmpfs mount should be created.
+ TmpFS = "tmpfs"
+ // RootFS indicates no mount should be created and the root mount should be used.
+ RootFS = "rootfs"
+)
+
+// MakeMount makes a mount and cleanup based on the requested type. Bind
+// and volume mounts are backed by a temp directory made with mktemp.
+// tmpfs mounts require no such backing and are just made.
+// rootfs mounts do not make a mount, but instead return a target direectory at root.
+// It is up to the caller to call the returned cleanup.
+func MakeMount(machine Machine, fsType string) ([]mount.Mount, string, func(), error) {
+ mounts := make([]mount.Mount, 0, 1)
+ switch fsType {
+ case BindFS:
+ dir, err := machine.RunCommand("mktemp", "-d")
+ if err != nil {
+ return mounts, "", func() {}, fmt.Errorf("failed to create tempdir: %v", err)
+ }
+ dir = strings.TrimSuffix(dir, "\n")
+
+ out, err := machine.RunCommand("chmod", "777", dir)
+ if err != nil {
+ machine.RunCommand("rm", "-rf", dir)
+ return mounts, "", func() {}, fmt.Errorf("failed modify directory: %v %s", err, out)
+ }
+ target := "/data"
+ mounts = append(mounts, mount.Mount{
+ Target: target,
+ Source: dir,
+ Type: mount.TypeBind,
+ })
+ return mounts, target, func() { machine.RunCommand("rm", "-rf", dir) }, nil
+ case RootFS:
+ return mounts, "/", func() {}, nil
+ case TmpFS:
+ target := "/data"
+ mounts = append(mounts, mount.Mount{
+ Target: target,
+ Type: mount.TypeTmpfs,
+ })
+ return mounts, target, func() {}, nil
+ default:
+ return mounts, "", func() {}, fmt.Errorf("illegal mount type not supported: %v", fsType)
+ }
+}
diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go
index 780e4f7ae..be74e4d4a 100644
--- a/test/benchmarks/tcp/tcp_proxy.go
+++ b/test/benchmarks/tcp/tcp_proxy.go
@@ -29,7 +29,6 @@ import (
"runtime"
"runtime/pprof"
"strconv"
- "syscall"
"time"
"golang.org/x/sys/unix"
@@ -112,33 +111,33 @@ func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) {
const protocol = 0x0300 // htons(ETH_P_ALL)
fds := make([]int, numChannels)
for i := range fds {
- fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
+ fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, protocol)
if err != nil {
return nil, fmt.Errorf("unable to create raw socket: %v", err)
}
// Bind to the appropriate device.
- ll := syscall.SockaddrLinklayer{
+ ll := unix.SockaddrLinklayer{
Protocol: protocol,
Ifindex: iface.Index,
- Pkttype: syscall.PACKET_HOST,
+ Pkttype: unix.PACKET_HOST,
}
- if err := syscall.Bind(fd, &ll); err != nil {
+ if err := unix.Bind(fd, &ll); err != nil {
return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
}
// RAW Sockets by default have a very small SO_RCVBUF of 256KB,
// up it to at least 4MB to reduce packet drops.
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil {
+ if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, bufSize); err != nil {
return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err)
}
- if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil {
+ if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF, bufSize); err != nil {
return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err)
}
if !*swgso && *gso != 0 {
- if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
+ if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
}
}
@@ -403,7 +402,7 @@ func main() {
log.Printf("client=%v, server=%v, ready.", *client, *server)
sigs := make(chan os.Signal, 1)
- signal.Notify(sigs, syscall.SIGTERM)
+ signal.Notify(sigs, unix.SIGTERM)
go func() {
<-sigs
if *cpuprofile != "" {
diff --git a/test/fuse/linux/mount_test.cc b/test/fuse/linux/mount_test.cc
index 8a5478116..276f842ea 100644
--- a/test/fuse/linux/mount_test.cc
+++ b/test/fuse/linux/mount_test.cc
@@ -15,6 +15,7 @@
#include <errno.h>
#include <fcntl.h>
#include <sys/mount.h>
+#include <unistd.h>
#include "gtest/gtest.h"
#include "test/util/mount_util.h"
@@ -29,7 +30,9 @@ namespace {
TEST(FuseMount, Success) {
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY));
- std::string mopts = absl::StrCat("fd=", std::to_string(fd.get()));
+ std::string mopts =
+ absl::StrFormat("fd=%d,user_id=%d,group_id=%d,rootmode=0777", fd.get(),
+ getuid(), getgid());
const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
diff --git a/test/iptables/BUILD b/test/iptables/BUILD
index ae4bba847..94d4ca2d4 100644
--- a/test/iptables/BUILD
+++ b/test/iptables/BUILD
@@ -18,6 +18,7 @@ go_library(
"//pkg/binary",
"//pkg/test/testutil",
"//pkg/usermem",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go
index bd85a8fea..dd1a1c082 100644
--- a/test/iptables/iptables_unsafe.go
+++ b/test/iptables/iptables_unsafe.go
@@ -16,12 +16,13 @@ package iptables
import (
"fmt"
- "syscall"
"unsafe"
+
+ "golang.org/x/sys/unix"
)
type originalDstError struct {
- errno syscall.Errno
+ errno unix.Errno
}
func (e originalDstError) Error() string {
@@ -32,27 +33,27 @@ func (e originalDstError) Error() string {
// getsockopt.
const SO_ORIGINAL_DST = 80
-func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) {
- var addr syscall.RawSockaddrInet4
- var addrLen uint32 = syscall.SizeofSockaddrInet4
- if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 {
- return syscall.RawSockaddrInet4{}, originalDstError{errno}
+func originalDestination4(connfd int) (unix.RawSockaddrInet4, error) {
+ var addr unix.RawSockaddrInet4
+ var addrLen uint32 = unix.SizeofSockaddrInet4
+ if errno := originalDestination(connfd, unix.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 {
+ return unix.RawSockaddrInet4{}, originalDstError{errno}
}
return addr, nil
}
-func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) {
- var addr syscall.RawSockaddrInet6
- var addrLen uint32 = syscall.SizeofSockaddrInet6
- if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 {
- return syscall.RawSockaddrInet6{}, originalDstError{errno}
+func originalDestination6(connfd int) (unix.RawSockaddrInet6, error) {
+ var addr unix.RawSockaddrInet6
+ var addrLen uint32 = unix.SizeofSockaddrInet6
+ if errno := originalDestination(connfd, unix.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 {
+ return unix.RawSockaddrInet6{}, originalDstError{errno}
}
return addr, nil
}
-func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno {
- _, _, errno := syscall.Syscall6(
- syscall.SYS_GETSOCKOPT,
+func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) unix.Errno {
+ _, _, errno := unix.Syscall6(
+ unix.SYS_GETSOCKOPT,
uintptr(connfd),
level,
SO_ORIGINAL_DST,
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 7f1d6d7ad..70d8a1832 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -19,8 +19,8 @@ import (
"errors"
"fmt"
"net"
- "syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -584,33 +584,33 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.
// traditional syscalls.
// Create the listening socket, bind, listen, and accept.
- family := syscall.AF_INET
+ family := unix.AF_INET
if ipv6 {
- family = syscall.AF_INET6
+ family = unix.AF_INET6
}
- sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0)
+ sockfd, err := unix.Socket(family, unix.SOCK_STREAM, 0)
if err != nil {
return err
}
- defer syscall.Close(sockfd)
+ defer unix.Close(sockfd)
- var bindAddr syscall.Sockaddr
+ var bindAddr unix.Sockaddr
if ipv6 {
- bindAddr = &syscall.SockaddrInet6{
+ bindAddr = &unix.SockaddrInet6{
Port: acceptPort,
Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any
}
} else {
- bindAddr = &syscall.SockaddrInet4{
+ bindAddr = &unix.SockaddrInet4{
Port: acceptPort,
Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY
}
}
- if err := syscall.Bind(sockfd, bindAddr); err != nil {
+ if err := unix.Bind(sockfd, bindAddr); err != nil {
return err
}
- if err := syscall.Listen(sockfd, 1); err != nil {
+ if err := unix.Listen(sockfd, 1); err != nil {
return err
}
@@ -619,8 +619,8 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.
errCh := make(chan error)
go func() {
for {
- connFD, _, err := syscall.Accept(sockfd)
- if errors.Is(err, syscall.EINTR) {
+ connFD, _, err := unix.Accept(sockfd)
+ if errors.Is(err, unix.EINTR) {
continue
}
if err != nil {
@@ -641,7 +641,7 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.
return err
case connFD = <-connCh:
}
- defer syscall.Close(connFD)
+ defer unix.Close(connFD)
// Verify that, despite listening on acceptPort, SO_ORIGINAL_DST
// indicates the packet was sent to originalDst:dropPort.
@@ -764,35 +764,35 @@ func recvWithRECVORIGDSTADDR(ctx context.Context, ipv6 bool, expectedDst *net.IP
// Create the listening socket.
var (
- family = syscall.AF_INET
- level = syscall.SOL_IP
- option = syscall.IP_RECVORIGDSTADDR
- bindAddr syscall.Sockaddr = &syscall.SockaddrInet4{
+ family = unix.AF_INET
+ level = unix.SOL_IP
+ option = unix.IP_RECVORIGDSTADDR
+ bindAddr unix.Sockaddr = &unix.SockaddrInet4{
Port: int(port),
Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY
}
)
if ipv6 {
- family = syscall.AF_INET6
- level = syscall.SOL_IPV6
+ family = unix.AF_INET6
+ level = unix.SOL_IPV6
option = 74 // IPV6_RECVORIGDSTADDR, which is missing from the syscall package.
- bindAddr = &syscall.SockaddrInet6{
+ bindAddr = &unix.SockaddrInet6{
Port: int(port),
Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any
}
}
- sockfd, err := syscall.Socket(family, syscall.SOCK_DGRAM, 0)
+ sockfd, err := unix.Socket(family, unix.SOCK_DGRAM, 0)
if err != nil {
- return fmt.Errorf("failed Socket(%d, %d, 0): %w", family, syscall.SOCK_DGRAM, err)
+ return fmt.Errorf("failed Socket(%d, %d, 0): %w", family, unix.SOCK_DGRAM, err)
}
- defer syscall.Close(sockfd)
+ defer unix.Close(sockfd)
- if err := syscall.Bind(sockfd, bindAddr); err != nil {
+ if err := unix.Bind(sockfd, bindAddr); err != nil {
return fmt.Errorf("failed Bind(%d, %+v): %v", sockfd, bindAddr, err)
}
// Enable IP_RECVORIGDSTADDR.
- if err := syscall.SetsockoptInt(sockfd, level, option, 1); err != nil {
+ if err := unix.SetsockoptInt(sockfd, level, option, 1); err != nil {
return fmt.Errorf("failed SetsockoptByte(%d, %d, %d, 1): %v", sockfd, level, option, err)
}
@@ -837,41 +837,41 @@ func recvWithRECVORIGDSTADDR(ctx context.Context, ipv6 bool, expectedDst *net.IP
// Verify that the address has the post-NAT port and address.
if ipv6 {
- return addrMatches6(addr.(syscall.RawSockaddrInet6), localAddrs, redirectPort)
+ return addrMatches6(addr.(unix.RawSockaddrInet6), localAddrs, redirectPort)
}
- return addrMatches4(addr.(syscall.RawSockaddrInet4), localAddrs, redirectPort)
+ return addrMatches4(addr.(unix.RawSockaddrInet4), localAddrs, redirectPort)
}
-func recvOrigDstAddr4(sockfd int) (syscall.RawSockaddrInet4, error) {
- buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet4)
+func recvOrigDstAddr4(sockfd int) (unix.RawSockaddrInet4, error) {
+ buf, err := recvOrigDstAddr(sockfd, unix.SOL_IP, unix.SizeofSockaddrInet4)
if err != nil {
- return syscall.RawSockaddrInet4{}, err
+ return unix.RawSockaddrInet4{}, err
}
- var addr syscall.RawSockaddrInet4
+ var addr unix.RawSockaddrInet4
binary.Unmarshal(buf, usermem.ByteOrder, &addr)
return addr, nil
}
-func recvOrigDstAddr6(sockfd int) (syscall.RawSockaddrInet6, error) {
- buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet6)
+func recvOrigDstAddr6(sockfd int) (unix.RawSockaddrInet6, error) {
+ buf, err := recvOrigDstAddr(sockfd, unix.SOL_IP, unix.SizeofSockaddrInet6)
if err != nil {
- return syscall.RawSockaddrInet6{}, err
+ return unix.RawSockaddrInet6{}, err
}
- var addr syscall.RawSockaddrInet6
+ var addr unix.RawSockaddrInet6
binary.Unmarshal(buf, usermem.ByteOrder, &addr)
return addr, nil
}
func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) {
buf := make([]byte, 64)
- oob := make([]byte, syscall.CmsgSpace(addrSize))
+ oob := make([]byte, unix.CmsgSpace(addrSize))
for {
- _, oobn, _, _, err := syscall.Recvmsg(
+ _, oobn, _, _, err := unix.Recvmsg(
sockfd,
buf, // Message buffer.
oob, // Out-of-band buffer.
0) // Flags.
- if errors.Is(err, syscall.EINTR) {
+ if errors.Is(err, unix.EINTR) {
continue
}
if err != nil {
@@ -880,7 +880,7 @@ func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) {
oob = oob[:oobn]
// Parse out the control message.
- msgs, err := syscall.ParseSocketControlMessage(oob)
+ msgs, err := unix.ParseSocketControlMessage(oob)
if err != nil {
return nil, fmt.Errorf("failed to parse control message: %w", err)
}
@@ -888,10 +888,10 @@ func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) {
}
}
-func addrMatches4(got syscall.RawSockaddrInet4, wantAddrs []net.IP, port uint16) error {
+func addrMatches4(got unix.RawSockaddrInet4, wantAddrs []net.IP, port uint16) error {
for _, wantAddr := range wantAddrs {
- want := syscall.RawSockaddrInet4{
- Family: syscall.AF_INET,
+ want := unix.RawSockaddrInet4{
+ Family: unix.AF_INET,
Port: htons(port),
}
copy(want.Addr[:], wantAddr.To4())
@@ -902,10 +902,10 @@ func addrMatches4(got syscall.RawSockaddrInet4, wantAddrs []net.IP, port uint16)
return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
}
-func addrMatches6(got syscall.RawSockaddrInet6, wantAddrs []net.IP, port uint16) error {
+func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) error {
for _, wantAddr := range wantAddrs {
- want := syscall.RawSockaddrInet6{
- Family: syscall.AF_INET6,
+ want := unix.RawSockaddrInet6{
+ Family: unix.AF_INET6,
Port: htons(port),
}
copy(want.Addr[:], wantAddr.To16())
diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc
index 0d93b806e..ea83bbe72 100644
--- a/test/packetimpact/dut/posix_server.cc
+++ b/test/packetimpact/dut/posix_server.cc
@@ -395,7 +395,8 @@ class PosixImpl final : public posix_server::Posix::Service {
::grpc::Status Shutdown(grpc::ServerContext *context,
const ::posix_server::ShutdownRequest *request,
::posix_server::ShutdownResponse *response) override {
- if (shutdown(request->fd(), request->how()) < 0) {
+ response->set_ret(shutdown(request->fd(), request->how()));
+ if (response->ret() < 0) {
response->set_errno_(errno);
}
return ::grpc::Status::OK;
diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto
index 521f03465..175a65336 100644
--- a/test/packetimpact/proto/posix_server.proto
+++ b/test/packetimpact/proto/posix_server.proto
@@ -214,7 +214,8 @@ message ShutdownRequest {
}
message ShutdownResponse {
- int32 errno_ = 1; // "errno" may fail to compile in c++.
+ int32 ret = 1;
+ int32 errno_ = 2; // "errno" may fail to compile in c++.
}
message RecvRequest {
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index 8ce5edf2b..567f64c41 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -203,6 +203,11 @@ ALL_TESTS = [
name = "tcp_outside_the_window",
),
PacketimpactTestInfo(
+ name = "tcp_outside_the_window_closing",
+ # TODO(b/181625316): Fix netstack then merge into tcp_outside_the_window.
+ expect_netstack_failure = True,
+ ),
+ PacketimpactTestInfo(
name = "tcp_noaccept_close_rst",
),
PacketimpactTestInfo(
@@ -212,6 +217,11 @@ ALL_TESTS = [
name = "tcp_unacc_seq_ack",
),
PacketimpactTestInfo(
+ name = "tcp_unacc_seq_ack_closing",
+ # TODO(b/181625316): Fix netstack then merge into tcp_unacc_seq_ack.
+ expect_netstack_failure = True,
+ ),
+ PacketimpactTestInfo(
name = "tcp_paws_mechanism",
# TODO(b/156682000): Fix netstack then remove the line below.
expect_netstack_failure = True,
@@ -277,6 +287,11 @@ ALL_TESTS = [
PacketimpactTestInfo(
name = "tcp_info",
),
+ PacketimpactTestInfo(
+ name = "tcp_fin_retransmission",
+ # TODO(b/181625316): Fix netstack then remove the line below.
+ expect_netstack_failure = True,
+ ),
]
def validate_all_tests():
diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go
index 3da265b78..1064ca976 100644
--- a/test/packetimpact/runner/dut.go
+++ b/test/packetimpact/runner/dut.go
@@ -109,6 +109,7 @@ type dutInfo struct {
dut DUT
ctrlNet, testNet *dockerutil.Network
netInfo *testbench.DUTTestNet
+ uname *testbench.DUTUname
}
// setUpDUT will set up one DUT and return information for setting up the
@@ -182,6 +183,10 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut
POSIXServerIP: AddressInSubnet(DUTAddr, *ctrlNet.Subnet),
POSIXServerPort: CtrlPort,
}
+ info.uname, err = dut.Uname(ctx)
+ if err != nil {
+ return dutInfo{}, fmt.Errorf("failed to get uname information on DUT: %w", err)
+ }
return info, nil
}
@@ -195,7 +200,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
dutInfoChan := make(chan dutInfo, numDUTs)
errChan := make(chan error, numDUTs)
var dockerNetworks []*dockerutil.Network
- var dutTestNets []*testbench.DUTTestNet
+ var dutInfos []*testbench.DUTInfo
var duts []DUT
setUpCtx, cancelSetup := context.WithCancel(ctx)
@@ -214,7 +219,10 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
select {
case info := <-dutInfoChan:
dockerNetworks = append(dockerNetworks, info.ctrlNet, info.testNet)
- dutTestNets = append(dutTestNets, info.netInfo)
+ dutInfos = append(dutInfos, &testbench.DUTInfo{
+ Net: info.netInfo,
+ Uname: info.uname,
+ })
duts = append(duts, info.dut)
case err := <-errChan:
t.Fatal(err)
@@ -241,28 +249,29 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
testbenchContainer,
testbenchAddr,
dockerNetworks,
+ nil, /* sysctls */
"tail", "-f", "/dev/null",
); err != nil {
t.Fatalf("cannot start testbench container: %s", err)
}
- for i := range dutTestNets {
- name, info, err := deviceByIP(ctx, testbenchContainer, dutTestNets[i].LocalIPv4)
+ for i := range dutInfos {
+ name, info, err := deviceByIP(ctx, testbenchContainer, dutInfos[i].Net.LocalIPv4)
if err != nil {
- t.Fatalf("failed to get the device name associated with %s: %s", dutTestNets[i].LocalIPv4, err)
+ t.Fatalf("failed to get the device name associated with %s: %s", dutInfos[i].Net.LocalIPv4, err)
}
- dutTestNets[i].LocalDevName = name
- dutTestNets[i].LocalDevID = info.ID
- dutTestNets[i].LocalMAC = info.MAC
+ dutInfos[i].Net.LocalDevName = name
+ dutInfos[i].Net.LocalDevID = info.ID
+ dutInfos[i].Net.LocalMAC = info.MAC
localIPv6, err := getOrAssignIPv6Addr(ctx, testbenchContainer, name)
if err != nil {
t.Fatalf("failed to get IPV6 address on %s: %s", testbenchContainer.Name, err)
}
- dutTestNets[i].LocalIPv6 = localIPv6
+ dutInfos[i].Net.LocalIPv6 = localIPv6
}
- dutTestNetsBytes, err := json.Marshal(dutTestNets)
+ dutInfosBytes, err := json.Marshal(dutInfos)
if err != nil {
- t.Fatalf("failed to marshal %v into json: %s", dutTestNets, err)
+ t.Fatalf("failed to marshal %v into json: %s", dutInfos, err)
}
baseSnifferArgs := []string{
@@ -296,7 +305,8 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
"-n",
}
}
- for _, n := range dutTestNets {
+ for _, info := range dutInfos {
+ n := info.Net
snifferArgs := append(baseSnifferArgs, "-i", n.LocalDevName)
if !tshark {
snifferArgs = append(
@@ -351,7 +361,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co
testArgs = append(testArgs, extraTestArgs...)
testArgs = append(testArgs,
fmt.Sprintf("--native=%t", native),
- "--dut_test_nets_json", string(dutTestNetsBytes),
+ "--dut_infos_json", string(dutInfosBytes),
)
testbenchLogs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, testArgs...)
if (err != nil) != expectFailure {
@@ -388,6 +398,10 @@ type DUT interface {
// The t parameter is supposed to be used for t.Cleanup. Don't use it for
// t.Fatal/FailNow functions.
Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network) (net.IP, net.HardwareAddr, uint32, string, error)
+
+ // Uname gathers information of DUT using command uname.
+ Uname(ctx context.Context) (*testbench.DUTUname, error)
+
// Logs retrieves the logs from the dut.
Logs(ctx context.Context) (string, error)
}
@@ -415,6 +429,10 @@ func (dut *DockerDUT) Prepare(ctx context.Context, _ *testing.T, runOpts dockeru
dut.c,
DUTAddr,
[]*dockerutil.Network{ctrlNet, testNet},
+ map[string]string{
+ // This enables creating ICMP sockets on Linux.
+ "net.ipv4.ping_group_range": "0 0",
+ },
containerPosixServerBinary,
"--ip=0.0.0.0",
fmt.Sprintf("--port=%d", CtrlPort),
@@ -440,6 +458,38 @@ func (dut *DockerDUT) Prepare(ctx context.Context, _ *testing.T, runOpts dockeru
return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev, nil
}
+// Uname implements DUT.Uname.
+func (dut *DockerDUT) Uname(ctx context.Context) (*testbench.DUTUname, error) {
+ machine, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-m")
+ if err != nil {
+ return nil, err
+ }
+ kernelRelease, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-r")
+ if err != nil {
+ return nil, err
+ }
+ kernelVersion, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-v")
+ if err != nil {
+ return nil, err
+ }
+ kernelName, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-s")
+ if err != nil {
+ return nil, err
+ }
+ // TODO(gvisor.dev/issues/5586): -o is not supported on macOS.
+ operatingSystem, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-o")
+ if err != nil {
+ return nil, err
+ }
+ return &testbench.DUTUname{
+ Machine: strings.TrimRight(machine, "\n"),
+ KernelName: strings.TrimRight(kernelName, "\n"),
+ KernelRelease: strings.TrimRight(kernelRelease, "\n"),
+ KernelVersion: strings.TrimRight(kernelVersion, "\n"),
+ OperatingSystem: strings.TrimRight(operatingSystem, "\n"),
+ }, nil
+}
+
// Logs implements DUT.Logs.
func (dut *DockerDUT) Logs(ctx context.Context) (string, error) {
logs, err := dut.c.Logs(ctx)
@@ -545,11 +595,14 @@ func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error {
// StartContainer will create a container instance from runOpts, connect it
// with the specified docker networks and start executing the specified cmd.
-func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, cmd ...string) error {
+func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, sysctls map[string]string, cmd ...string) error {
conf, hostconf, netconf := c.ConfigsFrom(runOpts, cmd...)
_ = netconf
hostconf.AutoRemove = true
hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"}
+ for k, v := range sysctls {
+ hostconf.Sysctls[k] = v
+ }
if err := c.CreateFrom(ctx, runOpts.Image, conf, hostconf, nil); err != nil {
return fmt.Errorf("unable to create container %s: %w", c.Name, err)
diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD
index 983c2c030..43b4c7ca1 100644
--- a/test/packetimpact/testbench/BUILD
+++ b/test/packetimpact/testbench/BUILD
@@ -1,7 +1,6 @@
load("//tools:defs.bzl", "go_library", "go_test")
package(
- default_visibility = ["//test/packetimpact:__subpackages__"],
licenses = ["notice"],
)
@@ -15,6 +14,7 @@ go_library(
"rawsockets.go",
"testbench.go",
],
+ visibility = ["//test/packetimpact:__subpackages__"],
deps = [
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 15e1a51de..8ad9040ff 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -677,17 +677,17 @@ func (conn *TCPIPv4) Connect(t *testing.T) {
t.Helper()
// Send the SYN.
- conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)})
+ conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
- synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ synAck, err := conn.Expect(t, TCP{Flags: TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("didn't get synack during handshake: %s", err)
}
conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
// Send an ACK.
- conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)})
+ conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagAck)})
}
// ConnectWithOptions performs a TCP 3-way handshake with given TCP options.
@@ -696,17 +696,17 @@ func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) {
t.Helper()
// Send the SYN.
- conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options})
+ conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagSyn), Options: options})
// Wait for the SYN-ACK.
- synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ synAck, err := conn.Expect(t, TCP{Flags: TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("didn't get synack during handshake: %s", err)
}
conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
// Send an ACK.
- conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)})
+ conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagAck)})
}
// ExpectData is a convenient method that expects a Layer and the Layer after
@@ -823,6 +823,27 @@ func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
return sa
}
+// GenerateOTWSeqSegment generates a segment with
+// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only
+// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the
+// receiver.
+func GenerateOTWSeqSegment(t *testing.T, conn *TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) TCP {
+ t.Helper()
+ lastAcceptable := conn.LocalSeqNum(t).Add(windowSize)
+ otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
+ return TCP{SeqNum: Uint32(otwSeq), Flags: TCPFlags(header.TCPFlagAck)}
+}
+
+// GenerateUnaccACKSegment generates a segment with
+// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable
+// when seqNumOffset is 0, otherwise an ACK is expected from the receiver.
+func GenerateUnaccACKSegment(t *testing.T, conn *TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) TCP {
+ t.Helper()
+ lastAcceptable := conn.RemoteSeqNum(t)
+ unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
+ return TCP{AckNum: Uint32(unaccAck), Flags: TCPFlags(header.TCPFlagAck)}
+}
+
// IPv4Conn maintains the state for all the layers in a IPv4 connection.
type IPv4Conn struct {
Connection
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index be5121d98..eabdc8cb3 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -19,7 +19,6 @@ import (
"encoding/binary"
"fmt"
"net"
- "syscall"
"testing"
"time"
@@ -35,24 +34,26 @@ type DUT struct {
conn *grpc.ClientConn
posixServer POSIXClient
Net *DUTTestNet
+ Uname *DUTUname
}
// NewDUT creates a new connection with the DUT over gRPC.
func NewDUT(t *testing.T) DUT {
t.Helper()
- n := GetDUTTestNet()
- dut := n.ConnectToDUT(t)
+ info := getDUTInfo()
+ dut := info.ConnectToDUT(t)
t.Cleanup(func() {
dut.TearDownConnection()
- dut.Net.Release()
+ info.release()
})
return dut
}
// ConnectToDUT connects to DUT through gRPC.
-func (n *DUTTestNet) ConnectToDUT(t *testing.T) DUT {
+func (info *DUTInfo) ConnectToDUT(t *testing.T) DUT {
t.Helper()
+ n := info.Net
posixServerAddress := net.JoinHostPort(n.POSIXServerIP.String(), fmt.Sprintf("%d", n.POSIXServerPort))
conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive}))
if err != nil {
@@ -63,6 +64,7 @@ func (n *DUTTestNet) ConnectToDUT(t *testing.T) DUT {
conn: conn,
posixServer: posixServer,
Net: n,
+ Uname: info.Uname,
}
}
@@ -196,7 +198,7 @@ func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32)
if err != nil {
t.Fatalf("failed to call Accept: %s", err)
}
- return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), unix.Errno(resp.GetErrno_())
}
// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
@@ -225,7 +227,7 @@ func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa un
if err != nil {
t.Fatalf("failed to call Bind: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), unix.Errno(resp.GetErrno_())
}
// Close calls close on the DUT and causes a fatal test failure if it doesn't
@@ -253,7 +255,7 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int
if err != nil {
t.Fatalf("failed to call Close: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), unix.Errno(resp.GetErrno_())
}
// Connect calls connect on the DUT and causes a fatal test failure if it
@@ -267,7 +269,7 @@ func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) {
ret, err := dut.ConnectWithErrno(ctx, t, fd, sa)
// Ignore 'operation in progress' error that can be returned when the socket
// is non-blocking.
- if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 {
+ if err != unix.EINPROGRESS && ret != 0 {
t.Fatalf("failed to connect socket: %s", err)
}
}
@@ -284,7 +286,7 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa
if err != nil {
t.Fatalf("failed to call Connect: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), unix.Errno(resp.GetErrno_())
}
// GetSockName calls getsockname on the DUT and causes a fatal test failure if
@@ -313,7 +315,7 @@ func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd i
if err != nil {
t.Fatalf("failed to call Bind: %s", err)
}
- return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), unix.Errno(resp.GetErrno_())
}
func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) {
@@ -334,7 +336,7 @@ func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, opt
if optval == nil {
t.Fatalf("GetSockOpt response does not contain a value")
}
- return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), optval, unix.Errno(resp.GetErrno_())
}
// GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it
@@ -452,7 +454,7 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backl
if err != nil {
t.Fatalf("failed to call Listen: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), unix.Errno(resp.GetErrno_())
}
// PollOne calls poll on the DUT and asserts that the expected event must be
@@ -519,7 +521,7 @@ func (dut *DUT) PollWithErrno(ctx context.Context, t *testing.T, pfds []unix.Pol
Revents: int16(protoPfd.GetEvents()),
})
}
- return resp.GetRet(), result, syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), result, unix.Errno(resp.GetErrno_())
}
// Send calls send on the DUT and causes a fatal test failure if it doesn't
@@ -550,7 +552,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, b
if err != nil {
t.Fatalf("failed to call Send: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), unix.Errno(resp.GetErrno_())
}
// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't
@@ -582,7 +584,7 @@ func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32,
if err != nil {
t.Fatalf("failed to call SendTo: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), unix.Errno(resp.GetErrno_())
}
// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking
@@ -602,7 +604,7 @@ func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) {
t.Fatalf("failed to call SetNonblocking: %s", err)
}
if resp.GetRet() == -1 {
- t.Fatalf("fcntl(%d, %s) failed: %s", fd, resp.GetCmd(), syscall.Errno(resp.GetErrno_()))
+ t.Fatalf("fcntl(%d, %s) failed: %s", fd, resp.GetCmd(), unix.Errno(resp.GetErrno_()))
}
}
@@ -619,7 +621,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, opt
if err != nil {
t.Fatalf("failed to call SetSockOpt: %s", err)
}
- return resp.GetRet(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), unix.Errno(resp.GetErrno_())
}
// SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it
@@ -720,7 +722,7 @@ func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32,
if err != nil {
t.Fatalf("failed to call Socket: %s", err)
}
- return resp.GetFd(), syscall.Errno(resp.GetErrno_())
+ return resp.GetFd(), unix.Errno(resp.GetErrno_())
}
// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
@@ -751,7 +753,7 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, fl
if err != nil {
t.Fatalf("failed to call Recv: %s", err)
}
- return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), resp.GetBuf(), unix.Errno(resp.GetErrno_())
}
// SetSockLingerOption sets SO_LINGER socket option on the DUT.
@@ -771,16 +773,19 @@ func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Dur
// Shutdown calls shutdown on the DUT and causes a fatal test failure if it
// doesn't succeed. If more control over the timeout or error handling is
// needed, use ShutdownWithErrno.
-func (dut *DUT) Shutdown(t *testing.T, fd, how int32) error {
+func (dut *DUT) Shutdown(t *testing.T, fd, how int32) {
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- return dut.ShutdownWithErrno(ctx, t, fd, how)
+ ret, err := dut.ShutdownWithErrno(ctx, t, fd, how)
+ if ret != 0 {
+ t.Fatalf("failed to shutdown(%d, %d): %s", fd, how, err)
+ }
}
// ShutdownWithErrno calls shutdown on the DUT.
-func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int32) error {
+func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int32) (int32, error) {
t.Helper()
req := &pb.ShutdownRequest{
@@ -791,5 +796,8 @@ func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int
if err != nil {
t.Fatalf("failed to call Shutdown: %s", err)
}
- return syscall.Errno(resp.GetErrno_())
+ if resp.GetErrno_() == 0 {
+ return resp.GetRet(), nil
+ }
+ return resp.GetRet(), unix.Errno(resp.GetErrno_())
}
diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go
index 64a7a171a..2311f7686 100644
--- a/test/packetimpact/testbench/layers.go
+++ b/test/packetimpact/testbench/layers.go
@@ -407,6 +407,12 @@ func Uint8(v uint8) *uint8 {
return &v
}
+// TCPFlags is a helper routine that allocates a new
+// header.TCPFlags value to store v and returns a pointer to it.
+func TCPFlags(v header.TCPFlags) *header.TCPFlags {
+ return &v
+}
+
// Address is a helper routine that allocates a new tcpip.Address value to
// store v and returns a pointer to it.
func Address(v tcpip.Address) *tcpip.Address {
@@ -1030,7 +1036,7 @@ type TCP struct {
SeqNum *uint32
AckNum *uint32
DataOffset *uint8
- Flags *uint8
+ Flags *header.TCPFlags
WindowSize *uint16
Checksum *uint16
UrgentPointer *uint16
@@ -1063,7 +1069,7 @@ func (l *TCP) ToBytes() ([]byte, error) {
h.SetDataOffset(uint8(l.length()))
}
if l.Flags != nil {
- h.SetFlags(*l.Flags)
+ h.SetFlags(uint8(*l.Flags))
}
if l.WindowSize != nil {
h.SetWindowSize(*l.WindowSize)
@@ -1157,7 +1163,7 @@ func parseTCP(b []byte) (Layer, layerParser) {
SeqNum: Uint32(h.SequenceNumber()),
AckNum: Uint32(h.AckNumber()),
DataOffset: Uint8(h.DataOffset()),
- Flags: Uint8(h.Flags()),
+ Flags: TCPFlags(h.Flags()),
WindowSize: Uint16(h.WindowSize()),
Checksum: Uint16(h.Checksum()),
UrgentPointer: Uint16(h.UrgentPointer()),
diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go
index eca0780b5..614a5de1e 100644
--- a/test/packetimpact/testbench/layers_test.go
+++ b/test/packetimpact/testbench/layers_test.go
@@ -178,7 +178,7 @@ func TestLayerStringFormat(t *testing.T) {
SeqNum: Uint32(3452155723),
AckNum: Uint32(2596996163),
DataOffset: Uint8(5),
- Flags: Uint8(20),
+ Flags: TCPFlags(header.TCPFlagRst | header.TCPFlagAck),
WindowSize: Uint16(64240),
Checksum: Uint16(0x2e2b),
},
@@ -188,7 +188,7 @@ func TestLayerStringFormat(t *testing.T) {
"SeqNum:3452155723 " +
"AckNum:2596996163 " +
"DataOffset:5 " +
- "Flags:20 " +
+ "Flags: R A " +
"WindowSize:64240 " +
"Checksum:11819" +
"}",
@@ -436,7 +436,7 @@ func TestTCPOptions(t *testing.T) {
DstPort: Uint16(54321),
SeqNum: Uint32(0),
AckNum: Uint32(0),
- Flags: Uint8(header.TCPFlagSyn),
+ Flags: TCPFlags(header.TCPFlagSyn),
WindowSize: Uint16(8192),
Checksum: Uint16(0xf51c),
UrgentPointer: Uint16(0),
@@ -480,7 +480,7 @@ func TestTCPOptions(t *testing.T) {
DstPort: Uint16(54321),
SeqNum: Uint32(0),
AckNum: Uint32(0),
- Flags: Uint8(header.TCPFlagSyn),
+ Flags: TCPFlags(header.TCPFlagSyn),
WindowSize: Uint16(8192),
Checksum: Uint16(0xe521),
UrgentPointer: Uint16(0),
diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go
index 891897d55..a73c07e64 100644
--- a/test/packetimpact/testbench/testbench.go
+++ b/test/packetimpact/testbench/testbench.go
@@ -34,14 +34,29 @@ var (
// RPCTimeout is the gRPC timeout.
RPCTimeout = 100 * time.Millisecond
- // dutTestNetsJSON is the json string that describes all the test networks to
+ // dutInfosJSON is the json string that describes information about all the
// duts available to use.
- dutTestNetsJSON string
- // dutTestNets is the pool among which the testbench can choose a DUT to work
+ dutInfosJSON string
+ // dutInfo is the pool among which the testbench can choose a DUT to work
// with.
- dutTestNets chan *DUTTestNet
+ dutInfo chan *DUTInfo
)
+// DUTInfo has both network and uname information about the DUT.
+type DUTInfo struct {
+ Uname *DUTUname
+ Net *DUTTestNet
+}
+
+// DUTUname contains information about the DUT from uname.
+type DUTUname struct {
+ Machine string
+ KernelName string
+ KernelRelease string
+ KernelVersion string
+ OperatingSystem string
+}
+
// DUTTestNet describes the test network setup on dut and how the testbench
// should connect with an existing DUT.
type DUTTestNet struct {
@@ -86,7 +101,7 @@ func registerFlags(fs *flag.FlagSet) {
fs.BoolVar(&Native, "native", Native, "whether the test is running natively")
fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout")
fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive")
- fs.StringVar(&dutTestNetsJSON, "dut_test_nets_json", dutTestNetsJSON, "path to the dut test nets json file")
+ fs.StringVar(&dutInfosJSON, "dut_infos_json", dutInfosJSON, "json that describes the DUTs")
}
// Initialize initializes the testbench, it parse the flags and sets up the
@@ -94,27 +109,27 @@ func registerFlags(fs *flag.FlagSet) {
func Initialize(fs *flag.FlagSet) {
registerFlags(fs)
flag.Parse()
- if err := loadDUTTestNets(); err != nil {
+ if err := loadDUTInfos(); err != nil {
panic(err)
}
}
-// loadDUTTestNets loads available DUT test networks from the json file, it
+// loadDUTInfos loads available DUT test infos from the json file, it
// must be called after flag.Parse().
-func loadDUTTestNets() error {
- var parsedTestNets []DUTTestNet
- if err := json.Unmarshal([]byte(dutTestNetsJSON), &parsedTestNets); err != nil {
+func loadDUTInfos() error {
+ var dutInfos []DUTInfo
+ if err := json.Unmarshal([]byte(dutInfosJSON), &dutInfos); err != nil {
return fmt.Errorf("failed to unmarshal JSON: %w", err)
}
- if got, want := len(parsedTestNets), 1; got < want {
+ if got, want := len(dutInfos), 1; got < want {
return fmt.Errorf("got %d DUTs, the test requires at least %d DUTs", got, want)
}
// Using a buffered channel as semaphore
- dutTestNets = make(chan *DUTTestNet, len(parsedTestNets))
- for i := range parsedTestNets {
- parsedTestNets[i].LocalIPv4 = parsedTestNets[i].LocalIPv4.To4()
- parsedTestNets[i].RemoteIPv4 = parsedTestNets[i].RemoteIPv4.To4()
- dutTestNets <- &parsedTestNets[i]
+ dutInfo = make(chan *DUTInfo, len(dutInfos))
+ for i := range dutInfos {
+ dutInfos[i].Net.LocalIPv4 = dutInfos[i].Net.LocalIPv4.To4()
+ dutInfos[i].Net.RemoteIPv4 = dutInfos[i].Net.RemoteIPv4.To4()
+ dutInfo <- &dutInfos[i]
}
return nil
}
@@ -130,14 +145,13 @@ func GenerateRandomPayload(t *testing.T, n int) []byte {
return buf
}
-// GetDUTTestNet gets a usable DUTTestNet, the function will block until any
-// becomes available.
-func GetDUTTestNet() *DUTTestNet {
- return <-dutTestNets
+// getDUTInfo returns information about an available DUT from the pool. If no
+// DUT is readily available, getDUTInfo blocks until one becomes available.
+func getDUTInfo() *DUTInfo {
+ return <-dutInfo
}
-// Release releases the DUTTestNet back to the pool so that some other test
-// can use.
-func (n *DUTTestNet) Release() {
- dutTestNets <- n
+// release returns the DUTInfo back to the pool.
+func (info *DUTInfo) release() {
+ dutInfo <- info
}
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 301cf4980..d5cb0ae06 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -124,6 +124,17 @@ packetimpact_testbench(
)
packetimpact_testbench(
+ name = "tcp_outside_the_window_closing",
+ srcs = ["tcp_outside_the_window_closing_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_testbench(
name = "tcp_noaccept_close_rst",
srcs = ["tcp_noaccept_close_rst_test.go"],
deps = [
@@ -155,6 +166,17 @@ packetimpact_testbench(
)
packetimpact_testbench(
+ name = "tcp_unacc_seq_ack_closing",
+ srcs = ["tcp_unacc_seq_ack_closing_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_testbench(
name = "tcp_paws_mechanism",
srcs = ["tcp_paws_mechanism_test.go"],
deps = [
@@ -375,6 +397,16 @@ packetimpact_testbench(
],
)
+packetimpact_testbench(
+ name = "tcp_fin_retransmission",
+ srcs = ["tcp_fin_retransmission_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
validate_all_tests()
[packetimpact_go_test(
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
index 11f0fcd1e..cff8ca51d 100644
--- a/test/packetimpact/tests/fin_wait2_timeout_test.go
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -51,21 +51,21 @@ func TestFinWait2Timeout(t *testing.T) {
}
dut.Close(t, acceptFd)
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
time.Sleep(5 * time.Second)
conn.Drain(t)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
if tt.linger2 {
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, time.Second); err != nil {
t.Fatalf("expected a RST packet within a second but got none: %s", err)
}
} else {
- if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil {
+ if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil {
t.Fatalf("expected no RST packets within ten seconds but got one: %s", got)
}
}
diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
index a63b41366..2b69ceecb 100644
--- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go
+++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
@@ -100,7 +100,7 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) {
// Let the DUT estimate RTO with RTT from the DATA-ACK.
// TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
// we can skip sending this ACK.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
dut.Send(t, remoteFD, tc.payload, 0)
expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))}
diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go
index a7ba5035e..1db3c9883 100644
--- a/test/packetimpact/tests/tcp_cork_mss_test.go
+++ b/test/packetimpact/tests/tcp_cork_mss_test.go
@@ -60,24 +60,24 @@ func TestTCPCorkMSS(t *testing.T) {
// Expect the segments to be coalesced and sent and capped to MSS.
expectedPayload := testbench.Payload{Bytes: expectedData[:mss]}
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
// Expect the coalesced segment to be split and transmitted.
expectedPayload = testbench.Payload{Bytes: expectedData[mss:]}
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
// Check for segments to *not* be held up because of TCP_CORK when
// the current send window is less than MSS.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))})
dut.Send(t, acceptFD, sampleData, 0)
dut.Send(t, acceptFD, sampleData, 0)
expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)}
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
}
diff --git a/test/packetimpact/tests/tcp_fin_retransmission_test.go b/test/packetimpact/tests/tcp_fin_retransmission_test.go
new file mode 100644
index 000000000..500f7a783
--- /dev/null
+++ b/test/packetimpact/tests/tcp_fin_retransmission_test.go
@@ -0,0 +1,87 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_fin_retransmission_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+// TestTCPClosingFinRetransmission tests that TCP implementation should retransmit
+// FIN segment in CLOSING state.
+func TestTCPClosingFinRetransmission(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ flags header.TCPFlags
+ }{
+ {"CLOSING", header.TCPFlagAck | header.TCPFlagFin},
+ {"FIN_WAIT_1", header.TCPFlagAck},
+ } {
+ t.Run(tt.description, func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
+
+ // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK.
+ // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
+ // we can skip the next block of code.
+ sampleData := []byte("Sample Data")
+ if got, want := dut.Send(t, acceptFD, sampleData, 0), len(sampleData); int(got) != want {
+ t.Fatalf("got dut.Send(t, %d, %s, 0) = %d, want %d", acceptFD, sampleData, got, want)
+ }
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
+ t.Fatalf("expected payload was not received: %s", err)
+ }
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+
+ dut.Shutdown(t, acceptFD, unix.SHUT_WR)
+
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected FINACK from DUT, but got none: %s", err)
+ }
+
+ // Do not ack the FIN from DUT so that we can test for retransmission.
+ seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1)
+ conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(tt.flags)})
+
+ if tt.flags&header.TCPFlagFin != 0 {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Errorf("expected an ACK to our FIN, but got none: %s", err)
+ }
+ }
+
+ if _, err := conn.Expect(t, testbench.TCP{
+ SeqNum: seqNumForTheirFIN,
+ Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck),
+ }, time.Second); err != nil {
+ t.Errorf("expected retransmission of FIN from the DUT: %s", err)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go
index 5d1266f3c..668e0275c 100644
--- a/test/packetimpact/tests/tcp_handshake_window_size_test.go
+++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go
@@ -38,8 +38,8 @@ func TestTCPHandshakeWindowSize(t *testing.T) {
defer conn.Close(t)
// Start handshake with zero window size.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN-ACK: %s", err)
}
// Update the advertised window size to a non-zero value with the ACK that
@@ -47,7 +47,7 @@ func TestTCPHandshakeWindowSize(t *testing.T) {
//
// Set the window size with MSB set and expect the dut to treat it as
// an unsigned value.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))})
acceptFd, _ := dut.Accept(t, listenFD)
defer dut.Close(t, acceptFd)
@@ -59,7 +59,7 @@ func TestTCPHandshakeWindowSize(t *testing.T) {
// expect the dut to honor the recently advertised non-zero window
// and actually send out the data instead of probing for zero window.
dut.Send(t, acceptFd, sampleData, 0)
- if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go
index 69275e54b..3fc2c7fe5 100644
--- a/test/packetimpact/tests/tcp_info_test.go
+++ b/test/packetimpact/tests/tcp_info_test.go
@@ -51,7 +51,7 @@ func TestTCPInfo(t *testing.T) {
if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
info := linux.TCPInfo{}
infoBytes := dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo))
diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go
index bc4b64388..88942904d 100644
--- a/test/packetimpact/tests/tcp_linger_test.go
+++ b/test/packetimpact/tests/tcp_linger_test.go
@@ -17,7 +17,6 @@ package tcp_linger_test
import (
"context"
"flag"
- "syscall"
"testing"
"time"
@@ -58,10 +57,10 @@ func TestTCPLingerZeroTimeout(t *testing.T) {
dut.Close(t, acceptFD)
// If the linger timeout is set to zero, the DUT should send a RST.
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil {
t.Errorf("expected RST-ACK packet within a second but got none: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
}
// TestTCPLingerOff tests when SO_LINGER is not set. DUT should send FIN-ACK
@@ -75,10 +74,10 @@ func TestTCPLingerOff(t *testing.T) {
dut.Close(t, acceptFD)
// If SO_LINGER is not set, DUT should send a FIN-ACK.
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
t.Errorf("expected FIN-ACK packet within a second but got none: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
}
// TestTCPLingerNonZeroTimeout tests when SO_LINGER is set with non-zero timeout.
@@ -115,10 +114,10 @@ func TestTCPLingerNonZeroTimeout(t *testing.T) {
t.Errorf("expected close to return within a second, but returned later")
}
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
t.Errorf("expected FIN-ACK packet within a second but got none: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
})
}
}
@@ -166,10 +165,10 @@ func TestTCPLingerSendNonZeroTimeout(t *testing.T) {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
t.Errorf("expected FIN-ACK packet within a second but got none: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
})
}
}
@@ -183,19 +182,19 @@ func TestTCPLingerShutdownZeroTimeout(t *testing.T) {
defer closeAll(t, dut, listenFD, conn)
dut.SetSockLingerOption(t, acceptFD, 0, true)
- dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR)
+ dut.Shutdown(t, acceptFD, unix.SHUT_RDWR)
dut.Close(t, acceptFD)
// Shutdown will send FIN-ACK with read/write option.
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
t.Errorf("expected FIN-ACK packet within a second but got none: %s", err)
}
// If the linger timeout is set to zero, the DUT should send a RST.
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil {
t.Errorf("expected RST-ACK packet within a second but got none: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
}
// TestTCPLingerShutdownSendNonZeroTimeout tests SO_LINGER with shutdown() and
@@ -220,7 +219,7 @@ func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) {
sampleData := []byte("Sample Data")
dut.Send(t, acceptFD, sampleData, 0)
- dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR)
+ dut.Shutdown(t, acceptFD, unix.SHUT_RDWR)
// Increase timeout as Close will take longer time to
// return when SO_LINGER is set with non-zero timeout.
@@ -243,10 +242,10 @@ func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) {
t.Fatalf("expected a packet with payload %v: %s", samplePayload, err)
}
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
t.Errorf("expected FIN-ACK packet within a second but got none: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
})
}
}
diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go
index 53dc903e4..5168450ad 100644
--- a/test/packetimpact/tests/tcp_network_unreachable_test.go
+++ b/test/packetimpact/tests/tcp_network_unreachable_test.go
@@ -17,7 +17,6 @@ package tcp_synsent_reset_test
import (
"context"
"flag"
- "syscall"
"testing"
"time"
@@ -46,12 +45,12 @@ func TestTCPSynSentUnreachable(t *testing.T) {
defer cancel()
sa := unix.SockaddrInet4{Port: int(port)}
copy(sa.Addr[:], dut.Net.LocalIPv4)
- if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
+ if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != unix.EINPROGRESS {
t.Errorf("got connect() = %v, want EINPROGRESS", err)
}
// Get the SYN.
- tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second)
+ tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, nil, time.Second)
if err != nil {
t.Fatalf("expected SYN: %s", err)
}
@@ -100,12 +99,12 @@ func TestTCPSynSentUnreachable6(t *testing.T) {
ZoneId: dut.Net.RemoteDevID,
}
copy(sa.Addr[:], dut.Net.LocalIPv6)
- if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
+ if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != unix.EINPROGRESS {
t.Errorf("got connect() = %v, want EINPROGRESS", err)
}
// Get the SYN.
- tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second)
+ tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, nil, time.Second)
if err != nil {
t.Fatalf("expected SYN: %s", err)
}
@@ -156,7 +155,7 @@ func getConnectError(t *testing.T, dut *testbench.DUT, fd int32) error {
// failure).
dut.PollOne(t, fd, unix.POLLOUT, 10*time.Second)
if errno := dut.GetSockOptInt(t, fd, unix.SOL_SOCKET, unix.SO_ERROR); errno != 0 {
- return syscall.Errno(errno)
+ return unix.Errno(errno)
}
return nil
}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
index d2871df08..14eb7d93b 100644
--- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -40,7 +40,7 @@ func TestTcpNoAcceptCloseReset(t *testing.T) {
// it will only respond RST instead of RST+ACK.
dut.PollOne(t, listenFd, unix.POLLIN, time.Second)
dut.Close(t, listenFd)
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
t.Fatalf("expected a RST-ACK packet but got none: %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_closing_test.go b/test/packetimpact/tests/tcp_outside_the_window_closing_test.go
new file mode 100644
index 000000000..1097746c7
--- /dev/null
+++ b/test/packetimpact/tests/tcp_outside_the_window_closing_test.go
@@ -0,0 +1,86 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_outside_the_window_closing_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+// TestAckOTWSeqInClosing tests that the DUT should send an ACK with
+// the right ACK number when receiving a packet with OTW Seq number
+// in CLOSING state. https://tools.ietf.org/html/rfc793#page-69
+func TestAckOTWSeqInClosing(t *testing.T) {
+ for seqNumOffset := seqnum.Size(0); seqNumOffset < 3; seqNumOffset++ {
+ for _, tt := range []struct {
+ description string
+ flags header.TCPFlags
+ payloads testbench.Layers
+ }{
+ {"SYN", header.TCPFlagSyn, nil},
+ {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil},
+ {"ACK", header.TCPFlagAck, nil},
+ {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil},
+ {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}},
+ } {
+ t.Run(fmt.Sprintf("%s%d", tt.description, seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
+
+ dut.Shutdown(t, acceptFD, unix.SHUT_WR)
+
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ t.Fatalf("expected FINACK from DUT, but got none: %s", err)
+ }
+
+ // Do not ack the FIN from DUT so that the TCP state on DUT is CLOSING instead of CLOSED.
+ seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1)
+ conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)})
+
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Errorf("expected an ACK to our FIN, but got none: %s", err)
+ }
+
+ windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + seqNumOffset
+ conn.SendFrameStateless(t, conn.CreateFrame(t, testbench.Layers{&testbench.TCP{
+ SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))),
+ AckNum: seqNumForTheirFIN,
+ Flags: testbench.TCPFlags(tt.flags),
+ }}, tt.payloads...))
+
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Errorf("expected an ACK but got none: %s", err)
+ }
+ })
+ }
+ }
+}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
index 8909a348e..7cd7ff703 100644
--- a/test/packetimpact/tests/tcp_outside_the_window_test.go
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -37,7 +37,7 @@ func init() {
func TestTCPOutsideTheWindow(t *testing.T) {
for _, tt := range []struct {
description string
- tcpFlags uint8
+ tcpFlags header.TCPFlags
payload []testbench.Layer
seqNumOffset seqnum.Size
expectACK bool
@@ -76,11 +76,11 @@ func TestTCPOutsideTheWindow(t *testing.T) {
// to the AckNum.
localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t)))
conn.Send(t, testbench.TCP{
- Flags: testbench.Uint8(tt.tcpFlags),
+ Flags: testbench.TCPFlags(tt.tcpFlags),
SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))),
}, tt.payload...)
- timeout := 3 * time.Second
- gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
+ timeout := time.Second
+ gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
if tt.expectACK && err != nil {
t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err)
}
@@ -93,11 +93,11 @@ func TestTCPOutsideTheWindow(t *testing.T) {
// has passed since the last ACK was sent.
t.Logf("sending another segment")
conn.Send(t, testbench.TCP{
- Flags: testbench.Uint8(tt.tcpFlags),
+ Flags: testbench.TCPFlags(tt.tcpFlags),
SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))),
}, tt.payload...)
timeout := 3 * time.Second
- gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
+ gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
if err == nil {
t.Fatalf("expected no ACK packet but got one: %s", gotACK)
}
diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go
index 24d9ef4ec..9054955ea 100644
--- a/test/packetimpact/tests/tcp_paws_mechanism_test.go
+++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go
@@ -38,8 +38,8 @@ func TestPAWSMechanism(t *testing.T) {
options := make([]byte, header.TCPOptionTSLength)
header.EncodeTSOption(currentTS(), 0, options)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options})
- synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn), Options: options})
+ synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("didn't get synack during handshake: %s", err)
}
@@ -49,7 +49,7 @@ func TestPAWSMechanism(t *testing.T) {
}
tsecr := parsedSynOpts.TSVal
header.EncodeTSOption(currentTS(), tsecr, options)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), Options: options})
acceptFD, _ := dut.Accept(t, listenFD)
defer dut.Close(t, acceptFD)
@@ -60,9 +60,9 @@ func TestPAWSMechanism(t *testing.T) {
// every time we send one, it should not cause any flakiness because timestamps
// only need to be non-decreasing.
time.Sleep(3 * time.Millisecond)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
- gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected an ACK but got none: %s", err)
}
@@ -85,9 +85,9 @@ func TestPAWSMechanism(t *testing.T) {
// 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness
// due to the exact same reasoning discussed above.
time.Sleep(3 * time.Millisecond)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
- gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err)
}
diff --git a/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go
index 7dd1c326a..1c8b72ebe 100644
--- a/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go
+++ b/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go
@@ -21,7 +21,6 @@ import (
"errors"
"flag"
"sync"
- "syscall"
"testing"
"time"
@@ -45,10 +44,10 @@ func TestQueueSendInSynSentHandshake(t *testing.T) {
sampleData := []byte("Sample Data")
dut.SetNonBlocking(t, socket, true)
- if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) {
+ if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, unix.EINPROGRESS) {
t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
}
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second); err != nil {
t.Fatalf("expected a SYN from DUT, but got none: %s", err)
}
@@ -85,19 +84,19 @@ func TestQueueSendInSynSentHandshake(t *testing.T) {
time.Sleep(100 * time.Millisecond)
// Bring the connection to Established.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)})
// Expect the data from the DUT's enqueued send request.
//
// On Linux, this can be piggybacked with the ACK completing the
// handshake. On gVisor, getting such a piggyback is a bit more
// complicated because the actual data enqueuing occurs in the
// callers of endpoint Write.
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected an ACK from DUT, but got none: %s", err)
}
}
@@ -113,15 +112,15 @@ func TestQueueRecvInSynSentHandshake(t *testing.T) {
sampleData := []byte("Sample Data")
dut.SetNonBlocking(t, socket, true)
- if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) {
+ if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, unix.EINPROGRESS) {
t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
}
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second); err != nil {
t.Fatalf("expected a SYN from DUT, but got none: %s", err)
}
- if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) {
- t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err)
+ if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != unix.EWOULDBLOCK {
+ t.Fatalf("expected error %s, got %s", unix.EWOULDBLOCK, err)
}
// Test blocking read.
@@ -160,14 +159,14 @@ func TestQueueRecvInSynSentHandshake(t *testing.T) {
time.Sleep(100 * time.Millisecond)
// Bring the connection to Established.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected an ACK from DUT, but got none: %s", err)
}
// Send sample payload so that DUT can recv.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected an ACK from DUT, but got none: %s", err)
}
}
@@ -183,10 +182,10 @@ func TestQueueSendInSynSentRST(t *testing.T) {
sampleData := []byte("Sample Data")
dut.SetNonBlocking(t, socket, true)
- if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) {
+ if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, unix.EINPROGRESS) {
t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
}
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second); err != nil {
t.Fatalf("expected a SYN from DUT, but got none: %s", err)
}
@@ -207,8 +206,8 @@ func TestQueueSendInSynSentRST(t *testing.T) {
// Issue SEND call in SYN-SENT, this should be queued for
// process until the connection is established.
n, err := dut.SendWithErrno(ctx, t, socket, sampleData, 0)
- if err != syscall.Errno(unix.ECONNREFUSED) {
- t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err)
+ if err != unix.ECONNREFUSED {
+ t.Errorf("expected error %s, got %s", unix.ECONNREFUSED, err)
}
if n != -1 {
t.Errorf("expected return value %d, got %d", -1, n)
@@ -224,7 +223,7 @@ func TestQueueSendInSynSentRST(t *testing.T) {
// request and the system actually being blocked.
time.Sleep(100 * time.Millisecond)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)})
}
// TestQueueRecvInSynSentRST tests recv behavior when the TCP state
@@ -238,15 +237,15 @@ func TestQueueRecvInSynSentRST(t *testing.T) {
sampleData := []byte("Sample Data")
dut.SetNonBlocking(t, socket, true)
- if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) {
+ if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, unix.EINPROGRESS) {
t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
}
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second); err != nil {
t.Fatalf("expected a SYN from DUT, but got none: %s", err)
}
- if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) {
- t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err)
+ if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != unix.EWOULDBLOCK {
+ t.Fatalf("expected error %s, got %s", unix.EWOULDBLOCK, err)
}
// Test blocking read.
@@ -266,8 +265,8 @@ func TestQueueRecvInSynSentRST(t *testing.T) {
// Issue RECEIVE call in SYN-SENT, this should be queued for
// process until the connection is established.
n, _, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0)
- if err != syscall.Errno(unix.ECONNREFUSED) {
- t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err)
+ if err != unix.ECONNREFUSED {
+ t.Errorf("expected error %s, got %s", unix.ECONNREFUSED, err)
}
if n != -1 {
t.Errorf("expected return value %d, got %d", -1, n)
@@ -283,5 +282,5 @@ func TestQueueRecvInSynSentRST(t *testing.T) {
// request and the system actually being blocked.
time.Sleep(100 * time.Millisecond)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)})
}
diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go
index ef902c54d..0a5b0f12b 100644
--- a/test/packetimpact/tests/tcp_rack_test.go
+++ b/test/packetimpact/tests/tcp_rack_test.go
@@ -97,7 +97,7 @@ func sendAndReceive(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4, num
if sendACK {
time.Sleep(simulatedRTT)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(sn))})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(sn))})
}
}
return lastSent
@@ -149,7 +149,7 @@ func TestRACKTLPLost(t *testing.T) {
// Cumulative ACK for #[1-5] packets.
ackNum := seqNum1.Add(seqnum.Size(6 * payloadSize))
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(ackNum))})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(ackNum))})
// Probe Timeout (PTO) should be two times RTT. Check that the last
// packet is retransmitted after probe timeout.
@@ -194,7 +194,7 @@ func TestRACKWithSACK(t *testing.T) {
sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
start, end,
}}, sackBlock[sbOff:])
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
rtt, _ := getRTTAndRTO(t, dut, acceptFd)
timeout := 2 * rtt
@@ -206,7 +206,7 @@ func TestRACKWithSACK(t *testing.T) {
time.Sleep(simulatedRTT)
// ACK for #1 packet.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(end))})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(end))})
// RACK considers transmission times of the packets to mark them lost.
// As the 3rd packet was sent before the retransmitted 1st packet, RACK
@@ -243,7 +243,7 @@ func TestRACKWithoutReorder(t *testing.T) {
start, end,
}}, sackBlock[sbOff:])
time.Sleep(simulatedRTT)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
// RACK marks #1 and #2 packets as lost and retransmits both after
// RTT + reorderWindow. The reorderWindow initially will be a small
@@ -289,7 +289,7 @@ func TestRACKWithReorder(t *testing.T) {
sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
start, end,
}}, sackBlock[sbOff:])
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
}
// Send a DSACK block indicating both original and retransmitted
@@ -304,7 +304,7 @@ func TestRACKWithReorder(t *testing.T) {
dbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
start, end,
}}, dsackBlock[dbOff:])
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1 + numPkts*payloadSize)), Options: dsackBlock[:dbOff]})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1 + numPkts*payloadSize)), Options: dsackBlock[:dbOff]})
seqNum1.UpdateForward(seqnum.Size(numPkts * payloadSize))
sendTime := time.Now()
@@ -321,7 +321,7 @@ func TestRACKWithReorder(t *testing.T) {
sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{
start, end,
}}, sackBlock[sbOff:])
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
// Expect the retransmission of #1 packet after RTT+ReorderWindow.
if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, time.Second); err != nil {
@@ -361,7 +361,7 @@ func TestRACKWithLostRetransmission(t *testing.T) {
start, end,
}}, sackBlock[sbOff:])
time.Sleep(simulatedRTT)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
// RACK marks #1 packet as lost and retransmits it after
// RTT + reorderWindow. The reorderWindow is bounded between a small
@@ -394,7 +394,7 @@ func TestRACKWithLostRetransmission(t *testing.T) {
start, end,
}}, sackBlock1[sbOff1:])
time.Sleep(simulatedRTT)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock1[:sbOff1]})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock1[:sbOff1]})
// Expect re-retransmission of #1 packet without entering an RTO.
if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, timeout); err != nil {
diff --git a/test/packetimpact/tests/tcp_rcv_buf_space_test.go b/test/packetimpact/tests/tcp_rcv_buf_space_test.go
index d6ad5cda6..f121d44eb 100644
--- a/test/packetimpact/tests/tcp_rcv_buf_space_test.go
+++ b/test/packetimpact/tests/tcp_rcv_buf_space_test.go
@@ -17,7 +17,6 @@ package tcp_rcv_buf_space_test
import (
"context"
"flag"
- "syscall"
"testing"
"golang.org/x/sys/unix"
@@ -61,7 +60,7 @@ func TestReduceRecvBuf(t *testing.T) {
payloadBytes = l
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, []testbench.Layer{&testbench.Payload{Bytes: payload[:payloadBytes]}}...)
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, []testbench.Layer{&testbench.Payload{Bytes: payload[:payloadBytes]}}...)
payload = payload[payloadBytes:]
}
@@ -73,7 +72,7 @@ func TestReduceRecvBuf(t *testing.T) {
// Second read should return EAGAIN as the last segment should have been
// dropped due to it exceeding the receive buffer space available in the
// socket.
- if ret, got, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), syscall.MSG_DONTWAIT); got != nil || ret != -1 || err != syscall.EAGAIN {
+ if ret, got, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), unix.MSG_DONTWAIT); got != nil || ret != -1 || err != unix.EAGAIN {
t.Fatalf("expected no packets but got: %s", got)
}
}
diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go
index ba79fbf55..3dc8f63ab 100644
--- a/test/packetimpact/tests/tcp_retransmits_test.go
+++ b/test/packetimpact/tests/tcp_retransmits_test.go
@@ -15,6 +15,7 @@
package tcp_retransmits_test
import (
+ "bytes"
"flag"
"testing"
"time"
@@ -59,38 +60,43 @@ func TestRetransmits(t *testing.T) {
sampleData := []byte("Sample Data")
samplePayload := &testbench.Payload{Bytes: sampleData}
+ // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK.
+ // This is to reduce the test run-time from the default initial RTO of 1s.
+ // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
+ // we can skip this data send/recv which is solely to estimate RTO.
dut.Send(t, acceptFd, sampleData, 0)
if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK.
- // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
- // we can skip sending this ACK.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+ // Wait for the DUT to receive the data, thus ensuring that the stack has
+ // estimated RTO before we query RTO via TCP_INFO.
+ if got := dut.Recv(t, acceptFd, int32(len(sampleData)), 0); !bytes.Equal(got, sampleData) {
+ t.Fatalf("got dut.Recv(t, %d, %d, 0) = %s, want %s", acceptFd, len(sampleData), got, sampleData)
+ }
const timeoutCorrection = time.Second
- const diffCorrection = time.Millisecond
+ const diffCorrection = 200 * time.Millisecond
rto := getRTO(t, dut, acceptFd)
- timeout := rto + timeoutCorrection
- startTime := time.Now()
dut.Send(t, acceptFd, sampleData, 0)
seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
- if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, timeout); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, rto+timeoutCorrection); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
// Expect retransmits of the same segment.
for i := 0; i < 5; i++ {
- if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, timeout); err != nil {
- t.Fatalf("expected payload was not received within %d loop %d err %s", timeout, i, err)
+ startTime := time.Now()
+ rto = getRTO(t, dut, acceptFd)
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, rto+timeoutCorrection); err != nil {
+ t.Fatalf("expected payload was not received within %s loop %d err %s", rto+timeoutCorrection, i, err)
}
-
if diff := time.Since(startTime); diff+diffCorrection < rto {
- t.Fatalf("retransmit came sooner got: %d want: >= %d probe %d", diff, rto, i)
+ t.Fatalf("retransmit came sooner got: %s want: >= %s probe %d", diff+diffCorrection, rto, i)
}
- startTime = time.Now()
- rto = getRTO(t, dut, acceptFd)
- timeout = rto + timeoutCorrection
}
}
diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
index 418393796..64b7288fb 100644
--- a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
+++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
@@ -71,7 +71,7 @@ func TestSendWindowSizesPiggyback(t *testing.T) {
dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
- expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}
+ expectedTCP := testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}
dut.Send(t, acceptFd, sampleData, 0)
expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1}
@@ -90,7 +90,7 @@ func TestSendWindowSizesPiggyback(t *testing.T) {
// Send ACK for the previous segment along with data for the dut to
// receive and ACK back. Sending this ACK would make room for the dut
// to transmit any enqueued segment.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData})
// Expect the dut to piggyback the ACK for received data along with
// the segment enqueued for transmit.
diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
index 32271d7b2..3346d43c4 100644
--- a/test/packetimpact/tests/tcp_synrcvd_reset_test.go
+++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
@@ -37,11 +37,11 @@ func TestTCPSynRcvdReset(t *testing.T) {
defer conn.Close(t)
// Expect dut connection to have transitioned to SYN-RCVD state.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN-ACK %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)})
// Expect the connection to have transitioned SYN-RCVD to CLOSED.
//
@@ -49,8 +49,8 @@ func TestTCPSynRcvdReset(t *testing.T) {
// CLOSED. We cannot use TCP_INFO to lookup the state as this is a passive
// DUT connection.
for i := 0; i < 5; i++ {
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Logf("retransmit%d ACK as we did not get the expected RST, %s", i, err)
continue
}
diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go
index 2c8bb101b..cccb0abc6 100644
--- a/test/packetimpact/tests/tcp_synsent_reset_test.go
+++ b/test/packetimpact/tests/tcp_synsent_reset_test.go
@@ -42,7 +42,7 @@ func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16,
copy(sa.Addr[:], dut.Net.LocalIPv4)
// Bring the dut to SYN-SENT state with a non-blocking connect.
dut.Connect(t, clientFD, &sa)
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN\n")
}
@@ -53,11 +53,11 @@ func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16,
func TestTCPSynSentReset(t *testing.T) {
_, conn, _, _ := dutSynSentState(t)
defer conn.Close(t)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)})
// Expect the connection to have closed.
// TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
@@ -73,15 +73,15 @@ func TestTCPSynSentRcvdReset(t *testing.T) {
// Initiate new SYN connection with the same port pair
// (simultaneous open case), expect the dut connection to move to
// SYN-RCVD state
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN-ACK %s\n", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)})
// Expect the connection to have transitioned SYN-RCVD to CLOSED.
// TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
diff --git a/test/packetimpact/tests/tcp_timewait_reset_test.go b/test/packetimpact/tests/tcp_timewait_reset_test.go
index d1d2fb83d..89037f0a4 100644
--- a/test/packetimpact/tests/tcp_timewait_reset_test.go
+++ b/test/packetimpact/tests/tcp_timewait_reset_test.go
@@ -42,26 +42,26 @@ func TestTimeWaitReset(t *testing.T) {
// Trigger active close.
dut.Close(t, acceptFD)
- _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second)
+ _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected a FIN: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
// Send a FIN, DUT should transition to TIME_WAIT from FIN_WAIT2.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)})
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected an ACK for our FIN: %s", err)
}
// Send a RST, the DUT should transition to CLOSED from TIME_WAIT.
// This is the default Linux behavior, it can be changed to ignore RSTs via
// sysctl net.ipv4.tcp_rfc1337.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)})
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
// The DUT should reply with RST to our ACK as the state should have
// transitioned to CLOSED.
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, time.Second); err != nil {
t.Fatalf("expected a RST: %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go
new file mode 100644
index 000000000..a208210ac
--- /dev/null
+++ b/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go
@@ -0,0 +1,94 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_unacc_seq_ack_closing_test
+
+import (
+ "flag"
+ "fmt"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/seqnum"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+func TestSimultaneousCloseUnaccSeqAck(t *testing.T) {
+ for _, tt := range []struct {
+ description string
+ makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
+ seqNumOffset seqnum.Size
+ expectAck bool
+ }{
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: false},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: true},
+ } {
+ t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
+ dut := testbench.NewDUT(t)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
+ defer dut.Close(t, listenFD)
+ conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer conn.Close(t)
+
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+
+ // Trigger active close.
+ dut.Shutdown(t, acceptFD, unix.SHUT_WR)
+
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second)
+ if err != nil {
+ t.Fatalf("expected a FIN: %s", err)
+ }
+ // Do not ack the FIN from DUT so that we get to CLOSING.
+ seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1)
+ conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)})
+
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
+ t.Errorf("expected an ACK to our FIN, but got none: %s", err)
+ }
+
+ sampleData := []byte("Sample Data")
+ samplePayload := &testbench.Payload{Bytes: sampleData}
+
+ origSeq := uint32(*conn.LocalSeqNum(t))
+ // Send a segment with OTW Seq / unacc ACK.
+ tcp := tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize))
+ if tt.description == "OTWSeq" {
+ // If we generate an OTW Seq segment, make sure we don't acknowledge their FIN so that
+ // we stay in CLOSING.
+ tcp.AckNum = seqNumForTheirFIN
+ }
+ conn.Send(t, tcp, samplePayload)
+
+ got, err := conn.Expect(t, testbench.TCP{AckNum: testbench.Uint32(origSeq), Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
+ if tt.expectAck && err != nil {
+ t.Errorf("expected an ack in CLOSING state, but got none: %s", err)
+ }
+ if !tt.expectAck && got != nil {
+ t.Errorf("expected no ack in CLOSING state, but got one: %s", got)
+ }
+ })
+ }
+}
diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
index ea962c818..ce0a26171 100644
--- a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
+++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go
@@ -17,7 +17,6 @@ package tcp_unacc_seq_ack_test
import (
"flag"
"fmt"
- "syscall"
"testing"
"time"
@@ -39,12 +38,12 @@ func TestEstablishedUnaccSeqAck(t *testing.T) {
expectAck bool
restoreSeq bool
}{
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: true, restoreSeq: true},
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true, restoreSeq: true},
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true, restoreSeq: true},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: true, restoreSeq: false},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: false, restoreSeq: true},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: false, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: true, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true, restoreSeq: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: true, restoreSeq: false},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: false, restoreSeq: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: false, restoreSeq: true},
} {
t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
@@ -59,8 +58,8 @@ func TestEstablishedUnaccSeqAck(t *testing.T) {
sampleData := []byte("Sample Data")
samplePayload := &testbench.Payload{Bytes: sampleData}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected ack %s", err)
}
@@ -74,7 +73,7 @@ func TestEstablishedUnaccSeqAck(t *testing.T) {
// ACK matches the TCP layer state.
*conn.LocalSeqNum(t) = origSeq
}
- gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
if tt.expectAck && err != nil {
t.Fatalf("expected an ack but got none: %s", err)
}
@@ -92,12 +91,12 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) {
seqNumOffset seqnum.Size
expectAck bool
}{
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: false},
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true},
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: false},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: true},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: false},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: false},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: true},
} {
t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
@@ -110,8 +109,8 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) {
acceptFD, _ := dut.Accept(t, listenFD)
// Send a FIN to DUT to intiate the passive close.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
- gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagFin)})
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
}
@@ -122,7 +121,7 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) {
// Send a segment with OTW Seq / unacc ACK.
conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload)
- gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
if tt.expectAck && err != nil {
t.Errorf("expected an ack but got none: %s", err)
}
@@ -132,14 +131,14 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) {
// Now let's verify DUT is indeed in CLOSE_WAIT
dut.Close(t, acceptFD)
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
t.Fatalf("expected DUT to send a FIN: %s", err)
}
// Ack the FIN from DUT
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
// Send some extra data to DUT
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, samplePayload)
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, samplePayload)
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, time.Second); err != nil {
t.Fatalf("expected DUT to send an RST: %s", err)
}
})
@@ -153,12 +152,12 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) {
seqNumOffset seqnum.Size
restoreSeq bool
}{
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, restoreSeq: true},
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, restoreSeq: true},
- {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, restoreSeq: true},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, restoreSeq: false},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, restoreSeq: true},
- {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, restoreSeq: true},
+ {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, restoreSeq: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, restoreSeq: false},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, restoreSeq: true},
+ {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, restoreSeq: true},
} {
t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
@@ -171,14 +170,14 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) {
acceptFD, _ := dut.Accept(t, listenFD)
// Trigger active close.
- dut.Shutdown(t, acceptFD, syscall.SHUT_WR)
+ dut.Shutdown(t, acceptFD, unix.SHUT_WR)
// Get to FIN_WAIT2
- gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected a FIN: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
sendUnaccSeqAck := func(state string) {
t.Helper()
@@ -193,7 +192,7 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) {
// incoming ACK matches the TCP layer state.
*conn.LocalSeqNum(t) = origSeq
}
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
t.Errorf("expected an ack in %s state, but got none: %s", state, err)
}
}
@@ -201,8 +200,8 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) {
sendUnaccSeqAck("FIN_WAIT2")
// Send a FIN to DUT to get to TIME_WAIT
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)})
- if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected an ACK for our fin and DUT should enter TIME_WAIT: %s", err)
}
@@ -210,22 +209,3 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) {
})
}
}
-
-// generateOTWSeqSegment generates an segment with
-// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only
-// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the
-// receiver.
-func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
- lastAcceptable := conn.LocalSeqNum(t).Add(windowSize)
- otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
- return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)}
-}
-
-// generateUnaccACKSegment generates an segment with
-// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable
-// when seqNumOffset is 0, otherwise an ACK is expected from the receiver.
-func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
- lastAcceptable := conn.RemoteSeqNum(t)
- unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
- return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)}
-}
diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go
index b16e65366..ef38bd738 100644
--- a/test/packetimpact/tests/tcp_user_timeout_test.go
+++ b/test/packetimpact/tests/tcp_user_timeout_test.go
@@ -35,7 +35,7 @@ func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd i
}
conn.Drain(t)
dut.Send(t, fd, sampleData, 0)
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
t.Fatalf("expected data but got none: %w", err)
}
}
@@ -79,14 +79,14 @@ func TestTCPUserTimeout(t *testing.T) {
time.Sleep(tt.sendDelay)
conn.Drain(t)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
// If TCP_USER_TIMEOUT was set and the above delay was longer than the
// TCP_USER_TIMEOUT then the DUT should send a RST in response to the
// testbench's packet.
expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout
expectTimeout := 5 * time.Second
- got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout)
+ got, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, expectTimeout)
if expectRST && err != nil {
t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err)
}
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
index 093484721..0d65a2ea2 100644
--- a/test/packetimpact/tests/tcp_window_shrink_test.go
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -48,7 +48,7 @@ func TestWindowShrink(t *testing.T) {
if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
dut.Send(t, acceptFd, sampleData, 0)
dut.Send(t, acceptFd, sampleData, 0)
@@ -59,7 +59,7 @@ func TestWindowShrink(t *testing.T) {
t.Fatalf("expected payload was not received: %s", err)
}
// We close our receiving window here
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
dut.Send(t, acceptFd, []byte("Sample Data"), 0)
// Note: There is another kind of zero-window probing which Windows uses (by sending one
diff --git a/test/packetimpact/tests/tcp_zero_receive_window_test.go b/test/packetimpact/tests/tcp_zero_receive_window_test.go
index d06690705..d73495454 100644
--- a/test/packetimpact/tests/tcp_zero_receive_window_test.go
+++ b/test/packetimpact/tests/tcp_zero_receive_window_test.go
@@ -49,8 +49,8 @@ func TestZeroReceiveWindow(t *testing.T) {
// Expect the DUT to eventually advertise zero receive window.
// The test would timeout otherwise.
for readOnce := false; ; {
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected packet was not received: %s", err)
}
@@ -100,8 +100,8 @@ func TestNonZeroReceiveWindow(t *testing.T) {
// we sent. Once we have received ACKs with non-zero receive windows, we break
// the loop.
for {
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected packet was not received: %s", err)
}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
index d094c10eb..22b17a39e 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
@@ -15,6 +15,7 @@
package tcp_zero_window_probe_retransmit_test
import (
+ "bytes"
"flag"
"testing"
"time"
@@ -51,18 +52,23 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
- t.Fatalf("expected packet was not received: %s", err)
- }
// Check for the dut to keep the connection alive as long as the zero window
// probes are acknowledged. Check if the zero window probes are sent at
// exponentially increasing intervals. The timeout intervals are function
// of the recorded first zero probe transmission duration.
//
- // Advertize zero receive window again.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ // Advertize zero receive window along with a payload.
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(0)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected packet was not received: %s", err)
+ }
+ // Wait for the payload to be received by the DUT, which is also an
+ // indication of receive of the peer window advertisement.
+ if got := dut.Recv(t, acceptFd, int32(len(sampleData)), 0); !bytes.Equal(got, sampleData) {
+ t.Fatalf("got dut.Recv(t, %d, %d, 0) = %s, want %s", acceptFd, len(sampleData), got, sampleData)
+ }
+
probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
@@ -98,10 +104,10 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
}
prev = got
// Acknowledge the zero-window probes from the dut.
- conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
}
// Advertize non-zero window.
- conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.TCPFlags(header.TCPFlagAck)})
// Expect the dut to recover and transmit data.
if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go
index 650a569cc..8b90fcbe9 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go
@@ -53,8 +53,8 @@ func TestZeroWindowProbe(t *testing.T) {
t.Fatalf("expected payload was not received: %s", err)
}
sendTime := time.Now().Sub(start)
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected packet was not received: %s", err)
}
@@ -62,7 +62,7 @@ func TestZeroWindowProbe(t *testing.T) {
// probe to be sent.
//
// Advertize zero window to the dut.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Expected sequence number of the zero window probe.
probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
@@ -93,7 +93,7 @@ func TestZeroWindowProbe(t *testing.T) {
// and sends out the sample payload after the send window opens.
//
// Advertize non-zero window to the dut and ack the zero window probe.
- conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.TCPFlags(header.TCPFlagAck)})
// Expect the dut to recover and transmit data.
if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
@@ -104,8 +104,8 @@ func TestZeroWindowProbe(t *testing.T) {
// Basically with sequence number to one byte behind the unacknowledged
// sequence number.
p := testbench.Uint32(uint32(*conn.LocalSeqNum(t)))
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with ack number: %d: %s", p, err)
}
}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
index 079fea68c..1ce4d22b7 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
@@ -51,8 +51,8 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected packet was not received: %s", err)
}
@@ -60,7 +60,7 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
// probe to be sent.
//
// Advertize zero window to the dut.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Expected sequence number of the zero window probe.
probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
@@ -81,7 +81,7 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
// Reduce the retransmit timeout.
dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds()))
// Advertize zero window again.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Ask the dut to send out data that would trigger zero window probe retransmissions.
dut.Send(t, acceptFd, sampleData, 0)
@@ -90,8 +90,8 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
// Expect the connection to have timed out and closed which would cause the dut
// to reply with a RST to the ACK we send.
- conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
index 52c6f9d91..f63cfcc9a 100644
--- a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
+++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
@@ -19,7 +19,6 @@ import (
"flag"
"fmt"
"net"
- "syscall"
"testing"
"golang.org/x/sys/unix"
@@ -56,7 +55,7 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) {
)
ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0)
- if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
+ if errno != unix.EAGAIN || errno != unix.EWOULDBLOCK {
t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno)
}
})
@@ -86,7 +85,7 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) {
&testbench.Payload{Bytes: []byte("test payload")},
)
ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0)
- if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
+ if errno != unix.EAGAIN || errno != unix.EWOULDBLOCK {
t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno)
}
})
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
index 3fca8c7a3..3159d5b89 100644
--- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -20,7 +20,6 @@ import (
"fmt"
"net"
"sync"
- "syscall"
"testing"
"time"
@@ -86,16 +85,16 @@ type testData struct {
remotePort uint16
cleanFD int32
cleanPort uint16
- wantErrno syscall.Errno
+ wantErrno unix.Errno
}
// wantErrno computes the errno to expect given the connection mode of a UDP
// socket and the ICMP error it will receive.
-func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno {
+func wantErrno(c connectionMode, icmpErr icmpError) unix.Errno {
if c && icmpErr == portUnreachable {
- return syscall.Errno(unix.ECONNREFUSED)
+ return unix.ECONNREFUSED
}
- return syscall.Errno(0)
+ return unix.Errno(0)
}
// sendICMPError sends an ICMP error message in response to a UDP datagram.
@@ -123,7 +122,7 @@ func sendICMPError(t *testing.T, conn *testbench.UDPIPv4, icmpErr icmpError, udp
conn.SendFrameStateless(t, layers)
}
-// testRecv tests observing the ICMP error through the recv syscall. A packet
+// testRecv tests observing the ICMP error through the recv unix. A packet
// is sent to the DUT, and if wantErrno is non-zero, then the first recv should
// fail and the second should succeed. Otherwise if wantErrno is zero then the
// first recv should succeed immediately.
@@ -136,7 +135,7 @@ func testRecv(ctx context.Context, t *testing.T, d testData) {
d.conn.Send(t, testbench.UDP{})
- if d.wantErrno != syscall.Errno(0) {
+ if d.wantErrno != unix.Errno(0) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0)
@@ -162,7 +161,7 @@ func testSendTo(ctx context.Context, t *testing.T, d testData) {
t.Fatalf("did not receive UDP packet from clean socket on DUT: %s", err)
}
- if d.wantErrno != syscall.Errno(0) {
+ if d.wantErrno != unix.Errno(0) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t))
@@ -183,11 +182,11 @@ func testSendTo(ctx context.Context, t *testing.T, d testData) {
func testSockOpt(_ context.Context, t *testing.T, d testData) {
// Check that there's no pending error on the clean socket.
- if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) {
+ if errno := unix.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != unix.Errno(0) {
t.Fatalf("unexpected error (%[1]d) %[1]v on clean socket", errno)
}
- if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno {
+ if errno := unix.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno {
t.Fatalf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno)
}
@@ -310,7 +309,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
go func() {
defer wg.Done()
- if wantErrno != syscall.Errno(0) {
+ if wantErrno != unix.Errno(0) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
index 894d156cf..230b012c7 100644
--- a/test/packetimpact/tests/udp_send_recv_dgram_test.go
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -19,7 +19,6 @@ import (
"flag"
"fmt"
"net"
- "syscall"
"testing"
"time"
@@ -241,7 +240,7 @@ func TestUDP(t *testing.T) {
},
)
ret, recvPayload, errno := dut.RecvWithErrno(context.Background(), t, socketFD, 100, 0)
- if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
+ if errno != unix.EAGAIN || errno != unix.EWOULDBLOCK {
t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, recvPayload, errno)
}
}
diff --git a/test/perf/BUILD b/test/perf/BUILD
index e25f090ae..ed899ac22 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -1,4 +1,3 @@
-load("//tools:defs.bzl", "more_shards")
load("//test/runner:defs.bzl", "syscall_test")
package(licenses = ["notice"])
@@ -38,7 +37,6 @@ syscall_test(
syscall_test(
size = "enormous",
debug = False,
- shard_count = more_shards,
tags = ["nogotsan"],
test = "//test/perf/linux:getdents_benchmark",
)
diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go
index 38e57d62f..2ad5f58ef 100644
--- a/test/runner/gtest/gtest.go
+++ b/test/runner/gtest/gtest.go
@@ -35,6 +35,39 @@ var (
filterBenchmarkFlag = "--benchmark_filter"
)
+// BuildTestArgs builds arguments to be passed to the test binary to execute
+// only the test cases in `indices`.
+func BuildTestArgs(indices []int, testCases []TestCase) []string {
+ var testFilter, benchFilter string
+ for _, tci := range indices {
+ tc := testCases[tci]
+ if tc.all {
+ // No argument will make all tests run.
+ return nil
+ }
+ if tc.benchmark {
+ if len(benchFilter) > 0 {
+ benchFilter += "|"
+ }
+ benchFilter += "^" + tc.Name + "$"
+ } else {
+ if len(testFilter) > 0 {
+ testFilter += ":"
+ }
+ testFilter += tc.FullName()
+ }
+ }
+
+ var args []string
+ if len(testFilter) > 0 {
+ args = append(args, fmt.Sprintf("%s=%s", filterTestFlag, testFilter))
+ }
+ if len(benchFilter) > 0 {
+ args = append(args, fmt.Sprintf("%s=%s", filterBenchmarkFlag, benchFilter))
+ }
+ return args
+}
+
// TestCase is a single gtest test case.
type TestCase struct {
// Suite is the suite for this test.
@@ -59,22 +92,6 @@ func (tc TestCase) FullName() string {
return fmt.Sprintf("%s.%s", tc.Suite, tc.Name)
}
-// Args returns arguments to be passed when invoking the test.
-func (tc TestCase) Args() []string {
- if tc.all {
- return []string{} // No arguments.
- }
- if tc.benchmark {
- return []string{
- fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name),
- fmt.Sprintf("%s=", filterTestFlag),
- }
- }
- return []string{
- fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()),
- }
-}
-
// ParseTestCases calls a gtest test binary to list its test and returns a
// slice with the name and suite of each test.
//
@@ -90,6 +107,7 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes
// We failed to list tests with the given flags. Just
// return something that will run the binary with no
// flags, which should execute all tests.
+ fmt.Printf("failed to get test list: %v\n", err)
return []TestCase{
{
Suite: "Default",
diff --git a/test/runner/runner.go b/test/runner/runner.go
index e72c59200..a8a134fe2 100644
--- a/test/runner/runner.go
+++ b/test/runner/runner.go
@@ -26,7 +26,6 @@ import (
"path/filepath"
"strings"
"syscall"
- "testing"
"time"
specs "github.com/opencontainers/runtime-spec/specs-go"
@@ -57,13 +56,82 @@ var (
leakCheck = flag.Bool("leak-check", false, "check for reference leaks")
)
+func main() {
+ flag.Parse()
+ if flag.NArg() != 1 {
+ fatalf("test must be provided")
+ }
+
+ log.SetLevel(log.Info)
+ if *debug {
+ log.SetLevel(log.Debug)
+ }
+
+ if *platform != "native" && *runscPath == "" {
+ if err := testutil.ConfigureExePath(); err != nil {
+ panic(err.Error())
+ }
+ *runscPath = specutils.ExePath
+ }
+
+ // Make sure stdout and stderr are opened with O_APPEND, otherwise logs
+ // from outside the sandbox can (and will) stomp on logs from inside
+ // the sandbox.
+ for _, f := range []*os.File{os.Stdout, os.Stderr} {
+ flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0)
+ if err != nil {
+ fatalf("error getting file flags for %v: %v", f, err)
+ }
+ if flags&unix.O_APPEND == 0 {
+ flags |= unix.O_APPEND
+ if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil {
+ fatalf("error setting file flags for %v: %v", f, err)
+ }
+ }
+ }
+
+ // Resolve the absolute path for the binary.
+ testBin, err := filepath.Abs(flag.Args()[0])
+ if err != nil {
+ fatalf("Abs(%q) failed: %v", flag.Args()[0], err)
+ }
+
+ // Get all test cases in each binary.
+ testCases, err := gtest.ParseTestCases(testBin, true)
+ if err != nil {
+ fatalf("ParseTestCases(%q) failed: %v", testBin, err)
+ }
+
+ // Get subset of tests corresponding to shard.
+ indices, err := testutil.TestIndicesForShard(len(testCases))
+ if err != nil {
+ fatalf("TestsForShard() failed: %v", err)
+ }
+ if len(indices) == 0 {
+ log.Warningf("No tests to run in this shard")
+ return
+ }
+ args := gtest.BuildTestArgs(indices, testCases)
+
+ switch *platform {
+ case "native":
+ if err := runTestCaseNative(testBin, args); err != nil {
+ fatalf(err.Error())
+ }
+ default:
+ if err := runTestCaseRunsc(testBin, args); err != nil {
+ fatalf(err.Error())
+ }
+ }
+}
+
// runTestCaseNative runs the test case directly on the host machine.
-func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
+func runTestCaseNative(testBin string, args []string) error {
// These tests might be running in parallel, so make sure they have a
// unique test temp dir.
tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "")
if err != nil {
- t.Fatalf("could not create temp dir: %v", err)
+ return fmt.Errorf("could not create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
@@ -84,12 +152,12 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
}
// Remove shard env variables so that the gunit binary does not try to
// interpret them.
- env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"})
+ env = filterEnv(env, "TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS")
if *addUDSTree {
socketDir, cleanup, err := uds.CreateSocketTree("/tmp")
if err != nil {
- t.Fatalf("failed to create socket tree: %v", err)
+ return fmt.Errorf("failed to create socket tree: %v", err)
}
defer cleanup()
@@ -99,24 +167,25 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir)
}
- cmd := exec.Command(testBin, tc.Args()...)
+ cmd := exec.Command(testBin, args...)
cmd.Env = env
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
- cmd.SysProcAttr = &syscall.SysProcAttr{}
+ cmd.SysProcAttr = &unix.SysProcAttr{}
if specutils.HasCapabilities(capability.CAP_SYS_ADMIN) {
- cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWUTS
+ cmd.SysProcAttr.Cloneflags |= unix.CLONE_NEWUTS
}
if specutils.HasCapabilities(capability.CAP_NET_ADMIN) {
- cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWNET
+ cmd.SysProcAttr.Cloneflags |= unix.CLONE_NEWNET
}
if err := cmd.Run(); err != nil {
ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus)
- t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus())
+ return fmt.Errorf("test exited with status %d, want 0", ws.ExitStatus())
}
+ return nil
}
// runRunsc runs spec in runsc in a standard test configuration.
@@ -124,7 +193,7 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) {
// runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR.
//
// Returns an error if the sandboxed application exits non-zero.
-func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
+func runRunsc(spec *specs.Spec) error {
bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
if err != nil {
return fmt.Errorf("SetupBundleDir failed: %v", err)
@@ -137,9 +206,8 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
}
defer cleanup()
- name := tc.FullName()
id := testutil.RandomContainerID()
- log.Infof("Running test %q in container %q", name, id)
+ log.Infof("Running test in container %q", id)
specutils.LogSpec(spec)
args := []string{
@@ -148,7 +216,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
"-log-format=text",
"-TESTONLY-unsafe-nonroot=true",
"-net-raw=true",
- fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM),
+ fmt.Sprintf("-panic-signal=%d", unix.SIGTERM),
"-watchdog-action=panic",
"-platform", *platform,
"-file-access", *fileAccess,
@@ -175,13 +243,8 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
args = append(args, "-ref-leak-mode=log-names")
}
- testLogDir := ""
- if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
- // Create log directory dedicated for this test.
- testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1))
- if err := os.MkdirAll(testLogDir, 0755); err != nil {
- return fmt.Errorf("could not create test dir: %v", err)
- }
+ testLogDir := os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR")
+ if len(testLogDir) > 0 {
debugLogDir, err := ioutil.TempDir(testLogDir, "runsc")
if err != nil {
return fmt.Errorf("could not create temp dir: %v", err)
@@ -200,8 +263,8 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
// as root inside that namespace to get it.
rArgs := append(args, "run", "--bundle", bundleDir, id)
cmd := exec.Command(*runscPath, rArgs...)
- cmd.SysProcAttr = &syscall.SysProcAttr{
- Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS,
+ cmd.SysProcAttr = &unix.SysProcAttr{
+ Cloneflags: unix.CLONE_NEWUSER | unix.CLONE_NEWNS,
// Set current user/group as root inside the namespace.
UidMappings: []syscall.SysProcIDMap{
{ContainerID: 0, HostID: os.Getuid(), Size: 1},
@@ -219,14 +282,14 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
cmd.Stderr = os.Stderr
sig := make(chan os.Signal, 1)
defer close(sig)
- signal.Notify(sig, syscall.SIGTERM)
+ signal.Notify(sig, unix.SIGTERM)
defer signal.Stop(sig)
go func() {
s, ok := <-sig
if !ok {
return
}
- log.Warningf("%s: Got signal: %v", name, s)
+ log.Warningf("Got signal: %v", s)
done := make(chan bool, 1)
dArgs := append([]string{}, args...)
dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id)
@@ -247,7 +310,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
log.Warningf("Send SIGTERM to the sandbox process")
dArgs = append(args, "debug",
- fmt.Sprintf("--signal=%d", syscall.SIGTERM),
+ fmt.Sprintf("--signal=%d", unix.SIGTERM),
id)
signal := exec.Command(*runscPath, dArgs...)
signal.Stdout = os.Stdout
@@ -259,7 +322,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error {
if err == nil && len(testLogDir) > 0 {
// If the test passed, then we erase the log directory. This speeds up
// uploading logs in continuous integration & saves on disk space.
- os.RemoveAll(testLogDir)
+ _ = os.RemoveAll(testLogDir)
}
return err
@@ -314,10 +377,10 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) {
}
// runsTestCaseRunsc runs the test case in runsc.
-func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
+func runTestCaseRunsc(testBin string, args []string) error {
// Run a new container with the test executable and filter for the
// given test suite and name.
- spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...)
+ spec := testutil.NewSpecWithArgs(append([]string{testBin}, args...)...)
// Mark the root as writeable, as some tests attempt to
// write to the rootfs, and expect EACCES, not EROFS.
@@ -343,12 +406,12 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// users, so make sure it is world-accessible.
tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "")
if err != nil {
- t.Fatalf("could not create temp dir: %v", err)
+ return fmt.Errorf("could not create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
if err := os.Chmod(tmpDir, 0777); err != nil {
- t.Fatalf("could not chmod temp dir: %v", err)
+ return fmt.Errorf("could not chmod temp dir: %v", err)
}
// "/tmp" is not replaced with a tmpfs mount inside the sandbox
@@ -368,13 +431,12 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Set environment variables that indicate we are running in gVisor with
// the given platform, network, and filesystem stack.
- platformVar := "TEST_ON_GVISOR"
- networkVar := "GVISOR_NETWORK"
- env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network)
- vfsVar := "GVISOR_VFS"
+ env := []string{"TEST_ON_GVISOR=" + *platform, "GVISOR_NETWORK=" + *network}
+ env = append(env, os.Environ()...)
+ const vfsVar = "GVISOR_VFS"
if *vfs2 {
env = append(env, vfsVar+"=VFS2")
- fuseVar := "FUSE_ENABLED"
+ const fuseVar = "FUSE_ENABLED"
if *fuse {
env = append(env, fuseVar+"=TRUE")
} else {
@@ -386,11 +448,11 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
// Remove shard env variables so that the gunit binary does not try to
// interpret them.
- env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"})
+ env = filterEnv(env, "TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS")
// Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to
// be backed by tmpfs.
- env = filterEnv(env, []string{"TEST_TMPDIR"})
+ env = filterEnv(env, "TEST_TMPDIR")
env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir))
spec.Process.Env = env
@@ -398,18 +460,19 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) {
if *addUDSTree {
cleanup, err := setupUDSTree(spec)
if err != nil {
- t.Fatalf("error creating UDS tree: %v", err)
+ return fmt.Errorf("error creating UDS tree: %v", err)
}
defer cleanup()
}
- if err := runRunsc(tc, spec); err != nil {
- t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err)
+ if err := runRunsc(spec); err != nil {
+ return fmt.Errorf("test failed with error %v, want nil", err)
}
+ return nil
}
// filterEnv returns an environment with the excluded variables removed.
-func filterEnv(env, exclude []string) []string {
+func filterEnv(env []string, exclude ...string) []string {
var out []string
for _, kv := range env {
ok := true
@@ -430,82 +493,3 @@ func fatalf(s string, args ...interface{}) {
fmt.Fprintf(os.Stderr, s+"\n", args...)
os.Exit(1)
}
-
-func matchString(a, b string) (bool, error) {
- return a == b, nil
-}
-
-func main() {
- flag.Parse()
- if flag.NArg() != 1 {
- fatalf("test must be provided")
- }
- testBin := flag.Args()[0] // Only argument.
-
- log.SetLevel(log.Info)
- if *debug {
- log.SetLevel(log.Debug)
- }
-
- if *platform != "native" && *runscPath == "" {
- if err := testutil.ConfigureExePath(); err != nil {
- panic(err.Error())
- }
- *runscPath = specutils.ExePath
- }
-
- // Make sure stdout and stderr are opened with O_APPEND, otherwise logs
- // from outside the sandbox can (and will) stomp on logs from inside
- // the sandbox.
- for _, f := range []*os.File{os.Stdout, os.Stderr} {
- flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0)
- if err != nil {
- fatalf("error getting file flags for %v: %v", f, err)
- }
- if flags&unix.O_APPEND == 0 {
- flags |= unix.O_APPEND
- if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil {
- fatalf("error setting file flags for %v: %v", f, err)
- }
- }
- }
-
- // Get all test cases in each binary.
- testCases, err := gtest.ParseTestCases(testBin, true)
- if err != nil {
- fatalf("ParseTestCases(%q) failed: %v", testBin, err)
- }
-
- // Get subset of tests corresponding to shard.
- indices, err := testutil.TestIndicesForShard(len(testCases))
- if err != nil {
- fatalf("TestsForShard() failed: %v", err)
- }
-
- // Resolve the absolute path for the binary.
- testBin, err = filepath.Abs(testBin)
- if err != nil {
- fatalf("Abs() failed: %v", err)
- }
-
- // Run the tests.
- var tests []testing.InternalTest
- for _, tci := range indices {
- // Capture tc.
- tc := testCases[tci]
- tests = append(tests, testing.InternalTest{
- Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name),
- F: func(t *testing.T) {
- if *platform == "native" {
- // Run the test case on host.
- runTestCaseNative(testBin, tc, t)
- } else {
- // Run the test case in runsc.
- runTestCaseRunsc(testBin, tc, t)
- }
- },
- })
- }
-
- testing.Main(matchString, tests, nil, nil)
-}
diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD
index fdc6d3173..b4a9b12de 100644
--- a/test/runtimes/proctor/BUILD
+++ b/test/runtimes/proctor/BUILD
@@ -7,5 +7,8 @@ go_binary(
srcs = ["main.go"],
pure = True,
visibility = ["//test/runtimes:__pkg__"],
- deps = ["//test/runtimes/proctor/lib"],
+ deps = [
+ "//test/runtimes/proctor/lib",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
)
diff --git a/test/runtimes/proctor/lib/BUILD b/test/runtimes/proctor/lib/BUILD
index 0c8367dfe..f834f1b5a 100644
--- a/test/runtimes/proctor/lib/BUILD
+++ b/test/runtimes/proctor/lib/BUILD
@@ -13,6 +13,7 @@ go_library(
"python.go",
],
visibility = ["//test/runtimes/proctor:__pkg__"],
+ deps = ["@org_golang_x_sys//unix:go_default_library"],
)
go_test(
diff --git a/test/runtimes/proctor/lib/lib.go b/test/runtimes/proctor/lib/lib.go
index f2ba82498..36c60088a 100644
--- a/test/runtimes/proctor/lib/lib.go
+++ b/test/runtimes/proctor/lib/lib.go
@@ -22,7 +22,8 @@ import (
"os/signal"
"path/filepath"
"regexp"
- "syscall"
+
+ "golang.org/x/sys/unix"
)
// TestRunner is an interface that must be implemented for each runtime
@@ -59,7 +60,7 @@ func TestRunnerForRuntime(runtime string) (TestRunner, error) {
func PauseAndReap() {
// Get notified of any new children.
ch := make(chan os.Signal, 1)
- signal.Notify(ch, syscall.SIGCHLD)
+ signal.Notify(ch, unix.SIGCHLD)
for {
if _, ok := <-ch; !ok {
@@ -69,7 +70,7 @@ func PauseAndReap() {
// Reap the child.
for {
- if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 {
+ if cpid, _ := unix.Wait4(-1, nil, unix.WNOHANG, nil); cpid < 1 {
break
}
}
diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go
index 81cb68381..8c076a499 100644
--- a/test/runtimes/proctor/main.go
+++ b/test/runtimes/proctor/main.go
@@ -22,8 +22,8 @@ import (
"log"
"os"
"strings"
- "syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/test/runtimes/proctor/lib"
)
@@ -42,14 +42,14 @@ func setNumFilesLimit() error {
// timeout if the NOFILE limit is too high. On gVisor, syscalls are
// slower so these tests will need even more time to pass.
const nofile = 32768
- rLimit := syscall.Rlimit{}
- err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit)
+ rLimit := unix.Rlimit{}
+ err := unix.Getrlimit(unix.RLIMIT_NOFILE, &rLimit)
if err != nil {
return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err)
}
if rLimit.Cur > nofile {
rLimit.Cur = nofile
- err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit)
+ err := unix.Setrlimit(unix.RLIMIT_NOFILE, &rLimit)
if err != nil {
return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err)
}
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 9adb1cea3..ef299799e 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -65,14 +65,8 @@ syscall_test(
syscall_test(
size = "large",
- # Produce too many logs in the debug mode.
- debug = False,
shard_count = most_shards,
- # Takes too long for TSAN. Since this is kind of a stress test that doesn't
- # involve much concurrency, TSAN's usefulness here is limited anyway.
- tags = ["nogotsan"],
test = "//test/syscalls/linux:socket_stress_test",
- vfs2 = False,
)
syscall_test(
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 5371f825c..5399d8106 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -2330,13 +2330,15 @@ cc_binary(
],
linkstatic = 1,
deps = [
+ gtest,
":ip_socket_test_util",
":socket_test_util",
- "@com_google_absl//absl/strings",
- gtest,
+ "//test/util:file_descriptor",
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
],
)
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index e508ce27f..61a421788 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -2162,7 +2162,13 @@ class BlockingChild {
return tid_;
}
- void Join() { Stop(); }
+ void Join() {
+ {
+ absl::MutexLock ml(&mu_);
+ stop_ = true;
+ }
+ thread_.Join();
+ }
private:
void Start() {
@@ -2172,11 +2178,6 @@ class BlockingChild {
mu_.Await(absl::Condition(&stop_));
}
- void Stop() {
- absl::MutexLock ml(&mu_);
- stop_ = true;
- }
-
mutable absl::Mutex mu_;
bool stop_ ABSL_GUARDED_BY(mu_) = false;
pid_t tid_;
@@ -2190,16 +2191,18 @@ class BlockingChild {
TEST(ProcTask, NewThreadAppears) {
auto initial = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/task", false));
BlockingChild child1;
- EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task",
- TaskFiles(initial, {child1.Tid()})));
+ // Use Eventually* in case a proc from ealier test is still tearing down.
+ EXPECT_NO_ERRNO(EventuallyDirContainsExactly(
+ "/proc/self/task", TaskFiles(initial, {child1.Tid()})));
}
TEST(ProcTask, KilledThreadsDisappear) {
auto initial = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/task/", false));
BlockingChild child1;
- EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task",
- TaskFiles(initial, {child1.Tid()})));
+ // Use Eventually* in case a proc from ealier test is still tearing down.
+ EXPECT_NO_ERRNO(EventuallyDirContainsExactly(
+ "/proc/self/task", TaskFiles(initial, {child1.Tid()})));
// Stat child1's task file. Regression test for b/32097707.
struct stat statbuf;
diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc
index 73140b2e9..20f1dc305 100644
--- a/test/syscalls/linux/proc_net.cc
+++ b/test/syscalls/linux/proc_net.cc
@@ -40,6 +40,7 @@ namespace {
constexpr const char kProcNet[] = "/proc/net";
constexpr const char kIpForward[] = "/proc/sys/net/ipv4/ip_forward";
+constexpr const char kRangeFile[] = "/proc/sys/net/ipv4/ip_local_port_range";
TEST(ProcNetSymlinkTarget, FileMode) {
struct stat s;
@@ -562,6 +563,42 @@ TEST(ProcSysNetIpv4IpForward, CanReadAndWrite) {
EXPECT_EQ(buf, to_write);
}
+TEST(ProcSysNetPortRange, CanReadAndWrite) {
+ int min;
+ int max;
+ std::string rangefile = ASSERT_NO_ERRNO_AND_VALUE(GetContents(kRangeFile));
+ ASSERT_EQ(rangefile.back(), '\n');
+ rangefile.pop_back();
+ std::vector<std::string> range =
+ absl::StrSplit(rangefile, absl::ByAnyChar("\t "));
+ ASSERT_GT(range.size(), 1);
+ ASSERT_TRUE(absl::SimpleAtoi(range.front(), &min));
+ ASSERT_TRUE(absl::SimpleAtoi(range.back(), &max));
+ EXPECT_LE(min, max);
+
+ // If the file isn't writable, there's nothing else to do here.
+ if (access(kRangeFile, W_OK)) {
+ return;
+ }
+
+ constexpr int kSize = 77;
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(kRangeFile, O_WRONLY | O_TRUNC, 0));
+ max = min + kSize;
+ const std::string small_range = absl::StrFormat("%d %d", min, max);
+ ASSERT_THAT(write(fd.get(), small_range.c_str(), small_range.size()),
+ SyscallSucceedsWithValue(small_range.size()));
+
+ rangefile = ASSERT_NO_ERRNO_AND_VALUE(GetContents(kRangeFile));
+ ASSERT_EQ(rangefile.back(), '\n');
+ rangefile.pop_back();
+ range = absl::StrSplit(rangefile, absl::ByAnyChar("\t "));
+ ASSERT_GT(range.size(), 1);
+ ASSERT_TRUE(absl::SimpleAtoi(range.front(), &min));
+ ASSERT_TRUE(absl::SimpleAtoi(range.back(), &max));
+ EXPECT_EQ(min + kSize, max);
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index 294b9f6fd..8d15c491e 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -1255,8 +1255,11 @@ TEST_F(PtyTest, PartialBadBuffer) {
// Read from the replica into bad_buffer.
ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), size));
- EXPECT_THAT(ReadFd(replica_.get(), bad_buffer, size),
- SyscallFailsWithErrno(EFAULT));
+ // Before Linux 3b830a9c this returned EFAULT, but after that commit it
+ // returns EAGAIN.
+ EXPECT_THAT(
+ ReadFd(replica_.get(), bad_buffer, size),
+ AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EAGAIN)));
EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()) << addr;
}
diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc
index 679586530..c35aa2183 100644
--- a/test/syscalls/linux/socket_generic_stress.cc
+++ b/test/syscalls/linux/socket_generic_stress.cc
@@ -17,29 +17,72 @@
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/un.h>
+#include <unistd.h>
#include <array>
#include <string>
#include "gtest/gtest.h"
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/syscalls/linux/ip_socket_test_util.h"
#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
namespace gvisor {
namespace testing {
+constexpr char kRangeFile[] = "/proc/sys/net/ipv4/ip_local_port_range";
+
+PosixErrorOr<int> NumPorts() {
+ int min = 0;
+ int max = 1 << 16;
+
+ // Read the ephemeral range from /proc.
+ ASSIGN_OR_RETURN_ERRNO(std::string rangefile, GetContents(kRangeFile));
+ const std::string err_msg =
+ absl::StrFormat("%s has invalid content: %s", kRangeFile, rangefile);
+ if (rangefile.back() != '\n') {
+ return PosixError(EINVAL, err_msg);
+ }
+ rangefile.pop_back();
+ std::vector<std::string> range =
+ absl::StrSplit(rangefile, absl::ByAnyChar("\t "));
+ if (range.size() < 2 || !absl::SimpleAtoi(range.front(), &min) ||
+ !absl::SimpleAtoi(range.back(), &max)) {
+ return PosixError(EINVAL, err_msg);
+ }
+
+ // If we can open as writable, limit the range.
+ if (!access(kRangeFile, W_OK)) {
+ ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd,
+ Open(kRangeFile, O_WRONLY | O_TRUNC, 0));
+ max = min + 50;
+ const std::string small_range = absl::StrFormat("%d %d", min, max);
+ int n = write(fd.get(), small_range.c_str(), small_range.size());
+ if (n < 0) {
+ return PosixError(
+ errno,
+ absl::StrFormat("write(%d [%s], \"%s\", %d)", fd.get(), kRangeFile,
+ small_range.c_str(), small_range.size()));
+ }
+ }
+ return max - min;
+}
+
// Test fixture for tests that apply to pairs of connected sockets.
using ConnectStressTest = SocketPairTest;
-TEST_P(ConnectStressTest, Reset65kTimes) {
- // TODO(b/165912341): These are too slow on KVM platform with nested virt.
- SKIP_IF(GvisorPlatform() == Platform::kKVM);
-
- for (int i = 0; i < 1 << 16; ++i) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(ConnectStressTest, Reset) {
+ const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts());
+ for (int i = 0; i < nports * 2; i++) {
+ const std::unique_ptr<SocketPair> sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
// Send some data to ensure that the connection gets reset and the port gets
// released immediately. This avoids either end entering TIME-WAIT.
@@ -57,6 +100,24 @@ TEST_P(ConnectStressTest, Reset65kTimes) {
}
}
+// Tests that opening too many connections -- without closing them -- does lead
+// to port exhaustion.
+TEST_P(ConnectStressTest, TooManyOpen) {
+ const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts());
+ int err_num = 0;
+ std::vector<std::unique_ptr<SocketPair>> sockets =
+ std::vector<std::unique_ptr<SocketPair>>(nports);
+ for (int i = 0; i < nports * 2; i++) {
+ PosixErrorOr<std::unique_ptr<SocketPair>> socks = NewSocketPair();
+ if (!socks.ok()) {
+ err_num = socks.error().errno_value();
+ break;
+ }
+ sockets.push_back(std::move(socks).ValueOrDie());
+ }
+ ASSERT_EQ(err_num, EADDRINUSE);
+}
+
INSTANTIATE_TEST_SUITE_P(
AllConnectedSockets, ConnectStressTest,
::testing::Values(IPv6UDPBidirectionalBindSocketPair(0),
@@ -73,14 +134,40 @@ INSTANTIATE_TEST_SUITE_P(
// Test fixture for tests that apply to pairs of connected sockets created with
// a persistent listener (if applicable).
-using PersistentListenerConnectStressTest = SocketPairTest;
+class PersistentListenerConnectStressTest : public SocketPairTest {
+ protected:
+ PersistentListenerConnectStressTest() : slept_{false} {}
-TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) {
- // TODO(b/165912341): These are too slow on KVM platform with nested virt.
- SKIP_IF(GvisorPlatform() == Platform::kKVM);
+ // NewSocketSleep is the same as NewSocketPair, but will sleep once (over the
+ // lifetime of the fixture) and retry if creation fails due to EADDRNOTAVAIL.
+ PosixErrorOr<std::unique_ptr<SocketPair>> NewSocketSleep() {
+ // We can't reuse a connection too close in time to its last use, as TCP
+ // uses the timestamp difference to disambiguate connections. With a
+ // sufficiently small port range, we'll cycle through too quickly, and TCP
+ // won't allow for connection reuse. Thus, we sleep the first time
+ // encountering EADDRINUSE to allow for that difference (1 second in
+ // gVisor).
+ PosixErrorOr<std::unique_ptr<SocketPair>> socks = NewSocketPair();
+ if (socks.ok()) {
+ return socks;
+ }
+ if (!slept_ && socks.error().errno_value() == EADDRNOTAVAIL) {
+ absl::SleepFor(absl::Milliseconds(1500));
+ slept_ = true;
+ return NewSocketPair();
+ }
+ return socks;
+ }
- for (int i = 0; i < 1 << 16; ++i) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ private:
+ bool slept_;
+};
+
+TEST_P(PersistentListenerConnectStressTest, ShutdownCloseFirst) {
+ const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts());
+ for (int i = 0; i < nports * 2; i++) {
+ std::unique_ptr<SocketPair> sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketSleep());
ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds());
if (GetParam().type == SOCK_STREAM) {
// Poll the other FD to make sure that we see the FIN from the other
@@ -97,12 +184,11 @@ TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) {
}
}
-TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) {
- // TODO(b/165912341): These are too slow on KVM platform with nested virt.
- SKIP_IF(GvisorPlatform() == Platform::kKVM);
-
- for (int i = 0; i < 1 << 16; ++i) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(PersistentListenerConnectStressTest, ShutdownCloseSecond) {
+ const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts());
+ for (int i = 0; i < nports * 2; i++) {
+ const std::unique_ptr<SocketPair> sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds());
if (GetParam().type == SOCK_STREAM) {
// Poll the other FD to make sure that we see the FIN from the other
@@ -119,12 +205,11 @@ TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) {
}
}
-TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) {
- // TODO(b/165912341): These are too slow on KVM platform with nested virt.
- SKIP_IF(GvisorPlatform() == Platform::kKVM);
-
- for (int i = 0; i < 1 << 16; ++i) {
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+TEST_P(PersistentListenerConnectStressTest, Close) {
+ const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts());
+ for (int i = 0; i < nports * 2; i++) {
+ std::unique_ptr<SocketPair> sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketSleep());
}
}
@@ -149,7 +234,8 @@ TEST_P(DataTransferStressTest, BigDataTransfer) {
// TODO(b/165912341): These are too slow on KVM platform with nested virt.
SKIP_IF(GvisorPlatform() == Platform::kKVM);
- auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+ const std::unique_ptr<SocketPair> sockets =
+ ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
int client_fd = sockets->first_fd();
int server_fd = sockets->second_fd();
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 344a5a22c..54b45b075 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -705,12 +705,6 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) {
ds.reset();
- if (!IsRunningOnGvisor()) {
- ASSERT_THAT(
- bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr),
- conn_addrlen),
- SyscallSucceeds());
- }
ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(),
reinterpret_cast<sockaddr*>(&conn_addr),
conn_addrlen),
diff --git a/test/uds/BUILD b/test/uds/BUILD
index 51e2c7ce8..a8f49b50c 100644
--- a/test/uds/BUILD
+++ b/test/uds/BUILD
@@ -12,5 +12,6 @@ go_library(
deps = [
"//pkg/log",
"//pkg/unet",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/test/uds/uds.go b/test/uds/uds.go
index b714c61b0..02a4a7dee 100644
--- a/test/uds/uds.go
+++ b/test/uds/uds.go
@@ -21,8 +21,8 @@ import (
"io/ioutil"
"os"
"path/filepath"
- "syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -31,16 +31,16 @@ import (
//
// Only works for stream, seqpacket sockets.
func createEchoSocket(path string, protocol int) (cleanup func(), err error) {
- fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ fd, err := unix.Socket(unix.AF_UNIX, protocol, 0)
if err != nil {
return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err)
}
- if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil {
return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err)
}
- if err := syscall.Listen(fd, 0); err != nil {
+ if err := unix.Listen(fd, 0); err != nil {
return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err)
}
@@ -97,17 +97,17 @@ func createEchoSocket(path string, protocol int) (cleanup func(), err error) {
//
// Only relevant for stream, seqpacket sockets.
func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) {
- fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ fd, err := unix.Socket(unix.AF_UNIX, protocol, 0)
if err != nil {
return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err)
}
- if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil {
return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err)
}
cleanup = func() {
- if err := syscall.Close(fd); err != nil {
+ if err := unix.Close(fd); err != nil {
log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err)
}
}
@@ -119,12 +119,12 @@ func createNonListeningSocket(path string, protocol int) (cleanup func(), err er
//
// Only works for dgram sockets.
func createNullSocket(path string, protocol int) (cleanup func(), err error) {
- fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0)
+ fd, err := unix.Socket(unix.AF_UNIX, protocol, 0)
if err != nil {
return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err)
}
- if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil {
+ if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil {
return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err)
}
@@ -174,7 +174,7 @@ func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) {
sockets map[string]socketCreator
}{
{
- protocol: syscall.SOCK_STREAM,
+ protocol: unix.SOCK_STREAM,
name: "stream",
sockets: map[string]socketCreator{
"echo": createEchoSocket,
@@ -182,7 +182,7 @@ func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) {
},
},
{
- protocol: syscall.SOCK_SEQPACKET,
+ protocol: unix.SOCK_SEQPACKET,
name: "seqpacket",
sockets: map[string]socketCreator{
"echo": createEchoSocket,
@@ -190,7 +190,7 @@ func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) {
},
},
{
- protocol: syscall.SOCK_DGRAM,
+ protocol: unix.SOCK_DGRAM,
name: "dgram",
sockets: map[string]socketCreator{
"null": createNullSocket,
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index abd6f69ea..634abd1af 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -427,7 +427,7 @@ func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *inter
// implementations type t.
func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator {
i := newTestGenerator(t.spec, t.recv)
- i.emitTests(t.slice, t.dynamic)
+ i.emitTests(t.slice)
return i
}
@@ -488,7 +488,11 @@ func (g *Generator) Run() error {
panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name))
}
}
- ts = append(ts, g.generateOneTestSuite(t))
+ // Do not generate tests for dynamic types because they inherently
+ // violate some go_marshal requirements.
+ if !t.dynamic {
+ ts = append(ts, g.generateOneTestSuite(t))
+ }
}
}
diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go
index ca3e15c16..6cf00843f 100644
--- a/tools/go_marshal/gomarshal/generator_tests.go
+++ b/tools/go_marshal/gomarshal/generator_tests.go
@@ -216,16 +216,12 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() {
})
}
-func (g *testGenerator) emitTests(slice *sliceAPI, isDynamic bool) {
+func (g *testGenerator) emitTests(slice *sliceAPI) {
g.emitTestNonZeroSize()
g.emitTestSuspectAlignment()
- if !isDynamic {
- // Do not test these for dynamic structs because they violate some
- // assumptions that these tests make.
- g.emitTestMarshalUnmarshalPreservesData()
- g.emitTestWriteToUnmarshalPreservesData()
- g.emitTestSizeBytesOnTypedNilPtr()
- }
+ g.emitTestMarshalUnmarshalPreservesData()
+ g.emitTestWriteToUnmarshalPreservesData()
+ g.emitTestSizeBytesOnTypedNilPtr()
if slice != nil {
g.emitTestMarshalUnmarshalSlicePreservesData(slice)
diff --git a/tools/make_apt.sh b/tools/make_apt.sh
index 68f6973ec..935c4db2d 100755
--- a/tools/make_apt.sh
+++ b/tools/make_apt.sh
@@ -107,7 +107,9 @@ for pkg in "$@"; do
cp -a -L "$(dirname "${pkg}")/${name}.deb" "${destdir}"
cp -a -L "$(dirname "${pkg}")/${name}.changes" "${destdir}"
chmod 0644 "${destdir}"/"${name}".*
+ # Sign a package only if it isn't signed yet.
# We use [*] here to expand the gpg_opts array into a single shell-word.
+ dpkg-sig -g "${gpg_opts[*]}" --verify "${destdir}/${name}.deb" ||
dpkg-sig -g "${gpg_opts[*]}" --sign builder "${destdir}/${name}.deb"
done
diff --git a/tools/verity/BUILD b/tools/verity/BUILD
new file mode 100644
index 000000000..77d16359c
--- /dev/null
+++ b/tools/verity/BUILD
@@ -0,0 +1,15 @@
+load("//tools:defs.bzl", "go_binary")
+
+licenses(["notice"])
+
+go_binary(
+ name = "measure_tool",
+ srcs = [
+ "measure_tool.go",
+ "measure_tool_unsafe.go",
+ ],
+ pure = True,
+ deps = [
+ "//pkg/abi/linux",
+ ],
+)
diff --git a/tools/verity/measure_tool.go b/tools/verity/measure_tool.go
new file mode 100644
index 000000000..0d314ae70
--- /dev/null
+++ b/tools/verity/measure_tool.go
@@ -0,0 +1,87 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This binary can be used to run a measurement of the verity file system,
+// generate the corresponding Merkle tree files, and return the root hash.
+package main
+
+import (
+ "flag"
+ "io/ioutil"
+ "log"
+ "os"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
+
+var path = flag.String("path", "", "path to the verity file system.")
+
+const maxDigestSize = 64
+
+type digest struct {
+ metadata linux.DigestMetadata
+ digest [maxDigestSize]byte
+}
+
+func main() {
+ flag.Parse()
+ if *path == "" {
+ log.Fatalf("no path provided")
+ }
+ if err := enableDir(*path); err != nil {
+ log.Fatalf("Failed to enable file system %s: %v", *path, err)
+ }
+ // Print the root hash of the file system to stdout.
+ if err := measure(*path); err != nil {
+ log.Fatalf("Failed to measure file system %s: %v", *path, err)
+ }
+}
+
+// enableDir enables verity features on all the files and sub-directories within
+// path.
+func enableDir(path string) error {
+ files, err := ioutil.ReadDir(path)
+ if err != nil {
+ return err
+ }
+ for _, file := range files {
+ if file.IsDir() {
+ // For directories, first enable its children.
+ if err := enableDir(path + "/" + file.Name()); err != nil {
+ return err
+ }
+ } else if file.Mode().IsRegular() {
+ // For regular files, open and enable verity feature.
+ f, err := os.Open(path + "/" + file.Name())
+ if err != nil {
+ return err
+ }
+ var p uintptr
+ if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f.Fd()), uintptr(linux.FS_IOC_ENABLE_VERITY), p); err != 0 {
+ return err
+ }
+ }
+ }
+ // Once all children are enabled, enable the parent directory.
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ var p uintptr
+ if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f.Fd()), uintptr(linux.FS_IOC_ENABLE_VERITY), p); err != 0 {
+ return err
+ }
+ return nil
+}
diff --git a/runsc/mitigate/mitigate_conf.go b/tools/verity/measure_tool_unsafe.go
index ee326324b..d4079be9e 100644
--- a/runsc/mitigate/mitigate_conf.go
+++ b/tools/verity/measure_tool_unsafe.go
@@ -11,27 +11,29 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
-
-package mitigate
+package main
import (
- "gvisor.dev/gvisor/runsc/flag"
-)
-
-type mitigate struct {
-}
+ "encoding/hex"
+ "fmt"
+ "os"
+ "syscall"
+ "unsafe"
-// usage returns the usage string portion for the mitigate.
-func (m mitigate) usage() string { return "" }
-
-// setFlags sets additional flags for the Mitigate command.
-func (m mitigate) setFlags(f *flag.FlagSet) {}
-
-// execute performs additional parts of Execute for Mitigate.
-func (m mitigate) execute(set cpuSet, dryrun bool) error {
- return nil
-}
+ "gvisor.dev/gvisor/pkg/abi/linux"
+)
-func (m mitigate) vulnerable(other thread) bool {
- return other.isVulnerable()
+// measure prints the hash of path to stdout.
+func measure(path string) error {
+ f, err := os.Open(path)
+ if err != nil {
+ return err
+ }
+ var digest digest
+ digest.metadata.DigestSize = maxDigestSize
+ if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f.Fd()), uintptr(linux.FS_IOC_MEASURE_VERITY), uintptr(unsafe.Pointer(&digest))); err != 0 {
+ return err
+ }
+ fmt.Fprintf(os.Stdout, "%s\n", hex.EncodeToString(digest.digest[:digest.metadata.DigestSize]))
+ return err
}