summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc18
-rw-r--r--Makefile5
-rw-r--r--WORKSPACE32
-rw-r--r--images/benchmarks/ffmpeg/Dockerfile9
-rw-r--r--images/benchmarks/redis/Dockerfile1
-rw-r--r--pkg/abi/linux/BUILD1
-rw-r--r--pkg/abi/linux/fuse.go143
-rw-r--r--pkg/sentry/fs/fsutil/BUILD7
-rw-r--r--pkg/sentry/fs/fsutil/dirty_set.go7
-rw-r--r--pkg/sentry/fs/fsutil/file_range_set.go15
-rw-r--r--pkg/sentry/fs/fsutil/frame_ref_set.go10
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go5
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go19
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go25
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD41
-rw-r--r--pkg/sentry/fsimpl/fuse/connection.go255
-rw-r--r--pkg/sentry/fsimpl/fuse/dev.go289
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go429
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go52
-rw-r--r--pkg/sentry/fsimpl/fuse/register.go42
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go27
-rw-r--r--pkg/sentry/fsimpl/host/BUILD1
-rw-r--r--pkg/sentry/fsimpl/host/mmap.go21
-rw-r--r--pkg/sentry/kernel/shm/BUILD1
-rw-r--r--pkg/sentry/kernel/shm/shm.go3
-rw-r--r--pkg/sentry/kernel/timekeeper.go4
-rw-r--r--pkg/sentry/kernel/vdso.go6
-rw-r--r--pkg/sentry/memmap/BUILD14
-rw-r--r--pkg/sentry/memmap/memmap.go60
-rw-r--r--pkg/sentry/mm/BUILD4
-rw-r--r--pkg/sentry/mm/aio_context.go3
-rw-r--r--pkg/sentry/mm/mm.go10
-rw-r--r--pkg/sentry/mm/pma.go25
-rw-r--r--pkg/sentry/mm/special_mappable.go7
-rw-r--r--pkg/sentry/pgalloc/BUILD10
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go66
-rw-r--r--pkg/sentry/platform/BUILD20
-rw-r--r--pkg/sentry/platform/kvm/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/address_space.go3
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64_unsafe.go4
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go16
-rw-r--r--pkg/sentry/platform/platform.go50
-rw-r--r--pkg/sentry/platform/ptrace/BUILD1
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go3
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/vfs2.go102
-rw-r--r--pkg/tcpip/network/arp/arp.go7
-rw-r--r--pkg/tcpip/network/arp/arp_test.go58
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go8
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go52
-rw-r--r--pkg/tcpip/stack/forwarder_test.go18
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go2
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go2
-rw-r--r--pkg/tcpip/stack/nic_test.go2
-rw-r--r--pkg/tcpip/stack/registration.go7
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go2
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go4
-rw-r--r--pkg/tcpip/transport/tcp/segment.go6
-rw-r--r--pkg/tcpip/transport/tcp/segment_unsafe.go23
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go8
-rw-r--r--pkg/test/dockerutil/BUILD19
-rw-r--r--pkg/test/dockerutil/README.md86
-rw-r--r--pkg/test/dockerutil/container.go82
-rw-r--r--pkg/test/dockerutil/dockerutil.go21
-rw-r--r--pkg/test/dockerutil/profile.go152
-rw-r--r--pkg/test/dockerutil/profile_test.go117
-rw-r--r--runsc/container/container_test.go4
-rwxr-xr-xscripts/benchmark.sh30
-rwxr-xr-xscripts/common.sh27
-rwxr-xr-xscripts/docker_tests.sh4
-rw-r--r--test/benchmarks/README.md81
-rw-r--r--test/benchmarks/database/BUILD28
-rw-r--r--test/benchmarks/database/database.go31
-rw-r--r--test/benchmarks/database/redis_test.go197
-rw-r--r--test/benchmarks/fs/bazel_test.go32
-rw-r--r--test/benchmarks/harness/machine.go12
-rw-r--r--test/benchmarks/harness/util.go12
-rw-r--r--test/benchmarks/media/BUILD21
-rw-r--r--test/benchmarks/media/ffmpeg_test.go52
-rw-r--r--test/benchmarks/media/media.go31
-rw-r--r--test/benchmarks/network/BUILD1
-rw-r--r--test/benchmarks/network/httpd_test.go9
-rw-r--r--test/benchmarks/network/iperf_test.go40
-rw-r--r--test/iptables/filter_input.go30
-rw-r--r--test/iptables/iptables_util.go33
-rw-r--r--test/iptables/nat.go12
-rw-r--r--test/packetimpact/runner/packetimpact_test.go5
-rw-r--r--test/packetimpact/testbench/connections.go330
-rw-r--r--test/packetimpact/testbench/dut.go358
-rw-r--r--test/packetimpact/testbench/rawsockets.go44
-rw-r--r--test/packetimpact/tests/fin_wait2_timeout_test.go26
-rw-r--r--test/packetimpact/tests/icmpv6_param_problem_test.go8
-rw-r--r--test/packetimpact/tests/ipv4_id_uniqueness_test.go34
-rw-r--r--test/packetimpact/tests/ipv6_fragment_reassembly_test.go10
-rw-r--r--test/packetimpact/tests/ipv6_unknown_options_action_test.go44
-rw-r--r--test/packetimpact/tests/tcp_close_wait_ack_test.go63
-rw-r--r--test/packetimpact/tests/tcp_cork_mss_test.go36
-rw-r--r--test/packetimpact/tests/tcp_handshake_window_size_test.go20
-rw-r--r--test/packetimpact/tests/tcp_network_unreachable_test.go28
-rw-r--r--test/packetimpact/tests/tcp_noaccept_close_rst_test.go10
-rw-r--r--test/packetimpact/tests/tcp_outside_the_window_test.go24
-rw-r--r--test/packetimpact/tests/tcp_paws_mechanism_test.go24
-rw-r--r--test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go26
-rw-r--r--test/packetimpact/tests/tcp_reordering_test.go48
-rw-r--r--test/packetimpact/tests/tcp_retransmits_test.go28
-rw-r--r--test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go24
-rw-r--r--test/packetimpact/tests/tcp_synrcvd_reset_test.go16
-rw-r--r--test/packetimpact/tests/tcp_synsent_reset_test.go30
-rw-r--r--test/packetimpact/tests/tcp_user_timeout_test.go41
-rw-r--r--test/packetimpact/tests/tcp_window_shrink_test.go36
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go39
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_test.go44
-rw-r--r--test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go40
-rw-r--r--test/packetimpact/tests/udp_discard_mcast_source_addr_test.go22
-rw-r--r--test/packetimpact/tests/udp_icmp_error_propagation_test.go130
-rw-r--r--test/packetimpact/tests/udp_recv_mcast_bcast_test.go9
-rw-r--r--test/packetimpact/tests/udp_send_recv_dgram_test.go28
-rw-r--r--test/runner/defs.bzl2
-rw-r--r--test/syscalls/linux/mount.cc10
-rw-r--r--tools/bazel.mk82
-rw-r--r--tools/bazeldefs/BUILD37
121 files changed, 3757 insertions, 1230 deletions
diff --git a/.bazelrc b/.bazelrc
index 4a0671f4a..3c31282ce 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# RBE requires a strong hash function, such as SHA256.
+startup --host_jvm_args=-Dbazel.DigestFunction=SHA256
+
# Build with C++17.
build --cxxopt=-std=c++17
@@ -22,11 +25,17 @@ build --stamp --workspace_status_command tools/workspace_status.sh
build:remote --remote_executor=grpcs://remotebuildexecution.googleapis.com
build:remote --project_id=gvisor-rbe
build:remote --remote_instance_name=projects/gvisor-rbe/instances/default_instance
+build:remote3 --remote_executor=grpcs://remotebuildexecution.googleapis.com
+build:remote3 --project_id=gvisor-rbe
+build:remote3 --remote_instance_name=projects/gvisor-rbe/instances/default_instance
+
# Enable authentication. This will pick up application default credentials by
# default. You can use --google_credentials=some_file.json to use a service
# account credential instead.
build:remote --google_default_credentials=true
build:remote --auth_scope="https://www.googleapis.com/auth/cloud-source-tools"
+build:remote3 --google_default_credentials=true
+build:remote3 --auth_scope="https://www.googleapis.com/auth/cloud-source-tools"
# Add a custom platform and toolchain that builds in a privileged docker
# container, which is required by our syscall tests.
@@ -37,8 +46,13 @@ build:remote --platforms=//tools/bazeldefs:rbe_ubuntu1604
build:remote --crosstool_top=@rbe_default//cc:toolchain
build:remote --jobs=50
build:remote --remote_timeout=3600
-# RBE requires a strong hash function, such as SHA256.
-startup --host_jvm_args=-Dbazel.DigestFunction=SHA256
+build:remote3 --host_platform=//tools/bazeldefs:rbe_ubuntu1604_bazel3
+build:remote3 --extra_toolchains=//tools/bazeldefs:cc-toolchain-clang-x86_64-default_bazel3
+build:remote3 --extra_execution_platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3
+build:remote3 --platforms=//tools/bazeldefs:rbe_ubuntu1604_bazel3
+build:remote3 --crosstool_top=@rbe_default//cc:toolchain
+build:remote3 --jobs=50
+build:remote3 --remote_timeout=3600
# Set flags for uploading to BES in order to view results in the Bazel Build
# Results UI.
diff --git a/Makefile b/Makefile
index 7dc155ccc..8e97fc978 100644
--- a/Makefile
+++ b/Makefile
@@ -166,11 +166,14 @@ do-tests: runsc
simple-tests: unit-tests # Compatibility target.
.PHONY: simple-tests
+IMAGE_FILTER := HelloWorld\|Httpd\|Ruby\|Stdio
+INTEGRATION_FILTER := Life\|Pause\|Connect\|JobControl\|Overlay\|Exec\|DirCreation/root
+
docker-tests: load-basic-images
@$(call submake,install-test-runtime RUNTIME="vfs1")
@$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)")
@$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2")
- @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_filter=.*TestHelloWorld" TARGETS="$(INTEGRATION_TARGETS)")
+ @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_filter=$(IMAGE_FILTER)\|$(INTEGRATION_FILTER)" TARGETS="$(INTEGRATION_TARGETS)")
.PHONY: docker-tests
overlay-tests: load-basic-images
diff --git a/WORKSPACE b/WORKSPACE
index 2ab750a9d..49f231755 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -42,6 +42,28 @@ http_archive(
],
)
+http_archive(
+ name = "io_bazel_rules_go_bazel3", # To replace the above.
+ patch_args = ["-p1"],
+ patches = [
+ "//tools/nogo:io_bazel_rules_go-visibility.patch",
+ ],
+ sha256 = "87f0fb9747854cb76a0a82430adccb6269f7d394237104a4523b51061c469171",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.23.1/rules_go-v0.23.1.tar.gz",
+ ],
+)
+
+http_archive(
+ name = "bazel_gazelle_bazel3", # To replace the above.
+ sha256 = "bfd86b3cbe855d6c16c6fce60d76bd51f5c8dbc9cfcaef7a2bb5c1aafd0710e8",
+ urls = [
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz",
+ "https://github.com/bazelbuild/bazel-gazelle/releases/download/v0.21.0/bazel-gazelle-v0.21.0.tar.gz",
+ ],
+)
+
load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_dependencies")
go_rules_dependencies()
@@ -123,6 +145,16 @@ http_archive(
],
)
+http_archive(
+ name = "bazel_toolchains_bazel3", # To replace the above.
+ sha256 = "144290c4166bd67e76a54f96cd504ed86416ca3ca82030282760f0823c10be48",
+ strip_prefix = "bazel-toolchains-3.1.1",
+ urls = [
+ "https://github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/releases/download/3.1.1/bazel-toolchains-3.1.1.tar.gz",
+ ],
+)
+
# Creates a default toolchain config for RBE.
load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig")
diff --git a/images/benchmarks/ffmpeg/Dockerfile b/images/benchmarks/ffmpeg/Dockerfile
new file mode 100644
index 000000000..7108df64f
--- /dev/null
+++ b/images/benchmarks/ffmpeg/Dockerfile
@@ -0,0 +1,9 @@
+FROM ubuntu:18.04
+
+RUN set -x \
+ && apt-get update \
+ && apt-get install -y \
+ ffmpeg \
+ && rm -rf /var/lib/apt/lists/*
+WORKDIR /media
+ADD https://samples.ffmpeg.org/MPEG-4/video.mp4 video.mp4
diff --git a/images/benchmarks/redis/Dockerfile b/images/benchmarks/redis/Dockerfile
new file mode 100644
index 000000000..0f17249af
--- /dev/null
+++ b/images/benchmarks/redis/Dockerfile
@@ -0,0 +1 @@
+FROM redis:5.0.4
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD
index a4bb62013..05ca5342f 100644
--- a/pkg/abi/linux/BUILD
+++ b/pkg/abi/linux/BUILD
@@ -29,6 +29,7 @@ go_library(
"file_amd64.go",
"file_arm64.go",
"fs.go",
+ "fuse.go",
"futex.go",
"inotify.go",
"ioctl.go",
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go
new file mode 100644
index 000000000..d3ebbccc4
--- /dev/null
+++ b/pkg/abi/linux/fuse.go
@@ -0,0 +1,143 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package linux
+
+// +marshal
+type FUSEOpcode uint32
+
+// +marshal
+type FUSEOpID uint64
+
+// Opcodes for FUSE operations. Analogous to the opcodes in include/linux/fuse.h.
+const (
+ FUSE_LOOKUP FUSEOpcode = 1
+ FUSE_FORGET = 2 /* no reply */
+ FUSE_GETATTR = 3
+ FUSE_SETATTR = 4
+ FUSE_READLINK = 5
+ FUSE_SYMLINK = 6
+ _
+ FUSE_MKNOD = 8
+ FUSE_MKDIR = 9
+ FUSE_UNLINK = 10
+ FUSE_RMDIR = 11
+ FUSE_RENAME = 12
+ FUSE_LINK = 13
+ FUSE_OPEN = 14
+ FUSE_READ = 15
+ FUSE_WRITE = 16
+ FUSE_STATFS = 17
+ FUSE_RELEASE = 18
+ _
+ FUSE_FSYNC = 20
+ FUSE_SETXATTR = 21
+ FUSE_GETXATTR = 22
+ FUSE_LISTXATTR = 23
+ FUSE_REMOVEXATTR = 24
+ FUSE_FLUSH = 25
+ FUSE_INIT = 26
+ FUSE_OPENDIR = 27
+ FUSE_READDIR = 28
+ FUSE_RELEASEDIR = 29
+ FUSE_FSYNCDIR = 30
+ FUSE_GETLK = 31
+ FUSE_SETLK = 32
+ FUSE_SETLKW = 33
+ FUSE_ACCESS = 34
+ FUSE_CREATE = 35
+ FUSE_INTERRUPT = 36
+ FUSE_BMAP = 37
+ FUSE_DESTROY = 38
+ FUSE_IOCTL = 39
+ FUSE_POLL = 40
+ FUSE_NOTIFY_REPLY = 41
+ FUSE_BATCH_FORGET = 42
+)
+
+const (
+ // FUSE_MIN_READ_BUFFER is the minimum size the read can be for any FUSE filesystem.
+ // This is the minimum size Linux supports. See linux.fuse.h.
+ FUSE_MIN_READ_BUFFER uint32 = 8192
+)
+
+// FUSEHeaderIn is the header read by the daemon with each request.
+//
+// +marshal
+type FUSEHeaderIn struct {
+ // Len specifies the total length of the data, including this header.
+ Len uint32
+
+ // Opcode specifies the kind of operation of the request.
+ Opcode FUSEOpcode
+
+ // Unique specifies the unique identifier for this request.
+ Unique FUSEOpID
+
+ // NodeID is the ID of the filesystem object being operated on.
+ NodeID uint64
+
+ // UID is the UID of the requesting process.
+ UID uint32
+
+ // GID is the GID of the requesting process.
+ GID uint32
+
+ // PID is the PID of the requesting process.
+ PID uint32
+
+ _ uint32
+}
+
+// FUSEHeaderOut is the header written by the daemon when it processes
+// a request and wants to send a reply (almost all operations require a
+// reply; if they do not, this will be explicitly documented).
+//
+// +marshal
+type FUSEHeaderOut struct {
+ // Len specifies the total length of the data, including this header.
+ Len uint32
+
+ // Error specifies the error that occurred (0 if none).
+ Error int32
+
+ // Unique specifies the unique identifier of the corresponding request.
+ Unique FUSEOpID
+}
+
+// FUSEWriteIn is the header written by a daemon when it makes a
+// write request to the FUSE filesystem.
+//
+// +marshal
+type FUSEWriteIn struct {
+ // Fh specifies the file handle that is being written to.
+ Fh uint64
+
+ // Offset is the offset of the write.
+ Offset uint64
+
+ // Size is the size of data being written.
+ Size uint32
+
+ // WriteFlags is the flags used during the write.
+ WriteFlags uint32
+
+ // LockOwner is the ID of the lock owner.
+ LockOwner uint64
+
+ // Flags is the flags for the request.
+ Flags uint32
+
+ _ uint32
+}
diff --git a/pkg/sentry/fs/fsutil/BUILD b/pkg/sentry/fs/fsutil/BUILD
index 789369220..5fb419bcd 100644
--- a/pkg/sentry/fs/fsutil/BUILD
+++ b/pkg/sentry/fs/fsutil/BUILD
@@ -8,7 +8,6 @@ go_template_instance(
out = "dirty_set_impl.go",
imports = {
"memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
prefix = "Dirty",
@@ -25,14 +24,14 @@ go_template_instance(
name = "frame_ref_set_impl",
out = "frame_ref_set_impl.go",
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "fsutil",
prefix = "FrameRef",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "uint64",
"Functions": "FrameRefSetFunctions",
},
@@ -43,7 +42,6 @@ go_template_instance(
out = "file_range_set_impl.go",
imports = {
"memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
},
package = "fsutil",
prefix = "FileRange",
@@ -86,7 +84,6 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/platform",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usage",
"//pkg/state",
diff --git a/pkg/sentry/fs/fsutil/dirty_set.go b/pkg/sentry/fs/fsutil/dirty_set.go
index c6cd45087..2c9446c1d 100644
--- a/pkg/sentry/fs/fsutil/dirty_set.go
+++ b/pkg/sentry/fs/fsutil/dirty_set.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -159,7 +158,7 @@ func (ds *DirtySet) AllowClean(mr memmap.MappableRange) {
// repeatedly until all bytes have been written. max is the true size of the
// cached object; offsets beyond max will not be passed to writeAt, even if
// they are marked dirty.
-func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
var changedDirty bool
defer func() {
if changedDirty {
@@ -194,7 +193,7 @@ func SyncDirty(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet
// successful partial write, SyncDirtyAll will call it repeatedly until all
// bytes have been written. max is the true size of the cached object; offsets
// beyond max will not be passed to writeAt, even if they are marked dirty.
-func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
dseg := dirty.FirstSegment()
for dseg.Ok() {
if err := syncDirtyRange(ctx, dseg.Range(), cache, max, mem, writeAt); err != nil {
@@ -210,7 +209,7 @@ func SyncDirtyAll(ctx context.Context, cache *FileRangeSet, dirty *DirtySet, max
}
// Preconditions: mr must be page-aligned.
-func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem platform.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
+func syncDirtyRange(ctx context.Context, mr memmap.MappableRange, cache *FileRangeSet, max uint64, mem memmap.File, writeAt func(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)) error {
for cseg := cache.LowerBoundSegment(mr.Start); cseg.Ok() && cseg.Start() < mr.End; cseg = cseg.NextSegment() {
wbr := cseg.Range().Intersect(mr)
if max < wbr.Start {
diff --git a/pkg/sentry/fs/fsutil/file_range_set.go b/pkg/sentry/fs/fsutil/file_range_set.go
index 5643cdac9..bbafebf03 100644
--- a/pkg/sentry/fs/fsutil/file_range_set.go
+++ b/pkg/sentry/fs/fsutil/file_range_set.go
@@ -23,13 +23,12 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/usermem"
)
// FileRangeSet maps offsets into a memmap.Mappable to offsets into a
-// platform.File. It is used to implement Mappables that store data in
+// memmap.File. It is used to implement Mappables that store data in
// sparsely-allocated memory.
//
// type FileRangeSet <generated by go_generics>
@@ -65,20 +64,20 @@ func (FileRangeSetFunctions) Split(mr memmap.MappableRange, frstart uint64, spli
}
// FileRange returns the FileRange mapped by seg.
-func (seg FileRangeIterator) FileRange() platform.FileRange {
+func (seg FileRangeIterator) FileRange() memmap.FileRange {
return seg.FileRangeOf(seg.Range())
}
// FileRangeOf returns the FileRange mapped by mr.
//
// Preconditions: seg.Range().IsSupersetOf(mr). mr.Length() != 0.
-func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) platform.FileRange {
+func (seg FileRangeIterator) FileRangeOf(mr memmap.MappableRange) memmap.FileRange {
frstart := seg.Value() + (mr.Start - seg.Start())
- return platform.FileRange{frstart, frstart + mr.Length()}
+ return memmap.FileRange{frstart, frstart + mr.Length()}
}
// Fill attempts to ensure that all memmap.Mappable offsets in required are
-// mapped to a platform.File offset, by allocating from mf with the given
+// mapped to a memmap.File offset, by allocating from mf with the given
// memory usage kind and invoking readAt to store data into memory. (If readAt
// returns a successful partial read, Fill will call it repeatedly until all
// bytes have been read.) EOF is handled consistently with the requirements of
@@ -141,7 +140,7 @@ func (frs *FileRangeSet) Fill(ctx context.Context, required, optional memmap.Map
}
// Drop removes segments for memmap.Mappable offsets in mr, freeing the
-// corresponding platform.FileRanges.
+// corresponding memmap.FileRanges.
//
// Preconditions: mr must be page-aligned.
func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
@@ -154,7 +153,7 @@ func (frs *FileRangeSet) Drop(mr memmap.MappableRange, mf *pgalloc.MemoryFile) {
}
// DropAll removes all segments in mr, freeing the corresponding
-// platform.FileRanges.
+// memmap.FileRanges.
func (frs *FileRangeSet) DropAll(mf *pgalloc.MemoryFile) {
for seg := frs.FirstSegment(); seg.Ok(); seg = seg.NextSegment() {
mf.DecRef(seg.FileRange())
diff --git a/pkg/sentry/fs/fsutil/frame_ref_set.go b/pkg/sentry/fs/fsutil/frame_ref_set.go
index dd6f5aba6..a808894df 100644
--- a/pkg/sentry/fs/fsutil/frame_ref_set.go
+++ b/pkg/sentry/fs/fsutil/frame_ref_set.go
@@ -17,7 +17,7 @@ package fsutil
import (
"math"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
)
@@ -39,7 +39,7 @@ func (FrameRefSetFunctions) ClearValue(val *uint64) {
}
// Merge implements segment.Functions.Merge.
-func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.FileRange, val2 uint64) (uint64, bool) {
+func (FrameRefSetFunctions) Merge(_ memmap.FileRange, val1 uint64, _ memmap.FileRange, val2 uint64) (uint64, bool) {
if val1 != val2 {
return 0, false
}
@@ -47,13 +47,13 @@ func (FrameRefSetFunctions) Merge(_ platform.FileRange, val1 uint64, _ platform.
}
// Split implements segment.Functions.Split.
-func (FrameRefSetFunctions) Split(_ platform.FileRange, val uint64, _ uint64) (uint64, uint64) {
+func (FrameRefSetFunctions) Split(_ memmap.FileRange, val uint64, _ uint64) (uint64, uint64) {
return val, val
}
// IncRefAndAccount adds a reference on the range fr. All newly inserted segments
// are accounted as host page cache memory mappings.
-func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) {
+func (refs *FrameRefSet) IncRefAndAccount(fr memmap.FileRange) {
seg, gap := refs.Find(fr.Start)
for {
switch {
@@ -74,7 +74,7 @@ func (refs *FrameRefSet) IncRefAndAccount(fr platform.FileRange) {
// DecRefAndAccount removes a reference on the range fr and untracks segments
// that are removed from memory accounting.
-func (refs *FrameRefSet) DecRefAndAccount(fr platform.FileRange) {
+func (refs *FrameRefSet) DecRefAndAccount(fr memmap.FileRange) {
seg := refs.FindSegment(fr.Start)
for seg.Ok() && seg.Start() < fr.End {
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index e82afd112..ef0113b52 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -126,7 +125,7 @@ func (f *HostFileMapper) DecRefOn(mr memmap.MappableRange) {
// offsets in fr or until the next call to UnmapAll.
//
// Preconditions: The caller must hold a reference on all offsets in fr.
-func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
+func (f *HostFileMapper) MapInternal(fr memmap.FileRange, fd int, write bool) (safemem.BlockSeq, error) {
chunks := ((fr.End + chunkMask) >> chunkShift) - (fr.Start >> chunkShift)
f.mapsMu.Lock()
defer f.mapsMu.Unlock()
@@ -146,7 +145,7 @@ func (f *HostFileMapper) MapInternal(fr platform.FileRange, fd int, write bool)
}
// Preconditions: f.mapsMu must be locked.
-func (f *HostFileMapper) forEachMappingBlockLocked(fr platform.FileRange, fd int, write bool, fn func(safemem.Block)) error {
+func (f *HostFileMapper) forEachMappingBlockLocked(fr memmap.FileRange, fd int, write bool, fn func(safemem.Block)) error {
prot := syscall.PROT_READ
if write {
prot |= syscall.PROT_WRITE
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index 78fec553e..c15d8a946 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -21,18 +21,17 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
-// HostMappable implements memmap.Mappable and platform.File over a
+// HostMappable implements memmap.Mappable and memmap.File over a
// CachedFileObject.
//
// Lock order (compare the lock order model in mm/mm.go):
// truncateMu ("fs locks")
// mu ("memmap.Mappable locks not taken by Translate")
-// ("platform.File locks")
+// ("memmap.File locks")
// backingFile ("CachedFileObject locks")
//
// +stateify savable
@@ -124,24 +123,24 @@ func (h *HostMappable) NotifyChangeFD() error {
return nil
}
-// MapInternal implements platform.File.MapInternal.
-func (h *HostMappable) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (h *HostMappable) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return h.hostFileMapper.MapInternal(fr, h.backingFile.FD(), at.Write)
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (h *HostMappable) FD() int {
return h.backingFile.FD()
}
-// IncRef implements platform.File.IncRef.
-func (h *HostMappable) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (h *HostMappable) IncRef(fr memmap.FileRange) {
mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
h.hostFileMapper.IncRefOn(mr)
}
-// DecRef implements platform.File.DecRef.
-func (h *HostMappable) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (h *HostMappable) DecRef(fr memmap.FileRange) {
mr := memmap.MappableRange{Start: fr.Start, End: fr.End}
h.hostFileMapper.DecRefOn(mr)
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index 800c8b4e1..fe8b0b6ac 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -26,7 +26,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -934,7 +933,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
func (c *CachingInodeOperations) InvalidateUnsavable(ctx context.Context) error {
- // Whether we have a host fd (and consequently what platform.File is
+ // Whether we have a host fd (and consequently what memmap.File is
// mapped) can change across save/restore, so invalidate all translations
// unconditionally.
c.mapsMu.Lock()
@@ -999,10 +998,10 @@ func (c *CachingInodeOperations) Evict(ctx context.Context, er pgalloc.Evictable
}
}
-// IncRef implements platform.File.IncRef. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// IncRef implements memmap.File.IncRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
-func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
+func (c *CachingInodeOperations) IncRef(fr memmap.FileRange) {
// Hot path. Avoid defers.
c.dataMu.Lock()
seg, gap := c.refs.Find(fr.Start)
@@ -1024,10 +1023,10 @@ func (c *CachingInodeOperations) IncRef(fr platform.FileRange) {
}
}
-// DecRef implements platform.File.DecRef. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// DecRef implements memmap.File.DecRef. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
-func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
+func (c *CachingInodeOperations) DecRef(fr memmap.FileRange) {
// Hot path. Avoid defers.
c.dataMu.Lock()
seg := c.refs.FindSegment(fr.Start)
@@ -1046,15 +1045,15 @@ func (c *CachingInodeOperations) DecRef(fr platform.FileRange) {
c.dataMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal. This is used when we
+// MapInternal implements memmap.File.MapInternal. This is used when we
// directly map an underlying host fd and CachingInodeOperations is used as the
-// platform.File during translation.
-func (c *CachingInodeOperations) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// memmap.File during translation.
+func (c *CachingInodeOperations) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return c.hostFileMapper.MapInternal(fr, c.backingFile.FD(), at.Write)
}
-// FD implements platform.File.FD. This is used when we directly map an
-// underlying host fd and CachingInodeOperations is used as the platform.File
+// FD implements memmap.File.FD. This is used when we directly map an
+// underlying host fd and CachingInodeOperations is used as the memmap.File
// during translation.
func (c *CachingInodeOperations) FD() int {
return c.backingFile.FD()
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 737007748..67649e811 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -1,12 +1,28 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
licenses(["notice"])
+go_template_instance(
+ name = "request_list",
+ out = "request_list.go",
+ package = "fuse",
+ prefix = "request",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*Request",
+ "Linker": "*Request",
+ },
+)
+
go_library(
name = "fuse",
srcs = [
+ "connection.go",
"dev.go",
"fusefs.go",
+ "register.go",
+ "request_list.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
@@ -18,7 +34,30 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+go_test(
+ name = "dev_test",
+ size = "small",
+ srcs = ["dev_test.go"],
+ library = ":fuse",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/fsimpl/testutil",
+ "//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/vfs",
"//pkg/syserror",
"//pkg/usermem",
+ "//pkg/waiter",
+ "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/fsimpl/fuse/connection.go b/pkg/sentry/fsimpl/fuse/connection.go
new file mode 100644
index 000000000..f330da0bd
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/connection.go
@@ -0,0 +1,255 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "errors"
+ "fmt"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// MaxActiveRequestsDefault is the default setting controlling the upper bound
+// on the number of active requests at any given time.
+const MaxActiveRequestsDefault = 10000
+
+var (
+ // Ordinary requests have even IDs, while interrupts IDs are odd.
+ InitReqBit uint64 = 1
+ ReqIDStep uint64 = 2
+)
+
+// Request represents a FUSE operation request that hasn't been sent to the
+// server yet.
+//
+// +stateify savable
+type Request struct {
+ requestEntry
+
+ id linux.FUSEOpID
+ hdr *linux.FUSEHeaderIn
+ data []byte
+}
+
+// Response represents an actual response from the server, including the
+// response payload.
+//
+// +stateify savable
+type Response struct {
+ opcode linux.FUSEOpcode
+ hdr linux.FUSEHeaderOut
+ data []byte
+}
+
+// Connection is the struct by which the sentry communicates with the FUSE server daemon.
+type Connection struct {
+ fd *DeviceFD
+
+ // MaxWrite is the daemon's maximum size of a write buffer.
+ // This is negotiated during FUSE_INIT.
+ MaxWrite uint32
+}
+
+// NewFUSEConnection creates a FUSE connection to fd
+func NewFUSEConnection(_ context.Context, fd *vfs.FileDescription, maxInFlightRequests uint64) (*Connection, error) {
+ // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to
+ // mount a FUSE filesystem.
+ fuseFD := fd.Impl().(*DeviceFD)
+ fuseFD.mounted = true
+
+ // Create the writeBuf for the header to be stored in.
+ hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ fuseFD.writeBuf = make([]byte, hdrLen)
+ fuseFD.completions = make(map[linux.FUSEOpID]*futureResponse)
+ fuseFD.fullQueueCh = make(chan struct{}, maxInFlightRequests)
+ fuseFD.writeCursor = 0
+
+ return &Connection{
+ fd: fuseFD,
+ }, nil
+}
+
+// NewRequest creates a new request that can be sent to the FUSE server.
+func (conn *Connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+ conn.fd.nextOpID += linux.FUSEOpID(ReqIDStep)
+
+ hdrLen := (*linux.FUSEHeaderIn)(nil).SizeBytes()
+ hdr := linux.FUSEHeaderIn{
+ Len: uint32(hdrLen + payload.SizeBytes()),
+ Opcode: opcode,
+ Unique: conn.fd.nextOpID,
+ NodeID: ino,
+ UID: uint32(creds.EffectiveKUID),
+ GID: uint32(creds.EffectiveKGID),
+ PID: pid,
+ }
+
+ buf := make([]byte, hdr.Len)
+ hdr.MarshalUnsafe(buf[:hdrLen])
+ payload.MarshalUnsafe(buf[hdrLen:])
+
+ return &Request{
+ id: hdr.Unique,
+ hdr: &hdr,
+ data: buf,
+ }, nil
+}
+
+// Call makes a request to the server and blocks the invoking task until a
+// server responds with a response.
+// NOTE: If no task is provided then the Call will simply enqueue the request
+// and return a nil response. No blocking will happen in this case. Instead,
+// this is used to signify that the processing of this request will happen by
+// the kernel.Task that writes the response. See FUSE_INIT for such an
+// invocation.
+func (conn *Connection) Call(t *kernel.Task, r *Request) (*Response, error) {
+ fut, err := conn.callFuture(t, r)
+ if err != nil {
+ return nil, err
+ }
+
+ return fut.resolve(t)
+}
+
+// Error returns the error of the FUSE call.
+func (r *Response) Error() error {
+ errno := r.hdr.Error
+ if errno >= 0 {
+ return nil
+ }
+
+ sysErrNo := syscall.Errno(-errno)
+ return error(sysErrNo)
+}
+
+// UnmarshalPayload unmarshals the response data into m.
+func (r *Response) UnmarshalPayload(m marshal.Marshallable) error {
+ hdrLen := r.hdr.SizeBytes()
+ haveDataLen := r.hdr.Len - uint32(hdrLen)
+ wantDataLen := uint32(m.SizeBytes())
+
+ if haveDataLen < wantDataLen {
+ return fmt.Errorf("payload too small. Minimum data lenth required: %d, but got data length %d", wantDataLen, haveDataLen)
+ }
+
+ m.UnmarshalUnsafe(r.data[hdrLen:])
+ return nil
+}
+
+// callFuture makes a request to the server and returns a future response.
+// Call resolve() when the response needs to be fulfilled.
+func (conn *Connection) callFuture(t *kernel.Task, r *Request) (*futureResponse, error) {
+ conn.fd.mu.Lock()
+ defer conn.fd.mu.Unlock()
+
+ // Is the queue full?
+ //
+ // We must busy wait here until the request can be queued. We don't
+ // block on the fd.fullQueueCh with a lock - so after being signalled,
+ // before we acquire the lock, it is possible that a barging task enters
+ // and queues a request. As a result, upon acquiring the lock we must
+ // again check if the room is available.
+ //
+ // This can potentially starve a request forever but this can only happen
+ // if there are always too many ongoing requests all the time. The
+ // supported maxActiveRequests setting should be really high to avoid this.
+ for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
+ if t == nil {
+ // Since there is no task that is waiting. We must error out.
+ return nil, errors.New("FUSE request queue full")
+ }
+
+ log.Infof("Blocking request %v from being queued. Too many active requests: %v",
+ r.id, conn.fd.numActiveRequests)
+ conn.fd.mu.Unlock()
+ err := t.Block(conn.fd.fullQueueCh)
+ conn.fd.mu.Lock()
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return conn.callFutureLocked(t, r)
+}
+
+// callFutureLocked makes a request to the server and returns a future response.
+func (conn *Connection) callFutureLocked(t *kernel.Task, r *Request) (*futureResponse, error) {
+ conn.fd.queue.PushBack(r)
+ conn.fd.numActiveRequests += 1
+ fut := newFutureResponse(r.hdr.Opcode)
+ conn.fd.completions[r.id] = fut
+
+ // Signal the readers that there is something to read.
+ conn.fd.waitQueue.Notify(waiter.EventIn)
+
+ return fut, nil
+}
+
+// futureResponse represents an in-flight request, that may or may not have
+// completed yet. Convert it to a resolved Response by calling Resolve, but note
+// that this may block.
+//
+// +stateify savable
+type futureResponse struct {
+ opcode linux.FUSEOpcode
+ ch chan struct{}
+ hdr *linux.FUSEHeaderOut
+ data []byte
+}
+
+// newFutureResponse creates a future response to a FUSE request.
+func newFutureResponse(opcode linux.FUSEOpcode) *futureResponse {
+ return &futureResponse{
+ opcode: opcode,
+ ch: make(chan struct{}),
+ }
+}
+
+// resolve blocks the task until the server responds to its corresponding request,
+// then returns a resolved response.
+func (f *futureResponse) resolve(t *kernel.Task) (*Response, error) {
+ // If there is no Task associated with this request - then we don't try to resolve
+ // the response. Instead, the task writing the response (proxy to the server) will
+ // process the response on our behalf.
+ if t == nil {
+ log.Infof("fuse.Response.resolve: Not waiting on a response from server.")
+ return nil, nil
+ }
+
+ if err := t.Block(f.ch); err != nil {
+ return nil, err
+ }
+
+ return f.getResponse(), nil
+}
+
+// getResponse creates a Response from the data the futureResponse has.
+func (f *futureResponse) getResponse() *Response {
+ return &Response{
+ opcode: f.opcode,
+ hdr: *f.hdr,
+ data: f.data,
+ }
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go
index c9e12a94f..f3443ac71 100644
--- a/pkg/sentry/fsimpl/fuse/dev.go
+++ b/pkg/sentry/fsimpl/fuse/dev.go
@@ -15,13 +15,17 @@
package fuse
import (
+ "syscall"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
)
const fuseDevMinor = 229
@@ -54,9 +58,43 @@ type DeviceFD struct {
// mounted specifies whether a FUSE filesystem was mounted using the DeviceFD.
mounted bool
- // TODO(gvisor.dev/issue/2987): Add all the data structures needed to enqueue
- // and deque requests, control synchronization and establish communication
- // between the FUSE kernel module and the /dev/fuse character device.
+ // nextOpID is used to create new requests.
+ nextOpID linux.FUSEOpID
+
+ // queue is the list of requests that need to be processed by the FUSE server.
+ queue requestList
+
+ // numActiveRequests is the number of requests made by the Sentry that has
+ // yet to be responded to.
+ numActiveRequests uint64
+
+ // completions is used to map a request to its response. A Writer will use this
+ // to notify the caller of a completed response.
+ completions map[linux.FUSEOpID]*futureResponse
+
+ writeCursor uint32
+
+ // writeBuf is the memory buffer used to copy in the FUSE out header from
+ // userspace.
+ writeBuf []byte
+
+ // writeCursorFR current FR being copied from server.
+ writeCursorFR *futureResponse
+
+ // mu protects all the queues, maps, buffers and cursors and nextOpID.
+ mu sync.Mutex
+
+ // waitQueue is used to notify interested parties when the device becomes
+ // readable or writable.
+ waitQueue waiter.Queue
+
+ // fullQueueCh is a channel used to synchronize the readers with the writers.
+ // Writers (inbound requests to the filesystem) block if there are too many
+ // unprocessed in-flight requests.
+ fullQueueCh chan struct{}
+
+ // fs is the FUSE filesystem that this FD is being used for.
+ fs *filesystem
}
// Release implements vfs.FileDescriptionImpl.Release.
@@ -79,7 +117,75 @@ func (fd *DeviceFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.R
return 0, syserror.EPERM
}
- return 0, syserror.ENOSYS
+ // We require that any Read done on this filesystem have a sane minimum
+ // read buffer. It must have the capacity for the fixed parts of any request
+ // header (Linux uses the request header and the FUSEWriteIn header for this
+ // calculation) + the negotiated MaxWrite room for the data.
+ minBuffSize := linux.FUSE_MIN_READ_BUFFER
+ inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
+ writeHdrLen := uint32((*linux.FUSEWriteIn)(nil).SizeBytes())
+ negotiatedMinBuffSize := inHdrLen + writeHdrLen + fd.fs.conn.MaxWrite
+ if minBuffSize < negotiatedMinBuffSize {
+ minBuffSize = negotiatedMinBuffSize
+ }
+
+ // If the read buffer is too small, error out.
+ if dst.NumBytes() < int64(minBuffSize) {
+ return 0, syserror.EINVAL
+ }
+
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.readLocked(ctx, dst, opts)
+}
+
+// readLocked implements the reading of the fuse device while locked with DeviceFD.mu.
+func (fd *DeviceFD) readLocked(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ if fd.queue.Empty() {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ var readCursor uint32
+ var bytesRead int64
+ for {
+ req := fd.queue.Front()
+ if dst.NumBytes() < int64(req.hdr.Len) {
+ // The request is too large. Cannot process it. All requests must be smaller than the
+ // negotiated size as specified by Connection.MaxWrite set as part of the FUSE_INIT
+ // handshake.
+ errno := -int32(syscall.EIO)
+ if req.hdr.Opcode == linux.FUSE_SETXATTR {
+ errno = -int32(syscall.E2BIG)
+ }
+
+ // Return the error to the calling task.
+ if err := fd.sendError(ctx, errno, req); err != nil {
+ return 0, err
+ }
+
+ // We're done with this request.
+ fd.queue.Remove(req)
+
+ // Restart the read as this request was invalid.
+ log.Warningf("fuse.DeviceFD.Read: request found was too large. Restarting read.")
+ return fd.readLocked(ctx, dst, opts)
+ }
+
+ n, err := dst.CopyOut(ctx, req.data[readCursor:])
+ if err != nil {
+ return 0, err
+ }
+ readCursor += uint32(n)
+ bytesRead += int64(n)
+
+ if readCursor >= req.hdr.Len {
+ // Fully done with this req, remove it from the queue.
+ fd.queue.Remove(req)
+ break
+ }
+ }
+
+ return bytesRead, nil
}
// PWrite implements vfs.FileDescriptionImpl.PWrite.
@@ -94,12 +200,128 @@ func (fd *DeviceFD) PWrite(ctx context.Context, src usermem.IOSequence, offset i
// Write implements vfs.FileDescriptionImpl.Write.
func (fd *DeviceFD) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ fd.mu.Lock()
+ defer fd.mu.Unlock()
+ return fd.writeLocked(ctx, src, opts)
+}
+
+// writeLocked implements writing to the fuse device while locked with DeviceFD.mu.
+func (fd *DeviceFD) writeLocked(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
// Operations on /dev/fuse don't make sense until a FUSE filesystem is mounted.
if !fd.mounted {
return 0, syserror.EPERM
}
- return 0, syserror.ENOSYS
+ var cn, n int64
+ hdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+
+ for src.NumBytes() > 0 {
+ if fd.writeCursorFR != nil {
+ // Already have common header, and we're now copying the payload.
+ wantBytes := fd.writeCursorFR.hdr.Len
+
+ // Note that the FR data doesn't have the header. Copy it over if its necessary.
+ if fd.writeCursorFR.data == nil {
+ fd.writeCursorFR.data = make([]byte, wantBytes)
+ }
+
+ bytesCopied, err := src.CopyIn(ctx, fd.writeCursorFR.data[fd.writeCursor:wantBytes])
+ if err != nil {
+ return 0, err
+ }
+ src = src.DropFirst(bytesCopied)
+
+ cn = int64(bytesCopied)
+ n += cn
+ fd.writeCursor += uint32(cn)
+ if fd.writeCursor == wantBytes {
+ // Done reading this full response. Clean up and unblock the
+ // initiator.
+ break
+ }
+
+ // Check if we have more data in src.
+ continue
+ }
+
+ // Assert that the header isn't read into the writeBuf yet.
+ if fd.writeCursor >= hdrLen {
+ return 0, syserror.EINVAL
+ }
+
+ // We don't have the full common response header yet.
+ wantBytes := hdrLen - fd.writeCursor
+ bytesCopied, err := src.CopyIn(ctx, fd.writeBuf[fd.writeCursor:wantBytes])
+ if err != nil {
+ return 0, err
+ }
+ src = src.DropFirst(bytesCopied)
+
+ cn = int64(bytesCopied)
+ n += cn
+ fd.writeCursor += uint32(cn)
+ if fd.writeCursor == hdrLen {
+ // Have full header in the writeBuf. Use it to fetch the actual futureResponse
+ // from the device's completions map.
+ var hdr linux.FUSEHeaderOut
+ hdr.UnmarshalBytes(fd.writeBuf)
+
+ // We have the header now and so the writeBuf has served its purpose.
+ // We could reset it manually here but instead of doing that, at the
+ // end of the write, the writeCursor will be set to 0 thereby allowing
+ // the next request to overwrite whats in the buffer,
+
+ fut, ok := fd.completions[hdr.Unique]
+ if !ok {
+ // Server sent us a response for a request we never sent?
+ return 0, syserror.EINVAL
+ }
+
+ delete(fd.completions, hdr.Unique)
+
+ // Copy over the header into the future response. The rest of the payload
+ // will be copied over to the FR's data in the next iteration.
+ fut.hdr = &hdr
+ fd.writeCursorFR = fut
+
+ // Next iteration will now try read the complete request, if src has
+ // any data remaining. Otherwise we're done.
+ }
+ }
+
+ if fd.writeCursorFR != nil {
+ if err := fd.sendResponse(ctx, fd.writeCursorFR); err != nil {
+ return 0, err
+ }
+
+ // Ready the device for the next request.
+ fd.writeCursorFR = nil
+ fd.writeCursor = 0
+ }
+
+ return n, nil
+}
+
+// Readiness implements vfs.FileDescriptionImpl.Readiness.
+func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask {
+ var ready waiter.EventMask
+ ready |= waiter.EventOut // FD is always writable
+ if !fd.queue.Empty() {
+ // Have reqs available, FD is readable.
+ ready |= waiter.EventIn
+ }
+
+ return ready & mask
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (fd *DeviceFD) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ fd.waitQueue.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (fd *DeviceFD) EventUnregister(e *waiter.Entry) {
+ fd.waitQueue.EventUnregister(e)
}
// Seek implements vfs.FileDescriptionImpl.Seek.
@@ -112,22 +334,61 @@ func (fd *DeviceFD) Seek(ctx context.Context, offset int64, whence int32) (int64
return 0, syserror.ENOSYS
}
-// Register registers the FUSE device with vfsObj.
-func Register(vfsObj *vfs.VirtualFilesystem) error {
- if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{
- GroupName: "misc",
- }); err != nil {
+// sendResponse sends a response to the waiting task (if any).
+func (fd *DeviceFD) sendResponse(ctx context.Context, fut *futureResponse) error {
+ // See if the running task need to perform some action before returning.
+ // Since we just finished writing the future, we can be sure that
+ // getResponse generates a populated response.
+ if err := fd.noReceiverAction(ctx, fut.getResponse()); err != nil {
return err
}
+ // Signal that the queue is no longer full.
+ select {
+ case fd.fullQueueCh <- struct{}{}:
+ default:
+ }
+ fd.numActiveRequests -= 1
+
+ // Signal the task waiting on a response.
+ close(fut.ch)
return nil
}
-// CreateDevtmpfsFile creates a device special file in devtmpfs.
-func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error {
- if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil {
+// sendError sends an error response to the waiting task (if any).
+func (fd *DeviceFD) sendError(ctx context.Context, errno int32, req *Request) error {
+ // Return the error to the calling task.
+ outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ respHdr := linux.FUSEHeaderOut{
+ Len: outHdrLen,
+ Error: errno,
+ Unique: req.hdr.Unique,
+ }
+
+ fut, ok := fd.completions[respHdr.Unique]
+ if !ok {
+ // Server sent us a response for a request we never sent?
+ return syserror.EINVAL
+ }
+ delete(fd.completions, respHdr.Unique)
+
+ fut.hdr = &respHdr
+ if err := fd.sendResponse(ctx, fut); err != nil {
return err
}
return nil
}
+
+// noReceiverAction has the calling kernel.Task do some action if its known that no
+// receiver is going to be waiting on the future channel. This is to be used by:
+// FUSE_INIT.
+func (fd *DeviceFD) noReceiverAction(ctx context.Context, r *Response) error {
+ if r.opcode == linux.FUSE_INIT {
+ // TODO: process init response here.
+ // Maybe get the creds from the context?
+ // creds := auth.CredentialsFromContext(ctx)
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
new file mode 100644
index 000000000..fcd77832a
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -0,0 +1,429 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "fmt"
+ "io"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+)
+
+// echoTestOpcode is the Opcode used during testing. The server used in tests
+// will simply echo the payload back with the appropriate headers.
+const echoTestOpcode linux.FUSEOpcode = 1000
+
+type testPayload struct {
+ data uint32
+}
+
+// TestFUSECommunication tests that the communication layer between the Sentry and the
+// FUSE server daemon works as expected.
+func TestFUSECommunication(t *testing.T) {
+ s := setup(t)
+ defer s.Destroy()
+
+ k := kernel.KernelFromContext(s.Ctx)
+ creds := auth.CredentialsFromContext(s.Ctx)
+
+ // Create test cases with different number of concurrent clients and servers.
+ testCases := []struct {
+ Name string
+ NumClients int
+ NumServers int
+ MaxActiveRequests uint64
+ }{
+ {
+ Name: "SingleClientSingleServer",
+ NumClients: 1,
+ NumServers: 1,
+ MaxActiveRequests: MaxActiveRequestsDefault,
+ },
+ {
+ Name: "SingleClientMultipleServers",
+ NumClients: 1,
+ NumServers: 10,
+ MaxActiveRequests: MaxActiveRequestsDefault,
+ },
+ {
+ Name: "MultipleClientsSingleServer",
+ NumClients: 10,
+ NumServers: 1,
+ MaxActiveRequests: MaxActiveRequestsDefault,
+ },
+ {
+ Name: "MultipleClientsMultipleServers",
+ NumClients: 10,
+ NumServers: 10,
+ MaxActiveRequests: MaxActiveRequestsDefault,
+ },
+ {
+ Name: "RequestCapacityFull",
+ NumClients: 10,
+ NumServers: 1,
+ MaxActiveRequests: 1,
+ },
+ {
+ Name: "RequestCapacityContinuouslyFull",
+ NumClients: 100,
+ NumServers: 2,
+ MaxActiveRequests: 2,
+ },
+ }
+
+ for _, testCase := range testCases {
+ t.Run(testCase.Name, func(t *testing.T) {
+ conn, fd, err := newTestConnection(s, k, testCase.MaxActiveRequests)
+ if err != nil {
+ t.Fatalf("newTestConnection: %v", err)
+ }
+
+ clientsDone := make([]chan struct{}, testCase.NumClients)
+ serversDone := make([]chan struct{}, testCase.NumServers)
+ serversKill := make([]chan struct{}, testCase.NumServers)
+
+ // FUSE clients.
+ for i := 0; i < testCase.NumClients; i++ {
+ clientsDone[i] = make(chan struct{})
+ go func(i int) {
+ fuseClientRun(t, s, k, conn, creds, uint32(i), uint64(i), clientsDone[i])
+ }(i)
+ }
+
+ // FUSE servers.
+ for j := 0; j < testCase.NumServers; j++ {
+ serversDone[j] = make(chan struct{})
+ serversKill[j] = make(chan struct{}, 1) // The kill command shouldn't block.
+ go func(j int) {
+ fuseServerRun(t, s, k, fd, serversDone[j], serversKill[j])
+ }(j)
+ }
+
+ // Tear down.
+ //
+ // Make sure all the clients are done.
+ for i := 0; i < testCase.NumClients; i++ {
+ <-clientsDone[i]
+ }
+
+ // Kill any server that is potentially waiting.
+ for j := 0; j < testCase.NumServers; j++ {
+ serversKill[j] <- struct{}{}
+ }
+
+ // Make sure all the servers are done.
+ for j := 0; j < testCase.NumServers; j++ {
+ <-serversDone[j]
+ }
+ })
+ }
+}
+
+// CallTest makes a request to the server and blocks the invoking
+// goroutine until a server responds with a response. Doesn't block
+// a kernel.Task. Analogous to Connection.Call but used for testing.
+func CallTest(conn *Connection, t *kernel.Task, r *Request, i uint32) (*Response, error) {
+ conn.fd.mu.Lock()
+
+ // Wait until we're certain that a new request can be processed.
+ for conn.fd.numActiveRequests == conn.fd.fs.opts.maxActiveRequests {
+ conn.fd.mu.Unlock()
+ select {
+ case <-conn.fd.fullQueueCh:
+ }
+ conn.fd.mu.Lock()
+ }
+
+ fut, err := conn.callFutureLocked(t, r) // No task given.
+ conn.fd.mu.Unlock()
+
+ if err != nil {
+ return nil, err
+ }
+
+ // Resolve the response.
+ //
+ // Block without a task.
+ select {
+ case <-fut.ch:
+ }
+
+ // A response is ready. Resolve and return it.
+ return fut.getResponse(), nil
+}
+
+// ReadTest is analogous to vfs.FileDescription.Read and reads from the FUSE
+// device. However, it does so by - not blocking the task that is calling - and
+// instead just waits on a channel. The behaviour is essentially the same as
+// DeviceFD.Read except it guarantees that the task is not blocked.
+func ReadTest(serverTask *kernel.Task, fd *vfs.FileDescription, inIOseq usermem.IOSequence, killServer chan struct{}) (int64, bool, error) {
+ var err error
+ var n, total int64
+
+ dev := fd.Impl().(*DeviceFD)
+
+ // Register for notifications.
+ w, ch := waiter.NewChannelEntry(nil)
+ dev.EventRegister(&w, waiter.EventIn)
+ for {
+ // Issue the request and break out if it completes with anything other than
+ // "would block".
+ n, err = dev.Read(serverTask, inIOseq, vfs.ReadOptions{})
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ // Wait for a notification that we should retry.
+ // Emulate the blocking for when no requests are available
+ select {
+ case <-ch:
+ case <-killServer:
+ // Server killed by the main program.
+ return 0, true, nil
+ }
+ }
+
+ dev.EventUnregister(&w)
+ return total, false, err
+}
+
+// fuseClientRun emulates all the actions of a normal FUSE request. It creates
+// a header, a payload, calls the server, waits for the response, and processes
+// the response.
+func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *Connection, creds *auth.Credentials, pid uint32, inode uint64, clientDone chan struct{}) {
+ defer func() { clientDone <- struct{}{} }()
+
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ clientTask, err := testutil.CreateTask(s.Ctx, fmt.Sprintf("fuse-client-%v", pid), tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatal(err)
+ }
+ testObj := &testPayload{
+ data: rand.Uint32(),
+ }
+
+ req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
+ if err != nil {
+ t.Fatalf("NewRequest creation failed: %v", err)
+ }
+
+ // Queue up a request.
+ // Analogous to Call except it doesn't block on the task.
+ resp, err := CallTest(conn, clientTask, req, pid)
+ if err != nil {
+ t.Fatalf("CallTaskNonBlock failed: %v", err)
+ }
+
+ if err = resp.Error(); err != nil {
+ t.Fatalf("Server responded with an error: %v", err)
+ }
+
+ var respTestPayload testPayload
+ if err := resp.UnmarshalPayload(&respTestPayload); err != nil {
+ t.Fatalf("Unmarshalling payload error: %v", err)
+ }
+
+ if resp.hdr.Unique != req.hdr.Unique {
+ t.Fatalf("got response for another request. Expected response for req %v but got response for req %v",
+ req.hdr.Unique, resp.hdr.Unique)
+ }
+
+ if respTestPayload.data != testObj.data {
+ t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data)
+ }
+
+}
+
+// fuseServerRun creates a task and emulates all the actions of a simple FUSE server
+// that simply reads a request and echos the same struct back as a response using the
+// appropriate headers.
+func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.FileDescription, serverDone, killServer chan struct{}) {
+ defer func() { serverDone <- struct{}{} }()
+
+ // Create the tasks that the server will be using.
+ tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
+ var readPayload testPayload
+
+ serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Read the request.
+ for {
+ inHdrLen := uint32((*linux.FUSEHeaderIn)(nil).SizeBytes())
+ payloadLen := uint32(readPayload.SizeBytes())
+
+ // The raed buffer must meet some certain size criteria.
+ buffSize := inHdrLen + payloadLen
+ if buffSize < linux.FUSE_MIN_READ_BUFFER {
+ buffSize = linux.FUSE_MIN_READ_BUFFER
+ }
+ inBuf := make([]byte, buffSize)
+ inIOseq := usermem.BytesIOSequence(inBuf)
+
+ n, serverKilled, err := ReadTest(serverTask, fd, inIOseq, killServer)
+ if err != nil {
+ t.Fatalf("Read failed :%v", err)
+ }
+
+ // Server should shut down. No new requests are going to be made.
+ if serverKilled {
+ break
+ }
+
+ if n <= 0 {
+ t.Fatalf("Read read no bytes")
+ }
+
+ var readFUSEHeaderIn linux.FUSEHeaderIn
+ readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen])
+ readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen])
+
+ if readFUSEHeaderIn.Opcode != echoTestOpcode {
+ t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload)
+ }
+
+ // Write the response.
+ outHdrLen := uint32((*linux.FUSEHeaderOut)(nil).SizeBytes())
+ outBuf := make([]byte, outHdrLen+payloadLen)
+ outHeader := linux.FUSEHeaderOut{
+ Len: outHdrLen + payloadLen,
+ Error: 0,
+ Unique: readFUSEHeaderIn.Unique,
+ }
+
+ // Echo the payload back.
+ outHeader.MarshalUnsafe(outBuf[:outHdrLen])
+ readPayload.MarshalUnsafe(outBuf[outHdrLen:])
+ outIOseq := usermem.BytesIOSequence(outBuf)
+
+ n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Write failed :%v", err)
+ }
+ }
+}
+
+func setup(t *testing.T) *testutil.System {
+ k, err := testutil.Boot()
+ if err != nil {
+ t.Fatalf("Error creating kernel: %v", err)
+ }
+
+ ctx := k.SupervisorContext()
+ creds := auth.CredentialsFromContext(ctx)
+
+ k.VFS().MustRegisterFilesystemType(Name, &FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserList: true,
+ AllowUserMount: true,
+ })
+
+ mntns, err := k.VFS().NewMountNamespace(ctx, creds, "", tmpfs.Name, &vfs.GetFilesystemOptions{})
+ if err != nil {
+ t.Fatalf("NewMountNamespace(): %v", err)
+ }
+
+ return testutil.NewSystem(ctx, t, k.VFS(), mntns)
+}
+
+// newTestConnection creates a fuse connection that the sentry can communicate with
+// and the FD for the server to communicate with.
+func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveRequests uint64) (*Connection, *vfs.FileDescription, error) {
+ vfsObj := &vfs.VirtualFilesystem{}
+ fuseDev := &DeviceFD{}
+
+ if err := vfsObj.Init(); err != nil {
+ return nil, nil, err
+ }
+
+ vd := vfsObj.NewAnonVirtualDentry("genCountFD")
+ defer vd.DecRef()
+ if err := fuseDev.vfsfd.Init(fuseDev, linux.O_RDWR|linux.O_CREAT, vd.Mount(), vd.Dentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, nil, err
+ }
+
+ fsopts := filesystemOptions{
+ maxActiveRequests: maxActiveRequests,
+ }
+ fs, err := NewFUSEFilesystem(system.Ctx, 0, &fsopts, &fuseDev.vfsfd)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return fs.conn, &fuseDev.vfsfd, nil
+}
+
+// SizeBytes implements marshal.Marshallable.SizeBytes.
+func (t *testPayload) SizeBytes() int {
+ return 4
+}
+
+// MarshalBytes implements marshal.Marshallable.MarshalBytes.
+func (t *testPayload) MarshalBytes(dst []byte) {
+ usermem.ByteOrder.PutUint32(dst[:4], t.data)
+}
+
+// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
+func (t *testPayload) UnmarshalBytes(src []byte) {
+ *t = testPayload{data: usermem.ByteOrder.Uint32(src[:4])}
+}
+
+// Packed implements marshal.Marshallable.Packed.
+func (t *testPayload) Packed() bool {
+ return true
+}
+
+// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
+func (t *testPayload) MarshalUnsafe(dst []byte) {
+ t.MarshalBytes(dst)
+}
+
+// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
+func (t *testPayload) UnmarshalUnsafe(src []byte) {
+ t.UnmarshalBytes(src)
+}
+
+// CopyOutN implements marshal.Marshallable.CopyOutN.
+func (t *testPayload) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
+ panic("not implemented")
+}
+
+// CopyOut implements marshal.Marshallable.CopyOut.
+func (t *testPayload) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// CopyIn implements marshal.Marshallable.CopyIn.
+func (t *testPayload) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
+ panic("not implemented")
+}
+
+// WriteTo implements io.WriterTo.WriteTo.
+func (t *testPayload) WriteTo(w io.Writer) (int64, error) {
+ panic("not implemented")
+}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index f7775fb9b..911b6f7cb 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -51,6 +51,11 @@ type filesystemOptions struct {
// rootMode specifies the the file mode of the filesystem's root.
rootMode linux.FileMode
+
+ // maxActiveRequests specifies the maximum number of active requests that can
+ // exist at any time. Any further requests will block when trying to
+ // Call the server.
+ maxActiveRequests uint64
}
// filesystem implements vfs.FilesystemImpl.
@@ -58,12 +63,12 @@ type filesystem struct {
kernfs.Filesystem
devMinor uint32
- // fuseFD is the FD returned when opening /dev/fuse. It is used for communication
- // between the FUSE server daemon and the sentry fusefs.
- fuseFD *DeviceFD
+ // conn is used for communication between the FUSE server
+ // daemon and the sentry fusefs.
+ conn *Connection
// opts is the options the fusefs is initialized with.
- opts filesystemOptions
+ opts *filesystemOptions
}
// Name implements vfs.FilesystemType.Name.
@@ -100,7 +105,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
fuseFd := kernelTask.GetFileVFS2(int32(deviceDescriptor))
// Parse and set all the other supported FUSE mount options.
- // TODO: Expand the supported mount options.
+ // TODO(gVisor.dev/issue/3229): Expand the supported mount options.
if userIDStr, ok := mopts["user_id"]; ok {
delete(mopts, "user_id")
userID, err := strconv.ParseUint(userIDStr, 10, 32)
@@ -134,21 +139,20 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
fsopts.rootMode = rootMode
+ // Set the maxInFlightRequests option.
+ fsopts.maxActiveRequests = MaxActiveRequestsDefault
+
// Check for unparsed options.
if len(mopts) != 0 {
log.Warningf("%s.GetFilesystem: unknown options: %v", fsType.Name(), mopts)
return nil, nil, syserror.EINVAL
}
- // Mark the device as ready so it can be used. /dev/fuse can only be used if the FD was used to
- // mount a FUSE filesystem.
- fuseFD := fuseFd.Impl().(*DeviceFD)
- fuseFD.mounted = true
-
- fs := &filesystem{
- devMinor: devMinor,
- fuseFD: fuseFD,
- opts: fsopts,
+ // Create a new FUSE filesystem.
+ fs, err := NewFUSEFilesystem(ctx, devMinor, &fsopts, fuseFd)
+ if err != nil {
+ log.Warningf("%s.NewFUSEFilesystem: failed with error: %v", fsType.Name(), err)
+ return nil, nil, err
}
fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
@@ -162,6 +166,26 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return fs.VFSFilesystem(), root.VFSDentry(), nil
}
+// NewFUSEFilesystem creates a new FUSE filesystem.
+func NewFUSEFilesystem(ctx context.Context, devMinor uint32, opts *filesystemOptions, device *vfs.FileDescription) (*filesystem, error) {
+ fs := &filesystem{
+ devMinor: devMinor,
+ opts: opts,
+ }
+
+ conn, err := NewFUSEConnection(ctx, device, opts.maxActiveRequests)
+ if err != nil {
+ log.Warningf("fuse.NewFUSEFilesystem: NewFUSEConnection failed with error: %v", err)
+ return nil, syserror.EINVAL
+ }
+
+ fs.conn = conn
+ fuseFD := device.Impl().(*DeviceFD)
+ fuseFD.fs = fs
+
+ return fs, nil
+}
+
// Release implements vfs.FilesystemImpl.Release.
func (fs *filesystem) Release() {
fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
diff --git a/pkg/sentry/fsimpl/fuse/register.go b/pkg/sentry/fsimpl/fuse/register.go
new file mode 100644
index 000000000..b5b581152
--- /dev/null
+++ b/pkg/sentry/fsimpl/fuse/register.go
@@ -0,0 +1,42 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package fuse
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+)
+
+// Register registers the FUSE device with vfsObj.
+func Register(vfsObj *vfs.VirtualFilesystem) error {
+ if err := vfsObj.RegisterDevice(vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, fuseDevice{}, &vfs.RegisterDeviceOptions{
+ GroupName: "misc",
+ }); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// CreateDevtmpfsFile creates a device special file in devtmpfs.
+func CreateDevtmpfsFile(ctx context.Context, dev *devtmpfs.Accessor) error {
+ if err := dev.CreateDeviceFile(ctx, "fuse", vfs.CharDevice, linux.MISC_MAJOR, fuseDevMinor, 0666 /* mode */); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 02317a133..09f142cfc 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -29,7 +29,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserror"
@@ -221,12 +220,12 @@ func (fd *regularFileFD) pwriteLocked(ctx context.Context, src usermem.IOSequenc
return 0, syserror.EINVAL
}
mr := memmap.MappableRange{pgstart, pgend}
- var freed []platform.FileRange
+ var freed []memmap.FileRange
d.dataMu.Lock()
cseg := d.cache.LowerBoundSegment(mr.Start)
for cseg.Ok() && cseg.Start() < mr.End {
cseg = d.cache.Isolate(cseg, mr)
- freed = append(freed, platform.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
+ freed = append(freed, memmap.FileRange{cseg.Value(), cseg.Value() + cseg.Range().Length()})
cseg = d.cache.Remove(cseg).NextSegment()
}
d.dataMu.Unlock()
@@ -821,7 +820,7 @@ func maxFillRange(required, optional memmap.MappableRange) memmap.MappableRange
// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
func (d *dentry) InvalidateUnsavable(ctx context.Context) error {
- // Whether we have a host fd (and consequently what platform.File is
+ // Whether we have a host fd (and consequently what memmap.File is
// mapped) can change across save/restore, so invalidate all translations
// unconditionally.
d.mapsMu.Lock()
@@ -869,8 +868,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
}
}
-// dentryPlatformFile implements platform.File. It exists solely because dentry
-// cannot implement both vfs.DentryImpl.IncRef and platform.File.IncRef.
+// dentryPlatformFile implements memmap.File. It exists solely because dentry
+// cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef.
//
// dentryPlatformFile is only used when a host FD representing the remote file
// is available (i.e. dentry.handle.fd >= 0), and that FD is used for
@@ -878,7 +877,7 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) {
type dentryPlatformFile struct {
*dentry
- // fdRefs counts references on platform.File offsets. fdRefs is protected
+ // fdRefs counts references on memmap.File offsets. fdRefs is protected
// by dentry.dataMu.
fdRefs fsutil.FrameRefSet
@@ -890,29 +889,29 @@ type dentryPlatformFile struct {
hostFileMapperInitOnce sync.Once
}
-// IncRef implements platform.File.IncRef.
-func (d *dentryPlatformFile) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (d *dentryPlatformFile) IncRef(fr memmap.FileRange) {
d.dataMu.Lock()
d.fdRefs.IncRefAndAccount(fr)
d.dataMu.Unlock()
}
-// DecRef implements platform.File.DecRef.
-func (d *dentryPlatformFile) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) {
d.dataMu.Lock()
d.fdRefs.DecRefAndAccount(fr)
d.dataMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal.
-func (d *dentryPlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
d.handleMu.RLock()
bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write)
d.handleMu.RUnlock()
return bs, err
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (d *dentryPlatformFile) FD() int {
d.handleMu.RLock()
fd := d.handle.fd
diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD
index e86fbe2d5..bd701bbc7 100644
--- a/pkg/sentry/fsimpl/host/BUILD
+++ b/pkg/sentry/fsimpl/host/BUILD
@@ -34,7 +34,6 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/memmap",
- "//pkg/sentry/platform",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/host/mmap.go
index 8545a82f0..65d3af38c 100644
--- a/pkg/sentry/fsimpl/host/mmap.go
+++ b/pkg/sentry/fsimpl/host/mmap.go
@@ -19,13 +19,12 @@ import (
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
)
-// inodePlatformFile implements platform.File. It exists solely because inode
-// cannot implement both kernfs.Inode.IncRef and platform.File.IncRef.
+// inodePlatformFile implements memmap.File. It exists solely because inode
+// cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef.
//
// inodePlatformFile should only be used if inode.canMap is true.
type inodePlatformFile struct {
@@ -34,7 +33,7 @@ type inodePlatformFile struct {
// fdRefsMu protects fdRefs.
fdRefsMu sync.Mutex
- // fdRefs counts references on platform.File offsets. It is used solely for
+ // fdRefs counts references on memmap.File offsets. It is used solely for
// memory accounting.
fdRefs fsutil.FrameRefSet
@@ -45,32 +44,32 @@ type inodePlatformFile struct {
fileMapperInitOnce sync.Once
}
-// IncRef implements platform.File.IncRef.
+// IncRef implements memmap.File.IncRef.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) IncRef(fr platform.FileRange) {
+func (i *inodePlatformFile) IncRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.IncRefAndAccount(fr)
i.fdRefsMu.Unlock()
}
-// DecRef implements platform.File.DecRef.
+// DecRef implements memmap.File.DecRef.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) DecRef(fr platform.FileRange) {
+func (i *inodePlatformFile) DecRef(fr memmap.FileRange) {
i.fdRefsMu.Lock()
i.fdRefs.DecRefAndAccount(fr)
i.fdRefsMu.Unlock()
}
-// MapInternal implements platform.File.MapInternal.
+// MapInternal implements memmap.File.MapInternal.
//
// Precondition: i.inode.canMap must be true.
-func (i *inodePlatformFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
return i.fileMapper.MapInternal(fr, i.hostFD, at.Write)
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (i *inodePlatformFile) FD() int {
return i.hostFD
}
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index bfd779837..c211fc8d0 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -20,7 +20,6 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/memmap",
"//pkg/sentry/pgalloc",
- "//pkg/sentry/platform",
"//pkg/sentry/usage",
"//pkg/sync",
"//pkg/syserror",
diff --git a/pkg/sentry/kernel/shm/shm.go b/pkg/sentry/kernel/shm/shm.go
index f66cfcc7f..55b4c2cdb 100644
--- a/pkg/sentry/kernel/shm/shm.go
+++ b/pkg/sentry/kernel/shm/shm.go
@@ -45,7 +45,6 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -370,7 +369,7 @@ type Shm struct {
// fr is the offset into mfp.MemoryFile() that backs this contents of this
// segment. Immutable.
- fr platform.FileRange
+ fr memmap.FileRange
// mu protects all fields below.
mu sync.Mutex `state:"nosave"`
diff --git a/pkg/sentry/kernel/timekeeper.go b/pkg/sentry/kernel/timekeeper.go
index 5f3908d8b..7c4fefb16 100644
--- a/pkg/sentry/kernel/timekeeper.go
+++ b/pkg/sentry/kernel/timekeeper.go
@@ -21,8 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/log"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
sentrytime "gvisor.dev/gvisor/pkg/sentry/time"
"gvisor.dev/gvisor/pkg/sync"
)
@@ -90,7 +90,7 @@ type Timekeeper struct {
// NewTimekeeper does not take ownership of paramPage.
//
// SetClocks must be called on the returned Timekeeper before it is usable.
-func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage platform.FileRange) (*Timekeeper, error) {
+func NewTimekeeper(mfp pgalloc.MemoryFileProvider, paramPage memmap.FileRange) (*Timekeeper, error) {
return &Timekeeper{
params: NewVDSOParamPage(mfp, paramPage),
}, nil
diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go
index f1b3c212c..290c32466 100644
--- a/pkg/sentry/kernel/vdso.go
+++ b/pkg/sentry/kernel/vdso.go
@@ -19,8 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -58,7 +58,7 @@ type vdsoParams struct {
type VDSOParamPage struct {
// The parameter page is fr, allocated from mfp.MemoryFile().
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
// seq is the current sequence count written to the page.
//
@@ -81,7 +81,7 @@ type VDSOParamPage struct {
// * VDSOParamPage must be the only writer to fr.
//
// * mfp.MemoryFile().MapInternal(fr) must return a single safemem.Block.
-func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *VDSOParamPage {
+func NewVDSOParamPage(mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *VDSOParamPage {
return &VDSOParamPage{mfp: mfp, fr: fr}
}
diff --git a/pkg/sentry/memmap/BUILD b/pkg/sentry/memmap/BUILD
index a98b66de1..2c95669cd 100644
--- a/pkg/sentry/memmap/BUILD
+++ b/pkg/sentry/memmap/BUILD
@@ -28,9 +28,21 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "file_range",
+ out = "file_range.go",
+ package = "memmap",
+ prefix = "File",
+ template = "//pkg/segment:generic_range",
+ types = {
+ "T": "uint64",
+ },
+)
+
go_library(
name = "memmap",
srcs = [
+ "file_range.go",
"mappable_range.go",
"mapping_set.go",
"mapping_set_impl.go",
@@ -40,7 +52,7 @@ go_library(
deps = [
"//pkg/context",
"//pkg/log",
- "//pkg/sentry/platform",
+ "//pkg/safemem",
"//pkg/syserror",
"//pkg/usermem",
],
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index c6db9fc8f..c188f6c29 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -19,12 +19,12 @@ import (
"fmt"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/usermem"
)
// Mappable represents a memory-mappable object, a mutable mapping from uint64
-// offsets to (platform.File, uint64 File offset) pairs.
+// offsets to (File, uint64 File offset) pairs.
//
// See mm/mm.go for Mappable's place in the lock order.
//
@@ -74,7 +74,7 @@ type Mappable interface {
// Translations are valid until invalidated by a callback to
// MappingSpace.Invalidate or until the caller removes its mapping of the
// translated range. Mappable implementations must ensure that at least one
- // reference is held on all pages in a platform.File that may be the result
+ // reference is held on all pages in a File that may be the result
// of a valid Translation.
//
// Preconditions: required.Length() > 0. optional.IsSupersetOf(required).
@@ -100,7 +100,7 @@ type Translation struct {
Source MappableRange
// File is the mapped file.
- File platform.File
+ File File
// Offset is the offset into File at which this Translation begins.
Offset uint64
@@ -110,9 +110,9 @@ type Translation struct {
Perms usermem.AccessType
}
-// FileRange returns the platform.FileRange represented by t.
-func (t Translation) FileRange() platform.FileRange {
- return platform.FileRange{t.Offset, t.Offset + t.Source.Length()}
+// FileRange returns the FileRange represented by t.
+func (t Translation) FileRange() FileRange {
+ return FileRange{t.Offset, t.Offset + t.Source.Length()}
}
// CheckTranslateResult returns an error if (ts, terr) does not satisfy all
@@ -361,3 +361,49 @@ type MMapOpts struct {
// TODO(jamieliu): Replace entirely with MappingIdentity?
Hint string
}
+
+// File represents a host file that may be mapped into an platform.AddressSpace.
+type File interface {
+ // All pages in a File are reference-counted.
+
+ // IncRef increments the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr. (The File
+ // interface does not provide a way to acquire an initial reference;
+ // implementors may define mechanisms for doing so.)
+ IncRef(fr FileRange)
+
+ // DecRef decrements the reference count on all pages in fr.
+ //
+ // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
+ // 0. At least one reference must be held on all pages in fr.
+ DecRef(fr FileRange)
+
+ // MapInternal returns a mapping of the given file offsets in the invoking
+ // process' address space for reading and writing.
+ //
+ // Note that fr.Start and fr.End need not be page-aligned.
+ //
+ // Preconditions: fr.Length() > 0. At least one reference must be held on
+ // all pages in fr.
+ //
+ // Postconditions: The returned mapping is valid as long as at least one
+ // reference is held on the mapped pages.
+ MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
+
+ // FD returns the file descriptor represented by the File.
+ //
+ // The only permitted operation on the returned file descriptor is to map
+ // pages from it consistent with the requirements of AddressSpace.MapFile.
+ FD() int
+}
+
+// FileRange represents a range of uint64 offsets into a File.
+//
+// type FileRange <generated using go_generics>
+
+// String implements fmt.Stringer.String.
+func (fr FileRange) String() string {
+ return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
+}
diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD
index a036ce53c..f9d0837a1 100644
--- a/pkg/sentry/mm/BUILD
+++ b/pkg/sentry/mm/BUILD
@@ -7,14 +7,14 @@ go_template_instance(
name = "file_refcount_set",
out = "file_refcount_set.go",
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "mm",
prefix = "fileRefcount",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "int32",
"Functions": "fileRefcountSetFunctions",
},
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 379148903..1999ec706 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -20,7 +20,6 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
@@ -243,7 +242,7 @@ type aioMappable struct {
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
}
var aioRingBufferSize = uint64(usermem.Addr(linux.AIORingSize).MustRoundUp())
diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go
index 6db7c3d40..3e85964e4 100644
--- a/pkg/sentry/mm/mm.go
+++ b/pkg/sentry/mm/mm.go
@@ -25,7 +25,7 @@
// Locks taken by memmap.Mappable.Translate
// mm.privateRefs.mu
// platform.AddressSpace locks
-// platform.File locks
+// memmap.File locks
// mm.aioManager.mu
// mm.AIOContext.mu
//
@@ -396,7 +396,7 @@ type pma struct {
// file is the file mapped by this pma. Only pmas for which file ==
// MemoryManager.mfp.MemoryFile() may be saved. pmas hold a reference to
// the corresponding file range while they exist.
- file platform.File `state:"nosave"`
+ file memmap.File `state:"nosave"`
// off is the offset into file at which this pma begins.
//
@@ -436,7 +436,7 @@ type pma struct {
private bool
// If internalMappings is not empty, it is the cached return value of
- // file.MapInternal for the platform.FileRange mapped by this pma.
+ // file.MapInternal for the memmap.FileRange mapped by this pma.
internalMappings safemem.BlockSeq `state:"nosave"`
}
@@ -469,10 +469,10 @@ func (fileRefcountSetFunctions) MaxKey() uint64 {
func (fileRefcountSetFunctions) ClearValue(_ *int32) {
}
-func (fileRefcountSetFunctions) Merge(_ platform.FileRange, rc1 int32, _ platform.FileRange, rc2 int32) (int32, bool) {
+func (fileRefcountSetFunctions) Merge(_ memmap.FileRange, rc1 int32, _ memmap.FileRange, rc2 int32) (int32, bool) {
return rc1, rc1 == rc2
}
-func (fileRefcountSetFunctions) Split(_ platform.FileRange, rc int32, _ uint64) (int32, int32) {
+func (fileRefcountSetFunctions) Split(_ memmap.FileRange, rc int32, _ uint64) (int32, int32) {
return rc, rc
}
diff --git a/pkg/sentry/mm/pma.go b/pkg/sentry/mm/pma.go
index 62e4c20af..930ec895f 100644
--- a/pkg/sentry/mm/pma.go
+++ b/pkg/sentry/mm/pma.go
@@ -21,7 +21,6 @@ import (
"gvisor.dev/gvisor/pkg/safecopy"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/memmap"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -604,7 +603,7 @@ func (mm *MemoryManager) invalidateLocked(ar usermem.AddrRange, invalidatePrivat
}
}
-// Pin returns the platform.File ranges currently mapped by addresses in ar in
+// Pin returns the memmap.File ranges currently mapped by addresses in ar in
// mm, acquiring a reference on the returned ranges which the caller must
// release by calling Unpin. If not all addresses are mapped, Pin returns a
// non-nil error. Note that Pin may return both a non-empty slice of
@@ -674,15 +673,15 @@ type PinnedRange struct {
Source usermem.AddrRange
// File is the mapped file.
- File platform.File
+ File memmap.File
// Offset is the offset into File at which this PinnedRange begins.
Offset uint64
}
-// FileRange returns the platform.File offsets mapped by pr.
-func (pr PinnedRange) FileRange() platform.FileRange {
- return platform.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
+// FileRange returns the memmap.File offsets mapped by pr.
+func (pr PinnedRange) FileRange() memmap.FileRange {
+ return memmap.FileRange{pr.Offset, pr.Offset + uint64(pr.Source.Length())}
}
// Unpin releases the reference held by prs.
@@ -857,7 +856,7 @@ func (mm *MemoryManager) vecInternalMappingsLocked(ars usermem.AddrRangeSeq) saf
}
// incPrivateRef acquires a reference on private pages in fr.
-func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
+func (mm *MemoryManager) incPrivateRef(fr memmap.FileRange) {
mm.privateRefs.mu.Lock()
defer mm.privateRefs.mu.Unlock()
refSet := &mm.privateRefs.refs
@@ -878,8 +877,8 @@ func (mm *MemoryManager) incPrivateRef(fr platform.FileRange) {
}
// decPrivateRef releases a reference on private pages in fr.
-func (mm *MemoryManager) decPrivateRef(fr platform.FileRange) {
- var freed []platform.FileRange
+func (mm *MemoryManager) decPrivateRef(fr memmap.FileRange) {
+ var freed []memmap.FileRange
mm.privateRefs.mu.Lock()
refSet := &mm.privateRefs.refs
@@ -951,7 +950,7 @@ func (pmaSetFunctions) Merge(ar1 usermem.AddrRange, pma1 pma, ar2 usermem.AddrRa
// Discard internal mappings instead of trying to merge them, since merging
// them requires an allocation and getting them again from the
- // platform.File might not.
+ // memmap.File might not.
pma1.internalMappings = safemem.BlockSeq{}
return pma1, true
}
@@ -1012,12 +1011,12 @@ func (pseg pmaIterator) getInternalMappingsLocked() error {
return nil
}
-func (pseg pmaIterator) fileRange() platform.FileRange {
+func (pseg pmaIterator) fileRange() memmap.FileRange {
return pseg.fileRangeOf(pseg.Range())
}
// Preconditions: pseg.Range().IsSupersetOf(ar). ar.Length != 0.
-func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
+func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) memmap.FileRange {
if checkInvariants {
if !pseg.Ok() {
panic("terminal pma iterator")
@@ -1032,5 +1031,5 @@ func (pseg pmaIterator) fileRangeOf(ar usermem.AddrRange) platform.FileRange {
pma := pseg.ValuePtr()
pstart := pseg.Start()
- return platform.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
+ return memmap.FileRange{pma.off + uint64(ar.Start-pstart), pma.off + uint64(ar.End-pstart)}
}
diff --git a/pkg/sentry/mm/special_mappable.go b/pkg/sentry/mm/special_mappable.go
index 9ad52082d..0e142fb11 100644
--- a/pkg/sentry/mm/special_mappable.go
+++ b/pkg/sentry/mm/special_mappable.go
@@ -19,7 +19,6 @@ import (
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/pgalloc"
- "gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
@@ -35,7 +34,7 @@ type SpecialMappable struct {
refs.AtomicRefCount
mfp pgalloc.MemoryFileProvider
- fr platform.FileRange
+ fr memmap.FileRange
name string
}
@@ -44,7 +43,7 @@ type SpecialMappable struct {
// SpecialMappable will use the given name in /proc/[pid]/maps.
//
// Preconditions: fr.Length() != 0.
-func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr platform.FileRange) *SpecialMappable {
+func NewSpecialMappable(name string, mfp pgalloc.MemoryFileProvider, fr memmap.FileRange) *SpecialMappable {
m := SpecialMappable{mfp: mfp, fr: fr, name: name}
m.EnableLeakCheck("mm.SpecialMappable")
return &m
@@ -126,7 +125,7 @@ func (m *SpecialMappable) MemoryFileProvider() pgalloc.MemoryFileProvider {
// FileRange returns the offsets into MemoryFileProvider().MemoryFile() that
// store the SpecialMappable's contents.
-func (m *SpecialMappable) FileRange() platform.FileRange {
+func (m *SpecialMappable) FileRange() memmap.FileRange {
return m.fr
}
diff --git a/pkg/sentry/pgalloc/BUILD b/pkg/sentry/pgalloc/BUILD
index e1fcb175f..7a3311a70 100644
--- a/pkg/sentry/pgalloc/BUILD
+++ b/pkg/sentry/pgalloc/BUILD
@@ -36,14 +36,14 @@ go_template_instance(
"trackGaps": "1",
},
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "pgalloc",
prefix = "usage",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "usageInfo",
"Functions": "usageSetFunctions",
},
@@ -56,14 +56,14 @@ go_template_instance(
"minDegree": "10",
},
imports = {
- "platform": "gvisor.dev/gvisor/pkg/sentry/platform",
+ "memmap": "gvisor.dev/gvisor/pkg/sentry/memmap",
},
package = "pgalloc",
prefix = "reclaim",
template = "//pkg/segment:generic_set",
types = {
"Key": "uint64",
- "Range": "platform.FileRange",
+ "Range": "memmap.FileRange",
"Value": "reclaimSetValue",
"Functions": "reclaimSetFunctions",
},
@@ -89,7 +89,7 @@ go_library(
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/hostmm",
- "//pkg/sentry/platform",
+ "//pkg/sentry/memmap",
"//pkg/sentry/usage",
"//pkg/state",
"//pkg/state/wire",
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index afab97c0a..3243d7214 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -33,14 +33,14 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/hostmm"
- "gvisor.dev/gvisor/pkg/sentry/platform"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/usage"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
)
-// MemoryFile is a platform.File whose pages may be allocated to arbitrary
+// MemoryFile is a memmap.File whose pages may be allocated to arbitrary
// users.
type MemoryFile struct {
// opts holds options passed to NewMemoryFile. opts is immutable.
@@ -372,7 +372,7 @@ func (f *MemoryFile) Destroy() {
// to Allocate.
//
// Preconditions: length must be page-aligned and non-zero.
-func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.FileRange, error) {
+func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.FileRange, error) {
if length == 0 || length%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid allocation length: %#x", length))
}
@@ -390,7 +390,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
// Find a range in the underlying file.
fr, ok := findAvailableRange(&f.usage, f.fileSize, length, alignment)
if !ok {
- return platform.FileRange{}, syserror.ENOMEM
+ return memmap.FileRange{}, syserror.ENOMEM
}
// Expand the file if needed.
@@ -398,7 +398,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
// Round the new file size up to be chunk-aligned.
newFileSize := (int64(fr.End) + chunkMask) &^ chunkMask
if err := f.file.Truncate(newFileSize); err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
f.fileSize = newFileSize
f.mappingsMu.Lock()
@@ -416,7 +416,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
bs[i] = 0
}
}); err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
}
if !f.usage.Add(fr, usageInfo{
@@ -439,7 +439,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (platform.Fi
// space for mappings to be allocated downwards.
//
// Precondition: alignment must be a power of 2.
-func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (platform.FileRange, bool) {
+func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint64) (memmap.FileRange, bool) {
alignmentMask := alignment - 1
// Search for space in existing gaps, starting at the current end of the
@@ -461,7 +461,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
break
}
if start := unalignedStart &^ alignmentMask; start >= gap.Start() {
- return platform.FileRange{start, start + length}, true
+ return memmap.FileRange{start, start + length}, true
}
gap = gap.PrevLargeEnoughGap(length)
@@ -475,7 +475,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
min = (min + alignmentMask) &^ alignmentMask
if min+length < min {
// Overflow: allocation would exceed the range of uint64.
- return platform.FileRange{}, false
+ return memmap.FileRange{}, false
}
// Determine the minimum file size required to fit this allocation at its end.
@@ -484,7 +484,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
if newFileSize <= fileSize {
if fileSize != 0 {
// Overflow: allocation would exceed the range of int64.
- return platform.FileRange{}, false
+ return memmap.FileRange{}, false
}
newFileSize = chunkSize
}
@@ -496,7 +496,7 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
continue
}
if start := unalignedStart &^ alignmentMask; start >= min {
- return platform.FileRange{start, start + length}, true
+ return memmap.FileRange{start, start + length}, true
}
}
}
@@ -508,22 +508,22 @@ func findAvailableRange(usage *usageSet, fileSize int64, length, alignment uint6
// by r.ReadToBlocks(), it returns that error.
//
// Preconditions: length > 0. length must be page-aligned.
-func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (platform.FileRange, error) {
+func (f *MemoryFile) AllocateAndFill(length uint64, kind usage.MemoryKind, r safemem.Reader) (memmap.FileRange, error) {
fr, err := f.Allocate(length, kind)
if err != nil {
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
dsts, err := f.MapInternal(fr, usermem.Write)
if err != nil {
f.DecRef(fr)
- return platform.FileRange{}, err
+ return memmap.FileRange{}, err
}
n, err := safemem.ReadFullToBlocks(r, dsts)
un := uint64(usermem.Addr(n).RoundDown())
if un < length {
// Free unused memory and update fr to contain only the memory that is
// still allocated.
- f.DecRef(platform.FileRange{fr.Start + un, fr.End})
+ f.DecRef(memmap.FileRange{fr.Start + un, fr.End})
fr.End = fr.Start + un
}
return fr, err
@@ -540,7 +540,7 @@ const (
// will read zeroes.
//
// Preconditions: fr.Length() > 0.
-func (f *MemoryFile) Decommit(fr platform.FileRange) error {
+func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -560,7 +560,7 @@ func (f *MemoryFile) Decommit(fr platform.FileRange) error {
return nil
}
-func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
+func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
// Since we're changing the knownCommitted attribute, we need to merge
@@ -581,8 +581,8 @@ func (f *MemoryFile) markDecommitted(fr platform.FileRange) {
f.usage.MergeRange(fr)
}
-// IncRef implements platform.File.IncRef.
-func (f *MemoryFile) IncRef(fr platform.FileRange) {
+// IncRef implements memmap.File.IncRef.
+func (f *MemoryFile) IncRef(fr memmap.FileRange) {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -600,8 +600,8 @@ func (f *MemoryFile) IncRef(fr platform.FileRange) {
f.usage.MergeAdjacent(fr)
}
-// DecRef implements platform.File.DecRef.
-func (f *MemoryFile) DecRef(fr platform.FileRange) {
+// DecRef implements memmap.File.DecRef.
+func (f *MemoryFile) DecRef(fr memmap.FileRange) {
if !fr.WellFormed() || fr.Length() == 0 || fr.Start%usermem.PageSize != 0 || fr.End%usermem.PageSize != 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -637,8 +637,8 @@ func (f *MemoryFile) DecRef(fr platform.FileRange) {
}
}
-// MapInternal implements platform.File.MapInternal.
-func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
+// MapInternal implements memmap.File.MapInternal.
+func (f *MemoryFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) {
if !fr.WellFormed() || fr.Length() == 0 {
panic(fmt.Sprintf("invalid range: %v", fr))
}
@@ -664,7 +664,7 @@ func (f *MemoryFile) MapInternal(fr platform.FileRange, at usermem.AccessType) (
// forEachMappingSlice invokes fn on a sequence of byte slices that
// collectively map all bytes in fr.
-func (f *MemoryFile) forEachMappingSlice(fr platform.FileRange, fn func([]byte)) error {
+func (f *MemoryFile) forEachMappingSlice(fr memmap.FileRange, fn func([]byte)) error {
mappings := f.mappings.Load().([]uintptr)
for chunkStart := fr.Start &^ chunkMask; chunkStart < fr.End; chunkStart += chunkSize {
chunk := int(chunkStart >> chunkShift)
@@ -944,7 +944,7 @@ func (f *MemoryFile) updateUsageLocked(currentUsage uint64, checkCommitted func(
continue
case !populated && populatedRun:
// Finish the run by changing this segment.
- runRange := platform.FileRange{
+ runRange := memmap.FileRange{
Start: r.Start + uint64(populatedRunStart*usermem.PageSize),
End: r.Start + uint64(i*usermem.PageSize),
}
@@ -1009,7 +1009,7 @@ func (f *MemoryFile) File() *os.File {
return f.file
}
-// FD implements platform.File.FD.
+// FD implements memmap.File.FD.
func (f *MemoryFile) FD() int {
return int(f.file.Fd())
}
@@ -1090,13 +1090,13 @@ func (f *MemoryFile) runReclaim() {
//
// Note that there returned range will be removed from tracking. It
// must be reclaimed (removed from f.usage) at this point.
-func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
+func (f *MemoryFile) findReclaimable() (memmap.FileRange, bool) {
f.mu.Lock()
defer f.mu.Unlock()
for {
for {
if f.destroyed {
- return platform.FileRange{}, false
+ return memmap.FileRange{}, false
}
if f.reclaimable {
break
@@ -1120,7 +1120,7 @@ func (f *MemoryFile) findReclaimable() (platform.FileRange, bool) {
}
}
-func (f *MemoryFile) markReclaimed(fr platform.FileRange) {
+func (f *MemoryFile) markReclaimed(fr memmap.FileRange) {
f.mu.Lock()
defer f.mu.Unlock()
seg := f.usage.FindSegment(fr.Start)
@@ -1222,11 +1222,11 @@ func (usageSetFunctions) MaxKey() uint64 {
func (usageSetFunctions) ClearValue(val *usageInfo) {
}
-func (usageSetFunctions) Merge(_ platform.FileRange, val1 usageInfo, _ platform.FileRange, val2 usageInfo) (usageInfo, bool) {
+func (usageSetFunctions) Merge(_ memmap.FileRange, val1 usageInfo, _ memmap.FileRange, val2 usageInfo) (usageInfo, bool) {
return val1, val1 == val2
}
-func (usageSetFunctions) Split(_ platform.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
+func (usageSetFunctions) Split(_ memmap.FileRange, val usageInfo, _ uint64) (usageInfo, usageInfo) {
return val, val
}
@@ -1270,10 +1270,10 @@ func (reclaimSetFunctions) MaxKey() uint64 {
func (reclaimSetFunctions) ClearValue(val *reclaimSetValue) {
}
-func (reclaimSetFunctions) Merge(_ platform.FileRange, _ reclaimSetValue, _ platform.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
+func (reclaimSetFunctions) Merge(_ memmap.FileRange, _ reclaimSetValue, _ memmap.FileRange, _ reclaimSetValue) (reclaimSetValue, bool) {
return reclaimSetValue{}, true
}
-func (reclaimSetFunctions) Split(_ platform.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
+func (reclaimSetFunctions) Split(_ memmap.FileRange, _ reclaimSetValue, _ uint64) (reclaimSetValue, reclaimSetValue) {
return reclaimSetValue{}, reclaimSetValue{}
}
diff --git a/pkg/sentry/platform/BUILD b/pkg/sentry/platform/BUILD
index 453241eca..209b28053 100644
--- a/pkg/sentry/platform/BUILD
+++ b/pkg/sentry/platform/BUILD
@@ -1,39 +1,21 @@
load("//tools:defs.bzl", "go_library")
-load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
-go_template_instance(
- name = "file_range",
- out = "file_range.go",
- package = "platform",
- prefix = "File",
- template = "//pkg/segment:generic_range",
- types = {
- "T": "uint64",
- },
-)
-
go_library(
name = "platform",
srcs = [
"context.go",
- "file_range.go",
"mmap_min_addr.go",
"platform.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/atomicbitops",
"//pkg/context",
- "//pkg/log",
- "//pkg/safecopy",
- "//pkg/safemem",
"//pkg/seccomp",
"//pkg/sentry/arch",
- "//pkg/sentry/usage",
- "//pkg/syserror",
+ "//pkg/sentry/memmap",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 10a10bfe2..b5d27a72a 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -47,6 +47,7 @@ go_library(
"//pkg/safecopy",
"//pkg/seccomp",
"//pkg/sentry/arch",
+ "//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sentry/platform/ring0",
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index faf1d5e1c..98a3e539d 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -18,6 +18,7 @@ import (
"sync/atomic"
"gvisor.dev/gvisor/pkg/atomicbitops"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/sync"
@@ -150,7 +151,7 @@ func (as *addressSpace) mapLocked(addr usermem.Addr, m hostMapEntry, at usermem.
}
// MapFile implements platform.AddressSpace.MapFile.
-func (as *addressSpace) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+func (as *addressSpace) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
as.mu.Lock()
defer as.mu.Unlock()
diff --git a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
index 6531bae1d..48ccf8474 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64_unsafe.go
@@ -22,7 +22,8 @@ import (
)
var (
- runDataSize int
+ runDataSize int
+ hasGuestPCID bool
)
func updateSystemValues(fd int) error {
@@ -33,6 +34,7 @@ func updateSystemValues(fd int) error {
}
// Save the data.
runDataSize = int(sz)
+ hasGuestPCID = true
// Success.
return nil
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index 3de309c1a..ff8c068c0 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0"
+ "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -156,6 +157,14 @@ func (c *vCPU) initArchState() error {
return err
}
+ // Initialize the PCID database.
+ if hasGuestPCID {
+ // Note that NewPCIDs may return a nil table here, in which
+ // case we simply don't use PCID support (see below). In
+ // practice, this should not happen, however.
+ c.PCIDs = pagetables.NewPCIDs(fixedKernelPCID+1, poolPCIDs)
+ }
+
c.floatingPointState = arch.NewFloatingPointData()
return nil
}
@@ -234,6 +243,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo)
return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info)
}
+ // Assign PCIDs.
+ if c.PCIDs != nil {
+ var requireFlushPCID bool // Force a flush?
+ switchOpts.UserASID, requireFlushPCID = c.PCIDs.Assign(switchOpts.PageTables)
+ switchOpts.Flush = switchOpts.Flush || requireFlushPCID
+ }
+
var vector ring0.Vector
ttbr0App := switchOpts.PageTables.TTBR0_EL1(false, 0)
c.SetTtbr0App(uintptr(ttbr0App))
diff --git a/pkg/sentry/platform/platform.go b/pkg/sentry/platform/platform.go
index 171513f3f..4b13eec30 100644
--- a/pkg/sentry/platform/platform.go
+++ b/pkg/sentry/platform/platform.go
@@ -22,9 +22,9 @@ import (
"os"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/seccomp"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -207,7 +207,7 @@ type AddressSpace interface {
// Preconditions: addr and fr must be page-aligned. fr.Length() > 0.
// at.Any() == true. At least one reference must be held on all pages in
// fr, and must continue to be held as long as pages are mapped.
- MapFile(addr usermem.Addr, f File, fr FileRange, at usermem.AccessType, precommit bool) error
+ MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error
// Unmap unmaps the given range.
//
@@ -310,52 +310,6 @@ func (f SegmentationFault) Error() string {
return fmt.Sprintf("segmentation fault at %#x", f.Addr)
}
-// File represents a host file that may be mapped into an AddressSpace.
-type File interface {
- // All pages in a File are reference-counted.
-
- // IncRef increments the reference count on all pages in fr.
- //
- // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
- // 0. At least one reference must be held on all pages in fr. (The File
- // interface does not provide a way to acquire an initial reference;
- // implementors may define mechanisms for doing so.)
- IncRef(fr FileRange)
-
- // DecRef decrements the reference count on all pages in fr.
- //
- // Preconditions: fr.Start and fr.End must be page-aligned. fr.Length() >
- // 0. At least one reference must be held on all pages in fr.
- DecRef(fr FileRange)
-
- // MapInternal returns a mapping of the given file offsets in the invoking
- // process' address space for reading and writing.
- //
- // Note that fr.Start and fr.End need not be page-aligned.
- //
- // Preconditions: fr.Length() > 0. At least one reference must be held on
- // all pages in fr.
- //
- // Postconditions: The returned mapping is valid as long as at least one
- // reference is held on the mapped pages.
- MapInternal(fr FileRange, at usermem.AccessType) (safemem.BlockSeq, error)
-
- // FD returns the file descriptor represented by the File.
- //
- // The only permitted operation on the returned file descriptor is to map
- // pages from it consistent with the requirements of AddressSpace.MapFile.
- FD() int
-}
-
-// FileRange represents a range of uint64 offsets into a File.
-//
-// type FileRange <generated using go_generics>
-
-// String implements fmt.Stringer.String.
-func (fr FileRange) String() string {
- return fmt.Sprintf("[%#x, %#x)", fr.Start, fr.End)
-}
-
// Requirements is used to specify platform specific requirements.
type Requirements struct {
// RequiresCurrentPIDNS indicates that the sandbox has to be started in the
diff --git a/pkg/sentry/platform/ptrace/BUILD b/pkg/sentry/platform/ptrace/BUILD
index 30402c2df..29fd23cc3 100644
--- a/pkg/sentry/platform/ptrace/BUILD
+++ b/pkg/sentry/platform/ptrace/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/hostcpu",
+ "//pkg/sentry/memmap",
"//pkg/sentry/platform",
"//pkg/sentry/platform/interrupt",
"//pkg/sync",
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 2389423b0..c990f3454 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/procid"
"gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/platform"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/usermem"
@@ -616,7 +617,7 @@ func (s *subprocess) syscall(sysno uintptr, args ...arch.SyscallArgument) (uintp
}
// MapFile implements platform.AddressSpace.MapFile.
-func (s *subprocess) MapFile(addr usermem.Addr, f platform.File, fr platform.FileRange, at usermem.AccessType, precommit bool) error {
+func (s *subprocess) MapFile(addr usermem.Addr, f memmap.File, fr memmap.FileRange, at usermem.AccessType, precommit bool) error {
var flags int
if precommit {
flags |= syscall.MAP_POPULATE
diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
index 1b2cfad7d..c576d9475 100644
--- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go
+++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go
@@ -62,7 +62,7 @@ func Override() {
s.Table[55] = syscalls.Supported("getsockopt", GetSockOpt)
s.Table[59] = syscalls.Supported("execve", Execve)
s.Table[72] = syscalls.Supported("fcntl", Fcntl)
- s.Table[73] = syscalls.Supported("fcntl", Flock)
+ s.Table[73] = syscalls.Supported("flock", Flock)
s.Table[74] = syscalls.Supported("fsync", Fsync)
s.Table[75] = syscalls.Supported("fdatasync", Fdatasync)
s.Table[76] = syscalls.Supported("truncate", Truncate)
@@ -163,6 +163,106 @@ func Override() {
// Override ARM64.
s = linux.ARM64
+ s.Table[5] = syscalls.Supported("setxattr", Setxattr)
+ s.Table[6] = syscalls.Supported("lsetxattr", Lsetxattr)
+ s.Table[7] = syscalls.Supported("fsetxattr", Fsetxattr)
+ s.Table[8] = syscalls.Supported("getxattr", Getxattr)
+ s.Table[9] = syscalls.Supported("lgetxattr", Lgetxattr)
+ s.Table[10] = syscalls.Supported("fgetxattr", Fgetxattr)
+ s.Table[11] = syscalls.Supported("listxattr", Listxattr)
+ s.Table[12] = syscalls.Supported("llistxattr", Llistxattr)
+ s.Table[13] = syscalls.Supported("flistxattr", Flistxattr)
+ s.Table[14] = syscalls.Supported("removexattr", Removexattr)
+ s.Table[15] = syscalls.Supported("lremovexattr", Lremovexattr)
+ s.Table[16] = syscalls.Supported("fremovexattr", Fremovexattr)
+ s.Table[17] = syscalls.Supported("getcwd", Getcwd)
+ s.Table[19] = syscalls.Supported("eventfd2", Eventfd2)
+ s.Table[20] = syscalls.Supported("epoll_create1", EpollCreate1)
+ s.Table[21] = syscalls.Supported("epoll_ctl", EpollCtl)
+ s.Table[22] = syscalls.Supported("epoll_pwait", EpollPwait)
+ s.Table[23] = syscalls.Supported("dup", Dup)
+ s.Table[24] = syscalls.Supported("dup3", Dup3)
+ s.Table[25] = syscalls.Supported("fcntl", Fcntl)
+ s.Table[26] = syscalls.PartiallySupported("inotify_init1", InotifyInit1, "inotify events are only available inside the sandbox.", nil)
+ s.Table[27] = syscalls.PartiallySupported("inotify_add_watch", InotifyAddWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[28] = syscalls.PartiallySupported("inotify_rm_watch", InotifyRmWatch, "inotify events are only available inside the sandbox.", nil)
+ s.Table[29] = syscalls.Supported("ioctl", Ioctl)
+ s.Table[32] = syscalls.Supported("flock", Flock)
+ s.Table[33] = syscalls.Supported("mknodat", Mknodat)
+ s.Table[34] = syscalls.Supported("mkdirat", Mkdirat)
+ s.Table[35] = syscalls.Supported("unlinkat", Unlinkat)
+ s.Table[36] = syscalls.Supported("symlinkat", Symlinkat)
+ s.Table[37] = syscalls.Supported("linkat", Linkat)
+ s.Table[38] = syscalls.Supported("renameat", Renameat)
+ s.Table[39] = syscalls.Supported("umount2", Umount2)
+ s.Table[40] = syscalls.Supported("mount", Mount)
+ s.Table[43] = syscalls.Supported("statfs", Statfs)
+ s.Table[44] = syscalls.Supported("fstatfs", Fstatfs)
+ s.Table[45] = syscalls.Supported("truncate", Truncate)
+ s.Table[46] = syscalls.Supported("ftruncate", Ftruncate)
+ s.Table[48] = syscalls.Supported("faccessat", Faccessat)
+ s.Table[49] = syscalls.Supported("chdir", Chdir)
+ s.Table[50] = syscalls.Supported("fchdir", Fchdir)
+ s.Table[51] = syscalls.Supported("chroot", Chroot)
+ s.Table[52] = syscalls.Supported("fchmod", Fchmod)
+ s.Table[53] = syscalls.Supported("fchmodat", Fchmodat)
+ s.Table[54] = syscalls.Supported("fchownat", Fchownat)
+ s.Table[55] = syscalls.Supported("fchown", Fchown)
+ s.Table[56] = syscalls.Supported("openat", Openat)
+ s.Table[57] = syscalls.Supported("close", Close)
+ s.Table[59] = syscalls.Supported("pipe2", Pipe2)
+ s.Table[61] = syscalls.Supported("getdents64", Getdents64)
+ s.Table[62] = syscalls.Supported("lseek", Lseek)
s.Table[63] = syscalls.Supported("read", Read)
+ s.Table[64] = syscalls.Supported("write", Write)
+ s.Table[65] = syscalls.Supported("readv", Readv)
+ s.Table[66] = syscalls.Supported("writev", Writev)
+ s.Table[67] = syscalls.Supported("pread64", Pread64)
+ s.Table[68] = syscalls.Supported("pwrite64", Pwrite64)
+ s.Table[69] = syscalls.Supported("preadv", Preadv)
+ s.Table[70] = syscalls.Supported("pwritev", Pwritev)
+ s.Table[72] = syscalls.Supported("pselect", Pselect)
+ s.Table[73] = syscalls.Supported("ppoll", Ppoll)
+ s.Table[74] = syscalls.Supported("signalfd4", Signalfd4)
+ s.Table[76] = syscalls.Supported("splice", Splice)
+ s.Table[77] = syscalls.Supported("tee", Tee)
+ s.Table[78] = syscalls.Supported("readlinkat", Readlinkat)
+ s.Table[80] = syscalls.Supported("fstat", Fstat)
+ s.Table[81] = syscalls.Supported("sync", Sync)
+ s.Table[82] = syscalls.Supported("fsync", Fsync)
+ s.Table[83] = syscalls.Supported("fdatasync", Fdatasync)
+ s.Table[84] = syscalls.Supported("sync_file_range", SyncFileRange)
+ s.Table[85] = syscalls.Supported("timerfd_create", TimerfdCreate)
+ s.Table[86] = syscalls.Supported("timerfd_settime", TimerfdSettime)
+ s.Table[87] = syscalls.Supported("timerfd_gettime", TimerfdGettime)
+ s.Table[88] = syscalls.Supported("utimensat", Utimensat)
+ s.Table[198] = syscalls.Supported("socket", Socket)
+ s.Table[199] = syscalls.Supported("socketpair", SocketPair)
+ s.Table[200] = syscalls.Supported("bind", Bind)
+ s.Table[201] = syscalls.Supported("listen", Listen)
+ s.Table[202] = syscalls.Supported("accept", Accept)
+ s.Table[203] = syscalls.Supported("connect", Connect)
+ s.Table[204] = syscalls.Supported("getsockname", GetSockName)
+ s.Table[205] = syscalls.Supported("getpeername", GetPeerName)
+ s.Table[206] = syscalls.Supported("sendto", SendTo)
+ s.Table[207] = syscalls.Supported("recvfrom", RecvFrom)
+ s.Table[208] = syscalls.Supported("setsockopt", SetSockOpt)
+ s.Table[209] = syscalls.Supported("getsockopt", GetSockOpt)
+ s.Table[210] = syscalls.Supported("shutdown", Shutdown)
+ s.Table[211] = syscalls.Supported("sendmsg", SendMsg)
+ s.Table[212] = syscalls.Supported("recvmsg", RecvMsg)
+ s.Table[221] = syscalls.Supported("execve", Execve)
+ s.Table[222] = syscalls.Supported("mmap", Mmap)
+ s.Table[242] = syscalls.Supported("accept4", Accept4)
+ s.Table[243] = syscalls.Supported("recvmmsg", RecvMMsg)
+ s.Table[267] = syscalls.Supported("syncfs", Syncfs)
+ s.Table[269] = syscalls.Supported("sendmmsg", SendMMsg)
+ s.Table[276] = syscalls.Supported("renameat2", Renameat2)
+ s.Table[279] = syscalls.Supported("memfd_create", MemfdCreate)
+ s.Table[281] = syscalls.Supported("execveat", Execveat)
+ s.Table[286] = syscalls.Supported("preadv2", Preadv2)
+ s.Table[287] = syscalls.Supported("pwritev2", Pwritev2)
+ s.Table[291] = syscalls.Supported("statx", Statx)
+
s.Init()
}
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index b0f57040c..31a242482 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -160,9 +160,12 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
- RemoteLinkAddress: header.EthernetBroadcastAddress,
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetBroadcastAddress
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 66e67429c..a35a64a0f 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -32,10 +32,14 @@ import (
)
const (
- stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
- stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
- stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
- stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+ stackLinkAddr1 = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackLinkAddr2 = tcpip.LinkAddress("\x0b\x0b\x0c\x0c\x0d\x0d")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
type testContext struct {
@@ -50,8 +54,7 @@ func newTestContext(t *testing.T) *testContext {
TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
})
- const defaultMTU = 65536
- ep := channel.New(256, defaultMTU, stackLinkAddr)
+ ep := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
@@ -119,7 +122,7 @@ func TestDirectRequest(t *testing.T) {
if !rep.IsValid() {
t.Fatalf("invalid ARP response pi.Pkt.Header.UsedLength()=%d", pi.Pkt.Header.UsedLength())
}
- if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr; got != want {
+ if got, want := tcpip.LinkAddress(rep.HardwareAddressSender()), stackLinkAddr1; got != want {
t.Errorf("got HardwareAddressSender = %s, want = %s", got, want)
}
if got, want := tcpip.Address(rep.ProtocolAddressSender()), tcpip.Address(h.ProtocolAddressTarget()); got != want {
@@ -144,3 +147,44 @@ func TestDirectRequest(t *testing.T) {
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: stackLinkAddr2,
+ expectLinkAddr: stackLinkAddr2,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: header.EthernetBroadcastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ p := arp.NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr1)
+ if err := linkRes.LinkAddressRequest(stackAddr1, stackAddr2, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr1, stackAddr2, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index ff1cb53dd..24600d877 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -504,7 +504,7 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements stack.LinkAddressResolver.
-func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
// TODO(b/148672031): Use stack.FindRoute instead of manually creating the
@@ -513,8 +513,12 @@ func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
- RemoteLinkAddress: header.EthernetAddressFromMulticastIPv6Address(snaddr),
+ RemoteLinkAddress: remoteLinkAddr,
}
+ if len(r.RemoteLinkAddress) == 0 {
+ r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(snaddr)
+ }
+
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 52a01b44e..f86aaed1d 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -34,6 +34,9 @@ const (
linkAddr0 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e")
linkAddr2 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
+
+ defaultChannelSize = 1
+ defaultMTU = 65536
)
var (
@@ -257,8 +260,7 @@ func newTestContext(t *testing.T) *testContext {
}),
}
- const defaultMTU = 65536
- c.linkEP0 = channel.New(256, defaultMTU, linkAddr0)
+ c.linkEP0 = channel.New(defaultChannelSize, defaultMTU, linkAddr0)
wrappedEP0 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP0})
if testing.Verbose() {
@@ -271,7 +273,7 @@ func newTestContext(t *testing.T) *testContext {
t.Fatalf("AddAddress lladdr0: %v", err)
}
- c.linkEP1 = channel.New(256, defaultMTU, linkAddr1)
+ c.linkEP1 = channel.New(defaultChannelSize, defaultMTU, linkAddr1)
wrappedEP1 := stack.LinkEndpoint(endpointWithResolutionCapability{LinkEndpoint: c.linkEP1})
if err := c.s1.CreateNIC(1, wrappedEP1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
@@ -951,3 +953,47 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) {
})
}
}
+
+func TestLinkAddressRequest(t *testing.T) {
+ snaddr := header.SolicitedNodeAddr(lladdr0)
+ mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr)
+
+ tests := []struct {
+ name string
+ remoteLinkAddr tcpip.LinkAddress
+ expectLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "Unicast",
+ remoteLinkAddr: linkAddr1,
+ expectLinkAddr: linkAddr1,
+ },
+ {
+ name: "Multicast",
+ remoteLinkAddr: "",
+ expectLinkAddr: mcaddr,
+ },
+ }
+
+ for _, test := range tests {
+ p := NewProtocol()
+ linkRes, ok := p.(stack.LinkAddressResolver)
+ if !ok {
+ t.Fatalf("expected IPv6 protocol to implement stack.LinkAddressResolver")
+ }
+
+ linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0)
+ if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil {
+ t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err)
+ }
+
+ pkt, ok := linkEP.Read()
+ if !ok {
+ t.Fatal("expected to send a link address request")
+ }
+
+ if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want {
+ t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/forwarder_test.go b/pkg/tcpip/stack/forwarder_test.go
index bca1d940b..c962693f5 100644
--- a/pkg/tcpip/stack/forwarder_test.go
+++ b/pkg/tcpip/stack/forwarder_test.go
@@ -121,10 +121,12 @@ func (*fwdTestNetworkEndpoint) Close() {}
type fwdTestNetworkProtocol struct {
addrCache *linkAddrCache
addrResolveDelay time.Duration
- onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address)
+ onLinkAddressResolved func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress)
onResolveStaticAddress func(tcpip.Address) (tcpip.LinkAddress, bool)
}
+var _ LinkAddressResolver = (*fwdTestNetworkProtocol)(nil)
+
func (f *fwdTestNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fwdTestNetNumber
}
@@ -174,10 +176,10 @@ func (f *fwdTestNetworkProtocol) Close() {}
func (f *fwdTestNetworkProtocol) Wait() {}
-func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error {
+func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error {
if f.addrCache != nil && f.onLinkAddressResolved != nil {
time.AfterFunc(f.addrResolveDelay, func() {
- f.onLinkAddressResolved(f.addrCache, addr)
+ f.onLinkAddressResolved(f.addrCache, addr, remoteLinkAddr)
})
}
return nil
@@ -405,7 +407,7 @@ func TestForwardingWithFakeResolver(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any address will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -463,7 +465,7 @@ func TestForwardingWithFakeResolverPartialTimeout(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Only packets to address 3 will be resolved to the
// link address "c".
if addr == "\x03" {
@@ -515,7 +517,7 @@ func TestForwardingWithFakeResolverTwoPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -559,7 +561,7 @@ func TestForwardingWithFakeResolverManyPackets(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
@@ -616,7 +618,7 @@ func TestForwardingWithFakeResolverManyResolutions(t *testing.T) {
// Create a network protocol with a fake resolver.
proto := &fwdTestNetworkProtocol{
addrResolveDelay: 500 * time.Millisecond,
- onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address) {
+ onLinkAddressResolved: func(cache *linkAddrCache, addr tcpip.Address, _ tcpip.LinkAddress) {
// Any packets will be resolved to the link address "c".
cache.add(tcpip.FullAddress{NIC: 2, Addr: addr}, "c")
},
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 403557fd7..6f73a0ce4 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -244,7 +244,7 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
- linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP)
select {
case now := <-time.After(c.resolutionTimeout):
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 1baa498d0..b15b8d1cb 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -48,7 +48,7 @@ type testLinkAddressResolver struct {
onLinkAddressRequest func()
}
-func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
time.AfterFunc(r.delay, func() { r.fakeRequest(addr) })
if f := r.onLinkAddressRequest; f != nil {
f()
diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go
index c477e31d8..a70792b50 100644
--- a/pkg/tcpip/stack/nic_test.go
+++ b/pkg/tcpip/stack/nic_test.go
@@ -243,7 +243,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
}
// LinkAddressRequest implements LinkAddressResolver.
-func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error {
return nil
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 9e1b2d25f..8604c4259 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -478,12 +478,13 @@ type InjectableLinkEndpoint interface {
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
- // LinkAddressRequest sends a request for the LinkAddress of addr.
- // The request is sent on linkEP with localAddr as the source.
+ // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts
+ // the request on the local network if remoteLinkAddr is the zero value. The
+ // request is sent on linkEP with localAddr as the source.
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
- LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+ LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 18ff89ffc..e860ee484 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -49,6 +49,7 @@ go_library(
"segment_heap.go",
"segment_queue.go",
"segment_state.go",
+ "segment_unsafe.go",
"snd.go",
"snd_state.go",
"tcp_endpoint_list.go",
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index b34e47bbd..5d6174a59 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -49,7 +49,7 @@ const (
// DefaultReceiveBufferSize is the default size of the receive buffer
// for an endpoint.
- DefaultReceiveBufferSize = 1 << 20 // 1MB
+ DefaultReceiveBufferSize = 32 << 10 // 32KB
// MaxBufferSize is the largest size a receive/send buffer can grow to.
MaxBufferSize = 4 << 20 // 4MB
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index dd89a292a..5e0bfe585 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -372,7 +372,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
// We only store the segment if it's within our buffer
// size limit.
if r.pendingBufUsed < r.pendingBufSize {
- r.pendingBufUsed += s.logicalLen()
+ r.pendingBufUsed += seqnum.Size(s.segMemSize())
s.incRef()
heap.Push(&r.pendingRcvdSegments, s)
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
@@ -406,7 +406,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) {
}
heap.Pop(&r.pendingRcvdSegments)
- r.pendingBufUsed -= s.logicalLen()
+ r.pendingBufUsed -= seqnum.Size(s.segMemSize())
s.decRef()
}
return false, nil
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 0280892a8..bb60dc29d 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -138,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size {
return l
}
+// segMemSize is the amount of memory used to hold the segment data and
+// the associated metadata.
+func (s *segment) segMemSize() int {
+ return segSize + s.data.Size()
+}
+
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the header.
diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go
new file mode 100644
index 000000000..0ab7b8f56
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_unsafe.go
@@ -0,0 +1,23 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "unsafe"
+)
+
+const (
+ segSize = int(unsafe.Sizeof(segment{}))
+)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 06fde2a79..37e7767d6 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -143,12 +143,14 @@ func New(t *testing.T, mtu uint32) *Context {
TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
})
+ const sendBufferSize = 1 << 20 // 1 MiB
+ const recvBufferSize = 1 << 20 // 1 MiB
// Allow minimum send/receive buffer sizes to be 1 during tests.
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
- if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil {
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %s", err)
}
@@ -202,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context {
t: t,
s: s,
linkEP: ep,
- WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)),
+ WindowScale: uint8(tcp.FindWndScale(recvBufferSize)),
}
}
diff --git a/pkg/test/dockerutil/BUILD b/pkg/test/dockerutil/BUILD
index 83b80c8bc..a5e84658a 100644
--- a/pkg/test/dockerutil/BUILD
+++ b/pkg/test/dockerutil/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -10,6 +10,7 @@ go_library(
"dockerutil.go",
"exec.go",
"network.go",
+ "profile.go",
],
visibility = ["//:sandbox"],
deps = [
@@ -23,3 +24,19 @@ go_library(
"@com_github_docker_go_connections//nat:go_default_library",
],
)
+
+go_test(
+ name = "profile_test",
+ size = "large",
+ srcs = [
+ "profile_test.go",
+ ],
+ library = ":dockerutil",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ # Also requires the test to be run as root.
+ "manual",
+ "local",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/test/dockerutil/README.md b/pkg/test/dockerutil/README.md
new file mode 100644
index 000000000..870292096
--- /dev/null
+++ b/pkg/test/dockerutil/README.md
@@ -0,0 +1,86 @@
+# dockerutil
+
+This package is for creating and controlling docker containers for testing
+runsc, gVisor's docker/kubernetes binary. A simple test may look like:
+
+```
+ func TestSuperCool(t *testing.T) {
+ ctx := context.Background()
+ c := dockerutil.MakeContainer(ctx, t)
+ got, err := c.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/alpine"
+ }, "echo", "super cool")
+ if err != nil {
+ t.Fatalf("err was not nil: %v", err)
+ }
+ want := "super cool"
+ if !strings.Contains(got, want){
+ t.Fatalf("want: %s, got: %s", want, got)
+ }
+ }
+```
+
+For further examples, see many of our end to end tests elsewhere in the repo,
+such as those in //test/e2e or benchmarks at //test/benchmarks.
+
+dockerutil uses the "official" docker golang api, which is
+[very powerful](https://godoc.org/github.com/docker/docker/client). dockerutil
+is a thin wrapper around this API, allowing desired new use cases to be easily
+implemented.
+
+## Profiling
+
+dockerutil is capable of generating profiles. Currently, the only option is to
+use pprof profiles generated by `runsc debug`. The profiler will generate Block,
+CPU, Heap, Goroutine, and Mutex profiles. To generate profiles:
+
+* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc
+ ARGS="--profile"` Also add other flags with ARGS like `--platform=kvm` or
+ `--vfs2`.
+* Restart docker: `sudo service docker restart`
+
+To run and generate CPU profiles run:
+
+```
+make sudo TARGETS=//path/to:target \
+ ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt"
+```
+
+Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof`
+
+Container name in most tests and benchmarks in gVisor is usually the test name
+and some random characters like so:
+`BenchmarkABSL-CleanCache-JF2J2ZYF3U7SL47QAA727CSJI3C4ZAW2`
+
+Profiling requires root as runsc debug inspects running containers in /var/run
+among other things.
+
+### Writing for Profiling
+
+The below shows an example of using profiles with dockerutil.
+
+```
+func TestSuperCool(t *testing.T){
+ ctx := context.Background()
+ // profiled and using runtime from dockerutil.runtime flag
+ profiled := MakeContainer()
+
+ // not profiled and using runtime runc
+ native := MakeNativeContainer()
+
+ err := profiled.Spawn(ctx, RunOpts{
+ Image: "some/image",
+ }, "sleep", "100000")
+ // profiling has begun here
+ ...
+ expensive setup that I don't want to profile.
+ ...
+ profiled.RestartProfiles()
+ // profiled activity
+}
+```
+
+In the above example, `profiled` would be profiled and `native` would not. The
+call to `RestartProfiles()` restarts the clock on profiling. This is useful if
+the main activity being tested is done with `docker exec` or `container.Spawn()`
+followed by one or more `container.Exec()` calls.
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 17acdaf6f..b59503188 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -43,15 +43,21 @@ import (
// See: https://pkg.go.dev/github.com/docker/docker.
type Container struct {
Name string
- Runtime string
+ runtime string
logger testutil.Logger
client *client.Client
id string
mounts []mount.Mount
links []string
- cleanups []func()
copyErr error
+ cleanups []func()
+
+ // Profiles are profiles added to this container. They contain methods
+ // that are run after Creation, Start, and Cleanup of this Container, along
+ // a handle to restart the profile. Generally, tests/benchmarks using
+ // profiles need to run as root.
+ profiles []Profile
// Stores streams attached to the container. Used by WaitForOutputSubmatch.
streams types.HijackedResponse
@@ -106,7 +112,19 @@ type RunOpts struct {
// MakeContainer sets up the struct for a Docker container.
//
// Names of containers will be unique.
+// Containers will check flags for profiling requests.
func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
+ c := MakeNativeContainer(ctx, logger)
+ c.runtime = *runtime
+ if p := MakePprofFromFlags(c); p != nil {
+ c.AddProfile(p)
+ }
+ return c
+}
+
+// MakeNativeContainer sets up the struct for a DockerContainer using runc. Native
+// containers aren't profiled.
+func MakeNativeContainer(ctx context.Context, logger testutil.Logger) *Container {
// Slashes are not allowed in container names.
name := testutil.RandomID(logger.Name())
name = strings.ReplaceAll(name, "/", "-")
@@ -114,20 +132,33 @@ func MakeContainer(ctx context.Context, logger testutil.Logger) *Container {
if err != nil {
return nil
}
-
client.NegotiateAPIVersion(ctx)
-
return &Container{
logger: logger,
Name: name,
- Runtime: *runtime,
+ runtime: "",
client: client,
}
}
+// AddProfile adds a profile to this container.
+func (c *Container) AddProfile(p Profile) {
+ c.profiles = append(c.profiles, p)
+}
+
+// RestartProfiles calls Restart on all profiles for this container.
+func (c *Container) RestartProfiles() error {
+ for _, profile := range c.profiles {
+ if err := profile.Restart(c); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
// Spawn is analogous to 'docker run -d'.
func (c *Container) Spawn(ctx context.Context, r RunOpts, args ...string) error {
- if err := c.create(ctx, r, args); err != nil {
+ if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
return err
}
return c.Start(ctx)
@@ -153,7 +184,7 @@ func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string)
// Run is analogous to 'docker run'.
func (c *Container) Run(ctx context.Context, r RunOpts, args ...string) (string, error) {
- if err := c.create(ctx, r, args); err != nil {
+ if err := c.create(ctx, c.config(r, args), c.hostConfig(r), nil); err != nil {
return "", err
}
@@ -181,27 +212,25 @@ func (c *Container) MakeLink(target string) string {
// CreateFrom creates a container from the given configs.
func (c *Container) CreateFrom(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
- cont, err := c.client.ContainerCreate(ctx, conf, hostconf, netconf, c.Name)
- if err != nil {
- return err
- }
- c.id = cont.ID
- return nil
+ return c.create(ctx, conf, hostconf, netconf)
}
// Create is analogous to 'docker create'.
func (c *Container) Create(ctx context.Context, r RunOpts, args ...string) error {
- return c.create(ctx, r, args)
+ return c.create(ctx, c.config(r, args), c.hostConfig(r), nil)
}
-func (c *Container) create(ctx context.Context, r RunOpts, args []string) error {
- conf := c.config(r, args)
- hostconf := c.hostConfig(r)
+func (c *Container) create(ctx context.Context, conf *container.Config, hostconf *container.HostConfig, netconf *network.NetworkingConfig) error {
cont, err := c.client.ContainerCreate(ctx, conf, hostconf, nil, c.Name)
if err != nil {
return err
}
c.id = cont.ID
+ for _, profile := range c.profiles {
+ if err := profile.OnCreate(c); err != nil {
+ return fmt.Errorf("OnCreate method failed with: %v", err)
+ }
+ }
return nil
}
@@ -227,7 +256,7 @@ func (c *Container) hostConfig(r RunOpts) *container.HostConfig {
c.mounts = append(c.mounts, r.Mounts...)
return &container.HostConfig{
- Runtime: c.Runtime,
+ Runtime: c.runtime,
Mounts: c.mounts,
PublishAllPorts: true,
Links: r.Links,
@@ -261,8 +290,15 @@ func (c *Container) Start(ctx context.Context) error {
c.cleanups = append(c.cleanups, func() {
c.streams.Close()
})
-
- return c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{})
+ if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil {
+ return fmt.Errorf("ContainerStart failed: %v", err)
+ }
+ for _, profile := range c.profiles {
+ if err := profile.OnStart(c); err != nil {
+ return fmt.Errorf("OnStart method failed: %v", err)
+ }
+ }
+ return nil
}
// Stop is analogous to 'docker stop'.
@@ -482,6 +518,12 @@ func (c *Container) Remove(ctx context.Context) error {
// CleanUp kills and deletes the container (best effort).
func (c *Container) CleanUp(ctx context.Context) {
+ // Execute profile cleanups before the container goes down.
+ for _, profile := range c.profiles {
+ profile.OnCleanUp(c)
+ }
+ // Forget profiles.
+ c.profiles = nil
// Kill the container.
if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") {
// Just log; can't do anything here.
diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go
index df09babf3..5a9dd8bd8 100644
--- a/pkg/test/dockerutil/dockerutil.go
+++ b/pkg/test/dockerutil/dockerutil.go
@@ -25,6 +25,7 @@ import (
"os/exec"
"regexp"
"strconv"
+ "time"
"gvisor.dev/gvisor/pkg/test/testutil"
)
@@ -42,6 +43,26 @@ var (
// config is the default Docker daemon configuration path.
config = flag.String("config_path", "/etc/docker/daemon.json", "configuration file for reading paths")
+
+ // The following flags are for the "pprof" profiler tool.
+
+ // pprofBaseDir allows the user to change the directory to which profiles are
+ // written. By default, profiles will appear under:
+ // /tmp/profile/RUNTIME/CONTAINER_NAME/*.pprof.
+ pprofBaseDir = flag.String("pprof-dir", "/tmp/profile", "base directory in: BASEDIR/RUNTIME/CONTINER_NAME/FILENAME (e.g. /tmp/profile/runtime/mycontainer/cpu.pprof)")
+
+ // duration is the max duration `runsc debug` will run and capture profiles.
+ // If the container's clean up method is called prior to duration, the
+ // profiling process will be killed.
+ duration = flag.Duration("pprof-duration", 10*time.Second, "duration to run the profile in seconds")
+
+ // The below flags enable each type of profile. Multiple profiles can be
+ // enabled for each run.
+ pprofBlock = flag.Bool("pprof-block", false, "enables block profiling with runsc debug")
+ pprofCPU = flag.Bool("pprof-cpu", false, "enables CPU profiling with runsc debug")
+ pprofGo = flag.Bool("pprof-go", false, "enables goroutine profiling with runsc debug")
+ pprofHeap = flag.Bool("pprof-heap", false, "enables heap profiling with runsc debug")
+ pprofMutex = flag.Bool("pprof-mutex", false, "enables mutex profiling with runsc debug")
)
// EnsureSupportedDockerVersion checks if correct docker is installed.
diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go
new file mode 100644
index 000000000..1fab33083
--- /dev/null
+++ b/pkg/test/dockerutil/profile.go
@@ -0,0 +1,152 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "time"
+)
+
+// Profile represents profile-like operations on a container,
+// such as running perf or pprof. It is meant to be added to containers
+// such that the container type calls the Profile during its lifecycle.
+type Profile interface {
+ // OnCreate is called just after the container is created when the container
+ // has a valid ID (e.g. c.ID()).
+ OnCreate(c *Container) error
+
+ // OnStart is called just after the container is started when the container
+ // has a valid Pid (e.g. c.SandboxPid()).
+ OnStart(c *Container) error
+
+ // Restart restarts the Profile on request.
+ Restart(c *Container) error
+
+ // OnCleanUp is called during the container's cleanup method.
+ // Cleanups should just log errors if they have them.
+ OnCleanUp(c *Container) error
+}
+
+// Pprof is for running profiles with 'runsc debug'. Pprof workloads
+// should be run as root and ONLY against runsc sandboxes. The runtime
+// should have --profile set as an option in /etc/docker/daemon.json in
+// order for profiling to work with Pprof.
+type Pprof struct {
+ BasePath string // path to put profiles
+ BlockProfile bool
+ CPUProfile bool
+ GoRoutineProfile bool
+ HeapProfile bool
+ MutexProfile bool
+ Duration time.Duration // duration to run profiler e.g. '10s' or '1m'.
+ shouldRun bool
+ cmd *exec.Cmd
+ stdout io.ReadCloser
+ stderr io.ReadCloser
+}
+
+// MakePprofFromFlags makes a Pprof profile from flags.
+func MakePprofFromFlags(c *Container) *Pprof {
+ if !(*pprofBlock || *pprofCPU || *pprofGo || *pprofHeap || *pprofMutex) {
+ return nil
+ }
+ return &Pprof{
+ BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name),
+ BlockProfile: *pprofBlock,
+ CPUProfile: *pprofCPU,
+ GoRoutineProfile: *pprofGo,
+ HeapProfile: *pprofHeap,
+ MutexProfile: *pprofMutex,
+ Duration: *duration,
+ }
+}
+
+// OnCreate implements Profile.OnCreate.
+func (p *Pprof) OnCreate(c *Container) error {
+ return os.MkdirAll(p.BasePath, 0755)
+}
+
+// OnStart implements Profile.OnStart.
+func (p *Pprof) OnStart(c *Container) error {
+ path, err := RuntimePath()
+ if err != nil {
+ return fmt.Errorf("failed to get runtime path: %v", err)
+ }
+
+ // The root directory of this container's runtime.
+ root := fmt.Sprintf("--root=/var/run/docker/runtime-%s/moby", c.runtime)
+ // Format is `runsc --root=rootdir debug --profile-*=file --duration=* containerID`.
+ args := []string{root, "debug"}
+ args = append(args, p.makeProfileArgs(c)...)
+ args = append(args, c.ID())
+
+ // Best effort wait until container is running.
+ for now := time.Now(); time.Since(now) < 5*time.Second; {
+ if status, err := c.Status(context.Background()); err != nil {
+ return fmt.Errorf("failed to get status with: %v", err)
+
+ } else if status.Running {
+ break
+ }
+ time.Sleep(500 * time.Millisecond)
+ }
+ p.cmd = exec.Command(path, args...)
+ if err := p.cmd.Start(); err != nil {
+ return fmt.Errorf("process failed: %v", err)
+ }
+ return nil
+}
+
+// Restart implements Profile.Restart.
+func (p *Pprof) Restart(c *Container) error {
+ p.OnCleanUp(c)
+ return p.OnStart(c)
+}
+
+// OnCleanUp implements Profile.OnCleanup
+func (p *Pprof) OnCleanUp(c *Container) error {
+ defer func() { p.cmd = nil }()
+ if p.cmd != nil && p.cmd.Process != nil && p.cmd.ProcessState != nil && !p.cmd.ProcessState.Exited() {
+ return p.cmd.Process.Kill()
+ }
+ return nil
+}
+
+// makeProfileArgs turns Pprof fields into runsc debug flags.
+func (p *Pprof) makeProfileArgs(c *Container) []string {
+ var ret []string
+ if p.BlockProfile {
+ ret = append(ret, fmt.Sprintf("--profile-block=%s", filepath.Join(p.BasePath, "block.pprof")))
+ }
+ if p.CPUProfile {
+ ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof")))
+ }
+ if p.GoRoutineProfile {
+ ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof")))
+ }
+ if p.HeapProfile {
+ ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof")))
+ }
+ if p.MutexProfile {
+ ret = append(ret, fmt.Sprintf("--profile-mutex=%s", filepath.Join(p.BasePath, "mutex.pprof")))
+ }
+ ret = append(ret, fmt.Sprintf("--duration=%s", p.Duration))
+ return ret
+}
diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go
new file mode 100644
index 000000000..b7b4d7618
--- /dev/null
+++ b/pkg/test/dockerutil/profile_test.go
@@ -0,0 +1,117 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package dockerutil
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+ "time"
+)
+
+type testCase struct {
+ name string
+ pprof Pprof
+ expectedFiles []string
+}
+
+func TestPprof(t *testing.T) {
+ // Basepath and expected file names for each type of profile.
+ basePath := "/tmp/test/profile"
+ block := "block.pprof"
+ cpu := "cpu.pprof"
+ goprofle := "go.pprof"
+ heap := "heap.pprof"
+ mutex := "mutex.pprof"
+
+ testCases := []testCase{
+ {
+ name: "Cpu",
+ pprof: Pprof{
+ BasePath: basePath,
+ CPUProfile: true,
+ Duration: 2 * time.Second,
+ },
+ expectedFiles: []string{cpu},
+ },
+ {
+ name: "All",
+ pprof: Pprof{
+ BasePath: basePath,
+ BlockProfile: true,
+ CPUProfile: true,
+ GoRoutineProfile: true,
+ HeapProfile: true,
+ MutexProfile: true,
+ Duration: 2 * time.Second,
+ },
+ expectedFiles: []string{block, cpu, goprofle, heap, mutex},
+ },
+ }
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ ctx := context.Background()
+ c := MakeContainer(ctx, t)
+ // Set basepath to include the container name so there are no conflicts.
+ tc.pprof.BasePath = filepath.Join(tc.pprof.BasePath, c.Name)
+ c.AddProfile(&tc.pprof)
+
+ func() {
+ defer c.CleanUp(ctx)
+ // Start a container.
+ if err := c.Spawn(ctx, RunOpts{
+ Image: "basic/alpine",
+ }, "sleep", "1000"); err != nil {
+ t.Fatalf("run failed with: %v", err)
+ }
+
+ if status, err := c.Status(context.Background()); !status.Running {
+ t.Fatalf("container is not yet running: %+v err: %v", status, err)
+ }
+
+ // End early if the expected files exist and have data.
+ for start := time.Now(); time.Since(start) < tc.pprof.Duration; time.Sleep(500 * time.Millisecond) {
+ if err := checkFiles(tc); err == nil {
+ break
+ }
+ }
+ }()
+
+ // Check all expected files exist and have data.
+ if err := checkFiles(tc); err != nil {
+ t.Fatalf(err.Error())
+ }
+ })
+ }
+}
+
+func checkFiles(tc testCase) error {
+ for _, file := range tc.expectedFiles {
+ stat, err := os.Stat(filepath.Join(tc.pprof.BasePath, file))
+ if err != nil {
+ return fmt.Errorf("stat failed with: %v", err)
+ } else if stat.Size() < 1 {
+ return fmt.Errorf("file not written to: %+v", stat)
+ }
+ }
+ return nil
+}
+
+func TestMain(m *testing.M) {
+ EnsureSupportedDockerVersion()
+ os.Exit(m.Run())
+}
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index cd76645bd..5e8247bc8 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -643,7 +643,9 @@ func TestExec(t *testing.T) {
if err != nil {
t.Fatalf("error creating temporary directory: %v", err)
}
- cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100", dir)
+ // Note that some shells may exec the final command in a sequence as
+ // an optimization. We avoid this here by adding the exit 0.
+ cmd := fmt.Sprintf("ln -s /bin/true %q/symlink && sleep 100 && exit 0", dir)
spec := testutil.NewSpecWithArgs("sh", "-c", cmd)
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
diff --git a/scripts/benchmark.sh b/scripts/benchmark.sh
deleted file mode 100755
index c49f988b8..000000000
--- a/scripts/benchmark.sh
+++ /dev/null
@@ -1,30 +0,0 @@
-#!/bin/bash
-
-# Copyright 2020 The gVisor Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-source $(dirname $0)/common.sh
-
-make load-all-images
-
-if [[ -z "${1:-}" ]]; then
- target=$(query "attr(tags, manual, tests(//test/benchmarks/...))")
-else
- target="$1"
-fi
-
-install_runsc_for_benchmarks benchmark
-
-echo $target
-benchmark_runsc $target "${@:2}"
diff --git a/scripts/common.sh b/scripts/common.sh
index 36158654f..3ca699e4a 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -42,15 +42,6 @@ function test_runsc() {
test --test_arg=--runtime=${RUNTIME} "$@"
}
-function benchmark_runsc() {
- test_runsc -c opt \
- --nocache_test_results \
- --test_arg=-test.bench=. \
- --test_arg=-test.benchmem \
- --jobs=1 \
- "$@"
-}
-
function install_runsc_for_test() {
local -r test_name=$1
shift
@@ -72,24 +63,6 @@ function install_runsc_for_test() {
"$@"
}
-function install_runsc_for_benchmarks() {
- local -r test_name=$1
- shift
- if [[ -z "${test_name}" ]]; then
- echo "Missing mandatory test name"
- exit 1
- fi
-
- # Add test to the name, so it doesn't conflict with other runtimes.
- set_runtime $(find_branch_name)_"${test_name}"
-
- # ${RUNSC_TEST_NAME} is set by tests (see dockerutil) to pass the test name
- # down to the runtime.
- install_runsc "${RUNTIME}" \
- --TESTONLY-test-name-env=RUNSC_TEST_NAME \
- "$@"
-}
-
# Installs the runsc with given runtime name. set_runtime must have been called
# to set runtime and logs location.
function install_runsc() {
diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh
index dce0a4085..07e9f3109 100755
--- a/scripts/docker_tests.sh
+++ b/scripts/docker_tests.sh
@@ -22,4 +22,6 @@ install_runsc_for_test docker
test_runsc //test/image:image_test //test/e2e:integration_test
install_runsc_for_test docker --vfs2
-test_runsc //test/image:image_test --test_filter=.*TestHelloWorld
+IMAGE_FILTER="Hello|Httpd|Ruby|Stdio"
+INTEGRATION_FILTER="LifeCycle|Pause|Connect|JobControl|Overlay|Exec|DirCreation/root"
+test_runsc //test/e2e:integration_test //test/image:image_test --test_filter="${IMAGE_FILTER}|${INTEGRATION_FILTER}"
diff --git a/test/benchmarks/README.md b/test/benchmarks/README.md
index 9ff602cf1..d1bbabf6f 100644
--- a/test/benchmarks/README.md
+++ b/test/benchmarks/README.md
@@ -13,33 +13,51 @@ To run benchmarks you will need:
* Docker installed (17.09.0 or greater).
-The easiest way to run benchmarks is to use the script at
-//scripts/benchmark.sh.
+The easiest way to setup runsc for running benchmarks is to use the make file.
+From the root directory:
-If not using the script, you will need:
+* Download images: `make load-all-images`
+* Install runsc suitable for benchmarking, which should probably not have
+ strace or debug logs enabled. For example:`make configure RUNTIME=myrunsc
+ ARGS=--platform=kvm`.
+* Restart docker: `sudo service docker restart`
-* `runsc` configured with docker
+You should now have a runtime with the following options configured in
+`/etc/docker/daemon.json`
-Note: benchmarks call the runtime by name. If docker can run it with
-`--runtime=` flag, these tools should work.
+```
+"myrunsc": {
+ "path": "/tmp/myrunsc/runsc",
+ "runtimeArgs": [
+ "--debug-log",
+ "/tmp/bench/logs/runsc.log.%TEST%.%TIMESTAMP%.%COMMAND%",
+ "--platform=kvm"
+ ]
+ },
+
+```
+
+This runtime has been configured with a debugging off and strace logs off and is
+using kvm for demonstration.
## Running benchmarks
-The easiest way to run is with the script at //scripts/benchmarks.sh. The script
-will run all benchmarks under //test/benchmarks if a target is not provided.
+Given the runtime above runtime `myrunsc`, run benchmarks with the following:
-```bash
-./script/benchmarks.sh //path/to/target
+```
+make sudo TARGETS=//path/to:target ARGS="--runtime=myrunsc -test.v \
+ -test.bench=." OPTIONS="-c opt
```
-If you want to run benchmarks manually:
-
-* Run `make load-all-images` from `//`
-* Run with:
+For example, to run only the Iperf tests:
-```bash
-bazel test --test_arg=--runtime=RUNTIME -c opt --test_output=streamed --test_timeout=600 --test_arg=-test.bench=. --nocache_test_results //path/to/target
```
+make sudo TARGETS=//test/benchmarks/network:network_test \
+ ARGS="--runtime=myrunsc -test.v -test.bench=Iperf" OPTIONS="-c opt"
+```
+
+Benchmarks are run with root as some benchmarks require root privileges to do
+things like drop caches.
## Writing benchmarks
@@ -69,6 +87,7 @@ var h harness.Harness
func BenchmarkMyCoolOne(b *testing.B) {
machine, err := h.GetMachine()
// check err
+ defer machine.CleanUp()
ctx := context.Background()
container := machine.GetContainer(ctx, b)
@@ -82,7 +101,7 @@ func BenchmarkMyCoolOne(b *testing.B) {
Image: "benchmarks/my-cool-image",
Env: []string{"MY_VAR=awesome"},
other options...see dockerutil
- }, "sh", "-c", "echo MY_VAR" ...)
+ }, "sh", "-c", "echo MY_VAR")
//check err
b.StopTimer()
@@ -107,12 +126,32 @@ Some notes on the above:
flags, remote virtual machines (eventually), and other services.
* Respect `b.N` in that users of the benchmark may want to "run for an hour"
or something of the sort.
-* Use the `b.ReportMetric` method to report custom metrics.
+* Use the `b.ReportMetric()` method to report custom metrics.
* Set the timer if time is useful for reporting. There isn't a way to turn off
default metrics in testing.B (B/op, allocs/op, ns/op).
* Take a look at dockerutil at //pkg/test/dockerutil to see all methods
available from containers. The API is based on the "official"
[docker API for golang](https://pkg.go.dev/mod/github.com/docker/docker).
-* `harness.GetMachine` marks how many machines this tests needs. If you have a
- client and server and to mark them as multiple machines, call it
- `GetMachine` twice.
+* `harness.GetMachine()` marks how many machines this tests needs. If you have
+ a client and server and to mark them as multiple machines, call
+ `harness.GetMachine()` twice.
+
+## Profiling
+
+For profiling, the runtime is required to have the `--profile` flag enabled.
+This flag loosens seccomp filters so that the runtime can write profile data to
+disk. This configuration is not recommended for production.
+
+* Install runsc with the `--profile` flag: `make configure RUNTIME=myrunsc
+ ARGS="--profile --platform=kvm --vfs2"`. The kvm and vfs2 flags are not
+ required, but are included for demonstration.
+* Restart docker: `sudo service docker restart`
+
+To run and generate CPU profiles fs_test test run:
+
+```
+make sudo TARGETS=//test/benchmarks/fs:fs_test \
+ ARGS="--runtime=myrunsc -test.v -test.bench=. --pprof-cpu" OPTIONS="-c opt"
+```
+
+Profiles would be at: `/tmp/profile/myrunsc/CONTAINERNAME/cpu.pprof`
diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD
new file mode 100644
index 000000000..5e33465cd
--- /dev/null
+++ b/test/benchmarks/database/BUILD
@@ -0,0 +1,28 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "database",
+ testonly = 1,
+ srcs = ["database.go"],
+ deps = ["//test/benchmarks/harness"],
+)
+
+go_test(
+ name = "database_test",
+ size = "enormous",
+ srcs = [
+ "redis_test.go",
+ ],
+ library = ":database",
+ tags = [
+ # Requires docker and runsc to be configured before test runs.
+ "manual",
+ "local",
+ ],
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ ],
+)
diff --git a/test/benchmarks/database/database.go b/test/benchmarks/database/database.go
new file mode 100644
index 000000000..9eeb59f9a
--- /dev/null
+++ b/test/benchmarks/database/database.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package database holds benchmarks around database applications.
+package database
+
+import (
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+var h harness.Harness
+
+// TestMain is the main method for package database.
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/database/redis_test.go b/test/benchmarks/database/redis_test.go
new file mode 100644
index 000000000..6d39f4d66
--- /dev/null
+++ b/test/benchmarks/database/redis_test.go
@@ -0,0 +1,197 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package database
+
+import (
+ "context"
+ "fmt"
+ "regexp"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+// All possible operations from redis. Note: "ping" will
+// run both PING_INLINE and PING_BUILD.
+var operations []string = []string{
+ "PING_INLINE",
+ "PING_BULK",
+ "SET",
+ "GET",
+ "INCR",
+ "LPUSH",
+ "RPUSH",
+ "LPOP",
+ "RPOP",
+ "SADD",
+ "HSET",
+ "SPOP",
+ "LRANGE_100",
+ "LRANGE_300",
+ "LRANGE_500",
+ "LRANGE_600",
+ "MSET",
+}
+
+// BenchmarkRedis runs redis-benchmark against a redis instance and reports
+// data in queries per second. Each is reported by named operation (e.g. LPUSH).
+func BenchmarkRedis(b *testing.B) {
+ clientMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer clientMachine.CleanUp()
+
+ serverMachine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer serverMachine.CleanUp()
+
+ // Redis runs on port 6379 by default.
+ port := 6379
+ ctx := context.Background()
+
+ for _, operation := range operations {
+ b.Run(operation, func(b *testing.B) {
+ server := serverMachine.GetContainer(ctx, b)
+ defer server.CleanUp(ctx)
+
+ // The redis docker container takes no arguments to run a redis server.
+ if err := server.Spawn(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ Ports: []int{port},
+ }); err != nil {
+ b.Fatalf("failed to start redis server with: %v", err)
+ }
+
+ if out, err := server.WaitForOutput(ctx, "Ready to accept connections", 3*time.Second); err != nil {
+ b.Fatalf("failed to start redis server: %v %s", err, out)
+ }
+
+ ip, err := serverMachine.IPAddress()
+ if err != nil {
+ b.Fatal("failed to get IP from server: %v", err)
+ }
+
+ serverPort, err := server.FindPort(ctx, port)
+ if err != nil {
+ b.Fatal("failed to get IP from server: %v", err)
+ }
+
+ if err = harness.WaitUntilServing(ctx, clientMachine, ip, serverPort); err != nil {
+ b.Fatalf("failed to start redis with: %v", err)
+ }
+
+ // runs redis benchmark -t operation for 100K requests against server.
+ cmd := strings.Split(
+ fmt.Sprintf("redis-benchmark --csv -t %s -h %s -p %d", operation, ip, serverPort), " ")
+
+ // There is no -t PING_BULK for redis-benchmark, so adjust the command in that case.
+ // Note that "ping" will run both PING_INLINE and PING_BULK.
+ if operation == "PING_BULK" {
+ cmd = strings.Split(
+ fmt.Sprintf("redis-benchmark --csv -t ping -h %s -p %d", ip, serverPort), " ")
+ }
+ // Reset profiles and timer to begin the measurement.
+ server.RestartProfiles()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ client := clientMachine.GetNativeContainer(ctx, b)
+ defer client.CleanUp(ctx)
+ out, err := client.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/redis",
+ }, cmd...)
+ if err != nil {
+ b.Fatalf("redis-benchmark failed with: %v", err)
+ }
+
+ // Stop time while we parse results.
+ b.StopTimer()
+ result, err := parseOperation(operation, out)
+ if err != nil {
+ b.Fatalf("parsing result %s failed with err: %v", out, err)
+ }
+ b.ReportMetric(result, operation) // operations per second
+ b.StartTimer()
+ }
+ })
+ }
+}
+
+// parseOperation grabs the metric operations per second from redis-benchmark output.
+func parseOperation(operation, data string) (float64, error) {
+ re := regexp.MustCompile(fmt.Sprintf(`"%s( .*)?","(\d*\.\d*)"`, operation))
+ match := re.FindStringSubmatch(data)
+ // If no match, simply don't add it to the result map.
+ if len(match) < 3 {
+ return 0.0, fmt.Errorf("could not find %s in %s", operation, data)
+ }
+ return strconv.ParseFloat(match[2], 64)
+}
+
+// TestParser tests the parser on sample data.
+func TestParser(t *testing.T) {
+ sampleData := `
+ "PING_INLINE","48661.80"
+ "PING_BULK","50301.81"
+ "SET","48923.68"
+ "GET","49382.71"
+ "INCR","49975.02"
+ "LPUSH","49875.31"
+ "RPUSH","50276.52"
+ "LPOP","50327.12"
+ "RPOP","50556.12"
+ "SADD","49504.95"
+ "HSET","49504.95"
+ "SPOP","50025.02"
+ "LPUSH (needed to benchmark LRANGE)","48875.86"
+ "LRANGE_100 (first 100 elements)","33955.86"
+ "LRANGE_300 (first 300 elements)","16550.81"
+ "LRANGE_500 (first 450 elements)","13653.74"
+ "LRANGE_600 (first 600 elements)","11219.57"
+ "MSET (10 keys)","44682.75"
+ `
+ wants := map[string]float64{
+ "PING_INLINE": 48661.80,
+ "PING_BULK": 50301.81,
+ "SET": 48923.68,
+ "GET": 49382.71,
+ "INCR": 49975.02,
+ "LPUSH": 49875.31,
+ "RPUSH": 50276.52,
+ "LPOP": 50327.12,
+ "RPOP": 50556.12,
+ "SADD": 49504.95,
+ "HSET": 49504.95,
+ "SPOP": 50025.02,
+ "LRANGE_100": 33955.86,
+ "LRANGE_300": 16550.81,
+ "LRANGE_500": 13653.74,
+ "LRANGE_600": 11219.57,
+ "MSET": 44682.75,
+ }
+ for op, want := range wants {
+ if got, err := parseOperation(op, sampleData); err != nil {
+ t.Fatalf("failed to parse %s: %v", op, err)
+ } else if want != got {
+ t.Fatalf("wanted %f for op %s, got %f", want, op, got)
+ }
+ }
+}
diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go
index fdcac1a7a..9b652fd43 100644
--- a/test/benchmarks/fs/bazel_test.go
+++ b/test/benchmarks/fs/bazel_test.go
@@ -15,6 +15,7 @@ package fs
import (
"context"
+ "fmt"
"strings"
"testing"
@@ -51,10 +52,10 @@ func BenchmarkABSL(b *testing.B) {
workdir := "/abseil-cpp"
- // Start a container.
+ // Start a container and sleep by an order of b.N.
if err := container.Spawn(ctx, dockerutil.RunOpts{
Image: "benchmarks/absl",
- }, "sleep", "1000"); err != nil {
+ }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil {
b.Fatalf("run failed with: %v", err)
}
@@ -67,15 +68,21 @@ func BenchmarkABSL(b *testing.B) {
workdir = "/tmp" + workdir
}
- // Drop Caches.
- if bm.clearCache {
- if out, err := machine.RunCommand("/bin/sh -c sync; echo 3 > /proc/sys/vm/drop_caches"); err != nil {
- b.Fatalf("failed to drop caches: %v %s", err, out)
- }
- }
-
+ // Restart profiles after the copy.
+ container.RestartProfiles()
b.ResetTimer()
+ // Drop Caches and bazel clean should happen inside the loop as we may use
+ // time options with b.N. (e.g. Run for an hour.)
for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ // Drop Caches for clear cache runs.
+ if bm.clearCache {
+ if out, err := machine.RunCommand("/bin/sh", "-c", "sync && sysctl vm.drop_caches=3"); err != nil {
+ b.Skipf("failed to drop caches: %v %s. You probably need root.", err, out)
+ }
+ }
+ b.StartTimer()
+
got, err := container.Exec(ctx, dockerutil.ExecOpts{
WorkDir: workdir,
}, "bazel", "build", "-c", "opt", "absl/base/...")
@@ -88,6 +95,13 @@ func BenchmarkABSL(b *testing.B) {
if !strings.Contains(got, want) {
b.Fatalf("string %s not in: %s", want, got)
}
+ // Clean bazel in case we use b.N.
+ _, err = container.Exec(ctx, dockerutil.ExecOpts{
+ WorkDir: workdir,
+ }, "bazel", "clean")
+ if err != nil {
+ b.Fatalf("build failed with: %v", err)
+ }
b.StartTimer()
}
})
diff --git a/test/benchmarks/harness/machine.go b/test/benchmarks/harness/machine.go
index 93c0db9ce..88e5e841b 100644
--- a/test/benchmarks/harness/machine.go
+++ b/test/benchmarks/harness/machine.go
@@ -25,9 +25,14 @@ import (
// Machine describes a real machine for use in benchmarks.
type Machine interface {
- // GetContainer gets a container from the machine,
+ // GetContainer gets a container from the machine. The container uses the
+ // runtime under test and is profiled if requested by flags.
GetContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container
+ // GetNativeContainer gets a native container from the machine. Native containers
+ // use runc by default and are not profiled.
+ GetNativeContainer(ctx context.Context, log testutil.Logger) *dockerutil.Container
+
// RunCommand runs cmd on this machine.
RunCommand(cmd string, args ...string) (string, error)
@@ -47,6 +52,11 @@ func (l *localMachine) GetContainer(ctx context.Context, logger testutil.Logger)
return dockerutil.MakeContainer(ctx, logger)
}
+// GetContainer implements Machine.GetContainer for localMachine.
+func (l *localMachine) GetNativeContainer(ctx context.Context, logger testutil.Logger) *dockerutil.Container {
+ return dockerutil.MakeNativeContainer(ctx, logger)
+}
+
// RunCommand implements Machine.RunCommand for localMachine.
func (l *localMachine) RunCommand(cmd string, args ...string) (string, error) {
c := exec.Command(cmd, args...)
diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go
index cc7de6426..bc551c582 100644
--- a/test/benchmarks/harness/util.go
+++ b/test/benchmarks/harness/util.go
@@ -27,12 +27,20 @@ import (
// IP:port.
func WaitUntilServing(ctx context.Context, machine Machine, server net.IP, port int) error {
var logger testutil.DefaultLogger = "netcat"
- netcat := machine.GetContainer(ctx, logger)
+ netcat := machine.GetNativeContainer(ctx, logger)
defer netcat.CleanUp(ctx)
- cmd := fmt.Sprintf("while ! nc -zv %s %d; do true; done", server.String(), port)
+ cmd := fmt.Sprintf("while ! nc -zv %s %d; do true; done", server, port)
_, err := netcat.Run(ctx, dockerutil.RunOpts{
Image: "packetdrill",
}, "sh", "-c", cmd)
return err
}
+
+// DropCaches drops caches on the provided machine. Requires root.
+func DropCaches(machine Machine) error {
+ if out, err := machine.RunCommand("/bin/sh", "-c", "sync | sysctl vm.drop_caches=3"); err != nil {
+ return fmt.Errorf("failed to drop caches: %v logs: %s", err, out)
+ }
+ return nil
+}
diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD
new file mode 100644
index 000000000..6c41fc4f6
--- /dev/null
+++ b/test/benchmarks/media/BUILD
@@ -0,0 +1,21 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "media",
+ testonly = 1,
+ srcs = ["media.go"],
+ deps = ["//test/benchmarks/harness"],
+)
+
+go_test(
+ name = "media_test",
+ size = "large",
+ srcs = ["ffmpeg_test.go"],
+ library = ":media",
+ deps = [
+ "//pkg/test/dockerutil",
+ "//test/benchmarks/harness",
+ ],
+)
diff --git a/test/benchmarks/media/ffmpeg_test.go b/test/benchmarks/media/ffmpeg_test.go
new file mode 100644
index 000000000..bfcfbab80
--- /dev/null
+++ b/test/benchmarks/media/ffmpeg_test.go
@@ -0,0 +1,52 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+package media
+
+import (
+ "context"
+ "strings"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+// BenchmarkFfmpeg runs ffmpeg in a container and records runtime.
+// BenchmarkFfmpeg should run as root to drop caches.
+func BenchmarkFfmpeg(b *testing.B) {
+ machine, err := h.GetMachine()
+ if err != nil {
+ b.Fatalf("failed to get machine: %v", err)
+ }
+ defer machine.CleanUp()
+
+ ctx := context.Background()
+ container := machine.GetContainer(ctx, b)
+ cmd := strings.Split("ffmpeg -i video.mp4 -c:v libx264 -preset veryslow output.mp4", " ")
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ b.StopTimer()
+ if err := harness.DropCaches(machine); err != nil {
+ b.Skipf("failed to drop caches: %v. You probably need root.", err)
+ }
+ b.StartTimer()
+
+ if _, err := container.Run(ctx, dockerutil.RunOpts{
+ Image: "benchmarks/ffmpeg",
+ }, cmd...); err != nil {
+ b.Fatalf("failed to run container: %v", err)
+ }
+ }
+}
diff --git a/test/benchmarks/media/media.go b/test/benchmarks/media/media.go
new file mode 100644
index 000000000..c7b35b758
--- /dev/null
+++ b/test/benchmarks/media/media.go
@@ -0,0 +1,31 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package media holds benchmarks around media processing applications.
+package media
+
+import (
+ "os"
+ "testing"
+
+ "gvisor.dev/gvisor/test/benchmarks/harness"
+)
+
+var h harness.Harness
+
+// TestMain is the main method for package media.
+func TestMain(m *testing.M) {
+ h.Init()
+ os.Exit(m.Run())
+}
diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD
index 16d267bc8..363041fb7 100644
--- a/test/benchmarks/network/BUILD
+++ b/test/benchmarks/network/BUILD
@@ -24,6 +24,7 @@ go_test(
],
deps = [
"//pkg/test/dockerutil",
+ "//pkg/test/testutil",
"//test/benchmarks/harness",
],
)
diff --git a/test/benchmarks/network/httpd_test.go b/test/benchmarks/network/httpd_test.go
index f9afdf15f..fe23ca949 100644
--- a/test/benchmarks/network/httpd_test.go
+++ b/test/benchmarks/network/httpd_test.go
@@ -52,12 +52,12 @@ func BenchmarkHttpdConcurrency(b *testing.B) {
defer serverMachine.CleanUp()
// The test iterates over client concurrency, so set other parameters.
- requests := 1000
+ requests := 10000
concurrency := []int{1, 5, 10, 25}
doc := docs["10Kb"]
for _, c := range concurrency {
- b.Run(fmt.Sprintf("%dConcurrency", c), func(b *testing.B) {
+ b.Run(fmt.Sprintf("%d", c), func(b *testing.B) {
runHttpd(b, clientMachine, serverMachine, doc, requests, c)
})
}
@@ -78,7 +78,7 @@ func BenchmarkHttpdDocSize(b *testing.B) {
}
defer serverMachine.CleanUp()
- requests := 1000
+ requests := 10000
concurrency := 1
for name, filename := range docs {
@@ -129,7 +129,7 @@ func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, doc st
harness.WaitUntilServing(ctx, clientMachine, ip, servingPort)
// Grab a client.
- client := clientMachine.GetContainer(ctx, b)
+ client := clientMachine.GetNativeContainer(ctx, b)
defer client.CleanUp(ctx)
path := fmt.Sprintf("http://%s:%d/%s", ip, servingPort, doc)
@@ -137,6 +137,7 @@ func runHttpd(b *testing.B, clientMachine, serverMachine harness.Machine, doc st
cmd = fmt.Sprintf("ab -n %d -c %d %s", requests, concurrency, path)
b.ResetTimer()
+ server.RestartProfiles()
for i := 0; i < b.N; i++ {
out, err := client.Run(ctx, dockerutil.RunOpts{
Image: "benchmarks/ab",
diff --git a/test/benchmarks/network/iperf_test.go b/test/benchmarks/network/iperf_test.go
index 664e0797e..a5e198e14 100644
--- a/test/benchmarks/network/iperf_test.go
+++ b/test/benchmarks/network/iperf_test.go
@@ -22,12 +22,13 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/test/dockerutil"
+ "gvisor.dev/gvisor/pkg/test/testutil"
"gvisor.dev/gvisor/test/benchmarks/harness"
)
func BenchmarkIperf(b *testing.B) {
+ const time = 10 // time in seconds to run the client.
- // Get two machines
clientMachine, err := h.GetMachine()
if err != nil {
b.Fatalf("failed to get machine: %v", err)
@@ -39,30 +40,32 @@ func BenchmarkIperf(b *testing.B) {
b.Fatalf("failed to get machine: %v", err)
}
defer serverMachine.CleanUp()
-
+ ctx := context.Background()
for _, bm := range []struct {
- name string
- clientRuntime string
- serverRuntime string
+ name string
+ clientFunc func(context.Context, testutil.Logger) *dockerutil.Container
+ serverFunc func(context.Context, testutil.Logger) *dockerutil.Container
}{
// We are either measuring the server or the client. The other should be
// runc. e.g. Upload sees how fast the runtime under test uploads to a native
// server.
- {name: "Upload", clientRuntime: dockerutil.Runtime(), serverRuntime: "runc"},
- {name: "Download", clientRuntime: "runc", serverRuntime: dockerutil.Runtime()},
+ {
+ name: "Upload",
+ clientFunc: clientMachine.GetContainer,
+ serverFunc: serverMachine.GetNativeContainer,
+ },
+ {
+ name: "Download",
+ clientFunc: clientMachine.GetNativeContainer,
+ serverFunc: serverMachine.GetContainer,
+ },
} {
b.Run(bm.name, func(b *testing.B) {
-
- // Get a container from the server and set its runtime.
- ctx := context.Background()
- server := serverMachine.GetContainer(ctx, b)
+ // Set up the containers.
+ server := bm.serverFunc(ctx, b)
defer server.CleanUp(ctx)
- server.Runtime = bm.serverRuntime
-
- // Get a container from the client and set its runtime.
- client := clientMachine.GetContainer(ctx, b)
+ client := bm.clientFunc(ctx, b)
defer client.CleanUp(ctx)
- client.Runtime = bm.clientRuntime
// iperf serves on port 5001 by default.
port := 5001
@@ -91,11 +94,14 @@ func BenchmarkIperf(b *testing.B) {
}
// iperf report in Kb realtime
- cmd := fmt.Sprintf("iperf -f K --realtime -c %s -p %d", ip.String(), servingPort)
+ cmd := fmt.Sprintf("iperf -f K --realtime --time %d -c %s -p %d", time, ip.String(), servingPort)
// Run the client.
b.ResetTimer()
+ // Restart the server profiles. If the server isn't being profiled
+ // this does nothing.
+ server.RestartProfiles()
for i := 0; i < b.N; i++ {
out, err := client.Run(ctx, dockerutil.RunOpts{
Image: "benchmarks/iperf",
diff --git a/test/iptables/filter_input.go b/test/iptables/filter_input.go
index 068f228bd..af4355ba8 100644
--- a/test/iptables/filter_input.go
+++ b/test/iptables/filter_input.go
@@ -81,7 +81,7 @@ func (FilterInputDropUDP) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropUDP) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, dropPort, sendloopDuration)
+ return spawnUDPLoop(ip, dropPort, sendloopDuration)
}
// FilterInputDropOnlyUDP tests that "-p udp -j DROP" only affects UDP traffic.
@@ -141,7 +141,7 @@ func (FilterInputDropUDPPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropUDPPort) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, dropPort, sendloopDuration)
+ return spawnUDPLoop(ip, dropPort, sendloopDuration)
}
// FilterInputDropDifferentUDPPort tests that dropping traffic for a single UDP port
@@ -169,7 +169,7 @@ func (FilterInputDropDifferentUDPPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropDifferentUDPPort) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputDropTCPDestPort tests that connections are not accepted on specified source ports.
@@ -269,7 +269,7 @@ func (FilterInputDropAll) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDropAll) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, dropPort, sendloopDuration)
+ return spawnUDPLoop(ip, dropPort, sendloopDuration)
}
// FilterInputMultiUDPRules verifies that multiple UDP rules are applied
@@ -365,7 +365,7 @@ func (FilterInputDefaultPolicyAccept) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDefaultPolicyAccept) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputDefaultPolicyDrop tests the default DROP policy.
@@ -396,7 +396,7 @@ func (FilterInputDefaultPolicyDrop) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDefaultPolicyDrop) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputReturnUnderflow tests that -j RETURN in a built-in chain causes
@@ -428,7 +428,7 @@ func (FilterInputReturnUnderflow) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputReturnUnderflow) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputSerializeJump verifies that we can serialize jumps.
@@ -482,7 +482,7 @@ func (FilterInputJumpBasic) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputJumpBasic) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputJumpReturn jumps, returns, and executes a rule.
@@ -512,7 +512,7 @@ func (FilterInputJumpReturn) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputJumpReturn) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputJumpReturnDrop jumps to a chain, returns, and DROPs packets.
@@ -549,7 +549,7 @@ func (FilterInputJumpReturnDrop) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputJumpReturnDrop) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, dropPort, sendloopDuration)
+ return spawnUDPLoop(ip, dropPort, sendloopDuration)
}
// FilterInputJumpBuiltin verifies that jumping to a top-levl chain is illegal.
@@ -604,7 +604,7 @@ func (FilterInputJumpTwice) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputJumpTwice) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputDestination verifies that we can filter packets via `-d
@@ -638,7 +638,7 @@ func (FilterInputDestination) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputDestination) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputInvertDestination verifies that we can filter packets via `! -d
@@ -667,7 +667,7 @@ func (FilterInputInvertDestination) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputInvertDestination) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputSource verifies that we can filter packets via `-s
@@ -696,7 +696,7 @@ func (FilterInputSource) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputSource) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// FilterInputInvertSource verifies that we can filter packets via `! -s
@@ -725,5 +725,5 @@ func (FilterInputInvertSource) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (FilterInputInvertSource) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
diff --git a/test/iptables/iptables_util.go b/test/iptables/iptables_util.go
index d4bc55b24..174694002 100644
--- a/test/iptables/iptables_util.go
+++ b/test/iptables/iptables_util.go
@@ -84,17 +84,42 @@ func listenUDP(port int, timeout time.Duration) error {
// sendUDPLoop sends 1 byte UDP packets repeatedly to the IP and port specified
// over a duration.
func sendUDPLoop(ip net.IP, port int, duration time.Duration) error {
- // Send packets for a few seconds.
+ conn, err := connectUDP(ip, port)
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+ loopUDP(conn, duration)
+ return nil
+}
+
+// spawnUDPLoop works like sendUDPLoop, but returns immediately and sends
+// packets in another goroutine.
+func spawnUDPLoop(ip net.IP, port int, duration time.Duration) error {
+ conn, err := connectUDP(ip, port)
+ if err != nil {
+ return err
+ }
+ go func() {
+ defer conn.Close()
+ loopUDP(conn, duration)
+ }()
+ return nil
+}
+
+func connectUDP(ip net.IP, port int) (net.Conn, error) {
remote := net.UDPAddr{
IP: ip,
Port: port,
}
conn, err := net.DialUDP(network, nil, &remote)
if err != nil {
- return err
+ return nil, err
}
- defer conn.Close()
+ return conn, nil
+}
+func loopUDP(conn net.Conn, duration time.Duration) {
to := time.After(duration)
for timedOut := false; !timedOut; {
// This may return an error (connection refused) if the remote
@@ -109,8 +134,6 @@ func sendUDPLoop(ip net.IP, port int, duration time.Duration) error {
time.Sleep(200 * time.Millisecond)
}
}
-
- return nil
}
// listenTCP listens for connections on a TCP port.
diff --git a/test/iptables/nat.go b/test/iptables/nat.go
index 149dec2bb..23288577d 100644
--- a/test/iptables/nat.go
+++ b/test/iptables/nat.go
@@ -67,7 +67,7 @@ func (NATPreRedirectUDPPort) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (NATPreRedirectUDPPort) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// NATPreRedirectTCPPort tests that connections are redirected on specified ports.
@@ -187,7 +187,7 @@ func (NATDropUDP) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (NATDropUDP) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// NATAcceptAll tests that all UDP packets are accepted.
@@ -213,7 +213,7 @@ func (NATAcceptAll) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (NATAcceptAll) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// NATOutRedirectIP uses iptables to select packets based on destination IP and
@@ -310,7 +310,7 @@ func (NATPreRedirectIP) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (NATPreRedirectIP) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, dropPort, sendloopDuration)
+ return spawnUDPLoop(ip, dropPort, sendloopDuration)
}
// NATPreDontRedirectIP tests that iptables matching with "-d" does not match
@@ -332,7 +332,7 @@ func (NATPreDontRedirectIP) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (NATPreDontRedirectIP) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, acceptPort, sendloopDuration)
+ return spawnUDPLoop(ip, acceptPort, sendloopDuration)
}
// NATPreRedirectInvert tests that iptables can match with "! -d".
@@ -353,7 +353,7 @@ func (NATPreRedirectInvert) ContainerAction(ip net.IP) error {
// LocalAction implements TestCase.LocalAction.
func (NATPreRedirectInvert) LocalAction(ip net.IP) error {
- return sendUDPLoop(ip, dropPort, sendloopDuration)
+ return spawnUDPLoop(ip, dropPort, sendloopDuration)
}
// NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a
diff --git a/test/packetimpact/runner/packetimpact_test.go b/test/packetimpact/runner/packetimpact_test.go
index 1a0221893..74e1e6def 100644
--- a/test/packetimpact/runner/packetimpact_test.go
+++ b/test/packetimpact/runner/packetimpact_test.go
@@ -142,7 +142,7 @@ func TestOne(t *testing.T) {
// Create the Docker container for the DUT.
dut := dockerutil.MakeContainer(ctx, logger("dut"))
if *dutPlatform == "linux" {
- dut.Runtime = ""
+ dut = dockerutil.MakeNativeContainer(ctx, logger("dut"))
}
runOpts := dockerutil.RunOpts{
@@ -208,8 +208,7 @@ func TestOne(t *testing.T) {
}
// Create the Docker container for the testbench.
- testbench := dockerutil.MakeContainer(ctx, logger("testbench"))
- testbench.Runtime = "" // The testbench always runs on Linux.
+ testbench := dockerutil.MakeNativeContainer(ctx, logger("testbench"))
tbb := path.Base(*testbenchBinary)
containerTestbenchBinary := "/packetimpact/" + tbb
diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go
index 87ce58c24..3af5f83fd 100644
--- a/test/packetimpact/testbench/connections.go
+++ b/test/packetimpact/testbench/connections.go
@@ -429,7 +429,6 @@ type Connection struct {
layerStates []layerState
injector Injector
sniffer Sniffer
- t *testing.T
}
// Returns the default incoming frame against which to match. If received is
@@ -462,7 +461,9 @@ func (conn *Connection) match(override, received Layers) bool {
}
// Close frees associated resources held by the Connection.
-func (conn *Connection) Close() {
+func (conn *Connection) Close(t *testing.T) {
+ t.Helper()
+
errs := multierr.Combine(conn.sniffer.close(), conn.injector.close())
for _, s := range conn.layerStates {
if err := s.close(); err != nil {
@@ -470,7 +471,7 @@ func (conn *Connection) Close() {
}
}
if errs != nil {
- conn.t.Fatalf("unable to close %+v: %s", conn, errs)
+ t.Fatalf("unable to close %+v: %s", conn, errs)
}
}
@@ -482,7 +483,9 @@ func (conn *Connection) Close() {
// overriden first. As an example, valid values of overrideLayers for a TCP-
// over-IPv4-over-Ethernet connection are: nil, [TCP], [IPv4, TCP], and
// [Ethernet, IPv4, TCP].
-func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...Layer) Layers {
+func (conn *Connection) CreateFrame(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) Layers {
+ t.Helper()
+
var layersToSend Layers
for i, s := range conn.layerStates {
layer := s.outgoing()
@@ -491,7 +494,7 @@ func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...L
// end.
if j := len(overrideLayers) - (len(conn.layerStates) - i); j >= 0 {
if err := layer.merge(overrideLayers[j]); err != nil {
- conn.t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err)
+ t.Fatalf("can't merge %+v into %+v: %s", layer, overrideLayers[j], err)
}
}
layersToSend = append(layersToSend, layer)
@@ -505,21 +508,25 @@ func (conn *Connection) CreateFrame(overrideLayers Layers, additionalLayers ...L
// This method is useful for sending out-of-band control messages such as
// ICMP packets, where it would not make sense to update the transport layer's
// state using the ICMP header.
-func (conn *Connection) SendFrameStateless(frame Layers) {
+func (conn *Connection) SendFrameStateless(t *testing.T, frame Layers) {
+ t.Helper()
+
outBytes, err := frame.ToBytes()
if err != nil {
- conn.t.Fatalf("can't build outgoing packet: %s", err)
+ t.Fatalf("can't build outgoing packet: %s", err)
}
- conn.injector.Send(outBytes)
+ conn.injector.Send(t, outBytes)
}
// SendFrame sends a frame on the wire and updates the state of all layers.
-func (conn *Connection) SendFrame(frame Layers) {
+func (conn *Connection) SendFrame(t *testing.T, frame Layers) {
+ t.Helper()
+
outBytes, err := frame.ToBytes()
if err != nil {
- conn.t.Fatalf("can't build outgoing packet: %s", err)
+ t.Fatalf("can't build outgoing packet: %s", err)
}
- conn.injector.Send(outBytes)
+ conn.injector.Send(t, outBytes)
// frame might have nil values where the caller wanted to use default values.
// sentFrame will have no nil values in it because it comes from parsing the
@@ -528,7 +535,7 @@ func (conn *Connection) SendFrame(frame Layers) {
// Update the state of each layer based on what was sent.
for i, s := range conn.layerStates {
if err := s.sent(sentFrame[i]); err != nil {
- conn.t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
+ t.Fatalf("Unable to update the state of %+v with %s: %s", s, sentFrame[i], err)
}
}
}
@@ -538,18 +545,22 @@ func (conn *Connection) SendFrame(frame Layers) {
//
// Types defined with Connection as the underlying type should expose
// type-safe versions of this method.
-func (conn *Connection) send(overrideLayers Layers, additionalLayers ...Layer) {
- conn.SendFrame(conn.CreateFrame(overrideLayers, additionalLayers...))
+func (conn *Connection) send(t *testing.T, overrideLayers Layers, additionalLayers ...Layer) {
+ t.Helper()
+
+ conn.SendFrame(t, conn.CreateFrame(t, overrideLayers, additionalLayers...))
}
// recvFrame gets the next successfully parsed frame (of type Layers) within the
// timeout provided. If no parsable frame arrives before the timeout, it returns
// nil.
-func (conn *Connection) recvFrame(timeout time.Duration) Layers {
+func (conn *Connection) recvFrame(t *testing.T, timeout time.Duration) Layers {
+ t.Helper()
+
if timeout <= 0 {
return nil
}
- b := conn.sniffer.Recv(timeout)
+ b := conn.sniffer.Recv(t, timeout)
if b == nil {
return nil
}
@@ -569,32 +580,36 @@ func (e *layersError) Error() string {
// Expect expects a frame with the final layerStates layer matching the
// provided Layer within the timeout specified. If it doesn't arrive in time,
// an error is returned.
-func (conn *Connection) Expect(layer Layer, timeout time.Duration) (Layer, error) {
+func (conn *Connection) Expect(t *testing.T, layer Layer, timeout time.Duration) (Layer, error) {
+ t.Helper()
+
// Make a frame that will ignore all but the final layer.
layers := make([]Layer, len(conn.layerStates))
layers[len(layers)-1] = layer
- gotFrame, err := conn.ExpectFrame(layers, timeout)
+ gotFrame, err := conn.ExpectFrame(t, layers, timeout)
if err != nil {
return nil, err
}
if len(conn.layerStates)-1 < len(gotFrame) {
return gotFrame[len(conn.layerStates)-1], nil
}
- conn.t.Fatal("the received frame should be at least as long as the expected layers")
+ t.Fatalf("the received frame should be at least as long as the expected layers, got %d layers, want at least %d layers, got frame: %#v", len(gotFrame), len(conn.layerStates), gotFrame)
panic("unreachable")
}
// ExpectFrame expects a frame that matches the provided Layers within the
// timeout specified. If one arrives in time, the Layers is returned without an
// error. If it doesn't arrive in time, it returns nil and error is non-nil.
-func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layers, error) {
+func (conn *Connection) ExpectFrame(t *testing.T, layers Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
deadline := time.Now().Add(timeout)
var errs error
for {
var gotLayers Layers
if timeout = time.Until(deadline); timeout > 0 {
- gotLayers = conn.recvFrame(timeout)
+ gotLayers = conn.recvFrame(t, timeout)
}
if gotLayers == nil {
if errs == nil {
@@ -605,7 +620,7 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer
if conn.match(layers, gotLayers) {
for i, s := range conn.layerStates {
if err := s.received(gotLayers[i]); err != nil {
- conn.t.Fatal(err)
+ t.Fatalf("failed to update test connection's layer states based on received frame: %s", err)
}
}
return gotLayers, nil
@@ -616,8 +631,10 @@ func (conn *Connection) ExpectFrame(layers Layers, timeout time.Duration) (Layer
// Drain drains the sniffer's receive buffer by receiving packets until there's
// nothing else to receive.
-func (conn *Connection) Drain() {
- conn.sniffer.Drain()
+func (conn *Connection) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
}
// TCPIPv4 maintains the state for all the layers in a TCP/IPv4 connection.
@@ -625,6 +642,8 @@ type TCPIPv4 Connection
// NewTCPIPv4 creates a new TCPIPv4 connection with reasonable defaults.
func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
+ t.Helper()
+
etherState, err := newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
@@ -650,57 +669,58 @@ func NewTCPIPv4(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv4 {
layerStates: []layerState{etherState, ipv4State, tcpState},
injector: injector,
sniffer: sniffer,
- t: t,
}
}
// Connect performs a TCP 3-way handshake. The input Connection should have a
// final TCP Layer.
-func (conn *TCPIPv4) Connect() {
- conn.t.Helper()
+func (conn *TCPIPv4) Connect(t *testing.T) {
+ t.Helper()
// Send the SYN.
- conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn)})
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)})
// Wait for the SYN-ACK.
- synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
if err != nil {
- conn.t.Fatalf("didn't get synack during handshake: %s", err)
+ t.Fatalf("didn't get synack during handshake: %s", err)
}
conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
// Send an ACK.
- conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)})
}
// ConnectWithOptions performs a TCP 3-way handshake with given TCP options.
// The input Connection should have a final TCP Layer.
-func (conn *TCPIPv4) ConnectWithOptions(options []byte) {
- conn.t.Helper()
+func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) {
+ t.Helper()
// Send the SYN.
- conn.Send(TCP{Flags: Uint8(header.TCPFlagSyn), Options: options})
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options})
// Wait for the SYN-ACK.
- synAck, err := conn.Expect(TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
if err != nil {
- conn.t.Fatalf("didn't get synack during handshake: %s", err)
+ t.Fatalf("didn't get synack during handshake: %s", err)
}
conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck
// Send an ACK.
- conn.Send(TCP{Flags: Uint8(header.TCPFlagAck)})
+ conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)})
}
// ExpectData is a convenient method that expects a Layer and the Layer after
// it. If it doens't arrive in time, it returns nil.
-func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+func (conn *TCPIPv4) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
expected := make([]Layer, len(conn.layerStates))
expected[len(expected)-1] = tcp
if payload != nil {
expected = append(expected, payload)
}
- return (*Connection)(conn).ExpectFrame(expected, timeout)
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
}
// ExpectNextData attempts to receive the next incoming segment for the
@@ -709,9 +729,11 @@ func (conn *TCPIPv4) ExpectData(tcp *TCP, payload *Payload, timeout time.Duratio
// It differs from ExpectData() in that here we are only interested in the next
// received segment, while ExpectData() can receive multiple segments for the
// connection until there is a match with given layers or a timeout.
-func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+func (conn *TCPIPv4) ExpectNextData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
// Receive the first incoming TCP segment for this connection.
- got, err := conn.ExpectData(&TCP{}, nil, timeout)
+ got, err := conn.ExpectData(t, &TCP{}, nil, timeout)
if err != nil {
return nil, err
}
@@ -720,7 +742,7 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur
expected[len(expected)-1] = tcp
if payload != nil {
expected = append(expected, payload)
- tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum()) - uint32(payload.Length()))
+ tcp.SeqNum = Uint32(uint32(*conn.RemoteSeqNum(t)) - uint32(payload.Length()))
}
if !(*Connection)(conn).match(expected, got) {
return nil, fmt.Errorf("next frame is not matching %s during %s: got %s", expected, timeout, got)
@@ -730,71 +752,91 @@ func (conn *TCPIPv4) ExpectNextData(tcp *TCP, payload *Payload, timeout time.Dur
// Send a packet with reasonable defaults. Potentially override the TCP layer in
// the connection with the provided layer and add additionLayers.
-func (conn *TCPIPv4) Send(tcp TCP, additionalLayers ...Layer) {
- (*Connection)(conn).send(Layers{&tcp}, additionalLayers...)
+func (conn *TCPIPv4) Send(t *testing.T, tcp TCP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&tcp}, additionalLayers...)
}
// Close frees associated resources held by the TCPIPv4 connection.
-func (conn *TCPIPv4) Close() {
- (*Connection)(conn).Close()
+func (conn *TCPIPv4) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
// Expect expects a frame with the TCP layer matching the provided TCP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
-func (conn *TCPIPv4) Expect(tcp TCP, timeout time.Duration) (*TCP, error) {
- layer, err := (*Connection)(conn).Expect(&tcp, timeout)
+func (conn *TCPIPv4) Expect(t *testing.T, tcp TCP, timeout time.Duration) (*TCP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &tcp, timeout)
if layer == nil {
return nil, err
}
gotTCP, ok := layer.(*TCP)
if !ok {
- conn.t.Fatalf("expected %s to be TCP", layer)
+ t.Fatalf("expected %s to be TCP", layer)
}
return gotTCP, err
}
-func (conn *TCPIPv4) tcpState() *tcpState {
+func (conn *TCPIPv4) tcpState(t *testing.T) *tcpState {
+ t.Helper()
+
state, ok := conn.layerStates[2].(*tcpState)
if !ok {
- conn.t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2])
+ t.Fatalf("got transport-layer state type=%T, expected tcpState", conn.layerStates[2])
}
return state
}
-func (conn *TCPIPv4) ipv4State() *ipv4State {
+func (conn *TCPIPv4) ipv4State(t *testing.T) *ipv4State {
+ t.Helper()
+
state, ok := conn.layerStates[1].(*ipv4State)
if !ok {
- conn.t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ t.Fatalf("expected network-layer state type=%T, expected ipv4State", conn.layerStates[1])
}
return state
}
// RemoteSeqNum returns the next expected sequence number from the DUT.
-func (conn *TCPIPv4) RemoteSeqNum() *seqnum.Value {
- return conn.tcpState().remoteSeqNum
+func (conn *TCPIPv4) RemoteSeqNum(t *testing.T) *seqnum.Value {
+ t.Helper()
+
+ return conn.tcpState(t).remoteSeqNum
}
// LocalSeqNum returns the next sequence number to send from the testbench.
-func (conn *TCPIPv4) LocalSeqNum() *seqnum.Value {
- return conn.tcpState().localSeqNum
+func (conn *TCPIPv4) LocalSeqNum(t *testing.T) *seqnum.Value {
+ t.Helper()
+
+ return conn.tcpState(t).localSeqNum
}
// SynAck returns the SynAck that was part of the handshake.
-func (conn *TCPIPv4) SynAck() *TCP {
- return conn.tcpState().synAck
+func (conn *TCPIPv4) SynAck(t *testing.T) *TCP {
+ t.Helper()
+
+ return conn.tcpState(t).synAck
}
// LocalAddr gets the local socket address of this connection.
-func (conn *TCPIPv4) LocalAddr() *unix.SockaddrInet4 {
- sa := &unix.SockaddrInet4{Port: int(*conn.tcpState().out.SrcPort)}
- copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr)
+func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
+ t.Helper()
+
+ sa := &unix.SockaddrInet4{Port: int(*conn.tcpState(t).out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr)
return sa
}
// Drain drains the sniffer's receive buffer by receiving packets until there's
// nothing else to receive.
-func (conn *TCPIPv4) Drain() {
- conn.sniffer.Drain()
+func (conn *TCPIPv4) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
}
// IPv6Conn maintains the state for all the layers in a IPv6 connection.
@@ -802,6 +844,8 @@ type IPv6Conn Connection
// NewIPv6Conn creates a new IPv6Conn connection with reasonable defaults.
func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
+ t.Helper()
+
etherState, err := newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make EtherState: %s", err)
@@ -824,25 +868,30 @@ func NewIPv6Conn(t *testing.T, outgoingIPv6, incomingIPv6 IPv6) IPv6Conn {
layerStates: []layerState{etherState, ipv6State},
injector: injector,
sniffer: sniffer,
- t: t,
}
}
// Send sends a frame with ipv6 overriding the IPv6 layer defaults and
// additionalLayers added after it.
-func (conn *IPv6Conn) Send(ipv6 IPv6, additionalLayers ...Layer) {
- (*Connection)(conn).send(Layers{&ipv6}, additionalLayers...)
+func (conn *IPv6Conn) Send(t *testing.T, ipv6 IPv6, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&ipv6}, additionalLayers...)
}
// Close to clean up any resources held.
-func (conn *IPv6Conn) Close() {
- (*Connection)(conn).Close()
+func (conn *IPv6Conn) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
// ExpectFrame expects a frame that matches the provided Layers within the
// timeout specified. If it doesn't arrive in time, an error is returned.
-func (conn *IPv6Conn) ExpectFrame(frame Layers, timeout time.Duration) (Layers, error) {
- return (*Connection)(conn).ExpectFrame(frame, timeout)
+func (conn *IPv6Conn) ExpectFrame(t *testing.T, frame Layers, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
+ return (*Connection)(conn).ExpectFrame(t, frame, timeout)
}
// UDPIPv4 maintains the state for all the layers in a UDP/IPv4 connection.
@@ -850,6 +899,8 @@ type UDPIPv4 Connection
// NewUDPIPv4 creates a new UDPIPv4 connection with reasonable defaults.
func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
+ t.Helper()
+
etherState, err := newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
@@ -875,81 +926,96 @@ func NewUDPIPv4(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv4 {
layerStates: []layerState{etherState, ipv4State, udpState},
injector: injector,
sniffer: sniffer,
- t: t,
}
}
-func (conn *UDPIPv4) udpState() *udpState {
+func (conn *UDPIPv4) udpState(t *testing.T) *udpState {
+ t.Helper()
+
state, ok := conn.layerStates[2].(*udpState)
if !ok {
- conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
}
return state
}
-func (conn *UDPIPv4) ipv4State() *ipv4State {
+func (conn *UDPIPv4) ipv4State(t *testing.T) *ipv4State {
+ t.Helper()
+
state, ok := conn.layerStates[1].(*ipv4State)
if !ok {
- conn.t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1])
+ t.Fatalf("got network-layer state type=%T, expected ipv4State", conn.layerStates[1])
}
return state
}
// LocalAddr gets the local socket address of this connection.
-func (conn *UDPIPv4) LocalAddr() *unix.SockaddrInet4 {
- sa := &unix.SockaddrInet4{Port: int(*conn.udpState().out.SrcPort)}
- copy(sa.Addr[:], *conn.ipv4State().out.SrcAddr)
+func (conn *UDPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 {
+ t.Helper()
+
+ sa := &unix.SockaddrInet4{Port: int(*conn.udpState(t).out.SrcPort)}
+ copy(sa.Addr[:], *conn.ipv4State(t).out.SrcAddr)
return sa
}
// Send sends a packet with reasonable defaults, potentially overriding the UDP
// layer and adding additionLayers.
-func (conn *UDPIPv4) Send(udp UDP, additionalLayers ...Layer) {
- (*Connection)(conn).send(Layers{&udp}, additionalLayers...)
+func (conn *UDPIPv4) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...)
}
// SendIP sends a packet with reasonable defaults, potentially overriding the
// UDP and IPv4 headers and adding additionLayers.
-func (conn *UDPIPv4) SendIP(ip IPv4, udp UDP, additionalLayers ...Layer) {
- (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...)
+func (conn *UDPIPv4) SendIP(t *testing.T, ip IPv4, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...)
}
// Expect expects a frame with the UDP layer matching the provided UDP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
-func (conn *UDPIPv4) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
- conn.t.Helper()
- layer, err := (*Connection)(conn).Expect(&udp, timeout)
+func (conn *UDPIPv4) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &udp, timeout)
if err != nil {
return nil, err
}
gotUDP, ok := layer.(*UDP)
if !ok {
- conn.t.Fatalf("expected %s to be UDP", layer)
+ t.Fatalf("expected %s to be UDP", layer)
}
return gotUDP, nil
}
// ExpectData is a convenient method that expects a Layer and the Layer after
// it. If it doens't arrive in time, it returns nil.
-func (conn *UDPIPv4) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
- conn.t.Helper()
+func (conn *UDPIPv4) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
expected := make([]Layer, len(conn.layerStates))
expected[len(expected)-1] = &udp
if payload.length() != 0 {
expected = append(expected, &payload)
}
- return (*Connection)(conn).ExpectFrame(expected, timeout)
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
}
// Close frees associated resources held by the UDPIPv4 connection.
-func (conn *UDPIPv4) Close() {
- (*Connection)(conn).Close()
+func (conn *UDPIPv4) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
// Drain drains the sniffer's receive buffer by receiving packets until there's
// nothing else to receive.
-func (conn *UDPIPv4) Drain() {
- conn.sniffer.Drain()
+func (conn *UDPIPv4) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
}
// UDPIPv6 maintains the state for all the layers in a UDP/IPv6 connection.
@@ -957,6 +1023,8 @@ type UDPIPv6 Connection
// NewUDPIPv6 creates a new UDPIPv6 connection with reasonable defaults.
func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 {
+ t.Helper()
+
etherState, err := newEtherState(Ether{}, Ether{})
if err != nil {
t.Fatalf("can't make etherState: %s", err)
@@ -981,86 +1049,101 @@ func NewUDPIPv6(t *testing.T, outgoingUDP, incomingUDP UDP) UDPIPv6 {
layerStates: []layerState{etherState, ipv6State, udpState},
injector: injector,
sniffer: sniffer,
- t: t,
}
}
-func (conn *UDPIPv6) udpState() *udpState {
+func (conn *UDPIPv6) udpState(t *testing.T) *udpState {
+ t.Helper()
+
state, ok := conn.layerStates[2].(*udpState)
if !ok {
- conn.t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
+ t.Fatalf("got transport-layer state type=%T, expected udpState", conn.layerStates[2])
}
return state
}
-func (conn *UDPIPv6) ipv6State() *ipv6State {
+func (conn *UDPIPv6) ipv6State(t *testing.T) *ipv6State {
+ t.Helper()
+
state, ok := conn.layerStates[1].(*ipv6State)
if !ok {
- conn.t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1])
+ t.Fatalf("got network-layer state type=%T, expected ipv6State", conn.layerStates[1])
}
return state
}
// LocalAddr gets the local socket address of this connection.
-func (conn *UDPIPv6) LocalAddr() *unix.SockaddrInet6 {
+func (conn *UDPIPv6) LocalAddr(t *testing.T) *unix.SockaddrInet6 {
+ t.Helper()
+
sa := &unix.SockaddrInet6{
- Port: int(*conn.udpState().out.SrcPort),
+ Port: int(*conn.udpState(t).out.SrcPort),
// Local address is in perspective to the remote host, so it's scoped to the
// ID of the remote interface.
ZoneId: uint32(RemoteInterfaceID),
}
- copy(sa.Addr[:], *conn.ipv6State().out.SrcAddr)
+ copy(sa.Addr[:], *conn.ipv6State(t).out.SrcAddr)
return sa
}
// Send sends a packet with reasonable defaults, potentially overriding the UDP
// layer and adding additionLayers.
-func (conn *UDPIPv6) Send(udp UDP, additionalLayers ...Layer) {
- (*Connection)(conn).send(Layers{&udp}, additionalLayers...)
+func (conn *UDPIPv6) Send(t *testing.T, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&udp}, additionalLayers...)
}
// SendIPv6 sends a packet with reasonable defaults, potentially overriding the
// UDP and IPv6 headers and adding additionLayers.
-func (conn *UDPIPv6) SendIPv6(ip IPv6, udp UDP, additionalLayers ...Layer) {
- (*Connection)(conn).send(Layers{&ip, &udp}, additionalLayers...)
+func (conn *UDPIPv6) SendIPv6(t *testing.T, ip IPv6, udp UDP, additionalLayers ...Layer) {
+ t.Helper()
+
+ (*Connection)(conn).send(t, Layers{&ip, &udp}, additionalLayers...)
}
// Expect expects a frame with the UDP layer matching the provided UDP within
// the timeout specified. If it doesn't arrive in time, an error is returned.
-func (conn *UDPIPv6) Expect(udp UDP, timeout time.Duration) (*UDP, error) {
- conn.t.Helper()
- layer, err := (*Connection)(conn).Expect(&udp, timeout)
+func (conn *UDPIPv6) Expect(t *testing.T, udp UDP, timeout time.Duration) (*UDP, error) {
+ t.Helper()
+
+ layer, err := (*Connection)(conn).Expect(t, &udp, timeout)
if err != nil {
return nil, err
}
gotUDP, ok := layer.(*UDP)
if !ok {
- conn.t.Fatalf("expected %s to be UDP", layer)
+ t.Fatalf("expected %s to be UDP", layer)
}
return gotUDP, nil
}
// ExpectData is a convenient method that expects a Layer and the Layer after
// it. If it doens't arrive in time, it returns nil.
-func (conn *UDPIPv6) ExpectData(udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
- conn.t.Helper()
+func (conn *UDPIPv6) ExpectData(t *testing.T, udp UDP, payload Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
expected := make([]Layer, len(conn.layerStates))
expected[len(expected)-1] = &udp
if payload.length() != 0 {
expected = append(expected, &payload)
}
- return (*Connection)(conn).ExpectFrame(expected, timeout)
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
}
// Close frees associated resources held by the UDPIPv6 connection.
-func (conn *UDPIPv6) Close() {
- (*Connection)(conn).Close()
+func (conn *UDPIPv6) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
// Drain drains the sniffer's receive buffer by receiving packets until there's
// nothing else to receive.
-func (conn *UDPIPv6) Drain() {
- conn.sniffer.Drain()
+func (conn *UDPIPv6) Drain(t *testing.T) {
+ t.Helper()
+
+ conn.sniffer.Drain(t)
}
// TCPIPv6 maintains the state for all the layers in a TCP/IPv6 connection.
@@ -1093,7 +1176,6 @@ func NewTCPIPv6(t *testing.T, outgoingTCP, incomingTCP TCP) TCPIPv6 {
layerStates: []layerState{etherState, ipv6State, tcpState},
injector: injector,
sniffer: sniffer,
- t: t,
}
}
@@ -1104,16 +1186,20 @@ func (conn *TCPIPv6) SrcPort() uint16 {
// ExpectData is a convenient method that expects a Layer and the Layer after
// it. If it doens't arrive in time, it returns nil.
-func (conn *TCPIPv6) ExpectData(tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+func (conn *TCPIPv6) ExpectData(t *testing.T, tcp *TCP, payload *Payload, timeout time.Duration) (Layers, error) {
+ t.Helper()
+
expected := make([]Layer, len(conn.layerStates))
expected[len(expected)-1] = tcp
if payload != nil {
expected = append(expected, payload)
}
- return (*Connection)(conn).ExpectFrame(expected, timeout)
+ return (*Connection)(conn).ExpectFrame(t, expected, timeout)
}
// Close frees associated resources held by the TCPIPv6 connection.
-func (conn *TCPIPv6) Close() {
- (*Connection)(conn).Close()
+func (conn *TCPIPv6) Close(t *testing.T) {
+ t.Helper()
+
+ (*Connection)(conn).Close(t)
}
diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go
index 51be13759..73c532e75 100644
--- a/test/packetimpact/testbench/dut.go
+++ b/test/packetimpact/testbench/dut.go
@@ -31,13 +31,14 @@ import (
// DUT communicates with the DUT to force it to make POSIX calls.
type DUT struct {
- t *testing.T
conn *grpc.ClientConn
posixServer POSIXClient
}
// NewDUT creates a new connection with the DUT over gRPC.
func NewDUT(t *testing.T) DUT {
+ t.Helper()
+
flag.Parse()
if err := genPseudoFlags(); err != nil {
t.Fatal("generating psuedo flags:", err)
@@ -50,7 +51,6 @@ func NewDUT(t *testing.T) DUT {
}
posixServer := NewPOSIXClient(conn)
return DUT{
- t: t,
conn: conn,
posixServer: posixServer,
}
@@ -61,8 +61,9 @@ func (dut *DUT) TearDown() {
dut.conn.Close()
}
-func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
- dut.t.Helper()
+func (dut *DUT) sockaddrToProto(t *testing.T, sa unix.Sockaddr) *pb.Sockaddr {
+ t.Helper()
+
switch s := sa.(type) {
case *unix.SockaddrInet4:
return &pb.Sockaddr{
@@ -87,12 +88,13 @@ func (dut *DUT) sockaddrToProto(sa unix.Sockaddr) *pb.Sockaddr {
},
}
}
- dut.t.Fatalf("can't parse Sockaddr struct: %+v", sa)
+ t.Fatalf("can't parse Sockaddr struct: %+v", sa)
return nil
}
-func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
- dut.t.Helper()
+func (dut *DUT) protoToSockaddr(t *testing.T, sa *pb.Sockaddr) unix.Sockaddr {
+ t.Helper()
+
switch s := sa.Sockaddr.(type) {
case *pb.Sockaddr_In:
ret := unix.SockaddrInet4{
@@ -108,31 +110,32 @@ func (dut *DUT) protoToSockaddr(sa *pb.Sockaddr) unix.Sockaddr {
copy(ret.Addr[:], s.In6.GetAddr())
return &ret
}
- dut.t.Fatalf("can't parse Sockaddr proto: %+v", sa)
+ t.Fatalf("can't parse Sockaddr proto: %#v", sa)
return nil
}
// CreateBoundSocket makes a new socket on the DUT, with type typ and protocol
// proto, and bound to the IP address addr. Returns the new file descriptor and
// the port that was selected on the DUT.
-func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16) {
- dut.t.Helper()
+func (dut *DUT) CreateBoundSocket(t *testing.T, typ, proto int32, addr net.IP) (int32, uint16) {
+ t.Helper()
+
var fd int32
if addr.To4() != nil {
- fd = dut.Socket(unix.AF_INET, typ, proto)
+ fd = dut.Socket(t, unix.AF_INET, typ, proto)
sa := unix.SockaddrInet4{}
copy(sa.Addr[:], addr.To4())
- dut.Bind(fd, &sa)
+ dut.Bind(t, fd, &sa)
} else if addr.To16() != nil {
- fd = dut.Socket(unix.AF_INET6, typ, proto)
+ fd = dut.Socket(t, unix.AF_INET6, typ, proto)
sa := unix.SockaddrInet6{}
copy(sa.Addr[:], addr.To16())
sa.ZoneId = uint32(RemoteInterfaceID)
- dut.Bind(fd, &sa)
+ dut.Bind(t, fd, &sa)
} else {
- dut.t.Fatalf("unknown ip addr type for remoteIP")
+ t.Fatalf("invalid IP address: %s", addr)
}
- sa := dut.GetSockName(fd)
+ sa := dut.GetSockName(t, fd)
var port int
switch s := sa.(type) {
case *unix.SockaddrInet4:
@@ -140,15 +143,17 @@ func (dut *DUT) CreateBoundSocket(typ, proto int32, addr net.IP) (int32, uint16)
case *unix.SockaddrInet6:
port = s.Port
default:
- dut.t.Fatalf("unknown sockaddr type from getsockname: %t", sa)
+ t.Fatalf("unknown sockaddr type from getsockname: %T", sa)
}
return fd, uint16(port)
}
// CreateListener makes a new TCP connection. If it fails, the test ends.
-func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
- fd, remotePort := dut.CreateBoundSocket(typ, proto, net.ParseIP(RemoteIPv4))
- dut.Listen(fd, backlog)
+func (dut *DUT) CreateListener(t *testing.T, typ, proto, backlog int32) (int32, uint16) {
+ t.Helper()
+
+ fd, remotePort := dut.CreateBoundSocket(t, typ, proto, net.ParseIP(RemoteIPv4))
+ dut.Listen(t, fd, backlog)
return fd, remotePort
}
@@ -158,53 +163,57 @@ func (dut *DUT) CreateListener(typ, proto, backlog int32) (int32, uint16) {
// Accept calls accept on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// AcceptWithErrno.
-func (dut *DUT) Accept(sockfd int32) (int32, unix.Sockaddr) {
- dut.t.Helper()
+func (dut *DUT) Accept(t *testing.T, sockfd int32) (int32, unix.Sockaddr) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- fd, sa, err := dut.AcceptWithErrno(ctx, sockfd)
+ fd, sa, err := dut.AcceptWithErrno(ctx, t, sockfd)
if fd < 0 {
- dut.t.Fatalf("failed to accept: %s", err)
+ t.Fatalf("failed to accept: %s", err)
}
return fd, sa
}
// AcceptWithErrno calls accept on the DUT.
-func (dut *DUT) AcceptWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
- dut.t.Helper()
+func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) {
+ t.Helper()
+
req := pb.AcceptRequest{
Sockfd: sockfd,
}
resp, err := dut.posixServer.Accept(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Accept: %s", err)
+ t.Fatalf("failed to call Accept: %s", err)
}
- return resp.GetFd(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_())
}
// Bind calls bind on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is
// needed, use BindWithErrno.
-func (dut *DUT) Bind(fd int32, sa unix.Sockaddr) {
- dut.t.Helper()
+func (dut *DUT) Bind(t *testing.T, fd int32, sa unix.Sockaddr) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.BindWithErrno(ctx, fd, sa)
+ ret, err := dut.BindWithErrno(ctx, t, fd, sa)
if ret != 0 {
- dut.t.Fatalf("failed to bind socket: %s", err)
+ t.Fatalf("failed to bind socket: %s", err)
}
}
// BindWithErrno calls bind on the DUT.
-func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) {
+ t.Helper()
+
req := pb.BindRequest{
Sockfd: fd,
- Addr: dut.sockaddrToProto(sa),
+ Addr: dut.sockaddrToProto(t, sa),
}
resp, err := dut.posixServer.Bind(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Bind: %s", err)
+ t.Fatalf("failed to call Bind: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -212,25 +221,27 @@ func (dut *DUT) BindWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (
// Close calls close on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// CloseWithErrno.
-func (dut *DUT) Close(fd int32) {
- dut.t.Helper()
+func (dut *DUT) Close(t *testing.T, fd int32) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.CloseWithErrno(ctx, fd)
+ ret, err := dut.CloseWithErrno(ctx, t, fd)
if ret != 0 {
- dut.t.Fatalf("failed to close: %s", err)
+ t.Fatalf("failed to close: %s", err)
}
}
// CloseWithErrno calls close on the DUT.
-func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int32, error) {
+ t.Helper()
+
req := pb.CloseRequest{
Fd: fd,
}
resp, err := dut.posixServer.Close(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Close: %s", err)
+ t.Fatalf("failed to call Close: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -238,28 +249,30 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, fd int32) (int32, error) {
// Connect calls connect on the DUT and causes a fatal test failure if it
// doesn't succeed. If more control over the timeout or error handling is
// needed, use ConnectWithErrno.
-func (dut *DUT) Connect(fd int32, sa unix.Sockaddr) {
- dut.t.Helper()
+func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.ConnectWithErrno(ctx, fd, sa)
+ ret, err := dut.ConnectWithErrno(ctx, t, fd, sa)
// Ignore 'operation in progress' error that can be returned when the socket
// is non-blocking.
if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 {
- dut.t.Fatalf("failed to connect socket: %s", err)
+ t.Fatalf("failed to connect socket: %s", err)
}
}
// ConnectWithErrno calls bind on the DUT.
-func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa unix.Sockaddr) (int32, error) {
+ t.Helper()
+
req := pb.ConnectRequest{
Sockfd: fd,
- Addr: dut.sockaddrToProto(sa),
+ Addr: dut.sockaddrToProto(t, sa),
}
resp, err := dut.posixServer.Connect(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Connect: %s", err)
+ t.Fatalf("failed to call Connect: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -267,20 +280,22 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, fd int32, sa unix.Sockaddr
// Fcntl calls fcntl on the DUT and causes a fatal test failure if it
// doesn't succeed. If more control over the timeout or error handling is
// needed, use FcntlWithErrno.
-func (dut *DUT) Fcntl(fd, cmd, arg int32) int32 {
- dut.t.Helper()
+func (dut *DUT) Fcntl(t *testing.T, fd, cmd, arg int32) int32 {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.FcntlWithErrno(ctx, fd, cmd, arg)
+ ret, err := dut.FcntlWithErrno(ctx, t, fd, cmd, arg)
if ret == -1 {
- dut.t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err)
+ t.Fatalf("failed to Fcntl: ret=%d, errno=%s", ret, err)
}
return ret
}
// FcntlWithErrno calls fcntl on the DUT.
-func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) FcntlWithErrno(ctx context.Context, t *testing.T, fd, cmd, arg int32) (int32, error) {
+ t.Helper()
+
req := pb.FcntlRequest{
Fd: fd,
Cmd: cmd,
@@ -288,7 +303,7 @@ func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32,
}
resp, err := dut.posixServer.Fcntl(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Fcntl: %s", err)
+ t.Fatalf("failed to call Fcntl: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -296,32 +311,35 @@ func (dut *DUT) FcntlWithErrno(ctx context.Context, fd, cmd, arg int32) (int32,
// GetSockName calls getsockname on the DUT and causes a fatal test failure if
// it doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockNameWithErrno.
-func (dut *DUT) GetSockName(sockfd int32) unix.Sockaddr {
- dut.t.Helper()
+func (dut *DUT) GetSockName(t *testing.T, sockfd int32) unix.Sockaddr {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, sa, err := dut.GetSockNameWithErrno(ctx, sockfd)
+ ret, sa, err := dut.GetSockNameWithErrno(ctx, t, sockfd)
if ret != 0 {
- dut.t.Fatalf("failed to getsockname: %s", err)
+ t.Fatalf("failed to getsockname: %s", err)
}
return sa
}
// GetSockNameWithErrno calls getsockname on the DUT.
-func (dut *DUT) GetSockNameWithErrno(ctx context.Context, sockfd int32) (int32, unix.Sockaddr, error) {
- dut.t.Helper()
+func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd int32) (int32, unix.Sockaddr, error) {
+ t.Helper()
+
req := pb.GetSockNameRequest{
Sockfd: sockfd,
}
resp, err := dut.posixServer.GetSockName(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Bind: %s", err)
+ t.Fatalf("failed to call Bind: %s", err)
}
- return resp.GetRet(), dut.protoToSockaddr(resp.GetAddr()), syscall.Errno(resp.GetErrno_())
+ return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_())
}
-func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) {
- dut.t.Helper()
+func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) {
+ t.Helper()
+
req := pb.GetSockOptRequest{
Sockfd: sockfd,
Level: level,
@@ -331,11 +349,11 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i
}
resp, err := dut.posixServer.GetSockOpt(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call GetSockOpt: %s", err)
+ t.Fatalf("failed to call GetSockOpt: %s", err)
}
optval := resp.GetOptval()
if optval == nil {
- dut.t.Fatalf("GetSockOpt response does not contain a value")
+ t.Fatalf("GetSockOpt response does not contain a value")
}
return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_())
}
@@ -345,13 +363,14 @@ func (dut *DUT) getSockOpt(ctx context.Context, sockfd, level, optname, optlen i
// needed, use GetSockOptWithErrno. Because endianess and the width of values
// might differ between the testbench and DUT architectures, prefer to use a
// more specific GetSockOptXxx function.
-func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte {
- dut.t.Helper()
+func (dut *DUT) GetSockOpt(t *testing.T, sockfd, level, optname, optlen int32) []byte {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, optval, err := dut.GetSockOptWithErrno(ctx, sockfd, level, optname, optlen)
+ ret, optval, err := dut.GetSockOptWithErrno(ctx, t, sockfd, level, optname, optlen)
if ret != 0 {
- dut.t.Fatalf("failed to GetSockOpt: %s", err)
+ t.Fatalf("failed to GetSockOpt: %s", err)
}
return optval
}
@@ -359,12 +378,13 @@ func (dut *DUT) GetSockOpt(sockfd, level, optname, optlen int32) []byte {
// GetSockOptWithErrno calls getsockopt on the DUT. Because endianess and the
// width of values might differ between the testbench and DUT architectures,
// prefer to use a more specific GetSockOptXxxWithErrno function.
-func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname, optlen int32) (int32, []byte, error) {
- dut.t.Helper()
- ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES)
+func (dut *DUT) GetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32) (int32, []byte, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, optlen, pb.GetSockOptRequest_BYTES)
bytesval, ok := optval.Val.(*pb.SockOptVal_Bytesval)
if !ok {
- dut.t.Fatalf("GetSockOpt got value type: %T, want bytes", optval)
+ t.Fatalf("GetSockOpt got value type: %T, want bytes", optval.Val)
}
return ret, bytesval.Bytesval, errno
}
@@ -372,24 +392,26 @@ func (dut *DUT) GetSockOptWithErrno(ctx context.Context, sockfd, level, optname,
// GetSockOptInt calls getsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the int optval or error handling
// is needed, use GetSockOptIntWithErrno.
-func (dut *DUT) GetSockOptInt(sockfd, level, optname int32) int32 {
- dut.t.Helper()
+func (dut *DUT) GetSockOptInt(t *testing.T, sockfd, level, optname int32) int32 {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, intval, err := dut.GetSockOptIntWithErrno(ctx, sockfd, level, optname)
+ ret, intval, err := dut.GetSockOptIntWithErrno(ctx, t, sockfd, level, optname)
if ret != 0 {
- dut.t.Fatalf("failed to GetSockOptInt: %s", err)
+ t.Fatalf("failed to GetSockOptInt: %s", err)
}
return intval
}
// GetSockOptIntWithErrno calls getsockopt with an integer optval.
-func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, int32, error) {
- dut.t.Helper()
- ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_INT)
+func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, int32, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_INT)
intval, ok := optval.Val.(*pb.SockOptVal_Intval)
if !ok {
- dut.t.Fatalf("GetSockOpt got value type: %T, want int", optval)
+ t.Fatalf("GetSockOpt got value type: %T, want int", optval.Val)
}
return ret, intval.Intval, errno
}
@@ -397,24 +419,26 @@ func (dut *DUT) GetSockOptIntWithErrno(ctx context.Context, sockfd, level, optna
// GetSockOptTimeval calls getsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the timeout or error handling is
// needed, use GetSockOptTimevalWithErrno.
-func (dut *DUT) GetSockOptTimeval(sockfd, level, optname int32) unix.Timeval {
- dut.t.Helper()
+func (dut *DUT) GetSockOptTimeval(t *testing.T, sockfd, level, optname int32) unix.Timeval {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, sockfd, level, optname)
+ ret, timeval, err := dut.GetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname)
if ret != 0 {
- dut.t.Fatalf("failed to GetSockOptTimeval: %s", err)
+ t.Fatalf("failed to GetSockOptTimeval: %s", err)
}
return timeval
}
// GetSockOptTimevalWithErrno calls getsockopt and returns a timeval.
-func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32) (int32, unix.Timeval, error) {
- dut.t.Helper()
- ret, optval, errno := dut.getSockOpt(ctx, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME)
+func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32) (int32, unix.Timeval, error) {
+ t.Helper()
+
+ ret, optval, errno := dut.getSockOpt(ctx, t, sockfd, level, optname, 0, pb.GetSockOptRequest_TIME)
tv, ok := optval.Val.(*pb.SockOptVal_Timeval)
if !ok {
- dut.t.Fatalf("GetSockOpt got value type: %T, want timeval", optval)
+ t.Fatalf("GetSockOpt got value type: %T, want timeval", optval.Val)
}
timeval := unix.Timeval{
Sec: tv.Timeval.Seconds,
@@ -426,26 +450,28 @@ func (dut *DUT) GetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, o
// Listen calls listen on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// ListenWithErrno.
-func (dut *DUT) Listen(sockfd, backlog int32) {
- dut.t.Helper()
+func (dut *DUT) Listen(t *testing.T, sockfd, backlog int32) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.ListenWithErrno(ctx, sockfd, backlog)
+ ret, err := dut.ListenWithErrno(ctx, t, sockfd, backlog)
if ret != 0 {
- dut.t.Fatalf("failed to listen: %s", err)
+ t.Fatalf("failed to listen: %s", err)
}
}
// ListenWithErrno calls listen on the DUT.
-func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backlog int32) (int32, error) {
+ t.Helper()
+
req := pb.ListenRequest{
Sockfd: sockfd,
Backlog: backlog,
}
resp, err := dut.posixServer.Listen(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Listen: %s", err)
+ t.Fatalf("failed to call Listen: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -453,20 +479,22 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, sockfd, backlog int32) (int
// Send calls send on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// SendWithErrno.
-func (dut *DUT) Send(sockfd int32, buf []byte, flags int32) int32 {
- dut.t.Helper()
+func (dut *DUT) Send(t *testing.T, sockfd int32, buf []byte, flags int32) int32 {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SendWithErrno(ctx, sockfd, buf, flags)
+ ret, err := dut.SendWithErrno(ctx, t, sockfd, buf, flags)
if ret == -1 {
- dut.t.Fatalf("failed to send: %s", err)
+ t.Fatalf("failed to send: %s", err)
}
return ret
}
// SendWithErrno calls send on the DUT.
-func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32) (int32, error) {
+ t.Helper()
+
req := pb.SendRequest{
Sockfd: sockfd,
Buf: buf,
@@ -474,7 +502,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla
}
resp, err := dut.posixServer.Send(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Send: %s", err)
+ t.Fatalf("failed to call Send: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -482,48 +510,52 @@ func (dut *DUT) SendWithErrno(ctx context.Context, sockfd int32, buf []byte, fla
// SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// SendToWithErrno.
-func (dut *DUT) SendTo(sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
- dut.t.Helper()
+func (dut *DUT) SendTo(t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) int32 {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SendToWithErrno(ctx, sockfd, buf, flags, destAddr)
+ ret, err := dut.SendToWithErrno(ctx, t, sockfd, buf, flags, destAddr)
if ret == -1 {
- dut.t.Fatalf("failed to sendto: %s", err)
+ t.Fatalf("failed to sendto: %s", err)
}
return ret
}
// SendToWithErrno calls sendto on the DUT.
-func (dut *DUT) SendToWithErrno(ctx context.Context, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, buf []byte, flags int32, destAddr unix.Sockaddr) (int32, error) {
+ t.Helper()
+
req := pb.SendToRequest{
Sockfd: sockfd,
Buf: buf,
Flags: flags,
- DestAddr: dut.sockaddrToProto(destAddr),
+ DestAddr: dut.sockaddrToProto(t, destAddr),
}
resp, err := dut.posixServer.SendTo(ctx, &req)
if err != nil {
- dut.t.Fatalf("faled to call SendTo: %s", err)
+ t.Fatalf("faled to call SendTo: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
// SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking
// is true, otherwise it will clear the flag.
-func (dut *DUT) SetNonBlocking(fd int32, nonblocking bool) {
- dut.t.Helper()
- flags := dut.Fcntl(fd, unix.F_GETFL, 0)
+func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) {
+ t.Helper()
+
+ flags := dut.Fcntl(t, fd, unix.F_GETFL, 0)
if nonblocking {
flags |= unix.O_NONBLOCK
} else {
flags &= ^unix.O_NONBLOCK
}
- dut.Fcntl(fd, unix.F_SETFL, flags)
+ dut.Fcntl(t, fd, unix.F_SETFL, flags)
}
-func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval *pb.SockOptVal) (int32, error) {
+ t.Helper()
+
req := pb.SetSockOptRequest{
Sockfd: sockfd,
Level: level,
@@ -532,7 +564,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op
}
resp, err := dut.posixServer.SetSockOpt(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call SetSockOpt: %s", err)
+ t.Fatalf("failed to call SetSockOpt: %s", err)
}
return resp.GetRet(), syscall.Errno(resp.GetErrno_())
}
@@ -542,81 +574,89 @@ func (dut *DUT) setSockOpt(ctx context.Context, sockfd, level, optname int32, op
// needed, use SetSockOptWithErrno. Because endianess and the width of values
// might differ between the testbench and DUT architectures, prefer to use a
// more specific SetSockOptXxx function.
-func (dut *DUT) SetSockOpt(sockfd, level, optname int32, optval []byte) {
- dut.t.Helper()
+func (dut *DUT) SetSockOpt(t *testing.T, sockfd, level, optname int32, optval []byte) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SetSockOptWithErrno(ctx, sockfd, level, optname, optval)
+ ret, err := dut.SetSockOptWithErrno(ctx, t, sockfd, level, optname, optval)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOpt: %s", err)
+ t.Fatalf("failed to SetSockOpt: %s", err)
}
}
// SetSockOptWithErrno calls setsockopt on the DUT. Because endianess and the
// width of values might differ between the testbench and DUT architectures,
// prefer to use a more specific SetSockOptXxxWithErrno function.
-func (dut *DUT) SetSockOptWithErrno(ctx context.Context, sockfd, level, optname int32, optval []byte) (int32, error) {
- dut.t.Helper()
- return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}})
+func (dut *DUT) SetSockOptWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, optval []byte) (int32, error) {
+ t.Helper()
+
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Bytesval{optval}})
}
// SetSockOptInt calls setsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the int optval or error handling
// is needed, use SetSockOptIntWithErrno.
-func (dut *DUT) SetSockOptInt(sockfd, level, optname, optval int32) {
- dut.t.Helper()
+func (dut *DUT) SetSockOptInt(t *testing.T, sockfd, level, optname, optval int32) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SetSockOptIntWithErrno(ctx, sockfd, level, optname, optval)
+ ret, err := dut.SetSockOptIntWithErrno(ctx, t, sockfd, level, optname, optval)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOptInt: %s", err)
+ t.Fatalf("failed to SetSockOptInt: %s", err)
}
}
// SetSockOptIntWithErrno calls setsockopt with an integer optval.
-func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, sockfd, level, optname, optval int32) (int32, error) {
- dut.t.Helper()
- return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}})
+func (dut *DUT) SetSockOptIntWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname, optval int32) (int32, error) {
+ t.Helper()
+
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Intval{optval}})
}
// SetSockOptTimeval calls setsockopt on the DUT and causes a fatal test failure
// if it doesn't succeed. If more control over the timeout or error handling is
// needed, use SetSockOptTimevalWithErrno.
-func (dut *DUT) SetSockOptTimeval(sockfd, level, optname int32, tv *unix.Timeval) {
- dut.t.Helper()
+func (dut *DUT) SetSockOptTimeval(t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, err := dut.SetSockOptTimevalWithErrno(ctx, sockfd, level, optname, tv)
+ ret, err := dut.SetSockOptTimevalWithErrno(ctx, t, sockfd, level, optname, tv)
if ret != 0 {
- dut.t.Fatalf("failed to SetSockOptTimeval: %s", err)
+ t.Fatalf("failed to SetSockOptTimeval: %s", err)
}
}
// SetSockOptTimevalWithErrno calls setsockopt with the timeval converted to
// bytes.
-func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) SetSockOptTimevalWithErrno(ctx context.Context, t *testing.T, sockfd, level, optname int32, tv *unix.Timeval) (int32, error) {
+ t.Helper()
+
timeval := pb.Timeval{
Seconds: int64(tv.Sec),
Microseconds: int64(tv.Usec),
}
- return dut.setSockOpt(ctx, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}})
+ return dut.setSockOpt(ctx, t, sockfd, level, optname, &pb.SockOptVal{Val: &pb.SockOptVal_Timeval{&timeval}})
}
// Socket calls socket on the DUT and returns the file descriptor. If socket
// fails on the DUT, the test ends.
-func (dut *DUT) Socket(domain, typ, proto int32) int32 {
- dut.t.Helper()
- fd, err := dut.SocketWithErrno(domain, typ, proto)
+func (dut *DUT) Socket(t *testing.T, domain, typ, proto int32) int32 {
+ t.Helper()
+
+ fd, err := dut.SocketWithErrno(t, domain, typ, proto)
if fd < 0 {
- dut.t.Fatalf("failed to create socket: %s", err)
+ t.Fatalf("failed to create socket: %s", err)
}
return fd
}
// SocketWithErrno calls socket on the DUT and returns the fd and errno.
-func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
- dut.t.Helper()
+func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, error) {
+ t.Helper()
+
req := pb.SocketRequest{
Domain: domain,
Type: typ,
@@ -625,7 +665,7 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
ctx := context.Background()
resp, err := dut.posixServer.Socket(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Socket: %s", err)
+ t.Fatalf("failed to call Socket: %s", err)
}
return resp.GetFd(), syscall.Errno(resp.GetErrno_())
}
@@ -633,20 +673,22 @@ func (dut *DUT) SocketWithErrno(domain, typ, proto int32) (int32, error) {
// Recv calls recv on the DUT and causes a fatal test failure if it doesn't
// succeed. If more control over the timeout or error handling is needed, use
// RecvWithErrno.
-func (dut *DUT) Recv(sockfd, len, flags int32) []byte {
- dut.t.Helper()
+func (dut *DUT) Recv(t *testing.T, sockfd, len, flags int32) []byte {
+ t.Helper()
+
ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout)
defer cancel()
- ret, buf, err := dut.RecvWithErrno(ctx, sockfd, len, flags)
+ ret, buf, err := dut.RecvWithErrno(ctx, t, sockfd, len, flags)
if ret == -1 {
- dut.t.Fatalf("failed to recv: %s", err)
+ t.Fatalf("failed to recv: %s", err)
}
return buf
}
// RecvWithErrno calls recv on the DUT.
-func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (int32, []byte, error) {
- dut.t.Helper()
+func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, flags int32) (int32, []byte, error) {
+ t.Helper()
+
req := pb.RecvRequest{
Sockfd: sockfd,
Len: len,
@@ -654,7 +696,7 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, sockfd, len, flags int32) (in
}
resp, err := dut.posixServer.Recv(ctx, &req)
if err != nil {
- dut.t.Fatalf("failed to call Recv: %s", err)
+ t.Fatalf("failed to call Recv: %s", err)
}
return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_())
}
diff --git a/test/packetimpact/testbench/rawsockets.go b/test/packetimpact/testbench/rawsockets.go
index 278229b7e..57e822725 100644
--- a/test/packetimpact/testbench/rawsockets.go
+++ b/test/packetimpact/testbench/rawsockets.go
@@ -28,7 +28,6 @@ import (
// Sniffer can sniff raw packets on the wire.
type Sniffer struct {
- t *testing.T
fd int
}
@@ -40,6 +39,8 @@ func htons(x uint16) uint16 {
// NewSniffer creates a Sniffer connected to *device.
func NewSniffer(t *testing.T) (Sniffer, error) {
+ t.Helper()
+
snifferFd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, int(htons(unix.ETH_P_ALL)))
if err != nil {
return Sniffer{}, err
@@ -51,7 +52,6 @@ func NewSniffer(t *testing.T) (Sniffer, error) {
t.Fatalf("can't setsockopt SO_RCVBUF to 10M: %s", err)
}
return Sniffer{
- t: t,
fd: snifferFd,
}, nil
}
@@ -61,7 +61,9 @@ func NewSniffer(t *testing.T) (Sniffer, error) {
const maxReadSize int = 65536
// Recv tries to read one frame until the timeout is up.
-func (s *Sniffer) Recv(timeout time.Duration) []byte {
+func (s *Sniffer) Recv(t *testing.T, timeout time.Duration) []byte {
+ t.Helper()
+
deadline := time.Now().Add(timeout)
for {
timeout = deadline.Sub(time.Now())
@@ -75,7 +77,7 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte {
}
if err := unix.SetsockoptTimeval(s.fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
- s.t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
+ t.Fatalf("can't setsockopt SO_RCVTIMEO: %s", err)
}
buf := make([]byte, maxReadSize)
@@ -85,10 +87,10 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte {
continue
}
if err != nil {
- s.t.Fatalf("can't read: %s", err)
+ t.Fatalf("can't read: %s", err)
}
if nread > maxReadSize {
- s.t.Fatalf("received a truncated frame of %d bytes", nread)
+ t.Fatalf("received a truncated frame of %d bytes, want at most %d bytes", nread, maxReadSize)
}
return buf[:nread]
}
@@ -96,14 +98,16 @@ func (s *Sniffer) Recv(timeout time.Duration) []byte {
// Drain drains the Sniffer's socket receive buffer by receiving until there's
// nothing else to receive.
-func (s *Sniffer) Drain() {
- s.t.Helper()
+func (s *Sniffer) Drain(t *testing.T) {
+ t.Helper()
+
flags, err := unix.FcntlInt(uintptr(s.fd), unix.F_GETFL, 0)
if err != nil {
- s.t.Fatalf("failed to get sniffer socket fd flags: %s", err)
+ t.Fatalf("failed to get sniffer socket fd flags: %s", err)
}
- if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags|unix.O_NONBLOCK); err != nil {
- s.t.Fatalf("failed to make sniffer socket non-blocking: %s", err)
+ nonBlockingFlags := flags | unix.O_NONBLOCK
+ if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, nonBlockingFlags); err != nil {
+ t.Fatalf("failed to make sniffer socket non-blocking with flags %b: %s", nonBlockingFlags, err)
}
for {
buf := make([]byte, maxReadSize)
@@ -113,7 +117,7 @@ func (s *Sniffer) Drain() {
}
}
if _, err := unix.FcntlInt(uintptr(s.fd), unix.F_SETFL, flags); err != nil {
- s.t.Fatalf("failed to restore sniffer socket fd flags: %s", err)
+ t.Fatalf("failed to restore sniffer socket fd flags to %b: %s", flags, err)
}
}
@@ -128,12 +132,13 @@ func (s *Sniffer) close() error {
// Injector can inject raw frames.
type Injector struct {
- t *testing.T
fd int
}
// NewInjector creates a new injector on *device.
func NewInjector(t *testing.T) (Injector, error) {
+ t.Helper()
+
ifInfo, err := net.InterfaceByName(Device)
if err != nil {
return Injector{}, err
@@ -156,15 +161,20 @@ func NewInjector(t *testing.T) (Injector, error) {
return Injector{}, err
}
return Injector{
- t: t,
fd: injectFd,
}, nil
}
// Send a raw frame.
-func (i *Injector) Send(b []byte) {
- if _, err := unix.Write(i.fd, b); err != nil {
- i.t.Fatalf("can't write: %s of len %d", err, len(b))
+func (i *Injector) Send(t *testing.T, b []byte) {
+ t.Helper()
+
+ n, err := unix.Write(i.fd, b)
+ if err != nil {
+ t.Fatalf("can't write bytes of len %d: %s", len(b), err)
+ }
+ if n != len(b) {
+ t.Fatalf("got %d bytes written, want %d", n, len(b))
}
}
diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go
index 407565078..a61054c2c 100644
--- a/test/packetimpact/tests/fin_wait2_timeout_test.go
+++ b/test/packetimpact/tests/fin_wait2_timeout_test.go
@@ -39,34 +39,34 @@ func TestFinWait2Timeout(t *testing.T) {
t.Run(tt.description, func(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
- conn.Connect()
+ defer conn.Close(t)
+ conn.Connect(t)
- acceptFd, _ := dut.Accept(listenFd)
+ acceptFd, _ := dut.Accept(t, listenFd)
if tt.linger2 {
tv := unix.Timeval{Sec: 1, Usec: 0}
- dut.SetSockOptTimeval(acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv)
+ dut.SetSockOptTimeval(t, acceptFd, unix.SOL_TCP, unix.TCP_LINGER2, &tv)
}
- dut.Close(acceptFd)
+ dut.Close(t, acceptFd)
- if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err)
}
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
time.Sleep(5 * time.Second)
- conn.Drain()
+ conn.Drain(t)
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
if tt.linger2 {
- if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
t.Fatalf("expected a RST packet within a second but got none: %s", err)
}
} else {
- if got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil {
+ if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil {
t.Fatalf("expected no RST packets within ten seconds but got one: %s", got)
}
}
diff --git a/test/packetimpact/tests/icmpv6_param_problem_test.go b/test/packetimpact/tests/icmpv6_param_problem_test.go
index 8dfd26ee8..2d59d552d 100644
--- a/test/packetimpact/tests/icmpv6_param_problem_test.go
+++ b/test/packetimpact/tests/icmpv6_param_problem_test.go
@@ -34,7 +34,7 @@ func TestICMPv6ParamProblemTest(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
- defer conn.Close()
+ defer conn.Close(t)
ipv6 := testbench.IPv6{
// 254 is reserved and used for experimentation and testing. This should
// cause an error.
@@ -45,8 +45,8 @@ func TestICMPv6ParamProblemTest(t *testing.T) {
Payload: []byte("hello world"),
}
- toSend := (*testbench.Connection)(&conn).CreateFrame(testbench.Layers{&ipv6}, &icmpv6)
- (*testbench.Connection)(&conn).SendFrame(toSend)
+ toSend := (*testbench.Connection)(&conn).CreateFrame(t, testbench.Layers{&ipv6}, &icmpv6)
+ (*testbench.Connection)(&conn).SendFrame(t, toSend)
// Build the expected ICMPv6 payload, which includes an index to the
// problematic byte and also the problematic packet as described in
@@ -72,7 +72,7 @@ func TestICMPv6ParamProblemTest(t *testing.T) {
&expectedICMPv6,
}
timeout := time.Second
- if _, err := conn.ExpectFrame(paramProblem, timeout); err != nil {
+ if _, err := conn.ExpectFrame(t, paramProblem, timeout); err != nil {
t.Errorf("expected %s within %s but got none: %s", paramProblem, timeout, err)
}
}
diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
index 70f6df5e0..cf881418c 100644
--- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go
+++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go
@@ -31,8 +31,8 @@ func init() {
testbench.RegisterFlags(flag.CommandLine)
}
-func recvTCPSegment(conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) {
- layers, err := conn.ExpectData(expect, expectPayload, time.Second)
+func recvTCPSegment(t *testing.T, conn *testbench.TCPIPv4, expect *testbench.TCP, expectPayload *testbench.Payload) (uint16, error) {
+ layers, err := conn.ExpectData(t, expect, expectPayload, time.Second)
if err != nil {
return 0, fmt.Errorf("failed to receive TCP segment: %s", err)
}
@@ -69,17 +69,17 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFD)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
- conn.Connect()
- remoteFD, _ := dut.Accept(listenFD)
- defer dut.Close(remoteFD)
+ conn.Connect(t)
+ remoteFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, remoteFD)
- dut.SetSockOptInt(remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+ dut.SetSockOptInt(t, remoteFD, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
// TODO(b/129291778) The following socket option clears the DF bit on
// IP packets sent over the socket, and is currently not supported by
@@ -87,30 +87,30 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) {
// socket option being not supported does not affect the operation of
// this test. Once the socket option is supported, the following call
// can be changed to simply assert success.
- ret, errno := dut.SetSockOptIntWithErrno(context.Background(), remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT)
+ ret, errno := dut.SetSockOptIntWithErrno(context.Background(), t, remoteFD, unix.IPPROTO_IP, linux.IP_MTU_DISCOVER, linux.IP_PMTUDISC_DONT)
if ret == -1 && errno != unix.ENOTSUP {
t.Fatalf("failed to set IP_MTU_DISCOVER socket option to IP_PMTUDISC_DONT: %s", errno)
}
samplePayload := &testbench.Payload{Bytes: tc.payload}
- dut.Send(remoteFD, tc.payload, 0)
- if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ dut.Send(t, remoteFD, tc.payload, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("failed to receive TCP segment sent for RTT calculation: %s", err)
}
// Let the DUT estimate RTO with RTT from the DATA-ACK.
// TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
// we can skip sending this ACK.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- dut.Send(remoteFD, tc.payload, 0)
- expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum()))}
- originalID, err := recvTCPSegment(&conn, expectTCP, samplePayload)
+ dut.Send(t, remoteFD, tc.payload, 0)
+ expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))}
+ originalID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload)
if err != nil {
t.Fatalf("failed to receive TCP segment: %s", err)
}
- retransmitID, err := recvTCPSegment(&conn, expectTCP, samplePayload)
+ retransmitID, err := recvTCPSegment(t, &conn, expectTCP, samplePayload)
if err != nil {
t.Fatalf("failed to receive retransmitted TCP segment: %s", err)
}
diff --git a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
index 7b462c8e2..b5f94ad4b 100644
--- a/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
+++ b/test/packetimpact/tests/ipv6_fragment_reassembly_test.go
@@ -48,7 +48,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
- defer conn.Close()
+ defer conn.Close(t)
firstPayloadToSend := make([]byte, firstPayloadLength)
for i := range firstPayloadToSend {
@@ -81,7 +81,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
buffer.NewVectorisedView(len(secondPayloadToSend), []buffer.View{secondPayloadToSend}),
)
- conn.Send(testbench.IPv6{},
+ conn.Send(t, testbench.IPv6{},
&testbench.IPv6FragmentExtHdr{
FragmentOffset: testbench.Uint16(0),
MoreFragments: testbench.Bool(true),
@@ -96,7 +96,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
icmpv6ProtoNum := header.IPv6ExtensionHeaderIdentifier(header.ICMPv6ProtocolNumber)
- conn.Send(testbench.IPv6{},
+ conn.Send(t, testbench.IPv6{},
&testbench.IPv6FragmentExtHdr{
NextHeader: &icmpv6ProtoNum,
FragmentOffset: testbench.Uint16((firstPayloadLength + header.ICMPv6EchoMinimumSize) / 8),
@@ -107,7 +107,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
Bytes: secondPayloadToSend,
})
- gotEchoReplyFirstPart, err := conn.ExpectFrame(testbench.Layers{
+ gotEchoReplyFirstPart, err := conn.ExpectFrame(t, testbench.Layers{
&testbench.Ether{},
&testbench.IPv6{},
&testbench.IPv6FragmentExtHdr{
@@ -142,7 +142,7 @@ func TestIPv6FragmentReassembly(t *testing.T) {
hex.Dump(wantFirstPayload))
}
- gotEchoReplySecondPart, err := conn.ExpectFrame(testbench.Layers{
+ gotEchoReplySecondPart, err := conn.ExpectFrame(t, testbench.Layers{
&testbench.Ether{},
&testbench.IPv6{},
&testbench.IPv6FragmentExtHdr{
diff --git a/test/packetimpact/tests/ipv6_unknown_options_action_test.go b/test/packetimpact/tests/ipv6_unknown_options_action_test.go
index 100b30ad7..d7d63cbd2 100644
--- a/test/packetimpact/tests/ipv6_unknown_options_action_test.go
+++ b/test/packetimpact/tests/ipv6_unknown_options_action_test.go
@@ -23,21 +23,21 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- tb "gvisor.dev/gvisor/test/packetimpact/testbench"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
)
func init() {
- tb.RegisterFlags(flag.CommandLine)
+ testbench.RegisterFlags(flag.CommandLine)
}
-func mkHopByHopOptionsExtHdr(optType byte) tb.Layer {
- return &tb.IPv6HopByHopOptionsExtHdr{
+func mkHopByHopOptionsExtHdr(optType byte) testbench.Layer {
+ return &testbench.IPv6HopByHopOptionsExtHdr{
Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00},
}
}
-func mkDestinationOptionsExtHdr(optType byte) tb.Layer {
- return &tb.IPv6DestinationOptionsExtHdr{
+func mkDestinationOptionsExtHdr(optType byte) testbench.Layer {
+ return &testbench.IPv6DestinationOptionsExtHdr{
Options: []byte{optType, 0x04, 0x00, 0x00, 0x00, 0x00},
}
}
@@ -49,7 +49,7 @@ func optionTypeFromAction(action header.IPv6OptionUnknownAction) byte {
func TestIPv6UnknownOptionAction(t *testing.T) {
for _, tt := range []struct {
description string
- mkExtHdr func(optType byte) tb.Layer
+ mkExtHdr func(optType byte) testbench.Layer
action header.IPv6OptionUnknownAction
multicastDst bool
wantICMPv6 bool
@@ -140,21 +140,21 @@ func TestIPv6UnknownOptionAction(t *testing.T) {
},
} {
t.Run(tt.description, func(t *testing.T) {
- dut := tb.NewDUT(t)
+ dut := testbench.NewDUT(t)
defer dut.TearDown()
- ipv6Conn := tb.NewIPv6Conn(t, tb.IPv6{}, tb.IPv6{})
- conn := (*tb.Connection)(&ipv6Conn)
- defer ipv6Conn.Close()
+ ipv6Conn := testbench.NewIPv6Conn(t, testbench.IPv6{}, testbench.IPv6{})
+ conn := (*testbench.Connection)(&ipv6Conn)
+ defer ipv6Conn.Close(t)
- outgoingOverride := tb.Layers{}
+ outgoingOverride := testbench.Layers{}
if tt.multicastDst {
- outgoingOverride = tb.Layers{&tb.IPv6{
- DstAddr: tb.Address(tcpip.Address(net.ParseIP("ff02::1"))),
+ outgoingOverride = testbench.Layers{&testbench.IPv6{
+ DstAddr: testbench.Address(tcpip.Address(net.ParseIP("ff02::1"))),
}}
}
- outgoing := conn.CreateFrame(outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action)))
- conn.SendFrame(outgoing)
+ outgoing := conn.CreateFrame(t, outgoingOverride, tt.mkExtHdr(optionTypeFromAction(tt.action)))
+ conn.SendFrame(t, outgoing)
ipv6Sent := outgoing[1:]
invokingPacket, err := ipv6Sent.ToBytes()
if err != nil {
@@ -167,12 +167,12 @@ func TestIPv6UnknownOptionAction(t *testing.T) {
// after the IPv6 header (after NextHeader and ExtHdrLen).
binary.BigEndian.PutUint32(icmpv6Payload, header.IPv6MinimumSize+2)
icmpv6Payload = append(icmpv6Payload, invokingPacket...)
- gotICMPv6, err := ipv6Conn.ExpectFrame(tb.Layers{
- &tb.Ether{},
- &tb.IPv6{},
- &tb.ICMPv6{
- Type: tb.ICMPv6Type(header.ICMPv6ParamProblem),
- Code: tb.Byte(2),
+ gotICMPv6, err := ipv6Conn.ExpectFrame(t, testbench.Layers{
+ &testbench.Ether{},
+ &testbench.IPv6{},
+ &testbench.ICMPv6{
+ Type: testbench.ICMPv6Type(header.ICMPv6ParamProblem),
+ Code: testbench.Byte(2),
Payload: icmpv6Payload,
},
}, time.Second)
diff --git a/test/packetimpact/tests/tcp_close_wait_ack_test.go b/test/packetimpact/tests/tcp_close_wait_ack_test.go
index 6e7ff41d7..e6a96f214 100644
--- a/test/packetimpact/tests/tcp_close_wait_ack_test.go
+++ b/test/packetimpact/tests/tcp_close_wait_ack_test.go
@@ -33,39 +33,39 @@ func init() {
func TestCloseWaitAck(t *testing.T) {
for _, tt := range []struct {
description string
- makeTestingTCP func(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP
+ makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP
seqNumOffset seqnum.Size
expectAck bool
}{
- {"OTW", GenerateOTWSeqSegment, 0, false},
- {"OTW", GenerateOTWSeqSegment, 1, true},
- {"OTW", GenerateOTWSeqSegment, 2, true},
- {"ACK", GenerateUnaccACKSegment, 0, false},
- {"ACK", GenerateUnaccACKSegment, 1, true},
- {"ACK", GenerateUnaccACKSegment, 2, true},
+ {"OTW", generateOTWSeqSegment, 0, false},
+ {"OTW", generateOTWSeqSegment, 1, true},
+ {"OTW", generateOTWSeqSegment, 2, true},
+ {"ACK", generateUnaccACKSegment, 0, false},
+ {"ACK", generateUnaccACKSegment, 1, true},
+ {"ACK", generateUnaccACKSegment, 2, true},
} {
t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
- conn.Connect()
- acceptFd, _ := dut.Accept(listenFd)
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
// Send a FIN to DUT to intiate the active close
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
- gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)})
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err)
}
windowSize := seqnum.Size(*gotTCP.WindowSize)
// Send a segment with OTW Seq / unacc ACK and expect an ACK back
- conn.Send(tt.makeTestingTCP(&conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")})
- gotAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), &testbench.Payload{Bytes: []byte("Sample Data")})
+ gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
if tt.expectAck && err != nil {
t.Fatalf("expected an ack but got none: %s", err)
}
@@ -74,35 +74,36 @@ func TestCloseWaitAck(t *testing.T) {
}
// Now let's verify DUT is indeed in CLOSE_WAIT
- dut.Close(acceptFd)
- if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
+ dut.Close(t, acceptFd)
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil {
t.Fatalf("expected DUT to send a FIN: %s", err)
}
// Ack the FIN from DUT
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
// Send some extra data to DUT
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")})
- if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: []byte("Sample Data")})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil {
t.Fatalf("expected DUT to send an RST: %s", err)
}
})
}
}
-// This generates an segment with seqnum = RCV.NXT + RCV.WND + seqNumOffset, the
-// generated segment is only acceptable when seqNumOffset is 0, otherwise an ACK
-// is expected from the receiver.
-func GenerateOTWSeqSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
- lastAcceptable := conn.LocalSeqNum().Add(windowSize)
+// generateOTWSeqSegment generates an segment with
+// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only
+// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the
+// receiver.
+func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
+ lastAcceptable := conn.LocalSeqNum(t).Add(windowSize)
otwSeq := uint32(lastAcceptable.Add(seqNumOffset))
return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)}
}
-// This generates an segment with acknum = SND.NXT + seqNumOffset, the generated
-// segment is only acceptable when seqNumOffset is 0, otherwise an ACK is
-// expected from the receiver.
-func GenerateUnaccACKSegment(conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
- lastAcceptable := conn.RemoteSeqNum()
+// generateUnaccACKSegment generates an segment with
+// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable
+// when seqNumOffset is 0, otherwise an ACK is expected from the receiver.
+func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP {
+ lastAcceptable := conn.RemoteSeqNum(t)
unaccAck := uint32(lastAcceptable.Add(seqNumOffset))
return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)}
}
diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go
index fb8f48629..8feea4a82 100644
--- a/test/packetimpact/tests/tcp_cork_mss_test.go
+++ b/test/packetimpact/tests/tcp_cork_mss_test.go
@@ -32,53 +32,53 @@ func init() {
func TestTCPCorkMSS(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFD)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
const mss = uint32(header.TCPDefaultMSS)
options := make([]byte, header.TCPOptionMSSLength)
header.EncodeMSSOption(mss, options)
- conn.ConnectWithOptions(options)
+ conn.ConnectWithOptions(t, options)
- acceptFD, _ := dut.Accept(listenFD)
- defer dut.Close(acceptFD)
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
- dut.SetSockOptInt(acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1)
+ dut.SetSockOptInt(t, acceptFD, unix.IPPROTO_TCP, unix.TCP_CORK, 1)
// Let the dut application send 2 small segments to be held up and coalesced
// until the application sends a larger segment to fill up to > MSS.
sampleData := []byte("Sample Data")
- dut.Send(acceptFD, sampleData, 0)
- dut.Send(acceptFD, sampleData, 0)
+ dut.Send(t, acceptFD, sampleData, 0)
+ dut.Send(t, acceptFD, sampleData, 0)
expectedData := sampleData
expectedData = append(expectedData, sampleData...)
largeData := make([]byte, mss+1)
expectedData = append(expectedData, largeData...)
- dut.Send(acceptFD, largeData, 0)
+ dut.Send(t, acceptFD, largeData, 0)
// Expect the segments to be coalesced and sent and capped to MSS.
expectedPayload := testbench.Payload{Bytes: expectedData[:mss]}
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
// Expect the coalesced segment to be split and transmitted.
expectedPayload = testbench.Payload{Bytes: expectedData[mss:]}
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
// Check for segments to *not* be held up because of TCP_CORK when
// the current send window is less than MSS.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))})
- dut.Send(acceptFD, sampleData, 0)
- dut.Send(acceptFD, sampleData, 0)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))})
+ dut.Send(t, acceptFD, sampleData, 0)
+ dut.Send(t, acceptFD, sampleData, 0)
expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)}
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
}
diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go
index 652b530d0..22937d92f 100644
--- a/test/packetimpact/tests/tcp_handshake_window_size_test.go
+++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go
@@ -33,14 +33,14 @@ func init() {
func TestTCPHandshakeWindowSize(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFD)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
// Start handshake with zero window size.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))})
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN-ACK: %s", err)
}
// Update the advertised window size to a non-zero value with the ACK that
@@ -48,10 +48,10 @@ func TestTCPHandshakeWindowSize(t *testing.T) {
//
// Set the window size with MSB set and expect the dut to treat it as
// an unsigned value.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))})
- acceptFd, _ := dut.Accept(listenFD)
- defer dut.Close(acceptFd)
+ acceptFd, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFd)
sampleData := []byte("Sample Data")
samplePayload := &testbench.Payload{Bytes: sampleData}
@@ -59,8 +59,8 @@ func TestTCPHandshakeWindowSize(t *testing.T) {
// Since we advertised a zero window followed by a non-zero window,
// expect the dut to honor the recently advertised non-zero window
// and actually send out the data instead of probing for zero window.
- dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectNextData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil {
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go
index 868a08da8..900352fa1 100644
--- a/test/packetimpact/tests/tcp_network_unreachable_test.go
+++ b/test/packetimpact/tests/tcp_network_unreachable_test.go
@@ -38,29 +38,29 @@ func TestTCPSynSentUnreachable(t *testing.T) {
// Create the DUT and connection.
dut := testbench.NewDUT(t)
defer dut.TearDown()
- clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
port := uint16(9001)
conn := testbench.NewTCPIPv4(t, testbench.TCP{SrcPort: &port, DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort, DstPort: &port})
- defer conn.Close()
+ defer conn.Close(t)
// Bring the DUT to SYN-SENT state with a non-blocking connect.
ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout)
defer cancel()
sa := unix.SockaddrInet4{Port: int(port)}
copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv4)).To4())
- if _, err := dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
+ if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err)
}
// Get the SYN.
- tcpLayers, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second)
+ tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second)
if err != nil {
t.Fatalf("expected SYN: %s", err)
}
// Send a host unreachable message.
rawConn := (*testbench.Connection)(&conn)
- layers := rawConn.CreateFrame(nil)
+ layers := rawConn.CreateFrame(t, nil)
layers = layers[:len(layers)-1]
const ipLayer = 1
const tcpLayer = ipLayer + 1
@@ -74,9 +74,9 @@ func TestTCPSynSentUnreachable(t *testing.T) {
}
var icmpv4 testbench.ICMPv4 = testbench.ICMPv4{Type: testbench.ICMPv4Type(header.ICMPv4DstUnreachable), Code: testbench.Uint8(header.ICMPv4HostUnreachable)}
layers = append(layers, &icmpv4, ip, tcp)
- rawConn.SendFrameStateless(layers)
+ rawConn.SendFrameStateless(t, layers)
- if _, err = dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) {
+ if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EHOSTUNREACH) {
t.Errorf("expected connect to fail with EHOSTUNREACH, but got %v", err)
}
}
@@ -88,9 +88,9 @@ func TestTCPSynSentUnreachable6(t *testing.T) {
// Create the DUT and connection.
dut := testbench.NewDUT(t)
defer dut.TearDown()
- clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6))
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv6))
conn := testbench.NewTCPIPv6(t, testbench.TCP{DstPort: &clientPort}, testbench.TCP{SrcPort: &clientPort})
- defer conn.Close()
+ defer conn.Close(t)
// Bring the DUT to SYN-SENT state with a non-blocking connect.
ctx, cancel := context.WithTimeout(context.Background(), testbench.RPCTimeout)
@@ -100,19 +100,19 @@ func TestTCPSynSentUnreachable6(t *testing.T) {
ZoneId: uint32(testbench.RemoteInterfaceID),
}
copy(sa.Addr[:], net.IP(net.ParseIP(testbench.LocalIPv6)).To16())
- if _, err := dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
+ if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) {
t.Errorf("expected connect to fail with EINPROGRESS, but got %v", err)
}
// Get the SYN.
- tcpLayers, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second)
+ tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second)
if err != nil {
t.Fatalf("expected SYN: %s", err)
}
// Send a host unreachable message.
rawConn := (*testbench.Connection)(&conn)
- layers := rawConn.CreateFrame(nil)
+ layers := rawConn.CreateFrame(t, nil)
layers = layers[:len(layers)-1]
const ipLayer = 1
const tcpLayer = ipLayer + 1
@@ -131,9 +131,9 @@ func TestTCPSynSentUnreachable6(t *testing.T) {
Payload: []byte{0, 0, 0, 0},
}
layers = append(layers, &icmpv6, ip, tcp)
- rawConn.SendFrameStateless(layers)
+ rawConn.SendFrameStateless(t, layers)
- if _, err = dut.ConnectWithErrno(ctx, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) {
+ if _, err = dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.ENETUNREACH) {
t.Errorf("expected connect to fail with ENETUNREACH, but got %v", err)
}
}
diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
index b9b3e91d3..82b7a85ff 100644
--- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
+++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go
@@ -31,12 +31,12 @@ func init() {
func TestTcpNoAcceptCloseReset(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- conn.Connect()
- defer conn.Close()
- dut.Close(listenFd)
- if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
+ conn.Connect(t)
+ defer conn.Close(t)
+ dut.Close(t, listenFd)
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil {
t.Fatalf("expected a RST-ACK packet but got none: %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go
index ad8c74234..08f759f7c 100644
--- a/test/packetimpact/tests/tcp_outside_the_window_test.go
+++ b/test/packetimpact/tests/tcp_outside_the_window_test.go
@@ -63,25 +63,25 @@ func TestTCPOutsideTheWindow(t *testing.T) {
t.Run(fmt.Sprintf("%s%d", tt.description, tt.seqNumOffset), func(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFD)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
- conn.Connect()
- acceptFD, _ := dut.Accept(listenFD)
- defer dut.Close(acceptFD)
+ defer conn.Close(t)
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
- windowSize := seqnum.Size(*conn.SynAck().WindowSize) + tt.seqNumOffset
- conn.Drain()
+ windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + tt.seqNumOffset
+ conn.Drain(t)
// Ignore whatever incrementing that this out-of-order packet might cause
// to the AckNum.
- localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum()))
- conn.Send(testbench.TCP{
+ localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t)))
+ conn.Send(t, testbench.TCP{
Flags: testbench.Uint8(tt.tcpFlags),
- SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum().Add(windowSize))),
+ SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))),
}, tt.payload...)
timeout := 3 * time.Second
- gotACK, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
+ gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout)
if tt.expectACK && err != nil {
t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err)
}
diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go
index 55db4ece6..37f3b56dd 100644
--- a/test/packetimpact/tests/tcp_paws_mechanism_test.go
+++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go
@@ -32,15 +32,15 @@ func init() {
func TestPAWSMechanism(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFD)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
options := make([]byte, header.TCPOptionTSLength)
header.EncodeTSOption(currentTS(), 0, options)
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options})
- synAck, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options})
+ synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("didn't get synack during handshake: %s", err)
}
@@ -50,9 +50,9 @@ func TestPAWSMechanism(t *testing.T) {
}
tsecr := parsedSynOpts.TSVal
header.EncodeTSOption(currentTS(), tsecr, options)
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options})
- acceptFD, _ := dut.Accept(listenFD)
- defer dut.Close(acceptFD)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options})
+ acceptFD, _ := dut.Accept(t, listenFD)
+ defer dut.Close(t, acceptFD)
sampleData := []byte("Sample Data")
sentTSVal := currentTS()
@@ -61,9 +61,9 @@ func TestPAWSMechanism(t *testing.T) {
// every time we send one, it should not cause any flakiness because timestamps
// only need to be non-decreasing.
time.Sleep(3 * time.Millisecond)
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
- gotTCP, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected an ACK but got none: %s", err)
}
@@ -86,9 +86,9 @@ func TestPAWSMechanism(t *testing.T) {
// 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness
// due to the exact same reasoning discussed above.
time.Sleep(3 * time.Millisecond)
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData})
- gotTCP, err = conn.Expect(testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
+ gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second)
if err != nil {
t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err)
}
diff --git a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
index 8fbec893b..d9f3ea0f2 100644
--- a/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
+++ b/test/packetimpact/tests/tcp_queue_receive_in_syn_sent_test.go
@@ -52,26 +52,26 @@ func TestQueueReceiveInSynSent(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- socket, remotePort := dut.CreateBoundSocket(unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
+ socket, remotePort := dut.CreateBoundSocket(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, net.ParseIP(testbench.RemoteIPv4))
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
sampleData := []byte("Sample Data")
- dut.SetNonBlocking(socket, true)
- if _, err := dut.ConnectWithErrno(context.Background(), socket, conn.LocalAddr()); !errors.Is(err, syscall.EINPROGRESS) {
+ dut.SetNonBlocking(t, socket, true)
+ if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) {
t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err)
}
- if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil {
t.Fatalf("expected a SYN from DUT, but got none: %s", err)
}
- if _, _, err := dut.RecvWithErrno(context.Background(), socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) {
+ if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) {
t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err)
}
// Test blocking read.
- dut.SetNonBlocking(socket, false)
+ dut.SetNonBlocking(t, socket, false)
var wg sync.WaitGroup
defer wg.Wait()
@@ -86,7 +86,7 @@ func TestQueueReceiveInSynSent(t *testing.T) {
block.Done()
// Issue RECEIVE call in SYN-SENT, this should be queued for
// process until the connection is established.
- n, buff, err := dut.RecvWithErrno(ctx, socket, int32(len(sampleData)), 0)
+ n, buff, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0)
if tt.reset {
if err != syscall.Errno(unix.ECONNREFUSED) {
t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err)
@@ -112,19 +112,19 @@ func TestQueueReceiveInSynSent(t *testing.T) {
time.Sleep(100 * time.Millisecond)
if tt.reset {
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
return
}
// Bring the connection to Established.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
- if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected an ACK from DUT, but got none: %s", err)
}
// Send sample payload and expect an ACK.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
- if _, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData})
+ if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil {
t.Fatalf("expected an ACK from DUT, but got none: %s", err)
}
})
diff --git a/test/packetimpact/tests/tcp_reordering_test.go b/test/packetimpact/tests/tcp_reordering_test.go
index a5378a9dd..8742819ca 100644
--- a/test/packetimpact/tests/tcp_reordering_test.go
+++ b/test/packetimpact/tests/tcp_reordering_test.go
@@ -32,10 +32,10 @@ func init() {
func TestReorderingWindow(t *testing.T) {
dut := tb.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := tb.NewTCPIPv4(t, tb.TCP{DstPort: &remotePort}, tb.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
// Enable SACK.
opts := make([]byte, 40)
@@ -49,17 +49,17 @@ func TestReorderingWindow(t *testing.T) {
const mss = minMTU - header.IPv4MinimumSize - header.TCPMinimumSize
optsOff += header.EncodeMSSOption(mss, opts[optsOff:])
- conn.ConnectWithOptions(opts[:optsOff])
+ conn.ConnectWithOptions(t, opts[:optsOff])
- acceptFd, _ := dut.Accept(listenFd)
- defer dut.Close(acceptFd)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
if tb.DUTType == "linux" {
// Linux has changed its handling of reordering, force the old behavior.
- dut.SetSockOpt(acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno"))
+ dut.SetSockOpt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_CONGESTION, []byte("reno"))
}
- pls := dut.GetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG)
+ pls := dut.GetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_MAXSEG)
if tb.DUTType == "netstack" {
// netstack does not impliment TCP_MAXSEG correctly. Fake it
// here. Netstack uses the max SACK size which is 32. The MSS
@@ -69,13 +69,13 @@ func TestReorderingWindow(t *testing.T) {
payload := make([]byte, pls)
- seqNum1 := *conn.RemoteSeqNum()
+ seqNum1 := *conn.RemoteSeqNum(t)
const numPkts = 10
// Send some packets, checking that we receive each.
for i, sn := 0, seqNum1; i < numPkts; i++ {
- dut.Send(acceptFd, payload, 0)
+ dut.Send(t, acceptFd, payload, 0)
- gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
sn.UpdateForward(seqnum.Size(len(payload)))
if err != nil {
t.Errorf("Expect #%d: %s", i+1, err)
@@ -86,7 +86,7 @@ func TestReorderingWindow(t *testing.T) {
}
}
- seqNum2 := *conn.RemoteSeqNum()
+ seqNum2 := *conn.RemoteSeqNum(t)
// SACK packets #2-4.
sackBlock := make([]byte, 40)
@@ -97,13 +97,13 @@ func TestReorderingWindow(t *testing.T) {
seqNum1.Add(seqnum.Size(len(payload))),
seqNum1.Add(seqnum.Size(4 * len(payload))),
}}, sackBlock[sbOff:])
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]})
// ACK first packet.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))})
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum1) + uint32(len(payload)))})
// Check for retransmit.
- gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second)
+ gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(seqNum1))}, time.Second)
if err != nil {
t.Error("Expect for retransmit:", err)
}
@@ -123,14 +123,14 @@ func TestReorderingWindow(t *testing.T) {
seqNum1.Add(seqnum.Size(4 * len(payload))),
}}, dsackBlock[dsbOff:])
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]})
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck), AckNum: tb.Uint32(uint32(seqNum2)), Options: dsackBlock[:dsbOff]})
// Send half of the original window of packets, checking that we
// received each.
for i, sn := 0, seqNum2; i < numPkts/2; i++ {
- dut.Send(acceptFd, payload, 0)
+ dut.Send(t, acceptFd, payload, 0)
- gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
sn.UpdateForward(seqnum.Size(len(payload)))
if err != nil {
t.Errorf("Expect #%d: %s", i+1, err)
@@ -144,8 +144,8 @@ func TestReorderingWindow(t *testing.T) {
if tb.DUTType == "netstack" {
// The window should now be halved, so we should receive any
// more, even if we send them.
- dut.Send(acceptFd, payload, 0)
- if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
+ dut.Send(t, acceptFd, payload, 0)
+ if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got)
}
return
@@ -153,9 +153,9 @@ func TestReorderingWindow(t *testing.T) {
// Linux reduces the window by three. Check that we can receive the rest.
for i, sn := 0, seqNum2.Add(seqnum.Size(numPkts/2*len(payload))); i < 2; i++ {
- dut.Send(acceptFd, payload, 0)
+ dut.Send(t, acceptFd, payload, 0)
- gotOne, err := conn.Expect(tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
+ gotOne, err := conn.Expect(t, tb.TCP{SeqNum: tb.Uint32(uint32(sn))}, time.Second)
sn.UpdateForward(seqnum.Size(len(payload)))
if err != nil {
t.Errorf("Expect #%d: %s", i+1, err)
@@ -167,8 +167,8 @@ func TestReorderingWindow(t *testing.T) {
}
// The window should now be full.
- dut.Send(acceptFd, payload, 0)
- if got, err := conn.Expect(tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
+ dut.Send(t, acceptFd, payload, 0)
+ if got, err := conn.Expect(t, tb.TCP{}, 100*time.Millisecond); got != nil || err == nil {
t.Fatalf("expected no packets within 100 millisecond, but got one: %s", got)
}
}
diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go
index 6940eb7fb..072014ff8 100644
--- a/test/packetimpact/tests/tcp_retransmits_test.go
+++ b/test/packetimpact/tests/tcp_retransmits_test.go
@@ -33,41 +33,41 @@ func init() {
func TestRetransmits(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
- conn.Connect()
- acceptFd, _ := dut.Accept(listenFd)
- defer dut.Close(acceptFd)
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
- dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
samplePayload := &testbench.Payload{Bytes: sampleData}
- dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
// Give a chance for the dut to estimate RTO with RTT from the DATA-ACK.
// TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which
// we can skip sending this ACK.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
startRTO := time.Second
current := startRTO
first := time.Now()
- dut.Send(acceptFd, sampleData, 0)
- seq := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
- if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil {
+ dut.Send(t, acceptFd, sampleData, 0)
+ seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, startRTO); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
// Expect retransmits of the same segment.
for i := 0; i < 5; i++ {
start := time.Now()
- if _, err := conn.ExpectData(&testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, 2*current); err != nil {
t.Fatalf("expected payload was not received: %s loop %d", err, i)
}
if i == 0 {
diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
index 90ab85419..f91b06ba1 100644
--- a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
+++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go
@@ -61,23 +61,23 @@ func TestSendWindowSizesPiggyback(t *testing.T) {
t.Run(fmt.Sprintf("%s%d", tt.description, tt.windowSize), func(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort, WindowSize: testbench.Uint16(tt.windowSize)}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
- conn.Connect()
- acceptFd, _ := dut.Accept(listenFd)
- defer dut.Close(acceptFd)
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
- dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}
- dut.Send(acceptFd, sampleData, 0)
+ dut.Send(t, acceptFd, sampleData, 0)
expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1}
- if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
@@ -86,18 +86,18 @@ func TestSendWindowSizesPiggyback(t *testing.T) {
if tt.enqueue {
// Enqueue a segment for the dut to transmit.
- dut.Send(acceptFd, sampleData, 0)
+ dut.Send(t, acceptFd, sampleData, 0)
}
// Send ACK for the previous segment along with data for the dut to
// receive and ACK back. Sending this ACK would make room for the dut
// to transmit any enqueued segment.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData})
// Expect the dut to piggyback the ACK for received data along with
// the segment enqueued for transmit.
expectedPayload = testbench.Payload{Bytes: tt.expectedPayload2}
- if _, err := conn.ExpectData(&expectedTCP, &expectedPayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &expectedTCP, &expectedPayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
})
diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
index 7d5deab01..57d034dd1 100644
--- a/test/packetimpact/tests/tcp_synrcvd_reset_test.go
+++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go
@@ -32,21 +32,21 @@ func init() {
func TestTCPSynRcvdReset(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFD)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
// Expect dut connection to have transitioned to SYN-RCVD state.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)})
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN-ACK %s", err)
}
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)})
// Expect the connection to have transitioned SYN-RCVD to CLOSED.
// TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go
index 6898a2239..eac8eb19d 100644
--- a/test/packetimpact/tests/tcp_synsent_reset_test.go
+++ b/test/packetimpact/tests/tcp_synsent_reset_test.go
@@ -31,17 +31,19 @@ func init() {
// dutSynSentState sets up the dut connection in SYN-SENT state.
func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) {
+ t.Helper()
+
dut := tb.NewDUT(t)
- clientFD, clientPort := dut.CreateBoundSocket(unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4))
+ clientFD, clientPort := dut.CreateBoundSocket(t, unix.SOCK_STREAM|unix.SOCK_NONBLOCK, unix.IPPROTO_TCP, net.ParseIP(tb.RemoteIPv4))
port := uint16(9001)
conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &port, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &port})
sa := unix.SockaddrInet4{Port: int(port)}
copy(sa.Addr[:], net.IP(net.ParseIP(tb.LocalIPv4)).To4())
// Bring the dut to SYN-SENT state with a non-blocking connect.
- dut.Connect(clientFD, &sa)
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil {
+ dut.Connect(t, clientFD, &sa)
+ if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN\n")
}
@@ -51,13 +53,13 @@ func dutSynSentState(t *testing.T) (*tb.DUT, *tb.TCPIPv4, uint16, uint16) {
// TestTCPSynSentReset tests RFC793, p67: SYN-SENT to CLOSED transition.
func TestTCPSynSentReset(t *testing.T) {
dut, conn, _, _ := dutSynSentState(t)
- defer conn.Close()
+ defer conn.Close(t)
defer dut.TearDown()
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst | header.TCPFlagAck)})
// Expect the connection to have closed.
// TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
@@ -67,22 +69,22 @@ func TestTCPSynSentReset(t *testing.T) {
func TestTCPSynSentRcvdReset(t *testing.T) {
dut, c, remotePort, clientPort := dutSynSentState(t)
defer dut.TearDown()
- defer c.Close()
+ defer c.Close(t)
conn := tb.NewTCPIPv4(t, tb.TCP{SrcPort: &remotePort, DstPort: &clientPort}, tb.TCP{SrcPort: &clientPort, DstPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
// Initiate new SYN connection with the same port pair
// (simultaneous open case), expect the dut connection to move to
// SYN-RCVD state
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)})
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected SYN-ACK %s\n", err)
}
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)})
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)})
// Expect the connection to have transitioned SYN-RCVD to CLOSED.
// TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side.
- conn.Send(tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(&tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, tb.TCP{Flags: tb.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &tb.TCP{Flags: tb.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go
index 87e45d765..551dc78e7 100644
--- a/test/packetimpact/tests/tcp_user_timeout_test.go
+++ b/test/packetimpact/tests/tcp_user_timeout_test.go
@@ -16,7 +16,6 @@ package tcp_user_timeout_test
import (
"flag"
- "fmt"
"testing"
"time"
@@ -29,22 +28,20 @@ func init() {
testbench.RegisterFlags(flag.CommandLine)
}
-func sendPayload(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error {
+func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) {
sampleData := make([]byte, 100)
for i := range sampleData {
sampleData[i] = uint8(i)
}
- conn.Drain()
- dut.Send(fd, sampleData, 0)
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
- return fmt.Errorf("expected data but got none: %w", err)
+ conn.Drain(t)
+ dut.Send(t, fd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil {
+ t.Fatalf("expected data but got none: %w", err)
}
- return nil
}
-func sendFIN(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error {
- dut.Close(fd)
- return nil
+func sendFIN(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) {
+ dut.Close(t, fd)
}
func TestTCPUserTimeout(t *testing.T) {
@@ -59,7 +56,7 @@ func TestTCPUserTimeout(t *testing.T) {
} {
for _, ttf := range []struct {
description string
- f func(conn *testbench.TCPIPv4, dut *testbench.DUT, fd int32) error
+ f func(_ *testing.T, _ *testbench.TCPIPv4, _ *testbench.DUT, fd int32)
}{
{"AfterPayload", sendPayload},
{"AfterFIN", sendFIN},
@@ -68,31 +65,29 @@ func TestTCPUserTimeout(t *testing.T) {
// Create a socket, listen, TCP handshake, and accept.
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFD, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFD)
+ listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFD)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
- conn.Connect()
- acceptFD, _ := dut.Accept(listenFD)
+ defer conn.Close(t)
+ conn.Connect(t)
+ acceptFD, _ := dut.Accept(t, listenFD)
if tt.userTimeout != 0 {
- dut.SetSockOptInt(acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds()))
+ dut.SetSockOptInt(t, acceptFD, unix.SOL_TCP, unix.TCP_USER_TIMEOUT, int32(tt.userTimeout.Milliseconds()))
}
- if err := ttf.f(&conn, &dut, acceptFD); err != nil {
- t.Fatal(err)
- }
+ ttf.f(t, &conn, &dut, acceptFD)
time.Sleep(tt.sendDelay)
- conn.Drain()
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Drain(t)
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
// If TCP_USER_TIMEOUT was set and the above delay was longer than the
// TCP_USER_TIMEOUT then the DUT should send a RST in response to the
// testbench's packet.
expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout
expectTimeout := 5 * time.Second
- got, err := conn.Expect(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout)
+ got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout)
if expectRST && err != nil {
t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err)
}
diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go
index e78d04756..5b001fbec 100644
--- a/test/packetimpact/tests/tcp_window_shrink_test.go
+++ b/test/packetimpact/tests/tcp_window_shrink_test.go
@@ -31,43 +31,43 @@ func init() {
func TestWindowShrink(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
- conn.Connect()
- acceptFd, _ := dut.Accept(listenFd)
- defer dut.Close(acceptFd)
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
- dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
samplePayload := &testbench.Payload{Bytes: sampleData}
- dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- dut.Send(acceptFd, sampleData, 0)
- dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ dut.Send(t, acceptFd, sampleData, 0)
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
// We close our receiving window here
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
- dut.Send(acceptFd, []byte("Sample Data"), 0)
+ dut.Send(t, acceptFd, []byte("Sample Data"), 0)
// Note: There is another kind of zero-window probing which Windows uses (by sending one
// new byte at `RemoteSeqNum`), if netstack wants to go that way, we may want to change
// the following lines.
- expectedRemoteSeqNum := *conn.RemoteSeqNum() - 1
- if _, err := conn.ExpectData(&testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil {
+ expectedRemoteSeqNum := *conn.RemoteSeqNum(t) - 1
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(expectedRemoteSeqNum))}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %d: %s", expectedRemoteSeqNum, err)
}
}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
index 8c89d57c9..da93267d6 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go
@@ -33,27 +33,27 @@ func init() {
func TestZeroWindowProbeRetransmit(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
- conn.Connect()
- acceptFd, _ := dut.Accept(listenFd)
- defer dut.Close(acceptFd)
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
- dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
samplePayload := &testbench.Payload{Bytes: sampleData}
// Send and receive sample data to the dut.
- dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected packet was not received: %s", err)
}
@@ -63,15 +63,15 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
// of the recorded first zero probe transmission duration.
//
// Advertize zero receive window again.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
- probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
- ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
+ ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
startProbeDuration := time.Second
current := startProbeDuration
first := time.Now()
// Ask the dut to send out data.
- dut.Send(acceptFd, sampleData, 0)
+ dut.Send(t, acceptFd, sampleData, 0)
// Expect the dut to keep the connection alive as long as the remote is
// acknowledging the zero-window probes.
for i := 0; i < 5; i++ {
@@ -79,7 +79,7 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
// Expect zero-window probe with a timeout which is a function of the typical
// first retransmission time. The retransmission times is supposed to
// exponentially increase.
- if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*current); err != nil {
t.Fatalf("expected a probe with sequence number %d: loop %d", probeSeq, i)
}
if i == 0 {
@@ -92,14 +92,13 @@ func TestZeroWindowProbeRetransmit(t *testing.T) {
t.Errorf("got zero probe %d after %s, want >= %s", i, got, want)
}
// Acknowledge the zero-window probes from the dut.
- conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
current *= 2
}
// Advertize non-zero window.
- conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
// Expect the dut to recover and transmit data.
- if _, err := conn.ExpectData(&testbench.
- TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go
index 649fd5699..44cac42f8 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go
@@ -33,29 +33,29 @@ func init() {
func TestZeroWindowProbe(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
- conn.Connect()
- acceptFd, _ := dut.Accept(listenFd)
- defer dut.Close(acceptFd)
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
- dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
samplePayload := &testbench.Payload{Bytes: sampleData}
start := time.Now()
// Send and receive sample data to the dut.
- dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
sendTime := time.Now().Sub(start)
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected packet was not received: %s", err)
}
@@ -63,24 +63,24 @@ func TestZeroWindowProbe(t *testing.T) {
// probe to be sent.
//
// Advertize zero window to the dut.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Expected sequence number of the zero window probe.
- probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
// Expected ack number of the ACK for the probe.
- ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum()))
+ ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))
// Expect there are no zero-window probes sent until there is data to be sent out
// from the dut.
- if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, 2*time.Second); err == nil {
t.Fatalf("unexpected packet with sequence number %d: %s", probeSeq, err)
}
start = time.Now()
// Ask the dut to send out data.
- dut.Send(acceptFd, sampleData, 0)
+ dut.Send(t, acceptFd, sampleData, 0)
// Expect zero-window probe from the dut.
- if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err)
}
// Expect the probe to be sent after some time. Compare against the previous
@@ -94,9 +94,9 @@ func TestZeroWindowProbe(t *testing.T) {
// and sends out the sample payload after the send window opens.
//
// Advertize non-zero window to the dut and ack the zero window probe.
- conn.Send(testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
+ conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)})
// Expect the dut to recover and transmit data.
- if _, err := conn.ExpectData(&testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
@@ -104,9 +104,9 @@ func TestZeroWindowProbe(t *testing.T) {
// Check if the dut responds as we do for a similar probe sent to it.
// Basically with sequence number to one byte behind the unacknowledged
// sequence number.
- p := testbench.Uint32(uint32(*conn.LocalSeqNum()))
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum() - 1))})
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil {
+ p := testbench.Uint32(uint32(*conn.LocalSeqNum(t)))
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with ack number: %d: %s", p, err)
}
}
diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
index 3c467b14f..09a1c653f 100644
--- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
+++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go
@@ -33,27 +33,27 @@ func init() {
func TestZeroWindowProbeUserTimeout(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- listenFd, remotePort := dut.CreateListener(unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
- defer dut.Close(listenFd)
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1)
+ defer dut.Close(t, listenFd)
conn := testbench.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
- conn.Connect()
- acceptFd, _ := dut.Accept(listenFd)
- defer dut.Close(acceptFd)
+ conn.Connect(t)
+ acceptFd, _ := dut.Accept(t, listenFd)
+ defer dut.Close(t, acceptFd)
- dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1)
sampleData := []byte("Sample Data")
samplePayload := &testbench.Payload{Bytes: sampleData}
// Send and receive sample data to the dut.
- dut.Send(acceptFd, sampleData, 0)
- if _, err := conn.ExpectData(&testbench.TCP{}, samplePayload, time.Second); err != nil {
+ dut.Send(t, acceptFd, sampleData, 0)
+ if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil {
t.Fatalf("expected payload was not received: %s", err)
}
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload)
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil {
t.Fatalf("expected packet was not received: %s", err)
}
@@ -61,15 +61,15 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
// probe to be sent.
//
// Advertize zero window to the dut.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Expected sequence number of the zero window probe.
- probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum() - 1))
+ probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1))
start := time.Now()
// Ask the dut to send out data.
- dut.Send(acceptFd, sampleData, 0)
+ dut.Send(t, acceptFd, sampleData, 0)
// Expect zero-window probe from the dut.
- if _, err := conn.ExpectData(&testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: probeSeq}, nil, time.Second); err != nil {
t.Fatalf("expected a packet with sequence number %d: %s", probeSeq, err)
}
// Record the duration for first probe, the dut sends the zero window probe after
@@ -80,19 +80,19 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) {
// when the dut is sending zero-window probes.
//
// Reduce the retransmit timeout.
- dut.SetSockOptInt(acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds()))
+ dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds()))
// Advertize zero window again.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)})
// Ask the dut to send out data that would trigger zero window probe retransmissions.
- dut.Send(acceptFd, sampleData, 0)
+ dut.Send(t, acceptFd, sampleData, 0)
// Wait for the connection to timeout after multiple zero-window probe retransmissions.
time.Sleep(8 * startProbeDuration)
// Expect the connection to have timed out and closed which would cause the dut
// to reply with a RST to the ACK we send.
- conn.Send(testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
- if _, err := conn.ExpectData(&testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
+ conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil {
t.Fatalf("expected a TCP RST")
}
}
diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
index b0315e67c..d30177e64 100644
--- a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
+++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go
@@ -36,11 +36,11 @@ func init() {
func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4))
- defer dut.Close(remoteFD)
- dut.SetSockOptTimeval(remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond)
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv4))
+ defer dut.Close(t, remoteFD)
+ dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond)
conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
for _, mcastAddr := range []net.IP{
net.IPv4allsys,
@@ -50,11 +50,12 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) {
} {
t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) {
conn.SendIP(
+ t,
testbench.IPv4{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To4()))},
testbench.UDP{},
)
- ret, payload, errno := dut.RecvWithErrno(context.Background(), remoteFD, 100, 0)
+ ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0)
if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno)
}
@@ -65,11 +66,11 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) {
func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6))
- defer dut.Close(remoteFD)
- dut.SetSockOptTimeval(remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond)
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(testbench.RemoteIPv6))
+ defer dut.Close(t, remoteFD)
+ dut.SetSockOptTimeval(t, remoteFD, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &oneSecond)
conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
for _, mcastAddr := range []net.IP{
net.IPv6interfacelocalallnodes,
@@ -80,10 +81,11 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) {
} {
t.Run(fmt.Sprintf("srcaddr=%s", mcastAddr), func(t *testing.T) {
conn.SendIPv6(
+ t,
testbench.IPv6{SrcAddr: testbench.Address(tcpip.Address(mcastAddr.To16()))},
testbench.UDP{},
)
- ret, payload, errno := dut.RecvWithErrno(context.Background(), remoteFD, 100, 0)
+ ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0)
if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK {
t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno)
}
diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
index b754918f6..715e8f5b5 100644
--- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go
+++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go
@@ -72,7 +72,7 @@ func (e icmpError) ToICMPv4() *testbench.ICMPv4 {
type errorDetection struct {
name string
useValidConn bool
- f func(context.Context, testData) error
+ f func(context.Context, *testing.T, testData)
}
type testData struct {
@@ -95,12 +95,14 @@ func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno {
}
// sendICMPError sends an ICMP error message in response to a UDP datagram.
-func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) error {
- layers := (*testbench.Connection)(conn).CreateFrame(nil)
+func sendICMPError(t *testing.T, conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UDP) {
+ t.Helper()
+
+ layers := (*testbench.Connection)(conn).CreateFrame(t, nil)
layers = layers[:len(layers)-1]
ip, ok := udp.Prev().(*testbench.IPv4)
if !ok {
- return fmt.Errorf("expected %s to be IPv4", udp.Prev())
+ t.Fatalf("expected %s to be IPv4", udp.Prev())
}
if icmpErr == timeToLiveExceeded {
*ip.TTL = 1
@@ -114,84 +116,82 @@ func sendICMPError(conn *testbench.UDPIPv4, icmpErr icmpError, udp *testbench.UD
// resulting in a mal-formed packet.
layers = append(layers, icmpErr.ToICMPv4(), ip, udp)
- (*testbench.Connection)(conn).SendFrameStateless(layers)
- return nil
+ (*testbench.Connection)(conn).SendFrameStateless(t, layers)
}
// testRecv tests observing the ICMP error through the recv syscall. A packet
// is sent to the DUT, and if wantErrno is non-zero, then the first recv should
// fail and the second should succeed. Otherwise if wantErrno is zero then the
// first recv should succeed immediately.
-func testRecv(ctx context.Context, d testData) error {
+func testRecv(ctx context.Context, t *testing.T, d testData) {
+ t.Helper()
+
// Check that receiving on the clean socket works.
- d.conn.Send(testbench.UDP{DstPort: &d.cleanPort})
- d.dut.Recv(d.cleanFD, 100, 0)
+ d.conn.Send(t, testbench.UDP{DstPort: &d.cleanPort})
+ d.dut.Recv(t, d.cleanFD, 100, 0)
- d.conn.Send(testbench.UDP{})
+ d.conn.Send(t, testbench.UDP{})
if d.wantErrno != syscall.Errno(0) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
- ret, _, err := d.dut.RecvWithErrno(ctx, d.remoteFD, 100, 0)
+ ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0)
if ret != -1 {
- return fmt.Errorf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno)
+ t.Fatalf("recv after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno)
}
if err != d.wantErrno {
- return fmt.Errorf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno)
+ t.Fatalf("recv after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno)
}
}
- d.dut.Recv(d.remoteFD, 100, 0)
- return nil
+ d.dut.Recv(t, d.remoteFD, 100, 0)
}
// testSendTo tests observing the ICMP error through the send syscall. If
// wantErrno is non-zero, the first send should fail and a subsequent send
// should suceed; while if wantErrno is zero then the first send should just
// succeed.
-func testSendTo(ctx context.Context, d testData) error {
+func testSendTo(ctx context.Context, t *testing.T, d testData) {
// Check that sending on the clean socket works.
- d.dut.SendTo(d.cleanFD, nil, 0, d.conn.LocalAddr())
- if _, err := d.conn.Expect(testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil {
- return fmt.Errorf("did not receive UDP packet from clean socket on DUT: %s", err)
+ d.dut.SendTo(t, d.cleanFD, nil, 0, d.conn.LocalAddr(t))
+ if _, err := d.conn.Expect(t, testbench.UDP{SrcPort: &d.cleanPort}, time.Second); err != nil {
+ t.Fatalf("did not receive UDP packet from clean socket on DUT: %s", err)
}
if d.wantErrno != syscall.Errno(0) {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
- ret, err := d.dut.SendToWithErrno(ctx, d.remoteFD, nil, 0, d.conn.LocalAddr())
+ ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t))
if ret != -1 {
- return fmt.Errorf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno)
+ t.Fatalf("sendto after ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", d.wantErrno)
}
if err != d.wantErrno {
- return fmt.Errorf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno)
+ t.Fatalf("sendto after ICMP error resulted in error (%[1]d) %[1]v, expected (%[2]d) %[2]v", err, d.wantErrno)
}
}
- d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr())
- if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil {
- return fmt.Errorf("did not receive UDP packet as expected: %s", err)
+ d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t))
+ if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil {
+ t.Fatalf("did not receive UDP packet as expected: %s", err)
}
- return nil
}
-func testSockOpt(_ context.Context, d testData) error {
+func testSockOpt(_ context.Context, t *testing.T, d testData) {
// Check that there's no pending error on the clean socket.
- if errno := syscall.Errno(d.dut.GetSockOptInt(d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) {
- return fmt.Errorf("unexpected error (%[1]d) %[1]v on clean socket", errno)
+ if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) {
+ t.Fatalf("unexpected error (%[1]d) %[1]v on clean socket", errno)
}
- if errno := syscall.Errno(d.dut.GetSockOptInt(d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno {
- return fmt.Errorf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno)
+ if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno {
+ t.Fatalf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno)
}
// Check that after clearing socket error, sending doesn't fail.
- d.dut.SendTo(d.remoteFD, nil, 0, d.conn.LocalAddr())
- if _, err := d.conn.Expect(testbench.UDP{}, time.Second); err != nil {
- return fmt.Errorf("did not receive UDP packet as expected: %s", err)
+ d.dut.SendTo(t, d.remoteFD, nil, 0, d.conn.LocalAddr(t))
+ if _, err := d.conn.Expect(t, testbench.UDP{}, time.Second); err != nil {
+ t.Fatalf("did not receive UDP packet as expected: %s", err)
}
- return nil
}
// TestUDPICMPErrorPropagation tests that ICMP error messages in response to
@@ -227,31 +227,29 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
- defer dut.Close(remoteFD)
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(t, remoteFD)
// Create a second, clean socket on the DUT to ensure that the ICMP
// error messages only affect the sockets they are intended for.
- cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
- defer dut.Close(cleanFD)
+ cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(t, cleanFD)
conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
if connect {
- dut.Connect(remoteFD, conn.LocalAddr())
- dut.Connect(cleanFD, conn.LocalAddr())
+ dut.Connect(t, remoteFD, conn.LocalAddr(t))
+ dut.Connect(t, cleanFD, conn.LocalAddr(t))
}
- dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
- udp, err := conn.Expect(testbench.UDP{}, time.Second)
+ dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t))
+ udp, err := conn.Expect(t, testbench.UDP{}, time.Second)
if err != nil {
t.Fatalf("did not receive message from DUT: %s", err)
}
- if err := sendICMPError(&conn, icmpErr, udp); err != nil {
- t.Fatal(err)
- }
+ sendICMPError(t, &conn, icmpErr, udp)
errDetectConn := &conn
if errDetect.useValidConn {
@@ -260,14 +258,12 @@ func TestUDPICMPErrorPropagation(t *testing.T) {
// interactions between it and the the DUT should be independent of
// the ICMP error at least at the port level.
connClean := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- defer connClean.Close()
+ defer connClean.Close(t)
errDetectConn = &connClean
}
- if err := errDetect.f(context.Background(), testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno}); err != nil {
- t.Fatal(err)
- }
+ errDetect.f(context.Background(), t, testData{&dut, errDetectConn, remoteFD, remotePort, cleanFD, cleanPort, wantErrno})
})
}
}
@@ -285,24 +281,24 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- remoteFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
- defer dut.Close(remoteFD)
+ remoteFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(t, remoteFD)
// Create a second, clean socket on the DUT to ensure that the ICMP
// error messages only affect the sockets they are intended for.
- cleanFD, cleanPort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
- defer dut.Close(cleanFD)
+ cleanFD, cleanPort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP("0.0.0.0"))
+ defer dut.Close(t, cleanFD)
conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
if connect {
- dut.Connect(remoteFD, conn.LocalAddr())
- dut.Connect(cleanFD, conn.LocalAddr())
+ dut.Connect(t, remoteFD, conn.LocalAddr(t))
+ dut.Connect(t, cleanFD, conn.LocalAddr(t))
}
- dut.SendTo(remoteFD, nil, 0, conn.LocalAddr())
- udp, err := conn.Expect(testbench.UDP{}, time.Second)
+ dut.SendTo(t, remoteFD, nil, 0, conn.LocalAddr(t))
+ udp, err := conn.Expect(t, testbench.UDP{}, time.Second)
if err != nil {
t.Fatalf("did not receive message from DUT: %s", err)
}
@@ -316,7 +312,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
- ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0)
+ ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0)
if ret != -1 {
t.Errorf("recv during ICMP error succeeded unexpectedly, expected (%[1]d) %[1]v", wantErrno)
return
@@ -330,7 +326,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
- if ret, _, err := dut.RecvWithErrno(ctx, remoteFD, 100, 0); ret == -1 {
+ if ret, _, err := dut.RecvWithErrno(ctx, t, remoteFD, 100, 0); ret == -1 {
t.Errorf("recv after ICMP error failed with (%[1]d) %[1]", err)
}
}()
@@ -341,7 +337,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
- if ret, _, err := dut.RecvWithErrno(ctx, cleanFD, 100, 0); ret == -1 {
+ if ret, _, err := dut.RecvWithErrno(ctx, t, cleanFD, 100, 0); ret == -1 {
t.Errorf("recv on clean socket failed with (%[1]d) %[1]", err)
}
}()
@@ -352,12 +348,10 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) {
// alternative is available.
time.Sleep(2 * time.Second)
- if err := sendICMPError(&conn, icmpErr, udp); err != nil {
- t.Fatal(err)
- }
+ sendICMPError(t, &conn, icmpErr, udp)
- conn.Send(testbench.UDP{DstPort: &cleanPort})
- conn.Send(testbench.UDP{})
+ conn.Send(t, testbench.UDP{DstPort: &cleanPort})
+ conn.Send(t, testbench.UDP{})
wg.Wait()
})
}
diff --git a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
index 263a54291..fcd202643 100644
--- a/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
+++ b/test/packetimpact/tests/udp_recv_mcast_bcast_test.go
@@ -31,10 +31,10 @@ func init() {
func TestUDPRecvMulticastBroadcast(t *testing.T) {
dut := testbench.NewDUT(t)
defer dut.TearDown()
- boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4(0, 0, 0, 0))
- defer dut.Close(boundFD)
+ boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.IPv4(0, 0, 0, 0))
+ defer dut.Close(t, boundFD)
conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- defer conn.Close()
+ defer conn.Close(t)
for _, bcastAddr := range []net.IP{
broadcastAddr(net.ParseIP(testbench.RemoteIPv4), net.CIDRMask(testbench.IPv4PrefixLength, 32)),
@@ -43,12 +43,13 @@ func TestUDPRecvMulticastBroadcast(t *testing.T) {
} {
payload := testbench.GenerateRandomPayload(t, 1<<10)
conn.SendIP(
+ t,
testbench.IPv4{DstAddr: testbench.Address(tcpip.Address(bcastAddr.To4()))},
testbench.UDP{},
&testbench.Payload{Bytes: payload},
)
t.Logf("Receiving packet sent to address: %s", bcastAddr)
- if got, want := string(dut.Recv(boundFD, int32(len(payload)), 0)), string(payload); got != want {
+ if got, want := string(dut.Recv(t, boundFD, int32(len(payload)), 0)), string(payload); got != want {
t.Errorf("received payload does not match sent payload got: %s, want: %s", got, want)
}
}
diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go
index bd53ad90b..dc20275d6 100644
--- a/test/packetimpact/tests/udp_send_recv_dgram_test.go
+++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go
@@ -29,10 +29,10 @@ func init() {
}
type udpConn interface {
- Send(testbench.UDP, ...testbench.Layer)
- ExpectData(testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error)
- Drain()
- Close()
+ Send(*testing.T, testbench.UDP, ...testbench.Layer)
+ ExpectData(*testing.T, testbench.UDP, testbench.Payload, time.Duration) (testbench.Layers, error)
+ Drain(*testing.T)
+ Close(*testing.T)
}
func TestUDP(t *testing.T) {
@@ -51,21 +51,21 @@ func TestUDP(t *testing.T) {
} else {
addr = testbench.RemoteIPv6
}
- boundFD, remotePort := dut.CreateBoundSocket(unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr))
- defer dut.Close(boundFD)
+ boundFD, remotePort := dut.CreateBoundSocket(t, unix.SOCK_DGRAM, unix.IPPROTO_UDP, net.ParseIP(addr))
+ defer dut.Close(t, boundFD)
var conn udpConn
var localAddr unix.Sockaddr
if isIPv4 {
v4Conn := testbench.NewUDPIPv4(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- localAddr = v4Conn.LocalAddr()
+ localAddr = v4Conn.LocalAddr(t)
conn = &v4Conn
} else {
v6Conn := testbench.NewUDPIPv6(t, testbench.UDP{DstPort: &remotePort}, testbench.UDP{SrcPort: &remotePort})
- localAddr = v6Conn.LocalAddr()
+ localAddr = v6Conn.LocalAddr(t)
conn = &v6Conn
}
- defer conn.Close()
+ defer conn.Close(t)
testCases := []struct {
name string
@@ -81,17 +81,17 @@ func TestUDP(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Send", func(t *testing.T) {
- conn.Send(testbench.UDP{}, &testbench.Payload{Bytes: tc.payload})
- if got, want := string(dut.Recv(boundFD, int32(len(tc.payload)), 0)), string(tc.payload); got != want {
+ conn.Send(t, testbench.UDP{}, &testbench.Payload{Bytes: tc.payload})
+ if got, want := string(dut.Recv(t, boundFD, int32(len(tc.payload)), 0)), string(tc.payload); got != want {
t.Fatalf("received payload does not match sent payload got: %s, want: %s", got, want)
}
})
t.Run("Recv", func(t *testing.T) {
- conn.Drain()
- if got, want := int(dut.SendTo(boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want {
+ conn.Drain(t)
+ if got, want := int(dut.SendTo(t, boundFD, tc.payload, 0, localAddr)), len(tc.payload); got != want {
t.Fatalf("short write got: %d, want: %d", got, want)
}
- if _, err := conn.ExpectData(testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil {
+ if _, err := conn.ExpectData(t, testbench.UDP{SrcPort: &remotePort}, testbench.Payload{Bytes: tc.payload}, time.Second); err != nil {
t.Fatal(err)
}
})
diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl
index 600cb5192..c92392b35 100644
--- a/test/runner/defs.bzl
+++ b/test/runner/defs.bzl
@@ -157,7 +157,7 @@ def syscall_test(
platform = "native",
use_tmpfs = False,
add_uds_tree = add_uds_tree,
- tags = tags,
+ tags = list(tags),
)
for (platform, platform_tags) in platforms.items():
diff --git a/test/syscalls/linux/mount.cc b/test/syscalls/linux/mount.cc
index 7664fa73d..97e8d0f7e 100644
--- a/test/syscalls/linux/mount.cc
+++ b/test/syscalls/linux/mount.cc
@@ -323,10 +323,7 @@ TEST(MountTest, RenameRemoveMountPoint) {
TEST(MountTest, MountFuseFilesystemNoDevice) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
-
- // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new
- // device registration is complete.
- SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor());
+ SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled());
auto const dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
EXPECT_THAT(mount("", dir.path().c_str(), "fuse", 0, ""),
@@ -335,10 +332,7 @@ TEST(MountTest, MountFuseFilesystemNoDevice) {
TEST(MountTest, MountFuseFilesystem) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
-
- // Note(gvisor.dev/issue/3076) This won't work in the sentry until the new
- // device registration is complete.
- SKIP_IF(IsRunningWithVFS1() || IsRunningOnGvisor());
+ SKIP_IF(IsRunningOnGvisor() && !IsFUSEEnabled());
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY));
diff --git a/tools/bazel.mk b/tools/bazel.mk
index e27e907ab..45d6007cf 100644
--- a/tools/bazel.mk
+++ b/tools/bazel.mk
@@ -15,6 +15,7 @@
# limitations under the License.
# See base Makefile.
+SHELL=/bin/bash -o pipefail
BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
git rev-parse --abbrev-ref HEAD 2>/dev/null) | \
xargs -n 1 basename 2>/dev/null)
@@ -22,8 +23,11 @@ BRANCH_NAME := $(shell (git branch --show-current 2>/dev/null || \
# Bazel container configuration (see below).
USER ?= gvisor
HASH ?= $(shell readlink -m $(CURDIR) | md5sum | cut -c1-8)
+BUILDER_BASE := gvisor.dev/images/default
+BUILDER_IMAGE := gvisor.dev/images/builder
+BUILDER_NAME ?= gvisor-builder-$(HASH)
DOCKER_NAME ?= gvisor-bazel-$(HASH)
-DOCKER_PRIVILEGED ?= --privileged --network host
+DOCKER_PRIVILEGED ?= --privileged
BAZEL_CACHE := $(shell readlink -m ~/.cache/bazel/)
GCLOUD_CONFIG := $(shell readlink -m ~/.config/gcloud/)
DOCKER_SOCKET := /var/run/docker.sock
@@ -32,17 +36,25 @@ DOCKER_SOCKET := /var/run/docker.sock
OPTIONS += --test_output=errors --keep_going --verbose_failures=true
BAZEL := bazel $(STARTUP_OPTIONS)
-# Non-configurable.
+# Basic options.
UID := $(shell id -u ${USER})
GID := $(shell id -g ${USER})
USERADD_OPTIONS :=
FULL_DOCKER_RUN_OPTIONS := $(DOCKER_RUN_OPTIONS)
+FULL_DOCKER_RUN_OPTIONS += --user $(UID):$(GID)
+FULL_DOCKER_RUN_OPTIONS += --entrypoint ""
+FULL_DOCKER_RUN_OPTIONS += --init
FULL_DOCKER_RUN_OPTIONS += -v "$(BAZEL_CACHE):$(BAZEL_CACHE)"
FULL_DOCKER_RUN_OPTIONS += -v "$(GCLOUD_CONFIG):$(GCLOUD_CONFIG)"
FULL_DOCKER_RUN_OPTIONS += -v "/tmp:/tmp"
+FULL_DOCKER_EXEC_OPTIONS := --user $(UID):$(GID)
+FULL_DOCKER_EXEC_OPTIONS += -i
+
+# Add docker passthrough options.
ifneq ($(DOCKER_PRIVILEGED),)
FULL_DOCKER_RUN_OPTIONS += -v "$(DOCKER_SOCKET):$(DOCKER_SOCKET)"
FULL_DOCKER_RUN_OPTIONS += $(DOCKER_PRIVILEGED)
+FULL_DOCKER_EXEC_OPTIONS += $(DOCKER_PRIVILEGED)
DOCKER_GROUP := $(shell stat -c '%g' $(DOCKER_SOCKET))
ifneq ($(GID),$(DOCKER_GROUP))
USERADD_OPTIONS += --groups $(DOCKER_GROUP)
@@ -50,7 +62,35 @@ GROUPADD_DOCKER += groupadd --gid $(DOCKER_GROUP) --non-unique docker-$(HASH) &&
FULL_DOCKER_RUN_OPTIONS += --group-add $(DOCKER_GROUP)
endif
endif
-SHELL=/bin/bash -o pipefail
+
+# Add KVM passthrough options.
+ifneq (,$(wildcard /dev/kvm))
+FULL_DOCKER_RUN_OPTIONS += --device=/dev/kvm
+KVM_GROUP := $(shell stat -c '%g' /dev/kvm)
+ifneq ($(GID),$(KVM_GROUP))
+USERADD_OPTIONS += --groups $(KVM_GROUP)
+GROUPADD_DOCKER += groupadd --gid $(KVM_GROUP) --non-unique kvm-$(HASH) &&
+FULL_DOCKER_RUN_OPTIONS += --group-add $(KVM_GROUP)
+endif
+endif
+
+# Load the appropriate config.
+ifneq (,$(BAZEL_CONFIG))
+OPTIONS += --config=$(BAZEL_CONFIG)
+endif
+
+bazel-image: load-default
+ @if docker ps --all | grep $(BUILDER_NAME); then docker rm -f $(BUILDER_NAME); fi
+ docker run --user 0:0 --entrypoint "" --name $(BUILDER_NAME) \
+ $(BUILDER_BASE) \
+ sh -c "groupadd --gid $(GID) --non-unique $(USER) && \
+ $(GROUPADD_DOCKER) \
+ useradd --uid $(UID) --non-unique --no-create-home \
+ --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \
+ if [[ -e /dev/kvm ]]; then chmod a+rw /dev/kvm; fi"
+ docker commit $(BUILDER_NAME) $(BUILDER_IMAGE)
+ @docker rm -f $(BUILDER_NAME)
+.PHONY: bazel-image
##
## Bazel helpers.
@@ -65,41 +105,37 @@ SHELL=/bin/bash -o pipefail
## GCLOUD_CONFIG - The gcloud config directory (detect: detected).
## DOCKER_SOCKET - The Docker socket (default: detected).
##
-bazel-server-start: load-default ## Starts the bazel server.
+bazel-server-start: bazel-image ## Starts the bazel server.
@mkdir -p $(BAZEL_CACHE)
@mkdir -p $(GCLOUD_CONFIG)
@if docker ps --all | grep $(DOCKER_NAME); then docker rm -f $(DOCKER_NAME); fi
- docker run -d --rm \
- --init \
- --name $(DOCKER_NAME) \
- --user 0:0 $(DOCKER_GROUP_OPTIONS) \
+ # This command runs a bazel server, and the container sticks around
+ # until the bazel server exits. This should ensure that it does not
+ # exit in the middle of running a build, but also it won't stick around
+ # forever. The build commands wrap around an appropriate exec into the
+ # container in order to perform work via the bazel client.
+ docker run -d --rm --name $(DOCKER_NAME) \
-v "$(CURDIR):$(CURDIR)" \
--workdir "$(CURDIR)" \
- --entrypoint "" \
$(FULL_DOCKER_RUN_OPTIONS) \
- gvisor.dev/images/default \
- sh -c "groupadd --gid $(GID) --non-unique $(USER) && \
- $(GROUPADD_DOCKER) \
- useradd --uid $(UID) --non-unique --no-create-home --gid $(GID) $(USERADD_OPTIONS) -d $(HOME) $(USER) && \
- $(BAZEL) version && \
- exec tail --pid=\$$($(BAZEL) info server_pid) -f /dev/null"
- @while :; do if docker logs $(DOCKER_NAME) 2>/dev/null | grep "Build label:" >/dev/null; then break; fi; \
- if ! docker ps | grep $(DOCKER_NAME); then docker logs $(DOCKER_NAME); exit 1; else sleep 1; fi; done
+ $(BUILDER_IMAGE) \
+ sh -c "tail -f --pid=\$$($(BAZEL) info server_pid)"
.PHONY: bazel-server-start
bazel-shutdown: ## Shuts down a running bazel server.
- @docker exec --user $(UID):$(GID) $(DOCKER_NAME) $(BAZEL) shutdown; rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]]
+ @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) shutdown; \
+ rc=$$?; docker kill $(DOCKER_NAME) || [[ $$rc -ne 0 ]]
.PHONY: bazel-shutdown
bazel-alias: ## Emits an alias that can be used within the shell.
- @echo "alias bazel='docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) bazel'"
+ @echo "alias bazel='docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) bazel'"
.PHONY: bazel-alias
bazel-server: ## Ensures that the server exists. Used as an internal target.
- @docker exec $(DOCKER_NAME) true || $(MAKE) bazel-server-start
+ @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) true || $(MAKE) bazel-server-start
.PHONY: bazel-server
-build_cmd = docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) $(TARGETS)'
+build_cmd = docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) sh -o pipefail -c '$(BAZEL) build $(OPTIONS) $(TARGETS)'
build_paths = $(build_cmd) 2>&1 \
| tee /proc/self/fd/2 \
@@ -126,9 +162,9 @@ sudo: bazel-server
.PHONY: sudo
test: bazel-server
- @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) $(BAZEL) test $(OPTIONS) $(TARGETS)
+ @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) test $(OPTIONS) $(TARGETS)
.PHONY: test
query: bazel-server
- @docker exec --user $(UID):$(GID) -i $(DOCKER_NAME) $(BAZEL) query $(OPTIONS) '$(TARGETS)'
+ @docker exec $(FULL_DOCKER_EXEC_OPTIONS) $(DOCKER_NAME) $(BAZEL) query $(OPTIONS) '$(TARGETS)'
.PHONY: query
diff --git a/tools/bazeldefs/BUILD b/tools/bazeldefs/BUILD
index f2f80bae1..3f809065d 100644
--- a/tools/bazeldefs/BUILD
+++ b/tools/bazeldefs/BUILD
@@ -49,3 +49,40 @@ rbe_toolchain(
toolchain = "@bazel_toolchains//configs/ubuntu16_04_clang/10.0.0/bazel_2.0.0/cc:cc-compiler-k8",
toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
)
+
+# Updated versions of the above, compatible with bazel3.
+rbe_platform(
+ name = "rbe_ubuntu1604_bazel3",
+ constraint_values = [
+ "@bazel_tools//platforms:x86_64",
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//tools/cpp:clang",
+ "@bazel_toolchains_bazel3//constraints:xenial",
+ "@bazel_toolchains_bazel3//constraints/sanitizers:support_msan",
+ ],
+ remote_execution_properties = """
+ properties: {
+ name: "container-image"
+ value:"docker://gcr.io/cloud-marketplace/google/rbe-ubuntu16-04@sha256:b516a2d69537cb40a7c6a7d92d0008abb29fba8725243772bdaf2c83f1be2272"
+ }
+ properties: {
+ name: "dockerAddCapabilities"
+ value: "SYS_ADMIN"
+ }
+ properties: {
+ name: "dockerPrivileged"
+ value: "true"
+ }
+ """,
+)
+
+rbe_toolchain(
+ name = "cc-toolchain-clang-x86_64-default_bazel3",
+ exec_compatible_with = [],
+ tags = [
+ "manual",
+ ],
+ target_compatible_with = [],
+ toolchain = "@bazel_toolchains_bazel3//configs/ubuntu16_04_clang/11.0.0/bazel_3.1.0/cc:cc-compiler-k8",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)