summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.buildkite/pipeline.yaml3
-rw-r--r--Makefile11
-rw-r--r--nogo.yaml2
-rw-r--r--pkg/abi/linux/fs.go3
-rw-r--r--pkg/abi/linux/ptrace_amd64.go5
-rw-r--r--pkg/abi/linux/ptrace_arm64.go5
-rw-r--r--pkg/merkletree/merkletree.go1
-rw-r--r--pkg/ring0/kernel_amd64.go10
-rw-r--r--pkg/ring0/kernel_arm64.go8
-rw-r--r--pkg/ring0/lib_arm64.go3
-rw-r--r--pkg/ring0/lib_arm64.s8
-rw-r--r--pkg/sentry/devices/memdev/zero.go1
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/BUILD47
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/base.go233
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go412
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpu.go70
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpuacct.go114
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cpuset.go39
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/memory.go74
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go19
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go1
-rw-r--r--pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go10
-rw-r--r--pkg/sentry/fsimpl/kernfs/kernfs.go10
-rw-r--r--pkg/sentry/fsimpl/proc/filesystem.go6
-rw-r--r--pkg/sentry/fsimpl/proc/task.go23
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go29
-rw-r--r--pkg/sentry/fsimpl/proc/tasks.go19
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_files.go16
-rw-r--r--pkg/sentry/fsimpl/tmpfs/regular_file.go1
-rw-r--r--pkg/sentry/fsimpl/verity/BUILD2
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go5
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go266
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go3
-rw-r--r--pkg/sentry/kernel/BUILD3
-rw-r--r--pkg/sentry/kernel/cgroup.go281
-rw-r--r--pkg/sentry/kernel/kernel.go52
-rw-r--r--pkg/sentry/kernel/task.go6
-rw-r--r--pkg/sentry/kernel/task_cgroup.go138
-rw-r--r--pkg/sentry/kernel/task_exit.go4
-rw-r--r--pkg/sentry/kernel/task_start.go5
-rw-r--r--pkg/sentry/kernel/threads.go9
-rw-r--r--pkg/sentry/memmap/memmap.go5
-rw-r--r--pkg/sentry/platform/kvm/BUILD2
-rw-r--r--pkg/sentry/platform/kvm/kvm_amd64_test.go11
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go2
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go21
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go21
-rw-r--r--pkg/sentry/vfs/file_description_impl_util.go3
-rw-r--r--pkg/sentry/vfs/mount.go17
-rw-r--r--pkg/tcpip/header/ipv4.go32
-rw-r--r--pkg/tcpip/header/ipv4_test.go75
-rw-r--r--pkg/tcpip/header/ipv6.go87
-rw-r--r--pkg/tcpip/header/ipv6_test.go86
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol.go57
-rw-r--r--pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go27
-rw-r--r--pkg/tcpip/network/ip_test.go78
-rw-r--r--pkg/tcpip/network/ipv4/igmp.go12
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go75
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go4
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go79
-rw-r--r--pkg/tcpip/network/ipv6/mld.go22
-rw-r--r--pkg/tcpip/network/ipv6/mld_test.go140
-rw-r--r--pkg/tcpip/network/ipv6/ndp.go10
-rw-r--r--pkg/tcpip/network/multicast_group_test.go6
-rw-r--r--pkg/tcpip/stack/ndp_test.go6
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/stack/stack.go4
-rw-r--r--pkg/tcpip/tcpip.go20
-rw-r--r--pkg/tcpip/tests/integration/BUILD2
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go192
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go9
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go9
-rw-r--r--pkg/tcpip/tests/utils/utils.go8
-rw-r--r--pkg/tcpip/transport/tcp/accept.go179
-rw-r--r--pkg/tcpip/transport/tcp/connect.go20
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go74
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go14
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go85
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go8
-rw-r--r--pkg/test/dockerutil/container.go9
-rw-r--r--runsc/BUILD1
-rw-r--r--runsc/boot/BUILD2
-rw-r--r--runsc/boot/controller.go2
-rw-r--r--runsc/boot/fs.go46
-rw-r--r--runsc/boot/loader.go2
-rw-r--r--runsc/boot/loader_test.go17
-rw-r--r--runsc/boot/vfs.go94
-rw-r--r--runsc/cli/main.go1
-rw-r--r--runsc/cmd/BUILD1
-rw-r--r--runsc/cmd/do.go108
-rw-r--r--runsc/cmd/verity_prepare.go108
-rw-r--r--runsc/config/config.go3
-rw-r--r--runsc/config/flags.go1
-rw-r--r--runsc/sandbox/sandbox.go4
-rw-r--r--runsc/specutils/fs.go18
-rw-r--r--shim/BUILD1
-rw-r--r--test/benchmarks/base/BUILD3
-rw-r--r--test/benchmarks/database/BUILD1
-rw-r--r--test/benchmarks/fs/BUILD2
-rw-r--r--test/benchmarks/media/BUILD1
-rw-r--r--test/benchmarks/ml/BUILD1
-rw-r--r--test/benchmarks/network/BUILD5
-rw-r--r--test/e2e/BUILD1
-rw-r--r--test/e2e/integration_test.go31
-rw-r--r--test/e2e/regression_test.go47
-rw-r--r--test/fsstress/BUILD4
-rw-r--r--test/fsstress/fsstress_test.go43
-rw-r--r--test/image/image_test.go5
-rw-r--r--test/packetimpact/runner/defs.bzl6
-rw-r--r--test/packetimpact/tests/BUILD20
-rw-r--r--test/packetimpact/tests/tcp_listen_backlog_test.go86
-rw-r--r--test/packetimpact/tests/tcp_syncookie_test.go70
-rw-r--r--test/perf/BUILD9
-rw-r--r--test/perf/linux/getpid_benchmark.cc18
-rw-r--r--test/runtimes/defs.bzl1
-rw-r--r--test/syscalls/BUILD18
-rw-r--r--test/syscalls/linux/BUILD52
-rw-r--r--test/syscalls/linux/accept_bind.cc36
-rw-r--r--test/syscalls/linux/cgroup.cc421
-rw-r--r--test/syscalls/linux/fpsig_fork.cc57
-rw-r--r--test/syscalls/linux/semaphore.cc8
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc334
-rw-r--r--test/syscalls/linux/verity_ioctl.cc133
-rw-r--r--test/syscalls/linux/verity_mount.cc53
-rw-r--r--test/util/BUILD18
-rw-r--r--test/util/cgroup_util.cc223
-rw-r--r--test/util/cgroup_util.h111
-rw-r--r--test/util/fs_util.cc44
-rw-r--r--test/util/fs_util.h12
-rw-r--r--tools/nogo/analyzers.go6
-rw-r--r--tools/nogo/check/main.go17
-rw-r--r--tools/nogo/defs.bzl42
134 files changed, 5007 insertions, 829 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml
index 3bc5041c0..9163db56d 100644
--- a/.buildkite/pipeline.yaml
+++ b/.buildkite/pipeline.yaml
@@ -90,6 +90,9 @@ steps:
label: ":person_in_lotus_position: KVM tests"
command: make kvm-tests
- <<: *common
+ label: ":weight_lifter: Fsstress test"
+ command: make fsstress-test
+ - <<: *common
label: ":docker: Containerd 1.3.9 tests"
command: make containerd-test-1.3.9
- <<: *common
diff --git a/Makefile b/Makefile
index 0f79b6a18..ea0674f77 100644
--- a/Makefile
+++ b/Makefile
@@ -144,6 +144,7 @@ dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo.
@$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile)
@$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2)
@$(call configure_noreload,$(RUNTIME)-vfs2-fuse-d,--net-raw --debug --strace --log-packets --vfs2 --fuse)
+ @$(call configure_noreload,$(RUNTIME)-vfs2-cgroup-d,--net-raw --debug --strace --log-packets --vfs2 --cgroupfs)
@$(call reload_docker)
.PHONY: dev
@@ -340,7 +341,8 @@ BENCHMARKS_FILTER := .
BENCHMARKS_OPTIONS := -test.benchtime=30s
BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) $(BENCHMARKS_OPTIONS)
BENCHMARKS_PROFILE := -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex
-BENCH_RUNTIME_ARGS ?= --vfs2
+BENCH_VFS := --vfs2
+BENCH_RUNTIME_ARGS ?=
init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema.
@$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE))
@@ -361,13 +363,14 @@ run_benchmark = \
benchmark-platforms: load-benchmarks $(RUNTIME_BIN) ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS.
@$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \
- $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \
- ) true
+ $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS) --vfs2) && \
+ $(call run_benchmark,$(PLATFORM)_vfs1,--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \
+ ) true
@$(call run_benchmark,runc)
.PHONY: benchmark-platforms
run-benchmark: load-benchmarks $(RUNTIME_BIN) ## Runs single benchmark and optionally sends data to BigQuery.
- @$(call run_benchmark,$(RUNTIME),$(BENCH_RUNTIME_ARGS))
+ @$(call run_benchmark,$(RUNTIME)$(BENCH_VFS),$(BENCH_RUNTIME_ARGS) $(BENCH_VFS))
.PHONY: run-benchmark
##
diff --git a/nogo.yaml b/nogo.yaml
index c0445a837..1e72d9e29 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -55,8 +55,6 @@ global:
# Same story for underscores.
- "should not use ALL_CAPS in Go names"
- "should not use underscores in Go names"
- # TODO(b/179817829): Upgrade to flock to v0.8.0.
- - "flock.NewFlock is deprecated: Use New instead"
exclude:
# Generated: exempt all.
- pkg/shim/runtimeoptions/runtimeoptions_cri.go
diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go
index 0d921ed6f..cad24fcc7 100644
--- a/pkg/abi/linux/fs.go
+++ b/pkg/abi/linux/fs.go
@@ -19,8 +19,10 @@ package linux
// See linux/magic.h.
const (
ANON_INODE_FS_MAGIC = 0x09041934
+ CGROUP_SUPER_MAGIC = 0x27e0eb
DEVPTS_SUPER_MAGIC = 0x00001cd1
EXT_SUPER_MAGIC = 0xef53
+ FUSE_SUPER_MAGIC = 0x65735546
OVERLAYFS_SUPER_MAGIC = 0x794c7630
PIPEFS_MAGIC = 0x50495045
PROC_SUPER_MAGIC = 0x9fa0
@@ -29,7 +31,6 @@ const (
SYSFS_MAGIC = 0x62656572
TMPFS_MAGIC = 0x01021994
V9FS_MAGIC = 0x01021997
- FUSE_SUPER_MAGIC = 0x65735546
)
// Filesystem path limits, from uapi/linux/limits.h.
diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go
index 50e22fe7e..e722971f1 100644
--- a/pkg/abi/linux/ptrace_amd64.go
+++ b/pkg/abi/linux/ptrace_amd64.go
@@ -61,3 +61,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 {
func (p *PtraceRegs) StackPointer() uint64 {
return p.Rsp
}
+
+// SetStackPointer sets the stack pointer to the specified value.
+func (p *PtraceRegs) SetStackPointer(sp uint64) {
+ p.Rsp = sp
+}
diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go
index da36811d2..3d0906565 100644
--- a/pkg/abi/linux/ptrace_arm64.go
+++ b/pkg/abi/linux/ptrace_arm64.go
@@ -38,3 +38,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 {
func (p *PtraceRegs) StackPointer() uint64 {
return p.Sp
}
+
+// SetStackPointer sets the stack pointer to the specified value.
+func (p *PtraceRegs) SetStackPointer(sp uint64) {
+ p.Sp = sp
+}
diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go
index 961bd4dcf..6450f664c 100644
--- a/pkg/merkletree/merkletree.go
+++ b/pkg/merkletree/merkletree.go
@@ -238,6 +238,7 @@ func Generate(params *GenerateParams) ([]byte, error) {
Mode: params.Mode,
UID: params.UID,
GID: params.GID,
+ Children: params.Children,
SymlinkTarget: params.SymlinkTarget,
}
diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go
index 6e17fb796..41dfd0bf9 100644
--- a/pkg/ring0/kernel_amd64.go
+++ b/pkg/ring0/kernel_amd64.go
@@ -250,7 +250,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
}
SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point.
WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS.
- ldmxcsr(&kernelMXCSR) // escapes: no. Restore kernel MXCSR.
+ RestoreKernelFPState() // escapes: no. Restore kernel MXCSR.
return
}
@@ -329,6 +329,14 @@ func ReadCR2() uintptr {
// at src/cmd/compile/abi-internal.md in the golang sources for more details.
var kernelMXCSR uint32
+// RestoreKernelFPState restores the Sentry floating point state.
+//
+//go:nosplit
+func RestoreKernelFPState() {
+ // Restore the MXCSR control configuration.
+ ldmxcsr(&kernelMXCSR)
+}
+
func init() {
stmxcsr(&kernelMXCSR)
}
diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go
index 7975e5f92..21db910a2 100644
--- a/pkg/ring0/kernel_arm64.go
+++ b/pkg/ring0/kernel_arm64.go
@@ -65,7 +65,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer())
if switchOpts.Flush {
- FlushTlbByASID(uintptr(switchOpts.UserASID))
+ LocalFlushTlbByASID(uintptr(switchOpts.UserASID))
}
regs := switchOpts.Registers
@@ -89,3 +89,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
return
}
+
+// RestoreKernelFPState restores the Sentry floating point state.
+//
+//go:nosplit
+func RestoreKernelFPState() {
+}
diff --git a/pkg/ring0/lib_arm64.go b/pkg/ring0/lib_arm64.go
index e44df00a6..5eabd4296 100644
--- a/pkg/ring0/lib_arm64.go
+++ b/pkg/ring0/lib_arm64.go
@@ -31,6 +31,9 @@ func FlushTlbByVA(addr uintptr)
// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable.
func FlushTlbByASID(asid uintptr)
+// LocalFlushTlbByASID invalidates tlb by ASID.
+func LocalFlushTlbByASID(asid uintptr)
+
// FlushTlbAll invalidates all tlb.
func FlushTlbAll()
diff --git a/pkg/ring0/lib_arm64.s b/pkg/ring0/lib_arm64.s
index e39b32841..69ebaf519 100644
--- a/pkg/ring0/lib_arm64.s
+++ b/pkg/ring0/lib_arm64.s
@@ -32,6 +32,14 @@ TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8
DSB $11 // dsb(ish)
RET
+TEXT ·LocalFlushTlbByASID(SB),NOSPLIT,$0-8
+ MOVD asid+0(FP), R1
+ LSL $TLBI_ASID_SHIFT, R1, R1
+ DSB $10 // dsb(ishst)
+ WORD $0xd5088741 // tlbi aside1, x1
+ DSB $11 // dsb(ish)
+ RET
+
TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0
DSB $6 // dsb(nshst)
WORD $0xd508871f // __tlbi(vmalle1)
diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go
index 1929e41cd..49c53452a 100644
--- a/pkg/sentry/devices/memdev/zero.go
+++ b/pkg/sentry/devices/memdev/zero.go
@@ -93,6 +93,7 @@ func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro
// "/dev/zero (deleted)".
opts.Offset = 0
opts.MappingIdentity = &fd.vfsfd
+ opts.SentryOwnedContent = true
opts.MappingIdentity.IncRef()
return nil
}
diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD
new file mode 100644
index 000000000..48913068a
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/BUILD
@@ -0,0 +1,47 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+licenses(["notice"])
+
+go_template_instance(
+ name = "dir_refs",
+ out = "dir_refs.go",
+ package = "cgroupfs",
+ prefix = "dir",
+ template = "//pkg/refsvfs2:refs_template",
+ types = {
+ "T": "dir",
+ },
+)
+
+go_library(
+ name = "cgroupfs",
+ srcs = [
+ "base.go",
+ "cgroupfs.go",
+ "cpu.go",
+ "cpuacct.go",
+ "cpuset.go",
+ "dir_refs.go",
+ "memory.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/coverage",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/refsvfs2",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/fsimpl/kernfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
+ "//pkg/sentry/usage",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserror",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go
new file mode 100644
index 000000000..39c1013e1
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/base.go
@@ -0,0 +1,233 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// controllerCommon implements kernel.CgroupController.
+//
+// Must call init before use.
+//
+// +stateify savable
+type controllerCommon struct {
+ ty kernel.CgroupControllerType
+ fs *filesystem
+}
+
+func (c *controllerCommon) init(ty kernel.CgroupControllerType, fs *filesystem) {
+ c.ty = ty
+ c.fs = fs
+}
+
+// Type implements kernel.CgroupController.Type.
+func (c *controllerCommon) Type() kernel.CgroupControllerType {
+ return kernel.CgroupControllerType(c.ty)
+}
+
+// HierarchyID implements kernel.CgroupController.HierarchyID.
+func (c *controllerCommon) HierarchyID() uint32 {
+ return c.fs.hierarchyID
+}
+
+// NumCgroups implements kernel.CgroupController.NumCgroups.
+func (c *controllerCommon) NumCgroups() uint64 {
+ return atomic.LoadUint64(&c.fs.numCgroups)
+}
+
+// Enabled implements kernel.CgroupController.Enabled.
+//
+// Controllers are currently always enabled.
+func (c *controllerCommon) Enabled() bool {
+ return true
+}
+
+// Filesystem implements kernel.CgroupController.Filesystem.
+func (c *controllerCommon) Filesystem() *vfs.Filesystem {
+ return c.fs.VFSFilesystem()
+}
+
+// RootCgroup implements kernel.CgroupController.RootCgroup.
+func (c *controllerCommon) RootCgroup() kernel.Cgroup {
+ return c.fs.rootCgroup()
+}
+
+// controller is an interface for common functionality related to all cgroups.
+// It is an extension of the public cgroup interface, containing cgroup
+// functionality private to cgroupfs.
+type controller interface {
+ kernel.CgroupController
+
+ // AddControlFiles should extend the contents map with inodes representing
+ // control files defined by this controller.
+ AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode)
+}
+
+// cgroupInode implements kernel.CgroupImpl and kernfs.Inode.
+//
+// +stateify savable
+type cgroupInode struct {
+ dir
+ fs *filesystem
+
+ // ts is the list of tasks in this cgroup. The kernel is responsible for
+ // removing tasks from this list before they're destroyed, so any tasks on
+ // this list are always valid.
+ //
+ // ts, and cgroup membership in general is protected by fs.tasksMu.
+ ts map[*kernel.Task]struct{}
+}
+
+var _ kernel.CgroupImpl = (*cgroupInode)(nil)
+
+func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode {
+ c := &cgroupInode{
+ fs: fs,
+ ts: make(map[*kernel.Task]struct{}),
+ }
+
+ contents := make(map[string]kernfs.Inode)
+ contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c})
+ contents["tasks"] = fs.newControllerFile(ctx, creds, &tasksData{c})
+
+ for _, ctl := range fs.controllers {
+ ctl.AddControlFiles(ctx, creds, c, contents)
+ }
+
+ c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555))
+ c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{})
+ c.dir.InitRefs()
+ c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents))
+
+ atomic.AddUint64(&fs.numCgroups, 1)
+
+ return c
+}
+
+func (c *cgroupInode) HierarchyID() uint32 {
+ return c.fs.hierarchyID
+}
+
+// Controllers implements kernel.CgroupImpl.Controllers.
+func (c *cgroupInode) Controllers() []kernel.CgroupController {
+ return c.fs.kcontrollers
+}
+
+// Enter implements kernel.CgroupImpl.Enter.
+func (c *cgroupInode) Enter(t *kernel.Task) {
+ c.fs.tasksMu.Lock()
+ c.ts[t] = struct{}{}
+ c.fs.tasksMu.Unlock()
+}
+
+// Leave implements kernel.CgroupImpl.Leave.
+func (c *cgroupInode) Leave(t *kernel.Task) {
+ c.fs.tasksMu.Lock()
+ delete(c.ts, t)
+ c.fs.tasksMu.Unlock()
+}
+
+func sortTIDs(tids []kernel.ThreadID) {
+ sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] })
+}
+
+// +stateify savable
+type cgroupProcsData struct {
+ *cgroupInode
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ t := kernel.TaskFromContext(ctx)
+ currPidns := t.ThreadGroup().PIDNamespace()
+
+ pgids := make(map[kernel.ThreadID]struct{})
+
+ d.fs.tasksMu.RLock()
+ defer d.fs.tasksMu.RUnlock()
+
+ for task := range d.ts {
+ // Map dedups pgid, since iterating over all tasks produces multiple
+ // entries for the group leaders.
+ if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 {
+ pgids[pgid] = struct{}{}
+ }
+ }
+
+ pgidList := make([]kernel.ThreadID, 0, len(pgids))
+ for pgid, _ := range pgids {
+ pgidList = append(pgidList, pgid)
+ }
+ sortTIDs(pgidList)
+
+ for _, pgid := range pgidList {
+ fmt.Fprintf(buf, "%d\n", pgid)
+ }
+
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // TODO(b/183137098): Payload is the pid for a process to add to this cgroup.
+ return src.NumBytes(), nil
+}
+
+// +stateify savable
+type tasksData struct {
+ *cgroupInode
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ t := kernel.TaskFromContext(ctx)
+ currPidns := t.ThreadGroup().PIDNamespace()
+
+ var pids []kernel.ThreadID
+
+ d.fs.tasksMu.RLock()
+ defer d.fs.tasksMu.RUnlock()
+
+ for task := range d.ts {
+ if pid := currPidns.IDOfTask(task); pid != 0 {
+ pids = append(pids, pid)
+ }
+ }
+ sortTIDs(pids)
+
+ for _, pid := range pids {
+ fmt.Fprintf(buf, "%d\n", pid)
+ }
+
+ return nil
+}
+
+// Write implements vfs.WritableDynamicBytesSource.Write.
+func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) {
+ // TODO(b/183137098): Payload is the pid for a process to add to this cgroup.
+ return src.NumBytes(), nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
new file mode 100644
index 000000000..ca8caee5f
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -0,0 +1,412 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package cgroupfs implements cgroupfs.
+//
+// A cgroup is a collection of tasks on the system, organized into a tree-like
+// structure similar to a filesystem directory tree. In fact, each cgroup is
+// represented by a directory on cgroupfs, and is manipulated through control
+// files in the directory.
+//
+// All cgroups on a system are organized into hierarchies. Hierarchies are a
+// distinct tree of cgroups, with a common set of controllers. One or more
+// cgroupfs mounts may point to each hierarchy. These mounts provide a common
+// view into the same tree of cgroups.
+//
+// A controller (also known as a "resource controller", or a cgroup "subsystem")
+// determines the behaviour of each cgroup.
+//
+// In addition to cgroupfs, the kernel has a cgroup registry that tracks
+// system-wide state related to cgroups such as active hierarchies and the
+// controllers associated with them.
+//
+// Since cgroupfs doesn't allow hardlinks, there is a unique mapping between
+// cgroupfs dentries and inodes.
+//
+// # Synchronization
+//
+// Cgroup hierarchy creation and destruction is protected by the
+// kernel.CgroupRegistry.mu. Once created, a hierarchy's set of controllers, the
+// filesystem associated with it, and the root cgroup for the hierarchy are
+// immutable.
+//
+// Membership of tasks within cgroups is protected by
+// cgroupfs.filesystem.tasksMu. Tasks also maintain a set of all cgroups they're
+// in, and this list is protected by Task.mu.
+//
+// Lock order:
+//
+// kernel.CgroupRegistry.mu
+// cgroupfs.filesystem.mu
+// Task.mu
+// cgroupfs.filesystem.tasksMu.
+package cgroupfs
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+const (
+ // Name is the default filesystem name.
+ Name = "cgroup"
+ readonlyFileMode = linux.FileMode(0444)
+ writableFileMode = linux.FileMode(0644)
+ defaultMaxCachedDentries = uint64(1000)
+)
+
+const (
+ controllerCPU = kernel.CgroupControllerType("cpu")
+ controllerCPUAcct = kernel.CgroupControllerType("cpuacct")
+ controllerCPUSet = kernel.CgroupControllerType("cpuset")
+ controllerMemory = kernel.CgroupControllerType("memory")
+)
+
+var allControllers = []kernel.CgroupControllerType{controllerCPU, controllerCPUAcct, controllerCPUSet, controllerMemory}
+
+// SupportedMountOptions is the set of supported mount options for cgroupfs.
+var SupportedMountOptions = []string{"all", "cpu", "cpuacct", "cpuset", "memory"}
+
+// FilesystemType implements vfs.FilesystemType.
+//
+// +stateify savable
+type FilesystemType struct{}
+
+// InternalData contains internal data passed in to the cgroupfs mount via
+// vfs.GetFilesystemOptions.InternalData.
+//
+// +stateify savable
+type InternalData struct {
+ DefaultControlValues map[string]int64
+}
+
+// filesystem implements vfs.FilesystemImpl.
+//
+// +stateify savable
+type filesystem struct {
+ kernfs.Filesystem
+ devMinor uint32
+
+ // hierarchyID is the id the cgroup registry assigns to this hierarchy. Has
+ // the value kernel.InvalidCgroupHierarchyID until the FS is fully
+ // initialized.
+ //
+ // hierarchyID is immutable after initialization.
+ hierarchyID uint32
+
+ // controllers and kcontrollers are both the list of controllers attached to
+ // this cgroupfs. Both lists are the same set of controllers, but typecast
+ // to different interfaces for convenience. Both must stay in sync, and are
+ // immutable.
+ controllers []controller
+ kcontrollers []kernel.CgroupController
+
+ numCgroups uint64 // Protected by atomic ops.
+
+ root *kernfs.Dentry
+
+ // tasksMu serializes task membership changes across all cgroups within a
+ // filesystem.
+ tasksMu sync.RWMutex `state:"nosave"`
+}
+
+// Name implements vfs.FilesystemType.Name.
+func (FilesystemType) Name() string {
+ return Name
+}
+
+// Release implements vfs.FilesystemType.Release.
+func (FilesystemType) Release(ctx context.Context) {}
+
+// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
+func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ devMinor, err := vfsObj.GetAnonBlockDevMinor()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ maxCachedDentries := defaultMaxCachedDentries
+ if str, ok := mopts["dentry_cache_limit"]; ok {
+ delete(mopts, "dentry_cache_limit")
+ maxCachedDentries, err = strconv.ParseUint(str, 10, 64)
+ if err != nil {
+ ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str)
+ return nil, nil, syserror.EINVAL
+ }
+ }
+
+ var wantControllers []kernel.CgroupControllerType
+ if _, ok := mopts["cpu"]; ok {
+ delete(mopts, "cpu")
+ wantControllers = append(wantControllers, controllerCPU)
+ }
+ if _, ok := mopts["cpuacct"]; ok {
+ delete(mopts, "cpuacct")
+ wantControllers = append(wantControllers, controllerCPUAcct)
+ }
+ if _, ok := mopts["cpuset"]; ok {
+ delete(mopts, "cpuset")
+ wantControllers = append(wantControllers, controllerCPUSet)
+ }
+ if _, ok := mopts["memory"]; ok {
+ delete(mopts, "memory")
+ wantControllers = append(wantControllers, controllerMemory)
+ }
+ if _, ok := mopts["all"]; ok {
+ if len(wantControllers) > 0 {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers)
+ return nil, nil, syserror.EINVAL
+ }
+
+ delete(mopts, "all")
+ wantControllers = allControllers
+ }
+
+ if len(wantControllers) == 0 {
+ // Specifying no controllers implies all controllers.
+ wantControllers = allControllers
+ }
+
+ if len(mopts) != 0 {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ k := kernel.KernelFromContext(ctx)
+ r := k.CgroupRegistry()
+
+ // "It is not possible to mount the same controller against multiple
+ // cgroup hierarchies. For example, it is not possible to mount both
+ // the cpu and cpuacct controllers against one hierarchy, and to mount
+ // the cpu controller alone against another hierarchy." - man cgroups(7)
+ //
+ // Is there a hierarchy available with all the controllers we want? If so,
+ // this mount is a view into the same hierarchy.
+ //
+ // Note: we're guaranteed to have at least one requested controller, since
+ // no explicit controller name implies all controllers.
+ if vfsfs := r.FindHierarchy(wantControllers); vfsfs != nil {
+ fs := vfsfs.Impl().(*filesystem)
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID)
+ fs.root.IncRef()
+ return vfsfs, fs.root.VFSDentry(), nil
+ }
+
+ // No existing hierarchy with the exactly controllers found. Make a new
+ // one. Note that it's possible this mount creation is unsatisfiable, if one
+ // or more of the requested controllers are already on existing
+ // hierarchies. We'll find out about such collisions when we try to register
+ // the new hierarchy later.
+ fs := &filesystem{
+ devMinor: devMinor,
+ }
+ fs.MaxCachedDentries = maxCachedDentries
+ fs.VFSFilesystem().Init(vfsObj, &fsType, fs)
+
+ var defaults map[string]int64
+ if opts.InternalData != nil {
+ ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults)
+ defaults = opts.InternalData.(*InternalData).DefaultControlValues
+ }
+
+ for _, ty := range wantControllers {
+ var c controller
+ switch ty {
+ case controllerMemory:
+ c = newMemoryController(fs, defaults)
+ case controllerCPU:
+ c = newCPUController(fs, defaults)
+ case controllerCPUAcct:
+ c = newCPUAcctController(fs)
+ case controllerCPUSet:
+ c = newCPUSetController(fs)
+ default:
+ panic(fmt.Sprintf("Unreachable: unknown cgroup controller %q", ty))
+ }
+ fs.controllers = append(fs.controllers, c)
+ }
+
+ if len(defaults) != 0 {
+ // Internal data is always provided at sentry startup and unused values
+ // indicate a problem with the sandbox config. Fail fast.
+ panic(fmt.Sprintf("cgroupfs.FilesystemType.GetFilesystem: unknown internal mount data: %v", defaults))
+ }
+
+ // Controllers usually appear in alphabetical order when displayed. Sort it
+ // here now, so it never needs to be sorted elsewhere.
+ sort.Slice(fs.controllers, func(i, j int) bool { return fs.controllers[i].Type() < fs.controllers[j].Type() })
+ fs.kcontrollers = make([]kernel.CgroupController, 0, len(fs.controllers))
+ for _, c := range fs.controllers {
+ fs.kcontrollers = append(fs.kcontrollers, c)
+ }
+
+ root := fs.newCgroupInode(ctx, creds)
+ var rootD kernfs.Dentry
+ rootD.InitRoot(&fs.Filesystem, root)
+ fs.root = &rootD
+
+ // Register controllers. The registry may be modified concurrently, so if we
+ // get an error, we raced with someone else who registered the same
+ // controllers first.
+ hid, err := r.Register(fs.kcontrollers)
+ if err != nil {
+ ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err)
+ rootD.DecRef(ctx)
+ fs.VFSFilesystem().DecRef(ctx)
+ return nil, nil, syserror.EBUSY
+ }
+ fs.hierarchyID = hid
+
+ // Move all existing tasks to the root of the new hierarchy.
+ k.PopulateNewCgroupHierarchy(fs.rootCgroup())
+
+ return fs.VFSFilesystem(), rootD.VFSDentry(), nil
+}
+
+func (fs *filesystem) rootCgroup() kernel.Cgroup {
+ return kernel.Cgroup{
+ Dentry: fs.root,
+ CgroupImpl: fs.root.Inode().(kernel.CgroupImpl),
+ }
+}
+
+// Release implements vfs.FilesystemImpl.Release.
+func (fs *filesystem) Release(ctx context.Context) {
+ k := kernel.KernelFromContext(ctx)
+ r := k.CgroupRegistry()
+
+ if fs.hierarchyID != kernel.InvalidCgroupHierarchyID {
+ k.ReleaseCgroupHierarchy(fs.hierarchyID)
+ r.Unregister(fs.hierarchyID)
+ }
+
+ fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor)
+ fs.Filesystem.Release(ctx)
+}
+
+// MountOptions implements vfs.FilesystemImpl.MountOptions.
+func (fs *filesystem) MountOptions() string {
+ var cnames []string
+ for _, c := range fs.controllers {
+ cnames = append(cnames, string(c.Type()))
+ }
+ return strings.Join(cnames, ",")
+}
+
+// +stateify savable
+type implStatFS struct{}
+
+// StatFS implements kernfs.Inode.StatFS.
+func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) {
+ return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil
+}
+
+// dir implements kernfs.Inode for a generic cgroup resource controller
+// directory. Specific controllers extend this to add their own functionality.
+//
+// +stateify savable
+type dir struct {
+ dirRefs
+ kernfs.InodeAlwaysValid
+ kernfs.InodeAttrs
+ kernfs.InodeNotSymlink
+ kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir.
+ kernfs.OrderedChildren
+ implStatFS
+
+ locks vfs.FileLocks
+}
+
+// Keep implements kernfs.Inode.Keep.
+func (*dir) Keep() bool {
+ return true
+}
+
+// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed.
+func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error {
+ return syserror.EPERM
+}
+
+// Open implements kernfs.Inode.Open.
+func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{
+ SeekEnd: kernfs.SeekEndStaticEntries,
+ })
+ if err != nil {
+ return nil, err
+ }
+ return fd.VFSFileDescription(), nil
+}
+
+// DecRef implements kernfs.Inode.DecRef.
+func (d *dir) DecRef(ctx context.Context) {
+ d.dirRefs.DecRef(func() { d.Destroy(ctx) })
+}
+
+// StatFS implements kernfs.Inode.StatFS.
+func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) {
+ return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil
+}
+
+// controllerFile represents a generic control file that appears within a cgroup
+// directory.
+//
+// +stateify savable
+type controllerFile struct {
+ kernfs.DynamicBytesFile
+}
+
+func (fs *filesystem) newControllerFile(ctx context.Context, creds *auth.Credentials, data vfs.DynamicBytesSource) kernfs.Inode {
+ f := &controllerFile{}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, readonlyFileMode)
+ return f
+}
+
+func (fs *filesystem) newControllerWritableFile(ctx context.Context, creds *auth.Credentials, data vfs.WritableDynamicBytesSource) kernfs.Inode {
+ f := &controllerFile{}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, writableFileMode)
+ return f
+}
+
+// staticControllerFile represents a generic control file that appears within a
+// cgroup directory which always returns the same data when read.
+// staticControllerFiles are not writable.
+//
+// +stateify savable
+type staticControllerFile struct {
+ kernfs.DynamicBytesFile
+ vfs.StaticData
+}
+
+// Note: We let the caller provide the mode so that static files may be used to
+// fake both readable and writable control files. However, static files are
+// effectively readonly, as attempting to write to them will return EIO
+// regardless of the mode.
+func (fs *filesystem) newStaticControllerFile(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, data string) kernfs.Inode {
+ f := &staticControllerFile{StaticData: vfs.StaticData{Data: data}}
+ f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), f, mode)
+ return f
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpu.go b/pkg/sentry/fsimpl/cgroupfs/cpu.go
new file mode 100644
index 000000000..24d86a277
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpu.go
@@ -0,0 +1,70 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// +stateify savable
+type cpuController struct {
+ controllerCommon
+
+ // CFS bandwidth control parameters, values in microseconds.
+ cfsPeriod int64
+ cfsQuota int64
+
+ // CPU shares, values should be (num core * 1024).
+ shares int64
+}
+
+var _ controller = (*cpuController)(nil)
+
+func newCPUController(fs *filesystem, defaults map[string]int64) *cpuController {
+ // Default values for controller parameters from Linux.
+ c := &cpuController{
+ cfsPeriod: 100000,
+ cfsQuota: -1,
+ shares: 1024,
+ }
+
+ if val, ok := defaults["cpu.cfs_period_us"]; ok {
+ c.cfsPeriod = val
+ delete(defaults, "cpu.cfs_period_us")
+ }
+ if val, ok := defaults["cpu.cfs_quota_us"]; ok {
+ c.cfsQuota = val
+ delete(defaults, "cpu.cfs_quota_us")
+ }
+ if val, ok := defaults["cpu.shares"]; ok {
+ c.shares = val
+ delete(defaults, "cpu.shares")
+ }
+
+ c.controllerCommon.init(controllerCPU, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpuController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ contents["cpu.cfs_period_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsPeriod))
+ contents["cpu.cfs_quota_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsQuota))
+ contents["cpu.shares"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.shares))
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go
new file mode 100644
index 000000000..d4104a00e
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go
@@ -0,0 +1,114 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// +stateify savable
+type cpuacctController struct {
+ controllerCommon
+}
+
+var _ controller = (*cpuacctController)(nil)
+
+func newCPUAcctController(fs *filesystem) *cpuacctController {
+ c := &cpuacctController{}
+ c.controllerCommon.init(controllerCPUAcct, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpuacctController) AddControlFiles(ctx context.Context, creds *auth.Credentials, cg *cgroupInode, contents map[string]kernfs.Inode) {
+ cpuacctCG := &cpuacctCgroup{cg}
+ contents["cpuacct.stat"] = c.fs.newControllerFile(ctx, creds, &cpuacctStatData{cpuacctCG})
+ contents["cpuacct.usage"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageData{cpuacctCG})
+ contents["cpuacct.usage_user"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageUserData{cpuacctCG})
+ contents["cpuacct.usage_sys"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageSysData{cpuacctCG})
+}
+
+// +stateify savable
+type cpuacctCgroup struct {
+ *cgroupInode
+}
+
+func (c *cpuacctCgroup) collectCPUStats() usage.CPUStats {
+ var cs usage.CPUStats
+ c.fs.tasksMu.RLock()
+ // Note: This isn't very accurate, since the tasks are potentially
+ // still running as we accumulate their stats.
+ for t := range c.ts {
+ cs.Accumulate(t.CPUStats())
+ }
+ c.fs.tasksMu.RUnlock()
+ return cs
+}
+
+// +stateify savable
+type cpuacctStatData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctStatData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "user %d\n", linux.ClockTFromDuration(cs.UserTime))
+ fmt.Fprintf(buf, "system %d\n", linux.ClockTFromDuration(cs.SysTime))
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()+cs.SysTime.Nanoseconds())
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageUserData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageUserData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds())
+ return nil
+}
+
+// +stateify savable
+type cpuacctUsageSysData struct {
+ *cpuacctCgroup
+}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *cpuacctUsageSysData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ cs := d.collectCPUStats()
+ fmt.Fprintf(buf, "%d\n", cs.SysTime.Nanoseconds())
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go
new file mode 100644
index 000000000..ac547f8e2
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go
@@ -0,0 +1,39 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+)
+
+// +stateify savable
+type cpusetController struct {
+ controllerCommon
+}
+
+var _ controller = (*cpusetController)(nil)
+
+func newCPUSetController(fs *filesystem) *cpusetController {
+ c := &cpusetController{}
+ c.controllerCommon.init(controllerCPUSet, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ // This controller is currently intentionally empty.
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go
new file mode 100644
index 000000000..485c98376
--- /dev/null
+++ b/pkg/sentry/fsimpl/cgroupfs/memory.go
@@ -0,0 +1,74 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cgroupfs
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/usage"
+)
+
+// +stateify savable
+type memoryController struct {
+ controllerCommon
+
+ limitBytes int64
+}
+
+var _ controller = (*memoryController)(nil)
+
+func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController {
+ c := &memoryController{
+ // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which
+ // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact
+ // value isn't very important.
+ limitBytes: math.MaxInt64,
+ }
+ if val, ok := defaults["memory.limit_in_bytes"]; ok {
+ c.limitBytes = val
+ delete(defaults, "memory.limit_in_bytes")
+ }
+ c.controllerCommon.init(controllerMemory, fs)
+ return c
+}
+
+// AddControlFiles implements controller.AddControlFiles.
+func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) {
+ contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{})
+ contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes))
+}
+
+// +stateify savable
+type memoryUsageInBytesData struct{}
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *memoryUsageInBytesData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // TODO(b/183151557): This is a giant hack, we're using system-wide
+ // accounting since we know there is only one cgroup.
+ k := kernel.KernelFromContext(ctx)
+ mf := k.MemoryFile()
+ mf.UpdateUsage()
+ _, totalBytes := usage.MemoryAccounting.Copy()
+
+ fmt.Fprintf(buf, "%d\n", totalBytes)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index a0c05231a..526136324 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1104,24 +1104,27 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
defer d.metadataMu.Unlock()
// As with Linux, if the UID, GID, or file size is changing, we have to
- // clear permission bits. Note that when set, clearSGID causes
- // permissions to be updated, but does not modify stat.Mask, as
- // modification would cause an extra inotify flag to be set.
- clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) ||
- stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) ||
+ // clear permission bits. Note that when set, clearSGID may cause
+ // permissions to be updated.
+ clearSGID := (stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid)) ||
+ (stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid)) ||
stat.Mask&linux.STATX_SIZE != 0
if clearSGID {
if stat.Mask&linux.STATX_MODE != 0 {
stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode)))
} else {
- stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode)))
+ oldMode := atomic.LoadUint32(&d.mode)
+ if updatedMode := vfs.ClearSUIDAndSGID(oldMode); updatedMode != oldMode {
+ stat.Mode = uint16(updatedMode)
+ stat.Mask |= linux.STATX_MODE
+ }
}
}
if !d.isSynthetic() {
if stat.Mask != 0 {
if err := d.file.setAttr(ctx, p9.SetAttrMask{
- Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID,
+ Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
GID: stat.Mask&linux.STATX_GID != 0,
Size: stat.Mask&linux.STATX_SIZE != 0,
@@ -1156,7 +1159,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
return nil
}
}
- if stat.Mask&linux.STATX_MODE != 0 || clearSGID {
+ if stat.Mask&linux.STATX_MODE != 0 {
atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode))
}
if stat.Mask&linux.STATX_UID != 0 {
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 47563538c..713f0a480 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -701,6 +701,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
}
// After this point, d may be used as a memmap.Mappable.
d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
+ opts.SentryOwnedContent = d.fs.opts.forcePageCache
return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts)
}
diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
index 65054b0ea..84b1c3745 100644
--- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
+++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go
@@ -25,8 +25,10 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// DynamicBytesFile implements kernfs.Inode and represents a read-only
-// file whose contents are backed by a vfs.DynamicBytesSource.
+// DynamicBytesFile implements kernfs.Inode and represents a read-only file
+// whose contents are backed by a vfs.DynamicBytesSource. If data additionally
+// implements vfs.WritableDynamicBytesSource, the file also supports dispatching
+// writes to the implementer, but note that this will not update the source data.
//
// Must be instantiated with NewDynamicBytesFile or initialized with Init
// before first use.
@@ -40,7 +42,9 @@ type DynamicBytesFile struct {
InodeNotSymlink
locks vfs.FileLocks
- data vfs.DynamicBytesSource
+ // data can additionally implement vfs.WritableDynamicBytesSource to support
+ // writes.
+ data vfs.DynamicBytesSource
}
var _ Inode = (*DynamicBytesFile)(nil)
diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go
index 565d723f0..16486eeae 100644
--- a/pkg/sentry/fsimpl/kernfs/kernfs.go
+++ b/pkg/sentry/fsimpl/kernfs/kernfs.go
@@ -61,6 +61,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refsvfs2"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -508,6 +509,15 @@ func (d *Dentry) Inode() Inode {
return d.inode
}
+// FSLocalPath returns an absolute path to d, relative to the root of its
+// filesystem.
+func (d *Dentry) FSLocalPath() string {
+ var b fspath.Builder
+ _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b)
+ b.PrependByte('/')
+ return b.String()
+}
+
// The Inode interface maps filesystem-level operations that operate on paths to
// equivalent operations on specific filesystem nodes.
//
diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go
index 254a8b062..ce8f55b1f 100644
--- a/pkg/sentry/fsimpl/proc/filesystem.go
+++ b/pkg/sentry/fsimpl/proc/filesystem.go
@@ -86,13 +86,13 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF
procfs.MaxCachedDentries = maxCachedDentries
procfs.VFSFilesystem().Init(vfsObj, &ft, procfs)
- var cgroups map[string]string
+ var fakeCgroupControllers map[string]string
if opts.InternalData != nil {
data := opts.InternalData.(*InternalData)
- cgroups = data.Cgroups
+ fakeCgroupControllers = data.Cgroups
}
- inode := procfs.newTasksInode(ctx, k, pidns, cgroups)
+ inode := procfs.newTasksInode(ctx, k, pidns, fakeCgroupControllers)
var dentry kernfs.Dentry
dentry.InitRoot(&procfs.Filesystem, inode)
return procfs.VFSFilesystem(), dentry.VFSDentry(), nil
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index fea138f93..d05cc1508 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -47,7 +47,7 @@ type taskInode struct {
var _ kernfs.Inode = (*taskInode)(nil)
-func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) {
+func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, fakeCgroupControllers map[string]string) (kernfs.Inode, error) {
if task.ExitState() == kernel.TaskExitDead {
return nil, syserror.ESRCH
}
@@ -82,10 +82,12 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns
"uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
- contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers)
+ contents["task"] = fs.newSubtasks(ctx, task, pidns, fakeCgroupControllers)
}
- if len(cgroupControllers) > 0 {
- contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers))
+ if len(fakeCgroupControllers) > 0 {
+ contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newFakeCgroupData(fakeCgroupControllers))
+ } else {
+ contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskCgroupData{task: task})
}
taskInode := &taskInode{task: task}
@@ -226,11 +228,14 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData {
return &ioData{ioUsage: t}
}
-// newCgroupData creates inode that shows cgroup information.
-// From man 7 cgroups: "For each cgroup hierarchy of which the process is a
-// member, there is one entry containing three colon-separated fields:
-// hierarchy-ID:controller-list:cgroup-path"
-func newCgroupData(controllers map[string]string) dynamicInode {
+// newFakeCgroupData creates an inode that shows fake cgroup
+// information passed in as mount options. From man 7 cgroups: "For
+// each cgroup hierarchy of which the process is a member, there is
+// one entry containing three colon-separated fields:
+// hierarchy-ID:controller-list:cgroup-path"
+//
+// TODO(b/182488796): Remove once all users adopt cgroupfs.
+func newFakeCgroupData(controllers map[string]string) dynamicInode {
var buf bytes.Buffer
// The hierarchy ids must be positive integers (for cgroup v1), but the
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 85909d551..b294dfd6a 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -1100,3 +1100,32 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err
func (fd *namespaceFD) Release(ctx context.Context) {
fd.inode.DecRef(ctx)
}
+
+// taskCgroupData generates data for /proc/[pid]/cgroup.
+//
+// +stateify savable
+type taskCgroupData struct {
+ dynamicBytesFileSetAttr
+ task *kernel.Task
+}
+
+var _ dynamicInode = (*taskCgroupData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (d *taskCgroupData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ // When a task is existing on Linux, a task's cgroup set is cleared and
+ // reset to the initial cgroup set, which is essentially the set of root
+ // cgroups. Because of this, the /proc/<pid>/cgroup file is always readable
+ // on Linux throughout a task's lifetime.
+ //
+ // The sentry removes tasks from cgroups during the exit process, but
+ // doesn't move them into an initial cgroup set, so partway through task
+ // exit this file show a task is in no cgroups, which is incorrect. Instead,
+ // once a task has left its cgroups, we return an error.
+ if d.task.ExitState() >= kernel.TaskExitInitiated {
+ return syserror.ESRCH
+ }
+
+ d.task.GenerateProcTaskCgroup(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go
index fdc580610..7c7543f14 100644
--- a/pkg/sentry/fsimpl/proc/tasks.go
+++ b/pkg/sentry/fsimpl/proc/tasks.go
@@ -54,15 +54,15 @@ type tasksInode struct {
// '/proc/self' and '/proc/thread-self' have custom directory offsets in
// Linux. So handle them outside of OrderedChildren.
- // cgroupControllers is a map of controller name to directory in the
+ // fakeCgroupControllers is a map of controller name to directory in the
// cgroup hierarchy. These controllers are immutable and will be listed
// in /proc/pid/cgroup if not nil.
- cgroupControllers map[string]string
+ fakeCgroupControllers map[string]string
}
var _ kernfs.Inode = (*tasksInode)(nil)
-func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode {
+func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode {
root := auth.NewRootCredentials(pidns.UserNamespace())
contents := map[string]kernfs.Inode{
"cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))),
@@ -76,11 +76,16 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns
"uptime": fs.newInode(ctx, root, 0444, &uptimeData{}),
"version": fs.newInode(ctx, root, 0444, &versionData{}),
}
+ // If fakeCgroupControllers are provided, don't create a cgroupfs backed
+ // /proc/cgroup as it will not match the fake controllers.
+ if len(fakeCgroupControllers) == 0 {
+ contents["cgroups"] = fs.newInode(ctx, root, 0444, &cgroupsData{})
+ }
inode := &tasksInode{
- pidns: pidns,
- fs: fs,
- cgroupControllers: cgroupControllers,
+ pidns: pidns,
+ fs: fs,
+ fakeCgroupControllers: fakeCgroupControllers,
}
inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555)
inode.InitRefs()
@@ -118,7 +123,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err
return nil, syserror.ENOENT
}
- return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers)
+ return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers)
}
// IterDirents implements kernfs.inodeDirectory.IterDirents.
diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go
index f0029cda6..e1a8b4409 100644
--- a/pkg/sentry/fsimpl/proc/tasks_files.go
+++ b/pkg/sentry/fsimpl/proc/tasks_files.go
@@ -384,3 +384,19 @@ func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error
k.VFS().GenerateProcFilesystems(buf)
return nil
}
+
+// cgroupsData backs /proc/cgroups.
+//
+// +stateify savable
+type cgroupsData struct {
+ dynamicBytesFileSetAttr
+}
+
+var _ dynamicInode = (*cgroupsData)(nil)
+
+// Generate implements vfs.DynamicBytesSource.Generate.
+func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ r := kernel.KernelFromContext(ctx).CgroupRegistry()
+ r.GenerateProcCgroups(buf)
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go
index cd849e87e..c45bddff6 100644
--- a/pkg/sentry/fsimpl/tmpfs/regular_file.go
+++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go
@@ -488,6 +488,7 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) (
// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
file := fd.inode().impl.(*regularFile)
+ opts.SentryOwnedContent = true
return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts)
}
diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD
index 2da251233..d473a922d 100644
--- a/pkg/sentry/fsimpl/verity/BUILD
+++ b/pkg/sentry/fsimpl/verity/BUILD
@@ -18,10 +18,12 @@ go_library(
"//pkg/marshal/primitive",
"//pkg/merkletree",
"//pkg/refsvfs2",
+ "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/fs/lock",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/memmap",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/sync",
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 6cb1a23e0..214ffd095 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -632,8 +632,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
childVD.IncRef()
childMerkleVD.IncRef()
- parent.IncRef()
- child.parent = parent
child.name = name
child.mode = uint32(stat.Mode)
@@ -657,6 +655,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry,
}
}
+ parent.IncRef()
+ child.parent = parent
+
return child, nil
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index a7d92a878..06f2c211c 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -34,6 +34,8 @@
package verity
import (
+ "bytes"
+ "encoding/hex"
"encoding/json"
"fmt"
"math"
@@ -44,19 +46,20 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/merkletree"
"gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/memmap"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
const (
@@ -103,6 +106,13 @@ var (
verityMu sync.RWMutex
)
+// Mount option names for verityfs.
+const (
+ moptLowerPath = "lower_path"
+ moptRootHash = "root_hash"
+ moptRootName = "root_name"
+)
+
// HashAlgorithm is a type specifying the algorithm used to hash the file
// content.
type HashAlgorithm int
@@ -169,6 +179,9 @@ type filesystem struct {
// system.
alg HashAlgorithm
+ // opts is the string mount options passed to opts.Data.
+ opts string
+
// renameMu synchronizes renaming with non-renaming operations in order
// to ensure consistent lock ordering between dentry.dirMu in different
// dentries.
@@ -191,9 +204,6 @@ type filesystem struct {
//
// +stateify savable
type InternalFilesystemOptions struct {
- // RootMerkleFileName is the name of the verity root Merkle tree file.
- RootMerkleFileName string
-
// LowerName is the name of the filesystem wrapped by verity fs.
LowerName string
@@ -201,9 +211,6 @@ type InternalFilesystemOptions struct {
// system.
Alg HashAlgorithm
- // RootHash is the root hash of the overall verity file system.
- RootHash []byte
-
// AllowRuntimeEnable specifies whether the verity file system allows
// enabling verification for files (i.e. building Merkle trees) during
// runtime.
@@ -237,28 +244,99 @@ func alertIntegrityViolation(msg string) error {
// GetFilesystem implements vfs.FilesystemType.GetFilesystem.
func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) {
+ mopts := vfs.GenericParseMountOptions(opts.Data)
+ var rootHash []byte
+ if encodedRootHash, ok := mopts[moptRootHash]; ok {
+ delete(mopts, moptRootHash)
+ hash, err := hex.DecodeString(encodedRootHash)
+ if err != nil {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err)
+ return nil, nil, syserror.EINVAL
+ }
+ rootHash = hash
+ }
+ var lowerPathname string
+ if path, ok := mopts[moptLowerPath]; ok {
+ delete(mopts, moptLowerPath)
+ lowerPathname = path
+ }
+ rootName := "root"
+ if root, ok := mopts[moptRootName]; ok {
+ delete(mopts, moptRootName)
+ rootName = root
+ }
+
+ // Check for unparsed options.
+ if len(mopts) != 0 {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts)
+ return nil, nil, syserror.EINVAL
+ }
+
+ // Handle internal options.
iopts, ok := opts.InternalData.(InternalFilesystemOptions)
- if !ok {
+ if len(lowerPathname) == 0 && !ok {
ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs")
return nil, nil, syserror.EINVAL
}
+ if len(lowerPathname) != 0 {
+ if ok {
+ ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path")
+ return nil, nil, syserror.EINVAL
+ }
+ iopts = InternalFilesystemOptions{
+ AllowRuntimeEnable: len(rootHash) == 0,
+ Action: ErrorOnViolation,
+ }
+ }
action = iopts.Action
- // Mount the lower file system. The lower file system is wrapped inside
- // verity, and should not be exposed or connected.
- mopts := &vfs.MountOptions{
- GetFilesystemOptions: iopts.LowerGetFSOptions,
- InternalMount: true,
- }
- mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts)
- if err != nil {
- return nil, nil, err
+ var lowerMount *vfs.Mount
+ var mountedLowerVD vfs.VirtualDentry
+ // Use an existing mount if lowerPath is provided.
+ if len(lowerPathname) != 0 {
+ vfsroot := vfs.RootFromContext(ctx)
+ if vfsroot.Ok() {
+ defer vfsroot.DecRef(ctx)
+ }
+ lowerPath := fspath.Parse(lowerPathname)
+ if !lowerPath.Absolute {
+ ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname)
+ return nil, nil, syserror.EINVAL
+ }
+ var err error
+ mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{
+ Root: vfsroot,
+ Start: vfsroot,
+ Path: lowerPath,
+ FollowFinalSymlink: true,
+ }, &vfs.GetDentryOptions{
+ CheckSearchable: true,
+ })
+ if err != nil {
+ ctx.Infof("verity.FilesystemType.GetFilesystem: failed to resolve lower_path %q: %v", lowerPathname, err)
+ return nil, nil, err
+ }
+ lowerMount = mountedLowerVD.Mount()
+ defer mountedLowerVD.DecRef(ctx)
+ } else {
+ // Mount the lower file system. The lower file system is wrapped inside
+ // verity, and should not be exposed or connected.
+ mountOpts := &vfs.MountOptions{
+ GetFilesystemOptions: iopts.LowerGetFSOptions,
+ InternalMount: true,
+ }
+ mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mountOpts)
+ if err != nil {
+ return nil, nil, err
+ }
+ lowerMount = mnt
}
fs := &filesystem{
creds: creds.Fork(),
alg: iopts.Alg,
- lowerMount: mnt,
+ lowerMount: lowerMount,
+ opts: opts.Data,
allowRuntimeEnable: iopts.AllowRuntimeEnable,
}
fs.vfsfs.Init(vfsObj, &fstype, fs)
@@ -266,11 +344,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
// Construct the root dentry.
d := fs.newDentry()
d.refs = 1
- lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root())
+ lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root())
lowerVD.IncRef()
d.lowerVD = lowerVD
- rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName
+ rootMerkleName := merkleRootPrefix + rootName
lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
Root: lowerVD,
@@ -350,9 +428,15 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
d.mode = uint32(stat.Mode)
d.uid = stat.UID
d.gid = stat.GID
- d.hash = make([]byte, len(iopts.RootHash))
d.childrenNames = make(map[string]struct{})
+ d.hashMu.Lock()
+ d.hash = make([]byte, len(rootHash))
+ copy(d.hash, rootHash)
+ d.hashMu.Unlock()
+
+ fs.rootDentry = d
+
if !d.isDir() {
ctx.Warningf("verity root must be a directory")
return nil, nil, syserror.EINVAL
@@ -424,13 +508,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
}
}
- d.hashMu.Lock()
- copy(d.hash, iopts.RootHash)
- d.hashMu.Unlock()
d.vfsd.Init(d)
- fs.rootDentry = d
-
return &fs.vfsfs, &d.vfsd, nil
}
@@ -441,7 +520,7 @@ func (fs *filesystem) Release(ctx context.Context) {
// MountOptions implements vfs.FilesystemImpl.MountOptions.
func (fs *filesystem) MountOptions() string {
- return ""
+ return fs.opts
}
// dentry implements vfs.DentryImpl.
@@ -722,6 +801,10 @@ type fileDescription struct {
// underlying file system.
lowerFD *vfs.FileDescription
+ // lowerMappable is the memmap.Mappable corresponding to this file in the
+ // underlying file system.
+ lowerMappable memmap.Mappable
+
// merkleReader is the read-only FileDescription corresponding to the
// Merkle tree file in the underlying file system.
merkleReader *vfs.FileDescription
@@ -1201,6 +1284,24 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op
return 0, syserror.EROFS
}
+// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap.
+func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error {
+ if err := fd.lowerFD.ConfigureMMap(ctx, opts); err != nil {
+ return err
+ }
+ fd.lowerMappable = opts.Mappable
+ if opts.MappingIdentity != nil {
+ opts.MappingIdentity.DecRef(ctx)
+ opts.MappingIdentity = nil
+ }
+
+ // Check if mmap is allowed on the lower filesystem.
+ if !opts.SentryOwnedContent {
+ return syserror.ENODEV
+ }
+ return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts)
+}
+
// LockBSD implements vfs.FileDescriptionImpl.LockBSD.
func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error {
return fd.lowerFD.LockBSD(ctx, ownerPID, t, block)
@@ -1226,6 +1327,115 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t
return fd.lowerFD.TestPOSIX(ctx, uid, t, r)
}
+// Translate implements memmap.Mappable.Translate.
+func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) {
+ ts, err := fd.lowerMappable.Translate(ctx, required, optional, at)
+ if err != nil {
+ return ts, err
+ }
+
+ // dataSize is the size of the whole file.
+ dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{
+ Name: merkleSizeXattr,
+ Size: sizeOfStringInt32,
+ })
+
+ // The Merkle tree file for the child should have been created and
+ // contains the expected xattrs. If the xattr does not exist, it
+ // indicates unexpected modifications to the file system.
+ if err == syserror.ENODATA {
+ return ts, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err))
+ }
+ if err != nil {
+ return ts, err
+ }
+
+ // The dataSize xattr should be an integer. If it's not, it indicates
+ // unexpected modifications to the file system.
+ size, err := strconv.Atoi(dataSize)
+ if err != nil {
+ return ts, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
+ }
+
+ merkleReader := FileReadWriteSeeker{
+ FD: fd.merkleReader,
+ Ctx: ctx,
+ }
+
+ for _, t := range ts {
+ // Content integrity relies on sentry owning the backing data. MapInternal is guaranteed
+ // to fetch sentry owned memory because we disallow verity mmaps otherwise.
+ ims, err := t.File.MapInternal(memmap.FileRange{t.Offset, t.Offset + t.Source.Length()}, hostarch.Read)
+ if err != nil {
+ return nil, err
+ }
+ dataReader := mmapReadSeeker{ims, t.Source.Start}
+ var buf bytes.Buffer
+ _, err = merkletree.Verify(&merkletree.VerifyParams{
+ Out: &buf,
+ File: &dataReader,
+ Tree: &merkleReader,
+ Size: int64(size),
+ Name: fd.d.name,
+ Mode: fd.d.mode,
+ UID: fd.d.uid,
+ GID: fd.d.gid,
+ HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(),
+ ReadOffset: int64(t.Source.Start),
+ ReadSize: int64(t.Source.Length()),
+ Expected: fd.d.hash,
+ DataAndTreeInSameFile: false,
+ })
+ if err != nil {
+ return ts, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err))
+ }
+ }
+ return ts, err
+}
+
+// AddMapping implements memmap.Mappable.AddMapping.
+func (fd *fileDescription) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error {
+ return fd.lowerMappable.AddMapping(ctx, ms, ar, offset, writable)
+}
+
+// RemoveMapping implements memmap.Mappable.RemoveMapping.
+func (fd *fileDescription) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) {
+ fd.lowerMappable.RemoveMapping(ctx, ms, ar, offset, writable)
+}
+
+// CopyMapping implements memmap.Mappable.CopyMapping.
+func (fd *fileDescription) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error {
+ return fd.lowerMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable)
+}
+
+// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable.
+func (fd *fileDescription) InvalidateUnsavable(context.Context) error {
+ return nil
+}
+
+// mmapReadSeeker is a helper struct used by fileDescription.Translate to pass
+// a safemem.BlockSeq pointing to the mapped region as io.ReaderAt.
+type mmapReadSeeker struct {
+ safemem.BlockSeq
+ Offset uint64
+}
+
+// ReadAt implements io.ReaderAt.ReadAt. off is the offset into the mapped file.
+func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) {
+ bs := r.BlockSeq
+ // Adjust the offset into the mapped file to get the offset into the internally
+ // mapped region.
+ readOffset := off - int64(r.Offset)
+ if readOffset < 0 {
+ return 0, syserror.EINVAL
+ }
+ bs.DropFirst64(uint64(readOffset))
+ view := bs.TakeFirst64(uint64(len(p)))
+ dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p))
+ n, err := safemem.CopySeq(dst, view)
+ return int(n), err
+}
+
// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as
// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
type FileReadWriteSeeker struct {
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 57bd65202..5c78a0019 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -89,10 +89,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
AllowUserMount: true,
})
+ data := "root_name=" + rootMerkleFilename
mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: data,
InternalData: InternalFilesystemOptions{
- RootMerkleFileName: rootMerkleFilename,
LowerName: "tmpfs",
Alg: hashAlg,
AllowRuntimeEnable: true,
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index e9eb89378..a1ec6daab 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -141,6 +141,7 @@ go_library(
srcs = [
"abstract_socket_namespace.go",
"aio.go",
+ "cgroup.go",
"context.go",
"fd_table.go",
"fd_table_refs.go",
@@ -178,6 +179,7 @@ go_library(
"task.go",
"task_acct.go",
"task_block.go",
+ "task_cgroup.go",
"task_clone.go",
"task_context.go",
"task_exec.go",
@@ -241,6 +243,7 @@ go_library(
"//pkg/sentry/fs/lock",
"//pkg/sentry/fs/timerfd",
"//pkg/sentry/fsbridge",
+ "//pkg/sentry/fsimpl/kernfs",
"//pkg/sentry/fsimpl/pipefs",
"//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/fsimpl/timerfd",
diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go
new file mode 100644
index 000000000..1f1c63f37
--- /dev/null
+++ b/pkg/sentry/kernel/cgroup.go
@@ -0,0 +1,281 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// InvalidCgroupHierarchyID indicates an uninitialized hierarchy ID.
+const InvalidCgroupHierarchyID uint32 = 0
+
+// CgroupControllerType is the name of a cgroup controller.
+type CgroupControllerType string
+
+// CgroupController is the common interface to cgroup controllers available to
+// the entire sentry. The controllers themselves are defined by cgroupfs.
+//
+// Callers of this interface are often unable access synchronization needed to
+// ensure returned values remain valid. Some of values returned from this
+// interface are thus snapshots in time, and may become stale. This is ok for
+// many callers like procfs.
+type CgroupController interface {
+ // Returns the type of this cgroup controller (ex "memory", "cpu"). Returned
+ // value is valid for the lifetime of the controller.
+ Type() CgroupControllerType
+
+ // Hierarchy returns the ID of the hierarchy this cgroup controller is
+ // attached to. Returned value is valid for the lifetime of the controller.
+ HierarchyID() uint32
+
+ // Filesystem returns the filesystem this controller is attached to.
+ // Returned value is valid for the lifetime of the controller.
+ Filesystem() *vfs.Filesystem
+
+ // RootCgroup returns the root cgroup for this controller. Returned value is
+ // valid for the lifetime of the controller.
+ RootCgroup() Cgroup
+
+ // NumCgroups returns the number of cgroups managed by this controller.
+ // Returned value is a snapshot in time.
+ NumCgroups() uint64
+
+ // Enabled returns whether this controller is enabled. Returned value is a
+ // snapshot in time.
+ Enabled() bool
+}
+
+// Cgroup represents a named pointer to a cgroup in cgroupfs. When a task enters
+// a cgroup, it holds a reference on the underlying dentry pointing to the
+// cgroup.
+//
+// +stateify savable
+type Cgroup struct {
+ *kernfs.Dentry
+ CgroupImpl
+}
+
+func (c *Cgroup) decRef() {
+ c.Dentry.DecRef(context.Background())
+}
+
+// Path returns the absolute path of c, relative to its hierarchy root.
+func (c *Cgroup) Path() string {
+ return c.FSLocalPath()
+}
+
+// HierarchyID returns the id of the hierarchy that contains this cgroup.
+func (c *Cgroup) HierarchyID() uint32 {
+ // Note: a cgroup is guaranteed to have at least one controller.
+ return c.Controllers()[0].HierarchyID()
+}
+
+// CgroupImpl is the common interface to cgroups.
+type CgroupImpl interface {
+ Controllers() []CgroupController
+ Enter(t *Task)
+ Leave(t *Task)
+}
+
+// hierarchy represents a cgroupfs filesystem instance, with a unique set of
+// controllers attached to it. Multiple cgroupfs mounts may reference the same
+// hierarchy.
+//
+// +stateify savable
+type hierarchy struct {
+ id uint32
+ // These are a subset of the controllers in CgroupRegistry.controllers,
+ // grouped here by hierarchy for conveninent lookup.
+ controllers map[CgroupControllerType]CgroupController
+ // fs is not owned by hierarchy. The FS is responsible for unregistering the
+ // hierarchy on destruction, which removes this association.
+ fs *vfs.Filesystem
+}
+
+func (h *hierarchy) match(ctypes []CgroupControllerType) bool {
+ if len(ctypes) != len(h.controllers) {
+ return false
+ }
+ for _, ty := range ctypes {
+ if _, ok := h.controllers[ty]; !ok {
+ return false
+ }
+ }
+ return true
+}
+
+// CgroupRegistry tracks the active set of cgroup controllers on the system.
+//
+// +stateify savable
+type CgroupRegistry struct {
+ // lastHierarchyID is the id of the last allocated cgroup hierarchy. Valid
+ // ids are from 1 to math.MaxUint32. Must be accessed through atomic ops.
+ //
+ lastHierarchyID uint32
+
+ mu sync.Mutex `state:"nosave"`
+
+ // controllers is the set of currently known cgroup controllers on the
+ // system. Protected by mu.
+ //
+ // +checklocks:mu
+ controllers map[CgroupControllerType]CgroupController
+
+ // hierarchies is the active set of cgroup hierarchies. Protected by mu.
+ //
+ // +checklocks:mu
+ hierarchies map[uint32]hierarchy
+}
+
+func newCgroupRegistry() *CgroupRegistry {
+ return &CgroupRegistry{
+ controllers: make(map[CgroupControllerType]CgroupController),
+ hierarchies: make(map[uint32]hierarchy),
+ }
+}
+
+// nextHierarchyID returns a newly allocated, unique hierarchy ID.
+func (r *CgroupRegistry) nextHierarchyID() (uint32, error) {
+ if hid := atomic.AddUint32(&r.lastHierarchyID, 1); hid != 0 {
+ return hid, nil
+ }
+ return InvalidCgroupHierarchyID, fmt.Errorf("cgroup hierarchy ID overflow")
+}
+
+// FindHierarchy returns a cgroup filesystem containing exactly the set of
+// controllers named in names. If no such FS is found, FindHierarchy return
+// nil. FindHierarchy takes a reference on the returned FS, which is transferred
+// to the caller.
+func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Filesystem {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ for _, h := range r.hierarchies {
+ if h.match(ctypes) {
+ h.fs.IncRef()
+ return h.fs
+ }
+ }
+
+ return nil
+}
+
+// Register registers the provided set of controllers with the registry as a new
+// hierarchy. If any controller is already registered, the function returns an
+// error without modifying the registry. The hierarchy can be later referenced
+// by the returned id.
+func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if len(cs) == 0 {
+ return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers")
+ }
+
+ for _, c := range cs {
+ if _, ok := r.controllers[c.Type()]; ok {
+ return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy")
+ }
+ }
+
+ hid, err := r.nextHierarchyID()
+ if err != nil {
+ return hid, err
+ }
+
+ h := hierarchy{
+ id: hid,
+ controllers: make(map[CgroupControllerType]CgroupController),
+ fs: cs[0].Filesystem(),
+ }
+ for _, c := range cs {
+ n := c.Type()
+ r.controllers[n] = c
+ h.controllers[n] = c
+ }
+ r.hierarchies[hid] = h
+ return hid, nil
+}
+
+// Unregister removes a previously registered hierarchy from the registry. If
+// the controller was not previously registered, Unregister is a no-op.
+func (r *CgroupRegistry) Unregister(hid uint32) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if h, ok := r.hierarchies[hid]; ok {
+ for name, _ := range h.controllers {
+ delete(r.controllers, name)
+ }
+ delete(r.hierarchies, hid)
+ }
+}
+
+// computeInitialGroups takes a reference on each of the returned cgroups. The
+// caller takes ownership of this returned reference.
+func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[Cgroup]struct{} {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ ctlSet := make(map[CgroupControllerType]CgroupController)
+ cgset := make(map[Cgroup]struct{})
+
+ // Remember controllers from the inherited cgroups set...
+ for cg, _ := range inherit {
+ cg.IncRef() // Ref transferred to caller.
+ for _, ctl := range cg.Controllers() {
+ ctlSet[ctl.Type()] = ctl
+ cgset[cg] = struct{}{}
+ }
+ }
+
+ // ... and add the root cgroups of all the missing controllers.
+ for name, ctl := range r.controllers {
+ if _, ok := ctlSet[name]; !ok {
+ cg := ctl.RootCgroup()
+ cg.IncRef() // Ref transferred to caller.
+ cgset[cg] = struct{}{}
+ }
+ }
+ return cgset
+}
+
+// GenerateProcCgroups writes the contents of /proc/cgroups to buf.
+func (r *CgroupRegistry) GenerateProcCgroups(buf *bytes.Buffer) {
+ r.mu.Lock()
+ entries := make([]string, 0, len(r.controllers))
+ for _, c := range r.controllers {
+ en := 0
+ if c.Enabled() {
+ en = 1
+ }
+ entries = append(entries, fmt.Sprintf("%s\t%d\t%d\t%d\n", c.Type(), c.HierarchyID(), c.NumCgroups(), en))
+ }
+ r.mu.Unlock()
+
+ sort.Strings(entries)
+ fmt.Fprint(buf, "#subsys_name\thierarchy\tnum_cgroups\tenabled\n")
+ for _, e := range entries {
+ fmt.Fprint(buf, e)
+ }
+}
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 43065b45a..9a4fd64cb 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -294,6 +294,11 @@ type Kernel struct {
// YAMAPtraceScope is the current level of YAMA ptrace restrictions.
YAMAPtraceScope int32
+
+ // cgroupRegistry contains the set of active cgroup controllers on the
+ // system. It is controller by cgroupfs. Nil if cgroupfs is unavailable on
+ // the system.
+ cgroupRegistry *CgroupRegistry
}
// InitKernelArgs holds arguments to Init.
@@ -438,6 +443,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.socketMount = socketMount
k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord)
+
+ k.cgroupRegistry = newCgroupRegistry()
}
return nil
}
@@ -1815,6 +1822,11 @@ func (k *Kernel) SocketMount() *vfs.Mount {
return k.socketMount
}
+// CgroupRegistry returns the cgroup registry.
+func (k *Kernel) CgroupRegistry() *CgroupRegistry {
+ return k.cgroupRegistry
+}
+
// Release releases resources owned by k.
//
// Precondition: This should only be called after the kernel is fully
@@ -1831,3 +1843,43 @@ func (k *Kernel) Release() {
k.timekeeper.Destroy()
k.vdso.Release(ctx)
}
+
+// PopulateNewCgroupHierarchy moves all tasks into a newly created cgroup
+// hierarchy.
+//
+// Precondition: root must be a new cgroup with no tasks. This implies the
+// controllers for root are also new and currently manage no task, which in turn
+// implies the new cgroup can be populated without migrating tasks between
+// cgroups.
+func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) {
+ k.tasks.mu.RLock()
+ k.tasks.forEachTaskLocked(func(t *Task) {
+ if t.ExitState() != TaskExitNone {
+ return
+ }
+ t.mu.Lock()
+ t.enterCgroupLocked(root)
+ t.mu.Unlock()
+ })
+ k.tasks.mu.RUnlock()
+}
+
+// ReleaseCgroupHierarchy moves all tasks out of all cgroups belonging to the
+// hierarchy with the provided id. This is intended for use during hierarchy
+// teardown, as otherwise the tasks would be orphaned w.r.t to some controllers.
+func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) {
+ k.tasks.mu.RLock()
+ k.tasks.forEachTaskLocked(func(t *Task) {
+ if t.ExitState() != TaskExitNone {
+ return
+ }
+ t.mu.Lock()
+ for cg, _ := range t.cgroups {
+ if cg.HierarchyID() == hid {
+ t.leaveCgroupLocked(cg)
+ }
+ }
+ t.mu.Unlock()
+ })
+ k.tasks.mu.RUnlock()
+}
diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go
index 399985039..be1371855 100644
--- a/pkg/sentry/kernel/task.go
+++ b/pkg/sentry/kernel/task.go
@@ -587,6 +587,12 @@ type Task struct {
//
// kcov is exclusive to the task goroutine.
kcov *Kcov
+
+ // cgroups is the set of cgroups this task belongs to. This may be empty if
+ // no cgroup controllers are enabled. Protected by mu.
+ //
+ // +checklocks:mu
+ cgroups map[Cgroup]struct{}
}
func (t *Task) savePtraceTracer() *Task {
diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go
new file mode 100644
index 000000000..25d2504fa
--- /dev/null
+++ b/pkg/sentry/kernel/task_cgroup.go
@@ -0,0 +1,138 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kernel
+
+import (
+ "bytes"
+ "fmt"
+ "sort"
+ "strings"
+
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// EnterInitialCgroups moves t into an initial set of cgroups.
+//
+// Precondition: t isn't in any cgroups yet, t.cgs is empty.
+//
+// +checklocksignore parent.mu is conditionally acquired.
+func (t *Task) EnterInitialCgroups(parent *Task) {
+ var inherit map[Cgroup]struct{}
+ if parent != nil {
+ parent.mu.Lock()
+ defer parent.mu.Unlock()
+ inherit = parent.cgroups
+ }
+ joinSet := t.k.cgroupRegistry.computeInitialGroups(inherit)
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // Transfer ownership of joinSet refs to the task's cgset.
+ t.cgroups = joinSet
+ for c, _ := range t.cgroups {
+ // Since t isn't in any cgroup yet, we can skip the check against
+ // existing cgroups.
+ c.Enter(t)
+ }
+}
+
+// EnterCgroup moves t into c.
+func (t *Task) EnterCgroup(c Cgroup) error {
+ newControllers := make(map[CgroupControllerType]struct{})
+ for _, ctl := range c.Controllers() {
+ newControllers[ctl.Type()] = struct{}{}
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ for oldCG, _ := range t.cgroups {
+ for _, oldCtl := range oldCG.Controllers() {
+ if _, ok := newControllers[oldCtl.Type()]; ok {
+ // Already in a cgroup with the same controller as one of the
+ // new ones. Requires migration between cgroups.
+ //
+ // TODO(b/183137098): Implement cgroup migration.
+ log.Warningf("Cgroup migration is not implemented")
+ return syserror.EBUSY
+ }
+ }
+ }
+
+ // No migration required.
+ t.enterCgroupLocked(c)
+
+ return nil
+}
+
+// +checklocks:t.mu
+func (t *Task) enterCgroupLocked(c Cgroup) {
+ c.IncRef()
+ t.cgroups[c] = struct{}{}
+ c.Enter(t)
+}
+
+// LeaveCgroups removes t out from all its cgroups.
+func (t *Task) LeaveCgroups() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ for c, _ := range t.cgroups {
+ t.leaveCgroupLocked(c)
+ }
+}
+
+// +checklocks:t.mu
+func (t *Task) leaveCgroupLocked(c Cgroup) {
+ c.Leave(t)
+ delete(t.cgroups, c)
+ c.decRef()
+}
+
+// taskCgroupEntry represents a line in /proc/<pid>/cgroup, and is used to
+// format a cgroup for display.
+type taskCgroupEntry struct {
+ hierarchyID uint32
+ controllers string
+ path string
+}
+
+// GenerateProcTaskCgroup writes the contents of /proc/<pid>/cgroup for t to buf.
+func (t *Task) GenerateProcTaskCgroup(buf *bytes.Buffer) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ cgEntries := make([]taskCgroupEntry, 0, len(t.cgroups))
+ for c, _ := range t.cgroups {
+ ctls := c.Controllers()
+ ctlNames := make([]string, 0, len(ctls))
+ for _, ctl := range ctls {
+ ctlNames = append(ctlNames, string(ctl.Type()))
+ }
+
+ cgEntries = append(cgEntries, taskCgroupEntry{
+ // Note: We're guaranteed to have at least one controller, and all
+ // controllers are guaranteed to be on the same hierarchy.
+ hierarchyID: ctls[0].HierarchyID(),
+ controllers: strings.Join(ctlNames, ","),
+ path: c.Path(),
+ })
+ }
+
+ sort.Slice(cgEntries, func(i, j int) bool { return cgEntries[i].hierarchyID > cgEntries[j].hierarchyID })
+ for _, cgE := range cgEntries {
+ fmt.Fprintf(buf, "%d:%s:%s\n", cgE.hierarchyID, cgE.controllers, cgE.path)
+ }
+}
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index ad59e4f60..b1af1a7ef 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -275,6 +275,10 @@ func (*runExitMain) execute(t *Task) taskRunState {
t.fsContext.DecRef(t)
t.fdTable.DecRef(t)
+ // Detach task from all cgroups. This must happen before potentially the
+ // last ref to the cgroupfs mount is dropped below.
+ t.LeaveCgroups()
+
t.mu.Lock()
if t.mountNamespaceVFS2 != nil {
t.mountNamespaceVFS2.DecRef(t)
diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go
index fc18b6253..32031cd70 100644
--- a/pkg/sentry/kernel/task_start.go
+++ b/pkg/sentry/kernel/task_start.go
@@ -151,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
rseqSignature: cfg.RSeqSignature,
futexWaiter: futex.NewWaiter(),
containerID: cfg.ContainerID,
+ cgroups: make(map[Cgroup]struct{}),
}
t.creds.Store(cfg.Credentials)
t.endStopCond.L = &t.tg.signalHandlers.mu
@@ -189,6 +190,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) {
t.parent.children[t] = struct{}{}
}
+ if VFS2Enabled {
+ t.EnterInitialCgroups(t.parent)
+ }
+
if tg.leader == nil {
// New thread group.
tg.leader = t
diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go
index 09d070ec8..77ad62445 100644
--- a/pkg/sentry/kernel/threads.go
+++ b/pkg/sentry/kernel/threads.go
@@ -114,6 +114,15 @@ func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) {
}
}
+// forEachTaskLocked applies f to each Task in ts.
+//
+// Preconditions: ts.mu must be locked (for reading or writing).
+func (ts *TaskSet) forEachTaskLocked(f func(t *Task)) {
+ for t := range ts.Root.tids {
+ f(t)
+ }
+}
+
// A PIDNamespace represents a PID namespace, a bimap between thread IDs and
// tasks. See the pid_namespaces(7) man page for further details.
//
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 72868646a..610686ea0 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -375,6 +375,11 @@ type MMapOpts struct {
//
// If Force is true, Unmap and Fixed must be true.
Force bool
+
+ // SentryOwnedContent indicates the sentry exclusively controls the
+ // underlying memory backing the mapping thus the memory content is
+ // guaranteed not to be modified outside the sentry's purview.
+ SentryOwnedContent bool
}
// File represents a host file that may be mapped into an platform.AddressSpace.
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 47b294312..b307832fd 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -65,8 +65,8 @@ go_test(
name = "kvm_test",
srcs = [
"kvm_amd64_test.go",
- "kvm_arm64_test.go",
"kvm_amd64_test.s",
+ "kvm_arm64_test.go",
"kvm_test.go",
"virtual_map_test.go",
],
diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go
index 382953cf7..b8dd1e4a5 100644
--- a/pkg/sentry/platform/kvm/kvm_amd64_test.go
+++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go
@@ -62,10 +62,17 @@ func TestMXCSR(t *testing.T) {
PageTables: pt,
FullRestore: true,
}
+
+ const mxcsrControllMask = uint32(0x1f80)
mxcsrBefore := uint32(0)
mxcsrAfter := uint32(0)
stmxcsr(&mxcsrBefore)
- switchOpts.FloatingPointState.SetMXCSR(mxcsrBefore ^ 0x8)
+ if mxcsrBefore == 0 {
+ // goruntime sets mxcsr to 0x1f80 and it never changes
+ // the control configuration.
+ panic("mxcsr is zero")
+ }
+ switchOpts.FloatingPointState.SetMXCSR(0)
if _, err := c.SwitchToUser(
switchOpts, &si); err == platform.ErrContextInterrupt {
return true // Retry.
@@ -73,7 +80,7 @@ func TestMXCSR(t *testing.T) {
t.Errorf("application syscall failed: %v", err)
}
stmxcsr(&mxcsrAfter)
- if mxcsrAfter != mxcsrBefore {
+ if mxcsrAfter&mxcsrControllMask != mxcsrBefore&mxcsrControllMask {
t.Errorf("mxcsr = %x (expected %x)", mxcsrBefore, mxcsrAfter)
}
return false
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 03e84d804..cd912f922 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -47,7 +47,7 @@ const (
// Beyond a relatively small number, there are likely few perform
// benefits, since the TLB has likely long since lost any translations
// from more than a few PCIDs past.
- poolPCIDs = 8
+ poolPCIDs = 128
)
func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 9bdf6d3d8..eff251cec 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -35,12 +35,6 @@ import (
// LINT.IfChange
-// minListenBacklog is the minimum reasonable backlog for listening sockets.
-const minListenBacklog = 8
-
-// maxListenBacklog is the maximum allowed backlog for listening sockets.
-const maxListenBacklog = 1024
-
// maxAddrLen is the maximum socket address length we're willing to accept.
const maxAddrLen = 200
@@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8
// buffers upto INT_MAX.
const maxControlLen = 10 * 1024 * 1024
+// maxListenBacklog is the maximum limit of listen backlog supported.
+const maxListenBacklog = 1024
+
// nameLenOffset is the offset from the start of the MessageHeader64 struct to
// the NameLen field.
const nameLenOffset = 8
@@ -367,7 +364,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Listen implements the linux syscall listen(2).
func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
- backlog := args[1].Int()
+ backlog := args[1].Uint()
// Get socket from the file descriptor.
file := t.GetFile(fd)
@@ -382,11 +379,13 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ENOTSOCK
}
- // Per Linux, the backlog is silently capped to reasonable values.
- if backlog <= 0 {
- backlog = minListenBacklog
- }
if backlog > maxListenBacklog {
+ // Linux treats incoming backlog as uint with a limit defined by
+ // sysctl_somaxconn.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
+ //
+ // We use the backlog to allocate a channel of that size, hence enforce
+ // a hard limit for the backlog.
backlog = maxListenBacklog
}
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index a87a66146..936614eab 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -35,12 +35,6 @@ import (
"gvisor.dev/gvisor/pkg/hostarch"
)
-// minListenBacklog is the minimum reasonable backlog for listening sockets.
-const minListenBacklog = 8
-
-// maxListenBacklog is the maximum allowed backlog for listening sockets.
-const maxListenBacklog = 1024
-
// maxAddrLen is the maximum socket address length we're willing to accept.
const maxAddrLen = 200
@@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8
// buffers upto INT_MAX.
const maxControlLen = 10 * 1024 * 1024
+// maxListenBacklog is the maximum limit of listen backlog supported.
+const maxListenBacklog = 1024
+
// nameLenOffset is the offset from the start of the MessageHeader64 struct to
// the NameLen field.
const nameLenOffset = 8
@@ -371,7 +368,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
// Listen implements the linux syscall listen(2).
func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
fd := args[0].Int()
- backlog := args[1].Int()
+ backlog := args[1].Uint()
// Get socket from the file descriptor.
file := t.GetFileVFS2(fd)
@@ -386,11 +383,13 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return 0, nil, syserror.ENOTSOCK
}
- // Per Linux, the backlog is silently capped to reasonable values.
- if backlog <= 0 {
- backlog = minListenBacklog
- }
if backlog > maxListenBacklog {
+ // Linux treats incoming backlog as uint with a limit defined by
+ // sysctl_somaxconn.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666
+ //
+ // We use the backlog to allocate a channel of that size, hence enforce
+ // a hard limit for the backlog.
backlog = maxListenBacklog
}
diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go
index 1556b41a3..b87d9690a 100644
--- a/pkg/sentry/vfs/file_description_impl_util.go
+++ b/pkg/sentry/vfs/file_description_impl_util.go
@@ -252,6 +252,9 @@ type WritableDynamicBytesSource interface {
// are backed by a bytes.Buffer that is regenerated when necessary, consistent
// with Linux's fs/seq_file.c:single_open().
//
+// If data additionally implements WritableDynamicBytesSource, writes are
+// dispatched to the implementer. The source data is not automatically modified.
+//
// DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first
// use.
//
diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go
index 922f9e697..7cdab6945 100644
--- a/pkg/sentry/vfs/mount.go
+++ b/pkg/sentry/vfs/mount.go
@@ -970,17 +970,22 @@ func superBlockOpts(mountPath string, mnt *Mount) string {
opts += "," + mopts
}
- // NOTE(b/147673608): If the mount is a cgroup, we also need to include
- // the cgroup name in the options. For now we just read that from the
- // path.
+ // NOTE(b/147673608): If the mount is a ramdisk-based fake cgroupfs, we also
+ // need to include the cgroup name in the options. For now we just read that
+ // from the path. Note that this is only possible when "cgroup" isn't
+ // registered as a valid filesystem type.
//
- // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we
- // should get this value from the cgroup itself, and not rely on the
- // path.
+ // TODO(gvisor.dev/issue/190): Once we removed fake cgroupfs support, we
+ // should remove this.
+ if cgroupfs := mnt.vfs.getFilesystemType("cgroup"); cgroupfs != nil && cgroupfs.opts.AllowUserMount {
+ // Real cgroupfs available.
+ return opts
+ }
if mnt.fs.FilesystemType().Name() == "cgroup" {
splitPath := strings.Split(mountPath, "/")
cgroupType := splitPath[len(splitPath)-1]
opts += "," + cgroupType
}
+
return opts
}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
index f588311e0..85bd164cd 100644
--- a/pkg/tcpip/header/ipv4.go
+++ b/pkg/tcpip/header/ipv4.go
@@ -178,6 +178,26 @@ const (
IPv4FlagDontFragment
)
+// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined
+// by RFC 3927 section 1.
+var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
+// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as
+// defined by RFC 5771 section 4.
+var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet {
+ subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00"))
+ if err != nil {
+ panic(err)
+ }
+ return subnet
+}()
+
// IPv4EmptySubnet is the empty IPv4 subnet.
var IPv4EmptySubnet = func() tcpip.Subnet {
subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any))
@@ -423,6 +443,18 @@ func (b IPv4) IsValid(pktSize int) bool {
return true
}
+// IsV4LinkLocalUnicastAddress determines if the provided address is an IPv4
+// link-local unicast address.
+func IsV4LinkLocalUnicastAddress(addr tcpip.Address) bool {
+ return ipv4LinkLocalUnicastSubnet.Contains(addr)
+}
+
+// IsV4LinkLocalMulticastAddress determines if the provided address is an IPv4
+// link-local multicast address.
+func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool {
+ return ipv4LinkLocalMulticastSubnet.Contains(addr)
+}
+
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
// will be 1110 = 0xe0.
diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go
index 6475cd694..c02fe898b 100644
--- a/pkg/tcpip/header/ipv4_test.go
+++ b/pkg/tcpip/header/ipv4_test.go
@@ -18,6 +18,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
)
@@ -177,3 +178,77 @@ func TestIPv4EncodeOptions(t *testing.T) {
})
}
}
+
+func TestIsV4LinkLocalUnicastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid (lowest)",
+ addr: "\xa9\xfe\x00\x00",
+ expected: true,
+ },
+ {
+ name: "Valid (highest)",
+ addr: "\xa9\xfe\xff\xff",
+ expected: true,
+ },
+ {
+ name: "Invalid (before subnet)",
+ addr: "\xa9\xfd\xff\xff",
+ expected: false,
+ },
+ {
+ name: "Invalid (after subnet)",
+ addr: "\xa9\xff\x00\x00",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
+
+func TestIsV4LinkLocalMulticastAddress(t *testing.T) {
+ tests := []struct {
+ name string
+ addr tcpip.Address
+ expected bool
+ }{
+ {
+ name: "Valid (lowest)",
+ addr: "\xe0\x00\x00\x00",
+ expected: true,
+ },
+ {
+ name: "Valid (highest)",
+ addr: "\xe0\x00\x00\xff",
+ expected: true,
+ },
+ {
+ name: "Invalid (before subnet)",
+ addr: "\xdf\xff\xff\xff",
+ expected: false,
+ },
+ {
+ name: "Invalid (after subnet)",
+ addr: "\xe0\x00\x01\x00",
+ expected: false,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
index f2403978c..c3a0407ac 100644
--- a/pkg/tcpip/header/ipv6.go
+++ b/pkg/tcpip/header/ipv6.go
@@ -98,12 +98,27 @@ const (
// The address is ff02::1.
IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
- // IPv6AllRoutersMulticastAddress is a link-local multicast group that
- // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local
+ // multicast group that all IPv6 routers MUST join, as per RFC 4291, section
+ // 2.8. Packets destined to this address will reach the router on an
+ // interface.
+ //
+ // The address is ff01::2.
+ IPv6AllRoutersInterfaceLocalMulticastAddress tcpip.Address = "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group
+ // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all routers on a link.
//
// The address is ff02::2.
- IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ IPv6AllRoutersLinkLocalMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group
+ // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
+ // destined to this address will reach all routers in a site.
+ //
+ // The address is ff05::2.
+ IPv6AllRoutersSiteLocalMulticastAddress tcpip.Address = "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
// section 5:
@@ -142,11 +157,6 @@ const (
// ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
// within the byte holding the field, as per RFC 4291 section 2.7.
ipv6MulticastAddressScopeMask = 0xF
-
- // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within
- // a multicast IPv6 address that indicates the address has link-local scope,
- // as per RFC 4291 section 2.7.
- ipv6LinkLocalMulticastScope = 2
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
@@ -381,25 +391,25 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
return tcpip.Address(lladdrb[:])
}
-// IsV6LinkLocalAddress determines if the provided address is an IPv6
-// link-local address (fe80::/10).
-func IsV6LinkLocalAddress(addr tcpip.Address) bool {
+// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6
+// link-local unicast address, as defined by RFC 4291 section 2.5.6.
+func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool {
if len(addr) != IPv6AddressSize {
return false
}
return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80
}
-// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback
-// address.
+// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback
+// address, as defined by RFC 4291 section 2.5.3.
func IsV6LoopbackAddress(addr tcpip.Address) bool {
return addr == IPv6Loopback
}
-// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6
-// link-local multicast address.
+// IsV6LinkLocalMulticastAddress returns true iff the provided address is an
+// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7.
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
- return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope
+ return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope
}
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
@@ -462,7 +472,7 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) {
case IsV6LinkLocalMulticastAddress(addr):
return LinkLocalScope, nil
- case IsV6LinkLocalAddress(addr):
+ case IsV6LinkLocalUnicastAddress(addr):
return LinkLocalScope, nil
default:
@@ -520,3 +530,46 @@ func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address)
PrefixLen: IIDOffsetInIPv6Address * 8,
}
}
+
+// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by
+// RFC 7346 section 2.
+type IPv6MulticastScope uint8
+
+// The various values for IPv6 multicast scopes, as per RFC 7346 section 2:
+//
+// +------+--------------------------+-------------------------+
+// | scop | NAME | REFERENCE |
+// +------+--------------------------+-------------------------+
+// | 0 | Reserved | [RFC4291], RFC 7346 |
+// | 1 | Interface-Local scope | [RFC4291], RFC 7346 |
+// | 2 | Link-Local scope | [RFC4291], RFC 7346 |
+// | 3 | Realm-Local scope | [RFC4291], RFC 7346 |
+// | 4 | Admin-Local scope | [RFC4291], RFC 7346 |
+// | 5 | Site-Local scope | [RFC4291], RFC 7346 |
+// | 6 | Unassigned | |
+// | 7 | Unassigned | |
+// | 8 | Organization-Local scope | [RFC4291], RFC 7346 |
+// | 9 | Unassigned | |
+// | A | Unassigned | |
+// | B | Unassigned | |
+// | C | Unassigned | |
+// | D | Unassigned | |
+// | E | Global scope | [RFC4291], RFC 7346 |
+// | F | Reserved | [RFC4291], RFC 7346 |
+// +------+--------------------------+-------------------------+
+const (
+ IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0)
+ IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1)
+ IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2)
+ IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3)
+ IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4)
+ IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5)
+ IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8)
+ IPv6GlobalMulticastScope = IPv6MulticastScope(0xE)
+ IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF)
+)
+
+// V6MulticastScope returns the scope of a multicast address.
+func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope {
+ return IPv6MulticastScope(addr[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask)
+}
diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go
index f10f446a6..ccee9000e 100644
--- a/pkg/tcpip/header/ipv6_test.go
+++ b/pkg/tcpip/header/ipv6_test.go
@@ -252,7 +252,7 @@ func TestIsV6LinkLocalMulticastAddress(t *testing.T) {
}
}
-func TestIsV6LinkLocalAddress(t *testing.T) {
+func TestIsV6LinkLocalUnicastAddress(t *testing.T) {
tests := []struct {
name string
addr tcpip.Address
@@ -287,8 +287,8 @@ func TestIsV6LinkLocalAddress(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected {
- t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected)
+ if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected {
+ t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected)
}
})
}
@@ -373,3 +373,83 @@ func TestSolicitedNodeAddr(t *testing.T) {
})
}
}
+
+func TestV6MulticastScope(t *testing.T) {
+ tests := []struct {
+ addr tcpip.Address
+ want header.IPv6MulticastScope
+ }{
+ {
+ addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6Reserved0MulticastScope,
+ },
+ {
+ addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6InterfaceLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6LinkLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6RealmLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6AdminLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6SiteLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(6),
+ },
+ {
+ addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(7),
+ },
+ {
+ addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6OrganizationLocalMulticastScope,
+ },
+ {
+ addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(9),
+ },
+ {
+ addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(10),
+ },
+ {
+ addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(11),
+ },
+ {
+ addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(12),
+ },
+ {
+ addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6MulticastScope(13),
+ },
+ {
+ addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6GlobalMulticastScope,
+ },
+ {
+ addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ want: header.IPv6ReservedFMulticastScope,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) {
+ if got := header.V6MulticastScope(test.addr); got != test.want {
+ t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
index b9f129728..ac35d81e7 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go
@@ -156,14 +156,6 @@ type GenericMulticastProtocolOptions struct {
//
// Unsolicited reports are transmitted when a group is newly joined.
MaxUnsolicitedReportDelay time.Duration
-
- // AllNodesAddress is a multicast address that all nodes on a network should
- // be a member of.
- //
- // This address will not have the generic multicast protocol performed on it;
- // it will be left in the non member/listener state, and packets will never
- // be sent for it.
- AllNodesAddress tcpip.Address
}
// MulticastGroupProtocol is a multicast group protocol whose core state machine
@@ -188,6 +180,10 @@ type MulticastGroupProtocol interface {
// SendLeave sends a multicast leave for the specified group address.
SendLeave(groupAddress tcpip.Address) tcpip.Error
+
+ // ShouldPerformProtocol returns true iff the protocol should be performed for
+ // the specified group.
+ ShouldPerformProtocol(tcpip.Address) bool
}
// GenericMulticastProtocolState is the per interface generic multicast protocol
@@ -455,20 +451,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t
info.lastToSendReport = false
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
info.state = idleMember
return
}
@@ -537,20 +520,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres
return
}
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
return
}
@@ -627,20 +597,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr
return
}
- if groupAddress == g.opts.AllNodesAddress {
- // As per RFC 2236 section 6 page 10 (for IGMPv2),
- //
- // The all-systems group (address 224.0.0.1) is handled as a special
- // case. The host starts in Idle Member state for that group on every
- // interface, never transitions to another state, and never sends a
- // report for that group.
- //
- // As per RFC 2710 section 5 page 10 (for MLDv1),
- //
- // The link-scope all-nodes address (FF02::1) is handled as a special
- // case. The node starts in Idle Listener state for that address on
- // every interface, never transitions to another state, and never sends
- // a Report or Done for that address.
+ if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) {
return
}
diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
index 381460c82..0b51563cd 100644
--- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
+++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go
@@ -43,6 +43,8 @@ type mockMulticastGroupProtocolProtectedFields struct {
type mockMulticastGroupProtocol struct {
t *testing.T
+ skipProtocolAddress tcpip.Address
+
mu mockMulticastGroupProtocolProtectedFields
}
@@ -165,6 +167,11 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip
return nil
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ return groupAddress != m.skipProtocolAddress
+}
+
func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string {
m.mu.Lock()
defer m.mu.Unlock()
@@ -193,10 +200,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr
cmp.FilterPath(
func(p cmp.Path) bool {
switch p.Last().String() {
- case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup":
+ case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress":
return true
+ default:
+ return false
}
- return false
},
cmp.Ignore(),
),
@@ -225,14 +233,13 @@ func TestJoinGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(0)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr2,
})
// Joining a group should send a report immediately and another after
@@ -279,14 +286,13 @@ func TestLeaveGroup(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(1)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr2,
})
mgp.joinGroup(test.addr)
@@ -356,14 +362,13 @@ func TestHandleReport(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(2)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
@@ -446,14 +451,13 @@ func TestHandleQuery(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
@@ -574,14 +578,13 @@ func TestJoinCount(t *testing.T) {
}
func TestMakeAllNonMemberAndInitialize(t *testing.T) {
- mgp := mockMulticastGroupProtocol{t: t}
+ mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3}
clock := faketime.NewManualClock()
mgp.init(ip.GenericMulticastProtocolOptions{
Rand: rand.New(rand.NewSource(3)),
Clock: clock,
MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay,
- AllNodesAddress: addr3,
})
mgp.joinGroup(addr1)
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
index a4edc69c7..58fd18af8 100644
--- a/pkg/tcpip/network/ip_test.go
+++ b/pkg/tcpip/network/ip_test.go
@@ -15,6 +15,7 @@
package ip_test
import (
+ "fmt"
"strings"
"testing"
@@ -1938,3 +1939,80 @@ func TestICMPInclusionSize(t *testing.T) {
})
}
}
+
+func TestJoinLeaveAllRoutersGroup(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ protoFactory stack.NetworkProtocolFactory
+ allRoutersAddr tcpip.Address
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ protoFactory: ipv4.NewProtocol,
+ allRoutersAddr: header.IPv4AllRoutersGroup,
+ },
+ {
+ name: "IPv6 Interface Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress,
+ },
+ {
+ name: "IPv6 Link Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress,
+ },
+ {
+ name: "IPv6 Site Local",
+ netProto: ipv6.ProtocolNumber,
+ protoFactory: ipv6.NewProtocol,
+ allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, nicDisabled := range [...]bool{true, false} {
+ t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ })
+ opts := stack.NICOptions{Disabled: nicDisabled}
+ if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil {
+ t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err)
+ }
+
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
+ }
+
+ if err := s.SetForwarding(test.netProto, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err)
+ }
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if !got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr)
+ }
+
+ if err := s.SetForwarding(test.netProto, false); err != nil {
+ t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err)
+ }
+ if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil {
+ t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err)
+ } else if got {
+ t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr)
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go
index f3fc1c87e..b1ac29294 100644
--- a/pkg/tcpip/network/ipv4/igmp.go
+++ b/pkg/tcpip/network/ipv4/igmp.go
@@ -126,6 +126,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
return err
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (igmp *igmpState) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ // As per RFC 2236 section 6 page 10,
+ //
+ // The all-systems group (address 224.0.0.1) is handled as a special
+ // case. The host starts in Idle Member state for that group on every
+ // interface, never transitions to another state, and never sends a
+ // report for that group.
+ return groupAddress != header.IPv4AllSystems
+}
+
// init sets up an igmpState struct, and is required to be called before using
// a new igmpState.
//
@@ -137,7 +148,6 @@ func (igmp *igmpState) init(ep *endpoint) {
Clock: ep.protocol.stack.Clock(),
Protocol: igmp,
MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
- AllNodesAddress: header.IPv4AllSystems,
})
igmp.igmpV1Present = igmpV1PresentDefault
igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() {
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 1a5661ca4..2e44f8523 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -150,6 +150,38 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
delete(p.mu.eps, nicID)
}
+// transitionForwarding transitions the endpoint's forwarding status to
+// forwarding.
+//
+// Must only be called when the forwarding status changes.
+func (e *endpoint) transitionForwarding(forwarding bool) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ if forwarding {
+ // There does not seem to be an RFC requirement for a node to join the all
+ // routers multicast address but
+ // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml
+ // specifies the address as a group for all routers on a subnet so we join
+ // the group here.
+ if err := e.joinGroupLocked(header.IPv4AllRoutersGroup); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a
+ // valid IPv4 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err))
+ }
+
+ return
+ }
+
+ switch err := e.leaveGroupLocked(header.IPv4AllRoutersGroup).(type) {
+ case nil:
+ case *tcpip.ErrBadLocalAddress:
+ // The endpoint may have already left the multicast group.
+ default:
+ panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err))
+ }
+}
+
// Enable implements stack.NetworkEndpoint.
func (e *endpoint) Enable() tcpip.Error {
e.mu.Lock()
@@ -226,7 +258,7 @@ func (e *endpoint) disableLocked() {
}
// The endpoint may have already left the multicast group.
- switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) {
+ switch err := e.leaveGroupLocked(header.IPv4AllSystems).(type) {
case nil, *tcpip.ErrBadLocalAddress:
default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err))
@@ -551,6 +583,22 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// forwardPacket attempts to forward a packet to its final destination.
func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv4(pkt.NetworkHeader().View())
+
+ dstAddr := h.DestinationAddress()
+ if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) {
+ // As per RFC 3927 section 7,
+ //
+ // A router MUST NOT forward a packet with an IPv4 Link-Local source or
+ // destination address, irrespective of the router's default route
+ // configuration or routes obtained from dynamic routing protocols.
+ //
+ // A router which receives a packet with an IPv4 Link-Local source or
+ // destination address MUST NOT forward the packet. This prevents
+ // forwarding of packets back onto the network segment from which they
+ // originated, or to any other segment.
+ return nil
+ }
+
ttl := h.TTL()
if ttl == 0 {
// As per RFC 792 page 6, Time Exceeded Message,
@@ -589,8 +637,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
}
}
- dstAddr := h.DestinationAddress()
-
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
ep.handleValidatedPacket(h, pkt)
@@ -1168,12 +1214,27 @@ func (p *protocol) Forwarding() bool {
return uint8(atomic.LoadUint32(&p.forwarding)) == 1
}
+// setForwarding sets the forwarding status for the protocol.
+//
+// Returns true if the forwarding status was updated.
+func (p *protocol) setForwarding(v bool) bool {
+ if v {
+ return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
+ }
+ return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
+}
+
// SetForwarding implements stack.ForwardingNetworkProtocol.
func (p *protocol) SetForwarding(v bool) {
- if v {
- atomic.StoreUint32(&p.forwarding, 1)
- } else {
- atomic.StoreUint32(&p.forwarding, 0)
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ if !p.setForwarding(v) {
+ return
+ }
+
+ for _, ep := range p.mu.eps {
+ ep.transitionForwarding(v)
}
}
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
index a142b76c1..b2a80e1e9 100644
--- a/pkg/tcpip/network/ipv6/icmp.go
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -273,7 +273,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP
if iph.HopLimit() != header.MLDHopLimit {
return false
}
- if !header.IsV6LinkLocalAddress(iph.SourceAddress()) {
+ if !header.IsV6LinkLocalUnicastAddress(iph.SourceAddress()) {
return false
}
return true
@@ -804,7 +804,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r
routerAddr := srcAddr
// Is the IP Source Address a link-local address?
- if !header.IsV6LinkLocalAddress(routerAddr) {
+ if !header.IsV6LinkLocalUnicastAddress(routerAddr) {
// ...No, silently drop the packet.
received.invalid.Increment()
return
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c6d9d8f0d..d36cefcd0 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -314,7 +314,7 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) {
// Snooping switches MUST manage multicast forwarding state based on MLD
// Report and Done messages sent with the unspecified address as the
// IPv6 source address.
- if header.IsV6LinkLocalAddress(addr) {
+ if header.IsV6LinkLocalUnicastAddress(addr) {
e.mu.mld.sendQueuedReports()
}
}
@@ -410,22 +410,65 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t
//
// Must only be called when the forwarding status changes.
func (e *endpoint) transitionForwarding(forwarding bool) {
+ allRoutersGroups := [...]tcpip.Address{
+ header.IPv6AllRoutersInterfaceLocalMulticastAddress,
+ header.IPv6AllRoutersLinkLocalMulticastAddress,
+ header.IPv6AllRoutersSiteLocalMulticastAddress,
+ }
+
e.mu.Lock()
defer e.mu.Unlock()
- if !e.Enabled() {
- return
- }
-
if forwarding {
// When transitioning into an IPv6 router, host-only state (NDP discovered
// routers, discovered on-link prefixes, and auto-generated addresses) is
// cleaned up/invalidated and NDP router solicitations are stopped.
e.mu.ndp.stopSolicitingRouters()
e.mu.ndp.cleanupState(true /* hostOnly */)
- } else {
- // When transitioning into an IPv6 host, NDP router solicitations are
- // started.
+
+ // As per RFC 4291 section 2.8:
+ //
+ // A router is required to recognize all addresses that a host is
+ // required to recognize, plus the following addresses as identifying
+ // itself:
+ //
+ // o The All-Routers multicast addresses defined in Section 2.7.1.
+ //
+ // As per RFC 4291 section 2.7.1,
+ //
+ // All Routers Addresses: FF01:0:0:0:0:0:0:2
+ // FF02:0:0:0:0:0:0:2
+ // FF05:0:0:0:0:0:0:2
+ //
+ // The above multicast addresses identify the group of all IPv6 routers,
+ // within scope 1 (interface-local), 2 (link-local), or 5 (site-local).
+ for _, g := range allRoutersGroups {
+ if err := e.joinGroupLocked(g); err != nil {
+ // joinGroupLocked only returns an error if the group address is not a
+ // valid IPv6 multicast address.
+ panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err))
+ }
+ }
+
+ return
+ }
+
+ for _, g := range allRoutersGroups {
+ switch err := e.leaveGroupLocked(g).(type) {
+ case nil:
+ case *tcpip.ErrBadLocalAddress:
+ // The endpoint may have already left the multicast group.
+ default:
+ panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err))
+ }
+ }
+
+ // When transitioning into an IPv6 host, NDP router solicitations are
+ // started if the endpoint is enabled.
+ //
+ // If the endpoint is not currently enabled, routers will be solicited when
+ // the endpoint becomes enabled (if it is still a host).
+ if e.Enabled() {
e.mu.ndp.startSolicitingRouters()
}
}
@@ -573,7 +616,7 @@ func (e *endpoint) disableLocked() {
e.mu.ndp.cleanupState(false /* hostOnly */)
// The endpoint may have already left the multicast group.
- switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) {
+ switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) {
case nil, *tcpip.ErrBadLocalAddress:
default:
panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err))
@@ -869,6 +912,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu
// forwardPacket attempts to forward a packet to its final destination.
func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
h := header.IPv6(pkt.NetworkHeader().View())
+
+ dstAddr := h.DestinationAddress()
+ if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) {
+ // As per RFC 4291 section 2.5.6,
+ //
+ // Routers must not forward any packets with Link-Local source or
+ // destination addresses to other links.
+ return nil
+ }
+
hopLimit := h.HopLimit()
if hopLimit <= 1 {
// As per RFC 4443 section 3.3,
@@ -881,8 +934,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt)
}
- dstAddr := h.DestinationAddress()
-
// Check if the destination is owned by the stack.
if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
ep.handleValidatedPacket(h, pkt)
@@ -1571,7 +1622,7 @@ func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address {
var linkLocalAddr tcpip.Address
e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool {
if addressEndpoint.IsAssigned(false /* allowExpired */) {
- if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) {
+ if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalUnicastAddress(addr) {
linkLocalAddr = addr
return false
}
@@ -1979,9 +2030,9 @@ func (p *protocol) Forwarding() bool {
// Returns true if the forwarding status was updated.
func (p *protocol) setForwarding(v bool) bool {
if v {
- return atomic.SwapUint32(&p.forwarding, 1) == 0
+ return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */)
}
- return atomic.SwapUint32(&p.forwarding, 0) == 1
+ return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */)
}
// SetForwarding implements stack.ForwardingNetworkProtocol.
diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go
index dd153466d..165b7d2d2 100644
--- a/pkg/tcpip/network/ipv6/mld.go
+++ b/pkg/tcpip/network/ipv6/mld.go
@@ -76,10 +76,29 @@ func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error)
//
// Precondition: mld.ep.mu must be read locked.
func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error {
- _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
+ _, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone)
return err
}
+// ShouldPerformProtocol implements ip.MulticastGroupProtocol.
+func (mld *mldState) ShouldPerformProtocol(groupAddress tcpip.Address) bool {
+ // As per RFC 2710 section 5 page 10,
+ //
+ // The link-scope all-nodes address (FF02::1) is handled as a special
+ // case. The node starts in Idle Listener state for that address on
+ // every interface, never transitions to another state, and never sends
+ // a Report or Done for that address.
+ //
+ // MLD messages are never sent for multicast addresses whose scope is 0
+ // (reserved) or 1 (node-local).
+ if groupAddress == header.IPv6AllNodesMulticastAddress {
+ return false
+ }
+
+ scope := header.V6MulticastScope(groupAddress)
+ return scope != header.IPv6Reserved0MulticastScope && scope != header.IPv6InterfaceLocalMulticastScope
+}
+
// init sets up an mldState struct, and is required to be called before using
// a new mldState.
//
@@ -91,7 +110,6 @@ func (mld *mldState) init(ep *endpoint) {
Clock: ep.protocol.stack.Clock(),
Protocol: mld,
MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax,
- AllNodesAddress: header.IPv6AllNodesMulticastAddress,
})
}
diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go
index 85a8f9944..146b300f1 100644
--- a/pkg/tcpip/network/ipv6/mld_test.go
+++ b/pkg/tcpip/network/ipv6/mld_test.go
@@ -93,7 +93,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) {
if p, ok := e.Read(); !ok {
t.Fatal("expected a done message to be sent")
} else {
- validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC)
}
}
@@ -464,3 +464,141 @@ func TestMLDPacketValidation(t *testing.T) {
})
}
}
+
+func TestMLDSkipProtocol(t *testing.T) {
+ const nicID = 1
+
+ tests := []struct {
+ name string
+ group tcpip.Address
+ expectReport bool
+ }{
+ {
+ name: "Reserverd0",
+ group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: false,
+ },
+ {
+ name: "Interface Local",
+ group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: false,
+ },
+ {
+ name: "Link Local",
+ group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Realm Local",
+ group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Admin Local",
+ group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Site Local",
+ group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(6)",
+ group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(7)",
+ group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Organization Local",
+ group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(9)",
+ group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(A)",
+ group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(B)",
+ group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(C)",
+ group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Unassigned(D)",
+ group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "Global",
+ group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ {
+ name: "ReservedF",
+ group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11",
+ expectReport: true,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{
+ MLD: ipv6.MLDOptions{
+ Enabled: true,
+ },
+ })},
+ })
+ e := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID, e); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
+ }
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err)
+ }
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC)
+ }
+
+ if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil {
+ t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err)
+ }
+ if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil {
+ t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err)
+ } else if !isInGroup {
+ t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group)
+ }
+
+ if !test.expectReport {
+ if p, ok := e.Read(); ok {
+ t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p)
+ }
+
+ return
+ }
+
+ if p, ok := e.Read(); !ok {
+ t.Fatal("expected a report message to be sent")
+ } else {
+ validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go
index 536493f87..a110faa54 100644
--- a/pkg/tcpip/network/ipv6/ndp.go
+++ b/pkg/tcpip/network/ipv6/ndp.go
@@ -737,7 +737,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) {
prefix := opt.Subnet()
// Is the prefix a link-local?
- if header.IsV6LinkLocalAddress(prefix.ID()) {
+ if header.IsV6LinkLocalUnicastAddress(prefix.ID()) {
// ...Yes, skip as per RFC 4861 section 6.3.4,
// and RFC 4862 section 5.5.3.b (for SLAAC).
continue
@@ -1703,7 +1703,7 @@ func (ndp *ndpState) startSolicitingRouters() {
// the unspecified address if no address is assigned
// to the sending interface.
localAddr := header.IPv6Any
- if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil {
+ if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersLinkLocalMulticastAddress, false); addressEndpoint != nil {
localAddr = addressEndpoint.AddressWithPrefix().Address
addressEndpoint.DecRef()
}
@@ -1730,7 +1730,7 @@ func (ndp *ndpState) startSolicitingRouters() {
icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpData,
Src: localAddr,
- Dst: header.IPv6AllRoutersMulticastAddress,
+ Dst: header.IPv6AllRoutersLinkLocalMulticastAddress,
}))
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
@@ -1739,14 +1739,14 @@ func (ndp *ndpState) startSolicitingRouters() {
})
sent := ndp.ep.stats.icmp.packetsSent
- if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{
+ if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{
Protocol: header.ICMPv6ProtocolNumber,
TTL: header.NDPHopLimit,
}, nil /* extensionHeaders */); err != nil {
panic(fmt.Sprintf("failed to add IP header: %s", err))
}
- if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
+ if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil {
sent.dropped.Increment()
// Don't send any more messages if we had an error.
remaining = 0
diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go
index ecd5003a7..2aa4e6d75 100644
--- a/pkg/tcpip/network/multicast_group_test.go
+++ b/pkg/tcpip/network/multicast_group_test.go
@@ -194,7 +194,7 @@ func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, c
if p, ok := e.Read(); !ok {
t.Fatal("expected a report message to be sent")
} else {
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC)
}
// Should not send any more packets.
@@ -606,7 +606,7 @@ func TestMGPLeaveGroup(t *testing.T) {
validateLeave: func(t *testing.T, p channel.PacketInfo) {
t.Helper()
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1)
},
checkInitialGroups: checkInitialIPv6Groups,
},
@@ -1014,7 +1014,7 @@ func TestMGPWithNICLifecycle(t *testing.T) {
validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) {
t.Helper()
- validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr)
+ validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr)
},
getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address {
t.Helper()
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 14124ae66..a869cce38 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -5204,13 +5204,13 @@ func TestRouterSolicitation(t *testing.T) {
}
// Make sure the right remote link address is used.
- if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want {
+ if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want {
t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want)
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(test.expectedSrcAddr),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)),
)
@@ -5362,7 +5362,7 @@ func TestStopStartSolicitingRouters(t *testing.T) {
}
checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()),
checker.SrcAddr(header.IPv6Any),
- checker.DstAddr(header.IPv6AllRoutersMulticastAddress),
+ checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress),
checker.TTL(header.NDPHopLimit),
checker.NDPRS())
}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 39344808d..4ae6bed5a 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp
localAddr = addressEndpoint.AddressWithPrefix().Address
}
- if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) {
+ if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) {
addressEndpoint.DecRef()
return nil
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 931a97ddc..f23112410 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1344,7 +1344,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
s.mu.RLock()
defer s.mu.RUnlock()
- isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
+ isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr)
isLocalBroadcast := remoteAddr == header.IPv4Broadcast
isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)
isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr)
@@ -1381,7 +1381,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n
return nil, &tcpip.ErrNetworkUnreachable{}
}
- canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal
+ canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal
// Find a route to the remote with the route table.
var chosenRoute tcpip.Route
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 87ea09a5e..60de16579 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -786,6 +786,13 @@ func (*TCPRecovery) isGettableTransportProtocolOption() {}
func (*TCPRecovery) isSettableTransportProtocolOption() {}
+// TCPAlwaysUseSynCookies indicates unconditional usage of syncookies.
+type TCPAlwaysUseSynCookies bool
+
+func (*TCPAlwaysUseSynCookies) isGettableTransportProtocolOption() {}
+
+func (*TCPAlwaysUseSynCookies) isSettableTransportProtocolOption() {}
+
const (
// TCPRACKLossDetection indicates RACK is used for loss detection and
// recovery.
@@ -1020,19 +1027,6 @@ func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {}
func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {}
-// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
-// the number of endpoints that can be in SYN-RCVD state before the stack
-// switches to using SYN cookies.
-type TCPSynRcvdCountThresholdOption uint64
-
-func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {}
-
-func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {}
-
// TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide
// default for number of times SYN is retransmitted before aborting a connect.
type TCPSynRetriesOption uint8
diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD
index 3cc8c36f1..3b51e4be0 100644
--- a/pkg/tcpip/tests/integration/BUILD
+++ b/pkg/tcpip/tests/integration/BUILD
@@ -9,6 +9,8 @@ go_test(
deps = [
"//pkg/tcpip",
"//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index d10ae05c2..0de5079e8 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -21,6 +21,8 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/checker"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
@@ -312,3 +314,193 @@ func TestForwarding(t *testing.T) {
})
}
}
+
+func TestMulticastForwarding(t *testing.T) {
+ const (
+ nicID1 = 1
+ nicID2 = 2
+
+ ipv4LinkLocalUnicastAddr = tcpip.Address("\xa9\xfe\x00\x0a")
+ ipv4LinkLocalMulticastAddr = tcpip.Address("\xe0\x00\x00\x0a")
+ ipv4GlobalMulticastAddr = tcpip.Address("\xe0\x00\x01\x0a")
+
+ ipv6LinkLocalUnicastAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a")
+ ipv6LinkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a")
+ ipv6GlobalMulticastAddr = tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a")
+
+ ttl = 64
+ )
+
+ rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv4EchoRequest(e, src, dst, ttl)
+ }
+
+ rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) {
+ utils.RxICMPv6EchoRequest(e, src, dst, ttl)
+ }
+
+ v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv4(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Echo)))
+ }
+
+ v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) {
+ checker.IPv6(t, b,
+ checker.SrcAddr(src),
+ checker.DstAddr(dst),
+ checker.TTL(ttl-1),
+ checker.ICMPv6(
+ checker.ICMPv6Type(header.ICMPv6EchoRequest)))
+ }
+
+ tests := []struct {
+ name string
+ srcAddr, dstAddr tcpip.Address
+ rx func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ expectForward bool
+ checker func(*testing.T, []byte)
+ }{
+ {
+ name: "IPv4 link-local multicast destination",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4LinkLocalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 link-local source",
+ srcAddr: ipv4LinkLocalUnicastAddr,
+ dstAddr: utils.RemoteIPv4Addr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 link-local destination",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4LinkLocalUnicastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv4 non-link-local unicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv4EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv4 non-link-local multicast",
+ srcAddr: utils.RemoteIPv4Addr,
+ dstAddr: ipv4GlobalMulticastAddr,
+ rx: rxICMPv4EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr)
+ },
+ },
+
+ {
+ name: "IPv6 link-local multicast destination",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6LinkLocalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 link-local source",
+ srcAddr: ipv6LinkLocalUnicastAddr,
+ dstAddr: utils.RemoteIPv6Addr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 link-local destination",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6LinkLocalUnicastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: false,
+ },
+ {
+ name: "IPv6 non-link-local unicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address,
+ rx: rxICMPv6EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address)
+ },
+ },
+ {
+ name: "IPv6 non-link-local multicast",
+ srcAddr: utils.RemoteIPv6Addr,
+ dstAddr: ipv6GlobalMulticastAddr,
+ rx: rxICMPv6EchoRequest,
+ expectForward: true,
+ checker: func(t *testing.T, b []byte) {
+ v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr)
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
+ })
+
+ e1 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID1, e1); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err)
+ }
+
+ e2 := channel.New(1, header.IPv6MinimumMTU, "")
+ if err := s.CreateNIC(nicID2, e2); err != nil {
+ t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err)
+ }
+
+ if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err)
+ }
+ if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil {
+ t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err)
+ }
+
+ if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: header.IPv4EmptySubnet,
+ NIC: nicID2,
+ },
+ {
+ Destination: header.IPv6EmptySubnet,
+ NIC: nicID2,
+ },
+ })
+
+ test.rx(e1, test.srcAddr, test.dstAddr)
+
+ p, ok := e2.Read()
+ if ok != test.expectForward {
+ t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward)
+ }
+
+ if test.expectForward {
+ test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader()))
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 2c538a43e..82c2e11ab 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -513,22 +513,23 @@ func TestExternalLoopbackTraffic(t *testing.T) {
ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01")
numPackets = 1
+ ttl = 64
)
loopbackSourcedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address)
+ utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl)
}
loopbackSourcedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address)
+ utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl)
}
loopbackDestinedICMPv4 := func(e *channel.Endpoint) {
- utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback)
+ utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl)
}
loopbackDestinedICMPv6 := func(e *channel.Endpoint) {
- utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback)
+ utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl)
}
invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter {
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index c6a9c2393..09ff3b892 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -43,12 +43,15 @@ const (
// to a multicast or broadcast address uses a unicast source address for the
// reply.
func TestPingMulticastBroadcast(t *testing.T) {
- const nicID = 1
+ const (
+ nicID = 1
+ ttl = 64
+ )
tests := []struct {
name string
protoNum tcpip.NetworkProtocolNumber
- rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address)
+ rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8)
srcAddr tcpip.Address
dstAddr tcpip.Address
expectedSrc tcpip.Address
@@ -136,7 +139,7 @@ func TestPingMulticastBroadcast(t *testing.T) {
},
})
- test.rxICMP(e, test.srcAddr, test.dstAddr)
+ test.rxICMP(e, test.srcAddr, test.dstAddr, ttl)
pkt, ok := e.Read()
if !ok {
t.Fatal("expected ICMP response")
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index d1c9f3a94..8fd9be32b 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -48,10 +48,6 @@ const (
LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09")
)
-const (
- ttl = 255
-)
-
// Common IP addresses used by tests.
var (
Ipv4Addr = tcpip.AddressWithPrefix{
@@ -322,7 +318,7 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
// RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on
// the provided endpoint.
-func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
@@ -347,7 +343,7 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
// RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on
// the provided endpoint.
-func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) {
+func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) {
totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize
hdr := buffer.NewPrependable(totalLen)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize))
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 025b134e2..a485064a1 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -51,11 +51,6 @@ const (
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
-
- // SynRcvdCountThreshold is the default global maximum number of
- // connections that are allowed to be in SYN-RCVD state before TCP
- // starts using SYN cookies to accept connections.
- SynRcvdCountThreshold uint64 = 1000
)
var (
@@ -80,9 +75,6 @@ func encodeMSS(mss uint16) uint32 {
type listenContext struct {
stack *stack.Stack
- // synRcvdCount is a reference to the stack level synRcvdCount.
- synRcvdCount *synRcvdCounter
-
// rcvWnd is the receive window that is sent by this listening context
// in the initial SYN-ACK.
rcvWnd seqnum.Size
@@ -138,11 +130,6 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
- p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol)
- if !ok {
- panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk))
- }
- l.synRcvdCount = p.SynRcvdCounter()
rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:])
@@ -199,6 +186,14 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
+func (l *listenContext) useSynCookies() bool {
+ var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
+ if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
+ panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
+ }
+ return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull())
+}
+
// createConnectingEndpoint creates a new endpoint in a connecting state, with
// the connection parameters given by the arguments.
func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) {
@@ -307,6 +302,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q
// Initialize and start the handshake.
h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
+ h.listenEP = l.listenEP
h.start()
return h, nil
}
@@ -485,7 +481,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
go func() {
- defer ctx.synRcvdCount.dec()
if err := h.complete(); err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
@@ -497,24 +492,29 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
h.ep.startAcceptedLoop()
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
e.deliverAccepted(h.ep, false /*withSynCookie*/)
- }() // S/R-SAFE: synRcvdCount is the barrier.
+ }()
return nil
}
-func (e *endpoint) incSynRcvdCount() bool {
+func (e *endpoint) synRcvdBacklogFull() bool {
e.acceptMu.Lock()
- canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan)
+ acceptedChanCap := cap(e.acceptedChan)
e.acceptMu.Unlock()
- if canInc {
- atomic.AddInt32(&e.synRcvdCount, 1)
- }
- return canInc
+ // The allocated accepted channel size would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ //
+ // We maintain an equality check here as the synRcvdCount is incremented
+ // and compared only from a single listener context and the capacity of
+ // the accepted channel can only increase by a new listen call.
+ return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedChanCap-1
}
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
- full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan)
+ full := len(e.acceptedChan) == cap(e.acceptedChan)
e.acceptMu.Unlock()
return full
}
@@ -538,69 +538,55 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
switch {
case s.flags == header.TCPFlagSyn:
- opts := parseSynSegmentOptions(s)
- if ctx.synRcvdCount.inc() {
- // Only handle the syn if the following conditions hold
- // - accept queue is not full.
- // - number of connections in synRcvd state is less than the
- // backlog.
- if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
- s.incRef()
- _ = e.handleSynSegment(ctx, s, &opts)
- return nil
- }
- ctx.synRcvdCount.dec()
+ if e.acceptQueueIsFull() {
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
return nil
- } else {
- // If cookies are in use but the endpoint accept queue
- // is full then drop the syn.
- if e.acceptQueueIsFull() {
- e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
- e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
- e.stack.Stats().DroppedPackets.Increment()
- return nil
- }
- cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ }
- route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
- if err != nil {
- return err
- }
- defer route.Release()
+ opts := parseSynSegmentOptions(s)
+ if !ctx.useSynCookies() {
+ s.incRef()
+ atomic.AddInt32(&e.synRcvdCount, 1)
+ return e.handleSynSegment(ctx, s, &opts)
+ }
+ route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
- // Send SYN without window scaling because we currently
- // don't encode this information in the cookie.
- //
- // Enable Timestamp option if the original syn did have
- // the timestamp option specified.
- //
- // Use the user supplied MSS on the listening socket for
- // new connections, if available.
- synOpts := header.TCPSynOptions{
- WS: -1,
- TS: opts.TS,
- TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
- TSEcr: opts.TSVal,
- MSS: calculateAdvertisedMSS(e.userMSS, route),
- }
- fields := tcpFields{
- id: s.id,
- ttl: e.ttl,
- tos: e.sendTOS,
- flags: header.TCPFlagSyn | header.TCPFlagAck,
- seq: cookie,
- ack: s.sequenceNumber + 1,
- rcvWnd: ctx.rcvWnd,
- }
- if err := e.sendSynTCP(route, fields, synOpts); err != nil {
- return err
- }
- e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
- return nil
+ // Send SYN without window scaling because we currently
+ // don't encode this information in the cookie.
+ //
+ // Enable Timestamp option if the original syn did have
+ // the timestamp option specified.
+ //
+ // Use the user supplied MSS on the listening socket for
+ // new connections, if available.
+ synOpts := header.TCPSynOptions{
+ WS: -1,
+ TS: opts.TS,
+ TSVal: tcpTimeStamp(time.Now(), timeStampOffset()),
+ TSEcr: opts.TSVal,
+ MSS: calculateAdvertisedMSS(e.userMSS, route),
+ }
+ cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ fields := tcpFields{
+ id: s.id,
+ ttl: e.ttl,
+ tos: e.sendTOS,
+ flags: header.TCPFlagSyn | header.TCPFlagAck,
+ seq: cookie,
+ ack: s.sequenceNumber + 1,
+ rcvWnd: ctx.rcvWnd,
+ }
+ if err := e.sendSynTCP(route, fields, synOpts); err != nil {
+ return err
}
+ e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ return nil
case (s.flags & header.TCPFlagAck) != 0:
if e.acceptQueueIsFull() {
@@ -615,25 +601,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
}
- if !ctx.synRcvdCount.synCookiesInUse() {
- // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
- // Any acknowledgment is bad if it arrives on a connection still in
- // the LISTEN state. An acceptable reset segment should be formed
- // for any arriving ACK-bearing segment. The RST should be
- // formatted as follows:
- //
- // <SEQ=SEG.ACK><CTL=RST>
- //
- // Send a reset as this is an ACK for which there is no
- // half open connections and we are not using cookies
- // yet.
- //
- // The only time we should reach here when a connection
- // was opened and closed really quickly and a delayed
- // ACK was received from the sender.
- return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
- }
-
iss := s.ackNumber - 1
irs := s.sequenceNumber - 1
@@ -651,7 +618,23 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
if !ok || int(data) >= len(mssTable) {
e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
e.stack.Stats().DroppedPackets.Increment()
- return nil
+
+ // When not using SYN cookies, as per RFC 793, section 3.9, page 64:
+ // Any acknowledgment is bad if it arrives on a connection still in
+ // the LISTEN state. An acceptable reset segment should be formed
+ // for any arriving ACK-bearing segment. The RST should be
+ // formatted as follows:
+ //
+ // <SEQ=SEG.ACK><CTL=RST>
+ //
+ // Send a reset as this is an ACK for which there is no
+ // half open connections and we are not using cookies
+ // yet.
+ //
+ // The only time we should reach here when a connection
+ // was opened and closed really quickly and a delayed
+ // ACK was received from the sender.
+ return replyWithReset(e.stack, s, e.sendTOS, e.ttl)
}
e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
// Create newly accepted endpoint and deliver it.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index a9e978cf6..8f0f0c3e9 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -65,11 +65,12 @@ const (
// NOTE: handshake.ep.mu is held during handshake processing. It is released if
// we are going to block and reacquired when we start processing an event.
type handshake struct {
- ep *endpoint
- state handshakeState
- active bool
- flags header.TCPFlags
- ackNum seqnum.Value
+ ep *endpoint
+ listenEP *endpoint
+ state handshakeState
+ active bool
+ flags header.TCPFlags
+ ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793.
iss seqnum.Value
@@ -394,6 +395,15 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error {
return nil
}
+ // Drop the ACK if the accept queue is full.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_ipv4.c#L1523
+ // We could abort the connection as well with a tunable as in
+ // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_minisocks.c#L788
+ if listenEP := h.listenEP; listenEP != nil && listenEP.acceptQueueIsFull() {
+ listenEP.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
// Update timestamp if required. See RFC7323, section-4.3.
if h.ep.sendTSOk && s.parsedOptions.TS {
h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index f6a16f96e..d6d68f128 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -565,17 +565,15 @@ func TestV4AcceptOnV4(t *testing.T) {
}
func testV4ListenClose(t *testing.T, c *context.Context) {
- // Set the SynRcvd threshold to zero to force a syn cookie based accept
- // to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
- const n = uint16(32)
+ const n = 32
// Start listening.
- if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil {
+ if err := c.EP.Listen(n); err != nil {
t.Fatalf("Listen failed: %v", err)
}
@@ -591,9 +589,9 @@ func testV4ListenClose(t *testing.T, c *context.Context) {
})
}
- // Each of these ACK's will cause a syn-cookie based connection to be
+ // Each of these ACKs will cause a syn-cookie based connection to be
// accepted and delivered to the listening endpoint.
- for i := uint16(0); i < n; i++ {
+ for i := 0; i < n; i++ {
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss := seqnum.Value(tcp.SequenceNumber())
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index c5daba232..5001d222e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -2474,6 +2474,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error {
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
func (e *endpoint) Listen(backlog int) tcpip.Error {
+ // Accept one more than the configured listen backlog to keep in parity with
+ // Linux. Ref, because of missing equality check here:
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937
+ backlog++
err := e.listen(backlog)
if err != nil {
if !err.IgnoreStats() {
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 2a4667906..fe0d7f10f 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -75,63 +75,6 @@ const (
ccCubic = "cubic"
)
-// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
-// value is protected by a mutex so that we can increment only when it's
-// guaranteed not to go above a threshold.
-type synRcvdCounter struct {
- sync.Mutex
- value uint64
- pending sync.WaitGroup
- threshold uint64
-}
-
-// inc tries to increment the global number of endpoints in SYN-RCVD state. It
-// succeeds if the increment doesn't make the count go beyond the threshold, and
-// fails otherwise.
-func (s *synRcvdCounter) inc() bool {
- s.Lock()
- defer s.Unlock()
- if s.value >= s.threshold {
- return false
- }
-
- s.pending.Add(1)
- s.value++
-
- return true
-}
-
-// dec atomically decrements the global number of endpoints in SYN-RCVD
-// state. It must only be called if a previous call to inc succeeded.
-func (s *synRcvdCounter) dec() {
- s.Lock()
- defer s.Unlock()
- s.value--
- s.pending.Done()
-}
-
-// synCookiesInUse returns true if the synRcvdCount is greater than
-// SynRcvdCountThreshold.
-func (s *synRcvdCounter) synCookiesInUse() bool {
- s.Lock()
- defer s.Unlock()
- return s.value >= s.threshold
-}
-
-// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
-func (s *synRcvdCounter) SetThreshold(threshold uint64) {
- s.Lock()
- defer s.Unlock()
- s.threshold = threshold
-}
-
-// Threshold returns the current value of synRcvdCounter.Threhsold.
-func (s *synRcvdCounter) Threshold() uint64 {
- s.Lock()
- defer s.Unlock()
- return s.threshold
-}
-
type protocol struct {
stack *stack.Stack
@@ -139,6 +82,7 @@ type protocol struct {
sackEnabled bool
recovery tcpip.TCPRecovery
delayEnabled bool
+ alwaysUseSynCookies bool
sendBufferSize tcpip.TCPSendBufferSizeRangeOption
recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption
congestionControl string
@@ -150,7 +94,6 @@ type protocol struct {
minRTO time.Duration
maxRTO time.Duration
maxRetries uint32
- synRcvdCount synRcvdCounter
synRetries uint8
dispatcher dispatcher
}
@@ -373,9 +316,9 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip
p.mu.Unlock()
return nil
- case *tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPAlwaysUseSynCookies:
p.mu.Lock()
- p.synRcvdCount.SetThreshold(uint64(*v))
+ p.alwaysUseSynCookies = bool(*v)
p.mu.Unlock()
return nil
@@ -480,9 +423,9 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er
p.mu.RUnlock()
return nil
- case *tcpip.TCPSynRcvdCountThresholdOption:
+ case *tcpip.TCPAlwaysUseSynCookies:
p.mu.RLock()
- *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ *v = tcpip.TCPAlwaysUseSynCookies(p.alwaysUseSynCookies)
p.mu.RUnlock()
return nil
@@ -507,12 +450,6 @@ func (p *protocol) Wait() {
p.dispatcher.wait()
}
-// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
-// instance.
-func (p *protocol) SynRcvdCounter() *synRcvdCounter {
- return &p.synRcvdCount
-}
-
// Parse implements stack.TransportProtocol.Parse.
func (*protocol) Parse(pkt *stack.PacketBuffer) bool {
return parse.TCP(pkt)
@@ -537,7 +474,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol {
lingerTimeout: DefaultTCPLingerTimeout,
timeWaitTimeout: DefaultTCPTimeWaitTimeout,
timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly,
- synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
synRetries: DefaultSynRetries,
minRTO: MinRTO,
maxRTO: MaxRTO,
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 81f800cad..20c9761f2 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -160,12 +160,9 @@ func TestSackPermittedAccept(t *testing.T) {
defer c.Cleanup()
if tc.cookieEnabled {
- // Set the SynRcvd threshold to
- // zero to force a syn cookie
- // based accept to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
setStackSACKPermitted(t, c, sackEnabled)
@@ -235,12 +232,9 @@ func TestSackDisabledAccept(t *testing.T) {
defer c.Cleanup()
if tc.cookieEnabled {
- // Set the SynRcvd threshold to
- // zero to force a syn cookie
- // based accept to happen.
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9c23469f2..5605a4390 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -955,11 +955,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) {
// when completing the handshake for a new TCP connection from a TCP
// listening socket. It should be present in the sent TCP SYN-ACK segment.
func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
- const (
- nonSynCookieAccepts = 2
- totalAccepts = 4
- mtu = 5000
- )
+ const mtu = 5000
ips := []struct {
name string
@@ -1033,12 +1029,6 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
ip.createEP(c)
- // Set the SynRcvd threshold to force a syn cookie based accept to happen.
- opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil {
t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err)
}
@@ -1048,13 +1038,17 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) {
t.Fatalf("Bind(%+v): %s:", bindAddr, err)
}
- if err := c.EP.Listen(totalAccepts); err != nil {
- t.Fatalf("Listen(%d): %s:", totalAccepts, err)
+ backlog := 5
+ // Keep the number of client requests twice to the backlog
+ // such that half of the connections do not use syncookies
+ // and the other half does.
+ clientConnects := backlog * 2
+
+ if err := c.EP.Listen(backlog); err != nil {
+ t.Fatalf("Listen(%d): %s:", backlog, err)
}
- // The first nonSynCookieAccepts packets sent will trigger a gorooutine
- // based accept. The rest will trigger a cookie based accept.
- for i := 0; i < totalAccepts; i++ {
+ for i := 0; i < clientConnects; i++ {
// Send a SYN requests.
iss := seqnum.Value(i)
srcPort := context.TestPort + uint16(i)
@@ -3087,11 +3081,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
c := context.New(t, mtu)
defer c.Cleanup()
- // Set the SynRcvd threshold to zero to force a syn cookie based accept
- // to happen.
- opt := tcpip.TCPSynRcvdCountThresholdOption(0)
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
// Create EP and start listening.
@@ -5363,7 +5355,7 @@ func TestListenBacklogFull(t *testing.T) {
}
lastPortOffset := uint16(0)
- for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ {
+ for ; int(lastPortOffset) < listenBacklog+1; lastPortOffset++ {
executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */)
}
@@ -5671,15 +5663,13 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
}
// Test acceptance.
- // Start listening.
- listenBacklog := 1
- if err := c.EP.Listen(listenBacklog); err != nil {
+ if err := c.EP.Listen(0); err != nil {
t.Fatalf("Listen failed: %s", err)
}
// Send two SYN's the first one should get a SYN-ACK, the
// second one should not get any response and is dropped as
- // the synRcvd count will be equal to backlog.
+ // the accept queue is full.
irs := seqnum.Value(context.TestInitialSequenceNumber)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -5701,23 +5691,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
}
checker.IPv4(t, b, checker.TCP(tcpCheckers...))
- // Now execute send one more SYN. The stack should not respond as the backlog
- // is full at this point.
- //
- // NOTE: we did not complete the handshake for the previous one so the
- // accept backlog should be empty and there should be one connection in
- // synRcvd state.
- c.SendPacket(nil, &context.Headers{
- SrcPort: context.TestPort + 1,
- DstPort: context.StackPort,
- Flags: header.TCPFlagSyn,
- SeqNum: seqnum.Value(889),
- RcvWnd: 30000,
- })
- c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
-
- // Now complete the previous connection and verify that there is a connection
- // to accept.
+ // Now complete the previous connection.
// Send ACK.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
@@ -5728,11 +5702,24 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
RcvWnd: 30000,
})
- // Try to accept the connections in the backlog.
+ // Verify if that is delivered to the accept queue.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.ReadableEvents)
defer c.WQ.EventUnregister(&we)
+ <-ch
+
+ // Now execute send one more SYN. The stack should not respond as the backlog
+ // is full at this point.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort + 1,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: seqnum.Value(889),
+ RcvWnd: 30000,
+ })
+ c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond)
+ // Try to accept the connections in the backlog.
newEP, _, err := c.EP.Accept(nil)
if _, ok := err.(*tcpip.ErrWouldBlock); ok {
// Wait for connection to be established.
@@ -5764,11 +5751,6 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
- opt := tcpip.TCPSynRcvdCountThresholdOption(1)
- if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
- }
-
// Create TCP endpoint.
var err tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
@@ -5781,9 +5763,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
t.Fatalf("Bind failed: %s", err)
}
- // Start listening.
- listenBacklog := 1
- if err := c.EP.Listen(listenBacklog); err != nil {
+ // Test for SynCookies usage after filling up the backlog.
+ if err := c.EP.Listen(0); err != nil {
t.Fatalf("Listen failed: %s", err)
}
@@ -6066,7 +6047,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) {
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
- if err := c.EP.Listen(1); err != nil {
+ if err := c.EP.Listen(0); err != nil {
t.Fatalf("Listen failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 2949588ce..1deb1fe4d 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -139,9 +139,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
defer c.Cleanup()
if cookieEnabled {
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
@@ -202,9 +202,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
defer c.Cleanup()
if cookieEnabled {
- var opt tcpip.TCPSynRcvdCountThresholdOption
+ opt := tcpip.TCPAlwaysUseSynCookies(true)
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil {
- t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err)
+ t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err)
}
}
diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go
index 41fcf4978..06152a444 100644
--- a/pkg/test/dockerutil/container.go
+++ b/pkg/test/dockerutil/container.go
@@ -434,7 +434,14 @@ func (c *Container) Wait(ctx context.Context) error {
select {
case err := <-errChan:
return err
- case <-statusChan:
+ case res := <-statusChan:
+ if res.StatusCode != 0 {
+ var msg string
+ if res.Error != nil {
+ msg = res.Error.Message
+ }
+ return fmt.Errorf("container returned non-zero status: %d, msg: %q", res.StatusCode, msg)
+ }
return nil
}
}
diff --git a/runsc/BUILD b/runsc/BUILD
index 3b91b984a..e99404eb1 100644
--- a/runsc/BUILD
+++ b/runsc/BUILD
@@ -9,6 +9,7 @@ go_binary(
"version.go",
],
pure = True,
+ tags = ["staging"],
visibility = [
"//visibility:public",
],
diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD
index 67307ab3c..579edaa2c 100644
--- a/runsc/boot/BUILD
+++ b/runsc/boot/BUILD
@@ -57,6 +57,7 @@ go_library(
"//pkg/sentry/fs/tmpfs",
"//pkg/sentry/fs/tty",
"//pkg/sentry/fs/user",
+ "//pkg/sentry/fsimpl/cgroupfs",
"//pkg/sentry/fsimpl/devpts",
"//pkg/sentry/fsimpl/devtmpfs",
"//pkg/sentry/fsimpl/fuse",
@@ -66,6 +67,7 @@ go_library(
"//pkg/sentry/fsimpl/proc",
"//pkg/sentry/fsimpl/sys",
"//pkg/sentry/fsimpl/tmpfs",
+ "//pkg/sentry/fsimpl/verity",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel:uncaught_signal_go_proto",
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 1ae76d7d7..05b721b28 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -400,7 +400,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
// Set up the restore environment.
ctx := k.SupervisorContext()
- mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled)
+ mntr := newContainerMounter(&cm.l.root, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled)
if kernel.VFS2Enabled {
ctx, err = mntr.configureRestore(ctx, cm.l.root.conf)
if err != nil {
diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go
index 19ced9b0e..3c0cef6db 100644
--- a/runsc/boot/fs.go
+++ b/runsc/boot/fs.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/gofer"
"gvisor.dev/gvisor/pkg/sentry/fs/ramfs"
"gvisor.dev/gvisor/pkg/sentry/fs/user"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/cgroupfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
gofervfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer"
@@ -103,7 +104,7 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name
// compileMounts returns the supported mounts from the mount spec, adding any
// mandatory mounts that are required by the OCI specification.
-func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount {
+func compileMounts(spec *specs.Spec, conf *config.Config, vfs2Enabled bool) []specs.Mount {
// Keep track of whether proc and sys were mounted.
var procMounted, sysMounted, devMounted, devptsMounted bool
var mounts []specs.Mount
@@ -114,6 +115,11 @@ func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount {
log.Warningf("ignoring dev mount at %q", m.Destination)
continue
}
+ // Unconditionally drop any cgroupfs mounts. If requested, we'll add our
+ // own below.
+ if m.Type == cgroupfs.Name {
+ continue
+ }
switch filepath.Clean(m.Destination) {
case "/proc":
procMounted = true
@@ -132,6 +138,24 @@ func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount {
// Mount proc and sys even if the user did not ask for it, as the spec
// says we SHOULD.
var mandatoryMounts []specs.Mount
+
+ if conf.Cgroupfs {
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: tmpfsvfs2.Name,
+ Destination: "/sys/fs/cgroup",
+ })
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: cgroupfs.Name,
+ Destination: "/sys/fs/cgroup/memory",
+ Options: []string{"memory"},
+ })
+ mandatoryMounts = append(mandatoryMounts, specs.Mount{
+ Type: cgroupfs.Name,
+ Destination: "/sys/fs/cgroup/cpu",
+ Options: []string{"cpu"},
+ })
+ }
+
if !procMounted {
mandatoryMounts = append(mandatoryMounts, specs.Mount{
Type: procvfs2.Name,
@@ -248,6 +272,10 @@ func isSupportedMountFlag(fstype, opt string) bool {
ok, err := parseMountOption(opt, tmpfsAllowedData...)
return ok && err == nil
}
+ if fstype == cgroupfs.Name {
+ ok, err := parseMountOption(opt, cgroupfs.SupportedMountOptions...)
+ return ok && err == nil
+ }
return false
}
@@ -572,11 +600,11 @@ type containerMounter struct {
hints *podMountHints
}
-func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter {
+func newContainerMounter(info *containerInfo, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter {
return &containerMounter{
- root: spec.Root,
- mounts: compileMounts(spec, vfs2Enabled),
- fds: fdDispenser{fds: goferFDs},
+ root: info.spec.Root,
+ mounts: compileMounts(info.spec, info.conf, vfs2Enabled),
+ fds: fdDispenser{fds: info.goferFDs},
k: k,
hints: hints,
}
@@ -795,7 +823,13 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M
opts = p9MountData(fd, c.getMountAccessType(conf, m), conf.VFS2)
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
-
+ case cgroupfs.Name:
+ fsName = m.Type
+ var err error
+ opts, err = parseAndFilterOptions(m.Options, cgroupfs.SupportedMountOptions...)
+ if err != nil {
+ return "", nil, false, err
+ }
default:
log.Warningf("ignoring unknown filesystem type %q", m.Type)
}
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 774621970..95daf1f00 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -752,7 +752,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn
// Setup the child container file system.
l.startGoferMonitor(cid, info.goferFDs)
- mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints, kernel.VFS2Enabled)
+ mntr := newContainerMounter(info, l.k, l.mountHints, kernel.VFS2Enabled)
if root {
if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil {
return nil, nil, nil, err
diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go
index 8b39bc59a..93c476971 100644
--- a/runsc/boot/loader_test.go
+++ b/runsc/boot/loader_test.go
@@ -439,7 +439,13 @@ func TestCreateMountNamespace(t *testing.T) {
}
defer cleanup()
- mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}, false /* vfs2Enabled */)
+ info := containerInfo{
+ conf: conf,
+ spec: &tc.spec,
+ goferFDs: []*fd.FD{fd.New(sandEnd)},
+ }
+
+ mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */)
mns, err := mntr.createMountNamespace(ctx, conf)
if err != nil {
t.Fatalf("failed to create mount namespace: %v", err)
@@ -479,7 +485,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) {
defer l.Destroy()
defer loaderCleanup()
- mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints, true /* vfs2Enabled */)
+ mntr := newContainerMounter(&l.root, l.k, l.mountHints, true /* vfs2Enabled */)
if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil {
t.Fatalf("failed process hints: %v", err)
}
@@ -702,7 +708,12 @@ func TestRestoreEnvironment(t *testing.T) {
for _, ioFD := range tc.ioFDs {
ioFDs = append(ioFDs, fd.New(ioFD))
}
- mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}, false /* vfs2Enabled */)
+ info := containerInfo{
+ conf: conf,
+ spec: tc.spec,
+ goferFDs: ioFDs,
+ }
+ mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */)
actualRenv, err := mntr.createRestoreEnvironment(conf)
if !tc.errorExpected && err != nil {
t.Fatalf("could not create restore environment for test:%s", tc.name)
diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go
index 9b3dacf46..7d8fd0483 100644
--- a/runsc/boot/vfs.go
+++ b/runsc/boot/vfs.go
@@ -16,6 +16,7 @@ package boot
import (
"fmt"
+ "path"
"sort"
"strings"
@@ -29,6 +30,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/devices/ttydev"
"gvisor.dev/gvisor/pkg/sentry/devices/tundev"
"gvisor.dev/gvisor/pkg/sentry/fs/user"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/cgroupfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/fuse"
@@ -37,6 +39,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fsimpl/proc"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sys"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/verity"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -50,6 +53,10 @@ func registerFilesystems(k *kernel.Kernel) error {
creds := auth.NewRootCredentials(k.RootUserNamespace())
vfsObj := k.VFS()
+ vfsObj.MustRegisterFilesystemType(cgroupfs.Name, &cgroupfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
vfsObj.MustRegisterFilesystemType(devpts.Name, &devpts.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserList: true,
// TODO(b/29356795): Users may mount this once the terminals are in a
@@ -60,6 +67,10 @@ func registerFilesystems(k *kernel.Kernel) error {
AllowUserMount: true,
AllowUserList: true,
})
+ vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
+ AllowUserMount: true,
+ AllowUserList: true,
+ })
vfsObj.MustRegisterFilesystemType(gofer.Name, &gofer.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserList: true,
})
@@ -79,9 +90,9 @@ func registerFilesystems(k *kernel.Kernel) error {
AllowUserMount: true,
AllowUserList: true,
})
- vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
- AllowUserMount: true,
+ vfsObj.MustRegisterFilesystemType(verity.Name, &verity.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{
AllowUserList: true,
+ AllowUserMount: true,
})
// Setup files in devtmpfs.
@@ -472,6 +483,12 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
var data []string
var iopts interface{}
+ verityData, verityOpts, verityRequested, remainingMOpts, err := parseVerityMountOptions(m.Options)
+ if err != nil {
+ return "", nil, false, err
+ }
+ m.Options = remainingMOpts
+
// Find filesystem name and FS specific data field.
switch m.Type {
case devpts.Name, devtmpfs.Name, proc.Name, sys.Name:
@@ -502,6 +519,13 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
// If configured, add overlay to all writable mounts.
useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly
+ case cgroupfs.Name:
+ var err error
+ data, err = parseAndFilterOptions(m.Options, cgroupfs.SupportedMountOptions...)
+ if err != nil {
+ return "", nil, false, err
+ }
+
default:
log.Warningf("ignoring unknown filesystem type %q", m.Type)
return "", nil, false, nil
@@ -530,9 +554,75 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo
}
}
+ if verityRequested {
+ verityData = verityData + "root_name=" + path.Base(m.Mount.Destination)
+ verityOpts.LowerName = fsName
+ verityOpts.LowerGetFSOptions = opts.GetFilesystemOptions
+ fsName = verity.Name
+ opts = &vfs.MountOptions{
+ GetFilesystemOptions: vfs.GetFilesystemOptions{
+ Data: verityData,
+ InternalData: verityOpts,
+ },
+ InternalMount: true,
+ }
+ }
+
return fsName, opts, useOverlay, nil
}
+func parseKeyValue(s string) (string, string, bool) {
+ tokens := strings.SplitN(s, "=", 2)
+ if len(tokens) < 2 {
+ return "", "", false
+ }
+ return strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1]), true
+}
+
+// parseAndFilterOptions scans the provided mount options for verity-related
+// mount options. It returns the parsed set of verity mount options, as well as
+// the filtered set of mount options unrelated to verity.
+func parseVerityMountOptions(mopts []string) (string, verity.InternalFilesystemOptions, bool, []string, error) {
+ nonVerity := []string{}
+ found := false
+ var rootHash string
+ verityOpts := verity.InternalFilesystemOptions{
+ Action: verity.PanicOnViolation,
+ }
+ for _, o := range mopts {
+ if !strings.HasPrefix(o, "verity.") {
+ nonVerity = append(nonVerity, o)
+ continue
+ }
+
+ k, v, ok := parseKeyValue(o)
+ if !ok {
+ return "", verityOpts, found, nonVerity, fmt.Errorf("invalid verity mount option with no value: %q", o)
+ }
+
+ found = true
+ switch k {
+ case "verity.roothash":
+ rootHash = v
+ case "verity.action":
+ switch v {
+ case "error":
+ verityOpts.Action = verity.ErrorOnViolation
+ case "panic":
+ verityOpts.Action = verity.PanicOnViolation
+ default:
+ log.Warningf("Invalid verity action %q", v)
+ verityOpts.Action = verity.PanicOnViolation
+ }
+ default:
+ return "", verityOpts, found, nonVerity, fmt.Errorf("unknown verity mount option: %q", k)
+ }
+ }
+ verityOpts.AllowRuntimeEnable = len(rootHash) == 0
+ verityData := "root_hash=" + rootHash + ","
+ return verityData, verityOpts, found, nonVerity, nil
+}
+
// mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so.
// Technically we don't have to mount tmpfs at /tmp, as we could just rely on
// the host /tmp, but this is a nice optimization, and fixes some apps that call
diff --git a/runsc/cli/main.go b/runsc/cli/main.go
index a3c515f4b..6db6614cc 100644
--- a/runsc/cli/main.go
+++ b/runsc/cli/main.go
@@ -86,6 +86,7 @@ func Main(version string) {
subcommands.Register(new(cmd.Symbolize), "")
subcommands.Register(new(cmd.Wait), "")
subcommands.Register(new(cmd.Mitigate), "")
+ subcommands.Register(new(cmd.VerityPrepare), "")
// Register internal commands with the internal group name. This causes
// them to be sorted below the user-facing commands with empty group.
diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD
index 2c3b4058b..4b9987cf6 100644
--- a/runsc/cmd/BUILD
+++ b/runsc/cmd/BUILD
@@ -35,6 +35,7 @@ go_library(
"statefile.go",
"symbolize.go",
"syscalls.go",
+ "verity_prepare.go",
"wait.go",
],
visibility = [
diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go
index 455c57692..5485db149 100644
--- a/runsc/cmd/do.go
+++ b/runsc/cmd/do.go
@@ -126,9 +126,8 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
Hostname: hostname,
}
- specutils.LogSpec(spec)
-
cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000))
+
if conf.Network == config.NetworkNone {
addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace})
@@ -154,55 +153,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su
}
}
- out, err := json.Marshal(spec)
- if err != nil {
- return Errorf("Error to marshal spec: %v", err)
- }
- tmpDir, err := ioutil.TempDir("", "runsc-do")
- if err != nil {
- return Errorf("Error to create tmp dir: %v", err)
- }
- defer os.RemoveAll(tmpDir)
-
- log.Infof("Changing configuration RootDir to %q", tmpDir)
- conf.RootDir = tmpDir
-
- cfgPath := filepath.Join(tmpDir, "config.json")
- if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil {
- return Errorf("Error write spec: %v", err)
- }
-
- containerArgs := container.Args{
- ID: cid,
- Spec: spec,
- BundleDir: tmpDir,
- Attached: true,
- }
- ct, err := container.New(conf, containerArgs)
- if err != nil {
- return Errorf("creating container: %v", err)
- }
- defer ct.Destroy()
-
- if err := ct.Start(conf); err != nil {
- return Errorf("starting container: %v", err)
- }
-
- // Forward signals to init in the container. Thus if we get SIGINT from
- // ^C, the container gracefully exit, and we can clean up.
- //
- // N.B. There is a still a window before this where a signal may kill
- // this process, skipping cleanup.
- stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */)
- defer stopForwarding()
-
- ws, err := ct.Wait()
- if err != nil {
- return Errorf("waiting for container: %v", err)
- }
-
- *waitStatus = ws
- return subcommands.ExitSuccess
+ return startContainerAndWait(spec, conf, cid, waitStatus)
}
func addNamespace(spec *specs.Spec, ns specs.LinuxNamespace) {
@@ -397,3 +348,58 @@ func calculatePeerIP(ip string) (string, error) {
}
return fmt.Sprintf("%s.%s.%s.%d", parts[0], parts[1], parts[2], n), nil
}
+
+func startContainerAndWait(spec *specs.Spec, conf *config.Config, cid string, waitStatus *unix.WaitStatus) subcommands.ExitStatus {
+ specutils.LogSpec(spec)
+
+ out, err := json.Marshal(spec)
+ if err != nil {
+ return Errorf("Error to marshal spec: %v", err)
+ }
+ tmpDir, err := ioutil.TempDir("", "runsc-do")
+ if err != nil {
+ return Errorf("Error to create tmp dir: %v", err)
+ }
+ defer os.RemoveAll(tmpDir)
+
+ log.Infof("Changing configuration RootDir to %q", tmpDir)
+ conf.RootDir = tmpDir
+
+ cfgPath := filepath.Join(tmpDir, "config.json")
+ if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil {
+ return Errorf("Error write spec: %v", err)
+ }
+
+ containerArgs := container.Args{
+ ID: cid,
+ Spec: spec,
+ BundleDir: tmpDir,
+ Attached: true,
+ }
+
+ ct, err := container.New(conf, containerArgs)
+ if err != nil {
+ return Errorf("creating container: %v", err)
+ }
+ defer ct.Destroy()
+
+ if err := ct.Start(conf); err != nil {
+ return Errorf("starting container: %v", err)
+ }
+
+ // Forward signals to init in the container. Thus if we get SIGINT from
+ // ^C, the container gracefully exit, and we can clean up.
+ //
+ // N.B. There is a still a window before this where a signal may kill
+ // this process, skipping cleanup.
+ stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */)
+ defer stopForwarding()
+
+ ws, err := ct.Wait()
+ if err != nil {
+ return Errorf("waiting for container: %v", err)
+ }
+
+ *waitStatus = ws
+ return subcommands.ExitSuccess
+}
diff --git a/runsc/cmd/verity_prepare.go b/runsc/cmd/verity_prepare.go
new file mode 100644
index 000000000..66128b2a3
--- /dev/null
+++ b/runsc/cmd/verity_prepare.go
@@ -0,0 +1,108 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package cmd
+
+import (
+ "context"
+ "fmt"
+ "math/rand"
+ "os"
+
+ "github.com/google/subcommands"
+ specs "github.com/opencontainers/runtime-spec/specs-go"
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/runsc/config"
+ "gvisor.dev/gvisor/runsc/flag"
+ "gvisor.dev/gvisor/runsc/specutils"
+)
+
+// VerityPrepare implements subcommands.Commands for the "verity-prepare"
+// command. It sets up a sandbox with a writable verity mount mapped to "--dir",
+// and executes the verity measure tool specified by "--tool" in the sandbox. It
+// is intended to prepare --dir to be mounted as a verity filesystem.
+type VerityPrepare struct {
+ root string
+ tool string
+ dir string
+}
+
+// Name implements subcommands.Command.Name.
+func (*VerityPrepare) Name() string {
+ return "verity-prepare"
+}
+
+// Synopsis implements subcommands.Command.Synopsis.
+func (*VerityPrepare) Synopsis() string {
+ return "Generates the data structures necessary to enable verityfs on a filesystem."
+}
+
+// Usage implements subcommands.Command.Usage.
+func (*VerityPrepare) Usage() string {
+ return "verity-prepare --tool=<measure_tool> --dir=<path>"
+}
+
+// SetFlags implements subcommands.Command.SetFlags.
+func (c *VerityPrepare) SetFlags(f *flag.FlagSet) {
+ f.StringVar(&c.root, "root", "/", `path to the root directory, defaults to "/"`)
+ f.StringVar(&c.tool, "tool", "", "path to the verity measure_tool")
+ f.StringVar(&c.dir, "dir", "", "path to the directory to be hashed")
+}
+
+// Execute implements subcommands.Command.Execute.
+func (c *VerityPrepare) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus {
+ conf := args[0].(*config.Config)
+ waitStatus := args[1].(*unix.WaitStatus)
+
+ hostname, err := os.Hostname()
+ if err != nil {
+ return Errorf("Error to retrieve hostname: %v", err)
+ }
+
+ // Map the entire host file system.
+ absRoot, err := resolvePath(c.root)
+ if err != nil {
+ return Errorf("Error resolving root: %v", err)
+ }
+
+ spec := &specs.Spec{
+ Root: &specs.Root{
+ Path: absRoot,
+ },
+ Process: &specs.Process{
+ Cwd: absRoot,
+ Args: []string{c.tool, "--path", "/verityroot"},
+ Env: os.Environ(),
+ Capabilities: specutils.AllCapabilities(),
+ },
+ Hostname: hostname,
+ Mounts: []specs.Mount{
+ specs.Mount{
+ Source: c.dir,
+ Destination: "/verityroot",
+ Type: "bind",
+ Options: []string{"verity.roothash="},
+ },
+ },
+ }
+
+ cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000))
+
+ // Force no networking, it is not necessary to run the verity measure tool.
+ conf.Network = config.NetworkNone
+
+ conf.Verity = true
+
+ return startContainerAndWait(spec, conf, cid, waitStatus)
+}
diff --git a/runsc/config/config.go b/runsc/config/config.go
index 1e5858837..0b2b97cc5 100644
--- a/runsc/config/config.go
+++ b/runsc/config/config.go
@@ -172,6 +172,9 @@ type Config struct {
// Enables seccomp inside the sandbox.
OCISeccomp bool `flag:"oci-seccomp"`
+ // Mounts the cgroup filesystem backed by the sentry's cgroupfs.
+ Cgroupfs bool `flag:"cgroupfs"`
+
// TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in
// tests. It allows runsc to start the sandbox process as the current
// user, and without chrooting the sandbox process. This can be
diff --git a/runsc/config/flags.go b/runsc/config/flags.go
index 1d996c841..13a1a0163 100644
--- a/runsc/config/flags.go
+++ b/runsc/config/flags.go
@@ -75,6 +75,7 @@ func RegisterFlags() {
flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.")
flag.Bool("vfs2", false, "enables VFSv2. This uses the new VFS layer that is faster than the previous one.")
flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.")
+ flag.Bool("cgroupfs", false, "Automatically mount cgroupfs.")
// Flags that control sandbox runtime behavior: network related.
flag.Var(networkTypePtr(NetworkSandbox), "network", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.")
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index 450f92645..47da2dd10 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -486,7 +486,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn
}
if deviceFile, err := gPlatform.OpenDevice(); err != nil {
- return fmt.Errorf("opening device file for platform %q: %v", gPlatform, err)
+ return fmt.Errorf("opening device file for platform %q: %v", conf.Platform, err)
} else if deviceFile != nil {
defer deviceFile.Close()
cmd.ExtraFiles = append(cmd.ExtraFiles, deviceFile)
@@ -1174,7 +1174,7 @@ func deviceFileForPlatform(name string) (*os.File, error) {
f, err := p.OpenDevice()
if err != nil {
- return nil, fmt.Errorf("opening device file for platform %q: %v", p, err)
+ return nil, fmt.Errorf("opening device file for platform %q: %w", name, err)
}
return f, nil
}
diff --git a/runsc/specutils/fs.go b/runsc/specutils/fs.go
index b62504a8c..9ecd0fde6 100644
--- a/runsc/specutils/fs.go
+++ b/runsc/specutils/fs.go
@@ -18,6 +18,7 @@ import (
"fmt"
"math/bits"
"path"
+ "strings"
specs "github.com/opencontainers/runtime-spec/specs-go"
"golang.org/x/sys/unix"
@@ -64,6 +65,12 @@ var optionsMap = map[string]mapping{
"sync": {set: true, val: unix.MS_SYNCHRONOUS},
}
+// verityMountOptions is the set of valid verity mount option keys.
+var verityMountOptions = map[string]struct{}{
+ "verity.roothash": struct{}{},
+ "verity.action": struct{}{},
+}
+
// propOptionsMap is similar to optionsMap, but it lists propagation options
// that cannot be used together with other flags.
var propOptionsMap = map[string]mapping{
@@ -117,6 +124,14 @@ func validateMount(mnt *specs.Mount) error {
return nil
}
+func moptKey(opt string) string {
+ if len(opt) == 0 {
+ return opt
+ }
+ // Guaranteed to have at least one token, since opt is not empty.
+ return strings.SplitN(opt, "=", 2)[0]
+}
+
// ValidateMountOptions validates that mount options are correct.
func ValidateMountOptions(opts []string) error {
for _, o := range opts {
@@ -125,7 +140,8 @@ func ValidateMountOptions(opts []string) error {
}
_, ok1 := optionsMap[o]
_, ok2 := propOptionsMap[o]
- if !ok1 && !ok2 {
+ _, ok3 := verityMountOptions[moptKey(o)]
+ if !ok1 && !ok2 && !ok3 {
return fmt.Errorf("unknown mount option %q", o)
}
if err := validatePropagation(o); err != nil {
diff --git a/shim/BUILD b/shim/BUILD
index 434269d31..695f61eb9 100644
--- a/shim/BUILD
+++ b/shim/BUILD
@@ -6,6 +6,7 @@ go_binary(
name = "containerd-shim-runsc-v1",
srcs = ["main.go"],
static = True,
+ tags = ["staging"],
visibility = [
"//visibility:public",
],
diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD
index 697ab5837..a5a3cf2c1 100644
--- a/test/benchmarks/base/BUILD
+++ b/test/benchmarks/base/BUILD
@@ -17,7 +17,6 @@ go_library(
benchmark_test(
name = "startup_test",
- size = "enormous",
srcs = ["startup_test.go"],
visibility = ["//:sandbox"],
deps = [
@@ -29,7 +28,6 @@ benchmark_test(
benchmark_test(
name = "size_test",
- size = "enormous",
srcs = ["size_test.go"],
visibility = ["//:sandbox"],
deps = [
@@ -42,7 +40,6 @@ benchmark_test(
benchmark_test(
name = "sysbench_test",
- size = "enormous",
srcs = ["sysbench_test.go"],
visibility = ["//:sandbox"],
deps = [
diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD
index 0b1743603..fee2695ff 100644
--- a/test/benchmarks/database/BUILD
+++ b/test/benchmarks/database/BUILD
@@ -11,7 +11,6 @@ go_library(
benchmark_test(
name = "redis_test",
- size = "enormous",
srcs = ["redis_test.go"],
library = ":database",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD
index dc82e63b2..c2b981a07 100644
--- a/test/benchmarks/fs/BUILD
+++ b/test/benchmarks/fs/BUILD
@@ -4,7 +4,6 @@ package(licenses = ["notice"])
benchmark_test(
name = "bazel_test",
- size = "enormous",
srcs = ["bazel_test.go"],
visibility = ["//:sandbox"],
deps = [
@@ -18,7 +17,6 @@ benchmark_test(
benchmark_test(
name = "fio_test",
- size = "enormous",
srcs = ["fio_test.go"],
visibility = ["//:sandbox"],
deps = [
diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD
index 380783f0b..ad2ef3a55 100644
--- a/test/benchmarks/media/BUILD
+++ b/test/benchmarks/media/BUILD
@@ -11,7 +11,6 @@ go_library(
benchmark_test(
name = "ffmpeg_test",
- size = "enormous",
srcs = ["ffmpeg_test.go"],
library = ":media",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD
index 3425b8dad..56a4d4f39 100644
--- a/test/benchmarks/ml/BUILD
+++ b/test/benchmarks/ml/BUILD
@@ -11,7 +11,6 @@ go_library(
benchmark_test(
name = "tensorflow_test",
- size = "enormous",
srcs = ["tensorflow_test.go"],
library = ":ml",
visibility = ["//:sandbox"],
diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD
index 2741570f5..e047020bf 100644
--- a/test/benchmarks/network/BUILD
+++ b/test/benchmarks/network/BUILD
@@ -18,7 +18,6 @@ go_library(
benchmark_test(
name = "iperf_test",
- size = "enormous",
srcs = [
"iperf_test.go",
],
@@ -34,7 +33,6 @@ benchmark_test(
benchmark_test(
name = "node_test",
- size = "enormous",
srcs = [
"node_test.go",
],
@@ -49,7 +47,6 @@ benchmark_test(
benchmark_test(
name = "ruby_test",
- size = "enormous",
srcs = [
"ruby_test.go",
],
@@ -64,7 +61,6 @@ benchmark_test(
benchmark_test(
name = "nginx_test",
- size = "enormous",
srcs = [
"nginx_test.go",
],
@@ -79,7 +75,6 @@ benchmark_test(
benchmark_test(
name = "httpd_test",
- size = "enormous",
srcs = [
"httpd_test.go",
],
diff --git a/test/e2e/BUILD b/test/e2e/BUILD
index 29a84f184..3b3dadf04 100644
--- a/test/e2e/BUILD
+++ b/test/e2e/BUILD
@@ -8,7 +8,6 @@ go_test(
srcs = [
"exec_test.go",
"integration_test.go",
- "regression_test.go",
],
library = ":integration",
tags = [
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index e83576722..1accc3b3b 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -168,13 +168,6 @@ func TestCheckpointRestore(t *testing.T) {
t.Skip("Pause/resume is not supported.")
}
- // TODO(gvisor.dev/issue/3373): Remove after implementing.
- if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 {
- t.Skip("CheckpointRestore not implemented in VFS2.")
- } else if err != nil {
- t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err)
- }
-
ctx := context.Background()
d := dockerutil.MakeContainer(ctx, t)
defer d.CleanUp(ctx)
@@ -592,6 +585,30 @@ func runIntegrationTest(t *testing.T, capAdd []string, args ...string) {
}
}
+// Test that UDS can be created using overlay when parent directory is in lower
+// layer only (b/134090485).
+//
+// Prerequisite: the directory where the socket file is created must not have
+// been open for write before bind(2) is called.
+func TestBindOverlay(t *testing.T) {
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ // Run the container.
+ got, err := d.Run(ctx, dockerutil.RunOpts{
+ Image: "basic/ubuntu",
+ }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p")
+ if err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // Check the output contains what we want.
+ if want := "foobar-asdf"; !strings.Contains(got, want) {
+ t.Fatalf("docker run output is missing %q: %s", want, got)
+ }
+}
+
func TestMain(m *testing.M) {
dockerutil.EnsureSupportedDockerVersion()
flag.Parse()
diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go
deleted file mode 100644
index 84564cdaa..000000000
--- a/test/e2e/regression_test.go
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package integration
-
-import (
- "context"
- "strings"
- "testing"
-
- "gvisor.dev/gvisor/pkg/test/dockerutil"
-)
-
-// Test that UDS can be created using overlay when parent directory is in lower
-// layer only (b/134090485).
-//
-// Prerequisite: the directory where the socket file is created must not have
-// been open for write before bind(2) is called.
-func TestBindOverlay(t *testing.T) {
- ctx := context.Background()
- d := dockerutil.MakeContainer(ctx, t)
- defer d.CleanUp(ctx)
-
- // Run the container.
- got, err := d.Run(ctx, dockerutil.RunOpts{
- Image: "basic/ubuntu",
- }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p")
- if err != nil {
- t.Fatalf("docker run failed: %v", err)
- }
-
- // Check the output contains what we want.
- if want := "foobar-asdf"; !strings.Contains(got, want) {
- t.Fatalf("docker run output is missing %q: %s", want, got)
- }
-}
diff --git a/test/fsstress/BUILD b/test/fsstress/BUILD
index d262c8554..e74e7fff2 100644
--- a/test/fsstress/BUILD
+++ b/test/fsstress/BUILD
@@ -14,9 +14,7 @@ go_test(
"manual",
"local",
],
- deps = [
- "//pkg/test/dockerutil",
- ],
+ deps = ["//pkg/test/dockerutil"],
)
go_library(
diff --git a/test/fsstress/fsstress_test.go b/test/fsstress/fsstress_test.go
index 300c21ceb..d53c8f90d 100644
--- a/test/fsstress/fsstress_test.go
+++ b/test/fsstress/fsstress_test.go
@@ -17,7 +17,9 @@ package fsstress
import (
"context"
+ "flag"
"math/rand"
+ "os"
"strconv"
"strings"
"testing"
@@ -30,33 +32,44 @@ func init() {
rand.Seed(int64(time.Now().Nanosecond()))
}
-func fsstress(t *testing.T, dir string) {
+func TestMain(m *testing.M) {
+ dockerutil.EnsureSupportedDockerVersion()
+ flag.Parse()
+ os.Exit(m.Run())
+}
+
+type config struct {
+ operations string
+ processes string
+ target string
+}
+
+func fsstress(t *testing.T, conf config) {
ctx := context.Background()
d := dockerutil.MakeContainer(ctx, t)
defer d.CleanUp(ctx)
- const (
- operations = "10000"
- processes = "100"
- image = "basic/fsstress"
- )
+ const image = "basic/fsstress"
seed := strconv.FormatUint(uint64(rand.Uint32()), 10)
- args := []string{"-d", dir, "-n", operations, "-p", processes, "-s", seed, "-X"}
- t.Logf("Repro: docker run --rm --runtime=runsc %s %s", image, strings.Join(args, ""))
+ args := []string{"-d", conf.target, "-n", conf.operations, "-p", conf.processes, "-s", seed, "-X"}
+ t.Logf("Repro: docker run --rm --runtime=%s gvisor.dev/images/%s %s", dockerutil.Runtime(), image, strings.Join(args, " "))
out, err := d.Run(ctx, dockerutil.RunOpts{Image: image}, args...)
if err != nil {
t.Fatalf("docker run failed: %v\noutput: %s", err, out)
}
- lines := strings.SplitN(out, "\n", 2)
- if len(lines) > 1 || !strings.HasPrefix(out, "seed =") {
+ // This is to catch cases where fsstress spews out error messages during clean
+ // up but doesn't return error.
+ if len(out) > 0 {
t.Fatalf("unexpected output: %s", out)
}
}
-func TestFsstressGofer(t *testing.T) {
- fsstress(t, "/test")
-}
-
func TestFsstressTmpfs(t *testing.T) {
- fsstress(t, "/tmp")
+ // This takes between 10s to run on my machine. Adjust as needed.
+ cfg := config{
+ operations: "5000",
+ processes: "20",
+ target: "/tmp",
+ }
+ fsstress(t, cfg)
}
diff --git a/test/image/image_test.go b/test/image/image_test.go
index 968e62f63..952264173 100644
--- a/test/image/image_test.go
+++ b/test/image/image_test.go
@@ -183,7 +183,10 @@ func TestMysql(t *testing.T) {
// Start the container.
if err := server.Spawn(ctx, dockerutil.RunOpts{
Image: "basic/mysql",
- Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"},
+ Env: []string{
+ "MYSQL_ROOT_PASSWORD=foobar123",
+ "MYSQL_ROOT_HOST=%", // Allow anyone to connect to the server.
+ },
}); err != nil {
t.Fatalf("docker run failed: %v", err)
}
diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl
index 34e83ec49..634c15727 100644
--- a/test/packetimpact/runner/defs.bzl
+++ b/test/packetimpact/runner/defs.bzl
@@ -246,6 +246,12 @@ ALL_TESTS = [
expect_netstack_failure = True,
),
PacketimpactTestInfo(
+ name = "tcp_listen_backlog",
+ ),
+ PacketimpactTestInfo(
+ name = "tcp_syncookie",
+ ),
+ PacketimpactTestInfo(
name = "icmpv6_param_problem",
),
PacketimpactTestInfo(
diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD
index 92103c1e9..83ff70951 100644
--- a/test/packetimpact/tests/BUILD
+++ b/test/packetimpact/tests/BUILD
@@ -385,6 +385,26 @@ packetimpact_testbench(
],
)
+packetimpact_testbench(
+ name = "tcp_listen_backlog",
+ srcs = ["tcp_listen_backlog_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
+packetimpact_testbench(
+ name = "tcp_syncookie",
+ srcs = ["tcp_syncookie_test.go"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//test/packetimpact/testbench",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
+
validate_all_tests()
[packetimpact_go_test(
diff --git a/test/packetimpact/tests/tcp_listen_backlog_test.go b/test/packetimpact/tests/tcp_listen_backlog_test.go
new file mode 100644
index 000000000..26c812d0a
--- /dev/null
+++ b/test/packetimpact/tests/tcp_listen_backlog_test.go
@@ -0,0 +1,86 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_listen_backlog_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+// TestTCPListenBacklog tests for a listening endpoint behavior:
+// (1) reply to more SYNs than what is configured as listen backlog
+// (2) ignore ACKs (that complete a handshake) when the accept queue is full
+// (3) ignore incoming SYNs when the accept queue is full
+func TestTCPListenBacklog(t *testing.T) {
+ dut := testbench.NewDUT(t)
+
+ // Listening endpoint accepts one more connection than the listen backlog.
+ listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 0 /*backlog*/)
+
+ var establishedConn testbench.TCPIPv4
+ var incompleteConn testbench.TCPIPv4
+
+ // Test if the DUT listener replies to more SYNs than listen backlog+1
+ for i, conn := range []*testbench.TCPIPv4{&establishedConn, &incompleteConn} {
+ *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ // Expect dut connection to have transitioned to SYN-RCVD state.
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK for %d connection, %s", i, err)
+ }
+ }
+ defer establishedConn.Close(t)
+ defer incompleteConn.Close(t)
+
+ // Send the ACK to complete handshake.
+ establishedConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ dut.PollOne(t, listenFd, unix.POLLIN, time.Second)
+
+ // Send the ACK to complete handshake, expect this to be ignored by the
+ // listener.
+ incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+
+ // Drain the accept queue to enable poll for subsequent connections on the
+ // listener.
+ dut.Accept(t, listenFd)
+
+ // The ACK for the incomplete connection should be ignored by the
+ // listening endpoint and the poll on listener should now time out.
+ if pfds := dut.Poll(t, []unix.PollFd{{Fd: listenFd, Events: unix.POLLIN}}, time.Second); len(pfds) != 0 {
+ t.Fatalf("got dut.Poll(...) = %#v", pfds)
+ }
+
+ // Re-send the ACK to complete handshake and re-fill the accept-queue.
+ incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)})
+ dut.PollOne(t, listenFd, unix.POLLIN, time.Second)
+
+ // Now initiate a new connection when the accept queue is full.
+ connectingConn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ defer connectingConn.Close(t)
+ // Expect dut connection to drop the SYN and let the client stay in SYN_SENT state.
+ connectingConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if got, err := connectingConn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err == nil {
+ t.Fatalf("expected no SYN-ACK, but got %s", got)
+ }
+}
diff --git a/test/packetimpact/tests/tcp_syncookie_test.go b/test/packetimpact/tests/tcp_syncookie_test.go
new file mode 100644
index 000000000..1c21c62ff
--- /dev/null
+++ b/test/packetimpact/tests/tcp_syncookie_test.go
@@ -0,0 +1,70 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp_syncookie_test
+
+import (
+ "flag"
+ "testing"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/test/packetimpact/testbench"
+)
+
+func init() {
+ testbench.Initialize(flag.CommandLine)
+}
+
+// TestSynCookie test if the DUT listener is replying back using syn cookies.
+// The test does not complete the handshake by not sending the ACK to SYNACK.
+// When syncookies are not used, this forces the listener to retransmit SYNACK.
+// And when syncookies are being used, there is no such retransmit.
+func TestTCPSynCookie(t *testing.T) {
+ dut := testbench.NewDUT(t)
+
+ // Listening endpoint accepts one more connection than the listen backlog.
+ _, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/)
+
+ var withoutSynCookieConn testbench.TCPIPv4
+ var withSynCookieConn testbench.TCPIPv4
+
+ // Test if the DUT listener replies to more SYNs than listen backlog+1
+ for _, conn := range []*testbench.TCPIPv4{&withoutSynCookieConn, &withSynCookieConn} {
+ *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort})
+ }
+ defer withoutSynCookieConn.Close(t)
+ defer withSynCookieConn.Close(t)
+
+ checkSynAck := func(t *testing.T, conn *testbench.TCPIPv4, expectRetransmit bool) {
+ // Expect dut connection to have transitioned to SYN-RCVD state.
+ conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)})
+ if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil {
+ t.Fatalf("expected SYN-ACK, but got %s", err)
+ }
+
+ // If the DUT listener is using syn cookies, it will not retransmit SYNACK
+ got, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)), Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, 2*time.Second)
+ if expectRetransmit && err != nil {
+ t.Fatalf("expected retransmitted SYN-ACK, but got %s", err)
+ }
+ if !expectRetransmit && err == nil {
+ t.Fatalf("expected no retransmitted SYN-ACK, but got %s", got)
+ }
+ }
+
+ t.Run("without syncookies", func(t *testing.T) { checkSynAck(t, &withoutSynCookieConn, true /*expectRetransmit*/) })
+ t.Run("with syncookies", func(t *testing.T) { checkSynAck(t, &withSynCookieConn, false /*expectRetransmit*/) })
+}
diff --git a/test/perf/BUILD b/test/perf/BUILD
index ed899ac22..71982fc4d 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -35,7 +35,7 @@ syscall_test(
)
syscall_test(
- size = "enormous",
+ size = "large",
debug = False,
tags = ["nogotsan"],
test = "//test/perf/linux:getdents_benchmark",
@@ -48,7 +48,7 @@ syscall_test(
)
syscall_test(
- size = "enormous",
+ size = "large",
debug = False,
tags = ["nogotsan"],
test = "//test/perf/linux:gettid_benchmark",
@@ -106,7 +106,7 @@ syscall_test(
)
syscall_test(
- size = "enormous",
+ size = "large",
debug = False,
test = "//test/perf/linux:signal_benchmark",
)
@@ -124,9 +124,10 @@ syscall_test(
)
syscall_test(
- size = "enormous",
+ size = "large",
add_overlay = True,
debug = False,
+ tags = ["nogotsan"],
test = "//test/perf/linux:unlink_benchmark",
)
diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc
index db74cb264..047a034bd 100644
--- a/test/perf/linux/getpid_benchmark.cc
+++ b/test/perf/linux/getpid_benchmark.cc
@@ -31,6 +31,24 @@ void BM_Getpid(benchmark::State& state) {
BENCHMARK(BM_Getpid);
+#ifdef __x86_64__
+
+#define SYSNO_STR1(x) #x
+#define SYSNO_STR(x) SYSNO_STR1(x)
+
+// BM_GetpidOpt uses the most often pattern of calling system calls:
+// mov $SYS_XXX, %eax; syscall.
+void BM_GetpidOpt(benchmark::State& state) {
+ for (auto s : state) {
+ __asm__("movl $" SYSNO_STR(SYS_getpid) ", %%eax\n"
+ "syscall\n"
+ : : : "rax", "rcx", "r11");
+ }
+}
+
+BENCHMARK(BM_GetpidOpt);
+#endif // __x86_64__
+
} // namespace
} // namespace testing
diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl
index 702522d86..2550b61a3 100644
--- a/test/runtimes/defs.bzl
+++ b/test/runtimes/defs.bzl
@@ -75,7 +75,6 @@ def runtime_test(name, **kwargs):
"local",
"manual",
],
- size = "enormous",
**kwargs
)
diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD
index ef299799e..affcae8fd 100644
--- a/test/syscalls/BUILD
+++ b/test/syscalls/BUILD
@@ -244,6 +244,10 @@ syscall_test(
)
syscall_test(
+ test = "//test/syscalls/linux:verity_ioctl_test",
+)
+
+syscall_test(
test = "//test/syscalls/linux:iptables_test",
)
@@ -318,6 +322,10 @@ syscall_test(
)
syscall_test(
+ test = "//test/syscalls/linux:verity_mount_test",
+)
+
+syscall_test(
size = "medium",
test = "//test/syscalls/linux:mremap_test",
)
@@ -772,8 +780,7 @@ syscall_test(
)
syscall_test(
- # NOTE(b/116636318): Large sendmsg may stall a long time.
- size = "enormous",
+ flaky = 1, # NOTE(b/116636318): Large sendmsg may stall a long time.
shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_dgram_local_test",
)
@@ -791,8 +798,7 @@ syscall_test(
)
syscall_test(
- # NOTE(b/116636318): Large sendmsg may stall a long time.
- size = "enormous",
+ flaky = 1, # NOTE(b/116636318): Large sendmsg may stall a long time.
shard_count = more_shards,
test = "//test/syscalls/linux:socket_unix_seqpacket_local_test",
)
@@ -995,3 +1001,7 @@ syscall_test(
syscall_test(
test = "//test/syscalls/linux:processes_test",
)
+
+syscall_test(
+ test = "//test/syscalls/linux:cgroup_test",
+)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index e565c6e77..bc2c7c0e3 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -154,7 +154,6 @@ cc_library(
defines = select_system(),
deps = default_net_util() + [
gtest,
- "//net/util:ports",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -1015,6 +1014,22 @@ cc_binary(
],
)
+cc_binary(
+ name = "verity_ioctl_test",
+ testonly = 1,
+ srcs = ["verity_ioctl.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:mount_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
cc_library(
name = "iptables_types",
testonly = 1,
@@ -1305,6 +1320,20 @@ cc_binary(
)
cc_binary(
+ name = "verity_mount_test",
+ testonly = 1,
+ srcs = ["verity_mount.cc"],
+ linkstatic = 1,
+ deps = [
+ gtest,
+ "//test/util:capability_util",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ ],
+)
+
+cc_binary(
name = "mremap_test",
testonly = 1,
srcs = ["mremap.cc"],
@@ -4206,3 +4235,24 @@ cc_binary(
"//test/util:test_util",
],
)
+
+cc_binary(
+ name = "cgroup_test",
+ testonly = 1,
+ srcs = ["cgroup.cc"],
+ linkstatic = 1,
+ deps = [
+ "//test/util:capability_util",
+ "//test/util:cgroup_util",
+ "//test/util:file_descriptor",
+ "//test/util:fs_util",
+ "@com_google_absl//absl/strings",
+ gtest,
+ "//test/util:posix_error",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ ],
+)
diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc
index f65a14fb8..119a1466b 100644
--- a/test/syscalls/linux/accept_bind.cc
+++ b/test/syscalls/linux/accept_bind.cc
@@ -67,6 +67,42 @@ TEST_P(AllSocketPairTest, ListenDecreaseBacklog) {
SyscallSucceeds());
}
+TEST_P(AllSocketPairTest, ListenBacklogSizes_NoRandomSave) {
+ DisableSave ds;
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+
+ int type;
+ socklen_t typelen = sizeof(type);
+ EXPECT_THAT(
+ getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen),
+ SyscallSucceeds());
+
+ std::array<int, 3> backlogs = {-1, 0, 1};
+ for (auto& backlog : backlogs) {
+ ASSERT_THAT(listen(sockets->first_fd(), backlog), SyscallSucceeds());
+
+ int expected_accepts = backlog;
+ if (backlog < 0) {
+ expected_accepts = 1024;
+ }
+ for (int i = 0; i < expected_accepts; i++) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ // Connect to the listening socket.
+ const FileDescriptor client =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0));
+ ASSERT_THAT(connect(client.get(), sockets->first_addr(),
+ sockets->first_addr_size()),
+ SyscallSucceeds());
+ const FileDescriptor accepted = ASSERT_NO_ERRNO_AND_VALUE(
+ Accept(sockets->first_fd(), nullptr, nullptr));
+ }
+ }
+}
+
TEST_P(AllSocketPairTest, ListenWithoutBind) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(EINVAL));
diff --git a/test/syscalls/linux/cgroup.cc b/test/syscalls/linux/cgroup.cc
new file mode 100644
index 000000000..a1006a978
--- /dev/null
+++ b/test/syscalls/linux/cgroup.cc
@@ -0,0 +1,421 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// All tests in this file rely on being about to mount and unmount cgroupfs,
+// which isn't expected to work, or be safe on a general linux system.
+
+#include <sys/mount.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/str_split.h"
+#include "test/util/capability_util.h"
+#include "test/util/cgroup_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+namespace {
+
+using ::testing::_;
+using ::testing::Ge;
+using ::testing::Gt;
+
+std::vector<std::string> known_controllers = {"cpu", "cpuset", "cpuacct",
+ "memory"};
+
+bool CgroupsAvailable() {
+ return IsRunningOnGvisor() && !IsRunningWithVFS1() &&
+ TEST_CHECK_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN));
+}
+
+TEST(Cgroup, MountSucceeds) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+ EXPECT_NO_ERRNO(c.ContainsCallingProcess());
+}
+
+TEST(Cgroup, SeparateMounts) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+
+ for (const auto& ctl : known_controllers) {
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(ctl));
+ EXPECT_NO_ERRNO(c.ContainsCallingProcess());
+ }
+}
+
+TEST(Cgroup, AllControllersImplicit) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+
+ absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ for (const auto& ctl : known_controllers) {
+ EXPECT_TRUE(cgroups_entries.contains(ctl))
+ << absl::StreamFormat("ctl=%s", ctl);
+ }
+ EXPECT_EQ(cgroups_entries.size(), known_controllers.size());
+}
+
+TEST(Cgroup, AllControllersExplicit) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("all"));
+
+ absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ for (const auto& ctl : known_controllers) {
+ EXPECT_TRUE(cgroups_entries.contains(ctl))
+ << absl::StreamFormat("ctl=%s", ctl);
+ }
+ EXPECT_EQ(cgroups_entries.size(), known_controllers.size());
+}
+
+TEST(Cgroup, ProcsAndTasks) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+ absl::flat_hash_set<pid_t> pids = ASSERT_NO_ERRNO_AND_VALUE(c.Procs());
+ absl::flat_hash_set<pid_t> tids = ASSERT_NO_ERRNO_AND_VALUE(c.Tasks());
+
+ EXPECT_GE(tids.size(), pids.size()) << "Found more processes than threads";
+
+ // Pids should be a strict subset of tids.
+ for (auto it = pids.begin(); it != pids.end(); ++it) {
+ EXPECT_TRUE(tids.contains(*it))
+ << absl::StreamFormat("Have pid %d, but no such tid", *it);
+ }
+}
+
+TEST(Cgroup, ControllersMustBeInUniqueHierarchy) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ // Hierarchy #1: all controllers.
+ Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+ // Hierarchy #2: memory.
+ //
+ // This should conflict since memory is already in hierarchy #1, and the two
+ // hierarchies have different sets of controllers, so this mount can't be a
+ // view into hierarchy #1.
+ EXPECT_THAT(m.MountCgroupfs("memory"), PosixErrorIs(EBUSY, _))
+ << "Memory controller mounted on two hierarchies";
+ EXPECT_THAT(m.MountCgroupfs("cpu"), PosixErrorIs(EBUSY, _))
+ << "CPU controller mounted on two hierarchies";
+}
+
+TEST(Cgroup, UnmountFreesControllers) {
+ SKIP_IF(!CgroupsAvailable());
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+ // All controllers are now attached to all's hierarchy. Attempting new mount
+ // with any individual controller should fail.
+ EXPECT_THAT(m.MountCgroupfs("memory"), PosixErrorIs(EBUSY, _))
+ << "Memory controller mounted on two hierarchies";
+
+ // Unmount the "all" hierarchy. This should enable any controller to be
+ // mounted on a new hierarchy again.
+ ASSERT_NO_ERRNO(m.Unmount(all));
+ EXPECT_NO_ERRNO(m.MountCgroupfs("memory"));
+ EXPECT_NO_ERRNO(m.MountCgroupfs("cpu"));
+}
+
+TEST(Cgroup, OnlyContainsControllerSpecificFiles) {
+ SKIP_IF(!CgroupsAvailable());
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup mem = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+ EXPECT_THAT(Exists(mem.Relpath("memory.usage_in_bytes")),
+ IsPosixErrorOkAndHolds(true));
+ // CPU files shouldn't exist in memory cgroups.
+ EXPECT_THAT(Exists(mem.Relpath("cpu.cfs_period_us")),
+ IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(mem.Relpath("cpu.cfs_quota_us")),
+ IsPosixErrorOkAndHolds(false));
+ EXPECT_THAT(Exists(mem.Relpath("cpu.shares")), IsPosixErrorOkAndHolds(false));
+
+ Cgroup cpu = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+ EXPECT_THAT(Exists(cpu.Relpath("cpu.cfs_period_us")),
+ IsPosixErrorOkAndHolds(true));
+ EXPECT_THAT(Exists(cpu.Relpath("cpu.cfs_quota_us")),
+ IsPosixErrorOkAndHolds(true));
+ EXPECT_THAT(Exists(cpu.Relpath("cpu.shares")), IsPosixErrorOkAndHolds(true));
+ // Memory files shouldn't exist in cpu cgroups.
+ EXPECT_THAT(Exists(cpu.Relpath("memory.usage_in_bytes")),
+ IsPosixErrorOkAndHolds(false));
+}
+
+TEST(Cgroup, InvalidController) {
+ SKIP_IF(!CgroupsAvailable());
+
+ TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string mopts = "this-controller-is-invalid";
+ EXPECT_THAT(
+ mount("none", mountpoint.path().c_str(), "cgroup", 0, mopts.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(Cgroup, MoptAllMustBeExclusive) {
+ SKIP_IF(!CgroupsAvailable());
+
+ TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string mopts = "all,cpu";
+ EXPECT_THAT(
+ mount("none", mountpoint.path().c_str(), "cgroup", 0, mopts.c_str()),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+TEST(MemoryCgroup, MemoryUsageInBytes) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+ EXPECT_THAT(c.ReadIntegerControlFile("memory.usage_in_bytes"),
+ IsPosixErrorOkAndHolds(Gt(0)));
+}
+
+TEST(CPUCgroup, ControlFilesHaveDefaultValues) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+ EXPECT_THAT(c.ReadIntegerControlFile("cpu.cfs_quota_us"),
+ IsPosixErrorOkAndHolds(-1));
+ EXPECT_THAT(c.ReadIntegerControlFile("cpu.cfs_period_us"),
+ IsPosixErrorOkAndHolds(100000));
+ EXPECT_THAT(c.ReadIntegerControlFile("cpu.shares"),
+ IsPosixErrorOkAndHolds(1024));
+}
+
+TEST(CPUAcctCgroup, CPUAcctUsage) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct"));
+
+ const int64_t usage =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage"));
+ const int64_t usage_user =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage_user"));
+ const int64_t usage_sys =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage_sys"));
+
+ EXPECT_GE(usage, 0);
+ EXPECT_GE(usage_user, 0);
+ EXPECT_GE(usage_sys, 0);
+
+ EXPECT_GE(usage_user + usage_sys, usage);
+}
+
+TEST(CPUAcctCgroup, CPUAcctStat) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct"));
+
+ std::string stat =
+ ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuacct.stat"));
+
+ // We're expecting the contents of "cpuacct.stat" to look similar to this:
+ //
+ // user 377986
+ // system 220662
+
+ std::vector<absl::string_view> lines =
+ absl::StrSplit(stat, '\n', absl::SkipEmpty());
+ ASSERT_EQ(lines.size(), 2);
+
+ std::vector<absl::string_view> user_tokens =
+ StrSplit(lines[0], absl::ByChar(' '));
+ EXPECT_EQ(user_tokens[0], "user");
+ EXPECT_THAT(Atoi<int64_t>(user_tokens[1]), IsPosixErrorOkAndHolds(Ge(0)));
+
+ std::vector<absl::string_view> sys_tokens =
+ StrSplit(lines[1], absl::ByChar(' '));
+ EXPECT_EQ(sys_tokens[0], "system");
+ EXPECT_THAT(Atoi<int64_t>(sys_tokens[1]), IsPosixErrorOkAndHolds(Ge(0)));
+}
+
+TEST(ProcCgroups, Empty) {
+ SKIP_IF(!CgroupsAvailable());
+
+ absl::flat_hash_map<std::string, CgroupsEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ // No cgroups mounted yet, we should have no entries.
+ EXPECT_TRUE(entries.empty());
+}
+
+TEST(ProcCgroups, ProcCgroupsEntries) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+
+ Cgroup mem = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+ absl::flat_hash_map<std::string, CgroupsEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ EXPECT_EQ(entries.size(), 1);
+ ASSERT_TRUE(entries.contains("memory"));
+ CgroupsEntry mem_e = entries["memory"];
+ EXPECT_EQ(mem_e.subsys_name, "memory");
+ EXPECT_GE(mem_e.hierarchy, 1);
+ // Expect a single root cgroup.
+ EXPECT_EQ(mem_e.num_cgroups, 1);
+ // Cgroups are currently always enabled when mounted.
+ EXPECT_TRUE(mem_e.enabled);
+
+ // Add a second cgroup, and check for new entry.
+
+ Cgroup cpu = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ EXPECT_EQ(entries.size(), 2);
+ EXPECT_TRUE(entries.contains("memory")); // Still have memory entry.
+ ASSERT_TRUE(entries.contains("cpu"));
+ CgroupsEntry cpu_e = entries["cpu"];
+ EXPECT_EQ(cpu_e.subsys_name, "cpu");
+ EXPECT_GE(cpu_e.hierarchy, 1);
+ EXPECT_EQ(cpu_e.num_cgroups, 1);
+ EXPECT_TRUE(cpu_e.enabled);
+
+ // Separate hierarchies, since controllers were mounted separately.
+ EXPECT_NE(mem_e.hierarchy, cpu_e.hierarchy);
+}
+
+TEST(ProcCgroups, UnmountRemovesEntries) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup cg = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu,memory"));
+ absl::flat_hash_map<std::string, CgroupsEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ EXPECT_EQ(entries.size(), 2);
+
+ ASSERT_NO_ERRNO(m.Unmount(cg));
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ EXPECT_TRUE(entries.empty());
+}
+
+TEST(ProcPIDCgroup, Empty) {
+ SKIP_IF(!CgroupsAvailable());
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_TRUE(entries.empty());
+}
+
+TEST(ProcPIDCgroup, Entries) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_EQ(entries.size(), 1);
+ PIDCgroupEntry mem_e = entries["memory"];
+ EXPECT_GE(mem_e.hierarchy, 1);
+ EXPECT_EQ(mem_e.controllers, "memory");
+ EXPECT_EQ(mem_e.path, "/");
+
+ Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_EQ(entries.size(), 2);
+ EXPECT_TRUE(entries.contains("memory")); // Still have memory entry.
+ PIDCgroupEntry cpu_e = entries["cpu"];
+ EXPECT_GE(cpu_e.hierarchy, 1);
+ EXPECT_EQ(cpu_e.controllers, "cpu");
+ EXPECT_EQ(cpu_e.path, "/");
+
+ // Separate hierarchies, since controllers were mounted separately.
+ EXPECT_NE(mem_e.hierarchy, cpu_e.hierarchy);
+}
+
+TEST(ProcPIDCgroup, UnmountRemovesEntries) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(""));
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_GT(entries.size(), 0);
+
+ ASSERT_NO_ERRNO(m.Unmount(all));
+
+ entries = ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+ EXPECT_TRUE(entries.empty());
+}
+
+TEST(ProcCgroup, PIDCgroupMatchesCgroups) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory"));
+ Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu"));
+
+ absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+ absl::flat_hash_map<std::string, PIDCgroupEntry> pid_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+
+ CgroupsEntry cgroup_mem = cgroups_entries["memory"];
+ PIDCgroupEntry pid_mem = pid_entries["memory"];
+
+ EXPECT_EQ(cgroup_mem.hierarchy, pid_mem.hierarchy);
+
+ CgroupsEntry cgroup_cpu = cgroups_entries["cpu"];
+ PIDCgroupEntry pid_cpu = pid_entries["cpu"];
+
+ EXPECT_EQ(cgroup_cpu.hierarchy, pid_cpu.hierarchy);
+ EXPECT_NE(cgroup_mem.hierarchy, cgroup_cpu.hierarchy);
+ EXPECT_NE(pid_mem.hierarchy, pid_cpu.hierarchy);
+}
+
+TEST(ProcCgroup, MultiControllerHierarchy) {
+ SKIP_IF(!CgroupsAvailable());
+
+ Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()));
+ Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory,cpu"));
+
+ absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries());
+
+ CgroupsEntry mem_e = cgroups_entries["memory"];
+ CgroupsEntry cpu_e = cgroups_entries["cpu"];
+
+ // Both controllers should have the same hierarchy ID.
+ EXPECT_EQ(mem_e.hierarchy, cpu_e.hierarchy);
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> pid_entries =
+ ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid()));
+
+ // Expecting an entry listing both controllers, that matches the previous
+ // hierarchy ID. Note that the controllers are listed in alphabetical order.
+ PIDCgroupEntry pid_e = pid_entries["cpu,memory"];
+ EXPECT_EQ(pid_e.hierarchy, mem_e.hierarchy);
+}
+
+} // namespace
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc
index c47567b4e..79b0596c4 100644
--- a/test/syscalls/linux/fpsig_fork.cc
+++ b/test/syscalls/linux/fpsig_fork.cc
@@ -44,6 +44,8 @@ namespace {
#define SET_FP0(var) SET_FPREG(var, d0)
#endif
+#define DEFAULT_MXCSR 0x1f80
+
int parent, child;
void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
@@ -57,6 +59,12 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) {
uint64_t got;
GET_FP0(got);
TEST_CHECK_MSG(val == got, "Basic FP check failed in sigusr1()");
+
+#ifdef __x86_64
+ uint32_t mxcsr;
+ __asm__("STMXCSR %0" : "=m"(mxcsr));
+ TEST_CHECK_MSG(mxcsr == DEFAULT_MXCSR, "Unexpected mxcsr");
+#endif
}
TEST(FPSigTest, Fork) {
@@ -125,6 +133,55 @@ TEST(FPSigTest, Fork) {
}
}
+#ifdef __x86_64__
+TEST(FPSigTest, ForkWithZeroMxcsr) {
+ parent = getpid();
+ pid_t parent_tid = gettid();
+
+ struct sigaction sa = {};
+ sigemptyset(&sa.sa_mask);
+ sa.sa_flags = SA_SIGINFO;
+ sa.sa_sigaction = sigusr1;
+ ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds());
+
+ // The control bits of the MXCSR register are callee-saved (preserved across
+ // calls), while the status bits are caller-saved (not preserved).
+ uint32_t expected = 0, origin;
+ __asm__("STMXCSR %0" : "=m"(origin));
+ __asm__("LDMXCSR %0" : : "m"(expected));
+
+ asm volatile(
+ "movl %[killnr], %%eax;"
+ "movl %[parent], %%edi;"
+ "movl %[tid], %%esi;"
+ "movl %[sig], %%edx;"
+ "syscall;"
+ :
+ : [killnr] "i"(__NR_tgkill), [parent] "rm"(parent),
+ [tid] "rm"(parent_tid), [sig] "i"(SIGUSR1)
+ : "rax", "rdi", "rsi", "rdx",
+ // Clobbered by syscall.
+ "rcx", "r11");
+
+ uint32_t got;
+ __asm__("STMXCSR %0" : "=m"(got));
+ __asm__("LDMXCSR %0" : : "m"(origin));
+
+ if (getpid() == parent) { // Parent.
+ int status;
+ ASSERT_THAT(waitpid(child, &status, 0), SyscallSucceedsWithValue(child));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ }
+
+ // TEST_CHECK_MSG since this may run in the child.
+ TEST_CHECK_MSG(expected == got, "Bad mxcsr value");
+
+ if (getpid() != parent) { // Child.
+ _exit(0);
+ }
+}
+#endif
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index 28f51a3bf..8c5732147 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -234,14 +234,6 @@ TEST(SemaphoreTest, SemTimedOpBlock) {
AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT));
ASSERT_THAT(sem.get(), SyscallSucceeds());
- ScopedThread th([&sem] {
- absl::SleepFor(absl::Milliseconds(100));
-
- struct sembuf buf = {};
- buf.sem_op = 1;
- ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds());
- });
-
struct sembuf buf = {};
buf.sem_op = -1;
struct timespec timeout = {};
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 597b5bcb1..d391363fb 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -489,13 +489,6 @@ void TestListenWhileConnect(const TestParam& param,
TestAddress const& listener = param.listener;
TestAddress const& connector = param.connector;
- constexpr int kBacklog = 2;
- // Linux completes one more connection than the listen backlog argument.
- // To ensure that there is at least one client connection that stays in
- // connecting state, keep 2 more client connections than the listen backlog.
- // gVisor differs in this behavior though, gvisor.dev/issue/3153.
- constexpr int kClients = kBacklog + 2;
-
// Create the listening socket.
FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
@@ -503,6 +496,13 @@ void TestListenWhileConnect(const TestParam& param,
ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
listener.addr_len),
SyscallSucceeds());
+ // This test is only interested in deterministically getting a socket in
+ // connecting state. For that, we use a listen backlog of zero which would
+ // mean there is exactly one connection that gets established and is enqueued
+ // to the accept queue. We poll on the listener to ensure that is enqueued.
+ // After that the subsequent client connect will stay in connecting state as
+ // the accept queue is full.
+ constexpr int kBacklog = 0;
ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
// Get the port bound by the listening socket.
@@ -515,42 +515,49 @@ void TestListenWhileConnect(const TestParam& param,
sockaddr_storage conn_addr = connector.addr;
ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- std::vector<FileDescriptor> clients;
- for (int i = 0; i < kClients; i++) {
- FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
- int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
- if (ret != 0) {
- EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- clients.push_back(std::move(client));
- }
+ FileDescriptor established_client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(
+ connect(established_client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+
+ // Ensure that the accept queue has the completed connection.
+ constexpr int kTimeout = 10000;
+ pollfd pfd = {
+ .fd = listen_fd.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN);
+
+ FileDescriptor connecting_client = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ // Keep the last client in connecting state.
+ int ret =
+ connect(connecting_client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
stopListen(listen_fd);
- for (auto& client : clients) {
- constexpr int kTimeout = 10000;
+ std::array<std::pair<int, int>, 2> sockets = {
+ std::make_pair(established_client.get(), ECONNRESET),
+ std::make_pair(connecting_client.get(), ECONNREFUSED),
+ };
+ for (size_t i = 0; i < sockets.size(); i++) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ auto [fd, expected_errno] = sockets[i];
pollfd pfd = {
- .fd = client.get(),
- .events = POLLIN,
+ .fd = fd,
};
- // When the listening socket is closed, then we expect the remote to reset
- // the connection.
- ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
- ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
+ // When the listening socket is closed, the peer would reset the connection.
+ EXPECT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(pfd.revents, POLLHUP | POLLERR);
char c;
- // Subsequent read can fail with:
- // ECONNRESET: If the client connection was established and was reset by the
- // remote.
- // ECONNREFUSED: If the client connection failed to be established.
- ASSERT_THAT(read(client.get(), &c, sizeof(c)),
- AnyOf(SyscallFailsWithErrno(ECONNRESET),
- SyscallFailsWithErrno(ECONNREFUSED)));
- // The last client connection would be in connecting (SYN_SENT) state.
- if (client.get() == clients[kClients - 1].get()) {
- ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno);
- }
+ EXPECT_THAT(read(fd, &c, sizeof(c)), SyscallFailsWithErrno(expected_errno));
}
}
@@ -570,7 +577,59 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) {
// random save as established connections which can't be delivered to the accept
// queue because the queue is full are not correctly delivered after restore
// causing the last accept to timeout on the restore.
-TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
+TEST_P(SocketInetLoopbackTest, TCPAcceptBacklogSizes_NoRandomSave) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+
+ // Create the listening socket.
+ const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ std::array<int, 3> backlogs = {-1, 0, 1};
+ for (auto& backlog : backlogs) {
+ ASSERT_THAT(listen(listen_fd.get(), backlog), SyscallSucceeds());
+
+ int expected_accepts;
+ if (backlog < 0) {
+ expected_accepts = 1024;
+ } else {
+ expected_accepts = backlog + 1;
+ }
+ for (int i = 0; i < expected_accepts; i++) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ // Connect to the listening socket.
+ const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ ASSERT_THAT(
+ RetryEINTR(connect)(conn_fd.get(),
+ reinterpret_cast<struct sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ const FileDescriptor accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ }
+ }
+}
+
+// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/
+// random save as established connections which can't be delivered to the accept
+// queue because the queue is full are not correctly delivered after restore
+// causing the last accept to timeout on the restore.
+TEST_P(SocketInetLoopbackTest, TCPBacklog_NoRandomSave) {
auto const& param = GetParam();
TestAddress const& listener = param.listener;
@@ -595,6 +654,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
int i = 0;
while (1) {
+ SCOPED_TRACE(absl::StrCat("i=", i));
int ret;
// Connect to the listening socket.
@@ -620,103 +680,133 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) {
i++;
}
+ int client_conns = i;
+ int accepted_conns = 0;
for (; i != 0; i--) {
- // Accept the connection.
- //
- // We have to assign a name to the accepted socket, as unamed temporary
- // objects are destructed upon full evaluation of the expression it is in,
- // potentially causing the connecting socket to fail to shutdown properly.
- auto accepted =
- ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ SCOPED_TRACE(absl::StrCat("i=", i));
+ pollfd pfd = {
+ .fd = listen_fd.get(),
+ .events = POLLIN,
+ };
+ // Look for incoming connections to accept. The last connect request could
+ // be established from the client side, but the ACK of the handshake could
+ // be dropped by the listener if the accept queue was filled up by the
+ // previous connect.
+ int ret;
+ ASSERT_THAT(ret = poll(&pfd, 1, 3000), SyscallSucceeds());
+ if (ret == 0) break;
+ if (pfd.revents == POLLIN) {
+ // Accept the connection.
+ //
+ // We have to assign a name to the accepted socket, as unamed temporary
+ // objects are destructed upon full evaluation of the expression it is in,
+ // potentially causing the connecting socket to fail to shutdown properly.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ accepted_conns++;
+ }
}
+ // We should accept at least listen backlog + 1 connections. As the stack is
+ // enqueuing established connections to the accept queue, newer SYNs could
+ // still be replied to causing those client connections would be accepted as
+ // we start dequeuing the queue.
+ ASSERT_GE(accepted_conns, kBacklogSize + 1);
+ ASSERT_GE(client_conns, accepted_conns);
}
-// Test if the stack completes atmost listen backlog number of client
-// connections. It exercises the path of the stack that enqueues completed
-// connections to accept queue vs new incoming SYNs.
-TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) {
- const auto& param = GetParam();
- const TestAddress& listener = param.listener;
- const TestAddress& connector = param.connector;
+// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/
+// random save as established connections which can't be delivered to the accept
+// queue because the queue is full are not correctly delivered after restore
+// causing the last accept to timeout on the restore.
+TEST_P(SocketInetLoopbackTest, TCPBacklogAcceptAll_NoRandomSave) {
+ auto const& param = GetParam();
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ // Create the listening socket.
+ FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
+ sockaddr_storage listen_addr = listener.addr;
+ ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
+ listener.addr_len),
+ SyscallSucceeds());
constexpr int kBacklog = 1;
- // Keep the number of client connections more than the listen backlog.
- // Linux completes one more connection than the listen backlog argument.
- // gVisor differs in this behavior though, gvisor.dev/issue/3153.
- int kClients = kBacklog + 2;
- if (IsRunningOnGvisor()) {
- kClients--;
- }
+ ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
- // Run the following test for few iterations to test race between accept queue
- // getting filled with incoming SYNs.
- for (int num = 0; num < 10; num++) {
- FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP));
- sockaddr_storage listen_addr = listener.addr;
- ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- listener.addr_len),
- SyscallSucceeds());
- ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds());
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(getsockname(listen_fd.get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
- socklen_t addrlen = listener.addr_len;
- ASSERT_THAT(
- getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr),
- &addrlen),
- SyscallSucceeds());
- uint16_t const port =
- ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
- sockaddr_storage conn_addr = connector.addr;
- ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ sockaddr_storage conn_addr = connector.addr;
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
- std::vector<FileDescriptor> clients;
- // Issue multiple non-blocking client connects.
- for (int i = 0; i < kClients; i++) {
- FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE(
- Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
- int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr),
- connector.addr_len);
- if (ret != 0) {
- EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
- }
- clients.push_back(std::move(client));
+ // Fill up the accept queue and trigger more client connections which would be
+ // waiting to be accepted.
+ std::array<FileDescriptor, kBacklog + 1> established_clients;
+ for (auto& fd : established_clients) {
+ fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP));
+ ASSERT_THAT(connect(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceeds());
+ }
+ std::array<FileDescriptor, kBacklog> waiting_clients;
+ for (auto& fd : waiting_clients) {
+ fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP));
+ int ret = connect(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len);
+ if (ret != 0) {
+ EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS));
}
+ }
- // Now that client connects are issued, wait for the accept queue to get
- // filled and ensure no new client connection is completed.
- for (int i = 0; i < kClients; i++) {
- pollfd pfd = {
- .fd = clients[i].get(),
- .events = POLLOUT,
- };
- if (i < kClients - 1) {
- // Poll for client side connection completions with a large timeout.
- // We cannot poll on the listener side without calling accept as poll
- // stays level triggered with non-zero accept queue length.
- //
- // Client side poll would not guarantee that the completed connection
- // has been enqueued in to the acccept queue, but the fact that the
- // listener ACKd the SYN, means that it cannot complete any new incoming
- // SYNs when it has already ACKd for > backlog number of SYNs.
- ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1))
- << "num=" << num << " i=" << i << " kClients=" << kClients;
- ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i;
- } else {
- // Now that we expect accept queue filled up, ensure that the last
- // client connection never completes with a smaller poll timeout.
- ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0))
- << "num=" << num << " i=" << i;
- }
+ auto accept_connection = [&]() {
+ constexpr int kTimeout = 10000;
+ pollfd pfd = {
+ .fd = listen_fd.get(),
+ .events = POLLIN,
+ };
+ ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ ASSERT_EQ(pfd.revents, POLLIN);
+ // Accept the connection.
+ //
+ // We have to assign a name to the accepted socket, as unamed temporary
+ // objects are destructed upon full evaluation of the expression it is in,
+ // potentially causing the connecting socket to fail to shutdown properly.
+ auto accepted =
+ ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr));
+ };
- ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0))
- << "num=" << num << " i=" << i;
- }
- clients.clear();
- // We close the listening side and open a new listener. We could instead
- // drain the accept queue by calling accept() and reuse the listener, but
- // that is racy as the retransmitted SYNs could get ACKd as we make room in
- // the accept queue.
- ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0));
+ // Ensure that we accept all client connections. The waiting connections would
+ // get enqueued as we drain the accept queue.
+ for (int i = 0; i < std::size(established_clients); i++) {
+ SCOPED_TRACE(absl::StrCat("established clients i=", i));
+ accept_connection();
+ }
+
+ // The waiting client connections could be in one of these 2 states:
+ // (1) SYN_SENT: if the SYN was dropped because accept queue was full
+ // (2) ESTABLISHED: if the listener sent back a SYNACK, but may have dropped
+ // the ACK from the client if the accept queue was full (send out a data to
+ // re-send that ACK, to address that case).
+ for (int i = 0; i < std::size(waiting_clients); i++) {
+ SCOPED_TRACE(absl::StrCat("waiting clients i=", i));
+ constexpr int kTimeout = 10000;
+ pollfd pfd = {
+ .fd = waiting_clients[i].get(),
+ .events = POLLOUT,
+ };
+ EXPECT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
+ EXPECT_EQ(pfd.revents, POLLOUT);
+ char c;
+ EXPECT_THAT(RetryEINTR(send)(waiting_clients[i].get(), &c, sizeof(c), 0),
+ SyscallSucceedsWithValue(sizeof(c)));
+ accept_connection();
}
}
diff --git a/test/syscalls/linux/verity_ioctl.cc b/test/syscalls/linux/verity_ioctl.cc
new file mode 100644
index 000000000..dcd28f2c3
--- /dev/null
+++ b/test/syscalls/linux/verity_ioctl.cc
@@ -0,0 +1,133 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/mount.h>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/fs_util.h"
+#include "test/util/mount_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+#ifndef FS_IOC_ENABLE_VERITY
+#define FS_IOC_ENABLE_VERITY 1082156677
+#endif
+
+#ifndef FS_IOC_MEASURE_VERITY
+#define FS_IOC_MEASURE_VERITY 3221513862
+#endif
+
+#ifndef FS_VERITY_FL
+#define FS_VERITY_FL 1048576
+#endif
+
+#ifndef FS_IOC_GETFLAGS
+#define FS_IOC_GETFLAGS 2148034049
+#endif
+
+struct fsverity_digest {
+ __u16 digest_algorithm;
+ __u16 digest_size; /* input/output */
+ __u8 digest[];
+};
+
+const int fsverity_max_digest_size = 64;
+const int fsverity_default_digest_size = 32;
+
+class IoctlTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ // Verity is implemented in VFS2.
+ SKIP_IF(IsRunningWithVFS1());
+
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+ // Mount a tmpfs file system, to be wrapped by a verity fs.
+ tmpfs_dir_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(mount("", tmpfs_dir_.path().c_str(), "tmpfs", 0, ""),
+ SyscallSucceeds());
+
+ // Create a new file in the tmpfs mount.
+ constexpr char kContents[] = "foobarbaz";
+ file_ = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(tmpfs_dir_.path(), kContents, 0777));
+ filename_ = Basename(file_.path());
+ }
+
+ TempPath tmpfs_dir_;
+ TempPath file_;
+ std::string filename_;
+};
+
+TEST_F(IoctlTest, Enable) {
+ // mount a verity fs on the existing tmpfs mount.
+ std::string mount_opts = "lower_path=" + tmpfs_dir_.path();
+ auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(
+ mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()),
+ SyscallSucceeds());
+
+ printf("verity path: %s, filename: %s\n", verity_dir.path().c_str(),
+ filename_.c_str());
+ fflush(nullptr);
+ // Confirm that the verity flag is absent.
+ int flag = 0;
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777));
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_GETFLAGS, &flag), SyscallSucceeds());
+ EXPECT_EQ(flag & FS_VERITY_FL, 0);
+
+ // Enable the file and confirm that the verity flag is present.
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds());
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_GETFLAGS, &flag), SyscallSucceeds());
+ EXPECT_EQ(flag & FS_VERITY_FL, FS_VERITY_FL);
+}
+
+TEST_F(IoctlTest, Measure) {
+ // mount a verity fs on the existing tmpfs mount.
+ std::string mount_opts = "lower_path=" + tmpfs_dir_.path();
+ auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(
+ mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()),
+ SyscallSucceeds());
+
+ // Confirm that the file cannot be measured.
+ auto const fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777));
+ int digest_size = sizeof(struct fsverity_digest) + fsverity_max_digest_size;
+ struct fsverity_digest *digest =
+ reinterpret_cast<struct fsverity_digest *>(malloc(digest_size));
+ memset(digest, 0, digest_size);
+ digest->digest_size = fsverity_max_digest_size;
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_MEASURE_VERITY, digest),
+ SyscallFailsWithErrno(ENODATA));
+
+ // Enable the file and confirm that the file can be measured.
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds());
+ ASSERT_THAT(ioctl(fd.get(), FS_IOC_MEASURE_VERITY, digest),
+ SyscallSucceeds());
+ EXPECT_EQ(digest->digest_size, fsverity_default_digest_size);
+ free(digest);
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/syscalls/linux/verity_mount.cc b/test/syscalls/linux/verity_mount.cc
new file mode 100644
index 000000000..e73dd5599
--- /dev/null
+++ b/test/syscalls/linux/verity_mount.cc
@@ -0,0 +1,53 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <sys/mount.h>
+
+#include <iomanip>
+#include <sstream>
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+#include "test/util/capability_util.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Mount verity file system on an existing gofer mount.
+TEST(MountTest, MountExisting) {
+ // Verity is implemented in VFS2.
+ SKIP_IF(IsRunningWithVFS1());
+
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)));
+
+ // Mount a new tmpfs file system.
+ auto const tmpfs_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ ASSERT_THAT(mount("", tmpfs_dir.path().c_str(), "tmpfs", 0, ""),
+ SyscallSucceeds());
+
+ // Mount a verity file system on the existing gofer mount.
+ auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ std::string opts = "lower_path=" + tmpfs_dir.path();
+ EXPECT_THAT(mount("", verity_dir.path().c_str(), "verity", 0, opts.c_str()),
+ SyscallSucceeds());
+}
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/BUILD b/test/util/BUILD
index e561f3daa..383de00ed 100644
--- a/test/util/BUILD
+++ b/test/util/BUILD
@@ -94,6 +94,7 @@ cc_library(
":file_descriptor",
":posix_error",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
gtest,
],
)
@@ -368,3 +369,20 @@ cc_library(
testonly = 1,
hdrs = ["temp_umask.h"],
)
+
+cc_library(
+ name = "cgroup_util",
+ testonly = 1,
+ srcs = ["cgroup_util.cc"],
+ hdrs = ["cgroup_util.h"],
+ deps = [
+ ":cleanup",
+ ":fs_util",
+ ":mount_util",
+ ":posix_error",
+ ":temp_path",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
+ ],
+)
diff --git a/test/util/cgroup_util.cc b/test/util/cgroup_util.cc
new file mode 100644
index 000000000..65d9c4986
--- /dev/null
+++ b/test/util/cgroup_util.cc
@@ -0,0 +1,223 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "test/util/cgroup_util.h"
+
+#include <sys/syscall.h>
+#include <unistd.h>
+
+#include "absl/strings/str_split.h"
+#include "test/util/fs_util.h"
+#include "test/util/mount_util.h"
+
+namespace gvisor {
+namespace testing {
+
+Cgroup::Cgroup(std::string path) : cgroup_path_(path) {
+ id_ = ++Cgroup::next_id_;
+ std::cerr << absl::StreamFormat("[cg#%d] <= %s", id_, cgroup_path_)
+ << std::endl;
+}
+
+PosixErrorOr<std::string> Cgroup::ReadControlFile(
+ absl::string_view name) const {
+ std::string buf;
+ RETURN_IF_ERRNO(GetContents(Relpath(name), &buf));
+
+ const std::string alias_path = absl::StrFormat("[cg#%d]/%s", id_, name);
+ std::cerr << absl::StreamFormat("<contents of %s>", alias_path) << std::endl;
+ std::cerr << buf;
+ std::cerr << absl::StreamFormat("<end of %s>", alias_path) << std::endl;
+
+ return buf;
+}
+
+PosixErrorOr<int64_t> Cgroup::ReadIntegerControlFile(
+ absl::string_view name) const {
+ ASSIGN_OR_RETURN_ERRNO(const std::string buf, ReadControlFile(name));
+ ASSIGN_OR_RETURN_ERRNO(const int64_t val, Atoi<int64_t>(buf));
+ return val;
+}
+
+PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::Procs() const {
+ ASSIGN_OR_RETURN_ERRNO(std::string buf, ReadControlFile("cgroup.procs"));
+ return ParsePIDList(buf);
+}
+
+PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::Tasks() const {
+ ASSIGN_OR_RETURN_ERRNO(std::string buf, ReadControlFile("tasks"));
+ return ParsePIDList(buf);
+}
+
+PosixError Cgroup::ContainsCallingProcess() const {
+ ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> procs, Procs());
+ ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> tasks, Tasks());
+ const pid_t pid = getpid();
+ const pid_t tid = syscall(SYS_gettid);
+ if (!procs.contains(pid)) {
+ return PosixError(
+ ENOENT, absl::StrFormat("Cgroup doesn't contain process %d", pid));
+ }
+ if (!tasks.contains(tid)) {
+ return PosixError(ENOENT,
+ absl::StrFormat("Cgroup doesn't contain task %d", tid));
+ }
+ return NoError();
+}
+
+PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::ParsePIDList(
+ absl::string_view data) const {
+ absl::flat_hash_set<pid_t> res;
+ std::vector<absl::string_view> lines = absl::StrSplit(data, '\n');
+ for (const std::string_view& line : lines) {
+ if (line.empty()) {
+ continue;
+ }
+ ASSIGN_OR_RETURN_ERRNO(const int32_t pid, Atoi<int32_t>(line));
+ res.insert(static_cast<pid_t>(pid));
+ }
+ return res;
+}
+
+int64_t Cgroup::next_id_ = 0;
+
+PosixErrorOr<Cgroup> Mounter::MountCgroupfs(std::string mopts) {
+ ASSIGN_OR_RETURN_ERRNO(TempPath mountpoint,
+ TempPath::CreateDirIn(root_.path()));
+ ASSIGN_OR_RETURN_ERRNO(
+ Cleanup mount, Mount("none", mountpoint.path(), "cgroup", 0, mopts, 0));
+ const std::string mountpath = mountpoint.path();
+ std::cerr << absl::StreamFormat(
+ "Mount(\"none\", \"%s\", \"cgroup\", 0, \"%s\", 0) => OK",
+ mountpath, mopts)
+ << std::endl;
+ Cgroup cg = Cgroup(mountpath);
+ mountpoints_[cg.id()] = std::move(mountpoint);
+ mounts_[cg.id()] = std::move(mount);
+ return cg;
+}
+
+PosixError Mounter::Unmount(const Cgroup& c) {
+ auto mount = mounts_.find(c.id());
+ auto mountpoint = mountpoints_.find(c.id());
+
+ if (mount == mounts_.end() || mountpoint == mountpoints_.end()) {
+ return PosixError(
+ ESRCH, absl::StrFormat("No mount found for cgroupfs containing cg#%d",
+ c.id()));
+ }
+
+ std::cerr << absl::StreamFormat("Unmount([cg#%d])", c.id()) << std::endl;
+
+ // Simply delete the entries, their destructors will unmount and delete the
+ // mountpoint. Note the order is important to avoid errors: mount then
+ // mountpoint.
+ mounts_.erase(mount);
+ mountpoints_.erase(mountpoint);
+
+ return NoError();
+}
+
+constexpr char kProcCgroupsHeader[] =
+ "#subsys_name\thierarchy\tnum_cgroups\tenabled";
+
+PosixErrorOr<absl::flat_hash_map<std::string, CgroupsEntry>>
+ProcCgroupsEntries() {
+ std::string content;
+ RETURN_IF_ERRNO(GetContents("/proc/cgroups", &content));
+
+ bool found_header = false;
+ absl::flat_hash_map<std::string, CgroupsEntry> entries;
+ std::vector<std::string> lines = absl::StrSplit(content, '\n');
+ std::cerr << "<contents of /proc/cgroups>" << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (!found_header) {
+ EXPECT_EQ(line, kProcCgroupsHeader);
+ found_header = true;
+ continue;
+ }
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/cgroups.
+ //
+ // Example entries, fields are tab separated in the real file:
+ //
+ // #subsys_name hierarchy num_cgroups enabled
+ // cpuset 12 35 1
+ // cpu 3 222 1
+ // ^ ^ ^ ^
+ // 0 1 2 3
+
+ CgroupsEntry entry;
+ std::vector<std::string> fields =
+ StrSplit(line, absl::ByAnyChar(": \t"), absl::SkipEmpty());
+
+ entry.subsys_name = fields[0];
+ ASSIGN_OR_RETURN_ERRNO(entry.hierarchy, Atoi<uint32_t>(fields[1]));
+ ASSIGN_OR_RETURN_ERRNO(entry.num_cgroups, Atoi<uint64_t>(fields[2]));
+ ASSIGN_OR_RETURN_ERRNO(const int enabled, Atoi<int>(fields[3]));
+ entry.enabled = enabled != 0;
+
+ entries[entry.subsys_name] = entry;
+ }
+ std::cerr << "<end of /proc/cgroups>" << std::endl;
+
+ return entries;
+}
+
+PosixErrorOr<absl::flat_hash_map<std::string, PIDCgroupEntry>>
+ProcPIDCgroupEntries(pid_t pid) {
+ const std::string path = absl::StrFormat("/proc/%d/cgroup", pid);
+ std::string content;
+ RETURN_IF_ERRNO(GetContents(path, &content));
+
+ absl::flat_hash_map<std::string, PIDCgroupEntry> entries;
+ std::vector<std::string> lines = absl::StrSplit(content, '\n');
+
+ std::cerr << absl::StreamFormat("<contents of %s>", path) << std::endl;
+ for (const std::string& line : lines) {
+ std::cerr << line << std::endl;
+
+ if (line.empty()) {
+ continue;
+ }
+
+ // Parse a single entry from /proc/<pid>/cgroup.
+ //
+ // Example entries:
+ //
+ // 2:cpu:/path/to/cgroup
+ // 1:memory:/
+
+ PIDCgroupEntry entry;
+ std::vector<std::string> fields =
+ absl::StrSplit(line, absl::ByChar(':'), absl::SkipEmpty());
+
+ ASSIGN_OR_RETURN_ERRNO(entry.hierarchy, Atoi<uint32_t>(fields[0]));
+ entry.controllers = fields[1];
+ entry.path = fields[2];
+
+ entries[entry.controllers] = entry;
+ }
+ std::cerr << absl::StreamFormat("<end of %s>", path) << std::endl;
+
+ return entries;
+}
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/util/cgroup_util.h b/test/util/cgroup_util.h
new file mode 100644
index 000000000..b049559df
--- /dev/null
+++ b/test/util/cgroup_util.h
@@ -0,0 +1,111 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GVISOR_TEST_UTIL_CGROUP_UTIL_H_
+#define GVISOR_TEST_UTIL_CGROUP_UTIL_H_
+
+#include <unistd.h>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
+#include "absl/strings/string_view.h"
+#include "test/util/cleanup.h"
+#include "test/util/fs_util.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+namespace testing {
+
+// Cgroup represents a cgroup directory on a mounted cgroupfs.
+class Cgroup {
+ public:
+ Cgroup(std::string path);
+
+ uint64_t id() const { return id_; }
+
+ std::string Relpath(absl::string_view leaf) const {
+ return JoinPath(cgroup_path_, leaf);
+ }
+
+ // Returns the contents of a cgroup control file with the given name.
+ PosixErrorOr<std::string> ReadControlFile(absl::string_view name) const;
+
+ // Reads the contents of a cgroup control with the given name, and attempts
+ // to parse it as an integer.
+ PosixErrorOr<int64_t> ReadIntegerControlFile(absl::string_view name) const;
+
+ // Returns the thread ids of the leaders of thread groups managed by this
+ // cgroup.
+ PosixErrorOr<absl::flat_hash_set<pid_t>> Procs() const;
+
+ PosixErrorOr<absl::flat_hash_set<pid_t>> Tasks() const;
+
+ // ContainsCallingProcess checks whether the calling process is part of the
+ PosixError ContainsCallingProcess() const;
+
+ private:
+ PosixErrorOr<absl::flat_hash_set<pid_t>> ParsePIDList(
+ absl::string_view data) const;
+
+ static int64_t next_id_;
+ int64_t id_;
+ const std::string cgroup_path_;
+};
+
+// Mounter is a utility for creating cgroupfs mounts. It automatically manages
+// the lifetime of created mounts.
+class Mounter {
+ public:
+ Mounter(TempPath root) : root_(std::move(root)) {}
+
+ PosixErrorOr<Cgroup> MountCgroupfs(std::string mopts);
+
+ PosixError Unmount(const Cgroup& c);
+
+ private:
+ // The destruction order of these members avoids errors during cleanup. We
+ // first unmount (by executing the mounts_ cleanups), then delete the
+ // mountpoint subdirs, then delete the root.
+ TempPath root_;
+ absl::flat_hash_map<int64_t, TempPath> mountpoints_;
+ absl::flat_hash_map<int64_t, Cleanup> mounts_;
+};
+
+// Represents a line from /proc/cgroups.
+struct CgroupsEntry {
+ std::string subsys_name;
+ uint32_t hierarchy;
+ uint64_t num_cgroups;
+ bool enabled;
+};
+
+// Returns a parsed representation of /proc/cgroups.
+PosixErrorOr<absl::flat_hash_map<std::string, CgroupsEntry>>
+ProcCgroupsEntries();
+
+// Represents a line from /proc/<pid>/cgroup.
+struct PIDCgroupEntry {
+ uint32_t hierarchy;
+ std::string controllers;
+ std::string path;
+};
+
+// Returns a parsed representation of /proc/<pid>/cgroup.
+PosixErrorOr<absl::flat_hash_map<std::string, PIDCgroupEntry>>
+ProcPIDCgroupEntries(pid_t pid);
+
+} // namespace testing
+} // namespace gvisor
+
+#endif // GVISOR_TEST_UTIL_CGROUP_UTIL_H_
diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc
index 5f1ce0d8a..483ae848d 100644
--- a/test/util/fs_util.cc
+++ b/test/util/fs_util.cc
@@ -28,6 +28,8 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "test/util/cleanup.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
@@ -366,6 +368,48 @@ PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath,
return files;
}
+PosixError DirContains(absl::string_view path,
+ const std::vector<std::string>& expect,
+ const std::vector<std::string>& exclude) {
+ ASSIGN_OR_RETURN_ERRNO(auto listing, ListDir(path, false));
+
+ for (auto& expected_entry : expect) {
+ auto cursor = std::find(listing.begin(), listing.end(), expected_entry);
+ if (cursor == listing.end()) {
+ return PosixError(ENOENT, absl::StrFormat("Failed to find '%s' in '%s'",
+ expected_entry, path));
+ }
+ }
+ for (auto& excluded_entry : exclude) {
+ auto cursor = std::find(listing.begin(), listing.end(), excluded_entry);
+ if (cursor != listing.end()) {
+ return PosixError(ENOENT, absl::StrCat("File '", excluded_entry,
+ "' found in path '", path, "'"));
+ }
+ }
+ return NoError();
+}
+
+PosixError EventuallyDirContains(absl::string_view path,
+ const std::vector<std::string>& expect,
+ const std::vector<std::string>& exclude) {
+ constexpr int kRetryCount = 100;
+ const absl::Duration kRetryDelay = absl::Milliseconds(100);
+
+ for (int i = 0; i < kRetryCount; ++i) {
+ auto res = DirContains(path, expect, exclude);
+ if (res.ok()) {
+ return res;
+ }
+ if (i < kRetryCount - 1) {
+ // Sleep if this isn't the final iteration.
+ absl::SleepFor(kRetryDelay);
+ }
+ }
+ return PosixError(ETIMEDOUT,
+ "Timed out while waiting for directory to contain files ");
+}
+
PosixError RecursivelyDelete(absl::string_view path, int* undeleted_dirs,
int* undeleted_files) {
ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path));
diff --git a/test/util/fs_util.h b/test/util/fs_util.h
index 2190c3bca..bb2d1d3c8 100644
--- a/test/util/fs_util.h
+++ b/test/util/fs_util.h
@@ -129,6 +129,18 @@ PosixError WalkTree(
PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath,
bool skipdots);
+// Check that a directory contains children nodes named in expect, and does not
+// contain any children nodes named in exclude.
+PosixError DirContains(absl::string_view path,
+ const std::vector<std::string>& expect,
+ const std::vector<std::string>& exclude);
+
+// Same as DirContains, but adds a retry. Suitable for checking a directory
+// being modified asynchronously.
+PosixError EventuallyDirContains(absl::string_view path,
+ const std::vector<std::string>& expect,
+ const std::vector<std::string>& exclude);
+
// Attempt to recursively delete a directory or file. Returns an error and
// the number of undeleted directories and files. If either
// undeleted_dirs or undeleted_files is nullptr then it will not be used.
diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go
index 8b4bff3b6..2b3c03fec 100644
--- a/tools/nogo/analyzers.go
+++ b/tools/nogo/analyzers.go
@@ -83,11 +83,6 @@ var AllAnalyzers = []*analysis.Analyzer{
checklocks.Analyzer,
}
-// EscapeAnalyzers is a list of escape-related analyzers.
-var EscapeAnalyzers = []*analysis.Analyzer{
- checkescape.EscapeAnalyzer,
-}
-
func register(all []*analysis.Analyzer) {
// Register all fact types.
//
@@ -129,5 +124,4 @@ func init() {
// Register lists.
register(AllAnalyzers)
- register(EscapeAnalyzers)
}
diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go
index 69bdfe502..4194770be 100644
--- a/tools/nogo/check/main.go
+++ b/tools/nogo/check/main.go
@@ -31,7 +31,6 @@ var (
stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)")
findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)")
factsOutput = flag.String("facts", "", "output file for facts (optional)")
- escapesOutput = flag.String("escapes", "", "output file for escapes (optional)")
)
func loadConfig(file string, config interface{}) interface{} {
@@ -66,25 +65,13 @@ func main() {
// Run the configuration.
if *stdlibFile != "" {
- // Perform basic analysis.
+ // Perform stdlib analysis.
c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig)
findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers)
-
} else if *packageFile != "" {
- // Perform basic analysis.
+ // Perform standard analysis.
c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig)
findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil)
-
- // Do we need to do escape analysis?
- if *escapesOutput != "" {
- escapes, _, err := nogo.CheckPackage(c, nogo.EscapeAnalyzers, nil)
- if err != nil {
- log.Fatalf("error performing escape analysis: %v", err)
- }
- if err := nogo.WriteFindingsToFile(escapes, *escapesOutput); err != nil {
- log.Fatalf("error writing escapes to %q: %v", *escapesOutput, err)
- }
- }
} else {
log.Fatalf("please provide at least one of package or stdlib!")
}
diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl
index 0c48a7a5a..cb407a736 100644
--- a/tools/nogo/defs.bzl
+++ b/tools/nogo/defs.bzl
@@ -174,7 +174,6 @@ NogoInfo = provider(
fields = {
"facts": "serialized package facts",
"raw_findings": "raw package findings (if relevant)",
- "escapes": "escape-only findings (if relevant)",
"importpath": "package import path",
"binaries": "package binary files",
"srcs": "srcs (for go_test support)",
@@ -281,7 +280,6 @@ def _nogo_aspect_impl(target, ctx):
go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch)
facts = ctx.actions.declare_file(target.label.name + ".facts")
raw_findings = ctx.actions.declare_file(target.label.name + ".raw_findings")
- escapes = ctx.actions.declare_file(target.label.name + ".escapes")
config = struct(
ImportPath = importpath,
GoFiles = [src.path for src in srcs if src.path.endswith(".go")],
@@ -298,7 +296,7 @@ def _nogo_aspect_impl(target, ctx):
inputs.append(config_file)
ctx.actions.run(
inputs = inputs,
- outputs = [facts, raw_findings, escapes],
+ outputs = [facts, raw_findings],
tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool),
executable = ctx.files._nogo_check[0],
mnemonic = "NogoAnalysis",
@@ -309,7 +307,6 @@ def _nogo_aspect_impl(target, ctx):
"-package=%s" % config_file.path,
"-findings=%s" % raw_findings.path,
"-facts=%s" % facts.path,
- "-escapes=%s" % escapes.path,
],
)
@@ -322,15 +319,16 @@ def _nogo_aspect_impl(target, ctx):
all_raw_findings = [stdlib_info.raw_findings] + depset(all_raw_findings).to_list() + [raw_findings]
# Return the package facts as output.
- return [NogoInfo(
- facts = facts,
- raw_findings = all_raw_findings,
- escapes = escapes,
- importpath = importpath,
- binaries = binaries,
- srcs = srcs,
- deps = deps,
- )]
+ return [
+ NogoInfo(
+ facts = facts,
+ raw_findings = all_raw_findings,
+ importpath = importpath,
+ binaries = binaries,
+ srcs = srcs,
+ deps = deps,
+ ),
+ ]
nogo_aspect = go_rule(
aspect,
@@ -367,7 +365,6 @@ def _nogo_test_impl(ctx):
if len(ctx.attr.deps) != 1:
fail("nogo_test requires exactly one dep.")
raw_findings = ctx.attr.deps[0][NogoInfo].raw_findings
- escapes = ctx.attr.deps[0][NogoInfo].escapes
# Build a step that applies the configuration.
config_srcs = ctx.attr.config[NogoConfigInfo].srcs
@@ -409,8 +406,6 @@ def _nogo_test_impl(ctx):
# pays attention to the mnemoic above, so this must be
# what is expected by the tooling.
nogo_findings = depset([findings]),
- # Expose all escape analysis findings (see above).
- nogo_escapes = depset([escapes]),
)]
nogo_test = rule(
@@ -432,3 +427,18 @@ nogo_test = rule(
},
test = True,
)
+
+def _nogo_aspect_tricorder_impl(target, ctx):
+ if ctx.rule.kind != "nogo_test" or OutputGroupInfo not in target:
+ return []
+ if not hasattr(target[OutputGroupInfo], "nogo_findings"):
+ return []
+ return [
+ OutputGroupInfo(tricorder = target[OutputGroupInfo].nogo_findings),
+ ]
+
+# Trivial aspect that forwards the findings from a nogo_test rule to
+# go/tricorder, which reads from the `tricorder` output group.
+nogo_aspect_tricorder = aspect(
+ implementation = _nogo_aspect_tricorder_impl,
+)