summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--WORKSPACE20
-rw-r--r--kokoro/build.cfg2
-rw-r--r--kokoro/build_tests.cfg1
-rw-r--r--kokoro/docker_tests.cfg1
-rw-r--r--kokoro/hostnet_tests.cfg1
-rw-r--r--kokoro/kvm_tests.cfg1
-rw-r--r--kokoro/overlay_tests.cfg1
-rw-r--r--kokoro/root_tests.cfg1
-rw-r--r--pkg/fd/BUILD3
-rw-r--r--pkg/fd/fd.go8
-rw-r--r--pkg/metric/BUILD7
-rw-r--r--pkg/p9/p9test/client_test.go95
-rw-r--r--pkg/seccomp/seccomp_unsafe.go2
-rw-r--r--pkg/sentry/arch/BUILD7
-rw-r--r--pkg/sentry/fs/fsutil/host_mappable.go2
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached.go22
-rw-r--r--pkg/sentry/fs/fsutil/inode_cached_test.go4
-rw-r--r--pkg/sentry/fs/gofer/inode.go15
-rw-r--r--pkg/sentry/fs/host/inode.go4
-rw-r--r--pkg/sentry/fs/host/tty.go15
-rw-r--r--pkg/sentry/fs/proc/net.go6
-rw-r--r--pkg/sentry/fs/proc/proc.go2
-rw-r--r--pkg/sentry/fs/splice.go4
-rw-r--r--pkg/sentry/fs/timerfd/timerfd.go3
-rw-r--r--pkg/sentry/fsimpl/ext/directory.go8
-rw-r--r--pkg/sentry/fsimpl/ext/ext_test.go2
-rw-r--r--pkg/sentry/fsimpl/memfs/directory.go24
-rw-r--r--pkg/sentry/kernel/BUILD7
-rw-r--r--pkg/sentry/kernel/kernel.go129
-rw-r--r--pkg/sentry/kernel/memevent/BUILD7
-rw-r--r--pkg/sentry/kernel/posixtimer.go8
-rw-r--r--pkg/sentry/kernel/task_identity.go4
-rw-r--r--pkg/sentry/kernel/task_sched.go33
-rw-r--r--pkg/sentry/kernel/thread_group.go3
-rw-r--r--pkg/sentry/kernel/time/time.go27
-rw-r--r--pkg/sentry/loader/elf.go18
-rw-r--r--pkg/sentry/loader/loader.go3
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go14
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_linux.go4
-rw-r--r--pkg/sentry/sighandling/sighandling_unsafe.go2
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go42
-rw-r--r--pkg/sentry/socket/epsocket/provider.go2
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD9
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go82
-rw-r--r--pkg/sentry/strace/BUILD7
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_read.go33
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_splice.go10
-rw-r--r--pkg/sentry/syscalls/linux/sys_time.go39
-rw-r--r--pkg/sentry/syscalls/linux/sys_utsname.go6
-rw-r--r--pkg/sentry/unimpl/BUILD7
-rw-r--r--pkg/sentry/usermem/usermem.go8
-rw-r--r--pkg/sentry/vfs/file_description.go7
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go5
-rw-r--r--pkg/tcpip/link/channel/channel.go3
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go32
-rw-r--r--pkg/tcpip/link/loopback/loopback.go3
-rw-r--r--pkg/tcpip/link/muxed/injectable.go7
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go3
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable.go3
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go3
-rw-r--r--pkg/tcpip/network/arp/arp.go16
-rw-r--r--pkg/tcpip/network/arp/arp_test.go5
-rw-r--r--pkg/tcpip/network/ip_test.go13
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go44
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go9
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go15
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go24
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go34
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go5
-rw-r--r--pkg/tcpip/ports/ports.go161
-rw-r--r--pkg/tcpip/ports/ports_test.go169
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go5
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go5
-rw-r--r--pkg/tcpip/stack/BUILD5
-rw-r--r--pkg/tcpip/stack/icmp_rate_limit.go49
-rw-r--r--pkg/tcpip/stack/nic.go80
-rw-r--r--pkg/tcpip/stack/registration.go45
-rw-r--r--pkg/tcpip/stack/stack.go162
-rw-r--r--pkg/tcpip/stack/stack_test.go175
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go227
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go352
-rw-r--r--pkg/tcpip/stack/transport_test.go45
-rw-r--r--pkg/tcpip/tcpip.go36
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go35
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go27
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go30
-rw-r--r--pkg/tcpip/transport/raw/protocol.go9
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/accept.go2
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go217
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go22
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go258
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go37
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go74
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/udp/protocol.go12
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go125
-rw-r--r--runsc/BUILD7
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/config.go4
-rw-r--r--runsc/boot/loader.go49
-rw-r--r--runsc/boot/user.go28
-rw-r--r--runsc/boot/user_test.go3
-rw-r--r--runsc/cmd/exec.go53
-rw-r--r--runsc/cmd/exec_test.go4
-rw-r--r--runsc/cmd/gofer.go5
-rw-r--r--runsc/container/BUILD1
-rw-r--r--runsc/container/container_test.go25
-rw-r--r--runsc/container/test_app/test_app.go65
-rw-r--r--runsc/dockerutil/dockerutil.go17
-rw-r--r--runsc/fsgofer/filter/config.go13
-rw-r--r--runsc/fsgofer/filter/filter.go13
-rw-r--r--runsc/fsgofer/fsgofer.go70
-rw-r--r--runsc/fsgofer/fsgofer_test.go2
-rw-r--r--runsc/main.go4
-rw-r--r--runsc/sandbox/sandbox.go2
-rw-r--r--runsc/specutils/BUILD1
-rw-r--r--runsc/specutils/namespace.go14
-rw-r--r--runsc/specutils/specutils.go10
-rwxr-xr-xscripts/build.sh2
-rwxr-xr-xscripts/common_bazel.sh13
-rwxr-xr-xscripts/dev.sh2
-rwxr-xr-xscripts/make_tests.sh1
-rw-r--r--test/e2e/BUILD2
-rw-r--r--test/e2e/exec_test.go132
-rw-r--r--test/runtimes/BUILD1
-rw-r--r--test/runtimes/blacklist_nodejs12.4.0.csv47
-rw-r--r--test/runtimes/build_defs.bzl29
-rw-r--r--test/runtimes/images/proctor/proctor.go31
-rw-r--r--test/runtimes/runner.go154
-rw-r--r--test/syscalls/BUILD5
-rw-r--r--test/syscalls/linux/BUILD92
-rw-r--r--test/syscalls/linux/clock_nanosleep.cc86
-rw-r--r--test/syscalls/linux/exec_binary.cc70
-rw-r--r--test/syscalls/linux/itimer.cc4
-rw-r--r--test/syscalls/linux/packet_socket.cc29
-rw-r--r--test/syscalls/linux/packet_socket_raw.cc21
-rw-r--r--test/syscalls/linux/proc.cc5
-rw-r--r--test/syscalls/linux/proc_net_tcp.cc39
-rw-r--r--test/syscalls/linux/pty.cc53
-rw-r--r--test/syscalls/linux/raw_socket_hdrincl.cc43
-rw-r--r--test/syscalls/linux/raw_socket_icmp.cc13
-rw-r--r--test/syscalls/linux/raw_socket_ipv4.cc13
-rw-r--r--test/syscalls/linux/readahead.cc91
-rw-r--r--test/syscalls/linux/socket.cc24
-rw-r--r--test/syscalls/linux/socket_bind_to_device.cc314
-rw-r--r--test/syscalls/linux/socket_bind_to_device_distribution.cc381
-rw-r--r--test/syscalls/linux/socket_bind_to_device_sequence.cc316
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.cc75
-rw-r--r--test/syscalls/linux/socket_bind_to_device_util.h67
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc2
-rw-r--r--test/syscalls/linux/socket_test_util.h2
-rw-r--r--test/syscalls/linux/uidgid.cc25
-rw-r--r--test/syscalls/linux/uname.cc14
-rw-r--r--test/util/BUILD11
-rw-r--r--test/util/proc_util.cc2
-rw-r--r--test/util/uid_util.cc44
-rw-r--r--test/util/uid_util.h29
163 files changed, 4701 insertions, 1356 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 6f62eb73f..082e26ee9 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -3,10 +3,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "io_bazel_rules_go",
- sha256 = "ae8c36ff6e565f674c7a3692d6a9ea1096e4c1ade497272c2108a810fb39acd2",
+ sha256 = "513c12397db1bc9aa46dd62f02dd94b49a9b5d17444d49b5a04c5a89f3053c1c",
urls = [
- "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz",
- "https://github.com/bazelbuild/rules_go/releases/download/0.19.4/rules_go-0.19.4.tar.gz",
+ "https://storage.googleapis.com/bazel-mirror/github.com/bazelbuild/rules_go/releases/download/v0.19.5/rules_go-v0.19.5.tar.gz",
+ "https://github.com/bazelbuild/rules_go/releases/download/v0.19.5/rules_go-v0.19.5.tar.gz",
],
)
@@ -24,7 +24,7 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_rules_dependencies", "go_register_to
go_rules_dependencies()
go_register_toolchains(
- go_version = "1.13",
+ go_version = "1.13.1",
nogo = "@//:nogo",
)
@@ -75,6 +75,16 @@ load("@bazel_toolchains//rules:rbe_repo.bzl", "rbe_autoconfig")
rbe_autoconfig(name = "rbe_default")
+http_archive(
+ name = "rules_pkg",
+ sha256 = "5bdc04987af79bd27bc5b00fe30f59a858f77ffa0bd2d8143d5b31ad8b1bd71c",
+ url = "https://github.com/bazelbuild/rules_pkg/releases/download/0.2.0/rules_pkg-0.2.0.tar.gz",
+)
+
+load("@rules_pkg//:deps.bzl", "rules_pkg_dependencies")
+
+rules_pkg_dependencies()
+
# External repositories, in sorted order.
go_repository(
name = "com_github_cenkalti_backoff",
@@ -196,7 +206,7 @@ go_repository(
go_repository(
name = "org_golang_x_time",
- commit = "9d24e82272b4f38b78bc8cff74fa936d31ccd8ef",
+ commit = "c4c64cad1fd0a1a8dab2523e04e61d35308e131e",
importpath = "golang.org/x/time",
)
diff --git a/kokoro/build.cfg b/kokoro/build.cfg
index cb2e5fbec..6c1d262d4 100644
--- a/kokoro/build.cfg
+++ b/kokoro/build.cfg
@@ -17,7 +17,7 @@ env_vars {
action {
define_artifacts {
regex: "**/runsc"
- regex: "**/runsc.sha256"
+ regex: "**/runsc.*"
regex: "**/dists/**"
}
}
diff --git a/kokoro/build_tests.cfg b/kokoro/build_tests.cfg
new file mode 100644
index 000000000..c64b7e679
--- /dev/null
+++ b/kokoro/build_tests.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/build.sh"
diff --git a/kokoro/docker_tests.cfg b/kokoro/docker_tests.cfg
index 717d71dd3..0a0ef87ed 100644
--- a/kokoro/docker_tests.cfg
+++ b/kokoro/docker_tests.cfg
@@ -5,5 +5,6 @@ action {
regex: "**/sponge_log.xml"
regex: "**/sponge_log.log"
regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
}
}
diff --git a/kokoro/hostnet_tests.cfg b/kokoro/hostnet_tests.cfg
index 532755f4a..520dc55a3 100644
--- a/kokoro/hostnet_tests.cfg
+++ b/kokoro/hostnet_tests.cfg
@@ -5,5 +5,6 @@ action {
regex: "**/sponge_log.xml"
regex: "**/sponge_log.log"
regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
}
}
diff --git a/kokoro/kvm_tests.cfg b/kokoro/kvm_tests.cfg
index 54365c2b2..1feb60c8a 100644
--- a/kokoro/kvm_tests.cfg
+++ b/kokoro/kvm_tests.cfg
@@ -5,5 +5,6 @@ action {
regex: "**/sponge_log.xml"
regex: "**/sponge_log.log"
regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
}
}
diff --git a/kokoro/overlay_tests.cfg b/kokoro/overlay_tests.cfg
index abd96f60c..6a2ddbd03 100644
--- a/kokoro/overlay_tests.cfg
+++ b/kokoro/overlay_tests.cfg
@@ -5,5 +5,6 @@ action {
regex: "**/sponge_log.xml"
regex: "**/sponge_log.log"
regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
}
}
diff --git a/kokoro/root_tests.cfg b/kokoro/root_tests.cfg
index 20b97766a..28351695c 100644
--- a/kokoro/root_tests.cfg
+++ b/kokoro/root_tests.cfg
@@ -5,5 +5,6 @@ action {
regex: "**/sponge_log.xml"
regex: "**/sponge_log.log"
regex: "**/outputs.zip"
+ regex: "**/runsc_logs_*.tar.gz"
}
}
diff --git a/pkg/fd/BUILD b/pkg/fd/BUILD
index afa8f7659..c7f549428 100644
--- a/pkg/fd/BUILD
+++ b/pkg/fd/BUILD
@@ -8,6 +8,9 @@ go_library(
srcs = ["fd.go"],
importpath = "gvisor.dev/gvisor/pkg/fd",
visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/unet",
+ ],
)
go_test(
diff --git a/pkg/fd/fd.go b/pkg/fd/fd.go
index 83bcfe220..7691b477b 100644
--- a/pkg/fd/fd.go
+++ b/pkg/fd/fd.go
@@ -22,6 +22,8 @@ import (
"runtime"
"sync/atomic"
"syscall"
+
+ "gvisor.dev/gvisor/pkg/unet"
)
// ReadWriter implements io.ReadWriter, io.ReaderAt, and io.WriterAt for fd. It
@@ -185,6 +187,12 @@ func OpenAt(dir *FD, path string, flags int, mode uint32) (*FD, error) {
return New(f), nil
}
+// DialUnix connects to a Unix Domain Socket and return the file descriptor.
+func DialUnix(path string) (*FD, error) {
+ socket, err := unet.Connect(path, false)
+ return New(socket.FD()), err
+}
+
// Close closes the file descriptor contained in the FD.
//
// Close is safe to call multiple times, but will return an error after the
diff --git a/pkg/metric/BUILD b/pkg/metric/BUILD
index 842788179..dd6ca6d39 100644
--- a/pkg/metric/BUILD
+++ b/pkg/metric/BUILD
@@ -1,6 +1,7 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -22,6 +23,12 @@ proto_library(
visibility = ["//:sandbox"],
)
+cc_proto_library(
+ name = "metric_cc_proto",
+ visibility = ["//:sandbox"],
+ deps = [":metric_proto"],
+)
+
go_proto_library(
name = "metric_go_proto",
importpath = "gvisor.dev/gvisor/pkg/metric/metric_go_proto",
diff --git a/pkg/p9/p9test/client_test.go b/pkg/p9/p9test/client_test.go
index fe649c2e8..8bbdb2488 100644
--- a/pkg/p9/p9test/client_test.go
+++ b/pkg/p9/p9test/client_test.go
@@ -2127,3 +2127,98 @@ func TestConcurrency(t *testing.T) {
}
}
}
+
+func TestReadWriteConcurrent(t *testing.T) {
+ h, c := NewHarness(t)
+ defer h.Finish()
+
+ _, root := newRoot(h, c)
+ defer root.Close()
+
+ const (
+ instances = 10
+ iterations = 10000
+ dataSize = 1024
+ )
+ var (
+ dataSets [instances][dataSize]byte
+ backends [instances]*Mock
+ files [instances]p9.File
+ )
+
+ // Walk to the file normally.
+ for i := 0; i < instances; i++ {
+ _, backends[i], files[i] = walkHelper(h, "file", root)
+ defer files[i].Close()
+ }
+
+ // Open the files.
+ for i := 0; i < instances; i++ {
+ backends[i].EXPECT().Open(p9.ReadWrite)
+ if _, _, _, err := files[i].Open(p9.ReadWrite); err != nil {
+ t.Fatalf("open got %v, wanted nil", err)
+ }
+ }
+
+ // Initialize random data for each instance.
+ for i := 0; i < instances; i++ {
+ if _, err := rand.Read(dataSets[i][:]); err != nil {
+ t.Fatalf("error initializing dataSet#%d, got %v", i, err)
+ }
+ }
+
+ // Define our random read/write mechanism.
+ randRead := func(h *Harness, backend *Mock, f p9.File, data, test []byte) {
+ // Prepare the backend.
+ backend.EXPECT().ReadAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if n := copy(p, data); n != len(data) {
+ // Note that we have to assert the result here, as the Return statement
+ // below cannot be dynamic: it will be bound before this call is made.
+ h.t.Errorf("wanted length %d, got %d", len(data), n)
+ }
+ }).Return(len(data), nil)
+
+ // Execute the read.
+ if n, err := f.ReadAt(test, 0); n != len(test) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(test), n, err)
+ return // No sense doing check below.
+ }
+ if !bytes.Equal(test, data) {
+ t.Errorf("data integrity failed during read") // Not as expected.
+ }
+ }
+ randWrite := func(h *Harness, backend *Mock, f p9.File, data []byte) {
+ // Prepare the backend.
+ backend.EXPECT().WriteAt(gomock.Any(), uint64(0)).Do(func(p []byte, offset uint64) {
+ if !bytes.Equal(p, data) {
+ h.t.Errorf("data integrity failed during write") // Not as expected.
+ }
+ }).Return(len(data), nil)
+
+ // Execute the write.
+ if n, err := f.WriteAt(data, 0); n != len(data) || err != nil {
+ t.Errorf("failed read: wanted (%d, nil), got (%d, %v)", len(data), n, err)
+ }
+ }
+ randReadWrite := func(n int, h *Harness, backend *Mock, f p9.File, data []byte) {
+ test := make([]byte, len(data))
+ for i := 0; i < n; i++ {
+ if rand.Intn(2) == 0 {
+ randRead(h, backend, f, data, test)
+ } else {
+ randWrite(h, backend, f, data)
+ }
+ }
+ }
+
+ // Start reading and writing.
+ var wg sync.WaitGroup
+ for i := 0; i < instances; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ randReadWrite(iterations, h, backends[i], files[i], dataSets[i][:])
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/pkg/seccomp/seccomp_unsafe.go b/pkg/seccomp/seccomp_unsafe.go
index 0a3d92854..be328db12 100644
--- a/pkg/seccomp/seccomp_unsafe.go
+++ b/pkg/seccomp/seccomp_unsafe.go
@@ -35,7 +35,7 @@ type sockFprog struct {
//go:nosplit
func SetFilter(instrs []linux.BPFInstruction) syscall.Errno {
// PR_SET_NO_NEW_PRIVS is required in order to enable seccomp. See seccomp(2) for details.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PRCTL, linux.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0, 0); errno != 0 {
return errno
}
diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD
index 7aace2d7b..c71cff9f3 100644
--- a/pkg/sentry/arch/BUILD
+++ b/pkg/sentry/arch/BUILD
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -42,6 +43,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "registers_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":registers_proto"],
+)
+
go_proto_library(
name = "registers_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto",
diff --git a/pkg/sentry/fs/fsutil/host_mappable.go b/pkg/sentry/fs/fsutil/host_mappable.go
index d2495cb83..693625ddc 100644
--- a/pkg/sentry/fs/fsutil/host_mappable.go
+++ b/pkg/sentry/fs/fsutil/host_mappable.go
@@ -144,7 +144,7 @@ func (h *HostMappable) Truncate(ctx context.Context, newSize int64) error {
mask := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: newSize}
- if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr); err != nil {
+ if err := h.backingFile.SetMaskedAttributes(ctx, mask, attr, false); err != nil {
return err
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached.go b/pkg/sentry/fs/fsutil/inode_cached.go
index d404a79d4..dd80757dc 100644
--- a/pkg/sentry/fs/fsutil/inode_cached.go
+++ b/pkg/sentry/fs/fsutil/inode_cached.go
@@ -140,12 +140,16 @@ type CachedFileObject interface {
// WriteFromBlocksAt may return a partial write without an error.
WriteFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, offset uint64) (uint64, error)
- // SetMaskedAttributes sets the attributes in attr that are true in mask
- // on the backing file.
+ // SetMaskedAttributes sets the attributes in attr that are true in
+ // mask on the backing file. If the mask contains only ATime or MTime
+ // and the CachedFileObject has an FD to the file, then this operation
+ // is a noop unless forceSetTimestamps is true. This avoids an extra
+ // RPC to the gofer in the open-read/write-close case, when the
+ // timestamps on the file will be updated by the host kernel for us.
//
// SetMaskedAttributes may be called at any point, regardless of whether
// the file was opened.
- SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error
+ SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error
// Allocate allows the caller to reserve disk space for the inode.
// It's equivalent to fallocate(2) with 'mode=0'.
@@ -224,7 +228,7 @@ func (c *CachingInodeOperations) SetPermissions(ctx context.Context, inode *fs.I
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Perms: true}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Perms: perms}, false); err != nil {
return false
}
c.attr.Perms = perms
@@ -246,7 +250,7 @@ func (c *CachingInodeOperations) SetOwner(ctx context.Context, inode *fs.Inode,
UID: owner.UID.Ok(),
GID: owner.GID.Ok(),
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{Owner: owner}, false); err != nil {
return err
}
if owner.UID.Ok() {
@@ -282,7 +286,9 @@ func (c *CachingInodeOperations) SetTimestamps(ctx context.Context, inode *fs.In
AccessTime: !ts.ATimeOmit,
ModificationTime: !ts.MTimeOmit,
}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}); err != nil {
+ // Call SetMaskedAttributes with forceSetTimestamps = true to make sure
+ // the timestamp is updated.
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, fs.UnstableAttr{AccessTime: ts.ATime, ModificationTime: ts.MTime}, true); err != nil {
return err
}
if !ts.ATimeOmit {
@@ -305,7 +311,7 @@ func (c *CachingInodeOperations) Truncate(ctx context.Context, inode *fs.Inode,
now := ktime.NowFromContext(ctx)
masked := fs.AttrMask{Size: true}
attr := fs.UnstableAttr{Size: size}
- if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, masked, attr, false); err != nil {
c.dataMu.Unlock()
return err
}
@@ -394,7 +400,7 @@ func (c *CachingInodeOperations) WriteOut(ctx context.Context, inode *fs.Inode)
c.dirtyAttr.Size = false
// Write out cached attributes.
- if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr); err != nil {
+ if err := c.backingFile.SetMaskedAttributes(ctx, c.dirtyAttr, c.attr, false); err != nil {
c.attrMu.Unlock()
return err
}
diff --git a/pkg/sentry/fs/fsutil/inode_cached_test.go b/pkg/sentry/fs/fsutil/inode_cached_test.go
index eb5730c35..129f314c8 100644
--- a/pkg/sentry/fs/fsutil/inode_cached_test.go
+++ b/pkg/sentry/fs/fsutil/inode_cached_test.go
@@ -39,7 +39,7 @@ func (noopBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.Block
return srcs.NumBytes(), nil
}
-func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (noopBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
@@ -230,7 +230,7 @@ func (f *sliceBackingFile) WriteFromBlocksAt(ctx context.Context, srcs safemem.B
return w.WriteFromBlocks(srcs)
}
-func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr) error {
+func (*sliceBackingFile) SetMaskedAttributes(context.Context, fs.AttrMask, fs.UnstableAttr, bool) error {
return nil
}
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 95b064aea..d918d6620 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -215,8 +215,8 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
- if i.skipSetAttr(mask) {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, forceSetTimestamps bool) error {
+ if i.skipSetAttr(mask, forceSetTimestamps) {
return nil
}
as, ans := attr.AccessTime.Unix()
@@ -251,13 +251,14 @@ func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMa
// when:
// - Mask is empty
// - Mask contains only attributes that cannot be set in the gofer
-// - Mask contains only atime and/or mtime, and host FD exists
+// - forceSetTimestamps is false and mask contains only atime and/or mtime
+// and host FD exists
//
// Updates to atime and mtime can be skipped because cached value will be
// "close enough" to host value, given that operation went directly to host FD.
// Skipping atime updates is particularly important to reduce the number of
// operations sent to the Gofer for readonly files.
-func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
+func (i *inodeFileState) skipSetAttr(mask fs.AttrMask, forceSetTimestamps bool) bool {
// First remove attributes that cannot be updated.
cpy := mask
cpy.Type = false
@@ -277,6 +278,12 @@ func (i *inodeFileState) skipSetAttr(mask fs.AttrMask) bool {
return false
}
+ // If forceSetTimestamps was passed, then we cannot skip.
+ if forceSetTimestamps {
+ return false
+ }
+
+ // Skip if we have a host FD.
i.handlesMu.RLock()
defer i.handlesMu.RUnlock()
return (i.readHandles != nil && i.readHandles.Host != nil) ||
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index 894ab01f0..a6e4a09e3 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -114,7 +114,7 @@ func (i *inodeFileState) WriteFromBlocksAt(ctx context.Context, srcs safemem.Blo
}
// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
-func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr) error {
+func (i *inodeFileState) SetMaskedAttributes(ctx context.Context, mask fs.AttrMask, attr fs.UnstableAttr, _ bool) error {
if mask.Empty() {
return nil
}
@@ -163,7 +163,7 @@ func (i *inodeFileState) unstableAttr(ctx context.Context) (fs.UnstableAttr, err
return unstableAttr(i.mops, &s), nil
}
-// SetMaskedAttributes implements fsutil.CachedFileObject.SetMaskedAttributes.
+// Allocate implements fsutil.CachedFileObject.Allocate.
func (i *inodeFileState) Allocate(_ context.Context, offset, length int64) error {
return syscall.Fallocate(i.FD(), 0, offset, length)
}
diff --git a/pkg/sentry/fs/host/tty.go b/pkg/sentry/fs/host/tty.go
index 2526412a4..90331e3b2 100644
--- a/pkg/sentry/fs/host/tty.go
+++ b/pkg/sentry/fs/host/tty.go
@@ -43,12 +43,15 @@ type TTYFileOperations struct {
// fgProcessGroup is the foreground process group that is currently
// connected to this TTY.
fgProcessGroup *kernel.ProcessGroup
+
+ termios linux.KernelTermios
}
// newTTYFile returns a new fs.File that wraps a TTY FD.
func newTTYFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags, iops *inodeOperations) *fs.File {
return fs.NewFile(ctx, dirent, flags, &TTYFileOperations{
fileOperations: fileOperations{iops: iops},
+ termios: linux.DefaultSlaveTermios,
})
}
@@ -97,9 +100,12 @@ func (t *TTYFileOperations) Write(ctx context.Context, file *fs.File, src userme
t.mu.Lock()
defer t.mu.Unlock()
- // Are we allowed to do the write?
- if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
- return 0, err
+ // Check whether TOSTOP is enabled. This corresponds to the check in
+ // drivers/tty/n_tty.c:n_tty_write().
+ if t.termios.LEnabled(linux.TOSTOP) {
+ if err := t.checkChange(ctx, linux.SIGTTOU); err != nil {
+ return 0, err
+ }
}
return t.fileOperations.Write(ctx, file, src, offset)
}
@@ -144,6 +150,9 @@ func (t *TTYFileOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO
return 0, err
}
err := ioctlSetTermios(fd, ioctl, &termios)
+ if err == nil {
+ t.termios.FromTermios(termios)
+ }
return 0, err
case linux.TIOCGPGRP:
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index c80c770b5..f70239449 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -354,7 +354,7 @@ func writeInetAddr(w io.Writer, family int, i linux.SockAddr) {
}
}
-func commonReadSeqFileDataTCP(n seqfile.SeqHandle, k *kernel.Kernel, ctx context.Context, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) {
+func commonReadSeqFileDataTCP(ctx context.Context, n seqfile.SeqHandle, k *kernel.Kernel, h seqfile.SeqHandle, fa int, header []byte) ([]seqfile.SeqData, int64) {
// t may be nil here if our caller is not part of a task goroutine. This can
// happen for example if we're here for "sentryctl cat". When t is nil,
// degrade gracefully and retrieve what we can.
@@ -498,7 +498,7 @@ func (*netTCP) NeedsUpdate(generation int64) bool {
// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
func (n *netTCP) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
header := []byte(" sl local_address rem_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode \n")
- return commonReadSeqFileDataTCP(n, n.k, ctx, h, linux.AF_INET, header)
+ return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET, header)
}
// netTCP6 implements seqfile.SeqSource for /proc/net/tcp6.
@@ -516,7 +516,7 @@ func (*netTCP6) NeedsUpdate(generation int64) bool {
// ReadSeqFileData implements seqfile.SeqSource.ReadSeqFileData.
func (n *netTCP6) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]seqfile.SeqData, int64) {
header := []byte(" sl local_address remote_address st tx_queue rx_queue tr tm->when retrnsmt uid timeout inode\n")
- return commonReadSeqFileDataTCP(n, n.k, ctx, h, linux.AF_INET6, header)
+ return commonReadSeqFileDataTCP(ctx, n, n.k, h, linux.AF_INET6, header)
}
// netUDP implements seqfile.SeqSource for /proc/net/udp.
diff --git a/pkg/sentry/fs/proc/proc.go b/pkg/sentry/fs/proc/proc.go
index 0ef13f2f5..56e92721e 100644
--- a/pkg/sentry/fs/proc/proc.go
+++ b/pkg/sentry/fs/proc/proc.go
@@ -230,7 +230,7 @@ func (rpf *rootProcFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dent
// But for whatever crazy reason, you can still walk to the given node.
for _, tg := range rpf.iops.pidns.ThreadGroups() {
if leader := tg.Leader(); leader != nil {
- name := strconv.FormatUint(uint64(tg.ID()), 10)
+ name := strconv.FormatUint(uint64(rpf.iops.pidns.IDOfThreadGroup(tg)), 10)
m[name] = fs.GenericDentAttr(fs.SpecialDirectory, device.ProcDevice)
names = append(names, name)
}
diff --git a/pkg/sentry/fs/splice.go b/pkg/sentry/fs/splice.go
index b03b7f836..311798811 100644
--- a/pkg/sentry/fs/splice.go
+++ b/pkg/sentry/fs/splice.go
@@ -139,7 +139,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
// Attempt to do a WriteTo; this is likely the most efficient.
n, err := src.FileOperations.WriteTo(ctx, src, w, opts.Length, opts.Dup)
- if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup {
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup {
// Attempt as a ReadFrom. If a WriteTo, a ReadFrom may also be
// more efficient than a copy if buffers are cached or readily
// available. (It's unlikely that they can actually be donated).
@@ -151,7 +151,7 @@ func Splice(ctx context.Context, dst *File, src *File, opts SpliceOpts) (int64,
// if we block at some point, we could lose data. If the source is
// not a pipe then reading is not destructive; if the destination
// is a regular file, then it is guaranteed not to block writing.
- if n == 0 && err != nil && err != syserror.ErrWouldBlock && !opts.Dup && (!dstPipe || !srcPipe) {
+ if n == 0 && err == syserror.ENOSYS && !opts.Dup && (!dstPipe || !srcPipe) {
// Fallback to an in-kernel copy.
n, err = io.Copy(w, &io.LimitedReader{
R: r,
diff --git a/pkg/sentry/fs/timerfd/timerfd.go b/pkg/sentry/fs/timerfd/timerfd.go
index 59403d9db..f8bf663bb 100644
--- a/pkg/sentry/fs/timerfd/timerfd.go
+++ b/pkg/sentry/fs/timerfd/timerfd.go
@@ -141,9 +141,10 @@ func (t *TimerOperations) Write(context.Context, *fs.File, usermem.IOSequence, i
}
// Notify implements ktime.TimerListener.Notify.
-func (t *TimerOperations) Notify(exp uint64) {
+func (t *TimerOperations) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
atomic.AddUint64(&t.val, exp)
t.events.Notify(waiter.EventIn)
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/fsimpl/ext/directory.go b/pkg/sentry/fsimpl/ext/directory.go
index b51f3e18d..0b471d121 100644
--- a/pkg/sentry/fsimpl/ext/directory.go
+++ b/pkg/sentry/fsimpl/ext/directory.go
@@ -190,10 +190,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
}
if !cb.Handle(vfs.Dirent{
- Name: child.diskDirent.FileName(),
- Type: fs.ToDirentType(childType),
- Ino: uint64(child.diskDirent.Inode()),
- Off: fd.off,
+ Name: child.diskDirent.FileName(),
+ Type: fs.ToDirentType(childType),
+ Ino: uint64(child.diskDirent.Inode()),
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/fsimpl/ext/ext_test.go b/pkg/sentry/fsimpl/ext/ext_test.go
index 63cf7aeaf..1aa2bd6a4 100644
--- a/pkg/sentry/fsimpl/ext/ext_test.go
+++ b/pkg/sentry/fsimpl/ext/ext_test.go
@@ -584,7 +584,7 @@ func TestIterDirents(t *testing.T) {
// Ignore the inode number and offset of dirents because those are likely to
// change as the underlying image changes.
cmpIgnoreFields := cmp.FilterPath(func(p cmp.Path) bool {
- return p.String() == "Ino" || p.String() == "Off"
+ return p.String() == "Ino" || p.String() == "NextOff"
}, cmp.Ignore())
if diff := cmp.Diff(cb.dirents, test.want, cmpIgnoreFields); diff != "" {
t.Errorf("dirents mismatch (-want +got):\n%s", diff)
diff --git a/pkg/sentry/fsimpl/memfs/directory.go b/pkg/sentry/fsimpl/memfs/directory.go
index c52dc781c..c620227c9 100644
--- a/pkg/sentry/fsimpl/memfs/directory.go
+++ b/pkg/sentry/fsimpl/memfs/directory.go
@@ -75,10 +75,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 0 {
if !cb.Handle(vfs.Dirent{
- Name: ".",
- Type: linux.DT_DIR,
- Ino: vfsd.Impl().(*dentry).inode.ino,
- Off: 0,
+ Name: ".",
+ Type: linux.DT_DIR,
+ Ino: vfsd.Impl().(*dentry).inode.ino,
+ NextOff: 1,
}) {
return nil
}
@@ -87,10 +87,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
if fd.off == 1 {
parentInode := vfsd.ParentOrSelf().Impl().(*dentry).inode
if !cb.Handle(vfs.Dirent{
- Name: "..",
- Type: parentInode.direntType(),
- Ino: parentInode.ino,
- Off: 1,
+ Name: "..",
+ Type: parentInode.direntType(),
+ Ino: parentInode.ino,
+ NextOff: 2,
}) {
return nil
}
@@ -112,10 +112,10 @@ func (fd *directoryFD) IterDirents(ctx context.Context, cb vfs.IterDirentsCallba
// Skip other directoryFD iterators.
if child.inode != nil {
if !cb.Handle(vfs.Dirent{
- Name: child.vfsd.Name(),
- Type: child.inode.direntType(),
- Ino: child.inode.ino,
- Off: fd.off,
+ Name: child.vfsd.Name(),
+ Type: child.inode.direntType(),
+ Ino: child.inode.ino,
+ NextOff: fd.off + 1,
}) {
dir.childList.InsertBefore(child, fd.iter)
return nil
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index eaccfd02d..aba2414d4 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -1,5 +1,6 @@
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -84,6 +85,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "uncaught_signal_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":uncaught_signal_proto"],
+)
+
go_proto_library(
name = "uncaught_signal_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/uncaught_signal_go_proto",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 8c1f79ab5..3cda03891 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -24,6 +24,7 @@
// TaskSet.mu
// SignalHandlers.mu
// Task.mu
+// runningTasksMu
//
// Locking SignalHandlers.mu in multiple SignalHandlers requires locking
// TaskSet.mu exclusively first. Locking Task.mu in multiple Tasks at the same
@@ -135,6 +136,22 @@ type Kernel struct {
// syslog is the kernel log.
syslog syslog
+ // runningTasksMu synchronizes disable/enable of cpuClockTicker when
+ // the kernel is idle (runningTasks == 0).
+ //
+ // runningTasksMu is used to exclude critical sections when the timer
+ // disables itself and when the first active task enables the timer,
+ // ensuring that tasks always see a valid cpuClock value.
+ runningTasksMu sync.Mutex `state:"nosave"`
+
+ // runningTasks is the total count of tasks currently in
+ // TaskGoroutineRunningSys or TaskGoroutineRunningApp. i.e., they are
+ // not blocked or stopped.
+ //
+ // runningTasks must be accessed atomically. Increments from 0 to 1 are
+ // further protected by runningTasksMu (see incRunningTasks).
+ runningTasks int64
+
// cpuClock is incremented every linux.ClockTick. cpuClock is used to
// measure task CPU usage, since sampling monotonicClock twice on every
// syscall turns out to be unreasonably expensive. This is similar to how
@@ -150,6 +167,22 @@ type Kernel struct {
// cpuClockTicker increments cpuClock.
cpuClockTicker *ktime.Timer `state:"nosave"`
+ // cpuClockTickerDisabled indicates that cpuClockTicker has been
+ // disabled because no tasks are running.
+ //
+ // cpuClockTickerDisabled is protected by runningTasksMu.
+ cpuClockTickerDisabled bool
+
+ // cpuClockTickerSetting is the ktime.Setting of cpuClockTicker at the
+ // point it was disabled. It is cached here to avoid a lock ordering
+ // violation with cpuClockTicker.mu when runningTaskMu is held.
+ //
+ // cpuClockTickerSetting is only valid when cpuClockTickerDisabled is
+ // true.
+ //
+ // cpuClockTickerSetting is protected by runningTasksMu.
+ cpuClockTickerSetting ktime.Setting
+
// fdMapUids is an ever-increasing counter for generating FDTable uids.
//
// fdMapUids is mutable, and is accessed using atomic memory operations.
@@ -912,6 +945,102 @@ func (k *Kernel) resumeTimeLocked() {
}
}
+func (k *Kernel) incRunningTasks() {
+ for {
+ tasks := atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // Standard case. Simply increment.
+ if !atomic.CompareAndSwapInt64(&k.runningTasks, tasks, tasks+1) {
+ continue
+ }
+ return
+ }
+
+ // Transition from 0 -> 1. Synchronize with other transitions and timer.
+ k.runningTasksMu.Lock()
+ tasks = atomic.LoadInt64(&k.runningTasks)
+ if tasks != 0 {
+ // We're no longer the first task, no need to
+ // re-enable.
+ atomic.AddInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ if !k.cpuClockTickerDisabled {
+ // Timer was never disabled.
+ atomic.StoreInt64(&k.runningTasks, 1)
+ k.runningTasksMu.Unlock()
+ return
+ }
+
+ // We need to update cpuClock for all of the ticks missed while we
+ // slept, and then re-enable the timer.
+ //
+ // The Notify in Swap isn't sufficient. kernelCPUClockTicker.Notify
+ // always increments cpuClock by 1 regardless of the number of
+ // expirations as a heuristic to avoid over-accounting in cases of CPU
+ // throttling.
+ //
+ // We want to cover the normal case, when all time should be accounted,
+ // so we increment for all expirations. Throttling is less concerning
+ // here because the ticker is only disabled from Notify. This means
+ // that Notify must schedule and compensate for the throttled period
+ // before the timer is disabled. Throttling while the timer is disabled
+ // doesn't matter, as nothing is running or reading cpuClock anyways.
+ //
+ // S/R also adds complication, as there are two cases. Recall that
+ // monotonicClock will jump forward on restore.
+ //
+ // 1. If the ticker is enabled during save, then on Restore Notify is
+ // called with many expirations, covering the time jump, but cpuClock
+ // is only incremented by 1.
+ //
+ // 2. If the ticker is disabled during save, then after Restore the
+ // first wakeup will call this function and cpuClock will be
+ // incremented by the number of expirations across the S/R.
+ //
+ // These cause very different value of cpuClock. But again, since
+ // nothing was running while the ticker was disabled, those differences
+ // don't matter.
+ setting, exp := k.cpuClockTickerSetting.At(k.monotonicClock.Now())
+ if exp > 0 {
+ atomic.AddUint64(&k.cpuClock, exp)
+ }
+
+ // Now that cpuClock is updated it is safe to allow other tasks to
+ // transition to running.
+ atomic.StoreInt64(&k.runningTasks, 1)
+
+ // N.B. we must unlock before calling Swap to maintain lock ordering.
+ //
+ // cpuClockTickerDisabled need not wait until after Swap to become
+ // true. It is sufficient that the timer *will* be enabled.
+ k.cpuClockTickerDisabled = false
+ k.runningTasksMu.Unlock()
+
+ // This won't call Notify (unless it's been ClockTick since setting.At
+ // above). This means we skip the thread group work in Notify. However,
+ // since nothing was running while we were disabled, none of the timers
+ // could have expired.
+ k.cpuClockTicker.Swap(setting)
+
+ return
+ }
+}
+
+func (k *Kernel) decRunningTasks() {
+ tasks := atomic.AddInt64(&k.runningTasks, -1)
+ if tasks < 0 {
+ panic(fmt.Sprintf("Invalid running count %d", tasks))
+ }
+
+ // Nothing to do. The next CPU clock tick will disable the timer if
+ // there is still nothing running. This provides approximately one tick
+ // of slack in which we can switch back and forth between idle and
+ // active without an expensive transition.
+}
+
// WaitExited blocks until all tasks in k have exited.
func (k *Kernel) WaitExited() {
k.tasks.liveGoroutines.Wait()
diff --git a/pkg/sentry/kernel/memevent/BUILD b/pkg/sentry/kernel/memevent/BUILD
index ebcfaa619..d7a7d1169 100644
--- a/pkg/sentry/kernel/memevent/BUILD
+++ b/pkg/sentry/kernel/memevent/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -24,6 +25,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "memory_events_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":memory_events_proto"],
+)
+
go_proto_library(
name = "memory_events_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/kernel/memevent/memory_events_go_proto",
diff --git a/pkg/sentry/kernel/posixtimer.go b/pkg/sentry/kernel/posixtimer.go
index c5d095af7..2e861a5a8 100644
--- a/pkg/sentry/kernel/posixtimer.go
+++ b/pkg/sentry/kernel/posixtimer.go
@@ -117,9 +117,9 @@ func (it *IntervalTimer) signalRejectedLocked() {
}
// Notify implements ktime.TimerListener.Notify.
-func (it *IntervalTimer) Notify(exp uint64) {
+func (it *IntervalTimer) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
if it.target == nil {
- return
+ return ktime.Setting{}, false
}
it.target.tg.pidns.owner.mu.RLock()
@@ -129,7 +129,7 @@ func (it *IntervalTimer) Notify(exp uint64) {
if it.sigpending {
it.overrunCur += exp
- return
+ return ktime.Setting{}, false
}
// sigpending must be set before sendSignalTimerLocked() so that it can be
@@ -148,6 +148,8 @@ func (it *IntervalTimer) Notify(exp uint64) {
if err := it.target.sendSignalTimerLocked(si, it.group, it); err != nil {
it.signalRejectedLocked()
}
+
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy. Users of Timer should call
diff --git a/pkg/sentry/kernel/task_identity.go b/pkg/sentry/kernel/task_identity.go
index 78ff14b20..ce3e6ef28 100644
--- a/pkg/sentry/kernel/task_identity.go
+++ b/pkg/sentry/kernel/task_identity.go
@@ -465,8 +465,8 @@ func (t *Task) SetKeepCaps(k bool) {
// disables the features we don't support anyway, is always set. This
// drastically simplifies this function.
//
-// - We don't implement AT_SECURE, because no_new_privs always being set means
-// that the conditions that require AT_SECURE never arise. (Compare Linux's
+// - We don't set AT_SECURE = 1, because no_new_privs always being set means
+// that the conditions that require AT_SECURE = 1 never arise. (Compare Linux's
// security/commoncap.c:cap_bprm_set_creds() and cap_bprm_secureexec().)
//
// - We don't check for CAP_SYS_ADMIN in prctl(PR_SET_SECCOMP), since
diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go
index e76c069b0..8b148db35 100644
--- a/pkg/sentry/kernel/task_sched.go
+++ b/pkg/sentry/kernel/task_sched.go
@@ -126,12 +126,22 @@ func (t *Task) accountTaskGoroutineEnter(state TaskGoroutineState) {
t.gosched.Timestamp = now
t.gosched.State = state
t.goschedSeq.EndWrite()
+
+ if state != TaskGoroutineRunningApp {
+ // Task is blocking/stopping.
+ t.k.decRunningTasks()
+ }
}
// Preconditions: The caller must be running on the task goroutine, and leaving
// a state indicated by a previous call to
// t.accountTaskGoroutineEnter(state).
func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) {
+ if state != TaskGoroutineRunningApp {
+ // Task is unblocking/continuing.
+ t.k.incRunningTasks()
+ }
+
now := t.k.CPUClockNow()
if t.gosched.State != state {
panic(fmt.Sprintf("Task goroutine switching from state %v (expected %v) to %v", t.gosched.State, state, TaskGoroutineRunningSys))
@@ -330,7 +340,7 @@ func newKernelCPUClockTicker(k *Kernel) *kernelCPUClockTicker {
}
// Notify implements ktime.TimerListener.Notify.
-func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
+func (ticker *kernelCPUClockTicker) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
// Only increment cpuClock by 1 regardless of the number of expirations.
// This approximately compensates for cases where thread throttling or bad
// Go runtime scheduling prevents the kernelCPUClockTicker goroutine, and
@@ -426,6 +436,27 @@ func (ticker *kernelCPUClockTicker) Notify(exp uint64) {
tgs[i] = nil
}
ticker.tgs = tgs[:0]
+
+ // If nothing is running, we can disable the timer.
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks == 0 {
+ ticker.k.runningTasksMu.Lock()
+ defer ticker.k.runningTasksMu.Unlock()
+ tasks := atomic.LoadInt64(&ticker.k.runningTasks)
+ if tasks != 0 {
+ // Raced with a 0 -> 1 transition.
+ return setting, false
+ }
+
+ // Stop the timer. We must cache the current setting so the
+ // kernel can access it without violating the lock order.
+ ticker.k.cpuClockTickerSetting = setting
+ ticker.k.cpuClockTickerDisabled = true
+ setting.Enabled = false
+ return setting, true
+ }
+
+ return setting, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go
index 0eef24bfb..72568d296 100644
--- a/pkg/sentry/kernel/thread_group.go
+++ b/pkg/sentry/kernel/thread_group.go
@@ -511,8 +511,9 @@ type itimerRealListener struct {
}
// Notify implements ktime.TimerListener.Notify.
-func (l *itimerRealListener) Notify(exp uint64) {
+func (l *itimerRealListener) Notify(exp uint64, setting ktime.Setting) (ktime.Setting, bool) {
l.tg.SendSignal(SignalInfoPriv(linux.SIGALRM))
+ return ktime.Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy.
diff --git a/pkg/sentry/kernel/time/time.go b/pkg/sentry/kernel/time/time.go
index aa6c75d25..107394183 100644
--- a/pkg/sentry/kernel/time/time.go
+++ b/pkg/sentry/kernel/time/time.go
@@ -280,13 +280,16 @@ func (ClockEventsQueue) Readiness(mask waiter.EventMask) waiter.EventMask {
// A TimerListener receives expirations from a Timer.
type TimerListener interface {
// Notify is called when its associated Timer expires. exp is the number of
- // expirations.
+ // expirations. setting is the next timer Setting.
//
// Notify is called with the associated Timer's mutex locked, so Notify
// must not take any locks that precede Timer.mu in lock order.
//
+ // If Notify returns true, the timer will use the returned setting
+ // rather than the passed one.
+ //
// Preconditions: exp > 0.
- Notify(exp uint64)
+ Notify(exp uint64, setting Setting) (newSetting Setting, update bool)
// Destroy is called when the timer is destroyed.
Destroy()
@@ -533,7 +536,9 @@ func (t *Timer) Tick() {
s, exp := t.setting.At(now)
t.setting = s
if exp > 0 {
- t.listener.Notify(exp)
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
}
@@ -588,7 +593,9 @@ func (t *Timer) Get() (Time, Setting) {
s, exp := t.setting.At(now)
t.setting = s
if exp > 0 {
- t.listener.Notify(exp)
+ if newS, ok := t.listener.Notify(exp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
return now, s
@@ -620,7 +627,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
}
oldS, oldExp := t.setting.At(now)
if oldExp > 0 {
- t.listener.Notify(oldExp)
+ t.listener.Notify(oldExp, oldS)
+ // N.B. The returned Setting doesn't matter because we're about
+ // to overwrite.
}
if f != nil {
f()
@@ -628,7 +637,9 @@ func (t *Timer) SwapAnd(s Setting, f func()) (Time, Setting) {
newS, newExp := s.At(now)
t.setting = newS
if newExp > 0 {
- t.listener.Notify(newExp)
+ if newS, ok := t.listener.Notify(newExp, t.setting); ok {
+ t.setting = newS
+ }
}
t.resetKickerLocked(now)
return now, oldS
@@ -683,11 +694,13 @@ func NewChannelNotifier() (TimerListener, <-chan struct{}) {
}
// Notify implements ktime.TimerListener.Notify.
-func (c *ChannelNotifier) Notify(uint64) {
+func (c *ChannelNotifier) Notify(uint64, Setting) (Setting, bool) {
select {
case c.tchan <- struct{}{}:
default:
}
+
+ return Setting{}, false
}
// Destroy implements ktime.TimerListener.Destroy and will close the channel.
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index ba9c9ce12..2d9251e92 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -323,18 +323,22 @@ func mapSegment(ctx context.Context, m *mm.MemoryManager, f *fs.File, phdr *elf.
return syserror.ENOEXEC
}
+ // N.B. Linux uses vm_brk_flags to map these pages, which only
+ // honors the X bit, always mapping at least RW. ignoring These
+ // pages are not included in the final brk region.
+ prot := usermem.ReadWrite
+ if phdr.Flags&elf.PF_X == elf.PF_X {
+ prot.Execute = true
+ }
+
if _, err := m.MMap(ctx, memmap.MMapOpts{
Length: uint64(anonSize),
Addr: anonAddr,
// Fixed without Unmap will fail the mmap if something is
// already at addr.
- Fixed: true,
- Private: true,
- // N.B. Linux uses vm_brk to map these pages, ignoring
- // the segment protections, instead always mapping RW.
- // These pages are not included in the final brk
- // region.
- Perms: usermem.ReadWrite,
+ Fixed: true,
+ Private: true,
+ Perms: prot,
MaxPerms: usermem.AnyAccess,
}); err != nil {
ctx.Infof("Error mapping PT_LOAD segment %v anonymous memory: %v", phdr, err)
diff --git a/pkg/sentry/loader/loader.go b/pkg/sentry/loader/loader.go
index f6f1ae762..089d1635b 100644
--- a/pkg/sentry/loader/loader.go
+++ b/pkg/sentry/loader/loader.go
@@ -308,6 +308,9 @@ func Load(ctx context.Context, m *mm.MemoryManager, mounts *fs.MountNamespace, r
arch.AuxEntry{linux.AT_EUID, usermem.Addr(c.EffectiveKUID.In(c.UserNamespace).OrOverflow())},
arch.AuxEntry{linux.AT_GID, usermem.Addr(c.RealKGID.In(c.UserNamespace).OrOverflow())},
arch.AuxEntry{linux.AT_EGID, usermem.Addr(c.EffectiveKGID.In(c.UserNamespace).OrOverflow())},
+ // The conditions that require AT_SECURE = 1 never arise. See
+ // kernel.Task.updateCredsForExecLocked.
+ arch.AuxEntry{linux.AT_SECURE, 0},
arch.AuxEntry{linux.AT_CLKTCK, linux.CLOCKS_PER_SEC},
arch.AuxEntry{linux.AT_EXECFN, execfn},
arch.AuxEntry{linux.AT_RANDOM, random},
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 4f8f9c5d9..9f0ecfbe4 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -267,7 +267,7 @@ func (s *subprocess) newThread() *thread {
// attach attaches to the thread.
func (t *thread) attach() {
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_ATTACH, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("unable to attach: %v", errno))
}
@@ -417,7 +417,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
for {
// Execute the syscall instruction.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -435,7 +435,7 @@ func (t *thread) syscall(regs *syscall.PtraceRegs) (uintptr, error) {
}
// Complete the actual system call.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSCALL, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -526,17 +526,17 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
for {
// Start running until the next system call.
if isSingleStepping(regs) {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU_SINGLESTEP,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
} else {
- if _, _, errno := syscall.RawSyscall(
+ if _, _, errno := syscall.RawSyscall6(
syscall.SYS_PTRACE,
syscall.PTRACE_SYSEMU,
- uintptr(t.tid), 0); errno != 0 {
+ uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace sysemu failed: %v", errno))
}
}
diff --git a/pkg/sentry/platform/ptrace/subprocess_linux.go b/pkg/sentry/platform/ptrace/subprocess_linux.go
index f09b0b3d0..c075b5f91 100644
--- a/pkg/sentry/platform/ptrace/subprocess_linux.go
+++ b/pkg/sentry/platform/ptrace/subprocess_linux.go
@@ -53,7 +53,7 @@ func probeSeccomp() bool {
for {
// Attempt an emulation.
- if _, _, errno := syscall.RawSyscall(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0); errno != 0 {
+ if _, _, errno := syscall.RawSyscall6(syscall.SYS_PTRACE, syscall.PTRACE_SYSEMU, uintptr(t.tid), 0, 0, 0, 0); errno != 0 {
panic(fmt.Sprintf("ptrace syscall-enter failed: %v", errno))
}
@@ -266,7 +266,7 @@ func attachedThread(flags uintptr, defaultAction linux.BPFAction) (*thread, erro
// Enable cpuid-faulting; this may fail on older kernels or hardware,
// so we just disregard the result. Host CPUID will be enabled.
- syscall.RawSyscall(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0)
+ syscall.RawSyscall6(syscall.SYS_ARCH_PRCTL, linux.ARCH_SET_CPUID, 0, 0, 0, 0, 0)
// Call the stub; should not return.
stubCall(stubStart, ppid)
diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go
index eace3766d..c303435d5 100644
--- a/pkg/sentry/sighandling/sighandling_unsafe.go
+++ b/pkg/sentry/sighandling/sighandling_unsafe.go
@@ -23,7 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
)
-// TODO(b/34161764): Move to pkg/abi/linux along with definitions in
+// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in
// pkg/sentry/arch.
type sigaction struct {
handler uintptr
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 25adca090..5812085fa 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -209,6 +209,10 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
@@ -887,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -903,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -938,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return int32(v), nil
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if len(v) == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return append([]byte(v), 0), nil
+
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1275,7 +1292,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
case linux.SO_RCVBUF:
if len(optVal) < sizeOfInt32 {
@@ -1283,7 +1300,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
@@ -1301,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2317,9 +2341,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCOUTQ:
- var v tcpip.SendQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 421f93dc4..0a9dfa6c3 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
- return 0, true, syserr.ErrPermissionDenied
+ return 0, true, syserr.ErrNotPermitted
}
switch protocol {
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 5061dcbde..3a6baa308 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -49,6 +50,14 @@ proto_library(
],
)
+cc_proto_library(
+ name = "syscall_rpc_cc_proto",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [":syscall_rpc_proto"],
+)
+
go_proto_library(
name = "syscall_rpc_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2b0ad6395..1867b3a5c 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,6 +175,10 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
@@ -838,6 +842,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -853,65 +861,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrQueueSizeNotSupported
}
return v, nil
- default:
- return -1, tcpip.ErrUnknownProtocolOption
- }
-}
-
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
- case *tcpip.SendQueueSizeOption:
+ case tcpip.SendQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ v := e.connected.SendQueuedSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
- }
- *o = qs
- return nil
-
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- return nil
+ return int(v), nil
- case *tcpip.SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ v := e.connected.SendMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- *o = qs
- return nil
+ return int(v), nil
- case *tcpip.ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
e.Lock()
if e.receiver == nil {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ v := e.receiver.RecvMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
}
- *o = qs
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 445d25010..7d7b42eba 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -44,6 +45,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+cc_proto_library(
+ name = "strace_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":strace_proto"],
+)
+
go_proto_library(
name = "strace_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/strace/strace_go_proto",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 18d24ab61..61acd0abd 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -232,7 +232,7 @@ var AMD64 = &kernel.SyscallTable{
184: syscalls.Error("tuxcall", syserror.ENOSYS, "Not implemented in Linux.", nil),
185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil),
186: syscalls.Supported("gettid", Gettid),
- 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341)
+ 187: syscalls.Supported("readahead", Readahead),
188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_read.go b/pkg/sentry/syscalls/linux/sys_read.go
index 3ab54271c..cd31e0649 100644
--- a/pkg/sentry/syscalls/linux/sys_read.go
+++ b/pkg/sentry/syscalls/linux/sys_read.go
@@ -72,6 +72,39 @@ func Read(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
return uintptr(n), nil, handleIOError(t, n != 0, err, kernel.ERESTARTSYS, "read", file)
}
+// Readahead implements readahead(2).
+func Readahead(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ fd := args[0].Int()
+ offset := args[1].Int64()
+ size := args[2].SizeT()
+
+ file := t.GetFile(fd)
+ if file == nil {
+ return 0, nil, syserror.EBADF
+ }
+ defer file.DecRef()
+
+ // Check that the file is readable.
+ if !file.Flags().Read {
+ return 0, nil, syserror.EBADF
+ }
+
+ // Check that the size is valid.
+ if int(size) < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Check that the offset is legitimate.
+ if offset < 0 {
+ return 0, nil, syserror.EINVAL
+ }
+
+ // Return EINVAL; if the underlying file type does not support readahead,
+ // then Linux will return EINVAL to indicate as much. In the future, we
+ // may extend this function to actually support readahead hints.
+ return 0, nil, syserror.EINVAL
+}
+
// Pread64 implements linux syscall pread64(2).
func Pread64(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 3bac4d90d..b5a72ce63 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy
return 0, nil, syserror.ENOTSOCK
}
- if optLen <= 0 {
+ if optLen < 0 {
return 0, nil, syserror.EINVAL
}
if optLen > maxOptLen {
diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go
index f0a292f2f..9f705ebca 100644
--- a/pkg/sentry/syscalls/linux/sys_splice.go
+++ b/pkg/sentry/syscalls/linux/sys_splice.go
@@ -245,12 +245,12 @@ func Splice(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
if inOffset != 0 || outOffset != 0 {
return 0, nil, syserror.ESPIPE
}
- default:
- return 0, nil, syserror.EINVAL
- }
- // We may not refer to the same pipe; otherwise it's a continuous loop.
- if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ // We may not refer to the same pipe; otherwise it's a continuous loop.
+ if inFile.Dirent.Inode.StableAttr.InodeID == outFile.Dirent.Inode.StableAttr.InodeID {
+ return 0, nil, syserror.EINVAL
+ }
+ default:
return 0, nil, syserror.EINVAL
}
diff --git a/pkg/sentry/syscalls/linux/sys_time.go b/pkg/sentry/syscalls/linux/sys_time.go
index 4b3f043a2..b887fa9d7 100644
--- a/pkg/sentry/syscalls/linux/sys_time.go
+++ b/pkg/sentry/syscalls/linux/sys_time.go
@@ -15,6 +15,7 @@
package linux
import (
+ "fmt"
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -228,41 +229,35 @@ func clockNanosleepFor(t *kernel.Task, c ktime.Clock, dur time.Duration, rem use
timer.Destroy()
- var remaining time.Duration
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- remaining = 0
- } else {
- remaining = dur - after.Sub(start)
+ switch err {
+ case syserror.ETIMEDOUT:
+ // Slept for entire timeout.
+ return nil
+ case syserror.ErrInterrupted:
+ // Interrupted.
+ remaining := dur - after.Sub(start)
if remaining < 0 {
remaining = time.Duration(0)
}
- }
- // Copy out remaining time.
- if err != nil && rem != usermem.Addr(0) {
- timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
- if err := copyTimespecOut(t, rem, &timeleft); err != nil {
- return err
+ // Copy out remaining time.
+ if rem != 0 {
+ timeleft := linux.NsecToTimespec(remaining.Nanoseconds())
+ if err := copyTimespecOut(t, rem, &timeleft); err != nil {
+ return err
+ }
}
- }
-
- // Did we just block for the entire duration?
- if err == syserror.ETIMEDOUT {
- return nil
- }
- // If interrupted, arrange for a restart with the remaining duration.
- if err == syserror.ErrInterrupted {
+ // Arrange for a restart with the remaining duration.
t.SetSyscallRestartBlock(&clockNanosleepRestartBlock{
c: c,
duration: remaining,
rem: rem,
})
return kernel.ERESTART_RESTARTBLOCK
+ default:
+ panic(fmt.Sprintf("Impossible BlockWithTimer error %v", err))
}
-
- return err
}
// Nanosleep implements linux syscall Nanosleep(2).
diff --git a/pkg/sentry/syscalls/linux/sys_utsname.go b/pkg/sentry/syscalls/linux/sys_utsname.go
index 271ace08e..748e8dd8d 100644
--- a/pkg/sentry/syscalls/linux/sys_utsname.go
+++ b/pkg/sentry/syscalls/linux/sys_utsname.go
@@ -79,11 +79,11 @@ func Sethostname(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S
return 0, nil, syserror.EINVAL
}
- name, err := t.CopyInString(nameAddr, int(size))
- if err != nil {
+ name := make([]byte, size)
+ if _, err := t.CopyInBytes(nameAddr, name); err != nil {
return 0, nil, err
}
- utsns.SetHostName(name)
+ utsns.SetHostName(string(name))
return 0, nil, nil
}
diff --git a/pkg/sentry/unimpl/BUILD b/pkg/sentry/unimpl/BUILD
index b69603da3..fc7614fff 100644
--- a/pkg/sentry/unimpl/BUILD
+++ b/pkg/sentry/unimpl/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -10,6 +11,12 @@ proto_library(
deps = ["//pkg/sentry/arch:registers_proto"],
)
+cc_proto_library(
+ name = "unimplemented_syscall_cc_proto",
+ visibility = ["//visibility:public"],
+ deps = [":unimplemented_syscall_proto"],
+)
+
go_proto_library(
name = "unimplemented_syscall_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/unimpl/unimplemented_syscall_go_proto",
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go
index 6eced660a..7b1f312b1 100644
--- a/pkg/sentry/usermem/usermem.go
+++ b/pkg/sentry/usermem/usermem.go
@@ -16,6 +16,7 @@
package usermem
import (
+ "bytes"
"errors"
"io"
"strconv"
@@ -270,11 +271,10 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts)
// Look for the terminating zero byte, which may have occurred before
// hitting err.
- for i, c := range buf[done : done+n] {
- if c == 0 {
- return stringFromImmutableBytes(buf[:done+i]), nil
- }
+ if i := bytes.IndexByte(buf[done:done+n], byte(0)); i >= 0 {
+ return stringFromImmutableBytes(buf[:done+i]), nil
}
+
done += n
if err != nil {
return stringFromImmutableBytes(buf[:done]), err
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index 86bde7fb3..7eb2b2821 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -199,8 +199,11 @@ type Dirent struct {
// Ino is the inode number.
Ino uint64
- // Off is this Dirent's offset.
- Off int64
+ // NextOff is the offset of the *next* Dirent in the directory; that is,
+ // FileDescription.Seek(NextOff, SEEK_SET) (as called by seekdir(3)) will
+ // cause the next call to FileDescription.IterDirents() to yield the next
+ // Dirent. (The offset of the first Dirent in a directory is always 0.)
+ NextOff int64
}
// IterDirentsCallback receives Dirents from FileDescriptionImpl.IterDirents.
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 672f026b2..8ced960bb 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -60,7 +60,10 @@ func TestTimeouts(t *testing.T) {
func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
// Create the stack and add a NIC.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()},
+ })
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
return nil, err
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
index eec430d0a..18adb2085 100644
--- a/pkg/tcpip/link/channel/channel.go
+++ b/pkg/tcpip/link/channel/channel.go
@@ -133,3 +133,6 @@ func (e *Endpoint) WritePacket(_ *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index adcf21371..7636418b1 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -41,6 +41,7 @@ package fdbased
import (
"fmt"
+ "sync"
"syscall"
"golang.org/x/sys/unix"
@@ -81,6 +82,19 @@ const (
PacketMMap
)
+func (p PacketDispatchMode) String() string {
+ switch p {
+ case Readv:
+ return "Readv"
+ case RecvMMsg:
+ return "RecvMMsg"
+ case PacketMMap:
+ return "PacketMMap"
+ default:
+ return fmt.Sprintf("unknown packet dispatch mode %v", p)
+ }
+}
+
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
@@ -114,6 +128,9 @@ type endpoint struct {
// gsoMaxSize is the maximum GSO packet size. It is zero if GSO is
// disabled.
gsoMaxSize uint32
+
+ // wg keeps track of running goroutines.
+ wg sync.WaitGroup
}
// Options specify the details about the fd-based endpoint to be created.
@@ -164,7 +181,8 @@ type Options struct {
// New creates a new fd-based endpoint.
//
// Makes fd non-blocking, but does not take ownership of fd, which must remain
-// open for the lifetime of the returned endpoint.
+// open for the lifetime of the returned endpoint (until after the endpoint has
+// stopped being using and Wait returns).
func New(opts *Options) (stack.LinkEndpoint, error) {
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
@@ -290,7 +308,11 @@ func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
for i := range e.inboundDispatchers {
- go e.dispatchLoop(e.inboundDispatchers[i]) // S/R-SAFE: See above.
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
}
}
@@ -320,6 +342,12 @@ func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
+// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
+// reading from its FD.
+func (e *endpoint) Wait() {
+ e.wg.Wait()
+}
+
// virtioNetHdr is declared in linux/virtio_net.h.
type virtioNetHdr struct {
flags uint8
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
index e121ea1a5..b36629d2c 100644
--- a/pkg/tcpip/link/loopback/loopback.go
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -85,3 +85,6 @@ func (e *endpoint) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prependa
return nil
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
diff --git a/pkg/tcpip/link/muxed/injectable.go b/pkg/tcpip/link/muxed/injectable.go
index 3ed7b98d1..7c946101d 100644
--- a/pkg/tcpip/link/muxed/injectable.go
+++ b/pkg/tcpip/link/muxed/injectable.go
@@ -104,6 +104,13 @@ func (m *InjectableEndpoint) WriteRawPacket(dest tcpip.Address, packet []byte) *
return endpoint.WriteRawPacket(dest, packet)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (m *InjectableEndpoint) Wait() {
+ for _, ep := range m.routes {
+ ep.Wait()
+ }
+}
+
// NewInjectableEndpoint creates a new multi-endpoint injectable endpoint.
func NewInjectableEndpoint(routes map[tcpip.Address]stack.InjectableLinkEndpoint) *InjectableEndpoint {
return &InjectableEndpoint{
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
index ba387af73..9e71d4edf 100644
--- a/pkg/tcpip/link/sharedmem/sharedmem.go
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -132,7 +132,8 @@ func (e *endpoint) Close() {
}
}
-// Wait waits until all workers have stopped after a Close() call.
+// Wait implements stack.LinkEndpoint.Wait. It waits until all workers have
+// stopped after a Close() call.
func (e *endpoint) Wait() {
e.completed.Wait()
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index e7b6d7912..e401dce44 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -240,6 +240,9 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
return e.lower.WritePacket(r, gso, hdr, payload, protocol)
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*endpoint) Wait() {}
+
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View, gso *stack.GSO) {
// Figure out the network layer info.
var transProto uint8
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
index 408cc62f7..5a1791cb5 100644
--- a/pkg/tcpip/link/waitable/waitable.go
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -120,3 +120,6 @@ func (e *Endpoint) WaitWrite() {
func (e *Endpoint) WaitDispatch() {
e.dispatchGate.Close()
}
+
+// Wait implements stack.LinkEndpoint.Wait.
+func (e *Endpoint) Wait() {}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
index 1031438b1..ae23c96b7 100644
--- a/pkg/tcpip/link/waitable/waitable_test.go
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -70,6 +70,9 @@ func (e *countedEndpoint) WritePacket(r *stack.Route, _ *stack.GSO, hdr buffer.P
return nil
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*countedEndpoint) Wait() {}
+
func TestWaitWrite(t *testing.T) {
ep := &countedEndpoint{}
wep := New(ep)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
index fd6395fc1..26cf1c528 100644
--- a/pkg/tcpip/network/arp/arp.go
+++ b/pkg/tcpip/network/arp/arp.go
@@ -16,9 +16,9 @@
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
//
-// To use it in the networking stack, pass arp.ProtocolName as one of the
-// network protocols when calling stack.New. Then add an "arp" address to
-// every NIC on the stack that should respond to ARP requests. That is:
+// To use it in the networking stack, pass arp.NewProtocol() as one of the
+// network protocols when calling stack.New. Then add an "arp" address to every
+// NIC on the stack that should respond to ARP requests. That is:
//
// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
// // handle err
@@ -33,9 +33,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ARP protocol name.
- ProtocolName = "arp"
-
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
@@ -200,8 +197,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an ARP network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
index 387fca96e..88b57ec03 100644
--- a/pkg/tcpip/network/arp/arp_test.go
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -44,7 +44,10 @@ type testContext struct {
}
func newTestContext(t *testing.T) *testContext {
- s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{icmp.ProtocolName4}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol4()},
+ })
const defaultMTU = 65536
ep := channel.New(256, defaultMTU, stackLinkAddr)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index 4b3bd74fa..a9741622e 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -144,6 +144,9 @@ func (*testObject) LinkAddress() tcpip.LinkAddress {
return ""
}
+// Wait implements stack.LinkEndpoint.Wait.
+func (*testObject) Wait() {}
+
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
@@ -169,7 +172,10 @@ func (t *testObject) WritePacket(_ *stack.Route, _ *stack.GSO, hdr buffer.Prepen
}
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
@@ -182,7 +188,10 @@ func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
- s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol(), tcp.NewProtocol()},
+ })
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index b7a06f525..b7b07a6c1 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -14,9 +14,9 @@
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv4.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv4.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv4
@@ -32,9 +32,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv4 protocol name.
- ProtocolName = "ipv4"
-
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
@@ -53,6 +50,7 @@ type endpoint struct {
linkEP stack.LinkEndpoint
dispatcher stack.TransportDispatcher
fragmentation *fragmentation.Fragmentation
+ protocol *protocol
}
// NewEndpoint creates a new ipv4 endpoint.
@@ -64,6 +62,7 @@ func (p *protocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWi
linkEP: linkEP,
dispatcher: dispatcher,
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ protocol: p,
}
return e, nil
@@ -204,7 +203,7 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1)
}
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
@@ -267,7 +266,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.Vect
if payload.Size() > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
- id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
+ id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, 0 /* protocol */, e.protocol.hashIV)%buckets], 1)
}
ip.SetID(uint16(id))
}
@@ -325,14 +324,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {}
-type protocol struct{}
-
-// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
+type protocol struct {
+ ids []uint32
+ hashIV uint32
}
// Number returns the ipv4 protocol number.
@@ -378,7 +372,7 @@ func calculateMTU(mtu uint32) uint32 {
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
-func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber, hashIV uint32) uint32 {
t := r.LocalAddress
a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = r.RemoteAddress
@@ -386,22 +380,16 @@ func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
-var (
- ids []uint32
- hashIV uint32
-)
-
-func init() {
- ids = make([]uint32, buckets)
+// NewProtocol returns an IPv4 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ ids := make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
r := hash.RandN32(1 + buckets)
for i := range ids {
ids[i] = r[i]
}
- hashIV = r[buckets]
+ hashIV := r[buckets]
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+ return &protocol{ids: ids, hashIV: hashIV}
}
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index ae827ca27..b6641ccc3 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -33,7 +33,10 @@ import (
)
func TestExcludeBroadcast(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
const defaultMTU = 65536
ep := stack.LinkEndpoint(channel.New(256, defaultMTU, ""))
@@ -238,7 +241,9 @@ type context struct {
func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context {
// Make the packet and write it.
- s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ })
ep := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors)
s.CreateNIC(1, ep)
const (
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index 653d984e9..01f5a17ec 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -81,7 +81,10 @@ func (*stubLinkAddressCache) AddLinkAddress(tcpip.NICID, tcpip.Address, tcpip.Li
}
func TestICMPCounts(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
{
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -205,8 +208,14 @@ func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapab
func newTestContext(t *testing.T) *testContext {
c := &testContext{
- s0: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
- s1: stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{}),
+ s0: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
+ s1: stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ }),
}
const defaultMTU = 65536
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index 331a8bdaa..7de6a4546 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -14,9 +14,9 @@
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
-// network protocols when calling stack.New(). Then endpoints can be created
-// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// activated on the stack by passing ipv6.NewProtocol() as one of the network
+// protocols when calling stack.New(). Then endpoints can be created by passing
+// ipv6.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv6
@@ -28,9 +28,6 @@ import (
)
const (
- // ProtocolName is the string representation of the ipv6 protocol name.
- ProtocolName = "ipv6"
-
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
@@ -160,14 +157,6 @@ func (*endpoint) Close() {}
type protocol struct{}
-// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
-// only for tests that short-circuit the stack. Regular use of the protocol is
-// done via the stack, which gets a protocol descriptor from the init() function
-// below.
-func NewProtocol() stack.NetworkProtocol {
- return &protocol{}
-}
-
// Number returns the ipv6 protocol number.
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
@@ -221,8 +210,7 @@ func calculateMTU(mtu uint32) uint32 {
return maxPayloadSize
}
-func init() {
- stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
- return &protocol{}
- })
+// NewProtocol returns an IPv6 network protocol.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 57bcd5455..78c674c2c 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -124,17 +124,20 @@ func testReceiveUDP(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst
// UDP packets destined to the IPv6 link-local all-nodes multicast address.
func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
tests := []struct {
- name string
- protocolName string
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.ProtocolName6, testReceiveICMP},
- {"UDP", udp.ProtocolName, testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -152,19 +155,22 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) {
// address.
func TestReceiveOnSolicitedNodeAddr(t *testing.T) {
tests := []struct {
- name string
- protocolName string
- rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
+ name string
+ protocolFactory stack.TransportProtocol
+ rxf func(t *testing.T, s *stack.Stack, e *channel.Endpoint, src, dst tcpip.Address, want uint64)
}{
- {"ICMP", icmp.ProtocolName6, testReceiveICMP},
- {"UDP", udp.ProtocolName, testReceiveUDP},
+ {"ICMP", icmp.NewProtocol6(), testReceiveICMP},
+ {"UDP", udp.NewProtocol(), testReceiveUDP},
}
snmc := header.SolicitedNodeAddr(addr2)
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New([]string{ProtocolName}, []string{test.protocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{test.protocolFactory},
+ })
e := channel.New(10, 1280, linkAddr1)
if err := s.CreateNIC(1, e); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
@@ -237,7 +243,9 @@ func TestAddIpv6Address(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- s := stack.New([]string{ProtocolName}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ })
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
}
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index 571915d3f..e30791fe3 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -31,7 +31,10 @@ import (
func setupStackAndEndpoint(t *testing.T, llladdr, rlladdr tcpip.Address) (*stack.Stack, stack.NetworkEndpoint) {
t.Helper()
- s := stack.New([]string{ProtocolName}, []string{icmp.ProtocolName6}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{icmp.NewProtocol6()},
+ })
if err := s.CreateNIC(1, &stubLinkEndpoint{}); err != nil {
t.Fatalf("CreateNIC(_) = %s", err)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
index 315780c0c..30cea8996 100644
--- a/pkg/tcpip/ports/ports.go
+++ b/pkg/tcpip/ports/ports.go
@@ -19,6 +19,7 @@ import (
"math"
"math/rand"
"sync"
+ "sync/atomic"
"gvisor.dev/gvisor/pkg/tcpip"
)
@@ -27,6 +28,10 @@ const (
// FirstEphemeral is the first ephemeral port.
FirstEphemeral = 16000
+ // numEphemeralPorts it the mnumber of available ephemeral ports to
+ // Netstack.
+ numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1
+
anyIPAddress tcpip.Address = ""
)
@@ -40,6 +45,13 @@ type portDescriptor struct {
type PortManager struct {
mu sync.RWMutex
allocatedPorts map[portDescriptor]bindAddresses
+
+ // hint is used to pick ports ephemeral ports in a stable order for
+ // a given port offset.
+ //
+ // hint must be accessed using the portHint/incPortHint helpers.
+ // TODO(gvisor.dev/issue/940): S/R this field.
+ hint uint32
}
type portNode struct {
@@ -47,43 +59,76 @@ type portNode struct {
refs int
}
-// bindAddresses is a set of IP addresses.
-type bindAddresses map[tcpip.Address]portNode
+// deviceNode is never empty. When it has no elements, it is removed from the
+// map that references it.
+type deviceNode map[tcpip.NICID]portNode
-// isAvailable checks whether an IP address is available to bind to.
-func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool {
- if addr == anyIPAddress {
- if len(b) == 0 {
- return true
- }
+// isAvailable checks whether binding is possible by device. If not binding to a
+// device, check against all portNodes. If binding to a specific device, check
+// against the unspecified device and the provided device.
+func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool {
+ if bindToDevice == 0 {
+ // Trying to binding all devices.
if !reuse {
+ // Can't bind because the (addr,port) is already bound.
return false
}
- for _, n := range b {
- if !n.reuse {
+ for _, p := range d {
+ if !p.reuse {
+ // Can't bind because the (addr,port) was previously bound without reuse.
return false
}
}
return true
}
- // If all addresses for this portDescriptor are already bound, no
- // address is available.
- if n, ok := b[anyIPAddress]; ok {
- if !reuse {
+ if p, ok := d[0]; ok {
+ if !reuse || !p.reuse {
return false
}
- if !n.reuse {
+ }
+
+ if p, ok := d[bindToDevice]; ok {
+ if !reuse || !p.reuse {
return false
}
}
- if n, ok := b[addr]; ok {
- if !reuse {
+ return true
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]deviceNode
+
+// isAvailable checks whether an IP address is available to bind to. If the
+// address is the "any" address, check all other addresses. Otherwise, just
+// check against the "any" address and the provided address.
+func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool {
+ if addr == anyIPAddress {
+ // If binding to the "any" address then check that there are no conflicts
+ // with all addresses.
+ for _, d := range b {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+ return true
+ }
+
+ // Check that there is no conflict with the "any" address.
+ if d, ok := b[anyIPAddress]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
return false
}
- return n.reuse
}
+
+ // Check that this is no conflict with the provided address.
+ if d, ok := b[addr]; ok {
+ if !d.isAvailable(reuse, bindToDevice) {
+ return false
+ }
+ }
+
return true
}
@@ -97,11 +142,40 @@ func NewPortManager() *PortManager {
// is suitable for its needs, and stopping when a port is found or an error
// occurs.
func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
- count := uint16(math.MaxUint16 - FirstEphemeral + 1)
- offset := uint16(rand.Int31n(int32(count)))
+ offset := uint32(rand.Int31n(numEphemeralPorts))
+ return s.pickEphemeralPort(offset, numEphemeralPorts, testPort)
+}
+
+// portHint atomically reads and returns the s.hint value.
+func (s *PortManager) portHint() uint32 {
+ return atomic.LoadUint32(&s.hint)
+}
+
+// incPortHint atomically increments s.hint by 1.
+func (s *PortManager) incPortHint() {
+ atomic.AddUint32(&s.hint, 1)
+}
- for i := uint16(0); i < count; i++ {
- port = FirstEphemeral + (offset+i)%count
+// PickEphemeralPortStable starts at the specified offset + s.portHint and
+// iterates over all ephemeral ports, allowing the caller to decide whether a
+// given port is suitable for its needs and stopping when a port is found or an
+// error occurs.
+func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort)
+ if err == nil {
+ s.incPortHint()
+ }
+ return p, err
+
+}
+
+// pickEphemeralPort starts at the offset specified from the FirstEphemeral port
+// and iterates over the number of ports specified by count and allows the
+// caller to decide whether a given port is suitable for its needs, and stopping
+// when a port is found or an error occurs.
+func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ for i := uint32(0); i < count; i++ {
+ port = uint16(FirstEphemeral + (offset+i)%count)
ok, err := testPort(port)
if err != nil {
return 0, err
@@ -116,17 +190,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er
}
// IsPortAvailable tests if the given port is available on all given protocols.
-func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
s.mu.Lock()
defer s.mu.Unlock()
- return s.isPortAvailableLocked(networks, transport, addr, port, reuse)
+ return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice)
}
-func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
+func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
- if !addrs.isAvailable(addr, reuse) {
+ if !addrs.isAvailable(addr, reuse, bindToDevice) {
return false
}
}
@@ -138,14 +212,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb
// reserved by another endpoint. If port is zero, ReservePort will search for
// an unreserved ephemeral port and reserve it, returning its value in the
// "port" return value.
-func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) {
+func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
- if !s.reserveSpecificPort(networks, transport, addr, port, reuse) {
+ if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) {
return 0, tcpip.ErrPortInUse
}
return port, nil
@@ -153,13 +227,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp
// A port wasn't specified, so try to find one.
return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
- return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil
+ return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil
})
}
// reserveSpecificPort tries to reserve the given port on all given protocols.
-func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool {
- if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) {
+func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool {
+ if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) {
return false
}
@@ -171,11 +245,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
m = make(bindAddresses)
s.allocatedPorts[desc] = m
}
- if n, ok := m[addr]; ok {
+ d, ok := m[addr]
+ if !ok {
+ d = make(deviceNode)
+ m[addr] = d
+ }
+ if n, ok := d[bindToDevice]; ok {
n.refs++
- m[addr] = n
+ d[bindToDevice] = n
} else {
- m[addr] = portNode{reuse: reuse, refs: 1}
+ d[bindToDevice] = portNode{reuse: reuse, refs: 1}
}
}
@@ -184,22 +263,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
-func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) {
s.mu.Lock()
defer s.mu.Unlock()
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
- n, ok := m[addr]
+ d, ok := m[addr]
+ if !ok {
+ continue
+ }
+ n, ok := d[bindToDevice]
if !ok {
continue
}
n.refs--
+ d[bindToDevice] = n
if n.refs == 0 {
+ delete(d, bindToDevice)
+ }
+ if len(d) == 0 {
delete(m, addr)
- } else {
- m[addr] = n
}
if len(m) == 0 {
delete(s.allocatedPorts, desc)
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
index 689401661..19f4833fc 100644
--- a/pkg/tcpip/ports/ports_test.go
+++ b/pkg/tcpip/ports/ports_test.go
@@ -15,6 +15,7 @@
package ports
import (
+ "math/rand"
"testing"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -34,6 +35,7 @@ type portReserveTestAction struct {
want *tcpip.Error
reuse bool
release bool
+ device tcpip.NICID
}
func TestPortReservation(t *testing.T) {
@@ -100,6 +102,112 @@ func TestPortReservation(t *testing.T) {
{port: 24, ip: anyIPAddress, release: true},
{port: 24, ip: anyIPAddress, reuse: false, want: nil},
},
+ }, {
+ tname: "bind twice with device fails",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 3, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 1, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 2, want: nil},
+ },
+ }, {
+ tname: "bind to device and then without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind without device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, want: nil},
+ {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind with reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ },
+ }, {
+ tname: "binding with reuse and device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "mixing reuse and not reuse by binding to device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 999, want: nil},
+ },
+ }, {
+ tname: "can't bind to 0 after mixing reuse and not reuse",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "bind and release",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse},
+ {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil},
+
+ // Release the bind to device 0 and try again.
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil},
+ },
+ }, {
+ tname: "bind twice with reuse once",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse},
+ },
+ }, {
+ tname: "release an unreserved device",
+ actions: []portReserveTestAction{
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil},
+ // The below don't exist.
+ {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true},
+ {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ // Release all.
+ {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true},
+ {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true},
+ },
},
} {
t.Run(test.tname, func(t *testing.T) {
@@ -108,12 +216,12 @@ func TestPortReservation(t *testing.T) {
for _, test := range test.actions {
if test.release {
- pm.ReleasePort(net, fakeTransNumber, test.ip, test.port)
+ pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device)
continue
}
- gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse)
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device)
if err != test.want {
- t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want)
+ t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
@@ -125,7 +233,6 @@ func TestPortReservation(t *testing.T) {
}
func TestPickEphemeralPort(t *testing.T) {
- pm := NewPortManager()
customErr := &tcpip.Error{}
for _, test := range []struct {
name string
@@ -169,9 +276,63 @@ func TestPickEphemeralPort(t *testing.T) {
},
} {
t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
}
})
}
}
+
+func TestPickEphemeralPortStable(t *testing.T) {
+ customErr := &tcpip.Error{}
+ for _, test := range []struct {
+ name string
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
+ wantPort uint16
+ }{
+ {
+ name: "no-port-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ {
+ name: "port-tester-error",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, customErr
+ },
+ wantErr: customErr,
+ },
+ {
+ name: "only-port-16042-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port == FirstEphemeral+42 {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantPort: FirstEphemeral + 42,
+ },
+ {
+ name: "only-port-under-16000-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port < FirstEphemeral {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ pm := NewPortManager()
+ portOffset := uint32(rand.Int31n(int32(numEphemeralPorts)))
+ if port, err := pm.PickEphemeralPortStable(portOffset, test.f); port != test.wantPort || err != test.wantErr {
+ t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index f12189580..2239c1e66 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -126,7 +126,10 @@ func main() {
// Create the stack with ipv4 and tcp protocols, then add a tun-based
// NIC and ipv4 address.
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 329941775..bca73cbb1 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -111,7 +111,10 @@ func main() {
// Create the stack with ip and tcp protocols, then add a tun-based
// NIC and address.
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
mtu, err := rawfile.GetMTU(tunName)
if err != nil {
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 28c49e8ff..baf88bfab 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -36,6 +36,7 @@ go_library(
],
deps = [
"//pkg/ilist",
+ "//pkg/rand",
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
@@ -54,6 +55,7 @@ go_test(
size = "small",
srcs = [
"stack_test.go",
+ "transport_demuxer_test.go",
"transport_test.go",
],
deps = [
@@ -64,6 +66,9 @@ go_test(
"//pkg/tcpip/iptables",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/transport/udp",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go
index f8156be47..3a20839da 100644
--- a/pkg/tcpip/stack/icmp_rate_limit.go
+++ b/pkg/tcpip/stack/icmp_rate_limit.go
@@ -15,8 +15,6 @@
package stack
import (
- "sync"
-
"golang.org/x/time/rate"
)
@@ -33,54 +31,11 @@ const (
// ICMPRateLimiter is a global rate limiter that controls the generation of
// ICMP messages generated by the stack.
type ICMPRateLimiter struct {
- mu sync.RWMutex
- l *rate.Limiter
+ *rate.Limiter
}
// NewICMPRateLimiter returns a global rate limiter for controlling the rate
// at which ICMP messages are generated by the stack.
func NewICMPRateLimiter() *ICMPRateLimiter {
- return &ICMPRateLimiter{l: rate.NewLimiter(icmpLimit, icmpBurst)}
-}
-
-// Allow returns true if we are allowed to send at least 1 message at the
-// moment.
-func (i *ICMPRateLimiter) Allow() bool {
- i.mu.RLock()
- allow := i.l.Allow()
- i.mu.RUnlock()
- return allow
-}
-
-// Limit returns the maximum number of ICMP messages that can be sent in one
-// second.
-func (i *ICMPRateLimiter) Limit() rate.Limit {
- i.mu.RLock()
- defer i.mu.RUnlock()
- return i.l.Limit()
-}
-
-// SetLimit sets the maximum number of ICMP messages that can be sent in one
-// second.
-func (i *ICMPRateLimiter) SetLimit(newLimit rate.Limit) {
- i.mu.RLock()
- defer i.mu.RUnlock()
- i.l.SetLimit(newLimit)
-}
-
-// Burst returns how many ICMP messages can be sent at any single instant.
-func (i *ICMPRateLimiter) Burst() int {
- i.mu.RLock()
- defer i.mu.RUnlock()
- return i.l.Burst()
-}
-
-// SetBurst sets the maximum number of ICMP messages allowed at any single
-// instant.
-//
-// NOTE: Changing Burst causes the underlying rate limiter to be recreated.
-func (i *ICMPRateLimiter) SetBurst(burst int) {
- i.mu.Lock()
- i.l = rate.NewLimiter(i.l.Limit(), burst)
- i.mu.Unlock()
+ return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a719058b4..f6106f762 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -34,8 +34,6 @@ type NIC struct {
linkEP LinkEndpoint
loopback bool
- demux *transportDemuxer
-
mu sync.RWMutex
spoofing bool
promiscuous bool
@@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback
name: name,
linkEP: ep,
loopback: loopback,
- demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
mcastJoins: make(map[NetworkEndpointID]int32),
@@ -148,37 +145,6 @@ func (n *NIC) setSpoofing(enable bool) {
n.mu.Unlock()
}
-func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
- n.mu.RLock()
- defer n.mu.RUnlock()
-
- var r *referencedNetworkEndpoint
-
- // Check for a primary endpoint.
- if list, ok := n.primary[protocol]; ok {
- for e := list.Front(); e != nil; e = e.Next() {
- ref := e.(*referencedNetworkEndpoint)
- if ref.getKind() == permanent && ref.tryIncRef() {
- r = ref
- break
- }
- }
-
- }
-
- if r == nil {
- return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress
- }
-
- addressWithPrefix := tcpip.AddressWithPrefix{
- Address: r.ep.ID().LocalAddress,
- PrefixLen: r.ep.PrefixLen(),
- }
- r.decRef()
-
- return addressWithPrefix, nil
-}
-
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
@@ -398,10 +364,12 @@ func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo
return err
}
-// Addresses returns the addresses associated with this NIC.
-func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+// AllAddresses returns all addresses (primary and non-primary) associated with
+// this NIC.
+func (n *NIC) AllAddresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
+
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ref := range n.endpoints {
// Don't include expired or tempory endpoints to avoid confusion and
@@ -421,6 +389,34 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress {
return addrs
}
+// PrimaryAddresses returns the primary addresses associated with this NIC.
+func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ var addrs []tcpip.ProtocolAddress
+ for proto, list := range n.primary {
+ for e := list.Front(); e != nil; e = e.Next() {
+ ref := e.(*referencedNetworkEndpoint)
+ // Don't include expired or tempory endpoints to avoid confusion and
+ // prevent the caller from using those.
+ switch ref.getKind() {
+ case permanentExpired, temporary:
+ continue
+ }
+
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: proto,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: ref.ep.ID().LocalAddress,
+ PrefixLen: ref.ep.PrefixLen(),
+ },
+ })
+ }
+ }
+ return addrs
+}
+
// AddAddressRange adds a range of addresses to n, so that it starts accepting
// packets targeted at the given addresses and network protocol. The range is
// given by a subnet address, and all addresses contained in the subnet are
@@ -708,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
// Raw socket packets are delivered based solely on the transport
// protocol number. We do not inspect the payload to ensure it's
// validly formed.
- if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) {
- n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
- }
+ n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv)
if len(vv.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
@@ -724,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
- if n.demux.deliverPacket(r, protocol, netHeader, vv, id) {
- return
- }
if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) {
return
}
@@ -768,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
- if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
- return
- }
- if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) {
return
}
}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 88a698b18..80101d4bb 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -295,6 +295,15 @@ type LinkEndpoint interface {
// IsAttached returns whether a NetworkDispatcher is attached to the
// endpoint.
IsAttached() bool
+
+ // Wait waits for any worker goroutines owned by the endpoint to stop.
+ //
+ // For now, requesting that an endpoint's worker goroutine(s) stop is
+ // implementation specific.
+ //
+ // Wait will not block if the endpoint hasn't started any goroutines
+ // yet, even if it might later.
+ Wait()
}
// InjectableLinkEndpoint is a LinkEndpoint where inbound packets are
@@ -357,14 +366,6 @@ type LinkAddressCache interface {
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
-// TransportProtocolFactory functions are used by the stack to instantiate
-// transport protocols.
-type TransportProtocolFactory func() TransportProtocol
-
-// NetworkProtocolFactory provides methods to be used by the stack to
-// instantiate network protocols.
-type NetworkProtocolFactory func() NetworkProtocol
-
// UnassociatedEndpointFactory produces endpoints for writing packets not
// associated with a particular transport protocol. Such endpoints can be used
// to write arbitrary packets that include the IP header.
@@ -372,34 +373,6 @@ type UnassociatedEndpointFactory interface {
NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
}
-var (
- transportProtocols = make(map[string]TransportProtocolFactory)
- networkProtocols = make(map[string]NetworkProtocolFactory)
-
- unassociatedFactory UnassociatedEndpointFactory
-)
-
-// RegisterTransportProtocolFactory registers a new transport protocol factory
-// with the stack so that it becomes available to users of the stack. This
-// function is intended to be called by init() functions of the protocols.
-func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
- transportProtocols[name] = p
-}
-
-// RegisterNetworkProtocolFactory registers a new network protocol factory with
-// the stack so that it becomes available to users of the stack. This function
-// is intended to be called by init() functions of the protocols.
-func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
- networkProtocols[name] = p
-}
-
-// RegisterUnassociatedFactory registers a factory to produce endpoints not
-// associated with any particular transport protocol. This function is intended
-// to be called by init() functions of the protocols.
-func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
- unassociatedFactory = f
-}
-
// GSOType is the type of GSO segments.
//
// +stateify savable
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 1fe21b68e..90c2cf1be 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -17,18 +17,15 @@
//
// For consumers, the only function of interest is New(), everything else is
// provided by the tcpip/public package.
-//
-// For protocol implementers, RegisterTransportProtocolFactory() and
-// RegisterNetworkProtocolFactory() are used to register protocol factories with
-// the stack, which will then be used to instantiate protocol objects when
-// consumers interact with the stack.
package stack
import (
+ "encoding/binary"
"sync"
"time"
"golang.org/x/time/rate"
+ "gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -351,6 +348,9 @@ type Stack struct {
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+ // unassociatedFactory creates unassociated endpoints. If nil, raw
+ // endpoints are disabled. It is set during Stack creation and is
+ // immutable.
unassociatedFactory UnassociatedEndpointFactory
demux *transportDemuxer
@@ -359,10 +359,6 @@ type Stack struct {
linkAddrCache *linkAddrCache
- // raw indicates whether raw sockets may be created. It is set during
- // Stack creation and is immutable.
- raw bool
-
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
forwarding bool
@@ -394,10 +390,22 @@ type Stack struct {
// icmpRateLimiter is a global rate limiter for all ICMP messages generated
// by the stack.
icmpRateLimiter *ICMPRateLimiter
+
+ // portSeed is a one-time random value initialized at stack startup
+ // and is used to seed the TCP port picking on active connections
+ //
+ // TODO(gvisor.dev/issues/940): S/R this field.
+ portSeed uint32
}
// Options contains optional Stack configuration.
type Options struct {
+ // NetworkProtocols lists the network protocols to enable.
+ NetworkProtocols []NetworkProtocol
+
+ // TransportProtocols lists the transport protocols to enable.
+ TransportProtocols []TransportProtocol
+
// Clock is an optional clock source used for timestampping packets.
//
// If no Clock is specified, the clock source will be time.Now.
@@ -411,8 +419,9 @@ type Options struct {
// stack (false).
HandleLocal bool
- // Raw indicates whether raw sockets may be created.
- Raw bool
+ // UnassociatedFactory produces unassociated endpoints raw endpoints.
+ // Raw endpoints are enabled only if this is non-nil.
+ UnassociatedFactory UnassociatedEndpointFactory
}
// New allocates a new networking stack with only the requested networking and
@@ -422,7 +431,7 @@ type Options struct {
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
// are supported.
-func New(network []string, transport []string, opts Options) *Stack {
+func New(opts Options) *Stack {
clock := opts.Clock
if clock == nil {
clock = &tcpip.StdClock{}
@@ -438,17 +447,12 @@ func New(network []string, transport []string, opts Options) *Stack {
clock: clock,
stats: opts.Stats.FillIn(),
handleLocal: opts.HandleLocal,
- raw: opts.Raw,
icmpRateLimiter: NewICMPRateLimiter(),
+ portSeed: generateRandUint32(),
}
// Add specified network protocols.
- for _, name := range network {
- netProtoFactory, ok := networkProtocols[name]
- if !ok {
- continue
- }
- netProto := netProtoFactory()
+ for _, netProto := range opts.NetworkProtocols {
s.networkProtocols[netProto.Number()] = netProto
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
@@ -456,18 +460,14 @@ func New(network []string, transport []string, opts Options) *Stack {
}
// Add specified transport protocols.
- for _, name := range transport {
- transProtoFactory, ok := transportProtocols[name]
- if !ok {
- continue
- }
- transProto := transProtoFactory()
+ for _, transProto := range opts.TransportProtocols {
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
}
- s.unassociatedFactory = unassociatedFactory
+ // Add the factory for unassociated endpoints, if present.
+ s.unassociatedFactory = opts.UnassociatedFactory
// Create the global transport demuxer.
s.demux = newTransportDemuxer(s)
@@ -602,7 +602,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
// protocol. Raw endpoints receive all traffic for a given protocol regardless
// of address.
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if !s.raw {
+ if s.unassociatedFactory == nil {
return nil, tcpip.ErrNotPermitted
}
@@ -738,7 +738,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.linkEP.LinkAddress(),
- ProtocolAddresses: nic.Addresses(),
+ ProtocolAddresses: nic.PrimaryAddresses(),
Flags: flags,
MTU: nic.linkEP.MTU(),
Stats: nic.stats,
@@ -845,19 +845,37 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
return tcpip.ErrUnknownNICID
}
-// GetMainNICAddress returns the first primary address (and the subnet that
-// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
-// address if no primary addresses exist. Returns an error if the NIC doesn't
-// exist or has no endpoints.
+// AllAddresses returns a map of NICIDs to their protocol addresses (primary
+// and non-primary).
+func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress)
+ for id, nic := range s.nics {
+ nics[id] = nic.AllAddresses()
+ }
+ return nics
+}
+
+// GetMainNICAddress returns the first primary address and prefix for the given
+// NIC and protocol. Returns an error if the NIC doesn't exist and an empty
+// value if the NIC doesn't have a primary address for the given protocol.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
- if nic, ok := s.nics[id]; ok {
- return nic.getMainNICAddress(protocol)
+ nic, ok := s.nics[id]
+ if !ok {
+ return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
}
- return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ for _, a := range nic.PrimaryAddresses() {
+ if a.Protocol == protocol {
+ return a.AddressWithPrefix, nil
+ }
+ }
+ return tcpip.AddressWithPrefix{}, nil
}
func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) {
@@ -1024,73 +1042,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort)
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterEndpoint(netProtos, protocol, id, ep)
- }
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
+ s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
}
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
- if nicID == 0 {
- return s.demux.registerRawEndpoint(netProto, transProto, ep)
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic == nil {
- return tcpip.ErrUnknownNICID
- }
-
- return nic.demux.registerRawEndpoint(netProto, transProto, ep)
+ return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
- if nicID == 0 {
- s.demux.unregisterRawEndpoint(netProto, transProto, ep)
- return
- }
-
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- nic := s.nics[nicID]
- if nic != nil {
- nic.demux.unregisterRawEndpoint(netProto, transProto, ep)
- }
+ s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
// RegisterRestoredEndpoint records e as an endpoint that has been restored on
@@ -1234,3 +1206,19 @@ func (s *Stack) SetICMPBurst(burst int) {
func (s *Stack) AllowICMPMessage() bool {
return s.icmpRateLimiter.Allow()
}
+
+// PortSeed returns a 32 bit value that can be used as a seed value for port
+// picking.
+//
+// NOTE: The seed is generated once during stack initialization only.
+func (s *Stack) PortSeed() uint32 {
+ return s.portSeed
+}
+
+func generateRandUint32() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ return binary.LittleEndian.Uint32(b)
+}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 0c26c9911..d2dede8a9 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -222,11 +222,17 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeNetFactory() stack.NetworkProtocol {
+ return &fakeNetworkProtocol{}
+}
+
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
ep := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
}
@@ -370,7 +376,9 @@ func TestNetworkSend(t *testing.T) {
// address: 1. The route table sends all packets through the only
// existing nic.
ep := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("NewNIC failed:", err)
}
@@ -395,7 +403,9 @@ func TestNetworkSendMultiRoute(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
@@ -476,7 +486,9 @@ func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
@@ -554,7 +566,9 @@ func TestAddressRemoval(t *testing.T) {
localAddr := tcpip.Address([]byte{localAddrByte})
remoteAddr := tcpip.Address("\x02")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -599,7 +613,9 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) {
localAddr := tcpip.Address([]byte{localAddrByte})
remoteAddr := tcpip.Address("\x02")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -688,7 +704,9 @@ func TestEndpointExpiration(t *testing.T) {
for _, promiscuous := range []bool{true, false} {
for _, spoofing := range []bool{true, false} {
t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
@@ -844,7 +862,9 @@ func TestEndpointExpiration(t *testing.T) {
}
func TestPromiscuousMode(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -894,7 +914,9 @@ func TestSpoofingWithAddress(t *testing.T) {
nonExistentLocalAddr := tcpip.Address("\x02")
dstAddr := tcpip.Address("\x03")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -930,10 +952,10 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != nonExistentLocalAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
}
// Sending a packet works.
testSendTo(t, s, dstAddr, ep, nil)
@@ -945,10 +967,10 @@ func TestSpoofingWithAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != localAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
}
// Sending a packet using the route works.
testSend(t, r, ep, nil)
@@ -958,7 +980,9 @@ func TestSpoofingNoAddress(t *testing.T) {
nonExistentLocalAddr := tcpip.Address("\x01")
dstAddr := tcpip.Address("\x02")
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -992,10 +1016,10 @@ func TestSpoofingNoAddress(t *testing.T) {
t.Fatal("FindRoute failed:", err)
}
if r.LocalAddress != nonExistentLocalAddr {
- t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr)
+ t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr)
}
if r.RemoteAddress != dstAddr {
- t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr)
}
// Sending a packet works.
// FIXME(b/139841518):Spoofing doesn't work if there is no primary address.
@@ -1003,7 +1027,9 @@ func TestSpoofingNoAddress(t *testing.T) {
}
func TestBroadcastNeedsNoRoute(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1074,7 +1100,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
{"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"},
} {
t.Run(tc.name, func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1130,7 +1158,9 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) {
// Add a range of addresses, then check that a packet is delivered.
func TestAddressRangeAcceptsMatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1196,7 +1226,9 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub
// existent.
func TestCheckLocalAddressForSubnet(t *testing.T) {
const nicID tcpip.NICID = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicID, ep); err != nil {
@@ -1234,7 +1266,9 @@ func TestCheckLocalAddressForSubnet(t *testing.T) {
// Set a range of addresses, then send a packet to a destination outside the
// range and then check it doesn't get delivered.
func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
@@ -1266,7 +1300,10 @@ func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) {
}
func TestNetworkOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{},
+ })
// Try an unsupported network protocol.
if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -1319,7 +1356,9 @@ func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.S
}
func TestAddresRangeAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1360,7 +1399,9 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) {
for never := 0; never < 3; never++ {
t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1400,20 +1441,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
if len(primaryAddrAdded) == 0 {
- // No primary addresses present, expect an error.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress)
+ // No primary addresses present.
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
}
} else {
- // At least one primary address was added, expect a valid
- // address and prefixLen.
- gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok {
- t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded)
+ // At least one primary address was added, verify the returned
+ // address is in the list of primary addresses we added.
+ if _, ok := primaryAddrAdded[gotAddr]; !ok {
+ t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
}
}
})
@@ -1425,7 +1466,9 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
}
func TestGetMainNICAddressAddRemove(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1452,19 +1495,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
+ gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
t.Fatal("GetMainNICAddress failed:", err)
- } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix)
+ }
+ if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
+ t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
t.Fatal("RemoveAddress failed:", err)
}
- // Check that we get an error after removal.
- if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
- t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress)
+ // Check that we get no address after removal.
+ gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
+ if err != nil {
+ t.Fatal("GetMainNICAddress failed:", err)
+ }
+ if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
+ t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
}
})
}
@@ -1479,8 +1528,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address {
}
func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) {
+ t.Helper()
+
if len(gotAddresses) != len(expectedAddresses) {
- t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses))
+ t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses))
}
sort.Slice(gotAddresses, func(i, j int) bool {
@@ -1500,7 +1551,9 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto
func TestAddAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1519,13 +1572,15 @@ func TestAddAddress(t *testing.T) {
})
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddress(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1551,13 +1606,15 @@ func TestAddProtocolAddress(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1580,13 +1637,15 @@ func TestAddAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestAddProtocolAddressWithOptions(t *testing.T) {
const nicid = 1
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(nicid, ep); err != nil {
t.Fatal("CreateNIC failed:", err)
@@ -1615,12 +1674,14 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
}
}
- gotAddresses := s.NICInfo()[nicid].ProtocolAddresses
+ gotAddresses := s.AllAddresses()[nicid]
verifyAddresses(t, expectedAddresses, gotAddresses)
}
func TestNICStats(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
ep1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, ep1); err != nil {
t.Fatal("CreateNIC failed: ", err)
@@ -1666,7 +1727,9 @@ func TestNICStats(t *testing.T) {
func TestNICForwarding(t *testing.T) {
// Create a stack with the fake network protocol, two NICs, each with
// an address.
- s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ })
s.SetForwarding(true)
ep1 := channel.New(10, defaultMTU, "")
@@ -1714,9 +1777,3 @@ func TestNICForwarding(t *testing.T) {
t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want)
}
}
-
-func init() {
- stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
- return &fakeNetworkProtocol{}
- })
-}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index cf8a6d129..8c768c299 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -35,25 +35,109 @@ type protocolIDs struct {
type transportEndpoints struct {
// mu protects all fields of the transportEndpoints.
mu sync.RWMutex
- endpoints map[TransportEndpointID]TransportEndpoint
+ endpoints map[TransportEndpointID]*endpointsByNic
// rawEndpoints contains endpoints for raw sockets, which receive all
// traffic of a given protocol regardless of port.
rawEndpoints []RawTransportEndpoint
}
+type endpointsByNic struct {
+ mu sync.RWMutex
+ endpoints map[tcpip.NICID]*multiPortEndpoint
+ // seed is a random secret for a jenkins hash.
+ seed uint32
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+
+ mpep, ok := epsByNic.endpoints[r.ref.nic.ID()]
+ if !ok {
+ if mpep, ok = epsByNic.endpoints[0]; !ok {
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+ }
+
+ // If this is a broadcast or multicast datagram, deliver the datagram to all
+ // endpoints bound to the right device.
+ if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
+ mpep.handlePacketAll(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+ return
+ }
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv)
+ epsByNic.mu.RUnlock() // Don't use defer for performance reasons.
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
+ epsByNic.mu.RLock()
+ defer epsByNic.mu.RUnlock()
+
+ mpep, ok := epsByNic.endpoints[n.ID()]
+ if !ok {
+ mpep, ok = epsByNic.endpoints[0]
+ }
+ if !ok {
+ return
+ }
+
+ // TODO(eyalsoha): Why don't we look at id to see if this packet needs to
+ // broadcast like we are doing with handlePacket above?
+
+ // multiPortEndpoints are guaranteed to have at least one element.
+ selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv)
+}
+
+// registerEndpoint returns true if it succeeds. It fails and returns
+// false if ep already has an element with the same key.
+func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+
+ if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok {
+ // There was already a bind.
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+ }
+
+ // This is a new binding.
+ multiPortEp := &multiPortEndpoint{}
+ multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
+ multiPortEp.reuse = reusePort
+ epsByNic.endpoints[bindToDevice] = multiPortEp
+ return multiPortEp.singleRegisterEndpoint(t, reusePort)
+}
+
+// unregisterEndpoint returns true if endpointsByNic has to be unregistered.
+func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool {
+ epsByNic.mu.Lock()
+ defer epsByNic.mu.Unlock()
+ multiPortEp, ok := epsByNic.endpoints[bindToDevice]
+ if !ok {
+ return false
+ }
+ if multiPortEp.unregisterEndpoint(t) {
+ delete(epsByNic.endpoints, bindToDevice)
+ }
+ return len(epsByNic.endpoints) == 0
+}
+
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) {
+func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
eps.mu.Lock()
defer eps.mu.Unlock()
- e, ok := eps.endpoints[id]
+ epsByNic, ok := eps.endpoints[id]
if !ok {
return
}
- if multiPortEp, ok := e.(*multiPortEndpoint); ok {
- if !multiPortEp.unregisterEndpoint(ep) {
- return
- }
+ if !epsByNic.unregisterEndpoint(bindToDevice, ep) {
+ return
}
delete(eps.endpoints, id)
}
@@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{
- endpoints: make(map[TransportEndpointID]TransportEndpoint),
+ endpoints: make(map[TransportEndpointID]*endpointsByNic),
}
}
}
@@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer {
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
-func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
for i, n := range netProtos {
- if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil {
- d.unregisterEndpoint(netProtos[:i], protocol, id, ep)
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice)
return err
}
}
@@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum
}
// multiPortEndpoint is a container for TransportEndpoints which are bound to
-// the same pair of address and port.
+// the same pair of address and port. endpointsArr always has at least one
+// element.
type multiPortEndpoint struct {
mu sync.RWMutex
endpointsArr []TransportEndpoint
endpointsMap map[TransportEndpoint]int
- // seed is a random secret for a jenkins hash.
- seed uint32
+ // reuse indicates if more than one endpoint is allowed.
+ reuse bool
}
// reciprocalScale scales a value into range [0, n).
@@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 {
// selectEndpoint calculates a hash of destination and source addresses and
// ports then uses it to select a socket. In this case, all packets from one
// address will be sent to same endpoint.
-func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint {
- ep.mu.RLock()
- defer ep.mu.RUnlock()
+func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint {
+ if len(mpep.endpointsArr) == 1 {
+ return mpep.endpointsArr[0]
+ }
payload := []byte{
byte(id.LocalPort),
@@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd
byte(id.RemotePort >> 8),
}
- h := jenkins.Sum32(ep.seed)
+ h := jenkins.Sum32(seed)
h.Write(payload)
h.Write([]byte(id.LocalAddress))
h.Write([]byte(id.RemoteAddress))
hash := h.Sum32()
- idx := reciprocalScale(hash, uint32(len(ep.endpointsArr)))
- return ep.endpointsArr[idx]
+ idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr)))
+ return mpep.endpointsArr[idx]
}
-// HandlePacket is called by the stack when new packets arrive to this transport
-// endpoint.
-func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
- // If this is a broadcast or multicast datagram, deliver the datagram to all
- // endpoints managed by ep.
- if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) {
- for i, endpoint := range ep.endpointsArr {
- // HandlePacket modifies vv, so each endpoint needs its own copy.
- if i == len(ep.endpointsArr)-1 {
- endpoint.HandlePacket(r, id, vv)
- break
- }
- vvCopy := buffer.NewView(vv.Size())
- copy(vvCopy, vv.ToView())
- endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
+func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) {
+ ep.mu.RLock()
+ for i, endpoint := range ep.endpointsArr {
+ // HandlePacket modifies vv, so each endpoint needs its own copy except for
+ // the final one.
+ if i == len(ep.endpointsArr)-1 {
+ endpoint.HandlePacket(r, id, vv)
+ break
}
- } else {
- ep.selectEndpoint(id).HandlePacket(r, id, vv)
+ vvCopy := buffer.NewView(vv.Size())
+ copy(vvCopy, vv.ToView())
+ endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView())
}
+ ep.mu.RUnlock() // Don't use defer for performance reasons.
}
-// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
-func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) {
- ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv)
-}
-
-func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) {
+// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint
+// list. The list might be empty already.
+func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- // A new endpoint is added into endpointsArr and its index there is
- // saved in endpointsMap. This will allows to remove endpoint from
- // the array fast.
+ if len(ep.endpointsArr) > 0 {
+ // If it was previously bound, we need to check if we can bind again.
+ if !ep.reuse || !reusePort {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ // A new endpoint is added into endpointsArr and its index there is saved in
+ // endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+ return nil
}
// unregisterEndpoint returns true if multiPortEndpoint has to be unregistered.
@@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool {
return true
}
-func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error {
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
if id.RemotePort != 0 {
+ // TODO(eyalsoha): Why?
reusePort = false
}
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
- return nil
+ return tcpip.ErrUnknownProtocol
}
eps.mu.Lock()
defer eps.mu.Unlock()
- var multiPortEp *multiPortEndpoint
- if _, ok := eps.endpoints[id]; ok {
- if !reusePort {
- return tcpip.ErrPortInUse
- }
- multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint)
- if !ok {
- return tcpip.ErrPortInUse
- }
+ if epsByNic, ok := eps.endpoints[id]; ok {
+ // There was already a binding.
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
- if reusePort {
- if multiPortEp == nil {
- multiPortEp = &multiPortEndpoint{}
- multiPortEp.endpointsMap = make(map[TransportEndpoint]int)
- multiPortEp.seed = rand.Uint32()
- eps.endpoints[id] = multiPortEp
- }
-
- multiPortEp.singleRegisterEndpoint(ep)
-
- return nil
+ // This is a new binding.
+ epsByNic := &endpointsByNic{
+ endpoints: make(map[tcpip.NICID]*multiPortEndpoint),
+ seed: rand.Uint32(),
}
- eps.endpoints[id] = ep
+ eps.endpoints[id] = epsByNic
- return nil
+ return epsByNic.registerEndpoint(ep, reusePort, bindToDevice)
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
-func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) {
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
- eps.unregisterEndpoint(id, ep)
+ eps.unregisterEndpoint(id, ep, bindToDevice)
}
}
}
@@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// If the packet is a broadcast, then find all matching transport endpoints.
// Otherwise, try to find a single matching transport endpoint.
- destEps := make([]TransportEndpoint, 0, 1)
+ destEps := make([]*endpointsByNic, 0, 1)
eps.mu.RLock()
if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast {
@@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
// Deliver the packet.
for _, ep := range destEps {
- ep.HandlePacket(r, id, vv)
+ ep.handlePacket(r, id, vv)
}
return true
@@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
-func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
+func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
@@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
}
// Deliver the packet.
- ep.HandleControlPacket(id, typ, extra, vv)
+ ep.handleControlPacket(n, id, typ, extra, vv)
return true
}
-func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic {
// Try to find a match with the id as provided.
if ep, ok := eps.endpoints[id]; ok {
return ep
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
new file mode 100644
index 000000000..210233dc0
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -0,0 +1,352 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package stack_test
+
+import (
+ "math"
+ "math/rand"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testPort = 4096
+)
+
+type testContext struct {
+ t *testing.T
+ linkEPs map[string]*channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v6only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed: %v", err)
+ }
+}
+
+// newDualTestContextMultiNic creates the testing context and also linkEpNames
+// named NICs.
+func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ linkEPs := make(map[string]*channel.Endpoint)
+ for i, linkEpName := range linkEpNames {
+ channelEP := channel.New(256, mtu, "")
+ nicid := tcpip.NICID(i + 1)
+ if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ linkEPs[linkEpName] = channelEP
+
+ if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress IPv4 failed: %v", err)
+ }
+
+ if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress IPv6 failed: %v", err)
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: 1,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEPs: linkEPs,
+ }
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u)))
+
+ // Calculate the UDP checksum and set it.
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum))
+
+ // Inject packet.
+ c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
+}
+
+func TestTransportDemuxerRegister(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ want *tcpip.Error
+ }{
+ {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol},
+ {"success", ipv4.ProtocolNumber, nil},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
+ if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want {
+ t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want)
+ }
+ })
+ }
+}
+
+// TestReuseBindToDevice injects varied packets on input devices and checks that
+// the distribution of packets received matches expectations.
+func TestDistribution(t *testing.T) {
+ type endpointSockopts struct {
+ reuse int
+ bindToDevice string
+ }
+ for _, test := range []struct {
+ name string
+ // endpoints will received the inject packets.
+ endpoints []endpointSockopts
+ // wantedDistribution is the wanted ratio of packets received on each
+ // endpoint for each NIC on which packets are injected.
+ wantedDistributions map[string][]float64
+ }{
+ {
+ "BindPortReuse",
+ // 5 endpoints that all have reuse set.
+ []endpointSockopts{
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed evenly.
+ "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2},
+ },
+ },
+ {
+ "BindToDevice",
+ // 3 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{0, "dev0"},
+ endpointSockopts{0, "dev1"},
+ endpointSockopts{0, "dev2"},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 go only to the endpoint bound to dev0.
+ "dev0": []float64{1, 0, 0},
+ // Injected packets on dev1 go only to the endpoint bound to dev1.
+ "dev1": []float64{0, 1, 0},
+ // Injected packets on dev2 go only to the endpoint bound to dev2.
+ "dev2": []float64{0, 0, 1},
+ },
+ },
+ {
+ "ReuseAndBindToDevice",
+ // 6 endpoints with various bindings.
+ []endpointSockopts{
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev0"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, "dev1"},
+ endpointSockopts{1, ""},
+ },
+ map[string][]float64{
+ // Injected packets on dev0 get distributed among endpoints bound to
+ // dev0.
+ "dev0": []float64{0.5, 0.5, 0, 0, 0, 0},
+ // Injected packets on dev1 get distributed among endpoints bound to
+ // dev1 or unbound.
+ "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0},
+ // Injected packets on dev999 go only to the unbound.
+ "dev999": []float64{0, 0, 0, 0, 0, 1},
+ },
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ for device, wantedDistribution := range test.wantedDistributions {
+ t.Run(device, func(t *testing.T) {
+ var devices []string
+ for d := range test.wantedDistributions {
+ devices = append(devices, d)
+ }
+ c := newDualTestContextMultiNic(t, defaultMTU, devices)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ eps := make(map[tcpip.Endpoint]int)
+
+ pollChannel := make(chan tcpip.Endpoint)
+ for i, endpoint := range test.endpoints {
+ // Try to receive the data.
+ wq := waiter.Queue{}
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+ defer close(ch)
+
+ var err *tcpip.Error
+ ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ eps[ep] = i
+
+ go func(ep tcpip.Endpoint) {
+ for range ch {
+ pollChannel <- ep
+ }
+ }(ep)
+
+ defer ep.Close()
+ reusePortOption := tcpip.ReusePortOption(endpoint.reuse)
+ if err := ep.SetSockOpt(reusePortOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err)
+ }
+ bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice)
+ if err := ep.SetSockOpt(bindToDeviceOption); err != nil {
+ c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err)
+ }
+ if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
+ t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err)
+ }
+ }
+
+ npackets := 100000
+ nports := 10000
+ if got, want := len(test.endpoints), len(wantedDistribution); got != want {
+ t.Fatalf("got len(test.endpoints) = %d, want %d", got, want)
+ }
+ ports := make(map[uint16]tcpip.Endpoint)
+ stats := make(map[tcpip.Endpoint]int)
+ for i := 0; i < npackets; i++ {
+ // Send a packet.
+ port := uint16(i % nports)
+ payload := newPayload()
+ c.sendV6Packet(payload,
+ &headers{
+ srcPort: testPort + port,
+ dstPort: stackPort},
+ device)
+
+ var addr tcpip.FullAddress
+ ep := <-pollChannel
+ _, _, err := ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err)
+ }
+ stats[ep]++
+ if i < nports {
+ ports[uint16(i)] = ep
+ } else {
+ // Check that all packets from one client are handled by the same
+ // socket.
+ if want, got := ports[port], ep; want != got {
+ t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got])
+ }
+ }
+ }
+
+ // Check that a packet distribution is as expected.
+ for ep, i := range eps {
+ wantedRatio := wantedDistribution[i]
+ wantedRecv := wantedRatio * float64(npackets)
+ actualRecv := stats[ep]
+ actualRatio := float64(stats[ep]) / float64(npackets)
+ // The deviation is less than 10%.
+ if math.Abs(actualRatio-wantedRatio) > 0.05 {
+ t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets)
+ }
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 847d02982..842a16277 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrUnknownProtocolOption
@@ -122,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.id.RemoteAddress = addr.Addr
- err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false)
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */)
if err != nil {
return err
}
@@ -163,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
f,
- false,
+ false, /* reuse */
+ 0, /* bindtoDevice */
); err != nil {
return err
}
@@ -277,9 +283,16 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
}
}
+func fakeTransFactory() stack.TransportProtocol {
+ return &fakeTransportProtocol{}
+}
+
func TestTransportReceive(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -341,7 +354,10 @@ func TestTransportReceive(t *testing.T) {
func TestTransportControlReceive(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -409,7 +425,10 @@ func TestTransportControlReceive(t *testing.T) {
func TestTransportSend(t *testing.T) {
linkEP := channel.New(10, defaultMTU, "")
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
if err := s.CreateNIC(1, linkEP); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -452,7 +471,10 @@ func TestTransportSend(t *testing.T) {
}
func TestTransportOptions(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
// Try an unsupported transport protocol.
if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
@@ -493,7 +515,10 @@ func TestTransportOptions(t *testing.T) {
}
func TestTransportForwarding(t *testing.T) {
- s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
+ TransportProtocols: []stack.TransportProtocol{fakeTransFactory()},
+ })
s.SetForwarding(true)
// TODO(b/123449044): Change this to a channel NIC.
@@ -571,9 +596,3 @@ func TestTransportForwarding(t *testing.T) {
t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
}
}
-
-func init() {
- stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
- return &fakeTransportProtocol{}
- })
-}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 2534069ab..faaa4a4e3 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -401,6 +401,10 @@ type Endpoint interface {
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
+ // SetSockOptInt sets a socket option, for simple cases where a value
+ // has the int type.
+ SetSockOptInt(opt SockOpt, v int) *Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
@@ -446,10 +450,22 @@ type WriteOptions struct {
type SockOpt int
const (
- // ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
- // unread bytes in the input buffer should be returned.
+ // ReceiveQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the input buffer should be returned.
ReceiveQueueSizeOption SockOpt = iota
+ // SendBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the send buffer size option.
+ SendBufferSizeOption
+
+ // ReceiveBufferSizeOption is used by SetSockOptInt/GetSockOptInt to
+ // specify the receive buffer size option.
+ ReceiveBufferSizeOption
+
+ // SendQueueSizeOption is used in GetSockOptInt to specify that the
+ // number of unread bytes in the output buffer should be returned.
+ SendQueueSizeOption
+
// TODO(b/137664753): convert all int socket options to be handled via
// GetSockOptInt.
)
@@ -458,18 +474,6 @@ const (
// the endpoint should be cleared and returned.
type ErrorOption struct{}
-// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
-// buffer size option.
-type SendBufferSizeOption int
-
-// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
-// receive buffer size option.
-type ReceiveBufferSizeOption int
-
-// SendQueueSizeOption is used in GetSockOpt to specify that the number of
-// unread bytes in the output buffer should be returned.
-type SendQueueSizeOption int
-
// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
@@ -491,6 +495,10 @@ type ReuseAddressOption int
// to be bound to an identical socket address.
type ReusePortOption int
+// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
+// should bind only on a specific NIC.
+type BindToDeviceOption string
+
// QuickAckOption is stubbed out in SetSockOpt/GetSockOpt.
type QuickAckOption int
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 3db060384..a3a910d41 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -104,7 +104,7 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */)
}
// Close the receive list and drain it.
@@ -319,6 +319,11 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+// SetSockOptInt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -331,6 +336,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -341,18 +358,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -538,14 +543,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */)
switch err {
case nil:
return true, nil
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
index 1eb790932..bfb16f7c3 100644
--- a/pkg/tcpip/transport/icmp/protocol.go
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -14,10 +14,9 @@
// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
// protocols for use in ping. To use it in the networking stack, this package
-// must be added to the project, and
-// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
-// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
-// calling stack.New(). Then endpoints can be created by passing
+// must be added to the project, and activated on the stack by passing
+// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport
+// protocols when calling stack.New(). Then endpoints can be created by passing
// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
// when calling Stack.NewEndpoint().
package icmp
@@ -34,15 +33,9 @@ import (
)
const (
- // ProtocolName4 is the string representation of the icmp protocol name.
- ProtocolName4 = "icmp4"
-
// ProtocolNumber4 is the ICMP protocol number.
ProtocolNumber4 = header.ICMPv4ProtocolNumber
- // ProtocolName6 is the string representation of the icmp protocol name.
- ProtocolName6 = "icmp6"
-
// ProtocolNumber6 is the IPv6-ICMP protocol number.
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
@@ -125,12 +118,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber4}
- })
+// NewProtocol4 returns an ICMPv4 transport protocol.
+func NewProtocol4() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+}
- stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
- return &protocol{ProtocolNumber6}
- })
+// NewProtocol6 returns an ICMPv6 transport protocol.
+func NewProtocol6() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index cf1c5c433..a02731a5d 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -492,6 +492,11 @@ func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
@@ -504,6 +509,19 @@ func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
ep.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ v := ep.sndBufSize
+ ep.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ v := ep.rcvBufSizeMax
+ ep.rcvMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
@@ -515,18 +533,6 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- ep.mu.Lock()
- *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
- ep.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- ep.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
- ep.rcvMu.Unlock()
- return nil
-
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go
index 783c21e6b..a2512d666 100644
--- a/pkg/tcpip/transport/raw/protocol.go
+++ b/pkg/tcpip/transport/raw/protocol.go
@@ -20,13 +20,10 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-type factory struct{}
+// EndpointFactory implements stack.UnassociatedEndpointFactory.
+type EndpointFactory struct{}
// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
-func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+func (EndpointFactory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
}
-
-func init() {
- stack.RegisterUnassociatedFactory(factory{})
-}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 39a839ab7..a42e1f4a2 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -49,6 +49,7 @@ go_library(
"//pkg/sleep",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/hash/jenkins",
"//pkg/tcpip/header",
"//pkg/tcpip/iptables",
"//pkg/tcpip/seqnum",
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 0802e984e..3ae4a5426 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
n.initGSO()
// Register new endpoint so that packets are routed to it.
- if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil {
n.Close()
return nil, err
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index dd931f88c..f9d5e0085 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -15,6 +15,7 @@
package tcp
import (
+ "encoding/binary"
"fmt"
"math"
"strings"
@@ -26,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
@@ -280,6 +282,9 @@ type endpoint struct {
// reusePort is set to true if SO_REUSEPORT is enabled.
reusePort bool
+ // bindToDevice is set to the NIC on which to bind or disabled if 0.
+ bindToDevice tcpip.NICID
+
// delay enables Nagle's algorithm.
//
// delay is a boolean (0 is false) and must be accessed atomically.
@@ -564,11 +569,11 @@ func (e *endpoint) Close() {
// in Listen() when trying to register.
if e.state == StateListen && e.isPortReserved {
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -625,12 +630,12 @@ func (e *endpoint) cleanupLocked() {
e.workerCleanup = false
if e.isRegistered {
- e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.isRegistered = false
}
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -952,62 +957,9 @@ func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
}
-// SetSockOpt sets a socket option.
-func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.DelayOption:
- if v == 0 {
- atomic.StoreUint32(&e.delay, 0)
-
- // Handle delayed data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.delay, 1)
- }
- return nil
-
- case tcpip.CorkOption:
- if v == 0 {
- atomic.StoreUint32(&e.cork, 0)
-
- // Handle the corked data.
- e.sndWaker.Assert()
- } else {
- atomic.StoreUint32(&e.cork, 1)
- }
- return nil
-
- case tcpip.ReuseAddressOption:
- e.mu.Lock()
- e.reuseAddr = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.ReusePortOption:
- e.mu.Lock()
- e.reusePort = v != 0
- e.mu.Unlock()
- return nil
-
- case tcpip.QuickAckOption:
- if v == 0 {
- atomic.StoreUint32(&e.slowAck, 1)
- } else {
- atomic.StoreUint32(&e.slowAck, 0)
- }
- return nil
-
- case tcpip.MaxSegOption:
- userMSS := v
- if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
- return tcpip.ErrInvalidOptionValue
- }
- e.mu.Lock()
- e.userMSS = int(userMSS)
- e.mu.Unlock()
- e.notifyProtocolGoroutine(notifyMSSChanged)
- return nil
-
+// SetSockOptInt sets a socket option.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ switch opt {
case tcpip.ReceiveBufferSizeOption:
// Make sure the receive buffer size is within the min and max
// allowed.
@@ -1071,6 +1023,82 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.sndBufMu.Unlock()
return nil
+ default:
+ return nil
+ }
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.MaxSegOption:
+ userMSS := v
+ if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS {
+ return tcpip.ErrInvalidOptionValue
+ }
+ e.mu.Lock()
+ e.userMSS = int(userMSS)
+ e.mu.Unlock()
+ e.notifyProtocolGoroutine(notifyMSSChanged)
+ return nil
+
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -1182,6 +1210,18 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
return e.readyReceiveSize()
+ case tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ v := e.sndBufSize
+ e.sndBufMu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ v := e.rcvBufSize
+ e.rcvListMu.Unlock()
+ return v, nil
+
}
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -1204,18 +1244,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
*o = header.TCPDefaultMSS
return nil
- case *tcpip.SendBufferSizeOption:
- e.sndBufMu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.sndBufMu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvListMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
- e.rcvListMu.Unlock()
- return nil
-
case *tcpip.DelayOption:
*o = 0
if v := atomic.LoadUint32(&e.delay); v != 0 {
@@ -1252,6 +1280,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = ""
+ return nil
+
case *tcpip.QuickAckOption:
*o = 1
if v := atomic.LoadUint32(&e.slowAck); v != 0 {
@@ -1458,7 +1496,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
if e.id.LocalPort != 0 {
// The endpoint is bound to a port, attempt to register it.
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1468,17 +1506,32 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// address/port for both local and remote (otherwise this
// endpoint would be trying to connect to itself).
sameAddr := e.id.LocalAddress == e.id.RemoteAddress
- if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+
+ // Calculate a port offset based on the destination IP/port and
+ // src IP to ensure that for a given tuple (srcIP, destIP,
+ // destPort) the offset used as a starting point is the same to
+ // ensure that we can cycle through the port space effectively.
+ h := jenkins.Sum32(e.stack.PortSeed())
+ h.Write([]byte(e.id.LocalAddress))
+ h.Write([]byte(e.id.RemoteAddress))
+ portBuf := make([]byte, 2)
+ binary.LittleEndian.PutUint16(portBuf, e.id.RemotePort)
+ h.Write(portBuf)
+ portOffset := h.Sum32()
+
+ if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) {
if sameAddr && p == e.id.RemotePort {
return false, nil
}
- if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ // reusePort is false below because connect cannot reuse a port even if
+ // reusePort was set.
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) {
return false, nil
}
id := e.id
id.LocalPort = p
- switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) {
case nil:
e.id = id
return true, nil
@@ -1496,7 +1549,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er
// before Connect: in such a case we don't want to hold on to
// reservations anymore.
if e.isPortReserved {
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice)
e.isPortReserved = false
}
@@ -1640,7 +1693,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
}
// Register the endpoint.
- if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil {
return err
}
@@ -1721,7 +1774,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
}
}
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice)
if err != nil {
return err
}
@@ -1731,16 +1784,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
e.id.LocalPort = port
// Any failures beyond this point must remove the port registration.
- defer func() {
+ defer func(bindToDevice tcpip.NICID) {
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice)
e.isPortReserved = false
e.effectiveNetProtos = nil
e.id.LocalPort = 0
e.id.LocalAddress = ""
e.boundNICID = 0
}
- }()
+ }(e.bindToDevice)
// If an address is specified, we must ensure that it's one of our
// local addresses.
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a13b2022..d5d8ab96a 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -14,7 +14,7 @@
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// activated on the stack by passing tcp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing tcp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -34,9 +34,6 @@ import (
)
const (
- // ProtocolName is the string representation of the tcp protocol name.
- ProtocolName = "tcp"
-
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
@@ -254,13 +251,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
}
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{
- sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
- recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
- congestionControl: ccReno,
- availableCongestionControl: []string{ccReno, ccCubic},
- }
- })
+// NewProtocol returns a TCP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{MinBufferSize, DefaultSendBufferSize, MaxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{MinBufferSize, DefaultReceiveBufferSize, MaxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
}
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index 272bbcdbd..9fa97528b 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -38,7 +38,7 @@ func TestFastRecovery(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -190,7 +190,7 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -232,7 +232,7 @@ func TestCongestionAvoidance(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -336,7 +336,7 @@ func TestCubicCongestionAvoidance(t *testing.T) {
enableCUBIC(t, c)
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
@@ -445,7 +445,7 @@ func TestRetransmit(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 7
data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 32bb45224..089826a88 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -84,7 +84,7 @@ func TestConnectIncrementActiveConnection(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.ActiveConnectionOpenings.Value() + 1
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want {
t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -97,7 +97,7 @@ func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) {
stats := c.Stack().Stats()
want := stats.TCP.FailedConnectionAttempts.Value()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.FailedConnectionAttempts.Value(); got != want {
t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want)
}
@@ -131,7 +131,7 @@ func TestTCPSegmentsSentIncrement(t *testing.T) {
stats := c.Stack().Stats()
// SYN and ACK
want := stats.TCP.SegmentsSent.Value() + 2
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if got := stats.TCP.SegmentsSent.Value(); got != want {
t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want)
@@ -299,7 +299,7 @@ func TestTCPResetsReceivedIncrement(t *testing.T) {
want := stats.TCP.ResetsReceived.Value() + 1
iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -323,7 +323,7 @@ func TestTCPResetsDoNotGenerateResets(t *testing.T) {
want := stats.TCP.ResetsReceived.Value() + 1
iss := seqnum.Value(789)
rcvWnd := seqnum.Size(30000)
- c.CreateConnected(iss, rcvWnd, nil)
+ c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -344,14 +344,14 @@ func TestActiveHandshake(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
}
func TestNonBlockingClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -367,7 +367,7 @@ func TestConnectResetAfterClose(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
ep := c.EP
c.EP = nil
@@ -417,7 +417,7 @@ func TestSimpleReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -465,11 +465,71 @@ func TestSimpleReceive(t *testing.T) {
)
}
+func TestConnectBindToDevice(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ device string
+ want tcp.EndpointState
+ }{
+ {"RightDevice", "nic1", tcp.StateEstablished},
+ {"WrongDevice", "nic2", tcp.StateSynSent},
+ {"AnyDevice", "", tcp.StateEstablished},
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.Create(-1)
+ bindToDevice := tcpip.BindToDeviceOption(test.device)
+ c.EP.SetSockOpt(bindToDevice)
+ // Start connection attempt.
+ waitEntry, _ := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+ if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ tcpHdr := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcpHdr.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ rcvWnd := seqnum.Size(30000)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcpHdr.DestinationPort(),
+ DstPort: tcpHdr.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: nil,
+ })
+
+ c.GetPacket()
+ if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want {
+ t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
func TestOutOfOrderReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -557,8 +617,7 @@ func TestOutOfOrderFlood(t *testing.T) {
defer c.Cleanup()
// Create a new connection with initial window size of 10.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -631,7 +690,7 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -700,7 +759,7 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -785,7 +844,7 @@ func TestShutdownRead(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock)
@@ -804,8 +863,7 @@ func TestFullWindowReceive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -872,11 +930,9 @@ func TestNoWindowShrinking(t *testing.T) {
defer c.Cleanup()
// Start off with a window size of 10, then shrink it to 5.
- opt := tcpip.ReceiveBufferSizeOption(10)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 10)
- opt = 5
- if err := c.EP.SetSockOpt(opt); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil {
t.Fatalf("SetSockOpt failed: %v", err)
}
@@ -976,7 +1032,7 @@ func TestSimpleSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1017,7 +1073,7 @@ func TestZeroWindowSend(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 0, nil)
+ c.CreateConnected(789, 0, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1075,8 +1131,7 @@ func TestScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1110,8 +1165,7 @@ func TestNonScaledWindowConnect(t *testing.T) {
defer c.Cleanup()
// Set the window size greater than the maximum non-scaled window.
- opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
- c.CreateConnected(789, 30000, &opt)
+ c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
@@ -1151,7 +1205,7 @@ func TestScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1224,7 +1278,7 @@ func TestNonScaledWindowAccept(t *testing.T) {
defer ep.Close()
// Set the window size greater than the maximum non-scaled window.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1293,8 +1347,7 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// Set the window size such that a window scale of 4 will be used.
const wnd = 65535 * 10
const ws = uint32(4)
- opt := tcpip.ReceiveBufferSizeOption(wnd)
- c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{
header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
})
@@ -1399,7 +1452,7 @@ func TestSegmentMerging(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Prevent the endpoint from processing packets.
test.stop(c.EP)
@@ -1449,7 +1502,7 @@ func TestDelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1497,7 +1550,7 @@ func TestUndelay(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.DelayOption(1))
@@ -1579,7 +1632,7 @@ func TestMSSNotDelayed(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -1695,7 +1748,7 @@ func TestSendGreaterThanMTU(t *testing.T) {
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
testBrokenUpWrite(t, c, maxPayload)
}
@@ -1704,7 +1757,7 @@ func TestActiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, 65535)
defer c.Cleanup()
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
testBrokenUpWrite(t, c, maxPayload)
@@ -1727,7 +1780,7 @@ func TestPassiveSendMSSLessThanMTU(t *testing.T) {
// Set the buffer size to a deterministic size so that we can check the
// window scaling option.
const rcvBufferSize = 0x20000
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1871,7 +1924,7 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// window scaling option.
const rcvBufferSize = 0x20000
const wndScale = 2
- if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil {
t.Fatalf("SetSockOpt failed failed: %v", err)
}
@@ -1973,7 +2026,7 @@ func TestReceiveOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -2010,7 +2063,7 @@ func TestSendOnResetConnection(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send RST segment.
c.SendPacket(nil, &context.Headers{
@@ -2035,7 +2088,7 @@ func TestFinImmediately(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2078,7 +2131,7 @@ func TestFinRetransmit(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Shutdown immediately, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -2132,7 +2185,7 @@ func TestFinWithNoPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
view := buffer.NewView(10)
@@ -2203,7 +2256,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
@@ -2291,7 +2344,7 @@ func TestFinWithPendingData(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
view := buffer.NewView(10)
@@ -2377,7 +2430,7 @@ func TestFinWithPartialAck(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
@@ -2509,7 +2562,7 @@ func scaledSendWindow(t *testing.T, scale uint8) {
defer c.Cleanup()
maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
- c.CreateConnectedWithRawOptions(789, 0, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
})
@@ -2559,7 +2612,7 @@ func TestScaledSendWindow(t *testing.T) {
func TestReceivedValidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ValidSegmentsReceived.Value() + 1
@@ -2580,7 +2633,7 @@ func TestReceivedValidSegmentCountIncrement(t *testing.T) {
func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.InvalidSegmentsReceived.Value() + 1
vv := c.BuildSegment(nil, &context.Headers{
@@ -2604,7 +2657,7 @@ func TestReceivedInvalidSegmentCountIncrement(t *testing.T) {
func TestReceivedIncorrectChecksumIncrement(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
stats := c.Stack().Stats()
want := stats.TCP.ChecksumErrors.Value() + 1
vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{
@@ -2635,7 +2688,7 @@ func TestReceivedSegmentQueuing(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Send 200 segments.
data := []byte{1, 2, 3}
@@ -2681,7 +2734,7 @@ func TestReadAfterClosedState(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
@@ -2856,8 +2909,8 @@ func TestReusePort(t *testing.T) {
func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2869,8 +2922,8 @@ func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
t.Helper()
- var s tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&s); err != nil {
+ s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
@@ -2880,7 +2933,10 @@ func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
}
func TestDefaultBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2926,7 +2982,10 @@ func TestDefaultBufferSizes(t *testing.T) {
}
func TestMinMaxBufferSizes(t *testing.T) {
- s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Check the default values.
ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
@@ -2945,37 +3004,96 @@ func TestMinMaxBufferSizes(t *testing.T) {
}
// Set values below the min.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, 200)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, 300)
// Set values above the max.
- if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultReceiveBufferSize*20)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20)
- if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultSendBufferSize*30)); err != nil {
+ if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil {
t.Fatalf("GetSockOpt failed: %v", err)
}
checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30)
}
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}})
+
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
+
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
+ }
+
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
+ }
+
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
+ }
+}
+
func makeStack() (*stack.Stack, *tcpip.Error) {
- s := stack.New([]string{
- ipv4.ProtocolName,
- ipv6.ProtocolName,
- }, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{
+ ipv4.NewProtocol(),
+ ipv6.NewProtocol(),
+ },
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
id := loopback.New()
if testing.Verbose() {
@@ -3231,7 +3349,7 @@ func TestPathMTUDiscovery(t *testing.T) {
// Create new connection with MSS of 1460.
const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
- c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{
header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
})
@@ -3308,7 +3426,7 @@ func TestTCPEndpointProbe(t *testing.T) {
invoked <- struct{}{}
})
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -3482,7 +3600,7 @@ func TestKeepalive(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- c.CreateConnected(789, 30000, nil)
+ c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(10 * time.Millisecond))
c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(10 * time.Millisecond))
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index 16783e716..ef823e4ae 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -137,7 +137,10 @@ type Context struct {
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()},
+ })
// Allow minimum send/receive buffer sizes to be 1 during tests.
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultSendBufferSize, 10 * tcp.DefaultSendBufferSize}); err != nil {
@@ -155,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context {
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
+ if err := s.CreateNamedNIC(1, "nic1", wep); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+ wep2 := stack.LinkEndpoint(channel.New(1000, mtu, ""))
+ if testing.Verbose() {
+ wep2 = sniffer.New(channel.New(1000, mtu, ""))
+ }
+ if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
@@ -512,7 +522,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
}
// CreateConnected creates a connected TCP endpoint.
-func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) {
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
@@ -585,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
c.Port = tcpHdr.SourcePort()
}
-// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
-// the specified option bytes as the Option field in the initial SYN packet.
-//
-// It also sets the receive buffer for the endpoint to the specified
-// value in epRcvBuf.
-func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+// Create creates a TCP endpoint.
+func (c *Context) Create(epRcvBuf int) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -598,11 +604,20 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.
c.t.Fatalf("NewEndpoint failed: %v", err)
}
- if epRcvBuf != nil {
- if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ if epRcvBuf != -1 {
+ if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) {
+ c.Create(epRcvBuf)
c.Connect(iss, rcvWnd, options)
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index c1ca22b35..7a635ab8d 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -52,6 +52,7 @@ go_test(
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 6ac7c067a..52f5af777 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -88,6 +88,7 @@ type endpoint struct {
multicastNICID tcpip.NICID
multicastLoop bool
reusePort bool
+ bindToDevice tcpip.NICID
broadcast bool
// shutdownFlags represent the current shutdown state of the endpoint.
@@ -144,8 +145,8 @@ func (e *endpoint) Close() {
switch e.state {
case StateBound, StateConnected:
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
}
for _, mem := range e.multicastMemberships {
@@ -389,7 +390,12 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
-// SetSockOpt sets a socket option. Currently not supported.
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
@@ -546,6 +552,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
e.reusePort = v != 0
e.mu.Unlock()
+ case tcpip.BindToDeviceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ if v == "" {
+ e.bindToDevice = 0
+ return nil
+ }
+ for nicid, nic := range e.stack.NICInfo() {
+ if nic.Name == string(v) {
+ e.bindToDevice = nicid
+ return nil
+ }
+ }
+ return tcpip.ErrUnknownDevice
+
case tcpip.BroadcastOption:
e.mu.Lock()
e.broadcast = v != 0
@@ -568,7 +589,20 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
}
e.rcvMu.Unlock()
return v, nil
+
+ case tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ v := e.sndBufSize
+ e.mu.Unlock()
+ return v, nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ v := e.rcvBufSizeMax
+ e.rcvMu.Unlock()
+ return v, nil
}
+
return -1, tcpip.ErrUnknownProtocolOption
}
@@ -578,18 +612,6 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
case tcpip.ErrorOption:
return nil
- case *tcpip.SendBufferSizeOption:
- e.mu.Lock()
- *o = tcpip.SendBufferSizeOption(e.sndBufSize)
- e.mu.Unlock()
- return nil
-
- case *tcpip.ReceiveBufferSizeOption:
- e.rcvMu.Lock()
- *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
- e.rcvMu.Unlock()
- return nil
-
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
@@ -640,6 +662,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
return nil
+ case *tcpip.BindToDeviceOption:
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok {
+ *o = tcpip.BindToDeviceOption(nic.Name)
+ return nil
+ }
+ *o = tcpip.BindToDeviceOption("")
+ return nil
+
case *tcpip.KeepaliveEnabledOption:
*o = 0
return nil
@@ -747,12 +779,12 @@ func (e *endpoint) Disconnect() *tcpip.Error {
} else {
if e.id.LocalPort != 0 {
// Release the ephemeral port.
- e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice)
}
e.state = StateInitial
}
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
e.id = id
e.route.Release()
e.route = stack.Route{}
@@ -829,7 +861,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Remove the old registration.
if e.id.LocalPort != 0 {
- e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice)
}
e.id = id
@@ -892,16 +924,16 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if e.id.LocalPort == 0 {
- port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice)
if err != nil {
return id, err
}
id.LocalPort = port
}
- err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice)
if err != nil {
- e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice)
}
return id, err
}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
index a9edc2c8d..2d0bc5221 100644
--- a/pkg/tcpip/transport/udp/forwarder.go
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -74,7 +74,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID {
// CreateEndpoint creates a connected UDP endpoint for the session request.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(r.stack, r.route.NetProto, queue)
- if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil {
ep.Close()
return nil, err
}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
index 068d9a272..f5cc932dd 100644
--- a/pkg/tcpip/transport/udp/protocol.go
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -14,7 +14,7 @@
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
-// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// activated on the stack by passing udp.NewProtocol() as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
@@ -30,9 +30,6 @@ import (
)
const (
- // ProtocolName is the string representation of the udp protocol name.
- ProtocolName = "udp"
-
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
@@ -182,8 +179,7 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-func init() {
- stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
- return &protocol{}
- })
+// NewProtocol returns a UDP transport protocol.
+func NewProtocol() stack.TransportProtocol {
+ return &protocol{}
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index c6deab892..5059ca22d 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -17,7 +17,6 @@ package udp_test
import (
"bytes"
"fmt"
- "math"
"math/rand"
"testing"
"time"
@@ -27,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
+ "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -274,7 +274,10 @@ type testContext struct {
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
t.Helper()
- s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()},
+ })
ep := channel.New(256, mtu, "")
wep := stack.LinkEndpoint(ep)
@@ -473,87 +476,59 @@ func newMinPayload(minSize int) []byte {
return b
}
-func TestBindPortReuse(t *testing.T) {
- c := newDualTestContext(t, defaultMTU)
- defer c.cleanup()
-
- c.createEndpoint(ipv6.ProtocolNumber)
-
- var eps [5]tcpip.Endpoint
- reusePortOpt := tcpip.ReusePortOption(1)
+func TestBindToDeviceOption(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()},
+ TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}})
- pollChannel := make(chan tcpip.Endpoint)
- for i := 0; i < len(eps); i++ {
- // Try to receive the data.
- wq := waiter.Queue{}
- we, ch := waiter.NewChannelEntry(nil)
- wq.EventRegister(&we, waiter.EventIn)
- defer wq.EventUnregister(&we)
- defer close(ch)
-
- var err *tcpip.Error
- eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
- if err != nil {
- c.t.Fatalf("NewEndpoint failed: %v", err)
- }
-
- go func(ep tcpip.Endpoint) {
- for range ch {
- pollChannel <- ep
- }
- }(eps[i])
-
- defer eps[i].Close()
- if err := eps[i].SetSockOpt(reusePortOpt); err != nil {
- c.t.Fatalf("SetSockOpt failed failed: %v", err)
- }
- if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil {
- t.Fatalf("ep.Bind(...) failed: %v", err)
- }
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
}
+ defer ep.Close()
- npackets := 100000
- nports := 10000
- ports := make(map[uint16]tcpip.Endpoint)
- stats := make(map[tcpip.Endpoint]int)
- for i := 0; i < npackets; i++ {
- // Send a packet.
- port := uint16(i % nports)
- payload := newPayload()
- c.injectV6Packet(payload, &header4Tuple{
- srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port},
- dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort},
- })
+ if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil {
+ t.Errorf("CreateNamedNIC failed: %v", err)
+ }
- var addr tcpip.FullAddress
- ep := <-pollChannel
- _, _, err := ep.Read(&addr)
- if err != nil {
- c.t.Fatalf("Read failed: %v", err)
- }
- stats[ep]++
- if i < nports {
- ports[uint16(i)] = ep
- } else {
- // Check that all packets from one client are handled
- // by the same socket.
- if ports[port] != ep {
- t.Fatalf("Port mismatch")
- }
- }
+ // Make an nameless NIC.
+ if err := s.CreateNIC(54321, loopback.New()); err != nil {
+ t.Errorf("CreateNIC failed: %v", err)
}
- if len(stats) != len(eps) {
- t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps))
+ // strPtr is used instead of taking the address of string literals, which is
+ // a compiler error.
+ strPtr := func(s string) *string {
+ return &s
}
- // Check that a packet distribution is fair between sockets.
- for _, c := range stats {
- n := float64(npackets) / float64(len(eps))
- // The deviation is less than 10%.
- if math.Abs(float64(c)-n) > n/10 {
- t.Fatal(c, n)
- }
+ testActions := []struct {
+ name string
+ setBindToDevice *string
+ setBindToDeviceError *tcpip.Error
+ getBindToDevice tcpip.BindToDeviceOption
+ }{
+ {"GetDefaultValue", nil, nil, ""},
+ {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""},
+ {"BindToExistent", strPtr("my_device"), nil, "my_device"},
+ {"UnbindToDevice", strPtr(""), nil, ""},
+ }
+ for _, testAction := range testActions {
+ t.Run(testAction.name, func(t *testing.T) {
+ if testAction.setBindToDevice != nil {
+ bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice)
+ if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want {
+ t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want)
+ }
+ }
+ bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt")
+ if ep.GetSockOpt(&bindToDevice) != nil {
+ t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil)
+ }
+ if got, want := bindToDevice, testAction.getBindToDevice; got != want {
+ t.Errorf("bindToDevice got %q, want %q", got, want)
+ }
+ })
}
}
diff --git a/runsc/BUILD b/runsc/BUILD
index 5e7dacb87..e4e8e64a3 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -1,7 +1,7 @@
package(licenses = ["notice"]) # Apache 2.0
load("@io_bazel_rules_go//go:def.bzl", "go_binary")
-load("@bazel_tools//tools/build_defs/pkg:pkg.bzl", "pkg_deb", "pkg_tar")
+load("@rules_pkg//:pkg.bzl", "pkg_deb", "pkg_tar")
go_binary(
name = "runsc",
@@ -91,11 +91,6 @@ pkg_deb(
maintainer = "The gVisor Authors <gvisor-dev@googlegroups.com>",
package = "runsc",
postinst = "debian/postinst.sh",
- tags = [
- # TODO(b/135475885): pkg_deb requires python2:
- # https://github.com/bazelbuild/bazel/issues/8443
- "manual",
- ],
version_file = ":version.txt",
visibility = [
"//visibility:public",
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 588bb8851..d90381c0f 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -80,6 +80,7 @@ go_library(
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
+ "//pkg/tcpip/transport/raw",
"//pkg/tcpip/transport/tcp",
"//pkg/tcpip/transport/udp",
"//pkg/urpc",
@@ -109,6 +110,7 @@ go_test(
"//pkg/sentry/arch:registers_go_proto",
"//pkg/sentry/context/contexttest",
"//pkg/sentry/fs",
+ "//pkg/sentry/kernel/auth",
"//pkg/unet",
"//runsc/fsgofer",
"@com_github_opencontainers_runtime-spec//specs-go:go_default_library",
diff --git a/runsc/boot/config.go b/runsc/boot/config.go
index 31103367d..38278d0a2 100644
--- a/runsc/boot/config.go
+++ b/runsc/boot/config.go
@@ -167,6 +167,9 @@ type Config struct {
// Overlay is whether to wrap the root filesystem in an overlay.
Overlay bool
+ // FSGoferHostUDS enables the gofer to mount a host UDS.
+ FSGoferHostUDS bool
+
// Network indicates what type of network to use.
Network NetworkType
@@ -253,6 +256,7 @@ func (c *Config) ToFlags() []string {
"--debug-log-format=" + c.DebugLogFormat,
"--file-access=" + c.FileAccess.String(),
"--overlay=" + strconv.FormatBool(c.Overlay),
+ "--fsgofer-host-uds=" + strconv.FormatBool(c.FSGoferHostUDS),
"--network=" + c.Network.String(),
"--log-packets=" + strconv.FormatBool(c.LogPackets),
"--platform=" + c.Platform,
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 823a34619..adf345490 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -20,7 +20,6 @@ import (
mrand "math/rand"
"os"
"runtime"
- "strings"
"sync"
"sync/atomic"
"syscall"
@@ -55,6 +54,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/raw"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/runsc/boot/filter"
@@ -535,23 +535,12 @@ func (l *Loader) run() error {
return err
}
- // Read /etc/passwd for the user's HOME directory and set the HOME
- // environment variable as required by POSIX if it is not overridden by
- // the user.
- hasHomeEnvv := false
- for _, envv := range l.rootProcArgs.Envv {
- if strings.HasPrefix(envv, "HOME=") {
- hasHomeEnvv = true
- }
- }
- if !hasHomeEnvv {
- homeDir, err := getExecUserHome(ctx, l.rootProcArgs.MountNamespace, uint32(l.rootProcArgs.Credentials.RealKUID))
- if err != nil {
- return fmt.Errorf("error reading exec user: %v", err)
- }
-
- l.rootProcArgs.Envv = append(l.rootProcArgs.Envv, "HOME="+homeDir)
+ // Add the HOME enviroment variable if it is not already set.
+ envv, err := maybeAddExecUserHome(ctx, l.rootProcArgs.MountNamespace, l.rootProcArgs.Credentials.RealKUID, l.rootProcArgs.Envv)
+ if err != nil {
+ return err
}
+ l.rootProcArgs.Envv = envv
// Create the root container init task. It will begin running
// when the kernel is started.
@@ -815,6 +804,16 @@ func (l *Loader) executeAsync(args *control.ExecArgs) (kernel.ThreadID, error) {
})
defer args.MountNamespace.DecRef()
+ // Add the HOME enviroment varible if it is not already set.
+ root := args.MountNamespace.Root()
+ defer root.DecRef()
+ ctx := fs.WithRoot(l.k.SupervisorContext(), root)
+ envv, err := maybeAddExecUserHome(ctx, args.MountNamespace, args.KUID, args.Envv)
+ if err != nil {
+ return 0, err
+ }
+ args.Envv = envv
+
// Start the process.
proc := control.Proc{Kernel: l.k}
args.PIDNamespace = tg.PIDNamespace()
@@ -913,15 +912,17 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
case NetworkNone, NetworkSandbox:
// NetworkNone sets up loopback using netstack.
- netProtos := []string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}
- protoNames := []string{tcp.ProtocolName, udp.ProtocolName, icmp.ProtocolName4}
- s := epsocket.Stack{stack.New(netProtos, protoNames, stack.Options{
- Clock: clock,
- Stats: epsocket.Metrics,
- HandleLocal: true,
+ netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol(), arp.NewProtocol()}
+ transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol(), icmp.NewProtocol4()}
+ s := epsocket.Stack{stack.New(stack.Options{
+ NetworkProtocols: netProtos,
+ TransportProtocols: transProtos,
+ Clock: clock,
+ Stats: epsocket.Metrics,
+ HandleLocal: true,
// Enable raw sockets for users with sufficient
// privileges.
- Raw: true,
+ UnassociatedFactory: raw.EndpointFactory{},
})}
// Enable SACK Recovery.
diff --git a/runsc/boot/user.go b/runsc/boot/user.go
index d1d423a5c..56cc12ee0 100644
--- a/runsc/boot/user.go
+++ b/runsc/boot/user.go
@@ -16,6 +16,7 @@ package boot
import (
"bufio"
+ "fmt"
"io"
"strconv"
"strings"
@@ -23,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/usermem"
)
@@ -42,7 +44,7 @@ func (r *fileReader) Read(buf []byte) (int, error) {
// getExecUserHome returns the home directory of the executing user read from
// /etc/passwd as read from the container filesystem.
-func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32) (string, error) {
+func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid auth.KUID) (string, error) {
// The default user home directory to return if no user matching the user
// if found in the /etc/passwd found in the image.
const defaultHome = "/"
@@ -82,7 +84,7 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32
File: f,
}
- homeDir, err := findHomeInPasswd(uid, r, defaultHome)
+ homeDir, err := findHomeInPasswd(uint32(uid), r, defaultHome)
if err != nil {
return "", err
}
@@ -90,6 +92,28 @@ func getExecUserHome(ctx context.Context, rootMns *fs.MountNamespace, uid uint32
return homeDir, nil
}
+// maybeAddExecUserHome returns a new slice with the HOME enviroment variable
+// set if the slice does not already contain it, otherwise it returns the
+// original slice unmodified.
+func maybeAddExecUserHome(ctx context.Context, mns *fs.MountNamespace, uid auth.KUID, envv []string) ([]string, error) {
+ // Check if the envv already contains HOME.
+ for _, env := range envv {
+ if strings.HasPrefix(env, "HOME=") {
+ // We have it. Return the original slice unmodified.
+ return envv, nil
+ }
+ }
+
+ // Read /etc/passwd for the user's HOME directory and set the HOME
+ // environment variable as required by POSIX if it is not overridden by
+ // the user.
+ homeDir, err := getExecUserHome(ctx, mns, uid)
+ if err != nil {
+ return nil, fmt.Errorf("error reading exec user: %v", err)
+ }
+ return append(envv, "HOME="+homeDir), nil
+}
+
// findHomeInPasswd parses a passwd file and returns the given user's home
// directory. This function does it's best to replicate the runc's behavior.
func findHomeInPasswd(uid uint32, passwd io.Reader, defaultHome string) (string, error) {
diff --git a/runsc/boot/user_test.go b/runsc/boot/user_test.go
index 906baf3e5..9aee2ad07 100644
--- a/runsc/boot/user_test.go
+++ b/runsc/boot/user_test.go
@@ -25,6 +25,7 @@ import (
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/sentry/context/contexttest"
"gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
func setupTempDir() (string, error) {
@@ -68,7 +69,7 @@ func setupPasswd(contents string, perms os.FileMode) func() (string, error) {
// TestGetExecUserHome tests the getExecUserHome function.
func TestGetExecUserHome(t *testing.T) {
tests := map[string]struct {
- uid uint32
+ uid auth.KUID
createRoot func() (string, error)
expected string
}{
diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go
index e817eff77..d1e99243b 100644
--- a/runsc/cmd/exec.go
+++ b/runsc/cmd/exec.go
@@ -105,11 +105,11 @@ func (ex *Exec) SetFlags(f *flag.FlagSet) {
// Execute implements subcommands.Command.Execute. It starts a process in an
// already created container.
func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
- e, id, err := ex.parseArgs(f)
+ conf := args[0].(*boot.Config)
+ e, id, err := ex.parseArgs(f, conf.EnableRaw)
if err != nil {
Fatalf("parsing process spec: %v", err)
}
- conf := args[0].(*boot.Config)
waitStatus := args[1].(*syscall.WaitStatus)
c, err := container.Load(conf.RootDir, id)
@@ -117,6 +117,9 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("loading sandbox: %v", err)
}
+ log.Debugf("Exec arguments: %+v", e)
+ log.Debugf("Exec capablities: %+v", e.Capabilities)
+
// Replace empty settings with defaults from container.
if e.WorkingDirectory == "" {
e.WorkingDirectory = c.Spec.Process.Cwd
@@ -127,15 +130,13 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("getting environment variables: %v", err)
}
}
+
if e.Capabilities == nil {
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec
- // requires the capability to be set explicitly, while 'docker
- // run' sets it by default.
- e.Capabilities, err = specutils.Capabilities(true /* enableRaw */, c.Spec.Process.Capabilities)
+ e.Capabilities, err = specutils.Capabilities(conf.EnableRaw, c.Spec.Process.Capabilities)
if err != nil {
Fatalf("creating capabilities: %v", err)
}
+ log.Infof("Using exec capabilities from container: %+v", e.Capabilities)
}
// containerd expects an actual process to represent the container being
@@ -282,14 +283,14 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi
// parseArgs parses exec information from the command line or a JSON file
// depending on whether the --process flag was used. Returns an ExecArgs and
// the ID of the container to be used.
-func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) {
+func (ex *Exec) parseArgs(f *flag.FlagSet, enableRaw bool) (*control.ExecArgs, string, error) {
if ex.processPath == "" {
// Requires at least a container ID and command.
if f.NArg() < 2 {
f.Usage()
return nil, "", fmt.Errorf("both a container-id and command are required")
}
- e, err := ex.argsFromCLI(f.Args()[1:])
+ e, err := ex.argsFromCLI(f.Args()[1:], enableRaw)
return e, f.Arg(0), err
}
// Requires only the container ID.
@@ -297,11 +298,11 @@ func (ex *Exec) parseArgs(f *flag.FlagSet) (*control.ExecArgs, string, error) {
f.Usage()
return nil, "", fmt.Errorf("a container-id is required")
}
- e, err := ex.argsFromProcessFile()
+ e, err := ex.argsFromProcessFile(enableRaw)
return e, f.Arg(0), err
}
-func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
+func (ex *Exec) argsFromCLI(argv []string, enableRaw bool) (*control.ExecArgs, error) {
extraKGIDs := make([]auth.KGID, 0, len(ex.extraKGIDs))
for _, s := range ex.extraKGIDs {
kgid, err := strconv.Atoi(s)
@@ -314,7 +315,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
var caps *auth.TaskCapabilities
if len(ex.caps) > 0 {
var err error
- caps, err = capabilities(ex.caps)
+ caps, err = capabilities(ex.caps, enableRaw)
if err != nil {
return nil, fmt.Errorf("capabilities error: %v", err)
}
@@ -332,7 +333,7 @@ func (ex *Exec) argsFromCLI(argv []string) (*control.ExecArgs, error) {
}, nil
}
-func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) {
+func (ex *Exec) argsFromProcessFile(enableRaw bool) (*control.ExecArgs, error) {
f, err := os.Open(ex.processPath)
if err != nil {
return nil, fmt.Errorf("error opening process file: %s, %v", ex.processPath, err)
@@ -342,21 +343,21 @@ func (ex *Exec) argsFromProcessFile() (*control.ExecArgs, error) {
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, fmt.Errorf("error parsing process file: %s, %v", ex.processPath, err)
}
- return argsFromProcess(&p)
+ return argsFromProcess(&p, enableRaw)
}
// argsFromProcess performs all the non-IO conversion from the Process struct
// to ExecArgs.
-func argsFromProcess(p *specs.Process) (*control.ExecArgs, error) {
+func argsFromProcess(p *specs.Process, enableRaw bool) (*control.ExecArgs, error) {
// Create capabilities.
var caps *auth.TaskCapabilities
if p.Capabilities != nil {
var err error
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec
- // requires the capability to be set explicitly, while 'docker
- // run' sets it by default.
- caps, err = specutils.Capabilities(true /* enableRaw */, p.Capabilities)
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ caps, err = specutils.Capabilities(enableRaw, p.Capabilities)
if err != nil {
return nil, fmt.Errorf("error creating capabilities: %v", err)
}
@@ -409,7 +410,7 @@ func resolveEnvs(envs ...[]string) ([]string, error) {
// capabilities takes a list of capabilities as strings and returns an
// auth.TaskCapabilities struct with those capabilities in every capability set.
// This mimics runc's behavior.
-func capabilities(cs []string) (*auth.TaskCapabilities, error) {
+func capabilities(cs []string, enableRaw bool) (*auth.TaskCapabilities, error) {
var specCaps specs.LinuxCapabilities
for _, cap := range cs {
specCaps.Ambient = append(specCaps.Ambient, cap)
@@ -418,11 +419,11 @@ func capabilities(cs []string) (*auth.TaskCapabilities, error) {
specCaps.Inheritable = append(specCaps.Inheritable, cap)
specCaps.Permitted = append(specCaps.Permitted, cap)
}
- // enableRaw is set to true to prevent the filtering out of
- // CAP_NET_RAW. This is the opposite of Create() because exec requires
- // the capability to be set explicitly, while 'docker run' sets it by
- // default.
- return specutils.Capabilities(true /* enableRaw */, &specCaps)
+ // Starting from Docker 19, capabilities are explicitly set for exec (instead
+ // of nil like before). So we can't distinguish 'exec' from
+ // 'exec --privileged', as both specify CAP_NET_RAW. Therefore, filter
+ // CAP_NET_RAW in the same way as container start.
+ return specutils.Capabilities(enableRaw, &specCaps)
}
// stringSlice allows a flag to be used multiple times, where each occurrence
diff --git a/runsc/cmd/exec_test.go b/runsc/cmd/exec_test.go
index eb38a431f..a1e980d08 100644
--- a/runsc/cmd/exec_test.go
+++ b/runsc/cmd/exec_test.go
@@ -91,7 +91,7 @@ func TestCLIArgs(t *testing.T) {
}
for _, tc := range testCases {
- e, err := tc.ex.argsFromCLI(tc.argv)
+ e, err := tc.ex.argsFromCLI(tc.argv, true)
if err != nil {
t.Errorf("argsFromCLI(%+v): got error: %+v", tc.ex, err)
} else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
@@ -144,7 +144,7 @@ func TestJSONArgs(t *testing.T) {
}
for _, tc := range testCases {
- e, err := argsFromProcess(&tc.p)
+ e, err := argsFromProcess(&tc.p, true)
if err != nil {
t.Errorf("argsFromProcess(%+v): got error: %+v", tc.p, err)
} else if !cmp.Equal(*e, tc.expected, cmpopts.IgnoreUnexported(os.File{})) {
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 9faabf494..fbd579fb8 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -182,6 +182,7 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
cfg := fsgofer.Config{
ROMount: isReadonlyMount(m.Options),
PanicOnWrite: g.panicOnWrite,
+ HostUDS: conf.FSGoferHostUDS,
}
ap, err := fsgofer.NewAttachPoint(m.Destination, cfg)
if err != nil {
@@ -200,6 +201,10 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
Fatalf("too many FDs passed for mounts. mounts: %d, FDs: %d", mountIdx, len(g.ioFDs))
}
+ if conf.FSGoferHostUDS {
+ filter.InstallUDSFilters()
+ }
+
if err := filter.Install(); err != nil {
Fatalf("installing seccomp filters: %v", err)
}
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index bc1fa25e3..26d1cd5ab 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -47,6 +47,7 @@ go_test(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/log",
"//pkg/sentry/control",
"//pkg/sentry/kernel",
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 2ac12e5b6..519f5ed9b 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -34,6 +34,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -2049,6 +2050,30 @@ func TestMountSymlink(t *testing.T) {
}
}
+// Check that --net-raw disables the CAP_NET_RAW capability.
+func TestNetRaw(t *testing.T) {
+ capNetRaw := strconv.FormatUint(bits.MaskOf64(int(linux.CAP_NET_RAW)), 10)
+ app, err := testutil.FindFile("runsc/container/test_app/test_app")
+ if err != nil {
+ t.Fatal("error finding test_app:", err)
+ }
+
+ for _, enableRaw := range []bool{true, false} {
+ conf := testutil.TestConfig()
+ conf.EnableRaw = enableRaw
+
+ test := "--enabled"
+ if !enableRaw {
+ test = "--disabled"
+ }
+
+ spec := testutil.NewSpecWithArgs(app, "capability", test, capNetRaw)
+ if err := run(spec, conf); err != nil {
+ t.Fatalf("Error running container: %v", err)
+ }
+ }
+}
+
// executeSync synchronously executes a new process.
func (cont *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) {
pid, err := cont.Execute(args)
diff --git a/runsc/container/test_app/test_app.go b/runsc/container/test_app/test_app.go
index 7f735c254..913d781c6 100644
--- a/runsc/container/test_app/test_app.go
+++ b/runsc/container/test_app/test_app.go
@@ -19,10 +19,12 @@ package main
import (
"context"
"fmt"
+ "io/ioutil"
"log"
"net"
"os"
"os/exec"
+ "regexp"
"strconv"
sys "syscall"
"time"
@@ -35,6 +37,7 @@ import (
func main() {
subcommands.Register(subcommands.HelpCommand(), "")
subcommands.Register(subcommands.FlagsCommand(), "")
+ subcommands.Register(new(capability), "")
subcommands.Register(new(fdReceiver), "")
subcommands.Register(new(fdSender), "")
subcommands.Register(new(forkBomb), "")
@@ -287,3 +290,65 @@ func (s *syscall) Execute(ctx context.Context, f *flag.FlagSet, args ...interfac
}
return subcommands.ExitSuccess
}
+
+type capability struct {
+ enabled uint64
+ disabled uint64
+}
+
+// Name implements subcommands.Command.
+func (*capability) Name() string {
+ return "capability"
+}
+
+// Synopsis implements subcommands.Command.
+func (*capability) Synopsis() string {
+ return "checks if effective capabilities are set/unset"
+}
+
+// Usage implements subcommands.Command.
+func (*capability) Usage() string {
+ return "capability [--enabled=number] [--disabled=number]"
+}
+
+// SetFlags implements subcommands.Command.
+func (c *capability) SetFlags(f *flag.FlagSet) {
+ f.Uint64Var(&c.enabled, "enabled", 0, "")
+ f.Uint64Var(&c.disabled, "disabled", 0, "")
+}
+
+// Execute implements subcommands.Command.
+func (c *capability) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ if c.enabled == 0 && c.disabled == 0 {
+ fmt.Println("One of the flags must be set")
+ return subcommands.ExitUsageError
+ }
+
+ status, err := ioutil.ReadFile("/proc/self/status")
+ if err != nil {
+ fmt.Printf("Error reading %q: %v\n", "proc/self/status", err)
+ return subcommands.ExitFailure
+ }
+ re := regexp.MustCompile("CapEff:\t([0-9a-f]+)\n")
+ matches := re.FindStringSubmatch(string(status))
+ if matches == nil || len(matches) != 2 {
+ fmt.Printf("Effective capabilities not found in\n%s\n", status)
+ return subcommands.ExitFailure
+ }
+ caps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ fmt.Printf("failed to convert capabilities %q: %v\n", matches[1], err)
+ return subcommands.ExitFailure
+ }
+
+ if c.enabled != 0 && (caps&c.enabled) != c.enabled {
+ fmt.Printf("Missing capabilities, want: %#x: got: %#x\n", c.enabled, caps)
+ return subcommands.ExitFailure
+ }
+ if c.disabled != 0 && (caps&c.disabled) != 0 {
+ fmt.Printf("Extra capabilities found, dont_want: %#x: got: %#x\n", c.disabled, caps)
+ return subcommands.ExitFailure
+ }
+
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/dockerutil/dockerutil.go b/runsc/dockerutil/dockerutil.go
index c073d8f75..57f6ae8de 100644
--- a/runsc/dockerutil/dockerutil.go
+++ b/runsc/dockerutil/dockerutil.go
@@ -282,7 +282,22 @@ func (d *Docker) Logs() (string, error) {
// Exec calls 'docker exec' with the arguments provided.
func (d *Docker) Exec(args ...string) (string, error) {
- a := []string{"exec", d.Name}
+ return d.ExecWithFlags(nil, args...)
+}
+
+// ExecWithFlags calls 'docker exec <flags> name <args>'.
+func (d *Docker) ExecWithFlags(flags []string, args ...string) (string, error) {
+ a := []string{"exec"}
+ a = append(a, flags...)
+ a = append(a, d.Name)
+ a = append(a, args...)
+ return do(a...)
+}
+
+// ExecAsUser calls 'docker exec' as the given user with the arguments
+// provided.
+func (d *Docker) ExecAsUser(user string, args ...string) (string, error) {
+ a := []string{"exec", "--user", user, d.Name}
a = append(a, args...)
return do(a...)
}
diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go
index 2f3f2039a..c7922b54f 100644
--- a/runsc/fsgofer/filter/config.go
+++ b/runsc/fsgofer/filter/config.go
@@ -214,3 +214,16 @@ var allowedSyscalls = seccomp.SyscallRules{
syscall.SYS_UTIMENSAT: {},
syscall.SYS_WRITE: {},
}
+
+var udsSyscalls = seccomp.SyscallRules{
+ syscall.SYS_SOCKET: []seccomp.Rule{
+ {
+ seccomp.AllowValue(syscall.AF_UNIX),
+ },
+ },
+ syscall.SYS_CONNECT: []seccomp.Rule{
+ {
+ seccomp.AllowAny{},
+ },
+ },
+}
diff --git a/runsc/fsgofer/filter/filter.go b/runsc/fsgofer/filter/filter.go
index 65053415f..289886720 100644
--- a/runsc/fsgofer/filter/filter.go
+++ b/runsc/fsgofer/filter/filter.go
@@ -23,11 +23,16 @@ import (
// Install installs seccomp filters.
func Install() error {
- s := allowedSyscalls
-
// Set of additional filters used by -race and -msan. Returns empty
// when not enabled.
- s.Merge(instrumentationFilters())
+ allowedSyscalls.Merge(instrumentationFilters())
+
+ return seccomp.Install(allowedSyscalls)
+}
- return seccomp.Install(s)
+// InstallUDSFilters extends the allowed syscalls to include those necessary for
+// connecting to a host UDS.
+func InstallUDSFilters() {
+ // Add additional filters required for connecting to the host's sockets.
+ allowedSyscalls.Merge(udsSyscalls)
}
diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go
index 7c4d2b94e..29a82138e 100644
--- a/runsc/fsgofer/fsgofer.go
+++ b/runsc/fsgofer/fsgofer.go
@@ -21,6 +21,7 @@
package fsgofer
import (
+ "errors"
"fmt"
"io"
"math"
@@ -54,6 +55,7 @@ const (
regular fileType = iota
directory
symlink
+ socket
unknown
)
@@ -66,6 +68,8 @@ func (f fileType) String() string {
return "directory"
case symlink:
return "symlink"
+ case socket:
+ return "socket"
}
return "unknown"
}
@@ -82,6 +86,9 @@ type Config struct {
// PanicOnWrite panics on attempts to write to RO mounts.
PanicOnWrite bool
+
+ // HostUDS signals whether the gofer can mount a host's UDS.
+ HostUDS bool
}
type attachPoint struct {
@@ -124,24 +131,50 @@ func (a *attachPoint) Attach() (p9.File, error) {
if err != nil {
return nil, fmt.Errorf("stat file %q, err: %v", a.prefix, err)
}
- mode := syscall.O_RDWR
- if a.conf.ROMount || (stat.Mode&syscall.S_IFMT) == syscall.S_IFDIR {
- mode = syscall.O_RDONLY
- }
-
- // Open the root directory.
- f, err := fd.Open(a.prefix, openFlags|mode, 0)
- if err != nil {
- return nil, fmt.Errorf("unable to open file %q, err: %v", a.prefix, err)
- }
+ // Acquire the attach point lock.
a.attachedMu.Lock()
defer a.attachedMu.Unlock()
+
if a.attached {
- f.Close()
return nil, fmt.Errorf("attach point already attached, prefix: %s", a.prefix)
}
+ // Hold the file descriptor we are converting into a p9.File.
+ var f *fd.FD
+
+ // Apply the S_IFMT bitmask so we can detect file type appropriately.
+ switch fmtStat := stat.Mode & syscall.S_IFMT; fmtStat {
+ case syscall.S_IFSOCK:
+ // Check to see if the CLI option has been set to allow the UDS mount.
+ if !a.conf.HostUDS {
+ return nil, errors.New("host UDS support is disabled")
+ }
+
+ // Attempt to open a connection. Bubble up the failures.
+ f, err = fd.DialUnix(a.prefix)
+ if err != nil {
+ return nil, err
+ }
+
+ default:
+ // Default to Read/Write permissions.
+ mode := syscall.O_RDWR
+
+ // If the configuration is Read Only or the mount point is a directory,
+ // set the mode to Read Only.
+ if a.conf.ROMount || fmtStat == syscall.S_IFDIR {
+ mode = syscall.O_RDONLY
+ }
+
+ // Open the mount point & capture the FD.
+ f, err = fd.Open(a.prefix, openFlags|mode, 0)
+ if err != nil {
+ return nil, fmt.Errorf("unable to open file %q, err: %v", a.prefix, err)
+ }
+ }
+
+ // Return a localFile object to the caller with the UDS FD included.
rv, err := newLocalFile(a, f, a.prefix, stat)
if err != nil {
return nil, err
@@ -295,7 +328,7 @@ func openAnyFile(path string, fn func(mode int) (*fd.FD, error)) (*fd.FD, error)
return file, nil
}
-func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
+func getSupportedFileType(stat syscall.Stat_t, permitSocket bool) (fileType, error) {
var ft fileType
switch stat.Mode & syscall.S_IFMT {
case syscall.S_IFREG:
@@ -304,6 +337,11 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
ft = directory
case syscall.S_IFLNK:
ft = symlink
+ case syscall.S_IFSOCK:
+ if !permitSocket {
+ return unknown, syscall.EPERM
+ }
+ ft = socket
default:
return unknown, syscall.EPERM
}
@@ -311,7 +349,7 @@ func getSupportedFileType(stat syscall.Stat_t) (fileType, error) {
}
func newLocalFile(a *attachPoint, file *fd.FD, path string, stat syscall.Stat_t) (*localFile, error) {
- ft, err := getSupportedFileType(stat)
+ ft, err := getSupportedFileType(stat, a.conf.HostUDS)
if err != nil {
return nil, err
}
@@ -1026,7 +1064,11 @@ func (l *localFile) Flush() error {
// Connect implements p9.File.
func (l *localFile) Connect(p9.ConnectFlags) (*fd.FD, error) {
- return nil, syscall.ECONNREFUSED
+ // Check to see if the CLI option has been set to allow the UDS mount.
+ if !l.attachPoint.conf.HostUDS {
+ return nil, syscall.ECONNREFUSED
+ }
+ return fd.DialUnix(l.hostPath)
}
// Close implements p9.File.
diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go
index cbbe71019..05af7e397 100644
--- a/runsc/fsgofer/fsgofer_test.go
+++ b/runsc/fsgofer/fsgofer_test.go
@@ -665,7 +665,7 @@ func TestAttachInvalidType(t *testing.T) {
}
f, err := a.Attach()
if f != nil || err == nil {
- t.Fatalf("Attach should have failed, got (%v, nil)", f)
+ t.Fatalf("Attach should have failed, got (%v, %v)", f, err)
}
})
}
diff --git a/runsc/main.go b/runsc/main.go
index ff74c0a3d..7dce9dc00 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -68,6 +68,7 @@ var (
network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.")
gso = flag.Bool("gso", true, "enable generic segmenation offload")
fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.")
+ fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "Allow the gofer to mount Unix Domain Sockets.")
overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.")
watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.")
panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.")
@@ -195,6 +196,7 @@ func main() {
DebugLog: *debugLog,
DebugLogFormat: *debugLogFormat,
FileAccess: fsAccess,
+ FSGoferHostUDS: *fsGoferHostUDS,
Overlay: *overlay,
Network: netType,
GSO: *gso,
@@ -239,7 +241,7 @@ func main() {
// want with them. Since Docker and Containerd both eat boot's stderr, we
// dup our stderr to the provided log FD so that panics will appear in the
// logs, rather than just disappear.
- if err := syscall.Dup2(int(f.Fd()), int(os.Stderr.Fd())); err != nil {
+ if err := syscall.Dup3(int(f.Fd()), int(os.Stderr.Fd()), 0); err != nil {
cmd.Fatalf("error dup'ing fd %d to stderr: %v", f.Fd(), err)
}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index 4c6c83fbd..ee9327fc8 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -352,7 +352,7 @@ func (s *Sandbox) createSandboxProcess(conf *boot.Config, args *Args, startSyncF
}
if conf.DebugLog != "" {
test := ""
- if len(conf.TestOnlyTestNameEnv) == 0 {
+ if len(conf.TestOnlyTestNameEnv) != 0 {
// Fetch test name if one is provided and the test only flag was set.
if t, ok := specutils.EnvVar(args.Spec.Process.Env, conf.TestOnlyTestNameEnv); ok {
test = t
diff --git a/runsc/specutils/BUILD b/runsc/specutils/BUILD
index fbfb8e2f8..fa58313a0 100644
--- a/runsc/specutils/BUILD
+++ b/runsc/specutils/BUILD
@@ -13,6 +13,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//pkg/log",
"//pkg/sentry/kernel/auth",
"@com_github_cenkalti_backoff//:go_default_library",
diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go
index d441419cb..c7dd3051c 100644
--- a/runsc/specutils/namespace.go
+++ b/runsc/specutils/namespace.go
@@ -33,19 +33,19 @@ import (
func nsCloneFlag(nst specs.LinuxNamespaceType) uintptr {
switch nst {
case specs.IPCNamespace:
- return syscall.CLONE_NEWIPC
+ return unix.CLONE_NEWIPC
case specs.MountNamespace:
- return syscall.CLONE_NEWNS
+ return unix.CLONE_NEWNS
case specs.NetworkNamespace:
- return syscall.CLONE_NEWNET
+ return unix.CLONE_NEWNET
case specs.PIDNamespace:
- return syscall.CLONE_NEWPID
+ return unix.CLONE_NEWPID
case specs.UTSNamespace:
- return syscall.CLONE_NEWUTS
+ return unix.CLONE_NEWUTS
case specs.UserNamespace:
- return syscall.CLONE_NEWUSER
+ return unix.CLONE_NEWUSER
case specs.CgroupNamespace:
- panic("cgroup namespace has no associated clone flag")
+ return unix.CLONE_NEWCGROUP
default:
panic(fmt.Sprintf("unknown namespace %v", nst))
}
diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go
index cb9e58dfb..591abe458 100644
--- a/runsc/specutils/specutils.go
+++ b/runsc/specutils/specutils.go
@@ -31,6 +31,7 @@ import (
"github.com/cenkalti/backoff"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -241,6 +242,15 @@ func AllCapabilities() *specs.LinuxCapabilities {
}
}
+// AllCapabilitiesUint64 returns a bitmask containing all capabilities set.
+func AllCapabilitiesUint64() uint64 {
+ var rv uint64
+ for _, cap := range capFromName {
+ rv |= bits.MaskOf64(int(cap))
+ }
+ return rv
+}
+
var capFromName = map[string]linux.Capability{
"CAP_CHOWN": linux.CAP_CHOWN,
"CAP_DAC_OVERRIDE": linux.CAP_DAC_OVERRIDE,
diff --git a/scripts/build.sh b/scripts/build.sh
index b3a6e4e7a..0b3d1b316 100755
--- a/scripts/build.sh
+++ b/scripts/build.sh
@@ -23,7 +23,7 @@ sudo apt-get update && sudo apt-get install -y dpkg-sig coreutils apt-utils
runsc=$(build -c opt //runsc)
# Build packages.
-pkg=$(build -c opt --host_force_python=py2 //runsc:runsc-debian)
+pkg=$(build -c opt //runsc:runsc-debian)
# Build a repository, if the key is available.
if [[ -v KOKORO_REPO_KEY ]]; then
diff --git a/scripts/common_bazel.sh b/scripts/common_bazel.sh
index dde0b51ed..f8ec967b1 100755
--- a/scripts/common_bazel.sh
+++ b/scripts/common_bazel.sh
@@ -79,9 +79,16 @@ function collect_logs() {
# Collect sentry logs, if any.
if [[ -v RUNSC_LOGS_DIR ]] && [[ -d "${RUNSC_LOGS_DIR}" ]]; then
- local -r logs=$(ls "${RUNSC_LOGS_DIR}")
- if [[ -z "${logs}" ]]; then
- tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${RUNTIME}.tar.gz" -C "${RUNSC_LOGS_DIR}" .
+ # Check if the directory is empty or not (only the first line it needed).
+ local -r logs=$(ls "${RUNSC_LOGS_DIR}" | head -n1)
+ if [[ "${logs}" ]]; then
+ local -r archive=runsc_logs_"${RUNTIME}".tar.gz
+ if [[ -v KOKORO_BUILD_ARTIFACTS_SUBDIR ]]; then
+ echo "runsc logs will be uploaded to:"
+ echo " gsutil cp gs://gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive} /tmp"
+ echo " https://storage.cloud.google.com/gvisor/logs/${KOKORO_BUILD_ARTIFACTS_SUBDIR}/${archive}"
+ fi
+ tar --create --gzip --file="${KOKORO_ARTIFACTS_DIR}/${archive}" -C "${RUNSC_LOGS_DIR}" .
fi
fi
fi
diff --git a/scripts/dev.sh b/scripts/dev.sh
index ee74dcb72..c67003018 100755
--- a/scripts/dev.sh
+++ b/scripts/dev.sh
@@ -58,7 +58,7 @@ if [[ ${REFRESH} -eq 0 ]]; then
echo
echo "Runtimes ${RUNTIME} and ${RUNTIME}-d (debug enabled) setup."
echo "Use --runtime="${RUNTIME}" with your Docker command."
- echo " docker run --rm --runtime="${RUNTIME}" --rm hello-world"
+ echo " docker run --rm --runtime="${RUNTIME}" hello-world"
echo
echo "If you rebuild, use $0 --refresh."
diff --git a/scripts/make_tests.sh b/scripts/make_tests.sh
index 0fa1248be..79426756d 100755
--- a/scripts/make_tests.sh
+++ b/scripts/make_tests.sh
@@ -21,4 +21,5 @@ top_level=$(git rev-parse --show-toplevel 2>/dev/null)
make
make runsc
+make BAZEL_OPTIONS="build //..." bazel
make bazel-shutdown
diff --git a/test/e2e/BUILD b/test/e2e/BUILD
index 99442cffb..4fe03a220 100644
--- a/test/e2e/BUILD
+++ b/test/e2e/BUILD
@@ -19,7 +19,9 @@ go_test(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/bits",
"//runsc/dockerutil",
+ "//runsc/specutils",
"//runsc/testutil",
],
)
diff --git a/test/e2e/exec_test.go b/test/e2e/exec_test.go
index ce2c4f689..88d26e865 100644
--- a/test/e2e/exec_test.go
+++ b/test/e2e/exec_test.go
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package image provides end-to-end integration tests for runsc. These tests
-// require docker and runsc to be installed on the machine.
+// Package integration provides end-to-end integration tests for runsc. These
+// tests require docker and runsc to be installed on the machine.
//
// Each test calls docker commands to start up a container, and tests that it
// is behaving properly, with various runsc commands. The container is killed
@@ -30,14 +30,17 @@ import (
"time"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/runsc/dockerutil"
+ "gvisor.dev/gvisor/runsc/specutils"
)
+// Test that exec uses the exact same capability set as the container.
func TestExecCapabilities(t *testing.T) {
if err := dockerutil.Pull("alpine"); err != nil {
t.Fatalf("docker pull failed: %v", err)
}
- d := dockerutil.MakeDocker("exec-test")
+ d := dockerutil.MakeDocker("exec-capabilities-test")
// Start the container.
if err := d.Run("alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
@@ -52,27 +55,59 @@ func TestExecCapabilities(t *testing.T) {
if len(matches) != 2 {
t.Fatalf("There should be a match for the whole line and the capability bitmask")
}
- capString := matches[1]
- t.Log("Root capabilities:", capString)
+ want := fmt.Sprintf("CapEff:\t%s\n", matches[1])
+ t.Log("Root capabilities:", want)
- // CAP_NET_RAW was in the capability set for the container, but was
- // removed. However, `exec` does not remove it. Verify that it's not
- // set in the container, then re-add it for comparison.
- caps, err := strconv.ParseUint(capString, 16, 64)
+ // Now check that exec'd process capabilities match the root.
+ got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
if err != nil {
- t.Fatalf("failed to convert capabilities %q: %v", capString, err)
+ t.Fatalf("docker exec failed: %v", err)
}
- if caps&(1<<uint64(linux.CAP_NET_RAW)) != 0 {
- t.Fatalf("CAP_NET_RAW should be filtered, but is set in the container: %x", caps)
+ t.Logf("CapEff: %v", got)
+ if got != want {
+ t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
}
- caps |= 1 << uint64(linux.CAP_NET_RAW)
- want := fmt.Sprintf("CapEff:\t%016x\n", caps)
+}
- // Now check that exec'd process capabilities match the root.
- got, err := d.Exec("grep", "CapEff:", "/proc/self/status")
+// Test that 'exec --privileged' adds all capabilities, except for CAP_NET_RAW
+// which is removed from the container when --net-raw=false.
+func TestExecPrivileged(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-privileged-test")
+
+ // Start the container with all capabilities dropped.
+ if err := d.Run("--cap-drop=all", "alpine", "sh", "-c", "cat /proc/self/status; sleep 100"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Check that all capabilities where dropped from container.
+ matches, err := d.WaitForOutputSubmatch("CapEff:\t([0-9a-f]+)\n", 5*time.Second)
+ if err != nil {
+ t.Fatalf("WaitForOutputSubmatch() timeout: %v", err)
+ }
+ if len(matches) != 2 {
+ t.Fatalf("There should be a match for the whole line and the capability bitmask")
+ }
+ containerCaps, err := strconv.ParseUint(matches[1], 16, 64)
+ if err != nil {
+ t.Fatalf("failed to convert capabilities %q: %v", matches[1], err)
+ }
+ t.Logf("Container capabilities: %#x", containerCaps)
+ if containerCaps != 0 {
+ t.Fatalf("Container should have no capabilities: %x", containerCaps)
+ }
+
+ // Check that 'exec --privileged' adds all capabilities, except
+ // for CAP_NET_RAW.
+ got, err := d.ExecWithFlags([]string{"--privileged"}, "grep", "CapEff:", "/proc/self/status")
if err != nil {
t.Fatalf("docker exec failed: %v", err)
}
+ t.Logf("Exec CapEff: %v", got)
+ want := fmt.Sprintf("CapEff:\t%016x\n", specutils.AllCapabilitiesUint64()&^bits.MaskOf64(int(linux.CAP_NET_RAW)))
if got != want {
t.Errorf("wrong capabilities, got: %q, want: %q", got, want)
}
@@ -154,3 +189,68 @@ func TestExecError(t *testing.T) {
t.Fatalf("docker exec wrong error, got: %s, want: .*%s.*", err.Error(), want)
}
}
+
+// Test that exec inherits environment from run.
+func TestExecEnv(t *testing.T) {
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-env-test")
+
+ // Start the container with env FOO=BAR.
+ if err := d.Run("-e", "FOO=BAR", "alpine", "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec "echo $FOO".
+ got, err := d.Exec("/bin/sh", "-c", "echo $FOO")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := "BAR"; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+}
+
+// Test that exec always has HOME environment set, even when not set in run.
+func TestExecEnvHasHome(t *testing.T) {
+ // Base alpine image does not have any environment variables set.
+ if err := dockerutil.Pull("alpine"); err != nil {
+ t.Fatalf("docker pull failed: %v", err)
+ }
+ d := dockerutil.MakeDocker("exec-env-home-test")
+
+ // We will check that HOME is set for root user, and also for a new
+ // non-root user we will create.
+ newUID := 1234
+ newHome := "/foo/bar"
+
+ // Create a new user with a home directory, and then sleep.
+ script := fmt.Sprintf(`
+ mkdir -p -m 777 %s && \
+ adduser foo -D -u %d -h %s && \
+ sleep 1000`, newHome, newUID, newHome)
+ if err := d.Run("alpine", "/bin/sh", "-c", script); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+ defer d.CleanUp()
+
+ // Exec "echo $HOME", and expect to see "/root".
+ got, err := d.Exec("/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := "/root"; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+
+ // Execute the same as uid 123 and expect newHome.
+ got, err = d.ExecAsUser(strconv.Itoa(newUID), "/bin/sh", "-c", "echo $HOME")
+ if err != nil {
+ t.Fatalf("docker exec failed: %v", err)
+ }
+ if want := newHome; !strings.Contains(got, want) {
+ t.Errorf("wanted exec output to contain %q, got %q", want, got)
+ }
+}
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index dfb4e2a97..1cde74cfc 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -26,6 +26,7 @@ runtime_test(
)
runtime_test(
+ blacklist_file = "blacklist_nodejs12.4.0.csv",
image = "gcr.io/gvisor-presubmit/nodejs12.4.0",
lang = "nodejs",
)
diff --git a/test/runtimes/blacklist_nodejs12.4.0.csv b/test/runtimes/blacklist_nodejs12.4.0.csv
new file mode 100644
index 000000000..9135d763c
--- /dev/null
+++ b/test/runtimes/blacklist_nodejs12.4.0.csv
@@ -0,0 +1,47 @@
+test name,bug id,comment
+benchmark/test-benchmark-fs.js,,
+benchmark/test-benchmark-module.js,,
+benchmark/test-benchmark-napi.js,,
+doctool/test-make-doc.js,b/68848110,Expected to fail.
+fixtures/test-error-first-line-offset.js,,
+fixtures/test-fs-readfile-error.js,,
+fixtures/test-fs-stat-sync-overflow.js,,
+internet/test-dgram-broadcast-multi-process.js,,
+internet/test-dgram-multicast-multi-process.js,,
+internet/test-dgram-multicast-set-interface-lo.js,,
+parallel/test-cluster-dgram-reuse.js,b/64024294,
+parallel/test-dgram-bind-fd.js,b/132447356,
+parallel/test-dgram-create-socket-handle-fd.js,b/132447238,
+parallel/test-dgram-createSocket-type.js,b/68847739,
+parallel/test-dgram-socket-buffer-size.js,b/68847921,
+parallel/test-fs-access.js,,
+parallel/test-fs-write-stream-double-close.js
+parallel/test-fs-write-stream-throw-type-error.js,b/110226209,
+parallel/test-fs-write-stream.js,,
+parallel/test-http2-respond-file-error-pipe-offset.js,,
+parallel/test-os.js,,
+parallel/test-process-uid-gid.js,,
+pseudo-tty/test-assert-colors.js,,
+pseudo-tty/test-assert-no-color.js,,
+pseudo-tty/test-assert-position-indicator.js,,
+pseudo-tty/test-async-wrap-getasyncid-tty.js,,
+pseudo-tty/test-fatal-error.js,,
+pseudo-tty/test-handle-wrap-isrefed-tty.js,,
+pseudo-tty/test-readable-tty-keepalive.js,,
+pseudo-tty/test-set-raw-mode-reset-process-exit.js,,
+pseudo-tty/test-set-raw-mode-reset-signal.js,,
+pseudo-tty/test-set-raw-mode-reset.js,,
+pseudo-tty/test-stderr-stdout-handle-sigwinch.js,,
+pseudo-tty/test-stdout-read.js,,
+pseudo-tty/test-tty-color-support.js,,
+pseudo-tty/test-tty-isatty.js,,
+pseudo-tty/test-tty-stdin-call-end.js,,
+pseudo-tty/test-tty-stdin-end.js,,
+pseudo-tty/test-stdin-write.js,,
+pseudo-tty/test-tty-stdout-end.js,,
+pseudo-tty/test-tty-stdout-resize.js,,
+pseudo-tty/test-tty-stream-constructors.js,,
+pseudo-tty/test-tty-window-size.js,,
+pseudo-tty/test-tty-wrap.js,,
+pummel/test-net-pingpong.js,,
+pummel/test-vm-memleak.js,,
diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl
index 19aceb6fb..7edd12c17 100644
--- a/test/runtimes/build_defs.bzl
+++ b/test/runtimes/build_defs.bzl
@@ -5,20 +5,27 @@
def runtime_test(
lang,
image,
- shard_count = 20,
- size = "enormous"):
+ shard_count = 50,
+ size = "enormous",
+ blacklist_file = ""):
+ args = [
+ "--lang",
+ lang,
+ "--image",
+ image,
+ ]
+ data = [
+ ":runner",
+ ]
+ if blacklist_file != "":
+ args += ["--blacklist_file", "test/runtimes/" + blacklist_file]
+ data += [blacklist_file]
+
sh_test(
name = lang + "_test",
srcs = ["runner.sh"],
- args = [
- "--lang",
- lang,
- "--image",
- image,
- ],
- data = [
- ":runner",
- ],
+ args = args,
+ data = data,
size = size,
shard_count = shard_count,
tags = [
diff --git a/test/runtimes/images/proctor/proctor.go b/test/runtimes/images/proctor/proctor.go
index e2c198b46..e6178e82b 100644
--- a/test/runtimes/images/proctor/proctor.go
+++ b/test/runtimes/images/proctor/proctor.go
@@ -22,8 +22,10 @@ import (
"log"
"os"
"os/exec"
+ "os/signal"
"path/filepath"
"regexp"
+ "syscall"
)
// TestRunner is an interface that must be implemented for each runtime
@@ -40,11 +42,17 @@ var (
runtime = flag.String("runtime", "", "name of runtime")
list = flag.Bool("list", false, "list all available tests")
test = flag.String("test", "", "run a single test from the list of available tests")
+ pause = flag.Bool("pause", false, "cause container to pause indefinitely, reaping any zombie children")
)
func main() {
flag.Parse()
+ if *pause {
+ pauseAndReap()
+ panic("pauseAndReap should never return")
+ }
+
if *runtime == "" {
log.Fatalf("runtime flag must be provided")
}
@@ -73,7 +81,7 @@ func main() {
cmd := tr.TestCmd(*test)
cmd.Stdout, cmd.Stderr = os.Stdout, os.Stderr
if err := cmd.Run(); err != nil {
- log.Fatalf("FAIL %q: %v", err)
+ log.Fatalf("FAIL: %v", err)
}
}
@@ -94,6 +102,27 @@ func testRunnerForRuntime(runtime string) (TestRunner, error) {
return nil, fmt.Errorf("invalid runtime %q", runtime)
}
+// pauseAndReap is like init. It runs forever and reaps any children.
+func pauseAndReap() {
+ // Get notified of any new children.
+ ch := make(chan os.Signal, 1)
+ signal.Notify(ch, syscall.SIGCHLD)
+
+ for {
+ if _, ok := <-ch; !ok {
+ // Channel closed. This should not happen.
+ panic("signal channel closed")
+ }
+
+ // Reap the child.
+ for {
+ if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 {
+ break
+ }
+ }
+ }
+}
+
// search is a helper function to find tests in the given directory that match
// the regex.
func search(root string, testFilter *regexp.Regexp) ([]string, error) {
diff --git a/test/runtimes/runner.go b/test/runtimes/runner.go
index 3111963eb..bec37c69d 100644
--- a/test/runtimes/runner.go
+++ b/test/runtimes/runner.go
@@ -16,8 +16,10 @@
package main
import (
+ "encoding/csv"
"flag"
"fmt"
+ "io"
"log"
"os"
"sort"
@@ -30,8 +32,9 @@ import (
)
var (
- lang = flag.String("lang", "", "language runtime to test")
- image = flag.String("image", "", "docker image with runtime tests")
+ lang = flag.String("lang", "", "language runtime to test")
+ image = flag.String("image", "", "docker image with runtime tests")
+ blacklistFile = flag.String("blacklist_file", "", "file containing blacklist of tests to exclude, in CSV format with fields: test name, bug id, comment")
)
// Wait time for each test to run.
@@ -43,30 +46,59 @@ func main() {
fmt.Fprintf(os.Stderr, "lang and image flags must not be empty\n")
os.Exit(1)
}
- tests, err := testsForImage(*lang, *image)
+
+ os.Exit(runTests())
+}
+
+// runTests is a helper that is called by main. It exists so that we can run
+// defered functions before exiting. It returns an exit code that should be
+// passed to os.Exit.
+func runTests() int {
+ // Get tests to blacklist.
+ blacklist, err := getBlacklist()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error getting blacklist: %s\n", err.Error())
+ return 1
+ }
+
+ // Create a single docker container that will be used for all tests.
+ d := dockerutil.MakeDocker("gvisor-" + *lang)
+ defer d.CleanUp()
+
+ // Get a slice of tests to run. This will also start a single Docker
+ // container that will be used to run each test. The final test will
+ // stop the Docker container.
+ tests, err := getTests(d, blacklist)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err.Error())
- os.Exit(1)
+ return 1
}
- testing.Main(func(a, b string) (bool, error) {
- return a == b, nil
- }, tests, nil, nil)
+ m := testing.MainStart(testDeps{}, tests, nil, nil)
+ return m.Run()
}
-func testsForImage(lang, image string) ([]testing.InternalTest, error) {
- if err := dockerutil.Pull(image); err != nil {
- return nil, fmt.Errorf("docker pull failed: %v", err)
+// getTests returns a slice of tests to run, subject to the shard size and
+// index.
+func getTests(d dockerutil.Docker, blacklist map[string]struct{}) ([]testing.InternalTest, error) {
+ // Pull the image.
+ if err := dockerutil.Pull(*image); err != nil {
+ return nil, fmt.Errorf("docker pull %q failed: %v", *image, err)
}
- c := dockerutil.MakeDocker("gvisor-list")
- list, err := c.RunFg(image, "--runtime", lang, "--list")
- defer c.CleanUp()
- if err != nil {
+ // Run proctor with --pause flag to keep container alive forever.
+ if err := d.Run(*image, "--pause"); err != nil {
return nil, fmt.Errorf("docker run failed: %v", err)
}
- // Get subset of tests corresponding to shard.
+ // Get a list of all tests in the image.
+ list, err := d.Exec("/proctor", "--runtime", *lang, "--list")
+ if err != nil {
+ return nil, fmt.Errorf("docker exec failed: %v", err)
+ }
+
+ // Calculate a subset of tests to run corresponding to the current
+ // shard.
tests := strings.Fields(list)
sort.Strings(tests)
begin, end, err := testutil.TestBoundsForShard(len(tests))
@@ -77,33 +109,91 @@ func testsForImage(lang, image string) ([]testing.InternalTest, error) {
tests = tests[begin:end]
var itests []testing.InternalTest
- for i, tc := range tests {
+ for _, tc := range tests {
// Capture tc in this scope.
tc := tc
itests = append(itests, testing.InternalTest{
Name: tc,
F: func(t *testing.T) {
- d := dockerutil.MakeDocker(fmt.Sprintf("gvisor-test-%d", i))
- defer d.CleanUp()
- if err := d.Run(image, "--runtime", lang, "--test", tc); err != nil {
- t.Fatalf("docker test %q failed to run: %v", tc, err)
+ // Is the test blacklisted?
+ if _, ok := blacklist[tc]; ok {
+ t.Skip("SKIP: blacklisted test %q", tc)
}
- status, err := d.Wait(timeout)
- if err != nil {
- t.Fatalf("docker test %q failed to wait: %v", tc, err)
- }
- logs, err := d.Logs()
- if err != nil {
- t.Fatalf("docker test %q failed to supply logs: %v", tc, err)
- }
- if status == 0 {
- t.Logf("test %q passed", tc)
- return
+ var (
+ now = time.Now()
+ done = make(chan struct{})
+ output string
+ err error
+ )
+
+ go func() {
+ fmt.Printf("RUNNING %s...\n", tc)
+ output, err = d.Exec("/proctor", "--runtime", *lang, "--test", tc)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ if err == nil {
+ fmt.Printf("PASS: %s (%v)\n\n", tc, time.Since(now))
+ return
+ }
+ t.Errorf("FAIL: %s (%v):\n%s\n", tc, time.Since(now), output)
+ case <-time.After(timeout):
+ t.Errorf("TIMEOUT: %s (%v):\n%s\n", tc, time.Since(now), output)
}
- t.Errorf("test %q failed: %v", tc, logs)
},
})
}
return itests, nil
}
+
+// getBlacklist reads the blacklist file and returns a set of test names to
+// exclude.
+func getBlacklist() (map[string]struct{}, error) {
+ blacklist := make(map[string]struct{})
+ if *blacklistFile == "" {
+ return blacklist, nil
+ }
+ file, err := testutil.FindFile(*blacklistFile)
+ if err != nil {
+ return nil, err
+ }
+ f, err := os.Open(file)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ r := csv.NewReader(f)
+
+ // First line is header. Skip it.
+ if _, err := r.Read(); err != nil {
+ return nil, err
+ }
+
+ for {
+ record, err := r.Read()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ return nil, err
+ }
+ blacklist[record[0]] = struct{}{}
+ }
+ return blacklist, nil
+}
+
+// testDeps implements testing.testDeps (an unexported interface), and is
+// required to use testing.MainStart.
+type testDeps struct{}
+
+func (f testDeps) MatchString(a, b string) (bool, error) { return a == b, nil }
+func (f testDeps) StartCPUProfile(io.Writer) error { return nil }
+func (f testDeps) StopCPUProfile() {}
+func (f testDeps) WriteProfileTo(string, io.Writer, int) error { return nil }
+func (f testDeps) ImportPath() string { return "" }
+func (f testDeps) StartTestLog(io.Writer) {}
+func (f testDeps) StopTestLog() error { return nil }
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index 63e4c63dd..341e6b252 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -346,6 +346,11 @@ syscall_test(
)
syscall_test(
+ add_overlay = True,
+ test = "//test/syscalls/linux:readahead_test",
+)
+
+syscall_test(
size = "medium",
shard_count = 5,
test = "//test/syscalls/linux:readv_socket_test",
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index a4cebf46f..b555bd7aa 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -333,6 +333,7 @@ cc_binary(
linkstatic = 1,
deps = [
":socket_test_util",
+ "//test/util:file_descriptor",
"//test/util:test_main",
"//test/util:test_util",
"@com_google_googletest//:gtest",
@@ -1736,6 +1737,20 @@ cc_binary(
)
cc_binary(
+ name = "readahead_test",
+ testonly = 1,
+ srcs = ["readahead.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:file_descriptor",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "readv_test",
testonly = 1,
srcs = [
@@ -2450,6 +2465,63 @@ cc_binary(
)
cc_binary(
+ name = "socket_bind_to_device_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_sequence_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_sequence.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
+ name = "socket_bind_to_device_distribution_test",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_distribution.cc",
+ ],
+ linkstatic = 1,
+ deps = [
+ ":ip_socket_test_util",
+ ":socket_bind_to_device_util",
+ ":socket_test_util",
+ "//test/util:capability_util",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:thread_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_binary(
name = "socket_ip_udp_loopback_non_blocking_test",
testonly = 1,
srcs = [
@@ -2726,6 +2798,23 @@ cc_library(
alwayslink = 1,
)
+cc_library(
+ name = "socket_bind_to_device_util",
+ testonly = 1,
+ srcs = [
+ "socket_bind_to_device_util.cc",
+ ],
+ hdrs = [
+ "socket_bind_to_device_util.h",
+ ],
+ deps = [
+ "//test/util:test_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ ],
+ alwayslink = 1,
+)
+
cc_binary(
name = "socket_stream_local_test",
testonly = 1,
@@ -3149,8 +3238,6 @@ cc_binary(
testonly = 1,
srcs = ["timers.cc"],
linkstatic = 1,
- # FIXME(b/136599201)
- tags = ["flaky"],
deps = [
"//test/util:cleanup",
"//test/util:logging",
@@ -3239,6 +3326,7 @@ cc_binary(
"//test/util:test_main",
"//test/util:test_util",
"//test/util:thread_util",
+ "//test/util:uid_util",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest",
diff --git a/test/syscalls/linux/clock_nanosleep.cc b/test/syscalls/linux/clock_nanosleep.cc
index 52a69d230..b55cddc52 100644
--- a/test/syscalls/linux/clock_nanosleep.cc
+++ b/test/syscalls/linux/clock_nanosleep.cc
@@ -43,7 +43,7 @@ int sys_clock_nanosleep(clockid_t clkid, int flags,
PosixErrorOr<absl::Time> GetTime(clockid_t clk) {
struct timespec ts = {};
- int rc = clock_gettime(clk, &ts);
+ const int rc = clock_gettime(clk, &ts);
MaybeSave();
if (rc < 0) {
return PosixError(errno, "clock_gettime");
@@ -67,31 +67,32 @@ TEST_P(WallClockNanosleepTest, InvalidValues) {
}
TEST_P(WallClockNanosleepTest, SleepOneSecond) {
- absl::Duration const duration = absl::Seconds(1);
- struct timespec dur = absl::ToTimespec(duration);
+ constexpr absl::Duration kSleepDuration = absl::Seconds(1);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
- absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &dur, &dur),
- SyscallSucceeds());
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ EXPECT_THAT(
+ RetryEINTR(sys_clock_nanosleep)(GetParam(), 0, &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- EXPECT_GE(after - before, duration);
+ EXPECT_GE(after - before, kSleepDuration);
}
TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
- absl::Duration const duration = absl::Seconds(60);
- struct timespec dur = absl::ToTimespec(duration);
+ constexpr absl::Duration kSleepDuration = absl::Seconds(60);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
// Install no-op signal handler for SIGALRM.
struct sigaction sa = {};
sigfillset(&sa.sa_mask);
sa.sa_handler = +[](int signo) {};
- auto const cleanup_sa =
+ const auto cleanup_sa =
ASSERT_NO_ERRNO_AND_VALUE(ScopedSigaction(SIGALRM, sa));
// Measure time since setting the alarm, since the alarm will interrupt the
// sleep and hence determine how long we sleep.
- absl::Time const before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time before = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
// Set an alarm to go off while sleeping.
struct itimerval timer = {};
@@ -99,26 +100,51 @@ TEST_P(WallClockNanosleepTest, InterruptedNanosleep) {
timer.it_value.tv_usec = 0;
timer.it_interval.tv_sec = 1;
timer.it_interval.tv_usec = 0;
- auto const cleanup =
+ const auto cleanup =
ASSERT_NO_ERRNO_AND_VALUE(ScopedItimer(ITIMER_REAL, timer));
- EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &dur, &dur),
+ EXPECT_THAT(sys_clock_nanosleep(GetParam(), 0, &duration, &duration),
SyscallFailsWithErrno(EINTR));
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- absl::Duration const remaining = absl::DurationFromTimespec(dur);
- EXPECT_GE(after - before + remaining, duration);
+ // Remaining time updated.
+ const absl::Duration remaining = absl::DurationFromTimespec(duration);
+ EXPECT_GE(after - before + remaining, kSleepDuration);
+}
+
+// Remaining time is *not* updated if nanosleep completes uninterrupted.
+TEST_P(WallClockNanosleepTest, UninterruptedNanosleep) {
+ constexpr absl::Duration kSleepDuration = absl::Milliseconds(10);
+ const struct timespec duration = absl::ToTimespec(kSleepDuration);
+
+ while (true) {
+ constexpr int kRemainingMagic = 42;
+ struct timespec remaining;
+ remaining.tv_sec = kRemainingMagic;
+ remaining.tv_nsec = kRemainingMagic;
+
+ int ret = sys_clock_nanosleep(GetParam(), 0, &duration, &remaining);
+ if (ret == EINTR) {
+ // Retry from beginning. We want a single uninterrupted call.
+ continue;
+ }
+
+ EXPECT_THAT(ret, SyscallSucceeds());
+ EXPECT_EQ(remaining.tv_sec, kRemainingMagic);
+ EXPECT_EQ(remaining.tv_nsec, kRemainingMagic);
+ break;
+ }
}
TEST_P(WallClockNanosleepTest, SleepUntil) {
- absl::Time const now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
- absl::Time const until = now + absl::Seconds(2);
- struct timespec ts = absl::ToTimespec(until);
+ const absl::Time now = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time until = now + absl::Seconds(2);
+ const struct timespec ts = absl::ToTimespec(until);
EXPECT_THAT(
RetryEINTR(sys_clock_nanosleep)(GetParam(), TIMER_ABSTIME, &ts, nullptr),
SyscallSucceeds());
- absl::Time const after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
+ const absl::Time after = ASSERT_NO_ERRNO_AND_VALUE(GetTime(GetParam()));
EXPECT_GE(after, until);
}
@@ -127,8 +153,8 @@ INSTANTIATE_TEST_SUITE_P(Sleepers, WallClockNanosleepTest,
::testing::Values(CLOCK_REALTIME, CLOCK_MONOTONIC));
TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
- absl::Duration const kDuration = absl::Seconds(5);
- struct timespec dur = absl::ToTimespec(kDuration);
+ const absl::Duration kSleepDuration = absl::Seconds(5);
+ struct timespec duration = absl::ToTimespec(kSleepDuration);
// Ensure that CLOCK_PROCESS_CPUTIME_ID advances.
std::atomic<bool> done(false);
@@ -136,16 +162,16 @@ TEST(ClockNanosleepProcessTest, SleepFiveSeconds) {
while (!done.load()) {
}
});
- auto const cleanup_done = Cleanup([&] { done.store(true); });
+ const auto cleanup_done = Cleanup([&] { done.store(true); });
- absl::Time const before =
+ const absl::Time before =
ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
- EXPECT_THAT(
- RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0, &dur, &dur),
- SyscallSucceeds());
- absl::Time const after =
+ EXPECT_THAT(RetryEINTR(sys_clock_nanosleep)(CLOCK_PROCESS_CPUTIME_ID, 0,
+ &duration, &duration),
+ SyscallSucceeds());
+ const absl::Time after =
ASSERT_NO_ERRNO_AND_VALUE(GetTime(CLOCK_PROCESS_CPUTIME_ID));
- EXPECT_GE(after - before, kDuration);
+ EXPECT_GE(after - before, kSleepDuration);
}
} // namespace
diff --git a/test/syscalls/linux/exec_binary.cc b/test/syscalls/linux/exec_binary.cc
index 91b55015c..0a3931e5a 100644
--- a/test/syscalls/linux/exec_binary.cc
+++ b/test/syscalls/linux/exec_binary.cc
@@ -401,12 +401,17 @@ TEST(ElfTest, DataSegment) {
})));
}
-// Additonal pages beyond filesz are always RW.
+// Additonal pages beyond filesz honor (only) execute protections.
//
-// N.B. Linux uses set_brk -> vm_brk to additional pages beyond filesz (even
-// though start_brk itself will always be beyond memsz). As a result, the
-// segment permissions don't apply; the mapping is always RW.
+// N.B. Linux changed this in 4.11 (16e72e9b30986 "powerpc: do not make the
+// entire heap executable"). Previously, extra pages were always RW.
TEST(ElfTest, ExtraMemPages) {
+ // gVisor has the newer behavior.
+ if (!IsRunningOnGvisor()) {
+ auto version = ASSERT_NO_ERRNO_AND_VALUE(GetKernelVersion());
+ SKIP_IF(version.major < 4 || (version.major == 4 && version.minor < 11));
+ }
+
ElfBinary<64> elf = StandardElf();
// Create a standard ELF, but extend to 1.5 pages. The second page will be the
@@ -415,7 +420,7 @@ TEST(ElfTest, ExtraMemPages) {
decltype(elf)::ElfPhdr phdr = {};
phdr.p_type = PT_LOAD;
- // RWX segment. The extra anon page will be RW anyways.
+ // RWX segment. The extra anon page will also be RWX.
//
// N.B. Linux uses clear_user to clear the end of the file-mapped page, which
// respects the mapping protections. Thus if we map this RO with memsz >
@@ -454,7 +459,7 @@ TEST(ElfTest, ExtraMemPages) {
{0x41000, 0x42000, true, true, true, true, kPageSize, 0, 0, 0,
file.path().c_str()},
// extra page from anon.
- {0x42000, 0x43000, true, true, false, true, 0, 0, 0, 0, ""},
+ {0x42000, 0x43000, true, true, true, true, 0, 0, 0, 0, ""},
})));
}
@@ -469,7 +474,7 @@ TEST(ElfTest, AnonOnlySegment) {
phdr.p_offset = 0;
phdr.p_vaddr = 0x41000;
phdr.p_filesz = 0;
- phdr.p_memsz = kPageSize - 0xe8;
+ phdr.p_memsz = kPageSize;
elf.phdrs.push_back(phdr);
elf.UpdateOffsets();
@@ -854,6 +859,11 @@ TEST(ElfTest, ELFInterpreter) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // N.B. Since Linux 4.10 (0036d1f7eb95b "binfmt_elf: fix calculations for bss
+ // padding"), Linux unconditionally zeroes the remainder of the highest mapped
+ // page in an interpreter, failing if the protections don't allow write. Thus
+ // we must mark this writeable.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -903,15 +913,15 @@ TEST(ElfTest, ELFInterpreter) {
const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
- EXPECT_THAT(child,
- ContainsMappings(std::vector<ProcMapsEntry>({
- // Main binary
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
- binary_file.path().c_str()},
- // Interpreter
- {interp_load_addr, interp_load_addr + 0x1000, true, false,
- true, true, 0, 0, 0, 0, interpreter_file.path().c_str()},
- })));
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Main binary
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ binary_file.path().c_str()},
+ // Interpreter
+ {interp_load_addr, interp_load_addr + 0x1000, true, true, true,
+ true, 0, 0, 0, 0, interpreter_file.path().c_str()},
+ })));
}
// Test parameter to ElfInterpterStaticTest cases. The first item is a suffix to
@@ -928,6 +938,8 @@ TEST_P(ElfInterpreterStaticTest, Test) {
const int expected_errno = std::get<1>(GetParam());
ElfBinary<64> interpreter = StandardElf();
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.UpdateOffsets();
TempPath interpreter_file =
ASSERT_NO_ERRNO_AND_VALUE(CreateElfWith(interpreter));
@@ -957,7 +969,7 @@ TEST_P(ElfInterpreterStaticTest, Test) {
EXPECT_THAT(child, ContainsMappings(std::vector<ProcMapsEntry>({
// Interpreter.
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0,
+ {0x40000, 0x41000, true, true, true, true, 0, 0, 0,
0, interpreter_file.path().c_str()},
})));
}
@@ -1035,6 +1047,8 @@ TEST(ElfTest, ELFInterpreterRelative) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -1073,15 +1087,15 @@ TEST(ElfTest, ELFInterpreterRelative) {
const uint64_t interp_load_addr = regs.rip & ~(kPageSize - 1);
- EXPECT_THAT(child,
- ContainsMappings(std::vector<ProcMapsEntry>({
- // Main binary
- {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
- binary_file.path().c_str()},
- // Interpreter
- {interp_load_addr, interp_load_addr + 0x1000, true, false,
- true, true, 0, 0, 0, 0, interpreter_file.path().c_str()},
- })));
+ EXPECT_THAT(
+ child, ContainsMappings(std::vector<ProcMapsEntry>({
+ // Main binary
+ {0x40000, 0x41000, true, false, true, true, 0, 0, 0, 0,
+ binary_file.path().c_str()},
+ // Interpreter
+ {interp_load_addr, interp_load_addr + 0x1000, true, true, true,
+ true, 0, 0, 0, 0, interpreter_file.path().c_str()},
+ })));
}
// ELF interpreter architecture doesn't match the binary.
@@ -1095,6 +1109,8 @@ TEST(ElfTest, ELFInterpreterWrongArch) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
@@ -1174,6 +1190,8 @@ TEST(ElfTest, ElfInterpreterNoExecute) {
// The first segment really needs to start at 0 for a normal PIE binary, and
// thus includes the headers.
uint64_t const offset = interpreter.phdrs[1].p_offset;
+ // See comment in ElfTest.ELFInterpreter.
+ interpreter.phdrs[1].p_flags = PF_R | PF_W | PF_X;
interpreter.phdrs[1].p_offset = 0x0;
interpreter.phdrs[1].p_vaddr = 0x0;
interpreter.phdrs[1].p_filesz += offset;
diff --git a/test/syscalls/linux/itimer.cc b/test/syscalls/linux/itimer.cc
index 51ce323b9..930d2b940 100644
--- a/test/syscalls/linux/itimer.cc
+++ b/test/syscalls/linux/itimer.cc
@@ -336,7 +336,9 @@ int main(int argc, char** argv) {
}
if (arg == gvisor::testing::kSIGPROFFairnessIdle) {
MaskSIGPIPE();
- return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(10));
+ // Sleep time > ClockTick (10ms) exercises sleeping gVisor's
+ // kernel.cpuClockTicker.
+ return gvisor::testing::TestSIGPROFFairness(absl::Milliseconds(25));
}
}
diff --git a/test/syscalls/linux/packet_socket.cc b/test/syscalls/linux/packet_socket.cc
index 7a3379b9e..37b4e6575 100644
--- a/test/syscalls/linux/packet_socket.cc
+++ b/test/syscalls/linux/packet_socket.cc
@@ -83,9 +83,15 @@ void SendUDPMessage(int sock) {
// Send an IP packet and make sure ETH_P_<something else> doesn't pick it up.
TEST(BasicCookedPacketTest, WrongType) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
FileDescriptor sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_PACKET, SOCK_DGRAM, ETH_P_PUP));
@@ -118,18 +124,27 @@ class CookedPacketTest : public ::testing::TestWithParam<int> {
};
void CookedPacketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
ASSERT_THAT(socket_ = socket(AF_PACKET, SOCK_DGRAM, htons(GetParam())),
SyscallSucceeds());
}
void CookedPacketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
int CookedPacketTest::GetLoopbackIndex() {
@@ -142,9 +157,6 @@ int CookedPacketTest::GetLoopbackIndex() {
// Receive via a packet socket.
TEST_P(CookedPacketTest, Receive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -201,9 +213,6 @@ TEST_P(CookedPacketTest, Receive) {
// Send via a packet socket.
TEST_P(CookedPacketTest, Send) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/packet_socket_raw.cc b/test/syscalls/linux/packet_socket_raw.cc
index 9e96460ee..6491453b6 100644
--- a/test/syscalls/linux/packet_socket_raw.cc
+++ b/test/syscalls/linux/packet_socket_raw.cc
@@ -97,9 +97,15 @@ class RawPacketTest : public ::testing::TestWithParam<int> {
};
void RawPacketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_PACKET, SOCK_RAW, htons(GetParam())),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
+
if (!IsRunningOnGvisor()) {
FileDescriptor acceptLocal = ASSERT_NO_ERRNO_AND_VALUE(
Open("/proc/sys/net/ipv4/conf/lo/accept_local", O_RDONLY));
@@ -119,10 +125,13 @@ void RawPacketTest::SetUp() {
}
void RawPacketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ // (b/129292371): Remove once we support packet sockets.
SKIP_IF(IsRunningOnGvisor());
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
int RawPacketTest::GetLoopbackIndex() {
@@ -135,9 +144,6 @@ int RawPacketTest::GetLoopbackIndex() {
// Receive via a packet socket.
TEST_P(RawPacketTest, Receive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's use a simple IP payload: a UDP datagram.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
@@ -208,9 +214,6 @@ TEST_P(RawPacketTest, Receive) {
// Send via a packet socket.
TEST_P(RawPacketTest, Send) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
- SKIP_IF(IsRunningOnGvisor());
-
// Let's send a UDP packet and receive it using a regular UDP socket.
FileDescriptor udp_sock =
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, 0));
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index 6f07803d9..e4c030bbb 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -440,6 +440,11 @@ TEST(ProcSelfAuxv, EntryPresence) {
EXPECT_EQ(auxv_entries.count(AT_PHENT), 1);
EXPECT_EQ(auxv_entries.count(AT_PHNUM), 1);
EXPECT_EQ(auxv_entries.count(AT_BASE), 1);
+ EXPECT_EQ(auxv_entries.count(AT_UID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EUID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_GID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_EGID), 1);
+ EXPECT_EQ(auxv_entries.count(AT_SECURE), 1);
EXPECT_EQ(auxv_entries.count(AT_CLKTCK), 1);
EXPECT_EQ(auxv_entries.count(AT_RANDOM), 1);
EXPECT_EQ(auxv_entries.count(AT_EXECFN), 1);
diff --git a/test/syscalls/linux/proc_net_tcp.cc b/test/syscalls/linux/proc_net_tcp.cc
index cb81b4e6e..f61795592 100644
--- a/test/syscalls/linux/proc_net_tcp.cc
+++ b/test/syscalls/linux/proc_net_tcp.cc
@@ -269,16 +269,14 @@ struct TCP6Entry {
};
bool IPv6AddrEqual(const struct in6_addr* a1, const struct in6_addr* a2) {
- if (memcmp(a1, a2, sizeof(struct in6_addr)) == 0)
- return true;
- return false;
+ return memcmp(a1, a2, sizeof(struct in6_addr)) == 0;
}
// Finds the first entry in 'entries' for which 'predicate' returns true.
// Returns true on match, and sets 'match' to a copy of the matching entry. If
// 'match' is null, it's ignored.
bool FindBy6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
- std::function<bool(const TCP6Entry&)> predicate) {
+ std::function<bool(const TCP6Entry&)> predicate) {
for (const TCP6Entry& entry : entries) {
if (predicate(entry)) {
if (match != nullptr) {
@@ -305,7 +303,7 @@ bool FindByLocalAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
}
bool FindByRemoteAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
- const struct sockaddr* addr) {
+ const struct sockaddr* addr) {
const struct in6_addr* remote = IP6FromInetSockaddr(addr);
uint16_t port = PortFromInetSockaddr(addr);
return FindBy6(entries, match, [remote, port](const TCP6Entry& e) {
@@ -315,14 +313,14 @@ bool FindByRemoteAddr6(const std::vector<TCP6Entry>& entries, TCP6Entry* match,
void ReadIPv6Address(std::string s, struct in6_addr* addr) {
uint32_t a0, a1, a2, a3;
- const char *fmt = "%08X%08X%08X%08X";
+ const char* fmt = "%08X%08X%08X%08X";
EXPECT_EQ(sscanf(s.c_str(), fmt, &a0, &a1, &a2, &a3), 4);
- uint8_t *b = addr->s6_addr;
- *((uint32_t *)&b[0]) = a0;
- *((uint32_t *)&b[4]) = a1;
- *((uint32_t *)&b[8]) = a2;
- *((uint32_t *)&b[12]) = a3;
+ uint8_t* b = addr->s6_addr;
+ *((uint32_t*)&b[0]) = a0;
+ *((uint32_t*)&b[4]) = a1;
+ *((uint32_t*)&b[8]) = a2;
+ *((uint32_t*)&b[12]) = a3;
}
// Returns a parsed representation of /proc/net/tcp6 entries.
@@ -429,7 +427,8 @@ TEST(ProcNetTCP6, InodeReasonable) {
TCP6Entry accepted_entry;
- ASSERT_TRUE(FindByLocalAddr6(entries, &accepted_entry, sockets->first_addr()));
+ ASSERT_TRUE(
+ FindByLocalAddr6(entries, &accepted_entry, sockets->first_addr()));
EXPECT_NE(accepted_entry.inode, 0);
TCP6Entry client_entry;
@@ -481,14 +480,14 @@ TEST(ProcNetTCP6, State) {
entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetTCP6Entries());
TCP6Entry accepted_entry;
- ASSERT_TRUE(FindBy6(entries, &accepted_entry,
- [client_entry, local,
- accepted_local_port](const TCP6Entry& e) {
- return IPv6AddrEqual(&e.local_addr, local) &&
- e.local_port == accepted_local_port &&
- IPv6AddrEqual(&e.remote_addr, &client_entry.local_addr) &&
- e.remote_port == client_entry.local_port;
- }));
+ ASSERT_TRUE(FindBy6(
+ entries, &accepted_entry,
+ [client_entry, local, accepted_local_port](const TCP6Entry& e) {
+ return IPv6AddrEqual(&e.local_addr, local) &&
+ e.local_port == accepted_local_port &&
+ IPv6AddrEqual(&e.remote_addr, &client_entry.local_addr) &&
+ e.remote_port == client_entry.local_port;
+ }));
EXPECT_EQ(accepted_entry.state, TCP_ESTABLISHED);
}
diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc
index 286388316..99a0df235 100644
--- a/test/syscalls/linux/pty.cc
+++ b/test/syscalls/linux/pty.cc
@@ -1292,10 +1292,9 @@ TEST_F(JobControlTest, ReleaseTTY) {
// Make sure we're ignoring SIGHUP, which will be sent to this process once we
// disconnect they TTY.
- struct sigaction sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
struct sigaction old_sa;
EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds());
@@ -1362,10 +1361,9 @@ TEST_F(JobControlTest, ReleaseTTYSignals) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
received = 0;
- struct sigaction sa = {
- .sa_handler = sig_handler,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = sig_handler;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
sigaddset(&sa.sa_mask, SIGHUP);
sigaddset(&sa.sa_mask, SIGCONT);
@@ -1403,10 +1401,9 @@ TEST_F(JobControlTest, ReleaseTTYSignals) {
// Make sure we're ignoring SIGHUP, which will be sent to this process once we
// disconnect they TTY.
- struct sigaction sighup_sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sighup_sa = {};
+ sighup_sa.sa_handler = SIG_IGN;
+ sighup_sa.sa_flags = 0;
sigemptyset(&sighup_sa.sa_mask);
struct sigaction old_sa;
EXPECT_THAT(sigaction(SIGHUP, &sighup_sa, &old_sa), SyscallSucceeds());
@@ -1456,10 +1453,9 @@ TEST_F(JobControlTest, SetForegroundProcessGroup) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
// Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp.
- struct sigaction sa = {
- .sa_handler = SIG_IGN,
- .sa_flags = 0,
- };
+ struct sigaction sa = {};
+ sa.sa_handler = SIG_IGN;
+ sa.sa_flags = 0;
sigemptyset(&sa.sa_mask);
sigaction(SIGTTOU, &sa, NULL);
@@ -1531,27 +1527,36 @@ TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) {
TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) {
ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds());
+ int sync_setsid[2];
+ int sync_exit[2];
+ ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds());
+ ASSERT_THAT(pipe(sync_exit), SyscallSucceeds());
+
// Create a new process and put it in a new session.
pid_t child = fork();
if (!child) {
TEST_PCHECK(setsid() >= 0);
// Tell the parent we're in a new session.
- TEST_PCHECK(!raise(SIGSTOP));
- TEST_PCHECK(!pause());
- _exit(1);
+ char c = 'c';
+ TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1);
+ TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1);
+ _exit(0);
}
// Wait for the child to tell us it's in a new session.
- int wstatus;
- EXPECT_THAT(waitpid(child, &wstatus, WUNTRACED),
- SyscallSucceedsWithValue(child));
- EXPECT_TRUE(WSTOPSIG(wstatus));
+ char c = 'c';
+ ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1));
// Child is in a new session, so we can't make it the foregroup process group.
EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child),
SyscallFailsWithErrno(EPERM));
- EXPECT_THAT(kill(child, SIGKILL), SyscallSucceeds());
+ EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1));
+
+ int wstatus;
+ EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(wstatus));
+ EXPECT_EQ(WEXITSTATUS(wstatus), 0);
}
// Verify that we don't hang when creating a new session from an orphaned
diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc
index a070817eb..0a27506aa 100644
--- a/test/syscalls/linux/raw_socket_hdrincl.cc
+++ b/test/syscalls/linux/raw_socket_hdrincl.cc
@@ -63,7 +63,11 @@ class RawHDRINCL : public ::testing::Test {
};
void RawHDRINCL::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
SyscallSucceeds());
@@ -76,9 +80,10 @@ void RawHDRINCL::SetUp() {
}
void RawHDRINCL::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(socket_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(socket_), SyscallSucceeds());
+ }
}
struct iphdr RawHDRINCL::LoopbackHeader() {
@@ -123,8 +128,6 @@ bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port,
// We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup
// creates the first one, so we only have to create one more here.
TEST_F(RawHDRINCL, MultipleCreation) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int s2;
ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds());
@@ -133,23 +136,17 @@ TEST_F(RawHDRINCL, MultipleCreation) {
// Test that shutting down an unconnected socket fails.
TEST_F(RawHDRINCL, FailShutdownWithoutConnect) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
}
// Test that listen() fails.
TEST_F(RawHDRINCL, FailListen) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP));
}
// Test that accept() fails.
TEST_F(RawHDRINCL, FailAccept) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct sockaddr saddr;
socklen_t addrlen;
ASSERT_THAT(accept(socket_, &saddr, &addrlen),
@@ -158,8 +155,6 @@ TEST_F(RawHDRINCL, FailAccept) {
// Test that the socket is writable immediately.
TEST_F(RawHDRINCL, PollWritableImmediately) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct pollfd pfd = {};
pfd.fd = socket_;
pfd.events = POLLOUT;
@@ -168,8 +163,6 @@ TEST_F(RawHDRINCL, PollWritableImmediately) {
// Test that the socket isn't readable.
TEST_F(RawHDRINCL, NotReadable) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
// Try to receive data with MSG_DONTWAIT, which returns immediately if there's
// nothing to be read.
char buf[117];
@@ -179,16 +172,12 @@ TEST_F(RawHDRINCL, NotReadable) {
// Test that we can connect() to a valid IP (loopback).
TEST_F(RawHDRINCL, ConnectToLoopback) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
}
TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct iphdr hdr = LoopbackHeader();
ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0),
SyscallSucceedsWithValue(sizeof(hdr)));
@@ -197,8 +186,6 @@ TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
// HDRINCL implies write-only. Verify that we can't read a packet sent to
// loopback.
TEST_F(RawHDRINCL, NotReadableAfterWrite) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
@@ -221,8 +208,6 @@ TEST_F(RawHDRINCL, NotReadableAfterWrite) {
}
TEST_F(RawHDRINCL, WriteTooSmall) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
sizeof(addr_)),
SyscallSucceeds());
@@ -235,8 +220,6 @@ TEST_F(RawHDRINCL, WriteTooSmall) {
// Bind to localhost.
TEST_F(RawHDRINCL, BindToLocalhost) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
ASSERT_THAT(
bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
SyscallSucceeds());
@@ -244,8 +227,6 @@ TEST_F(RawHDRINCL, BindToLocalhost) {
// Bind to a different address.
TEST_F(RawHDRINCL, BindToInvalid) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
struct sockaddr_in bind_addr = {};
bind_addr.sin_family = AF_INET;
bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
@@ -256,8 +237,6 @@ TEST_F(RawHDRINCL, BindToInvalid) {
// Send and receive a packet.
TEST_F(RawHDRINCL, SendAndReceive) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -302,8 +281,6 @@ TEST_F(RawHDRINCL, SendAndReceive) {
// Send and receive a packet with nonzero IP ID.
TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
@@ -349,8 +326,6 @@ TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
// Send and receive a packet where the sendto address is not the same as the
// provided destination.
TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
int port = 40000;
if (!IsRunningOnGvisor()) {
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc
index 971592d7d..8bcaba6f1 100644
--- a/test/syscalls/linux/raw_socket_icmp.cc
+++ b/test/syscalls/linux/raw_socket_icmp.cc
@@ -77,7 +77,11 @@ class RawSocketICMPTest : public ::testing::Test {
};
void RawSocketICMPTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, IPPROTO_ICMP),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, IPPROTO_ICMP), SyscallSucceeds());
@@ -90,9 +94,10 @@ void RawSocketICMPTest::SetUp() {
}
void RawSocketICMPTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(s_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
}
// We'll only read an echo in this case, as the kernel won't respond to the
diff --git a/test/syscalls/linux/raw_socket_ipv4.cc b/test/syscalls/linux/raw_socket_ipv4.cc
index 352037c88..cde2f07c9 100644
--- a/test/syscalls/linux/raw_socket_ipv4.cc
+++ b/test/syscalls/linux/raw_socket_ipv4.cc
@@ -67,7 +67,11 @@ class RawSocketTest : public ::testing::TestWithParam<int> {
};
void RawSocketTest::SetUp() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+ if (!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ ASSERT_THAT(socket(AF_INET, SOCK_RAW, Protocol()),
+ SyscallFailsWithErrno(EPERM));
+ GTEST_SKIP();
+ }
ASSERT_THAT(s_ = socket(AF_INET, SOCK_RAW, Protocol()), SyscallSucceeds());
@@ -79,9 +83,10 @@ void RawSocketTest::SetUp() {
}
void RawSocketTest::TearDown() {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
-
- EXPECT_THAT(close(s_), SyscallSucceeds());
+ // TearDown will be run even if we skip the test.
+ if (ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) {
+ EXPECT_THAT(close(s_), SyscallSucceeds());
+ }
}
// We should be able to create multiple raw sockets for the same protocol.
diff --git a/test/syscalls/linux/readahead.cc b/test/syscalls/linux/readahead.cc
new file mode 100644
index 000000000..09703b5c1
--- /dev/null
+++ b/test/syscalls/linux/readahead.cc
@@ -0,0 +1,91 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <errno.h>
+#include <fcntl.h>
+
+#include "gtest/gtest.h"
+#include "test/util/file_descriptor.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+TEST(ReadaheadTest, InvalidFD) {
+ EXPECT_THAT(readahead(-1, 1, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ReadaheadTest, InvalidOffset) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ EXPECT_THAT(readahead(fd.get(), -1, 1), SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(ReadaheadTest, ValidOffset) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+
+ // N.B. The implementation of readahead is filesystem-specific, and a file
+ // backed by ram may return EINVAL because there is nothing to be read.
+ EXPECT_THAT(readahead(fd.get(), 1, 1), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, PastEnd) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ // See above.
+ EXPECT_THAT(readahead(fd.get(), 2, 2), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, CrossesEnd) {
+ constexpr char kData[] = "123";
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ GetAbsoluteTestTmpdir(), kData, TempPath::kDefaultFileMode));
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ // See above.
+ EXPECT_THAT(readahead(fd.get(), 4, 2), AnyOf(SyscallSucceedsWithValue(0),
+ SyscallFailsWithErrno(EINVAL)));
+}
+
+TEST(ReadaheadTest, WriteOnly) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_WRONLY));
+ EXPECT_THAT(readahead(fd.get(), 0, 1), SyscallFailsWithErrno(EBADF));
+}
+
+TEST(ReadaheadTest, InvalidSize) {
+ const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ const FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDWR));
+ EXPECT_THAT(readahead(fd.get(), 0, -1), SyscallFailsWithErrno(EINVAL));
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket.cc b/test/syscalls/linux/socket.cc
index caae215b8..3a07ac8d2 100644
--- a/test/syscalls/linux/socket.cc
+++ b/test/syscalls/linux/socket.cc
@@ -17,6 +17,7 @@
#include "gtest/gtest.h"
#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/file_descriptor.h"
#include "test/util/test_util.h"
namespace gvisor {
@@ -57,5 +58,28 @@ TEST(SocketTest, ProtocolInet) {
}
}
+using SocketOpenTest = ::testing::TestWithParam<int>;
+
+// UDS cannot be opened.
+TEST_P(SocketOpenTest, Unix) {
+ // FIXME(b/142001530): Open incorrectly succeeds on gVisor.
+ SKIP_IF(IsRunningOnGvisor());
+
+ FileDescriptor bound =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, PF_UNIX));
+
+ struct sockaddr_un addr =
+ ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(/*abstract=*/false, AF_UNIX));
+
+ ASSERT_THAT(bind(bound.get(), reinterpret_cast<struct sockaddr*>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+
+ EXPECT_THAT(open(addr.sun_path, GetParam()), SyscallFailsWithErrno(ENXIO));
+}
+
+INSTANTIATE_TEST_SUITE_P(OpenModes, SocketOpenTest,
+ ::testing::Values(O_RDONLY, O_RDWR));
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc
new file mode 100644
index 000000000..d20821cac
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device.cc
@@ -0,0 +1,314 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+// Test fixture for SO_BINDTODEVICE tests.
+class BindToDeviceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+
+ interface_name_ = "eth1";
+ auto interface_names = GetInterfaceNames();
+ if (interface_names.find(interface_name_) == interface_names.end()) {
+ // Need a tunnel.
+ tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New());
+ interface_name_ = tunnel_->GetName();
+ ASSERT_FALSE(interface_name_.empty());
+ }
+ socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create());
+ }
+
+ string interface_name() const { return interface_name_; }
+
+ int socket_fd() const { return socket_->get(); }
+
+ private:
+ std::unique_ptr<Tunnel> tunnel_;
+ string interface_name_;
+ std::unique_ptr<FileDescriptor> socket_;
+};
+
+constexpr char kIllegalIfnameChar = '/';
+
+// Tests getsockopt of the default value.
+TEST_P(BindToDeviceTest, GetsockoptDefault) {
+ char name_buffer[IFNAMSIZ * 2];
+ char original_name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Read the default SO_BINDTODEVICE.
+ memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(name_buffer_size, 0);
+ EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)),
+ 0);
+ }
+}
+
+// Tests setsockopt of invalid device name.
+TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Set an invalid device name.
+ memset(name_buffer, kIllegalIfnameChar, 5);
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+}
+
+// Tests setsockopt of a buffer with a valid device name but not
+// null-terminated, with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Intentionally overwrite the null at the end.
+ memset(name_buffer + interface_name().size(), kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size());
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is exactly right.
+ if (name_buffer_size == interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests setsockopt of a buffer with a valid device name and null-terminated,
+// with different sizes of buffer.
+TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1);
+ // Don't overwrite the null at the end.
+ memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar,
+ sizeof(name_buffer) - interface_name().size() - 1);
+ for (size_t i = 1; i <= sizeof(name_buffer); i++) {
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // It should only work if the size provided is at least the right size.
+ if (name_buffer_size >= interface_name().size()) {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallSucceeds());
+ } else {
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, name_buffer_size),
+ SyscallFailsWithErrno(ENODEV));
+ }
+ }
+}
+
+// Tests that setsockopt of an invalid device name doesn't unset the previous
+// valid setsockopt.
+TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Write unsuccessfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = 5;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallFailsWithErrno(ENODEV));
+
+ // Read it back successfully, it's unchanged.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+}
+
+// Tests that setsockopt of zero-length string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClear) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ name_buffer_size = 0;
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests that setsockopt of empty string correctly unsets the previous
+// value.
+TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+
+ // Clear it successfully.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer[0] = 0;
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ name_buffer_size),
+ SyscallSucceeds());
+
+ // Read it back successfully, it's cleared.
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = sizeof(name_buffer);
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, 0);
+}
+
+// Tests getsockopt with different buffer sizes.
+TEST_P(BindToDeviceTest, GetsockoptDevice) {
+ char name_buffer[IFNAMSIZ * 2];
+ socklen_t name_buffer_size;
+
+ // Write successfully.
+ strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer));
+ ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer,
+ sizeof(name_buffer)),
+ SyscallSucceeds());
+
+ // Read it back at various buffer sizes.
+ for (size_t i = 0; i <= sizeof(name_buffer); i++) {
+ memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer));
+ name_buffer_size = i;
+ SCOPED_TRACE(absl::StrCat("Buffer size: ", i));
+ // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice
+ // for this interface name.
+ if (name_buffer_size >= IFNAMSIZ) {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallSucceeds());
+ EXPECT_EQ(name_buffer_size, interface_name().size() + 1);
+ EXPECT_STREQ(name_buffer, interface_name().c_str());
+ } else {
+ EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE,
+ name_buffer, &name_buffer_size),
+ SyscallFailsWithErrno(EINVAL));
+ EXPECT_EQ(name_buffer_size, i);
+ }
+ }
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc
new file mode 100644
index 000000000..4d2400328
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc
@@ -0,0 +1,381 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <atomic>
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+struct EndpointConfig {
+ std::string bind_to_device;
+ double expected_ratio;
+};
+
+struct DistributionTestCase {
+ std::string name;
+ std::vector<EndpointConfig> endpoints;
+};
+
+struct ListenerConnector {
+ TestAddress listener;
+ TestAddress connector;
+};
+
+// Test fixture for SO_BINDTODEVICE tests the distribution of packets received
+// with varying SO_BINDTODEVICE settings.
+class BindToDeviceDistributionTest
+ : public ::testing::TestWithParam<
+ ::testing::tuple<ListenerConnector, DistributionTestCase>> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s, listener=%s, connector=%s\n",
+ ::testing::get<1>(GetParam()).name.c_str(),
+ ::testing::get<0>(GetParam()).listener.description.c_str(),
+ ::testing::get<0>(GetParam()).connector.description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ }
+};
+
+PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) {
+ switch (family) {
+ case AF_INET:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in const*>(&addr)->sin_port);
+ case AF_INET6:
+ return static_cast<uint16_t>(
+ reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port);
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) {
+ switch (family) {
+ case AF_INET:
+ reinterpret_cast<sockaddr_in*>(addr)->sin_port = port;
+ return NoError();
+ case AF_INET6:
+ reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port;
+ return NoError();
+ default:
+ return PosixError(EINVAL,
+ absl::StrCat("unknown socket family: ", family));
+ }
+}
+
+// Binds sockets to different devices and then creates many TCP connections.
+// Checks that the distribution of connections received on the sockets matches
+// the expectation.
+TEST_P(BindToDeviceDistributionTest, Tcp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening sockets.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+ ASSERT_THAT(listen(fd, 40), SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> connects_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> accept_counts(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> listen_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ listen_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &accept_counts, &connects_received, i,
+ kConnectAttempts]() {
+ do {
+ auto fd = Accept(listener_fds[i].get(), nullptr, nullptr);
+ if (!fd.ok()) {
+ // Another thread has shutdown our read side causing the accept to
+ // fail.
+ ASSERT_GE(connects_received, kConnectAttempts)
+ << "errno = " << fd.error();
+ return;
+ }
+ // Receive some data from a socket to be sure that the connect()
+ // system call has been completed on another side.
+ int data;
+ EXPECT_THAT(
+ RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ accept_counts[i]++;
+ } while (++connects_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& listen_thread : listen_threads) {
+ listen_thread->Join();
+ }
+ // Check that connections are distributed correctly among listening sockets.
+ for (int i = 0; i < accept_counts.size(); i++) {
+ EXPECT_THAT(
+ accept_counts[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+// Binds sockets to different devices and then sends many UDP packets. Checks
+// that the distribution of packets received on the sockets matches the
+// expectation.
+TEST_P(BindToDeviceDistributionTest, Udp) {
+ auto const& [listener_connector, test] = GetParam();
+
+ TestAddress const& listener = listener_connector.listener;
+ TestAddress const& connector = listener_connector.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+
+ auto interface_names = GetInterfaceNames();
+
+ // Create the listening socket.
+ std::vector<FileDescriptor> listener_fds;
+ std::vector<std::unique_ptr<Tunnel>> all_tunnels;
+ for (auto const& endpoint : test.endpoints) {
+ if (!endpoint.bind_to_device.empty() &&
+ interface_names.find(endpoint.bind_to_device) ==
+ interface_names.end()) {
+ all_tunnels.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device)));
+ interface_names.insert(endpoint.bind_to_device);
+ }
+
+ listener_fds.push_back(
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0)));
+ int fd = listener_fds.back().get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE,
+ endpoint.bind_to_device.c_str(),
+ endpoint.bind_to_device.size() + 1),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (listener_fds.size() > 1) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10000;
+ std::atomic<int> packets_received = ATOMIC_VAR_INIT(0);
+ std::vector<int> packets_per_socket(listener_fds.size(), 0);
+ std::vector<std::unique_ptr<ScopedThread>> receiver_threads(
+ listener_fds.size());
+
+ for (int i = 0; i < listener_fds.size(); i++) {
+ receiver_threads[i] = absl::make_unique<ScopedThread>(
+ [&listener_fds, &packets_per_socket, &packets_received, i]() {
+ do {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ int data;
+
+ auto ret = RetryEINTR(recvfrom)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr), &addrlen);
+
+ if (packets_received < kConnectAttempts) {
+ ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ if (ret != sizeof(data)) {
+ // Another thread may have shutdown our read side causing the
+ // recvfrom to fail.
+ break;
+ }
+
+ packets_received++;
+ packets_per_socket[i]++;
+
+ // A response is required to synchronize with the main thread,
+ // otherwise the main thread can send more than can fit into receive
+ // queues.
+ EXPECT_THAT(RetryEINTR(sendto)(
+ listener_fds[i].get(), &data, sizeof(data), 0,
+ reinterpret_cast<sockaddr*>(&addr), addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ } while (packets_received < kConnectAttempts);
+
+ // Shutdown all sockets to wake up other threads.
+ for (auto const& listener_fd : listener_fds) {
+ shutdown(listener_fd.get(), SHUT_RDWR);
+ }
+ });
+ }
+
+ for (int i = 0; i < kConnectAttempts; i++) {
+ FileDescriptor const fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ int data;
+ EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0),
+ SyscallSucceedsWithValue(sizeof(data)));
+ }
+
+ // Join threads to be sure that all connections have been counted.
+ for (auto const& receiver_thread : receiver_threads) {
+ receiver_thread->Join();
+ }
+ // Check that packets are distributed correctly among listening sockets.
+ for (int i = 0; i < packets_per_socket.size(); i++) {
+ EXPECT_THAT(
+ packets_per_socket[i],
+ EquivalentWithin(static_cast<int>(kConnectAttempts *
+ test.endpoints[i].expected_ratio),
+ 0.10))
+ << "endpoint " << i << " got the wrong number of packets";
+ }
+}
+
+std::vector<DistributionTestCase> GetDistributionTestCases() {
+ return std::vector<DistributionTestCase>{
+ {"Even distribution among sockets not bound to device",
+ {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}},
+ {"Sockets bound to other interfaces get no packets",
+ {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}},
+ {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}},
+ {"Even distribution among sockets bound to device",
+ {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}},
+ };
+}
+
+INSTANTIATE_TEST_SUITE_P(
+ BindToDeviceTest, BindToDeviceDistributionTest,
+ ::testing::Combine(::testing::Values(
+ // Listeners bound to IPv4 addresses refuse
+ // connections using IPv6 addresses.
+ ListenerConnector{V4Any(), V4Loopback()},
+ ListenerConnector{V4Loopback(), V4MappedLoopback()}),
+ ::testing::ValuesIn(GetDistributionTestCases())));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc
new file mode 100644
index 000000000..a7365d139
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc
@@ -0,0 +1,316 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <arpa/inet.h>
+#include <linux/capability.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "gtest/gtest.h"
+#include "test/syscalls/linux/ip_socket_test_util.h"
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+#include "test/syscalls/linux/socket_test_util.h"
+#include "test/util/capability_util.h"
+#include "test/util/test_util.h"
+#include "test/util/thread_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+using std::vector;
+
+// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket
+// binding.
+class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> {
+ protected:
+ void SetUp() override {
+ printf("Testing case: %s\n", GetParam().description.c_str());
+ ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)))
+ << "CAP_NET_RAW is required to use SO_BINDTODEVICE";
+ socket_factory_ = GetParam();
+
+ interface_names_ = GetInterfaceNames();
+ }
+
+ PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const {
+ return socket_factory_.Create();
+ }
+
+ // Gets a device by device_id. If the device_id has been seen before, returns
+ // the previously returned device. If not, finds or creates a new device.
+ // Returns an empty string on failure.
+ void GetDevice(int device_id, string *device_name) {
+ auto device = devices_.find(device_id);
+ if (device != devices_.end()) {
+ *device_name = device->second;
+ return;
+ }
+
+ // Need to pick a new device. Try ethernet first.
+ *device_name = absl::StrCat("eth", next_unused_eth_);
+ if (interface_names_.find(*device_name) != interface_names_.end()) {
+ devices_[device_id] = *device_name;
+ next_unused_eth_++;
+ return;
+ }
+
+ // Need to make a new tunnel device. gVisor tests should have enough
+ // ethernet devices to never reach here.
+ ASSERT_FALSE(IsRunningOnGvisor());
+ // Need a tunnel.
+ tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()));
+ devices_[device_id] = tunnels_.back()->GetName();
+ *device_name = devices_[device_id];
+ }
+
+ // Release the socket
+ void ReleaseSocket(int socket_id) {
+ // Close the socket that was made in a previous action. The socket_id
+ // indicates which socket to close based on index into the list of actions.
+ sockets_to_close_.erase(socket_id);
+ }
+
+ // Bind a socket with the reuse option and bind_to_device options. Checks
+ // that all steps succeed and that the bind command's error matches want.
+ // Sets the socket_id to uniquely identify the socket bound if it is not
+ // nullptr.
+ void BindSocket(bool reuse, int device_id = 0, int want = 0,
+ int *socket_id = nullptr) {
+ next_socket_id_++;
+ sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto socket_fd = sockets_to_close_[next_socket_id_]->get();
+ if (socket_id != nullptr) {
+ *socket_id = next_socket_id_;
+ }
+
+ // If reuse is indicated, do that.
+ if (reuse) {
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+ }
+
+ // If the device is non-zero, bind to that device.
+ if (device_id != 0) {
+ string device_name;
+ ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name));
+ EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE,
+ device_name.c_str(), device_name.size() + 1),
+ SyscallSucceedsWithValue(0));
+ char get_device[100];
+ socklen_t get_device_size = 100;
+ EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device,
+ &get_device_size),
+ SyscallSucceedsWithValue(0));
+ }
+
+ struct sockaddr_in addr = {};
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
+ addr.sin_port = port_;
+ if (want == 0) {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallSucceeds());
+ } else {
+ ASSERT_THAT(
+ bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr),
+ sizeof(addr)),
+ SyscallFailsWithErrno(want));
+ }
+
+ if (port_ == 0) {
+ // We don't yet know what port we'll be using so we need to fetch it and
+ // remember it for future commands.
+ socklen_t addr_size = sizeof(addr);
+ ASSERT_THAT(
+ getsockname(socket_fd, reinterpret_cast<struct sockaddr *>(&addr),
+ &addr_size),
+ SyscallSucceeds());
+ port_ = addr.sin_port;
+ }
+ }
+
+ private:
+ SocketKind socket_factory_;
+ // devices maps from the device id in the test case to the name of the device.
+ std::unordered_map<int, string> devices_;
+ // These are the tunnels that were created for the test and will be destroyed
+ // by the destructor.
+ vector<std::unique_ptr<Tunnel>> tunnels_;
+ // A list of all interface names before the test started.
+ std::unordered_set<string> interface_names_;
+ // The next ethernet device to use when requested a device.
+ int next_unused_eth_ = 1;
+ // The port for all tests. Originally 0 (any) and later set to the port that
+ // all further commands will use.
+ in_port_t port_ = 0;
+ // sockets_to_close_ is a map from action index to the socket that was
+ // created.
+ std::unordered_map<int,
+ std::unique_ptr<gvisor::testing::FileDescriptor>>
+ sockets_to_close_;
+ int next_socket_id_ = 0;
+};
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 3));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 1));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 2));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindWithReuse) {
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0));
+}
+
+TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 456));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindAndRelease) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 123));
+ int to_release;
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 789));
+ // Release the bind to device 0 and try again.
+ ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 345));
+}
+
+TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) {
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ false, /* bind_to_device */ 123));
+ ASSERT_NO_FATAL_FAILURE(
+ BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE));
+}
+
+INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest,
+ ::testing::Values(IPv4UDPUnboundSocket(0),
+ IPv4TCPUnboundSocket(0)));
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc
new file mode 100644
index 000000000..f4ee775bd
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.cc
@@ -0,0 +1,75 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/syscalls/linux/socket_bind_to_device_util.h"
+
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+using std::string;
+
+PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) {
+ int fd;
+ RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR));
+
+ // Using `new` to access a non-public constructor.
+ auto new_tunnel = absl::WrapUnique(new Tunnel(fd));
+
+ ifreq ifr = {};
+ ifr.ifr_flags = IFF_TUN;
+ strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name));
+
+ RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr));
+ new_tunnel->name_ = ifr.ifr_name;
+ return new_tunnel;
+}
+
+std::unordered_set<string> GetInterfaceNames() {
+ struct if_nameindex* interfaces = if_nameindex();
+ std::unordered_set<string> names;
+ if (interfaces == nullptr) {
+ return names;
+ }
+ for (auto interface = interfaces;
+ interface->if_index != 0 || interface->if_name != nullptr; interface++) {
+ names.insert(interface->if_name);
+ }
+ if_freenameindex(interfaces);
+ return names;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h
new file mode 100644
index 000000000..f941ccc86
--- /dev/null
+++ b/test/syscalls/linux/socket_bind_to_device_util.h
@@ -0,0 +1,67 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
+
+#include <arpa/inet.h>
+#include <linux/if_tun.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include <cstdio>
+#include <cstring>
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+class Tunnel {
+ public:
+ static PosixErrorOr<std::unique_ptr<Tunnel>> New(
+ std::string tunnel_name = "");
+ const std::string& GetName() const { return name_; }
+
+ ~Tunnel() {
+ if (fd_ != -1) {
+ close(fd_);
+ }
+ }
+
+ private:
+ Tunnel(int fd) : fd_(fd) {}
+ int fd_ = -1;
+ std::string name_;
+};
+
+std::unordered_set<std::string> GetInterfaceNames();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index a43cf9bce..bfa7943b1 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -117,7 +117,7 @@ TEST_P(TCPSocketPairTest, RSTCausesPollHUP) {
struct pollfd poll_fd3 = {sockets->first_fd(), POLLHUP, 0};
ASSERT_THAT(RetryEINTR(poll)(&poll_fd3, 1, kPollTimeoutMs),
SyscallSucceedsWithValue(1));
- ASSERT_NE(poll_fd.revents & (POLLHUP | POLLIN), 0);
+ ASSERT_NE(poll_fd3.revents & POLLHUP, 0);
}
// This test validates that even if a RST is sent the other end will not
diff --git a/test/syscalls/linux/socket_test_util.h b/test/syscalls/linux/socket_test_util.h
index 6efa8055f..70710195c 100644
--- a/test/syscalls/linux/socket_test_util.h
+++ b/test/syscalls/linux/socket_test_util.h
@@ -83,6 +83,8 @@ inline ssize_t SendFd(int fd, void* buf, size_t count, int flags) {
count);
}
+PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain);
+
// A Creator<T> is a function that attempts to create and return a new T. (This
// is copy/pasted from cloud/gvisor/api/sandbox_util.h and is just duplicated
// here for clarity.)
diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc
index d48453a93..6218fbce1 100644
--- a/test/syscalls/linux/uidgid.cc
+++ b/test/syscalls/linux/uidgid.cc
@@ -25,6 +25,7 @@
#include "test/util/posix_error.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"
+#include "test/util/uid_util.h"
ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID");
ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID");
@@ -68,30 +69,6 @@ TEST(UidGidTest, Getgroups) {
// here; see the setgroups test below.
}
-// If the caller's real/effective/saved user/group IDs are all 0, IsRoot returns
-// true. Otherwise IsRoot logs an explanatory message and returns false.
-PosixErrorOr<bool> IsRoot() {
- uid_t ruid, euid, suid;
- int rc = getresuid(&ruid, &euid, &suid);
- MaybeSave();
- if (rc < 0) {
- return PosixError(errno, "getresuid");
- }
- if (ruid != 0 || euid != 0 || suid != 0) {
- return false;
- }
- gid_t rgid, egid, sgid;
- rc = getresgid(&rgid, &egid, &sgid);
- MaybeSave();
- if (rc < 0) {
- return PosixError(errno, "getresgid");
- }
- if (rgid != 0 || egid != 0 || sgid != 0) {
- return false;
- }
- return true;
-}
-
// Checks that the calling process' real/effective/saved user IDs are
// ruid/euid/suid respectively.
PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) {
diff --git a/test/syscalls/linux/uname.cc b/test/syscalls/linux/uname.cc
index 0a5d91017..d8824b171 100644
--- a/test/syscalls/linux/uname.cc
+++ b/test/syscalls/linux/uname.cc
@@ -41,6 +41,19 @@ TEST(UnameTest, Sanity) {
TEST(UnameTest, SetNames) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ char hostname[65];
+ ASSERT_THAT(sethostname("0123456789", 3), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "012");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 11), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
+ ASSERT_THAT(sethostname("0123456789\0xxx", 12), SyscallSucceeds());
+ EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
+ EXPECT_EQ(absl::string_view(hostname), "0123456789");
+
constexpr char kHostname[] = "wubbalubba";
ASSERT_THAT(sethostname(kHostname, sizeof(kHostname)), SyscallSucceeds());
@@ -54,7 +67,6 @@ TEST(UnameTest, SetNames) {
EXPECT_EQ(absl::string_view(buf.domainname), kDomainname);
// These should just be glibc wrappers that also call uname(2).
- char hostname[65];
EXPECT_THAT(gethostname(hostname, sizeof(hostname)), SyscallSucceeds());
EXPECT_EQ(absl::string_view(hostname), kHostname);
diff --git a/test/util/BUILD b/test/util/BUILD
index 25ed9c944..5d2a9cc2c 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -324,3 +324,14 @@ cc_library(
":test_util",
],
)
+
+cc_library(
+ name = "uid_util",
+ testonly = 1,
+ srcs = ["uid_util.cc"],
+ hdrs = ["uid_util.h"],
+ deps = [
+ ":posix_error",
+ ":save_util",
+ ],
+)
diff --git a/test/util/proc_util.cc b/test/util/proc_util.cc
index 75b24da37..34d636ba9 100644
--- a/test/util/proc_util.cc
+++ b/test/util/proc_util.cc
@@ -88,7 +88,7 @@ PosixErrorOr<std::vector<ProcMapsEntry>> ParseProcMaps(
std::vector<ProcMapsEntry> entries;
auto lines = absl::StrSplit(contents, '\n', absl::SkipEmpty());
for (const auto& l : lines) {
- std::cout << "line: " << l;
+ std::cout << "line: " << l << std::endl;
ASSIGN_OR_RETURN_ERRNO(auto entry, ParseProcMapsLine(l));
entries.push_back(entry);
}
diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc
new file mode 100644
index 000000000..b131b4b99
--- /dev/null
+++ b/test/util/uid_util.cc
@@ -0,0 +1,44 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/posix_error.h"
+#include "test/util/save_util.h"
+
+namespace gvisor {
+namespace testing {
+
+PosixErrorOr<bool> IsRoot() {
+ uid_t ruid, euid, suid;
+ int rc = getresuid(&ruid, &euid, &suid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresuid");
+ }
+ if (ruid != 0 || euid != 0 || suid != 0) {
+ return false;
+ }
+ gid_t rgid, egid, sgid;
+ rc = getresgid(&rgid, &egid, &sgid);
+ MaybeSave();
+ if (rc < 0) {
+ return PosixError(errno, "getresgid");
+ }
+ if (rgid != 0 || egid != 0 || sgid != 0) {
+ return false;
+ }
+ return true;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/uid_util.h b/test/util/uid_util.h
new file mode 100644
index 000000000..2cd387fb0
--- /dev/null
+++ b/test/util/uid_util.h
@@ -0,0 +1,29 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_
+
+#include "test/util/posix_error.h"
+
+namespace gvisor {
+namespace testing {
+
+// Returns true if the caller's real/effective/saved user/group IDs are all 0.
+PosixErrorOr<bool> IsRoot();
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_