diff options
134 files changed, 5007 insertions, 829 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml index 3bc5041c0..9163db56d 100644 --- a/.buildkite/pipeline.yaml +++ b/.buildkite/pipeline.yaml @@ -90,6 +90,9 @@ steps: label: ":person_in_lotus_position: KVM tests" command: make kvm-tests - <<: *common + label: ":weight_lifter: Fsstress test" + command: make fsstress-test + - <<: *common label: ":docker: Containerd 1.3.9 tests" command: make containerd-test-1.3.9 - <<: *common @@ -144,6 +144,7 @@ dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo. @$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile) @$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2) @$(call configure_noreload,$(RUNTIME)-vfs2-fuse-d,--net-raw --debug --strace --log-packets --vfs2 --fuse) + @$(call configure_noreload,$(RUNTIME)-vfs2-cgroup-d,--net-raw --debug --strace --log-packets --vfs2 --cgroupfs) @$(call reload_docker) .PHONY: dev @@ -340,7 +341,8 @@ BENCHMARKS_FILTER := . BENCHMARKS_OPTIONS := -test.benchtime=30s BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) $(BENCHMARKS_OPTIONS) BENCHMARKS_PROFILE := -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex -BENCH_RUNTIME_ARGS ?= --vfs2 +BENCH_VFS := --vfs2 +BENCH_RUNTIME_ARGS ?= init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema. @$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)) @@ -361,13 +363,14 @@ run_benchmark = \ benchmark-platforms: load-benchmarks $(RUNTIME_BIN) ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS. @$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \ - $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \ - ) true + $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS) --vfs2) && \ + $(call run_benchmark,$(PLATFORM)_vfs1,--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \ + ) true @$(call run_benchmark,runc) .PHONY: benchmark-platforms run-benchmark: load-benchmarks $(RUNTIME_BIN) ## Runs single benchmark and optionally sends data to BigQuery. - @$(call run_benchmark,$(RUNTIME),$(BENCH_RUNTIME_ARGS)) + @$(call run_benchmark,$(RUNTIME)$(BENCH_VFS),$(BENCH_RUNTIME_ARGS) $(BENCH_VFS)) .PHONY: run-benchmark ## @@ -55,8 +55,6 @@ global: # Same story for underscores. - "should not use ALL_CAPS in Go names" - "should not use underscores in Go names" - # TODO(b/179817829): Upgrade to flock to v0.8.0. - - "flock.NewFlock is deprecated: Use New instead" exclude: # Generated: exempt all. - pkg/shim/runtimeoptions/runtimeoptions_cri.go diff --git a/pkg/abi/linux/fs.go b/pkg/abi/linux/fs.go index 0d921ed6f..cad24fcc7 100644 --- a/pkg/abi/linux/fs.go +++ b/pkg/abi/linux/fs.go @@ -19,8 +19,10 @@ package linux // See linux/magic.h. const ( ANON_INODE_FS_MAGIC = 0x09041934 + CGROUP_SUPER_MAGIC = 0x27e0eb DEVPTS_SUPER_MAGIC = 0x00001cd1 EXT_SUPER_MAGIC = 0xef53 + FUSE_SUPER_MAGIC = 0x65735546 OVERLAYFS_SUPER_MAGIC = 0x794c7630 PIPEFS_MAGIC = 0x50495045 PROC_SUPER_MAGIC = 0x9fa0 @@ -29,7 +31,6 @@ const ( SYSFS_MAGIC = 0x62656572 TMPFS_MAGIC = 0x01021994 V9FS_MAGIC = 0x01021997 - FUSE_SUPER_MAGIC = 0x65735546 ) // Filesystem path limits, from uapi/linux/limits.h. diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go index 50e22fe7e..e722971f1 100644 --- a/pkg/abi/linux/ptrace_amd64.go +++ b/pkg/abi/linux/ptrace_amd64.go @@ -61,3 +61,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 { func (p *PtraceRegs) StackPointer() uint64 { return p.Rsp } + +// SetStackPointer sets the stack pointer to the specified value. +func (p *PtraceRegs) SetStackPointer(sp uint64) { + p.Rsp = sp +} diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go index da36811d2..3d0906565 100644 --- a/pkg/abi/linux/ptrace_arm64.go +++ b/pkg/abi/linux/ptrace_arm64.go @@ -38,3 +38,8 @@ func (p *PtraceRegs) InstructionPointer() uint64 { func (p *PtraceRegs) StackPointer() uint64 { return p.Sp } + +// SetStackPointer sets the stack pointer to the specified value. +func (p *PtraceRegs) SetStackPointer(sp uint64) { + p.Sp = sp +} diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index 961bd4dcf..6450f664c 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -238,6 +238,7 @@ func Generate(params *GenerateParams) ([]byte, error) { Mode: params.Mode, UID: params.UID, GID: params.GID, + Children: params.Children, SymlinkTarget: params.SymlinkTarget, } diff --git a/pkg/ring0/kernel_amd64.go b/pkg/ring0/kernel_amd64.go index 6e17fb796..41dfd0bf9 100644 --- a/pkg/ring0/kernel_amd64.go +++ b/pkg/ring0/kernel_amd64.go @@ -250,7 +250,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { } SaveFloatingPoint(switchOpts.FloatingPointState.BytePointer()) // escapes: no. Copy out floating point. WriteFS(uintptr(c.registers.Fs_base)) // escapes: no. Restore kernel FS. - ldmxcsr(&kernelMXCSR) // escapes: no. Restore kernel MXCSR. + RestoreKernelFPState() // escapes: no. Restore kernel MXCSR. return } @@ -329,6 +329,14 @@ func ReadCR2() uintptr { // at src/cmd/compile/abi-internal.md in the golang sources for more details. var kernelMXCSR uint32 +// RestoreKernelFPState restores the Sentry floating point state. +// +//go:nosplit +func RestoreKernelFPState() { + // Restore the MXCSR control configuration. + ldmxcsr(&kernelMXCSR) +} + func init() { stmxcsr(&kernelMXCSR) } diff --git a/pkg/ring0/kernel_arm64.go b/pkg/ring0/kernel_arm64.go index 7975e5f92..21db910a2 100644 --- a/pkg/ring0/kernel_arm64.go +++ b/pkg/ring0/kernel_arm64.go @@ -65,7 +65,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeEl0Fpstate(switchOpts.FloatingPointState.BytePointer()) if switchOpts.Flush { - FlushTlbByASID(uintptr(switchOpts.UserASID)) + LocalFlushTlbByASID(uintptr(switchOpts.UserASID)) } regs := switchOpts.Registers @@ -89,3 +89,9 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { return } + +// RestoreKernelFPState restores the Sentry floating point state. +// +//go:nosplit +func RestoreKernelFPState() { +} diff --git a/pkg/ring0/lib_arm64.go b/pkg/ring0/lib_arm64.go index e44df00a6..5eabd4296 100644 --- a/pkg/ring0/lib_arm64.go +++ b/pkg/ring0/lib_arm64.go @@ -31,6 +31,9 @@ func FlushTlbByVA(addr uintptr) // FlushTlbByASID invalidates tlb by ASID/Inner-Shareable. func FlushTlbByASID(asid uintptr) +// LocalFlushTlbByASID invalidates tlb by ASID. +func LocalFlushTlbByASID(asid uintptr) + // FlushTlbAll invalidates all tlb. func FlushTlbAll() diff --git a/pkg/ring0/lib_arm64.s b/pkg/ring0/lib_arm64.s index e39b32841..69ebaf519 100644 --- a/pkg/ring0/lib_arm64.s +++ b/pkg/ring0/lib_arm64.s @@ -32,6 +32,14 @@ TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8 DSB $11 // dsb(ish) RET +TEXT ·LocalFlushTlbByASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + LSL $TLBI_ASID_SHIFT, R1, R1 + DSB $10 // dsb(ishst) + WORD $0xd5088741 // tlbi aside1, x1 + DSB $11 // dsb(ish) + RET + TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 DSB $6 // dsb(nshst) WORD $0xd508871f // __tlbi(vmalle1) diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index 1929e41cd..49c53452a 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -93,6 +93,7 @@ func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro // "/dev/zero (deleted)". opts.Offset = 0 opts.MappingIdentity = &fd.vfsfd + opts.SentryOwnedContent = true opts.MappingIdentity.IncRef() return nil } diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD new file mode 100644 index 000000000..48913068a --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/BUILD @@ -0,0 +1,47 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +licenses(["notice"]) + +go_template_instance( + name = "dir_refs", + out = "dir_refs.go", + package = "cgroupfs", + prefix = "dir", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "dir", + }, +) + +go_library( + name = "cgroupfs", + srcs = [ + "base.go", + "cgroupfs.go", + "cpu.go", + "cpuacct.go", + "cpuset.go", + "dir_refs.go", + "memory.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/coverage", + "//pkg/log", + "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/usage", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go new file mode 100644 index 000000000..39c1013e1 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -0,0 +1,233 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + "sort" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/usermem" +) + +// controllerCommon implements kernel.CgroupController. +// +// Must call init before use. +// +// +stateify savable +type controllerCommon struct { + ty kernel.CgroupControllerType + fs *filesystem +} + +func (c *controllerCommon) init(ty kernel.CgroupControllerType, fs *filesystem) { + c.ty = ty + c.fs = fs +} + +// Type implements kernel.CgroupController.Type. +func (c *controllerCommon) Type() kernel.CgroupControllerType { + return kernel.CgroupControllerType(c.ty) +} + +// HierarchyID implements kernel.CgroupController.HierarchyID. +func (c *controllerCommon) HierarchyID() uint32 { + return c.fs.hierarchyID +} + +// NumCgroups implements kernel.CgroupController.NumCgroups. +func (c *controllerCommon) NumCgroups() uint64 { + return atomic.LoadUint64(&c.fs.numCgroups) +} + +// Enabled implements kernel.CgroupController.Enabled. +// +// Controllers are currently always enabled. +func (c *controllerCommon) Enabled() bool { + return true +} + +// Filesystem implements kernel.CgroupController.Filesystem. +func (c *controllerCommon) Filesystem() *vfs.Filesystem { + return c.fs.VFSFilesystem() +} + +// RootCgroup implements kernel.CgroupController.RootCgroup. +func (c *controllerCommon) RootCgroup() kernel.Cgroup { + return c.fs.rootCgroup() +} + +// controller is an interface for common functionality related to all cgroups. +// It is an extension of the public cgroup interface, containing cgroup +// functionality private to cgroupfs. +type controller interface { + kernel.CgroupController + + // AddControlFiles should extend the contents map with inodes representing + // control files defined by this controller. + AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode) +} + +// cgroupInode implements kernel.CgroupImpl and kernfs.Inode. +// +// +stateify savable +type cgroupInode struct { + dir + fs *filesystem + + // ts is the list of tasks in this cgroup. The kernel is responsible for + // removing tasks from this list before they're destroyed, so any tasks on + // this list are always valid. + // + // ts, and cgroup membership in general is protected by fs.tasksMu. + ts map[*kernel.Task]struct{} +} + +var _ kernel.CgroupImpl = (*cgroupInode)(nil) + +func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode { + c := &cgroupInode{ + fs: fs, + ts: make(map[*kernel.Task]struct{}), + } + + contents := make(map[string]kernfs.Inode) + contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c}) + contents["tasks"] = fs.newControllerFile(ctx, creds, &tasksData{c}) + + for _, ctl := range fs.controllers { + ctl.AddControlFiles(ctx, creds, c, contents) + } + + c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555)) + c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + c.dir.InitRefs() + c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents)) + + atomic.AddUint64(&fs.numCgroups, 1) + + return c +} + +func (c *cgroupInode) HierarchyID() uint32 { + return c.fs.hierarchyID +} + +// Controllers implements kernel.CgroupImpl.Controllers. +func (c *cgroupInode) Controllers() []kernel.CgroupController { + return c.fs.kcontrollers +} + +// Enter implements kernel.CgroupImpl.Enter. +func (c *cgroupInode) Enter(t *kernel.Task) { + c.fs.tasksMu.Lock() + c.ts[t] = struct{}{} + c.fs.tasksMu.Unlock() +} + +// Leave implements kernel.CgroupImpl.Leave. +func (c *cgroupInode) Leave(t *kernel.Task) { + c.fs.tasksMu.Lock() + delete(c.ts, t) + c.fs.tasksMu.Unlock() +} + +func sortTIDs(tids []kernel.ThreadID) { + sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] }) +} + +// +stateify savable +type cgroupProcsData struct { + *cgroupInode +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + + pgids := make(map[kernel.ThreadID]struct{}) + + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + for task := range d.ts { + // Map dedups pgid, since iterating over all tasks produces multiple + // entries for the group leaders. + if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 { + pgids[pgid] = struct{}{} + } + } + + pgidList := make([]kernel.ThreadID, 0, len(pgids)) + for pgid, _ := range pgids { + pgidList = append(pgidList, pgid) + } + sortTIDs(pgidList) + + for _, pgid := range pgidList { + fmt.Fprintf(buf, "%d\n", pgid) + } + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. + return src.NumBytes(), nil +} + +// +stateify savable +type tasksData struct { + *cgroupInode +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + + var pids []kernel.ThreadID + + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + for task := range d.ts { + if pid := currPidns.IDOfTask(task); pid != 0 { + pids = append(pids, pid) + } + } + sortTIDs(pids) + + for _, pid := range pids { + fmt.Fprintf(buf, "%d\n", pid) + } + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. + return src.NumBytes(), nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go new file mode 100644 index 000000000..ca8caee5f --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -0,0 +1,412 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cgroupfs implements cgroupfs. +// +// A cgroup is a collection of tasks on the system, organized into a tree-like +// structure similar to a filesystem directory tree. In fact, each cgroup is +// represented by a directory on cgroupfs, and is manipulated through control +// files in the directory. +// +// All cgroups on a system are organized into hierarchies. Hierarchies are a +// distinct tree of cgroups, with a common set of controllers. One or more +// cgroupfs mounts may point to each hierarchy. These mounts provide a common +// view into the same tree of cgroups. +// +// A controller (also known as a "resource controller", or a cgroup "subsystem") +// determines the behaviour of each cgroup. +// +// In addition to cgroupfs, the kernel has a cgroup registry that tracks +// system-wide state related to cgroups such as active hierarchies and the +// controllers associated with them. +// +// Since cgroupfs doesn't allow hardlinks, there is a unique mapping between +// cgroupfs dentries and inodes. +// +// # Synchronization +// +// Cgroup hierarchy creation and destruction is protected by the +// kernel.CgroupRegistry.mu. Once created, a hierarchy's set of controllers, the +// filesystem associated with it, and the root cgroup for the hierarchy are +// immutable. +// +// Membership of tasks within cgroups is protected by +// cgroupfs.filesystem.tasksMu. Tasks also maintain a set of all cgroups they're +// in, and this list is protected by Task.mu. +// +// Lock order: +// +// kernel.CgroupRegistry.mu +// cgroupfs.filesystem.mu +// Task.mu +// cgroupfs.filesystem.tasksMu. +package cgroupfs + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + // Name is the default filesystem name. + Name = "cgroup" + readonlyFileMode = linux.FileMode(0444) + writableFileMode = linux.FileMode(0644) + defaultMaxCachedDentries = uint64(1000) +) + +const ( + controllerCPU = kernel.CgroupControllerType("cpu") + controllerCPUAcct = kernel.CgroupControllerType("cpuacct") + controllerCPUSet = kernel.CgroupControllerType("cpuset") + controllerMemory = kernel.CgroupControllerType("memory") +) + +var allControllers = []kernel.CgroupControllerType{controllerCPU, controllerCPUAcct, controllerCPUSet, controllerMemory} + +// SupportedMountOptions is the set of supported mount options for cgroupfs. +var SupportedMountOptions = []string{"all", "cpu", "cpuacct", "cpuset", "memory"} + +// FilesystemType implements vfs.FilesystemType. +// +// +stateify savable +type FilesystemType struct{} + +// InternalData contains internal data passed in to the cgroupfs mount via +// vfs.GetFilesystemOptions.InternalData. +// +// +stateify savable +type InternalData struct { + DefaultControlValues map[string]int64 +} + +// filesystem implements vfs.FilesystemImpl. +// +// +stateify savable +type filesystem struct { + kernfs.Filesystem + devMinor uint32 + + // hierarchyID is the id the cgroup registry assigns to this hierarchy. Has + // the value kernel.InvalidCgroupHierarchyID until the FS is fully + // initialized. + // + // hierarchyID is immutable after initialization. + hierarchyID uint32 + + // controllers and kcontrollers are both the list of controllers attached to + // this cgroupfs. Both lists are the same set of controllers, but typecast + // to different interfaces for convenience. Both must stay in sync, and are + // immutable. + controllers []controller + kcontrollers []kernel.CgroupController + + numCgroups uint64 // Protected by atomic ops. + + root *kernfs.Dentry + + // tasksMu serializes task membership changes across all cgroups within a + // filesystem. + tasksMu sync.RWMutex `state:"nosave"` +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// Release implements vfs.FilesystemType.Release. +func (FilesystemType) Release(ctx context.Context) {} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + + var wantControllers []kernel.CgroupControllerType + if _, ok := mopts["cpu"]; ok { + delete(mopts, "cpu") + wantControllers = append(wantControllers, controllerCPU) + } + if _, ok := mopts["cpuacct"]; ok { + delete(mopts, "cpuacct") + wantControllers = append(wantControllers, controllerCPUAcct) + } + if _, ok := mopts["cpuset"]; ok { + delete(mopts, "cpuset") + wantControllers = append(wantControllers, controllerCPUSet) + } + if _, ok := mopts["memory"]; ok { + delete(mopts, "memory") + wantControllers = append(wantControllers, controllerMemory) + } + if _, ok := mopts["all"]; ok { + if len(wantControllers) > 0 { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers) + return nil, nil, syserror.EINVAL + } + + delete(mopts, "all") + wantControllers = allControllers + } + + if len(wantControllers) == 0 { + // Specifying no controllers implies all controllers. + wantControllers = allControllers + } + + if len(mopts) != 0 { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + + k := kernel.KernelFromContext(ctx) + r := k.CgroupRegistry() + + // "It is not possible to mount the same controller against multiple + // cgroup hierarchies. For example, it is not possible to mount both + // the cpu and cpuacct controllers against one hierarchy, and to mount + // the cpu controller alone against another hierarchy." - man cgroups(7) + // + // Is there a hierarchy available with all the controllers we want? If so, + // this mount is a view into the same hierarchy. + // + // Note: we're guaranteed to have at least one requested controller, since + // no explicit controller name implies all controllers. + if vfsfs := r.FindHierarchy(wantControllers); vfsfs != nil { + fs := vfsfs.Impl().(*filesystem) + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID) + fs.root.IncRef() + return vfsfs, fs.root.VFSDentry(), nil + } + + // No existing hierarchy with the exactly controllers found. Make a new + // one. Note that it's possible this mount creation is unsatisfiable, if one + // or more of the requested controllers are already on existing + // hierarchies. We'll find out about such collisions when we try to register + // the new hierarchy later. + fs := &filesystem{ + devMinor: devMinor, + } + fs.MaxCachedDentries = maxCachedDentries + fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + + var defaults map[string]int64 + if opts.InternalData != nil { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults) + defaults = opts.InternalData.(*InternalData).DefaultControlValues + } + + for _, ty := range wantControllers { + var c controller + switch ty { + case controllerMemory: + c = newMemoryController(fs, defaults) + case controllerCPU: + c = newCPUController(fs, defaults) + case controllerCPUAcct: + c = newCPUAcctController(fs) + case controllerCPUSet: + c = newCPUSetController(fs) + default: + panic(fmt.Sprintf("Unreachable: unknown cgroup controller %q", ty)) + } + fs.controllers = append(fs.controllers, c) + } + + if len(defaults) != 0 { + // Internal data is always provided at sentry startup and unused values + // indicate a problem with the sandbox config. Fail fast. + panic(fmt.Sprintf("cgroupfs.FilesystemType.GetFilesystem: unknown internal mount data: %v", defaults)) + } + + // Controllers usually appear in alphabetical order when displayed. Sort it + // here now, so it never needs to be sorted elsewhere. + sort.Slice(fs.controllers, func(i, j int) bool { return fs.controllers[i].Type() < fs.controllers[j].Type() }) + fs.kcontrollers = make([]kernel.CgroupController, 0, len(fs.controllers)) + for _, c := range fs.controllers { + fs.kcontrollers = append(fs.kcontrollers, c) + } + + root := fs.newCgroupInode(ctx, creds) + var rootD kernfs.Dentry + rootD.InitRoot(&fs.Filesystem, root) + fs.root = &rootD + + // Register controllers. The registry may be modified concurrently, so if we + // get an error, we raced with someone else who registered the same + // controllers first. + hid, err := r.Register(fs.kcontrollers) + if err != nil { + ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err) + rootD.DecRef(ctx) + fs.VFSFilesystem().DecRef(ctx) + return nil, nil, syserror.EBUSY + } + fs.hierarchyID = hid + + // Move all existing tasks to the root of the new hierarchy. + k.PopulateNewCgroupHierarchy(fs.rootCgroup()) + + return fs.VFSFilesystem(), rootD.VFSDentry(), nil +} + +func (fs *filesystem) rootCgroup() kernel.Cgroup { + return kernel.Cgroup{ + Dentry: fs.root, + CgroupImpl: fs.root.Inode().(kernel.CgroupImpl), + } +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release(ctx context.Context) { + k := kernel.KernelFromContext(ctx) + r := k.CgroupRegistry() + + if fs.hierarchyID != kernel.InvalidCgroupHierarchyID { + k.ReleaseCgroupHierarchy(fs.hierarchyID) + r.Unregister(fs.hierarchyID) + } + + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release(ctx) +} + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + var cnames []string + for _, c := range fs.controllers { + cnames = append(cnames, string(c.Type())) + } + return strings.Join(cnames, ",") +} + +// +stateify savable +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil +} + +// dir implements kernfs.Inode for a generic cgroup resource controller +// directory. Specific controllers extend this to add their own functionality. +// +// +stateify savable +type dir struct { + dirRefs + kernfs.InodeAlwaysValid + kernfs.InodeAttrs + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir. + kernfs.OrderedChildren + implStatFS + + locks vfs.FileLocks +} + +// Keep implements kernfs.Inode.Keep. +func (*dir) Keep() bool { + return true +} + +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. +func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Open implements kernfs.Inode.Open. +func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) + if err != nil { + return nil, err + } + return fd.VFSFileDescription(), nil +} + +// DecRef implements kernfs.Inode.DecRef. +func (d *dir) DecRef(ctx context.Context) { + d.dirRefs.DecRef(func() { d.Destroy(ctx) }) +} + +// StatFS implements kernfs.Inode.StatFS. +func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil +} + +// controllerFile represents a generic control file that appears within a cgroup +// directory. +// +// +stateify savable +type controllerFile struct { + kernfs.DynamicBytesFile +} + +func (fs *filesystem) newControllerFile(ctx context.Context, creds *auth.Credentials, data vfs.DynamicBytesSource) kernfs.Inode { + f := &controllerFile{} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, readonlyFileMode) + return f +} + +func (fs *filesystem) newControllerWritableFile(ctx context.Context, creds *auth.Credentials, data vfs.WritableDynamicBytesSource) kernfs.Inode { + f := &controllerFile{} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, writableFileMode) + return f +} + +// staticControllerFile represents a generic control file that appears within a +// cgroup directory which always returns the same data when read. +// staticControllerFiles are not writable. +// +// +stateify savable +type staticControllerFile struct { + kernfs.DynamicBytesFile + vfs.StaticData +} + +// Note: We let the caller provide the mode so that static files may be used to +// fake both readable and writable control files. However, static files are +// effectively readonly, as attempting to write to them will return EIO +// regardless of the mode. +func (fs *filesystem) newStaticControllerFile(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, data string) kernfs.Inode { + f := &staticControllerFile{StaticData: vfs.StaticData{Data: data}} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), f, mode) + return f +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpu.go b/pkg/sentry/fsimpl/cgroupfs/cpu.go new file mode 100644 index 000000000..24d86a277 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpu.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// +stateify savable +type cpuController struct { + controllerCommon + + // CFS bandwidth control parameters, values in microseconds. + cfsPeriod int64 + cfsQuota int64 + + // CPU shares, values should be (num core * 1024). + shares int64 +} + +var _ controller = (*cpuController)(nil) + +func newCPUController(fs *filesystem, defaults map[string]int64) *cpuController { + // Default values for controller parameters from Linux. + c := &cpuController{ + cfsPeriod: 100000, + cfsQuota: -1, + shares: 1024, + } + + if val, ok := defaults["cpu.cfs_period_us"]; ok { + c.cfsPeriod = val + delete(defaults, "cpu.cfs_period_us") + } + if val, ok := defaults["cpu.cfs_quota_us"]; ok { + c.cfsQuota = val + delete(defaults, "cpu.cfs_quota_us") + } + if val, ok := defaults["cpu.shares"]; ok { + c.shares = val + delete(defaults, "cpu.shares") + } + + c.controllerCommon.init(controllerCPU, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpuController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["cpu.cfs_period_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsPeriod)) + contents["cpu.cfs_quota_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsQuota)) + contents["cpu.shares"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.shares)) +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go new file mode 100644 index 000000000..d4104a00e --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go @@ -0,0 +1,114 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +// +stateify savable +type cpuacctController struct { + controllerCommon +} + +var _ controller = (*cpuacctController)(nil) + +func newCPUAcctController(fs *filesystem) *cpuacctController { + c := &cpuacctController{} + c.controllerCommon.init(controllerCPUAcct, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpuacctController) AddControlFiles(ctx context.Context, creds *auth.Credentials, cg *cgroupInode, contents map[string]kernfs.Inode) { + cpuacctCG := &cpuacctCgroup{cg} + contents["cpuacct.stat"] = c.fs.newControllerFile(ctx, creds, &cpuacctStatData{cpuacctCG}) + contents["cpuacct.usage"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageData{cpuacctCG}) + contents["cpuacct.usage_user"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageUserData{cpuacctCG}) + contents["cpuacct.usage_sys"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageSysData{cpuacctCG}) +} + +// +stateify savable +type cpuacctCgroup struct { + *cgroupInode +} + +func (c *cpuacctCgroup) collectCPUStats() usage.CPUStats { + var cs usage.CPUStats + c.fs.tasksMu.RLock() + // Note: This isn't very accurate, since the tasks are potentially + // still running as we accumulate their stats. + for t := range c.ts { + cs.Accumulate(t.CPUStats()) + } + c.fs.tasksMu.RUnlock() + return cs +} + +// +stateify savable +type cpuacctStatData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "user %d\n", linux.ClockTFromDuration(cs.UserTime)) + fmt.Fprintf(buf, "system %d\n", linux.ClockTFromDuration(cs.SysTime)) + return nil +} + +// +stateify savable +type cpuacctUsageData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()+cs.SysTime.Nanoseconds()) + return nil +} + +// +stateify savable +type cpuacctUsageUserData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageUserData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()) + return nil +} + +// +stateify savable +type cpuacctUsageSysData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageSysData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.SysTime.Nanoseconds()) + return nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go new file mode 100644 index 000000000..ac547f8e2 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go @@ -0,0 +1,39 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// +stateify savable +type cpusetController struct { + controllerCommon +} + +var _ controller = (*cpusetController)(nil) + +func newCPUSetController(fs *filesystem) *cpusetController { + c := &cpusetController{} + c.controllerCommon.init(controllerCPUSet, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + // This controller is currently intentionally empty. +} diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go new file mode 100644 index 000000000..485c98376 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/memory.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + "math" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +// +stateify savable +type memoryController struct { + controllerCommon + + limitBytes int64 +} + +var _ controller = (*memoryController)(nil) + +func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController { + c := &memoryController{ + // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which + // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact + // value isn't very important. + limitBytes: math.MaxInt64, + } + if val, ok := defaults["memory.limit_in_bytes"]; ok { + c.limitBytes = val + delete(defaults, "memory.limit_in_bytes") + } + c.controllerCommon.init(controllerMemory, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{}) + contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes)) +} + +// +stateify savable +type memoryUsageInBytesData struct{} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *memoryUsageInBytesData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // TODO(b/183151557): This is a giant hack, we're using system-wide + // accounting since we know there is only one cgroup. + k := kernel.KernelFromContext(ctx) + mf := k.MemoryFile() + mf.UpdateUsage() + _, totalBytes := usage.MemoryAccounting.Copy() + + fmt.Fprintf(buf, "%d\n", totalBytes) + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index a0c05231a..526136324 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1104,24 +1104,27 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs defer d.metadataMu.Unlock() // As with Linux, if the UID, GID, or file size is changing, we have to - // clear permission bits. Note that when set, clearSGID causes - // permissions to be updated, but does not modify stat.Mask, as - // modification would cause an extra inotify flag to be set. - clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) || - stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) || + // clear permission bits. Note that when set, clearSGID may cause + // permissions to be updated. + clearSGID := (stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid)) || + (stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid)) || stat.Mask&linux.STATX_SIZE != 0 if clearSGID { if stat.Mask&linux.STATX_MODE != 0 { stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode))) } else { - stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode))) + oldMode := atomic.LoadUint32(&d.mode) + if updatedMode := vfs.ClearSUIDAndSGID(oldMode); updatedMode != oldMode { + stat.Mode = uint16(updatedMode) + stat.Mask |= linux.STATX_MODE + } } } if !d.isSynthetic() { if stat.Mask != 0 { if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID, + Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, GID: stat.Mask&linux.STATX_GID != 0, Size: stat.Mask&linux.STATX_SIZE != 0, @@ -1156,7 +1159,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - if stat.Mask&linux.STATX_MODE != 0 || clearSGID { + if stat.Mask&linux.STATX_MODE != 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } if stat.Mask&linux.STATX_UID != 0 { diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 47563538c..713f0a480 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -701,6 +701,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt } // After this point, d may be used as a memmap.Mappable. d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init) + opts.SentryOwnedContent = d.fs.opts.forcePageCache return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts) } diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 65054b0ea..84b1c3745 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -25,8 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// DynamicBytesFile implements kernfs.Inode and represents a read-only -// file whose contents are backed by a vfs.DynamicBytesSource. +// DynamicBytesFile implements kernfs.Inode and represents a read-only file +// whose contents are backed by a vfs.DynamicBytesSource. If data additionally +// implements vfs.WritableDynamicBytesSource, the file also supports dispatching +// writes to the implementer, but note that this will not update the source data. // // Must be instantiated with NewDynamicBytesFile or initialized with Init // before first use. @@ -40,7 +42,9 @@ type DynamicBytesFile struct { InodeNotSymlink locks vfs.FileLocks - data vfs.DynamicBytesSource + // data can additionally implement vfs.WritableDynamicBytesSource to support + // writes. + data vfs.DynamicBytesSource } var _ Inode = (*DynamicBytesFile)(nil) diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 565d723f0..16486eeae 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -61,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -508,6 +509,15 @@ func (d *Dentry) Inode() Inode { return d.inode } +// FSLocalPath returns an absolute path to d, relative to the root of its +// filesystem. +func (d *Dentry) FSLocalPath() string { + var b fspath.Builder + _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b) + b.PrependByte('/') + return b.String() +} + // The Inode interface maps filesystem-level operations that operate on paths to // equivalent operations on specific filesystem nodes. // diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 254a8b062..ce8f55b1f 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -86,13 +86,13 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) - var cgroups map[string]string + var fakeCgroupControllers map[string]string if opts.InternalData != nil { data := opts.InternalData.(*InternalData) - cgroups = data.Cgroups + fakeCgroupControllers = data.Cgroups } - inode := procfs.newTasksInode(ctx, k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, fakeCgroupControllers) var dentry kernfs.Dentry dentry.InitRoot(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index fea138f93..d05cc1508 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -47,7 +47,7 @@ type taskInode struct { var _ kernfs.Inode = (*taskInode)(nil) -func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) { +func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, fakeCgroupControllers map[string]string) (kernfs.Inode, error) { if task.ExitState() == kernel.TaskExitDead { return nil, syserror.ESRCH } @@ -82,10 +82,12 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns "uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { - contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers) + contents["task"] = fs.newSubtasks(ctx, task, pidns, fakeCgroupControllers) } - if len(cgroupControllers) > 0 { - contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) + if len(fakeCgroupControllers) > 0 { + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newFakeCgroupData(fakeCgroupControllers)) + } else { + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskCgroupData{task: task}) } taskInode := &taskInode{task: task} @@ -226,11 +228,14 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData { return &ioData{ioUsage: t} } -// newCgroupData creates inode that shows cgroup information. -// From man 7 cgroups: "For each cgroup hierarchy of which the process is a -// member, there is one entry containing three colon-separated fields: -// hierarchy-ID:controller-list:cgroup-path" -func newCgroupData(controllers map[string]string) dynamicInode { +// newFakeCgroupData creates an inode that shows fake cgroup +// information passed in as mount options. From man 7 cgroups: "For +// each cgroup hierarchy of which the process is a member, there is +// one entry containing three colon-separated fields: +// hierarchy-ID:controller-list:cgroup-path" +// +// TODO(b/182488796): Remove once all users adopt cgroupfs. +func newFakeCgroupData(controllers map[string]string) dynamicInode { var buf bytes.Buffer // The hierarchy ids must be positive integers (for cgroup v1), but the diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 85909d551..b294dfd6a 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -1100,3 +1100,32 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release(ctx context.Context) { fd.inode.DecRef(ctx) } + +// taskCgroupData generates data for /proc/[pid]/cgroup. +// +// +stateify savable +type taskCgroupData struct { + dynamicBytesFileSetAttr + task *kernel.Task +} + +var _ dynamicInode = (*taskCgroupData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *taskCgroupData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // When a task is existing on Linux, a task's cgroup set is cleared and + // reset to the initial cgroup set, which is essentially the set of root + // cgroups. Because of this, the /proc/<pid>/cgroup file is always readable + // on Linux throughout a task's lifetime. + // + // The sentry removes tasks from cgroups during the exit process, but + // doesn't move them into an initial cgroup set, so partway through task + // exit this file show a task is in no cgroups, which is incorrect. Instead, + // once a task has left its cgroups, we return an error. + if d.task.ExitState() >= kernel.TaskExitInitiated { + return syserror.ESRCH + } + + d.task.GenerateProcTaskCgroup(buf) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index fdc580610..7c7543f14 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -54,15 +54,15 @@ type tasksInode struct { // '/proc/self' and '/proc/thread-self' have custom directory offsets in // Linux. So handle them outside of OrderedChildren. - // cgroupControllers is a map of controller name to directory in the + // fakeCgroupControllers is a map of controller name to directory in the // cgroup hierarchy. These controllers are immutable and will be listed // in /proc/pid/cgroup if not nil. - cgroupControllers map[string]string + fakeCgroupControllers map[string]string } var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), @@ -76,11 +76,16 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), "version": fs.newInode(ctx, root, 0444, &versionData{}), } + // If fakeCgroupControllers are provided, don't create a cgroupfs backed + // /proc/cgroup as it will not match the fake controllers. + if len(fakeCgroupControllers) == 0 { + contents["cgroups"] = fs.newInode(ctx, root, 0444, &cgroupsData{}) + } inode := &tasksInode{ - pidns: pidns, - fs: fs, - cgroupControllers: cgroupControllers, + pidns: pidns, + fs: fs, + fakeCgroupControllers: fakeCgroupControllers, } inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.InitRefs() @@ -118,7 +123,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err return nil, syserror.ENOENT } - return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers) + return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers) } // IterDirents implements kernfs.inodeDirectory.IterDirents. diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index f0029cda6..e1a8b4409 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -384,3 +384,19 @@ func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error k.VFS().GenerateProcFilesystems(buf) return nil } + +// cgroupsData backs /proc/cgroups. +// +// +stateify savable +type cgroupsData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cgroupsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + r := kernel.KernelFromContext(ctx).CgroupRegistry() + r.GenerateProcCgroups(buf) + return nil +} diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index cd849e87e..c45bddff6 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -488,6 +488,7 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { file := fd.inode().impl.(*regularFile) + opts.SentryOwnedContent = true return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts) } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 2da251233..d473a922d 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -18,10 +18,12 @@ go_library( "//pkg/marshal/primitive", "//pkg/merkletree", "//pkg/refsvfs2", + "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 6cb1a23e0..214ffd095 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -632,8 +632,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childVD.IncRef() childMerkleVD.IncRef() - parent.IncRef() - child.parent = parent child.name = name child.mode = uint32(stat.Mode) @@ -657,6 +655,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } } + parent.IncRef() + child.parent = parent + return child, nil } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index a7d92a878..06f2c211c 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -34,6 +34,8 @@ package verity import ( + "bytes" + "encoding/hex" "encoding/json" "fmt" "math" @@ -44,19 +46,20 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -103,6 +106,13 @@ var ( verityMu sync.RWMutex ) +// Mount option names for verityfs. +const ( + moptLowerPath = "lower_path" + moptRootHash = "root_hash" + moptRootName = "root_name" +) + // HashAlgorithm is a type specifying the algorithm used to hash the file // content. type HashAlgorithm int @@ -169,6 +179,9 @@ type filesystem struct { // system. alg HashAlgorithm + // opts is the string mount options passed to opts.Data. + opts string + // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. @@ -191,9 +204,6 @@ type filesystem struct { // // +stateify savable type InternalFilesystemOptions struct { - // RootMerkleFileName is the name of the verity root Merkle tree file. - RootMerkleFileName string - // LowerName is the name of the filesystem wrapped by verity fs. LowerName string @@ -201,9 +211,6 @@ type InternalFilesystemOptions struct { // system. Alg HashAlgorithm - // RootHash is the root hash of the overall verity file system. - RootHash []byte - // AllowRuntimeEnable specifies whether the verity file system allows // enabling verification for files (i.e. building Merkle trees) during // runtime. @@ -237,28 +244,99 @@ func alertIntegrityViolation(msg string) error { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + mopts := vfs.GenericParseMountOptions(opts.Data) + var rootHash []byte + if encodedRootHash, ok := mopts[moptRootHash]; ok { + delete(mopts, moptRootHash) + hash, err := hex.DecodeString(encodedRootHash) + if err != nil { + ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err) + return nil, nil, syserror.EINVAL + } + rootHash = hash + } + var lowerPathname string + if path, ok := mopts[moptLowerPath]; ok { + delete(mopts, moptLowerPath) + lowerPathname = path + } + rootName := "root" + if root, ok := mopts[moptRootName]; ok { + delete(mopts, moptRootName) + rootName = root + } + + // Check for unparsed options. + if len(mopts) != 0 { + ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + + // Handle internal options. iopts, ok := opts.InternalData.(InternalFilesystemOptions) - if !ok { + if len(lowerPathname) == 0 && !ok { ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } + if len(lowerPathname) != 0 { + if ok { + ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path") + return nil, nil, syserror.EINVAL + } + iopts = InternalFilesystemOptions{ + AllowRuntimeEnable: len(rootHash) == 0, + Action: ErrorOnViolation, + } + } action = iopts.Action - // Mount the lower file system. The lower file system is wrapped inside - // verity, and should not be exposed or connected. - mopts := &vfs.MountOptions{ - GetFilesystemOptions: iopts.LowerGetFSOptions, - InternalMount: true, - } - mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts) - if err != nil { - return nil, nil, err + var lowerMount *vfs.Mount + var mountedLowerVD vfs.VirtualDentry + // Use an existing mount if lowerPath is provided. + if len(lowerPathname) != 0 { + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + lowerPath := fspath.Parse(lowerPathname) + if !lowerPath.Absolute { + ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname) + return nil, nil, syserror.EINVAL + } + var err error + mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: lowerPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("verity.FilesystemType.GetFilesystem: failed to resolve lower_path %q: %v", lowerPathname, err) + return nil, nil, err + } + lowerMount = mountedLowerVD.Mount() + defer mountedLowerVD.DecRef(ctx) + } else { + // Mount the lower file system. The lower file system is wrapped inside + // verity, and should not be exposed or connected. + mountOpts := &vfs.MountOptions{ + GetFilesystemOptions: iopts.LowerGetFSOptions, + InternalMount: true, + } + mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mountOpts) + if err != nil { + return nil, nil, err + } + lowerMount = mnt } fs := &filesystem{ creds: creds.Fork(), alg: iopts.Alg, - lowerMount: mnt, + lowerMount: lowerMount, + opts: opts.Data, allowRuntimeEnable: iopts.AllowRuntimeEnable, } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -266,11 +344,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Construct the root dentry. d := fs.newDentry() d.refs = 1 - lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root()) + lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root()) lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + rootName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -350,9 +428,15 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID - d.hash = make([]byte, len(iopts.RootHash)) d.childrenNames = make(map[string]struct{}) + d.hashMu.Lock() + d.hash = make([]byte, len(rootHash)) + copy(d.hash, rootHash) + d.hashMu.Unlock() + + fs.rootDentry = d + if !d.isDir() { ctx.Warningf("verity root must be a directory") return nil, nil, syserror.EINVAL @@ -424,13 +508,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } - d.hashMu.Lock() - copy(d.hash, iopts.RootHash) - d.hashMu.Unlock() d.vfsd.Init(d) - fs.rootDentry = d - return &fs.vfsfs, &d.vfsd, nil } @@ -441,7 +520,7 @@ func (fs *filesystem) Release(ctx context.Context) { // MountOptions implements vfs.FilesystemImpl.MountOptions. func (fs *filesystem) MountOptions() string { - return "" + return fs.opts } // dentry implements vfs.DentryImpl. @@ -722,6 +801,10 @@ type fileDescription struct { // underlying file system. lowerFD *vfs.FileDescription + // lowerMappable is the memmap.Mappable corresponding to this file in the + // underlying file system. + lowerMappable memmap.Mappable + // merkleReader is the read-only FileDescription corresponding to the // Merkle tree file in the underlying file system. merkleReader *vfs.FileDescription @@ -1201,6 +1284,24 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op return 0, syserror.EROFS } +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + if err := fd.lowerFD.ConfigureMMap(ctx, opts); err != nil { + return err + } + fd.lowerMappable = opts.Mappable + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef(ctx) + opts.MappingIdentity = nil + } + + // Check if mmap is allowed on the lower filesystem. + if !opts.SentryOwnedContent { + return syserror.ENODEV + } + return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts) +} + // LockBSD implements vfs.FileDescriptionImpl.LockBSD. func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return fd.lowerFD.LockBSD(ctx, ownerPID, t, block) @@ -1226,6 +1327,115 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t return fd.lowerFD.TestPOSIX(ctx, uid, t, r) } +// Translate implements memmap.Mappable.Translate. +func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { + ts, err := fd.lowerMappable.Translate(ctx, required, optional, at) + if err != nil { + return ts, err + } + + // dataSize is the size of the whole file. + dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the xattr does not exist, it + // indicates unexpected modifications to the file system. + if err == syserror.ENODATA { + return ts, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + } + if err != nil { + return ts, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + size, err := strconv.Atoi(dataSize) + if err != nil { + return ts, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + } + + merkleReader := FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + + for _, t := range ts { + // Content integrity relies on sentry owning the backing data. MapInternal is guaranteed + // to fetch sentry owned memory because we disallow verity mmaps otherwise. + ims, err := t.File.MapInternal(memmap.FileRange{t.Offset, t.Offset + t.Source.Length()}, hostarch.Read) + if err != nil { + return nil, err + } + dataReader := mmapReadSeeker{ims, t.Source.Start} + var buf bytes.Buffer + _, err = merkletree.Verify(&merkletree.VerifyParams{ + Out: &buf, + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), + ReadOffset: int64(t.Source.Start), + ReadSize: int64(t.Source.Length()), + Expected: fd.d.hash, + DataAndTreeInSameFile: false, + }) + if err != nil { + return ts, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + } + } + return ts, err +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (fd *fileDescription) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { + return fd.lowerMappable.AddMapping(ctx, ms, ar, offset, writable) +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (fd *fileDescription) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { + fd.lowerMappable.RemoveMapping(ctx, ms, ar, offset, writable) +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (fd *fileDescription) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { + return fd.lowerMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable) +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (fd *fileDescription) InvalidateUnsavable(context.Context) error { + return nil +} + +// mmapReadSeeker is a helper struct used by fileDescription.Translate to pass +// a safemem.BlockSeq pointing to the mapped region as io.ReaderAt. +type mmapReadSeeker struct { + safemem.BlockSeq + Offset uint64 +} + +// ReadAt implements io.ReaderAt.ReadAt. off is the offset into the mapped file. +func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) { + bs := r.BlockSeq + // Adjust the offset into the mapped file to get the offset into the internally + // mapped region. + readOffset := off - int64(r.Offset) + if readOffset < 0 { + return 0, syserror.EINVAL + } + bs.DropFirst64(uint64(readOffset)) + view := bs.TakeFirst64(uint64(len(p))) + dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p)) + n, err := safemem.CopySeq(dst, view) + return int(n), err +} + // FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as // io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. type FileReadWriteSeeker struct { diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 57bd65202..5c78a0019 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -89,10 +89,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, AllowUserMount: true, }) + data := "root_name=" + rootMerkleFilename mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: data, InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, LowerName: "tmpfs", Alg: hashAlg, AllowRuntimeEnable: true, diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e9eb89378..a1ec6daab 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -141,6 +141,7 @@ go_library( srcs = [ "abstract_socket_namespace.go", "aio.go", + "cgroup.go", "context.go", "fd_table.go", "fd_table_refs.go", @@ -178,6 +179,7 @@ go_library( "task.go", "task_acct.go", "task_block.go", + "task_cgroup.go", "task_clone.go", "task_context.go", "task_exec.go", @@ -241,6 +243,7 @@ go_library( "//pkg/sentry/fs/lock", "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", + "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/fsimpl/timerfd", diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go new file mode 100644 index 000000000..1f1c63f37 --- /dev/null +++ b/pkg/sentry/kernel/cgroup.go @@ -0,0 +1,281 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "bytes" + "fmt" + "sort" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +// InvalidCgroupHierarchyID indicates an uninitialized hierarchy ID. +const InvalidCgroupHierarchyID uint32 = 0 + +// CgroupControllerType is the name of a cgroup controller. +type CgroupControllerType string + +// CgroupController is the common interface to cgroup controllers available to +// the entire sentry. The controllers themselves are defined by cgroupfs. +// +// Callers of this interface are often unable access synchronization needed to +// ensure returned values remain valid. Some of values returned from this +// interface are thus snapshots in time, and may become stale. This is ok for +// many callers like procfs. +type CgroupController interface { + // Returns the type of this cgroup controller (ex "memory", "cpu"). Returned + // value is valid for the lifetime of the controller. + Type() CgroupControllerType + + // Hierarchy returns the ID of the hierarchy this cgroup controller is + // attached to. Returned value is valid for the lifetime of the controller. + HierarchyID() uint32 + + // Filesystem returns the filesystem this controller is attached to. + // Returned value is valid for the lifetime of the controller. + Filesystem() *vfs.Filesystem + + // RootCgroup returns the root cgroup for this controller. Returned value is + // valid for the lifetime of the controller. + RootCgroup() Cgroup + + // NumCgroups returns the number of cgroups managed by this controller. + // Returned value is a snapshot in time. + NumCgroups() uint64 + + // Enabled returns whether this controller is enabled. Returned value is a + // snapshot in time. + Enabled() bool +} + +// Cgroup represents a named pointer to a cgroup in cgroupfs. When a task enters +// a cgroup, it holds a reference on the underlying dentry pointing to the +// cgroup. +// +// +stateify savable +type Cgroup struct { + *kernfs.Dentry + CgroupImpl +} + +func (c *Cgroup) decRef() { + c.Dentry.DecRef(context.Background()) +} + +// Path returns the absolute path of c, relative to its hierarchy root. +func (c *Cgroup) Path() string { + return c.FSLocalPath() +} + +// HierarchyID returns the id of the hierarchy that contains this cgroup. +func (c *Cgroup) HierarchyID() uint32 { + // Note: a cgroup is guaranteed to have at least one controller. + return c.Controllers()[0].HierarchyID() +} + +// CgroupImpl is the common interface to cgroups. +type CgroupImpl interface { + Controllers() []CgroupController + Enter(t *Task) + Leave(t *Task) +} + +// hierarchy represents a cgroupfs filesystem instance, with a unique set of +// controllers attached to it. Multiple cgroupfs mounts may reference the same +// hierarchy. +// +// +stateify savable +type hierarchy struct { + id uint32 + // These are a subset of the controllers in CgroupRegistry.controllers, + // grouped here by hierarchy for conveninent lookup. + controllers map[CgroupControllerType]CgroupController + // fs is not owned by hierarchy. The FS is responsible for unregistering the + // hierarchy on destruction, which removes this association. + fs *vfs.Filesystem +} + +func (h *hierarchy) match(ctypes []CgroupControllerType) bool { + if len(ctypes) != len(h.controllers) { + return false + } + for _, ty := range ctypes { + if _, ok := h.controllers[ty]; !ok { + return false + } + } + return true +} + +// CgroupRegistry tracks the active set of cgroup controllers on the system. +// +// +stateify savable +type CgroupRegistry struct { + // lastHierarchyID is the id of the last allocated cgroup hierarchy. Valid + // ids are from 1 to math.MaxUint32. Must be accessed through atomic ops. + // + lastHierarchyID uint32 + + mu sync.Mutex `state:"nosave"` + + // controllers is the set of currently known cgroup controllers on the + // system. Protected by mu. + // + // +checklocks:mu + controllers map[CgroupControllerType]CgroupController + + // hierarchies is the active set of cgroup hierarchies. Protected by mu. + // + // +checklocks:mu + hierarchies map[uint32]hierarchy +} + +func newCgroupRegistry() *CgroupRegistry { + return &CgroupRegistry{ + controllers: make(map[CgroupControllerType]CgroupController), + hierarchies: make(map[uint32]hierarchy), + } +} + +// nextHierarchyID returns a newly allocated, unique hierarchy ID. +func (r *CgroupRegistry) nextHierarchyID() (uint32, error) { + if hid := atomic.AddUint32(&r.lastHierarchyID, 1); hid != 0 { + return hid, nil + } + return InvalidCgroupHierarchyID, fmt.Errorf("cgroup hierarchy ID overflow") +} + +// FindHierarchy returns a cgroup filesystem containing exactly the set of +// controllers named in names. If no such FS is found, FindHierarchy return +// nil. FindHierarchy takes a reference on the returned FS, which is transferred +// to the caller. +func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Filesystem { + r.mu.Lock() + defer r.mu.Unlock() + + for _, h := range r.hierarchies { + if h.match(ctypes) { + h.fs.IncRef() + return h.fs + } + } + + return nil +} + +// Register registers the provided set of controllers with the registry as a new +// hierarchy. If any controller is already registered, the function returns an +// error without modifying the registry. The hierarchy can be later referenced +// by the returned id. +func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if len(cs) == 0 { + return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers") + } + + for _, c := range cs { + if _, ok := r.controllers[c.Type()]; ok { + return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy") + } + } + + hid, err := r.nextHierarchyID() + if err != nil { + return hid, err + } + + h := hierarchy{ + id: hid, + controllers: make(map[CgroupControllerType]CgroupController), + fs: cs[0].Filesystem(), + } + for _, c := range cs { + n := c.Type() + r.controllers[n] = c + h.controllers[n] = c + } + r.hierarchies[hid] = h + return hid, nil +} + +// Unregister removes a previously registered hierarchy from the registry. If +// the controller was not previously registered, Unregister is a no-op. +func (r *CgroupRegistry) Unregister(hid uint32) { + r.mu.Lock() + defer r.mu.Unlock() + + if h, ok := r.hierarchies[hid]; ok { + for name, _ := range h.controllers { + delete(r.controllers, name) + } + delete(r.hierarchies, hid) + } +} + +// computeInitialGroups takes a reference on each of the returned cgroups. The +// caller takes ownership of this returned reference. +func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[Cgroup]struct{} { + r.mu.Lock() + defer r.mu.Unlock() + + ctlSet := make(map[CgroupControllerType]CgroupController) + cgset := make(map[Cgroup]struct{}) + + // Remember controllers from the inherited cgroups set... + for cg, _ := range inherit { + cg.IncRef() // Ref transferred to caller. + for _, ctl := range cg.Controllers() { + ctlSet[ctl.Type()] = ctl + cgset[cg] = struct{}{} + } + } + + // ... and add the root cgroups of all the missing controllers. + for name, ctl := range r.controllers { + if _, ok := ctlSet[name]; !ok { + cg := ctl.RootCgroup() + cg.IncRef() // Ref transferred to caller. + cgset[cg] = struct{}{} + } + } + return cgset +} + +// GenerateProcCgroups writes the contents of /proc/cgroups to buf. +func (r *CgroupRegistry) GenerateProcCgroups(buf *bytes.Buffer) { + r.mu.Lock() + entries := make([]string, 0, len(r.controllers)) + for _, c := range r.controllers { + en := 0 + if c.Enabled() { + en = 1 + } + entries = append(entries, fmt.Sprintf("%s\t%d\t%d\t%d\n", c.Type(), c.HierarchyID(), c.NumCgroups(), en)) + } + r.mu.Unlock() + + sort.Strings(entries) + fmt.Fprint(buf, "#subsys_name\thierarchy\tnum_cgroups\tenabled\n") + for _, e := range entries { + fmt.Fprint(buf, e) + } +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 43065b45a..9a4fd64cb 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -294,6 +294,11 @@ type Kernel struct { // YAMAPtraceScope is the current level of YAMA ptrace restrictions. YAMAPtraceScope int32 + + // cgroupRegistry contains the set of active cgroup controllers on the + // system. It is controller by cgroupfs. Nil if cgroupfs is unavailable on + // the system. + cgroupRegistry *CgroupRegistry } // InitKernelArgs holds arguments to Init. @@ -438,6 +443,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.socketMount = socketMount k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) + + k.cgroupRegistry = newCgroupRegistry() } return nil } @@ -1815,6 +1822,11 @@ func (k *Kernel) SocketMount() *vfs.Mount { return k.socketMount } +// CgroupRegistry returns the cgroup registry. +func (k *Kernel) CgroupRegistry() *CgroupRegistry { + return k.cgroupRegistry +} + // Release releases resources owned by k. // // Precondition: This should only be called after the kernel is fully @@ -1831,3 +1843,43 @@ func (k *Kernel) Release() { k.timekeeper.Destroy() k.vdso.Release(ctx) } + +// PopulateNewCgroupHierarchy moves all tasks into a newly created cgroup +// hierarchy. +// +// Precondition: root must be a new cgroup with no tasks. This implies the +// controllers for root are also new and currently manage no task, which in turn +// implies the new cgroup can be populated without migrating tasks between +// cgroups. +func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) { + k.tasks.mu.RLock() + k.tasks.forEachTaskLocked(func(t *Task) { + if t.ExitState() != TaskExitNone { + return + } + t.mu.Lock() + t.enterCgroupLocked(root) + t.mu.Unlock() + }) + k.tasks.mu.RUnlock() +} + +// ReleaseCgroupHierarchy moves all tasks out of all cgroups belonging to the +// hierarchy with the provided id. This is intended for use during hierarchy +// teardown, as otherwise the tasks would be orphaned w.r.t to some controllers. +func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) { + k.tasks.mu.RLock() + k.tasks.forEachTaskLocked(func(t *Task) { + if t.ExitState() != TaskExitNone { + return + } + t.mu.Lock() + for cg, _ := range t.cgroups { + if cg.HierarchyID() == hid { + t.leaveCgroupLocked(cg) + } + } + t.mu.Unlock() + }) + k.tasks.mu.RUnlock() +} diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 399985039..be1371855 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -587,6 +587,12 @@ type Task struct { // // kcov is exclusive to the task goroutine. kcov *Kcov + + // cgroups is the set of cgroups this task belongs to. This may be empty if + // no cgroup controllers are enabled. Protected by mu. + // + // +checklocks:mu + cgroups map[Cgroup]struct{} } func (t *Task) savePtraceTracer() *Task { diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go new file mode 100644 index 000000000..25d2504fa --- /dev/null +++ b/pkg/sentry/kernel/task_cgroup.go @@ -0,0 +1,138 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "bytes" + "fmt" + "sort" + "strings" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/syserror" +) + +// EnterInitialCgroups moves t into an initial set of cgroups. +// +// Precondition: t isn't in any cgroups yet, t.cgs is empty. +// +// +checklocksignore parent.mu is conditionally acquired. +func (t *Task) EnterInitialCgroups(parent *Task) { + var inherit map[Cgroup]struct{} + if parent != nil { + parent.mu.Lock() + defer parent.mu.Unlock() + inherit = parent.cgroups + } + joinSet := t.k.cgroupRegistry.computeInitialGroups(inherit) + + t.mu.Lock() + defer t.mu.Unlock() + // Transfer ownership of joinSet refs to the task's cgset. + t.cgroups = joinSet + for c, _ := range t.cgroups { + // Since t isn't in any cgroup yet, we can skip the check against + // existing cgroups. + c.Enter(t) + } +} + +// EnterCgroup moves t into c. +func (t *Task) EnterCgroup(c Cgroup) error { + newControllers := make(map[CgroupControllerType]struct{}) + for _, ctl := range c.Controllers() { + newControllers[ctl.Type()] = struct{}{} + } + + t.mu.Lock() + defer t.mu.Unlock() + + for oldCG, _ := range t.cgroups { + for _, oldCtl := range oldCG.Controllers() { + if _, ok := newControllers[oldCtl.Type()]; ok { + // Already in a cgroup with the same controller as one of the + // new ones. Requires migration between cgroups. + // + // TODO(b/183137098): Implement cgroup migration. + log.Warningf("Cgroup migration is not implemented") + return syserror.EBUSY + } + } + } + + // No migration required. + t.enterCgroupLocked(c) + + return nil +} + +// +checklocks:t.mu +func (t *Task) enterCgroupLocked(c Cgroup) { + c.IncRef() + t.cgroups[c] = struct{}{} + c.Enter(t) +} + +// LeaveCgroups removes t out from all its cgroups. +func (t *Task) LeaveCgroups() { + t.mu.Lock() + defer t.mu.Unlock() + for c, _ := range t.cgroups { + t.leaveCgroupLocked(c) + } +} + +// +checklocks:t.mu +func (t *Task) leaveCgroupLocked(c Cgroup) { + c.Leave(t) + delete(t.cgroups, c) + c.decRef() +} + +// taskCgroupEntry represents a line in /proc/<pid>/cgroup, and is used to +// format a cgroup for display. +type taskCgroupEntry struct { + hierarchyID uint32 + controllers string + path string +} + +// GenerateProcTaskCgroup writes the contents of /proc/<pid>/cgroup for t to buf. +func (t *Task) GenerateProcTaskCgroup(buf *bytes.Buffer) { + t.mu.Lock() + defer t.mu.Unlock() + + cgEntries := make([]taskCgroupEntry, 0, len(t.cgroups)) + for c, _ := range t.cgroups { + ctls := c.Controllers() + ctlNames := make([]string, 0, len(ctls)) + for _, ctl := range ctls { + ctlNames = append(ctlNames, string(ctl.Type())) + } + + cgEntries = append(cgEntries, taskCgroupEntry{ + // Note: We're guaranteed to have at least one controller, and all + // controllers are guaranteed to be on the same hierarchy. + hierarchyID: ctls[0].HierarchyID(), + controllers: strings.Join(ctlNames, ","), + path: c.Path(), + }) + } + + sort.Slice(cgEntries, func(i, j int) bool { return cgEntries[i].hierarchyID > cgEntries[j].hierarchyID }) + for _, cgE := range cgEntries { + fmt.Fprintf(buf, "%d:%s:%s\n", cgE.hierarchyID, cgE.controllers, cgE.path) + } +} diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index ad59e4f60..b1af1a7ef 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -275,6 +275,10 @@ func (*runExitMain) execute(t *Task) taskRunState { t.fsContext.DecRef(t) t.fdTable.DecRef(t) + // Detach task from all cgroups. This must happen before potentially the + // last ref to the cgroupfs mount is dropped below. + t.LeaveCgroups() + t.mu.Lock() if t.mountNamespaceVFS2 != nil { t.mountNamespaceVFS2.DecRef(t) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index fc18b6253..32031cd70 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -151,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { rseqSignature: cfg.RSeqSignature, futexWaiter: futex.NewWaiter(), containerID: cfg.ContainerID, + cgroups: make(map[Cgroup]struct{}), } t.creds.Store(cfg.Credentials) t.endStopCond.L = &t.tg.signalHandlers.mu @@ -189,6 +190,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { t.parent.children[t] = struct{}{} } + if VFS2Enabled { + t.EnterInitialCgroups(t.parent) + } + if tg.leader == nil { // New thread group. tg.leader = t diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 09d070ec8..77ad62445 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -114,6 +114,15 @@ func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) { } } +// forEachTaskLocked applies f to each Task in ts. +// +// Preconditions: ts.mu must be locked (for reading or writing). +func (ts *TaskSet) forEachTaskLocked(f func(t *Task)) { + for t := range ts.Root.tids { + f(t) + } +} + // A PIDNamespace represents a PID namespace, a bimap between thread IDs and // tasks. See the pid_namespaces(7) man page for further details. // diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 72868646a..610686ea0 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -375,6 +375,11 @@ type MMapOpts struct { // // If Force is true, Unmap and Fixed must be true. Force bool + + // SentryOwnedContent indicates the sentry exclusively controls the + // underlying memory backing the mapping thus the memory content is + // guaranteed not to be modified outside the sentry's purview. + SentryOwnedContent bool } // File represents a host file that may be mapped into an platform.AddressSpace. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 47b294312..b307832fd 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -65,8 +65,8 @@ go_test( name = "kvm_test", srcs = [ "kvm_amd64_test.go", - "kvm_arm64_test.go", "kvm_amd64_test.s", + "kvm_arm64_test.go", "kvm_test.go", "virtual_map_test.go", ], diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index 382953cf7..b8dd1e4a5 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -62,10 +62,17 @@ func TestMXCSR(t *testing.T) { PageTables: pt, FullRestore: true, } + + const mxcsrControllMask = uint32(0x1f80) mxcsrBefore := uint32(0) mxcsrAfter := uint32(0) stmxcsr(&mxcsrBefore) - switchOpts.FloatingPointState.SetMXCSR(mxcsrBefore ^ 0x8) + if mxcsrBefore == 0 { + // goruntime sets mxcsr to 0x1f80 and it never changes + // the control configuration. + panic("mxcsr is zero") + } + switchOpts.FloatingPointState.SetMXCSR(0) if _, err := c.SwitchToUser( switchOpts, &si); err == platform.ErrContextInterrupt { return true // Retry. @@ -73,7 +80,7 @@ func TestMXCSR(t *testing.T) { t.Errorf("application syscall failed: %v", err) } stmxcsr(&mxcsrAfter) - if mxcsrAfter != mxcsrBefore { + if mxcsrAfter&mxcsrControllMask != mxcsrBefore&mxcsrControllMask { t.Errorf("mxcsr = %x (expected %x)", mxcsrBefore, mxcsrAfter) } return false diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 03e84d804..cd912f922 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -47,7 +47,7 @@ const ( // Beyond a relatively small number, there are likely few perform // benefits, since the TLB has likely long since lost any translations // from more than a few PCIDs past. - poolPCIDs = 8 + poolPCIDs = 128 ) func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9bdf6d3d8..eff251cec 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -35,12 +35,6 @@ import ( // LINT.IfChange -// minListenBacklog is the minimum reasonable backlog for listening sockets. -const minListenBacklog = 8 - -// maxListenBacklog is the maximum allowed backlog for listening sockets. -const maxListenBacklog = 1024 - // maxAddrLen is the maximum socket address length we're willing to accept. const maxAddrLen = 200 @@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8 // buffers upto INT_MAX. const maxControlLen = 10 * 1024 * 1024 +// maxListenBacklog is the maximum limit of listen backlog supported. +const maxListenBacklog = 1024 + // nameLenOffset is the offset from the start of the MessageHeader64 struct to // the NameLen field. const nameLenOffset = 8 @@ -367,7 +364,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Listen implements the linux syscall listen(2). func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() - backlog := args[1].Int() + backlog := args[1].Uint() // Get socket from the file descriptor. file := t.GetFile(fd) @@ -382,11 +379,13 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ENOTSOCK } - // Per Linux, the backlog is silently capped to reasonable values. - if backlog <= 0 { - backlog = minListenBacklog - } if backlog > maxListenBacklog { + // Linux treats incoming backlog as uint with a limit defined by + // sysctl_somaxconn. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666 + // + // We use the backlog to allocate a channel of that size, hence enforce + // a hard limit for the backlog. backlog = maxListenBacklog } diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index a87a66146..936614eab 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -35,12 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" ) -// minListenBacklog is the minimum reasonable backlog for listening sockets. -const minListenBacklog = 8 - -// maxListenBacklog is the maximum allowed backlog for listening sockets. -const maxListenBacklog = 1024 - // maxAddrLen is the maximum socket address length we're willing to accept. const maxAddrLen = 200 @@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8 // buffers upto INT_MAX. const maxControlLen = 10 * 1024 * 1024 +// maxListenBacklog is the maximum limit of listen backlog supported. +const maxListenBacklog = 1024 + // nameLenOffset is the offset from the start of the MessageHeader64 struct to // the NameLen field. const nameLenOffset = 8 @@ -371,7 +368,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Listen implements the linux syscall listen(2). func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() - backlog := args[1].Int() + backlog := args[1].Uint() // Get socket from the file descriptor. file := t.GetFileVFS2(fd) @@ -386,11 +383,13 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ENOTSOCK } - // Per Linux, the backlog is silently capped to reasonable values. - if backlog <= 0 { - backlog = minListenBacklog - } if backlog > maxListenBacklog { + // Linux treats incoming backlog as uint with a limit defined by + // sysctl_somaxconn. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666 + // + // We use the backlog to allocate a channel of that size, hence enforce + // a hard limit for the backlog. backlog = maxListenBacklog } diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 1556b41a3..b87d9690a 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -252,6 +252,9 @@ type WritableDynamicBytesSource interface { // are backed by a bytes.Buffer that is regenerated when necessary, consistent // with Linux's fs/seq_file.c:single_open(). // +// If data additionally implements WritableDynamicBytesSource, writes are +// dispatched to the implementer. The source data is not automatically modified. +// // DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first // use. // diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 922f9e697..7cdab6945 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -970,17 +970,22 @@ func superBlockOpts(mountPath string, mnt *Mount) string { opts += "," + mopts } - // NOTE(b/147673608): If the mount is a cgroup, we also need to include - // the cgroup name in the options. For now we just read that from the - // path. + // NOTE(b/147673608): If the mount is a ramdisk-based fake cgroupfs, we also + // need to include the cgroup name in the options. For now we just read that + // from the path. Note that this is only possible when "cgroup" isn't + // registered as a valid filesystem type. // - // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we - // should get this value from the cgroup itself, and not rely on the - // path. + // TODO(gvisor.dev/issue/190): Once we removed fake cgroupfs support, we + // should remove this. + if cgroupfs := mnt.vfs.getFilesystemType("cgroup"); cgroupfs != nil && cgroupfs.opts.AllowUserMount { + // Real cgroupfs available. + return opts + } if mnt.fs.FilesystemType().Name() == "cgroup" { splitPath := strings.Split(mountPath, "/") cgroupType := splitPath[len(splitPath)-1] opts += "," + cgroupType } + return opts } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index f588311e0..85bd164cd 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -178,6 +178,26 @@ const ( IPv4FlagDontFragment ) +// ipv4LinkLocalUnicastSubnet is the IPv4 link local unicast subnet as defined +// by RFC 3927 section 1. +var ipv4LinkLocalUnicastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet("\xa9\xfe\x00\x00", tcpip.AddressMask("\xff\xff\x00\x00")) + if err != nil { + panic(err) + } + return subnet +}() + +// ipv4LinkLocalMulticastSubnet is the IPv4 link local multicast subnet as +// defined by RFC 5771 section 4. +var ipv4LinkLocalMulticastSubnet = func() tcpip.Subnet { + subnet, err := tcpip.NewSubnet("\xe0\x00\x00\x00", tcpip.AddressMask("\xff\xff\xff\x00")) + if err != nil { + panic(err) + } + return subnet +}() + // IPv4EmptySubnet is the empty IPv4 subnet. var IPv4EmptySubnet = func() tcpip.Subnet { subnet, err := tcpip.NewSubnet(IPv4Any, tcpip.AddressMask(IPv4Any)) @@ -423,6 +443,18 @@ func (b IPv4) IsValid(pktSize int) bool { return true } +// IsV4LinkLocalUnicastAddress determines if the provided address is an IPv4 +// link-local unicast address. +func IsV4LinkLocalUnicastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalUnicastSubnet.Contains(addr) +} + +// IsV4LinkLocalMulticastAddress determines if the provided address is an IPv4 +// link-local multicast address. +func IsV4LinkLocalMulticastAddress(addr tcpip.Address) bool { + return ipv4LinkLocalMulticastSubnet.Contains(addr) +} + // IsV4MulticastAddress determines if the provided address is an IPv4 multicast // address (range 224.0.0.0 to 239.255.255.255). The four most significant bits // will be 1110 = 0xe0. diff --git a/pkg/tcpip/header/ipv4_test.go b/pkg/tcpip/header/ipv4_test.go index 6475cd694..c02fe898b 100644 --- a/pkg/tcpip/header/ipv4_test.go +++ b/pkg/tcpip/header/ipv4_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -177,3 +178,77 @@ func TestIPv4EncodeOptions(t *testing.T) { }) } } + +func TestIsV4LinkLocalUnicastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid (lowest)", + addr: "\xa9\xfe\x00\x00", + expected: true, + }, + { + name: "Valid (highest)", + addr: "\xa9\xfe\xff\xff", + expected: true, + }, + { + name: "Invalid (before subnet)", + addr: "\xa9\xfd\xff\xff", + expected: false, + }, + { + name: "Invalid (after subnet)", + addr: "\xa9\xff\x00\x00", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV4LinkLocalUnicastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV4LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} + +func TestIsV4LinkLocalMulticastAddress(t *testing.T) { + tests := []struct { + name string + addr tcpip.Address + expected bool + }{ + { + name: "Valid (lowest)", + addr: "\xe0\x00\x00\x00", + expected: true, + }, + { + name: "Valid (highest)", + addr: "\xe0\x00\x00\xff", + expected: true, + }, + { + name: "Invalid (before subnet)", + addr: "\xdf\xff\xff\xff", + expected: false, + }, + { + name: "Invalid (after subnet)", + addr: "\xe0\x00\x01\x00", + expected: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := header.IsV4LinkLocalMulticastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV4LinkLocalMulticastAddress(%s) = %t, want = %t", test.addr, got, test.expected) + } + }) + } +} diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index f2403978c..c3a0407ac 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -98,12 +98,27 @@ const ( // The address is ff02::1. IPv6AllNodesMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - // IPv6AllRoutersMulticastAddress is a link-local multicast group that - // all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local + // multicast group that all IPv6 routers MUST join, as per RFC 4291, section + // 2.8. Packets destined to this address will reach the router on an + // interface. + // + // The address is ff01::2. + IPv6AllRoutersInterfaceLocalMulticastAddress tcpip.Address = "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets // destined to this address will reach all routers on a link. // // The address is ff02::2. - IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + IPv6AllRoutersLinkLocalMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group + // that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets + // destined to this address will reach all routers in a site. + // + // The address is ff05::2. + IPv6AllRoutersSiteLocalMulticastAddress tcpip.Address = "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, // section 5: @@ -142,11 +157,6 @@ const ( // ipv6MulticastAddressScopeMask is the mask for the scope (scop) field, // within the byte holding the field, as per RFC 4291 section 2.7. ipv6MulticastAddressScopeMask = 0xF - - // ipv6LinkLocalMulticastScope is the value of the scope (scop) field within - // a multicast IPv6 address that indicates the address has link-local scope, - // as per RFC 4291 section 2.7. - ipv6LinkLocalMulticastScope = 2 ) // IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the @@ -381,25 +391,25 @@ func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { return tcpip.Address(lladdrb[:]) } -// IsV6LinkLocalAddress determines if the provided address is an IPv6 -// link-local address (fe80::/10). -func IsV6LinkLocalAddress(addr tcpip.Address) bool { +// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6 +// link-local unicast address, as defined by RFC 4291 section 2.5.6. +func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool { if len(addr) != IPv6AddressSize { return false } return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } -// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback -// address. +// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback +// address, as defined by RFC 4291 section 2.5.3. func IsV6LoopbackAddress(addr tcpip.Address) bool { return addr == IPv6Loopback } -// IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 -// link-local multicast address. +// IsV6LinkLocalMulticastAddress returns true iff the provided address is an +// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7. func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { - return IsV6MulticastAddress(addr) && addr[ipv6MulticastAddressScopeByteIdx]&ipv6MulticastAddressScopeMask == ipv6LinkLocalMulticastScope + return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope } // AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier @@ -462,7 +472,7 @@ func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) { case IsV6LinkLocalMulticastAddress(addr): return LinkLocalScope, nil - case IsV6LinkLocalAddress(addr): + case IsV6LinkLocalUnicastAddress(addr): return LinkLocalScope, nil default: @@ -520,3 +530,46 @@ func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) PrefixLen: IIDOffsetInIPv6Address * 8, } } + +// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by +// RFC 7346 section 2. +type IPv6MulticastScope uint8 + +// The various values for IPv6 multicast scopes, as per RFC 7346 section 2: +// +// +------+--------------------------+-------------------------+ +// | scop | NAME | REFERENCE | +// +------+--------------------------+-------------------------+ +// | 0 | Reserved | [RFC4291], RFC 7346 | +// | 1 | Interface-Local scope | [RFC4291], RFC 7346 | +// | 2 | Link-Local scope | [RFC4291], RFC 7346 | +// | 3 | Realm-Local scope | [RFC4291], RFC 7346 | +// | 4 | Admin-Local scope | [RFC4291], RFC 7346 | +// | 5 | Site-Local scope | [RFC4291], RFC 7346 | +// | 6 | Unassigned | | +// | 7 | Unassigned | | +// | 8 | Organization-Local scope | [RFC4291], RFC 7346 | +// | 9 | Unassigned | | +// | A | Unassigned | | +// | B | Unassigned | | +// | C | Unassigned | | +// | D | Unassigned | | +// | E | Global scope | [RFC4291], RFC 7346 | +// | F | Reserved | [RFC4291], RFC 7346 | +// +------+--------------------------+-------------------------+ +const ( + IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0) + IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1) + IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2) + IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3) + IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4) + IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5) + IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8) + IPv6GlobalMulticastScope = IPv6MulticastScope(0xE) + IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF) +) + +// V6MulticastScope returns the scope of a multicast address. +func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope { + return IPv6MulticastScope(addr[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask) +} diff --git a/pkg/tcpip/header/ipv6_test.go b/pkg/tcpip/header/ipv6_test.go index f10f446a6..ccee9000e 100644 --- a/pkg/tcpip/header/ipv6_test.go +++ b/pkg/tcpip/header/ipv6_test.go @@ -252,7 +252,7 @@ func TestIsV6LinkLocalMulticastAddress(t *testing.T) { } } -func TestIsV6LinkLocalAddress(t *testing.T) { +func TestIsV6LinkLocalUnicastAddress(t *testing.T) { tests := []struct { name string addr tcpip.Address @@ -287,8 +287,8 @@ func TestIsV6LinkLocalAddress(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if got := header.IsV6LinkLocalAddress(test.addr); got != test.expected { - t.Errorf("got header.IsV6LinkLocalAddress(%s) = %t, want = %t", test.addr, got, test.expected) + if got := header.IsV6LinkLocalUnicastAddress(test.addr); got != test.expected { + t.Errorf("got header.IsV6LinkLocalUnicastAddress(%s) = %t, want = %t", test.addr, got, test.expected) } }) } @@ -373,3 +373,83 @@ func TestSolicitedNodeAddr(t *testing.T) { }) } } + +func TestV6MulticastScope(t *testing.T) { + tests := []struct { + addr tcpip.Address + want header.IPv6MulticastScope + }{ + { + addr: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6Reserved0MulticastScope, + }, + { + addr: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6InterfaceLocalMulticastScope, + }, + { + addr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6LinkLocalMulticastScope, + }, + { + addr: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6RealmLocalMulticastScope, + }, + { + addr: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6AdminLocalMulticastScope, + }, + { + addr: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6SiteLocalMulticastScope, + }, + { + addr: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(6), + }, + { + addr: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(7), + }, + { + addr: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6OrganizationLocalMulticastScope, + }, + { + addr: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(9), + }, + { + addr: "\xff\x0a\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(10), + }, + { + addr: "\xff\x0b\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(11), + }, + { + addr: "\xff\x0c\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(12), + }, + { + addr: "\xff\x0d\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6MulticastScope(13), + }, + { + addr: "\xff\x0e\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6GlobalMulticastScope, + }, + { + addr: "\xff\x0f\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + want: header.IPv6ReservedFMulticastScope, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%s", test.addr), func(t *testing.T) { + if got := header.V6MulticastScope(test.addr); got != test.want { + t.Fatalf("got header.V6MulticastScope(%s) = %d, want = %d", test.addr, got, test.want) + } + }) + } +} diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go index b9f129728..ac35d81e7 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol.go @@ -156,14 +156,6 @@ type GenericMulticastProtocolOptions struct { // // Unsolicited reports are transmitted when a group is newly joined. MaxUnsolicitedReportDelay time.Duration - - // AllNodesAddress is a multicast address that all nodes on a network should - // be a member of. - // - // This address will not have the generic multicast protocol performed on it; - // it will be left in the non member/listener state, and packets will never - // be sent for it. - AllNodesAddress tcpip.Address } // MulticastGroupProtocol is a multicast group protocol whose core state machine @@ -188,6 +180,10 @@ type MulticastGroupProtocol interface { // SendLeave sends a multicast leave for the specified group address. SendLeave(groupAddress tcpip.Address) tcpip.Error + + // ShouldPerformProtocol returns true iff the protocol should be performed for + // the specified group. + ShouldPerformProtocol(tcpip.Address) bool } // GenericMulticastProtocolState is the per interface generic multicast protocol @@ -455,20 +451,7 @@ func (g *GenericMulticastProtocolState) initializeNewMemberLocked(groupAddress t info.lastToSendReport = false - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { info.state = idleMember return } @@ -537,20 +520,7 @@ func (g *GenericMulticastProtocolState) maybeSendLeave(groupAddress tcpip.Addres return } - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { return } @@ -627,20 +597,7 @@ func (g *GenericMulticastProtocolState) setDelayTimerForAddressRLocked(groupAddr return } - if groupAddress == g.opts.AllNodesAddress { - // As per RFC 2236 section 6 page 10 (for IGMPv2), - // - // The all-systems group (address 224.0.0.1) is handled as a special - // case. The host starts in Idle Member state for that group on every - // interface, never transitions to another state, and never sends a - // report for that group. - // - // As per RFC 2710 section 5 page 10 (for MLDv1), - // - // The link-scope all-nodes address (FF02::1) is handled as a special - // case. The node starts in Idle Listener state for that address on - // every interface, never transitions to another state, and never sends - // a Report or Done for that address. + if !g.opts.Protocol.ShouldPerformProtocol(groupAddress) { return } diff --git a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go index 381460c82..0b51563cd 100644 --- a/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go +++ b/pkg/tcpip/network/internal/ip/generic_multicast_protocol_test.go @@ -43,6 +43,8 @@ type mockMulticastGroupProtocolProtectedFields struct { type mockMulticastGroupProtocol struct { t *testing.T + skipProtocolAddress tcpip.Address + mu mockMulticastGroupProtocolProtectedFields } @@ -165,6 +167,11 @@ func (m *mockMulticastGroupProtocol) SendLeave(groupAddress tcpip.Address) tcpip return nil } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (m *mockMulticastGroupProtocol) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + return groupAddress != m.skipProtocolAddress +} + func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Address, sendLeaveGroupAddresses []tcpip.Address) string { m.mu.Lock() defer m.mu.Unlock() @@ -193,10 +200,11 @@ func (m *mockMulticastGroupProtocol) check(sendReportGroupAddresses []tcpip.Addr cmp.FilterPath( func(p cmp.Path) bool { switch p.Last().String() { - case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup": + case ".RWMutex", ".t", ".makeQueuePackets", ".disabled", ".genericMulticastGroup", ".skipProtocolAddress": return true + default: + return false } - return false }, cmp.Ignore(), ), @@ -225,14 +233,13 @@ func TestJoinGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(0)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, }) // Joining a group should send a report immediately and another after @@ -279,14 +286,13 @@ func TestLeaveGroup(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr2} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(1)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr2, }) mgp.joinGroup(test.addr) @@ -356,14 +362,13 @@ func TestHandleReport(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(2)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) @@ -446,14 +451,13 @@ func TestHandleQuery(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) @@ -574,14 +578,13 @@ func TestJoinCount(t *testing.T) { } func TestMakeAllNonMemberAndInitialize(t *testing.T) { - mgp := mockMulticastGroupProtocol{t: t} + mgp := mockMulticastGroupProtocol{t: t, skipProtocolAddress: addr3} clock := faketime.NewManualClock() mgp.init(ip.GenericMulticastProtocolOptions{ Rand: rand.New(rand.NewSource(3)), Clock: clock, MaxUnsolicitedReportDelay: maxUnsolicitedReportDelay, - AllNodesAddress: addr3, }) mgp.joinGroup(addr1) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index a4edc69c7..58fd18af8 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -15,6 +15,7 @@ package ip_test import ( + "fmt" "strings" "testing" @@ -1938,3 +1939,80 @@ func TestICMPInclusionSize(t *testing.T) { }) } } + +func TestJoinLeaveAllRoutersGroup(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + protoFactory stack.NetworkProtocolFactory + allRoutersAddr tcpip.Address + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + protoFactory: ipv4.NewProtocol, + allRoutersAddr: header.IPv4AllRoutersGroup, + }, + { + name: "IPv6 Interface Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersInterfaceLocalMulticastAddress, + }, + { + name: "IPv6 Link Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersLinkLocalMulticastAddress, + }, + { + name: "IPv6 Site Local", + netProto: ipv6.ProtocolNumber, + protoFactory: ipv6.NewProtocol, + allRoutersAddr: header.IPv6AllRoutersSiteLocalMulticastAddress, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, nicDisabled := range [...]bool{true, false} { + t.Run(fmt.Sprintf("NIC Disabled = %t", nicDisabled), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + }) + opts := stack.NICOptions{Disabled: nicDisabled} + if err := s.CreateNICWithOptions(nicID, channel.New(0, 0, ""), opts); err != nil { + t.Fatalf("CreateNICWithOptions(%d, _, %#v) = %s", nicID, opts, err) + } + + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if !got { + t.Fatalf("got s.IsInGroup(%d, %s) = false, want = true", nicID, test.allRoutersAddr) + } + + if err := s.SetForwarding(test.netProto, false); err != nil { + t.Fatalf("s.SetForwarding(%d, false): %s", test.netProto, err) + } + if got, err := s.IsInGroup(nicID, test.allRoutersAddr); err != nil { + t.Fatalf("s.IsInGroup(%d, %s): %s", nicID, test.allRoutersAddr, err) + } else if got { + t.Fatalf("got s.IsInGroup(%d, %s) = true, want = false", nicID, test.allRoutersAddr) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv4/igmp.go b/pkg/tcpip/network/ipv4/igmp.go index f3fc1c87e..b1ac29294 100644 --- a/pkg/tcpip/network/ipv4/igmp.go +++ b/pkg/tcpip/network/ipv4/igmp.go @@ -126,6 +126,17 @@ func (igmp *igmpState) SendLeave(groupAddress tcpip.Address) tcpip.Error { return err } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (igmp *igmpState) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + // As per RFC 2236 section 6 page 10, + // + // The all-systems group (address 224.0.0.1) is handled as a special + // case. The host starts in Idle Member state for that group on every + // interface, never transitions to another state, and never sends a + // report for that group. + return groupAddress != header.IPv4AllSystems +} + // init sets up an igmpState struct, and is required to be called before using // a new igmpState. // @@ -137,7 +148,6 @@ func (igmp *igmpState) init(ep *endpoint) { Clock: ep.protocol.stack.Clock(), Protocol: igmp, MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, - AllNodesAddress: header.IPv4AllSystems, }) igmp.igmpV1Present = igmpV1PresentDefault igmp.igmpV1Job = ep.protocol.stack.NewJob(&ep.mu, func() { diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1a5661ca4..2e44f8523 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -150,6 +150,38 @@ func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { delete(p.mu.eps, nicID) } +// transitionForwarding transitions the endpoint's forwarding status to +// forwarding. +// +// Must only be called when the forwarding status changes. +func (e *endpoint) transitionForwarding(forwarding bool) { + e.mu.Lock() + defer e.mu.Unlock() + + if forwarding { + // There does not seem to be an RFC requirement for a node to join the all + // routers multicast address but + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml + // specifies the address as a group for all routers on a subnet so we join + // the group here. + if err := e.joinGroupLocked(header.IPv4AllRoutersGroup); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv4 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } + + return + } + + switch err := e.leaveGroupLocked(header.IPv4AllRoutersGroup).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", header.IPv4AllRoutersGroup, err)) + } +} + // Enable implements stack.NetworkEndpoint. func (e *endpoint) Enable() tcpip.Error { e.mu.Lock() @@ -226,7 +258,7 @@ func (e *endpoint) disableLocked() { } // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv4AllSystems); err.(type) { + switch err := e.leaveGroupLocked(header.IPv4AllSystems).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv4AllSystems, err)) @@ -551,6 +583,22 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // forwardPacket attempts to forward a packet to its final destination. func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv4(pkt.NetworkHeader().View()) + + dstAddr := h.DestinationAddress() + if header.IsV4LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV4LinkLocalUnicastAddress(dstAddr) || header.IsV4LinkLocalMulticastAddress(dstAddr) { + // As per RFC 3927 section 7, + // + // A router MUST NOT forward a packet with an IPv4 Link-Local source or + // destination address, irrespective of the router's default route + // configuration or routes obtained from dynamic routing protocols. + // + // A router which receives a packet with an IPv4 Link-Local source or + // destination address MUST NOT forward the packet. This prevents + // forwarding of packets back onto the network segment from which they + // originated, or to any other segment. + return nil + } + ttl := h.TTL() if ttl == 0 { // As per RFC 792 page 6, Time Exceeded Message, @@ -589,8 +637,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { } } - dstAddr := h.DestinationAddress() - // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { ep.handleValidatedPacket(h, pkt) @@ -1168,12 +1214,27 @@ func (p *protocol) Forwarding() bool { return uint8(atomic.LoadUint32(&p.forwarding)) == 1 } +// setForwarding sets the forwarding status for the protocol. +// +// Returns true if the forwarding status was updated. +func (p *protocol) setForwarding(v bool) bool { + if v { + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) + } + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) +} + // SetForwarding implements stack.ForwardingNetworkProtocol. func (p *protocol) SetForwarding(v bool) { - if v { - atomic.StoreUint32(&p.forwarding, 1) - } else { - atomic.StoreUint32(&p.forwarding, 0) + p.mu.Lock() + defer p.mu.Unlock() + + if !p.setForwarding(v) { + return + } + + for _, ep := range p.mu.eps { + ep.transitionForwarding(v) } } diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index a142b76c1..b2a80e1e9 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -273,7 +273,7 @@ func isMLDValid(pkt *stack.PacketBuffer, iph header.IPv6, routerAlert *header.IP if iph.HopLimit() != header.MLDHopLimit { return false } - if !header.IsV6LinkLocalAddress(iph.SourceAddress()) { + if !header.IsV6LinkLocalUnicastAddress(iph.SourceAddress()) { return false } return true @@ -804,7 +804,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r routerAddr := srcAddr // Is the IP Source Address a link-local address? - if !header.IsV6LinkLocalAddress(routerAddr) { + if !header.IsV6LinkLocalUnicastAddress(routerAddr) { // ...No, silently drop the packet. received.invalid.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c6d9d8f0d..d36cefcd0 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -314,7 +314,7 @@ func (e *endpoint) onAddressAssignedLocked(addr tcpip.Address) { // Snooping switches MUST manage multicast forwarding state based on MLD // Report and Done messages sent with the unspecified address as the // IPv6 source address. - if header.IsV6LinkLocalAddress(addr) { + if header.IsV6LinkLocalUnicastAddress(addr) { e.mu.mld.sendQueuedReports() } } @@ -410,22 +410,65 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr t // // Must only be called when the forwarding status changes. func (e *endpoint) transitionForwarding(forwarding bool) { + allRoutersGroups := [...]tcpip.Address{ + header.IPv6AllRoutersInterfaceLocalMulticastAddress, + header.IPv6AllRoutersLinkLocalMulticastAddress, + header.IPv6AllRoutersSiteLocalMulticastAddress, + } + e.mu.Lock() defer e.mu.Unlock() - if !e.Enabled() { - return - } - if forwarding { // When transitioning into an IPv6 router, host-only state (NDP discovered // routers, discovered on-link prefixes, and auto-generated addresses) is // cleaned up/invalidated and NDP router solicitations are stopped. e.mu.ndp.stopSolicitingRouters() e.mu.ndp.cleanupState(true /* hostOnly */) - } else { - // When transitioning into an IPv6 host, NDP router solicitations are - // started. + + // As per RFC 4291 section 2.8: + // + // A router is required to recognize all addresses that a host is + // required to recognize, plus the following addresses as identifying + // itself: + // + // o The All-Routers multicast addresses defined in Section 2.7.1. + // + // As per RFC 4291 section 2.7.1, + // + // All Routers Addresses: FF01:0:0:0:0:0:0:2 + // FF02:0:0:0:0:0:0:2 + // FF05:0:0:0:0:0:0:2 + // + // The above multicast addresses identify the group of all IPv6 routers, + // within scope 1 (interface-local), 2 (link-local), or 5 (site-local). + for _, g := range allRoutersGroups { + if err := e.joinGroupLocked(g); err != nil { + // joinGroupLocked only returns an error if the group address is not a + // valid IPv6 multicast address. + panic(fmt.Sprintf("e.joinGroupLocked(%s): %s", g, err)) + } + } + + return + } + + for _, g := range allRoutersGroups { + switch err := e.leaveGroupLocked(g).(type) { + case nil: + case *tcpip.ErrBadLocalAddress: + // The endpoint may have already left the multicast group. + default: + panic(fmt.Sprintf("e.leaveGroupLocked(%s): %s", g, err)) + } + } + + // When transitioning into an IPv6 host, NDP router solicitations are + // started if the endpoint is enabled. + // + // If the endpoint is not currently enabled, routers will be solicited when + // the endpoint becomes enabled (if it is still a host). + if e.Enabled() { e.mu.ndp.startSolicitingRouters() } } @@ -573,7 +616,7 @@ func (e *endpoint) disableLocked() { e.mu.ndp.cleanupState(false /* hostOnly */) // The endpoint may have already left the multicast group. - switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { + switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress).(type) { case nil, *tcpip.ErrBadLocalAddress: default: panic(fmt.Sprintf("unexpected error when leaving group = %s: %s", header.IPv6AllNodesMulticastAddress, err)) @@ -869,6 +912,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // forwardPacket attempts to forward a packet to its final destination. func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { h := header.IPv6(pkt.NetworkHeader().View()) + + dstAddr := h.DestinationAddress() + if header.IsV6LinkLocalUnicastAddress(h.SourceAddress()) || header.IsV6LinkLocalUnicastAddress(dstAddr) || header.IsV6LinkLocalMulticastAddress(dstAddr) { + // As per RFC 4291 section 2.5.6, + // + // Routers must not forward any packets with Link-Local source or + // destination addresses to other links. + return nil + } + hopLimit := h.HopLimit() if hopLimit <= 1 { // As per RFC 4443 section 3.3, @@ -881,8 +934,6 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { return e.protocol.returnError(&icmpReasonHopLimitExceeded{}, pkt) } - dstAddr := h.DestinationAddress() - // Check if the destination is owned by the stack. if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { ep.handleValidatedPacket(h, pkt) @@ -1571,7 +1622,7 @@ func (e *endpoint) getLinkLocalAddressRLocked() tcpip.Address { var linkLocalAddr tcpip.Address e.mu.addressableEndpointState.ForEachPrimaryEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { if addressEndpoint.IsAssigned(false /* allowExpired */) { - if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalAddress(addr) { + if addr := addressEndpoint.AddressWithPrefix().Address; header.IsV6LinkLocalUnicastAddress(addr) { linkLocalAddr = addr return false } @@ -1979,9 +2030,9 @@ func (p *protocol) Forwarding() bool { // Returns true if the forwarding status was updated. func (p *protocol) setForwarding(v bool) bool { if v { - return atomic.SwapUint32(&p.forwarding, 1) == 0 + return atomic.CompareAndSwapUint32(&p.forwarding, 0 /* old */, 1 /* new */) } - return atomic.SwapUint32(&p.forwarding, 0) == 1 + return atomic.CompareAndSwapUint32(&p.forwarding, 1 /* old */, 0 /* new */) } // SetForwarding implements stack.ForwardingNetworkProtocol. diff --git a/pkg/tcpip/network/ipv6/mld.go b/pkg/tcpip/network/ipv6/mld.go index dd153466d..165b7d2d2 100644 --- a/pkg/tcpip/network/ipv6/mld.go +++ b/pkg/tcpip/network/ipv6/mld.go @@ -76,10 +76,29 @@ func (mld *mldState) SendReport(groupAddress tcpip.Address) (bool, tcpip.Error) // // Precondition: mld.ep.mu must be read locked. func (mld *mldState) SendLeave(groupAddress tcpip.Address) tcpip.Error { - _, err := mld.writePacket(header.IPv6AllRoutersMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) + _, err := mld.writePacket(header.IPv6AllRoutersLinkLocalMulticastAddress, groupAddress, header.ICMPv6MulticastListenerDone) return err } +// ShouldPerformProtocol implements ip.MulticastGroupProtocol. +func (mld *mldState) ShouldPerformProtocol(groupAddress tcpip.Address) bool { + // As per RFC 2710 section 5 page 10, + // + // The link-scope all-nodes address (FF02::1) is handled as a special + // case. The node starts in Idle Listener state for that address on + // every interface, never transitions to another state, and never sends + // a Report or Done for that address. + // + // MLD messages are never sent for multicast addresses whose scope is 0 + // (reserved) or 1 (node-local). + if groupAddress == header.IPv6AllNodesMulticastAddress { + return false + } + + scope := header.V6MulticastScope(groupAddress) + return scope != header.IPv6Reserved0MulticastScope && scope != header.IPv6InterfaceLocalMulticastScope +} + // init sets up an mldState struct, and is required to be called before using // a new mldState. // @@ -91,7 +110,6 @@ func (mld *mldState) init(ep *endpoint) { Clock: ep.protocol.stack.Clock(), Protocol: mld, MaxUnsolicitedReportDelay: UnsolicitedReportIntervalMax, - AllNodesAddress: header.IPv6AllNodesMulticastAddress, }) } diff --git a/pkg/tcpip/network/ipv6/mld_test.go b/pkg/tcpip/network/ipv6/mld_test.go index 85a8f9944..146b300f1 100644 --- a/pkg/tcpip/network/ipv6/mld_test.go +++ b/pkg/tcpip/network/ipv6/mld_test.go @@ -93,7 +93,7 @@ func TestIPv6JoinLeaveSolicitedNodeAddressPerformsMLD(t *testing.T) { if p, ok := e.Read(); !ok { t.Fatal("expected a done message to be sent") } else { - validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), header.IPv6Any, header.IPv6AllRoutersLinkLocalMulticastAddress, header.ICMPv6MulticastListenerDone, linkLocalAddrSNMC) } } @@ -464,3 +464,141 @@ func TestMLDPacketValidation(t *testing.T) { }) } } + +func TestMLDSkipProtocol(t *testing.T) { + const nicID = 1 + + tests := []struct { + name string + group tcpip.Address + expectReport bool + }{ + { + name: "Reserverd0", + group: "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: false, + }, + { + name: "Interface Local", + group: "\xff\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: false, + }, + { + name: "Link Local", + group: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Realm Local", + group: "\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Admin Local", + group: "\xff\x04\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Site Local", + group: "\xff\x05\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(6)", + group: "\xff\x06\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(7)", + group: "\xff\x07\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Organization Local", + group: "\xff\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(9)", + group: "\xff\x09\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(A)", + group: "\xff\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(B)", + group: "\xff\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(C)", + group: "\xff\x0c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Unassigned(D)", + group: "\xff\x0d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "Global", + group: "\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + { + name: "ReservedF", + group: "\xff\x0f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11", + expectReport: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv6.NewProtocolWithOptions(ipv6.Options{ + MLD: ipv6.MLDOptions{ + Enabled: true, + }, + })}, + }) + e := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv6.ProtocolNumber, linkLocalAddr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ipv6.ProtocolNumber, linkLocalAddr, err) + } + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, linkLocalAddrSNMC, header.ICMPv6MulticastListenerReport, linkLocalAddrSNMC) + } + + if err := s.JoinGroup(ipv6.ProtocolNumber, nicID, test.group); err != nil { + t.Fatalf("s.JoinGroup(%d, %d, %s): %s", ipv6.ProtocolNumber, nicID, test.group, err) + } + if isInGroup, err := s.IsInGroup(nicID, test.group); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.group, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.group) + } + + if !test.expectReport { + if p, ok := e.Read(); ok { + t.Fatalf("got e.Read() = (%#v, true), want = (_, false)", p) + } + + return + } + + if p, ok := e.Read(); !ok { + t.Fatal("expected a report message to be sent") + } else { + validateMLDPacket(t, stack.PayloadSince(p.Pkt.NetworkHeader()), linkLocalAddr, test.group, header.ICMPv6MulticastListenerReport, test.group) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index 536493f87..a110faa54 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -737,7 +737,7 @@ func (ndp *ndpState) handleRA(ip tcpip.Address, ra header.NDPRouterAdvert) { prefix := opt.Subnet() // Is the prefix a link-local? - if header.IsV6LinkLocalAddress(prefix.ID()) { + if header.IsV6LinkLocalUnicastAddress(prefix.ID()) { // ...Yes, skip as per RFC 4861 section 6.3.4, // and RFC 4862 section 5.5.3.b (for SLAAC). continue @@ -1703,7 +1703,7 @@ func (ndp *ndpState) startSolicitingRouters() { // the unspecified address if no address is assigned // to the sending interface. localAddr := header.IPv6Any - if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersMulticastAddress, false); addressEndpoint != nil { + if addressEndpoint := ndp.ep.AcquireOutgoingPrimaryAddress(header.IPv6AllRoutersLinkLocalMulticastAddress, false); addressEndpoint != nil { localAddr = addressEndpoint.AddressWithPrefix().Address addressEndpoint.DecRef() } @@ -1730,7 +1730,7 @@ func (ndp *ndpState) startSolicitingRouters() { icmpData.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpData, Src: localAddr, - Dst: header.IPv6AllRoutersMulticastAddress, + Dst: header.IPv6AllRoutersLinkLocalMulticastAddress, })) pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -1739,14 +1739,14 @@ func (ndp *ndpState) startSolicitingRouters() { }) sent := ndp.ep.stats.icmp.packetsSent - if err := addIPHeader(localAddr, header.IPv6AllRoutersMulticastAddress, pkt, stack.NetworkHeaderParams{ + if err := addIPHeader(localAddr, header.IPv6AllRoutersLinkLocalMulticastAddress, pkt, stack.NetworkHeaderParams{ Protocol: header.ICMPv6ProtocolNumber, TTL: header.NDPHopLimit, }, nil /* extensionHeaders */); err != nil { panic(fmt.Sprintf("failed to add IP header: %s", err)) } - if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { + if err := ndp.ep.nic.WritePacketToRemote(header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress), nil /* gso */, ProtocolNumber, pkt); err != nil { sent.dropped.Increment() // Don't send any more messages if we had an error. remaining = 0 diff --git a/pkg/tcpip/network/multicast_group_test.go b/pkg/tcpip/network/multicast_group_test.go index ecd5003a7..2aa4e6d75 100644 --- a/pkg/tcpip/network/multicast_group_test.go +++ b/pkg/tcpip/network/multicast_group_test.go @@ -194,7 +194,7 @@ func checkInitialIPv6Groups(t *testing.T, e *channel.Endpoint, s *stack.Stack, c if p, ok := e.Read(); !ok { t.Fatal("expected a report message to be sent") } else { - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6AddrSNMC) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6AddrSNMC) } // Should not send any more packets. @@ -606,7 +606,7 @@ func TestMGPLeaveGroup(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, ipv6MulticastAddr1) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, ipv6MulticastAddr1) }, checkInitialGroups: checkInitialIPv6Groups, }, @@ -1014,7 +1014,7 @@ func TestMGPWithNICLifecycle(t *testing.T) { validateLeave: func(t *testing.T, p channel.PacketInfo, addr tcpip.Address) { t.Helper() - validateMLDPacket(t, p, header.IPv6AllRoutersMulticastAddress, mldDone, 0, addr) + validateMLDPacket(t, p, header.IPv6AllRoutersLinkLocalMulticastAddress, mldDone, 0, addr) }, getAndCheckGroupAddress: func(t *testing.T, seen map[tcpip.Address]bool, p channel.PacketInfo) tcpip.Address { t.Helper() diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 14124ae66..a869cce38 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -5204,13 +5204,13 @@ func TestRouterSolicitation(t *testing.T) { } // Make sure the right remote link address is used. - if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersMulticastAddress); p.Route.RemoteLinkAddress != want { + if want := header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllRoutersLinkLocalMulticastAddress); p.Route.RemoteLinkAddress != want { t.Errorf("got remote link address = %s, want = %s", p.Route.RemoteLinkAddress, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(test.expectedSrcAddr), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS(checker.NDPRSOptions(test.expectedNDPOpts)), ) @@ -5362,7 +5362,7 @@ func TestStopStartSolicitingRouters(t *testing.T) { } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(header.IPv6Any), - checker.DstAddr(header.IPv6AllRoutersMulticastAddress), + checker.DstAddr(header.IPv6AllRoutersLinkLocalMulticastAddress), checker.TTL(header.NDPHopLimit), checker.NDPRS()) } diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 39344808d..4ae6bed5a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -132,7 +132,7 @@ func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndp localAddr = addressEndpoint.AddressWithPrefix().Address } - if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(localAddr) { + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalUnicastAddress(localAddr) { addressEndpoint.DecRef() return nil } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 931a97ddc..f23112410 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1344,7 +1344,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n s.mu.RLock() defer s.mu.RUnlock() - isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) + isLinkLocal := header.IsV6LinkLocalUnicastAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) @@ -1381,7 +1381,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, n return nil, &tcpip.ErrNetworkUnreachable{} } - canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalUnicastAddress(localAddr) && !isLinkLocal // Find a route to the remote with the route table. var chosenRoute tcpip.Route diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 87ea09a5e..60de16579 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -786,6 +786,13 @@ func (*TCPRecovery) isGettableTransportProtocolOption() {} func (*TCPRecovery) isSettableTransportProtocolOption() {} +// TCPAlwaysUseSynCookies indicates unconditional usage of syncookies. +type TCPAlwaysUseSynCookies bool + +func (*TCPAlwaysUseSynCookies) isGettableTransportProtocolOption() {} + +func (*TCPAlwaysUseSynCookies) isSettableTransportProtocolOption() {} + const ( // TCPRACKLossDetection indicates RACK is used for loss detection and // recovery. @@ -1020,19 +1027,6 @@ func (*TCPMaxRetriesOption) isGettableTransportProtocolOption() {} func (*TCPMaxRetriesOption) isSettableTransportProtocolOption() {} -// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify -// the number of endpoints that can be in SYN-RCVD state before the stack -// switches to using SYN cookies. -type TCPSynRcvdCountThresholdOption uint64 - -func (*TCPSynRcvdCountThresholdOption) isGettableSocketOption() {} - -func (*TCPSynRcvdCountThresholdOption) isSettableSocketOption() {} - -func (*TCPSynRcvdCountThresholdOption) isGettableTransportProtocolOption() {} - -func (*TCPSynRcvdCountThresholdOption) isSettableTransportProtocolOption() {} - // TCPSynRetriesOption is used by SetSockOpt/GetSockOpt to specify stack-wide // default for number of times SYN is retransmitted before aborting a connect. type TCPSynRetriesOption uint8 diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 3cc8c36f1..3b51e4be0 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -9,6 +9,8 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index d10ae05c2..0de5079e8 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -21,6 +21,8 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -312,3 +314,193 @@ func TestForwarding(t *testing.T) { }) } } + +func TestMulticastForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + ipv4LinkLocalUnicastAddr = tcpip.Address("\xa9\xfe\x00\x0a") + ipv4LinkLocalMulticastAddr = tcpip.Address("\xe0\x00\x00\x0a") + ipv4GlobalMulticastAddr = tcpip.Address("\xe0\x00\x01\x0a") + + ipv6LinkLocalUnicastAddr = tcpip.Address("\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a") + ipv6LinkLocalMulticastAddr = tcpip.Address("\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a") + ipv6GlobalMulticastAddr = tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0a") + + ttl = 64 + ) + + rxICMPv4EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv4EchoRequest(e, src, dst, ttl) + } + + rxICMPv6EchoRequest := func(e *channel.Endpoint, src, dst tcpip.Address) { + utils.RxICMPv6EchoRequest(e, src, dst, ttl) + } + + v4Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv4(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4Echo))) + } + + v6Checker := func(t *testing.T, b []byte, src, dst tcpip.Address) { + checker.IPv6(t, b, + checker.SrcAddr(src), + checker.DstAddr(dst), + checker.TTL(ttl-1), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoRequest))) + } + + tests := []struct { + name string + srcAddr, dstAddr tcpip.Address + rx func(*channel.Endpoint, tcpip.Address, tcpip.Address) + expectForward bool + checker func(*testing.T, []byte) + }{ + { + name: "IPv4 link-local multicast destination", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4LinkLocalMulticastAddr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 link-local source", + srcAddr: ipv4LinkLocalUnicastAddr, + dstAddr: utils.RemoteIPv4Addr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 link-local destination", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4LinkLocalUnicastAddr, + rx: rxICMPv4EchoRequest, + expectForward: false, + }, + { + name: "IPv4 non-link-local unicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + rx: rxICMPv4EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v4Checker(t, b, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv4 non-link-local multicast", + srcAddr: utils.RemoteIPv4Addr, + dstAddr: ipv4GlobalMulticastAddr, + rx: rxICMPv4EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v4Checker(t, b, utils.RemoteIPv4Addr, ipv4GlobalMulticastAddr) + }, + }, + + { + name: "IPv6 link-local multicast destination", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6LinkLocalMulticastAddr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 link-local source", + srcAddr: ipv6LinkLocalUnicastAddr, + dstAddr: utils.RemoteIPv6Addr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 link-local destination", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6LinkLocalUnicastAddr, + rx: rxICMPv6EchoRequest, + expectForward: false, + }, + { + name: "IPv6 non-link-local unicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + rx: rxICMPv6EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v6Checker(t, b, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address) + }, + }, + { + name: "IPv6 non-link-local multicast", + srcAddr: utils.RemoteIPv6Addr, + dstAddr: ipv6GlobalMulticastAddr, + rx: rxICMPv6EchoRequest, + expectForward: true, + checker: func(t *testing.T, b []byte) { + v6Checker(t, b, utils.RemoteIPv6Addr, ipv6GlobalMulticastAddr) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID1, e1); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNIC(nicID2, e2); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID2, err) + } + + if err := s.AddAddress(nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv4.ProtocolNumber, utils.Ipv4Addr.Address, err) + } + if err := s.AddAddress(nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID2, ipv6.ProtocolNumber, utils.Ipv6Addr.Address, err) + } + + if err := s.SetForwarding(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwarding(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwarding(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID2, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID2, + }, + }) + + test.rx(e1, test.srcAddr, test.dstAddr) + + p, ok := e2.Read() + if ok != test.expectForward { + t.Fatalf("got e2.Read() = (%#v, %t), want = (_, %t)", p, ok, test.expectForward) + } + + if test.expectForward { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + }) + } +} diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 2c538a43e..82c2e11ab 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -513,22 +513,23 @@ func TestExternalLoopbackTraffic(t *testing.T) { ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") numPackets = 1 + ttl = 64 ) loopbackSourcedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address) + utils.RxICMPv4EchoRequest(e, ipv4Loopback, utils.Ipv4Addr.Address, ttl) } loopbackSourcedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address) + utils.RxICMPv6EchoRequest(e, header.IPv6Loopback, utils.Ipv6Addr.Address, ttl) } loopbackDestinedICMPv4 := func(e *channel.Endpoint) { - utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback) + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, ipv4Loopback, ttl) } loopbackDestinedICMPv6 := func(e *channel.Endpoint) { - utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback) + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, header.IPv6Loopback, ttl) } invalidSrcAddrStat := func(s tcpip.IPStats) *tcpip.StatCounter { diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index c6a9c2393..09ff3b892 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -43,12 +43,15 @@ const ( // to a multicast or broadcast address uses a unicast source address for the // reply. func TestPingMulticastBroadcast(t *testing.T) { - const nicID = 1 + const ( + nicID = 1 + ttl = 64 + ) tests := []struct { name string protoNum tcpip.NetworkProtocolNumber - rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address) + rxICMP func(*channel.Endpoint, tcpip.Address, tcpip.Address, uint8) srcAddr tcpip.Address dstAddr tcpip.Address expectedSrc tcpip.Address @@ -136,7 +139,7 @@ func TestPingMulticastBroadcast(t *testing.T) { }, }) - test.rxICMP(e, test.srcAddr, test.dstAddr) + test.rxICMP(e, test.srcAddr, test.dstAddr, ttl) pkt, ok := e.Read() if !ok { t.Fatal("expected ICMP response") diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index d1c9f3a94..8fd9be32b 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -48,10 +48,6 @@ const ( LinkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") ) -const ( - ttl = 255 -) - // Common IP addresses used by tests. var ( Ipv4Addr = tcpip.AddressWithPrefix{ @@ -322,7 +318,7 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. // RxICMPv4EchoRequest constructs and injects an ICMPv4 echo request packet on // the provided endpoint. -func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { +func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) @@ -347,7 +343,7 @@ func RxICMPv4EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { // RxICMPv6EchoRequest constructs and injects an ICMPv6 echo request packet on // the provided endpoint. -func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address) { +func RxICMPv6EchoRequest(e *channel.Endpoint, src, dst tcpip.Address, ttl uint8) { totalLen := header.IPv6MinimumSize + header.ICMPv6MinimumSize hdr := buffer.NewPrependable(totalLen) pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 025b134e2..a485064a1 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -51,11 +51,6 @@ const ( // timestamp and the current timestamp. If the difference is greater // than maxTSDiff, the cookie is expired. maxTSDiff = 2 - - // SynRcvdCountThreshold is the default global maximum number of - // connections that are allowed to be in SYN-RCVD state before TCP - // starts using SYN cookies to accept connections. - SynRcvdCountThreshold uint64 = 1000 ) var ( @@ -80,9 +75,6 @@ func encodeMSS(mss uint16) uint32 { type listenContext struct { stack *stack.Stack - // synRcvdCount is a reference to the stack level synRcvdCount. - synRcvdCount *synRcvdCounter - // rcvWnd is the receive window that is sent by this listening context // in the initial SYN-ACK. rcvWnd seqnum.Size @@ -138,11 +130,6 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, listenEP: listenEP, pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), } - p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) - if !ok { - panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) - } - l.synRcvdCount = p.SynRcvdCounter() rand.Read(l.nonce[0][:]) rand.Read(l.nonce[1][:]) @@ -199,6 +186,14 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } +func (l *listenContext) useSynCookies() bool { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) +} + // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { @@ -307,6 +302,7 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // Initialize and start the handshake. h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) + h.listenEP = l.listenEP h.start() return h, nil } @@ -485,7 +481,6 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header } go func() { - defer ctx.synRcvdCount.dec() if err := h.complete(); err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() @@ -497,24 +492,29 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.deliverAccepted(h.ep, false /*withSynCookie*/) - }() // S/R-SAFE: synRcvdCount is the barrier. + }() return nil } -func (e *endpoint) incSynRcvdCount() bool { +func (e *endpoint) synRcvdBacklogFull() bool { e.acceptMu.Lock() - canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan) + acceptedChanCap := cap(e.acceptedChan) e.acceptMu.Unlock() - if canInc { - atomic.AddInt32(&e.synRcvdCount, 1) - } - return canInc + // The allocated accepted channel size would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + // + // We maintain an equality check here as the synRcvdCount is incremented + // and compared only from a single listener context and the capacity of + // the accepted channel can only increase by a new listen call. + return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedChanCap-1 } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan) + full := len(e.acceptedChan) == cap(e.acceptedChan) e.acceptMu.Unlock() return full } @@ -538,69 +538,55 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err switch { case s.flags == header.TCPFlagSyn: - opts := parseSynSegmentOptions(s) - if ctx.synRcvdCount.inc() { - // Only handle the syn if the following conditions hold - // - accept queue is not full. - // - number of connections in synRcvd state is less than the - // backlog. - if !e.acceptQueueIsFull() && e.incSynRcvdCount() { - s.incRef() - _ = e.handleSynSegment(ctx, s, &opts) - return nil - } - ctx.synRcvdCount.dec() + if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() return nil - } else { - // If cookies are in use but the endpoint accept queue - // is full then drop the syn. - if e.acceptQueueIsFull() { - e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() - e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() - e.stack.Stats().DroppedPackets.Increment() - return nil - } - cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + } - route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) - if err != nil { - return err - } - defer route.Release() + opts := parseSynSegmentOptions(s) + if !ctx.useSynCookies() { + s.incRef() + atomic.AddInt32(&e.synRcvdCount, 1) + return e.handleSynSegment(ctx, s, &opts) + } + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() - // Send SYN without window scaling because we currently - // don't encode this information in the cookie. - // - // Enable Timestamp option if the original syn did have - // the timestamp option specified. - // - // Use the user supplied MSS on the listening socket for - // new connections, if available. - synOpts := header.TCPSynOptions{ - WS: -1, - TS: opts.TS, - TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), - TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, route), - } - fields := tcpFields{ - id: s.id, - ttl: e.ttl, - tos: e.sendTOS, - flags: header.TCPFlagSyn | header.TCPFlagAck, - seq: cookie, - ack: s.sequenceNumber + 1, - rcvWnd: ctx.rcvWnd, - } - if err := e.sendSynTCP(route, fields, synOpts); err != nil { - return err - } - e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() - return nil + // Send SYN without window scaling because we currently + // don't encode this information in the cookie. + // + // Enable Timestamp option if the original syn did have + // the timestamp option specified. + // + // Use the user supplied MSS on the listening socket for + // new connections, if available. + synOpts := header.TCPSynOptions{ + WS: -1, + TS: opts.TS, + TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), + TSEcr: opts.TSVal, + MSS: calculateAdvertisedMSS(e.userMSS, route), + } + cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + fields := tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + } + if err := e.sendSynTCP(route, fields, synOpts); err != nil { + return err } + e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil case (s.flags & header.TCPFlagAck) != 0: if e.acceptQueueIsFull() { @@ -615,25 +601,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } - if !ctx.synRcvdCount.synCookiesInUse() { - // When not using SYN cookies, as per RFC 793, section 3.9, page 64: - // Any acknowledgment is bad if it arrives on a connection still in - // the LISTEN state. An acceptable reset segment should be formed - // for any arriving ACK-bearing segment. The RST should be - // formatted as follows: - // - // <SEQ=SEG.ACK><CTL=RST> - // - // Send a reset as this is an ACK for which there is no - // half open connections and we are not using cookies - // yet. - // - // The only time we should reach here when a connection - // was opened and closed really quickly and a delayed - // ACK was received from the sender. - return replyWithReset(e.stack, s, e.sendTOS, e.ttl) - } - iss := s.ackNumber - 1 irs := s.sequenceNumber - 1 @@ -651,7 +618,23 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return nil + + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // + // Send a reset as this is an ACK for which there is no + // half open connections and we are not using cookies + // yet. + // + // The only time we should reach here when a connection + // was opened and closed really quickly and a delayed + // ACK was received from the sender. + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index a9e978cf6..8f0f0c3e9 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -65,11 +65,12 @@ const ( // NOTE: handshake.ep.mu is held during handshake processing. It is released if // we are going to block and reacquired when we start processing an event. type handshake struct { - ep *endpoint - state handshakeState - active bool - flags header.TCPFlags - ackNum seqnum.Value + ep *endpoint + listenEP *endpoint + state handshakeState + active bool + flags header.TCPFlags + ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. iss seqnum.Value @@ -394,6 +395,15 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { return nil } + // Drop the ACK if the accept queue is full. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_ipv4.c#L1523 + // We could abort the connection as well with a tunable as in + // https://github.com/torvalds/linux/blob/7acac4b3196/net/ipv4/tcp_minisocks.c#L788 + if listenEP := h.listenEP; listenEP != nil && listenEP.acceptQueueIsFull() { + listenEP.stack.Stats().DroppedPackets.Increment() + return nil + } + // Update timestamp if required. See RFC7323, section-4.3. if h.ep.sendTSOk && s.parsedOptions.TS { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index f6a16f96e..d6d68f128 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -565,17 +565,15 @@ func TestV4AcceptOnV4(t *testing.T) { } func testV4ListenClose(t *testing.T, c *context.Context) { - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("setting TCPSynRcvdCountThresholdOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } - const n = uint16(32) + const n = 32 // Start listening. - if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { + if err := c.EP.Listen(n); err != nil { t.Fatalf("Listen failed: %v", err) } @@ -591,9 +589,9 @@ func testV4ListenClose(t *testing.T, c *context.Context) { }) } - // Each of these ACK's will cause a syn-cookie based connection to be + // Each of these ACKs will cause a syn-cookie based connection to be // accepted and delivered to the listening endpoint. - for i := uint16(0); i < n; i++ { + for i := 0; i < n; i++ { b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) iss := seqnum.Value(tcp.SequenceNumber()) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index c5daba232..5001d222e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2474,6 +2474,10 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) tcpip.Error { // Listen puts the endpoint in "listen" mode, which allows it to accept // new connections. func (e *endpoint) Listen(backlog int) tcpip.Error { + // Accept one more than the configured listen backlog to keep in parity with + // Linux. Ref, because of missing equality check here: + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937 + backlog++ err := e.listen(backlog) if err != nil { if !err.IgnoreStats() { diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 2a4667906..fe0d7f10f 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -75,63 +75,6 @@ const ( ccCubic = "cubic" ) -// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The -// value is protected by a mutex so that we can increment only when it's -// guaranteed not to go above a threshold. -type synRcvdCounter struct { - sync.Mutex - value uint64 - pending sync.WaitGroup - threshold uint64 -} - -// inc tries to increment the global number of endpoints in SYN-RCVD state. It -// succeeds if the increment doesn't make the count go beyond the threshold, and -// fails otherwise. -func (s *synRcvdCounter) inc() bool { - s.Lock() - defer s.Unlock() - if s.value >= s.threshold { - return false - } - - s.pending.Add(1) - s.value++ - - return true -} - -// dec atomically decrements the global number of endpoints in SYN-RCVD -// state. It must only be called if a previous call to inc succeeded. -func (s *synRcvdCounter) dec() { - s.Lock() - defer s.Unlock() - s.value-- - s.pending.Done() -} - -// synCookiesInUse returns true if the synRcvdCount is greater than -// SynRcvdCountThreshold. -func (s *synRcvdCounter) synCookiesInUse() bool { - s.Lock() - defer s.Unlock() - return s.value >= s.threshold -} - -// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. -func (s *synRcvdCounter) SetThreshold(threshold uint64) { - s.Lock() - defer s.Unlock() - s.threshold = threshold -} - -// Threshold returns the current value of synRcvdCounter.Threhsold. -func (s *synRcvdCounter) Threshold() uint64 { - s.Lock() - defer s.Unlock() - return s.threshold -} - type protocol struct { stack *stack.Stack @@ -139,6 +82,7 @@ type protocol struct { sackEnabled bool recovery tcpip.TCPRecovery delayEnabled bool + alwaysUseSynCookies bool sendBufferSize tcpip.TCPSendBufferSizeRangeOption recvBufferSize tcpip.TCPReceiveBufferSizeRangeOption congestionControl string @@ -150,7 +94,6 @@ type protocol struct { minRTO time.Duration maxRTO time.Duration maxRetries uint32 - synRcvdCount synRcvdCounter synRetries uint8 dispatcher dispatcher } @@ -373,9 +316,9 @@ func (p *protocol) SetOption(option tcpip.SettableTransportProtocolOption) tcpip p.mu.Unlock() return nil - case *tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPAlwaysUseSynCookies: p.mu.Lock() - p.synRcvdCount.SetThreshold(uint64(*v)) + p.alwaysUseSynCookies = bool(*v) p.mu.Unlock() return nil @@ -480,9 +423,9 @@ func (p *protocol) Option(option tcpip.GettableTransportProtocolOption) tcpip.Er p.mu.RUnlock() return nil - case *tcpip.TCPSynRcvdCountThresholdOption: + case *tcpip.TCPAlwaysUseSynCookies: p.mu.RLock() - *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + *v = tcpip.TCPAlwaysUseSynCookies(p.alwaysUseSynCookies) p.mu.RUnlock() return nil @@ -507,12 +450,6 @@ func (p *protocol) Wait() { p.dispatcher.wait() } -// SynRcvdCounter returns a reference to the synRcvdCount for this protocol -// instance. -func (p *protocol) SynRcvdCounter() *synRcvdCounter { - return &p.synRcvdCount -} - // Parse implements stack.TransportProtocol.Parse. func (*protocol) Parse(pkt *stack.PacketBuffer) bool { return parse.TCP(pkt) @@ -537,7 +474,6 @@ func NewProtocol(s *stack.Stack) stack.TransportProtocol { lingerTimeout: DefaultTCPLingerTimeout, timeWaitTimeout: DefaultTCPTimeWaitTimeout, timeWaitReuse: tcpip.TCPTimeWaitReuseLoopbackOnly, - synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 81f800cad..20c9761f2 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -160,12 +160,9 @@ func TestSackPermittedAccept(t *testing.T) { defer c.Cleanup() if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } setStackSACKPermitted(t, c, sackEnabled) @@ -235,12 +232,9 @@ func TestSackDisabledAccept(t *testing.T) { defer c.Cleanup() if tc.cookieEnabled { - // Set the SynRcvd threshold to - // zero to force a syn cookie - // based accept to happen. - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9c23469f2..5605a4390 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -955,11 +955,7 @@ func TestUserSuppliedMSSOnConnect(t *testing.T) { // when completing the handshake for a new TCP connection from a TCP // listening socket. It should be present in the sent TCP SYN-ACK segment. func TestUserSuppliedMSSOnListenAccept(t *testing.T) { - const ( - nonSynCookieAccepts = 2 - totalAccepts = 4 - mtu = 5000 - ) + const mtu = 5000 ips := []struct { name string @@ -1033,12 +1029,6 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { ip.createEP(c) - // Set the SynRcvd threshold to force a syn cookie based accept to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(nonSynCookieAccepts) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { t.Fatalf("SetSockOptInt(MaxSegOption, %d): %s", test.setMSS, err) } @@ -1048,13 +1038,17 @@ func TestUserSuppliedMSSOnListenAccept(t *testing.T) { t.Fatalf("Bind(%+v): %s:", bindAddr, err) } - if err := c.EP.Listen(totalAccepts); err != nil { - t.Fatalf("Listen(%d): %s:", totalAccepts, err) + backlog := 5 + // Keep the number of client requests twice to the backlog + // such that half of the connections do not use syncookies + // and the other half does. + clientConnects := backlog * 2 + + if err := c.EP.Listen(backlog); err != nil { + t.Fatalf("Listen(%d): %s:", backlog, err) } - // The first nonSynCookieAccepts packets sent will trigger a gorooutine - // based accept. The rest will trigger a cookie based accept. - for i := 0; i < totalAccepts; i++ { + for i := 0; i < clientConnects; i++ { // Send a SYN requests. iss := seqnum.Value(i) srcPort := context.TestPort + uint16(i) @@ -3087,11 +3081,9 @@ func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { c := context.New(t, mtu) defer c.Cleanup() - // Set the SynRcvd threshold to zero to force a syn cookie based accept - // to happen. - opt := tcpip.TCPSynRcvdCountThresholdOption(0) + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } // Create EP and start listening. @@ -5363,7 +5355,7 @@ func TestListenBacklogFull(t *testing.T) { } lastPortOffset := uint16(0) - for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { + for ; int(lastPortOffset) < listenBacklog+1; lastPortOffset++ { executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } @@ -5671,15 +5663,13 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } // Test acceptance. - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { + if err := c.EP.Listen(0); err != nil { t.Fatalf("Listen failed: %s", err) } // Send two SYN's the first one should get a SYN-ACK, the // second one should not get any response and is dropped as - // the synRcvd count will be equal to backlog. + // the accept queue is full. irs := seqnum.Value(context.TestInitialSequenceNumber) c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5701,23 +5691,7 @@ func TestListenSynRcvdQueueFull(t *testing.T) { } checker.IPv4(t, b, checker.TCP(tcpCheckers...)) - // Now execute send one more SYN. The stack should not respond as the backlog - // is full at this point. - // - // NOTE: we did not complete the handshake for the previous one so the - // accept backlog should be empty and there should be one connection in - // synRcvd state. - c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 1, - DstPort: context.StackPort, - Flags: header.TCPFlagSyn, - SeqNum: seqnum.Value(889), - RcvWnd: 30000, - }) - c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) - - // Now complete the previous connection and verify that there is a connection - // to accept. + // Now complete the previous connection. // Send ACK. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, @@ -5728,11 +5702,24 @@ func TestListenSynRcvdQueueFull(t *testing.T) { RcvWnd: 30000, }) - // Try to accept the connections in the backlog. + // Verify if that is delivered to the accept queue. we, ch := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&we, waiter.ReadableEvents) defer c.WQ.EventUnregister(&we) + <-ch + + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(889), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + // Try to accept the connections in the backlog. newEP, _, err := c.EP.Accept(nil) if _, ok := err.(*tcpip.ErrWouldBlock); ok { // Wait for connection to be established. @@ -5764,11 +5751,6 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() - opt := tcpip.TCPSynRcvdCountThresholdOption(1) - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) - } - // Create TCP endpoint. var err tcpip.Error c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -5781,9 +5763,8 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) { t.Fatalf("Bind failed: %s", err) } - // Start listening. - listenBacklog := 1 - if err := c.EP.Listen(listenBacklog); err != nil { + // Test for SynCookies usage after filling up the backlog. + if err := c.EP.Listen(0); err != nil { t.Fatalf("Listen failed: %s", err) } @@ -6066,7 +6047,7 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } - if err := c.EP.Listen(1); err != nil { + if err := c.EP.Listen(0); err != nil { t.Fatalf("Listen failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 2949588ce..1deb1fe4d 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -139,9 +139,9 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS defer c.Cleanup() if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } @@ -202,9 +202,9 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd defer c.Cleanup() if cookieEnabled { - var opt tcpip.TCPSynRcvdCountThresholdOption + opt := tcpip.TCPAlwaysUseSynCookies(true) if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, &opt); err != nil { - t.Fatalf("SetTransportProtocolOption(%d, &%T(%d)): %s", tcp.ProtocolNumber, opt, opt, err) + t.Fatalf("SetTransportProtocolOption(%d, &%T(%t)): %s", tcp.ProtocolNumber, opt, opt, err) } } diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 41fcf4978..06152a444 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -434,7 +434,14 @@ func (c *Container) Wait(ctx context.Context) error { select { case err := <-errChan: return err - case <-statusChan: + case res := <-statusChan: + if res.StatusCode != 0 { + var msg string + if res.Error != nil { + msg = res.Error.Message + } + return fmt.Errorf("container returned non-zero status: %d, msg: %q", res.StatusCode, msg) + } return nil } } diff --git a/runsc/BUILD b/runsc/BUILD index 3b91b984a..e99404eb1 100644 --- a/runsc/BUILD +++ b/runsc/BUILD @@ -9,6 +9,7 @@ go_binary( "version.go", ], pure = True, + tags = ["staging"], visibility = [ "//visibility:public", ], diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 67307ab3c..579edaa2c 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -57,6 +57,7 @@ go_library( "//pkg/sentry/fs/tmpfs", "//pkg/sentry/fs/tty", "//pkg/sentry/fs/user", + "//pkg/sentry/fsimpl/cgroupfs", "//pkg/sentry/fsimpl/devpts", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/fuse", @@ -66,6 +67,7 @@ go_library( "//pkg/sentry/fsimpl/proc", "//pkg/sentry/fsimpl/sys", "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/fsimpl/verity", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel:uncaught_signal_go_proto", diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 1ae76d7d7..05b721b28 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -400,7 +400,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { // Set up the restore environment. ctx := k.SupervisorContext() - mntr := newContainerMounter(cm.l.root.spec, cm.l.root.goferFDs, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled) + mntr := newContainerMounter(&cm.l.root, cm.l.k, cm.l.mountHints, kernel.VFS2Enabled) if kernel.VFS2Enabled { ctx, err = mntr.configureRestore(ctx, cm.l.root.conf) if err != nil { diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 19ced9b0e..3c0cef6db 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/gofer" "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/fs/user" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/cgroupfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" gofervfs2 "gvisor.dev/gvisor/pkg/sentry/fsimpl/gofer" @@ -103,7 +104,7 @@ func addOverlay(ctx context.Context, conf *config.Config, lower *fs.Inode, name // compileMounts returns the supported mounts from the mount spec, adding any // mandatory mounts that are required by the OCI specification. -func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount { +func compileMounts(spec *specs.Spec, conf *config.Config, vfs2Enabled bool) []specs.Mount { // Keep track of whether proc and sys were mounted. var procMounted, sysMounted, devMounted, devptsMounted bool var mounts []specs.Mount @@ -114,6 +115,11 @@ func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount { log.Warningf("ignoring dev mount at %q", m.Destination) continue } + // Unconditionally drop any cgroupfs mounts. If requested, we'll add our + // own below. + if m.Type == cgroupfs.Name { + continue + } switch filepath.Clean(m.Destination) { case "/proc": procMounted = true @@ -132,6 +138,24 @@ func compileMounts(spec *specs.Spec, vfs2Enabled bool) []specs.Mount { // Mount proc and sys even if the user did not ask for it, as the spec // says we SHOULD. var mandatoryMounts []specs.Mount + + if conf.Cgroupfs { + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: tmpfsvfs2.Name, + Destination: "/sys/fs/cgroup", + }) + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: cgroupfs.Name, + Destination: "/sys/fs/cgroup/memory", + Options: []string{"memory"}, + }) + mandatoryMounts = append(mandatoryMounts, specs.Mount{ + Type: cgroupfs.Name, + Destination: "/sys/fs/cgroup/cpu", + Options: []string{"cpu"}, + }) + } + if !procMounted { mandatoryMounts = append(mandatoryMounts, specs.Mount{ Type: procvfs2.Name, @@ -248,6 +272,10 @@ func isSupportedMountFlag(fstype, opt string) bool { ok, err := parseMountOption(opt, tmpfsAllowedData...) return ok && err == nil } + if fstype == cgroupfs.Name { + ok, err := parseMountOption(opt, cgroupfs.SupportedMountOptions...) + return ok && err == nil + } return false } @@ -572,11 +600,11 @@ type containerMounter struct { hints *podMountHints } -func newContainerMounter(spec *specs.Spec, goferFDs []*fd.FD, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter { +func newContainerMounter(info *containerInfo, k *kernel.Kernel, hints *podMountHints, vfs2Enabled bool) *containerMounter { return &containerMounter{ - root: spec.Root, - mounts: compileMounts(spec, vfs2Enabled), - fds: fdDispenser{fds: goferFDs}, + root: info.spec.Root, + mounts: compileMounts(info.spec, info.conf, vfs2Enabled), + fds: fdDispenser{fds: info.goferFDs}, k: k, hints: hints, } @@ -795,7 +823,13 @@ func (c *containerMounter) getMountNameAndOptions(conf *config.Config, m specs.M opts = p9MountData(fd, c.getMountAccessType(conf, m), conf.VFS2) // If configured, add overlay to all writable mounts. useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly - + case cgroupfs.Name: + fsName = m.Type + var err error + opts, err = parseAndFilterOptions(m.Options, cgroupfs.SupportedMountOptions...) + if err != nil { + return "", nil, false, err + } default: log.Warningf("ignoring unknown filesystem type %q", m.Type) } diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 774621970..95daf1f00 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -752,7 +752,7 @@ func (l *Loader) createContainerProcess(root bool, cid string, info *containerIn // Setup the child container file system. l.startGoferMonitor(cid, info.goferFDs) - mntr := newContainerMounter(info.spec, info.goferFDs, l.k, l.mountHints, kernel.VFS2Enabled) + mntr := newContainerMounter(info, l.k, l.mountHints, kernel.VFS2Enabled) if root { if err := mntr.processHints(info.conf, info.procArgs.Credentials); err != nil { return nil, nil, nil, err diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 8b39bc59a..93c476971 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -439,7 +439,13 @@ func TestCreateMountNamespace(t *testing.T) { } defer cleanup() - mntr := newContainerMounter(&tc.spec, []*fd.FD{fd.New(sandEnd)}, nil, &podMountHints{}, false /* vfs2Enabled */) + info := containerInfo{ + conf: conf, + spec: &tc.spec, + goferFDs: []*fd.FD{fd.New(sandEnd)}, + } + + mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */) mns, err := mntr.createMountNamespace(ctx, conf) if err != nil { t.Fatalf("failed to create mount namespace: %v", err) @@ -479,7 +485,7 @@ func TestCreateMountNamespaceVFS2(t *testing.T) { defer l.Destroy() defer loaderCleanup() - mntr := newContainerMounter(l.root.spec, l.root.goferFDs, l.k, l.mountHints, true /* vfs2Enabled */) + mntr := newContainerMounter(&l.root, l.k, l.mountHints, true /* vfs2Enabled */) if err := mntr.processHints(l.root.conf, l.root.procArgs.Credentials); err != nil { t.Fatalf("failed process hints: %v", err) } @@ -702,7 +708,12 @@ func TestRestoreEnvironment(t *testing.T) { for _, ioFD := range tc.ioFDs { ioFDs = append(ioFDs, fd.New(ioFD)) } - mntr := newContainerMounter(tc.spec, ioFDs, nil, &podMountHints{}, false /* vfs2Enabled */) + info := containerInfo{ + conf: conf, + spec: tc.spec, + goferFDs: ioFDs, + } + mntr := newContainerMounter(&info, nil, &podMountHints{}, false /* vfs2Enabled */) actualRenv, err := mntr.createRestoreEnvironment(conf) if !tc.errorExpected && err != nil { t.Fatalf("could not create restore environment for test:%s", tc.name) diff --git a/runsc/boot/vfs.go b/runsc/boot/vfs.go index 9b3dacf46..7d8fd0483 100644 --- a/runsc/boot/vfs.go +++ b/runsc/boot/vfs.go @@ -16,6 +16,7 @@ package boot import ( "fmt" + "path" "sort" "strings" @@ -29,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/devices/ttydev" "gvisor.dev/gvisor/pkg/sentry/devices/tundev" "gvisor.dev/gvisor/pkg/sentry/fs/user" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/cgroupfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devpts" "gvisor.dev/gvisor/pkg/sentry/fsimpl/devtmpfs" "gvisor.dev/gvisor/pkg/sentry/fsimpl/fuse" @@ -37,6 +39,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/proc" "gvisor.dev/gvisor/pkg/sentry/fsimpl/sys" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/verity" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -50,6 +53,10 @@ func registerFilesystems(k *kernel.Kernel) error { creds := auth.NewRootCredentials(k.RootUserNamespace()) vfsObj := k.VFS() + vfsObj.MustRegisterFilesystemType(cgroupfs.Name, &cgroupfs.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) vfsObj.MustRegisterFilesystemType(devpts.Name, &devpts.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserList: true, // TODO(b/29356795): Users may mount this once the terminals are in a @@ -60,6 +67,10 @@ func registerFilesystems(k *kernel.Kernel) error { AllowUserMount: true, AllowUserList: true, }) + vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ + AllowUserMount: true, + AllowUserList: true, + }) vfsObj.MustRegisterFilesystemType(gofer.Name, &gofer.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserList: true, }) @@ -79,9 +90,9 @@ func registerFilesystems(k *kernel.Kernel) error { AllowUserMount: true, AllowUserList: true, }) - vfsObj.MustRegisterFilesystemType(fuse.Name, &fuse.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ - AllowUserMount: true, + vfsObj.MustRegisterFilesystemType(verity.Name, &verity.FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ AllowUserList: true, + AllowUserMount: true, }) // Setup files in devtmpfs. @@ -472,6 +483,12 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo var data []string var iopts interface{} + verityData, verityOpts, verityRequested, remainingMOpts, err := parseVerityMountOptions(m.Options) + if err != nil { + return "", nil, false, err + } + m.Options = remainingMOpts + // Find filesystem name and FS specific data field. switch m.Type { case devpts.Name, devtmpfs.Name, proc.Name, sys.Name: @@ -502,6 +519,13 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo // If configured, add overlay to all writable mounts. useOverlay = conf.Overlay && !mountFlags(m.Options).ReadOnly + case cgroupfs.Name: + var err error + data, err = parseAndFilterOptions(m.Options, cgroupfs.SupportedMountOptions...) + if err != nil { + return "", nil, false, err + } + default: log.Warningf("ignoring unknown filesystem type %q", m.Type) return "", nil, false, nil @@ -530,9 +554,75 @@ func (c *containerMounter) getMountNameAndOptionsVFS2(conf *config.Config, m *mo } } + if verityRequested { + verityData = verityData + "root_name=" + path.Base(m.Mount.Destination) + verityOpts.LowerName = fsName + verityOpts.LowerGetFSOptions = opts.GetFilesystemOptions + fsName = verity.Name + opts = &vfs.MountOptions{ + GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: verityData, + InternalData: verityOpts, + }, + InternalMount: true, + } + } + return fsName, opts, useOverlay, nil } +func parseKeyValue(s string) (string, string, bool) { + tokens := strings.SplitN(s, "=", 2) + if len(tokens) < 2 { + return "", "", false + } + return strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1]), true +} + +// parseAndFilterOptions scans the provided mount options for verity-related +// mount options. It returns the parsed set of verity mount options, as well as +// the filtered set of mount options unrelated to verity. +func parseVerityMountOptions(mopts []string) (string, verity.InternalFilesystemOptions, bool, []string, error) { + nonVerity := []string{} + found := false + var rootHash string + verityOpts := verity.InternalFilesystemOptions{ + Action: verity.PanicOnViolation, + } + for _, o := range mopts { + if !strings.HasPrefix(o, "verity.") { + nonVerity = append(nonVerity, o) + continue + } + + k, v, ok := parseKeyValue(o) + if !ok { + return "", verityOpts, found, nonVerity, fmt.Errorf("invalid verity mount option with no value: %q", o) + } + + found = true + switch k { + case "verity.roothash": + rootHash = v + case "verity.action": + switch v { + case "error": + verityOpts.Action = verity.ErrorOnViolation + case "panic": + verityOpts.Action = verity.PanicOnViolation + default: + log.Warningf("Invalid verity action %q", v) + verityOpts.Action = verity.PanicOnViolation + } + default: + return "", verityOpts, found, nonVerity, fmt.Errorf("unknown verity mount option: %q", k) + } + } + verityOpts.AllowRuntimeEnable = len(rootHash) == 0 + verityData := "root_hash=" + rootHash + "," + return verityData, verityOpts, found, nonVerity, nil +} + // mountTmpVFS2 mounts an internal tmpfs at '/tmp' if it's safe to do so. // Technically we don't have to mount tmpfs at /tmp, as we could just rely on // the host /tmp, but this is a nice optimization, and fixes some apps that call diff --git a/runsc/cli/main.go b/runsc/cli/main.go index a3c515f4b..6db6614cc 100644 --- a/runsc/cli/main.go +++ b/runsc/cli/main.go @@ -86,6 +86,7 @@ func Main(version string) { subcommands.Register(new(cmd.Symbolize), "") subcommands.Register(new(cmd.Wait), "") subcommands.Register(new(cmd.Mitigate), "") + subcommands.Register(new(cmd.VerityPrepare), "") // Register internal commands with the internal group name. This causes // them to be sorted below the user-facing commands with empty group. diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index 2c3b4058b..4b9987cf6 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -35,6 +35,7 @@ go_library( "statefile.go", "symbolize.go", "syscalls.go", + "verity_prepare.go", "wait.go", ], visibility = [ diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go index 455c57692..5485db149 100644 --- a/runsc/cmd/do.go +++ b/runsc/cmd/do.go @@ -126,9 +126,8 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su Hostname: hostname, } - specutils.LogSpec(spec) - cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000)) + if conf.Network == config.NetworkNone { addNamespace(spec, specs.LinuxNamespace{Type: specs.NetworkNamespace}) @@ -154,55 +153,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su } } - out, err := json.Marshal(spec) - if err != nil { - return Errorf("Error to marshal spec: %v", err) - } - tmpDir, err := ioutil.TempDir("", "runsc-do") - if err != nil { - return Errorf("Error to create tmp dir: %v", err) - } - defer os.RemoveAll(tmpDir) - - log.Infof("Changing configuration RootDir to %q", tmpDir) - conf.RootDir = tmpDir - - cfgPath := filepath.Join(tmpDir, "config.json") - if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil { - return Errorf("Error write spec: %v", err) - } - - containerArgs := container.Args{ - ID: cid, - Spec: spec, - BundleDir: tmpDir, - Attached: true, - } - ct, err := container.New(conf, containerArgs) - if err != nil { - return Errorf("creating container: %v", err) - } - defer ct.Destroy() - - if err := ct.Start(conf); err != nil { - return Errorf("starting container: %v", err) - } - - // Forward signals to init in the container. Thus if we get SIGINT from - // ^C, the container gracefully exit, and we can clean up. - // - // N.B. There is a still a window before this where a signal may kill - // this process, skipping cleanup. - stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */) - defer stopForwarding() - - ws, err := ct.Wait() - if err != nil { - return Errorf("waiting for container: %v", err) - } - - *waitStatus = ws - return subcommands.ExitSuccess + return startContainerAndWait(spec, conf, cid, waitStatus) } func addNamespace(spec *specs.Spec, ns specs.LinuxNamespace) { @@ -397,3 +348,58 @@ func calculatePeerIP(ip string) (string, error) { } return fmt.Sprintf("%s.%s.%s.%d", parts[0], parts[1], parts[2], n), nil } + +func startContainerAndWait(spec *specs.Spec, conf *config.Config, cid string, waitStatus *unix.WaitStatus) subcommands.ExitStatus { + specutils.LogSpec(spec) + + out, err := json.Marshal(spec) + if err != nil { + return Errorf("Error to marshal spec: %v", err) + } + tmpDir, err := ioutil.TempDir("", "runsc-do") + if err != nil { + return Errorf("Error to create tmp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + log.Infof("Changing configuration RootDir to %q", tmpDir) + conf.RootDir = tmpDir + + cfgPath := filepath.Join(tmpDir, "config.json") + if err := ioutil.WriteFile(cfgPath, out, 0755); err != nil { + return Errorf("Error write spec: %v", err) + } + + containerArgs := container.Args{ + ID: cid, + Spec: spec, + BundleDir: tmpDir, + Attached: true, + } + + ct, err := container.New(conf, containerArgs) + if err != nil { + return Errorf("creating container: %v", err) + } + defer ct.Destroy() + + if err := ct.Start(conf); err != nil { + return Errorf("starting container: %v", err) + } + + // Forward signals to init in the container. Thus if we get SIGINT from + // ^C, the container gracefully exit, and we can clean up. + // + // N.B. There is a still a window before this where a signal may kill + // this process, skipping cleanup. + stopForwarding := ct.ForwardSignals(0 /* pid */, false /* fgProcess */) + defer stopForwarding() + + ws, err := ct.Wait() + if err != nil { + return Errorf("waiting for container: %v", err) + } + + *waitStatus = ws + return subcommands.ExitSuccess +} diff --git a/runsc/cmd/verity_prepare.go b/runsc/cmd/verity_prepare.go new file mode 100644 index 000000000..66128b2a3 --- /dev/null +++ b/runsc/cmd/verity_prepare.go @@ -0,0 +1,108 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "context" + "fmt" + "math/rand" + "os" + + "github.com/google/subcommands" + specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/runsc/config" + "gvisor.dev/gvisor/runsc/flag" + "gvisor.dev/gvisor/runsc/specutils" +) + +// VerityPrepare implements subcommands.Commands for the "verity-prepare" +// command. It sets up a sandbox with a writable verity mount mapped to "--dir", +// and executes the verity measure tool specified by "--tool" in the sandbox. It +// is intended to prepare --dir to be mounted as a verity filesystem. +type VerityPrepare struct { + root string + tool string + dir string +} + +// Name implements subcommands.Command.Name. +func (*VerityPrepare) Name() string { + return "verity-prepare" +} + +// Synopsis implements subcommands.Command.Synopsis. +func (*VerityPrepare) Synopsis() string { + return "Generates the data structures necessary to enable verityfs on a filesystem." +} + +// Usage implements subcommands.Command.Usage. +func (*VerityPrepare) Usage() string { + return "verity-prepare --tool=<measure_tool> --dir=<path>" +} + +// SetFlags implements subcommands.Command.SetFlags. +func (c *VerityPrepare) SetFlags(f *flag.FlagSet) { + f.StringVar(&c.root, "root", "/", `path to the root directory, defaults to "/"`) + f.StringVar(&c.tool, "tool", "", "path to the verity measure_tool") + f.StringVar(&c.dir, "dir", "", "path to the directory to be hashed") +} + +// Execute implements subcommands.Command.Execute. +func (c *VerityPrepare) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + conf := args[0].(*config.Config) + waitStatus := args[1].(*unix.WaitStatus) + + hostname, err := os.Hostname() + if err != nil { + return Errorf("Error to retrieve hostname: %v", err) + } + + // Map the entire host file system. + absRoot, err := resolvePath(c.root) + if err != nil { + return Errorf("Error resolving root: %v", err) + } + + spec := &specs.Spec{ + Root: &specs.Root{ + Path: absRoot, + }, + Process: &specs.Process{ + Cwd: absRoot, + Args: []string{c.tool, "--path", "/verityroot"}, + Env: os.Environ(), + Capabilities: specutils.AllCapabilities(), + }, + Hostname: hostname, + Mounts: []specs.Mount{ + specs.Mount{ + Source: c.dir, + Destination: "/verityroot", + Type: "bind", + Options: []string{"verity.roothash="}, + }, + }, + } + + cid := fmt.Sprintf("runsc-%06d", rand.Int31n(1000000)) + + // Force no networking, it is not necessary to run the verity measure tool. + conf.Network = config.NetworkNone + + conf.Verity = true + + return startContainerAndWait(spec, conf, cid, waitStatus) +} diff --git a/runsc/config/config.go b/runsc/config/config.go index 1e5858837..0b2b97cc5 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -172,6 +172,9 @@ type Config struct { // Enables seccomp inside the sandbox. OCISeccomp bool `flag:"oci-seccomp"` + // Mounts the cgroup filesystem backed by the sentry's cgroupfs. + Cgroupfs bool `flag:"cgroupfs"` + // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in // tests. It allows runsc to start the sandbox process as the current // user, and without chrooting the sandbox process. This can be diff --git a/runsc/config/flags.go b/runsc/config/flags.go index 1d996c841..13a1a0163 100644 --- a/runsc/config/flags.go +++ b/runsc/config/flags.go @@ -75,6 +75,7 @@ func RegisterFlags() { flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") flag.Bool("vfs2", false, "enables VFSv2. This uses the new VFS layer that is faster than the previous one.") flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.") + flag.Bool("cgroupfs", false, "Automatically mount cgroupfs.") // Flags that control sandbox runtime behavior: network related. flag.Var(networkTypePtr(NetworkSandbox), "network", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 450f92645..47da2dd10 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -486,7 +486,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn } if deviceFile, err := gPlatform.OpenDevice(); err != nil { - return fmt.Errorf("opening device file for platform %q: %v", gPlatform, err) + return fmt.Errorf("opening device file for platform %q: %v", conf.Platform, err) } else if deviceFile != nil { defer deviceFile.Close() cmd.ExtraFiles = append(cmd.ExtraFiles, deviceFile) @@ -1174,7 +1174,7 @@ func deviceFileForPlatform(name string) (*os.File, error) { f, err := p.OpenDevice() if err != nil { - return nil, fmt.Errorf("opening device file for platform %q: %v", p, err) + return nil, fmt.Errorf("opening device file for platform %q: %w", name, err) } return f, nil } diff --git a/runsc/specutils/fs.go b/runsc/specutils/fs.go index b62504a8c..9ecd0fde6 100644 --- a/runsc/specutils/fs.go +++ b/runsc/specutils/fs.go @@ -18,6 +18,7 @@ import ( "fmt" "math/bits" "path" + "strings" specs "github.com/opencontainers/runtime-spec/specs-go" "golang.org/x/sys/unix" @@ -64,6 +65,12 @@ var optionsMap = map[string]mapping{ "sync": {set: true, val: unix.MS_SYNCHRONOUS}, } +// verityMountOptions is the set of valid verity mount option keys. +var verityMountOptions = map[string]struct{}{ + "verity.roothash": struct{}{}, + "verity.action": struct{}{}, +} + // propOptionsMap is similar to optionsMap, but it lists propagation options // that cannot be used together with other flags. var propOptionsMap = map[string]mapping{ @@ -117,6 +124,14 @@ func validateMount(mnt *specs.Mount) error { return nil } +func moptKey(opt string) string { + if len(opt) == 0 { + return opt + } + // Guaranteed to have at least one token, since opt is not empty. + return strings.SplitN(opt, "=", 2)[0] +} + // ValidateMountOptions validates that mount options are correct. func ValidateMountOptions(opts []string) error { for _, o := range opts { @@ -125,7 +140,8 @@ func ValidateMountOptions(opts []string) error { } _, ok1 := optionsMap[o] _, ok2 := propOptionsMap[o] - if !ok1 && !ok2 { + _, ok3 := verityMountOptions[moptKey(o)] + if !ok1 && !ok2 && !ok3 { return fmt.Errorf("unknown mount option %q", o) } if err := validatePropagation(o); err != nil { diff --git a/shim/BUILD b/shim/BUILD index 434269d31..695f61eb9 100644 --- a/shim/BUILD +++ b/shim/BUILD @@ -6,6 +6,7 @@ go_binary( name = "containerd-shim-runsc-v1", srcs = ["main.go"], static = True, + tags = ["staging"], visibility = [ "//visibility:public", ], diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD index 697ab5837..a5a3cf2c1 100644 --- a/test/benchmarks/base/BUILD +++ b/test/benchmarks/base/BUILD @@ -17,7 +17,6 @@ go_library( benchmark_test( name = "startup_test", - size = "enormous", srcs = ["startup_test.go"], visibility = ["//:sandbox"], deps = [ @@ -29,7 +28,6 @@ benchmark_test( benchmark_test( name = "size_test", - size = "enormous", srcs = ["size_test.go"], visibility = ["//:sandbox"], deps = [ @@ -42,7 +40,6 @@ benchmark_test( benchmark_test( name = "sysbench_test", - size = "enormous", srcs = ["sysbench_test.go"], visibility = ["//:sandbox"], deps = [ diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD index 0b1743603..fee2695ff 100644 --- a/test/benchmarks/database/BUILD +++ b/test/benchmarks/database/BUILD @@ -11,7 +11,6 @@ go_library( benchmark_test( name = "redis_test", - size = "enormous", srcs = ["redis_test.go"], library = ":database", visibility = ["//:sandbox"], diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD index dc82e63b2..c2b981a07 100644 --- a/test/benchmarks/fs/BUILD +++ b/test/benchmarks/fs/BUILD @@ -4,7 +4,6 @@ package(licenses = ["notice"]) benchmark_test( name = "bazel_test", - size = "enormous", srcs = ["bazel_test.go"], visibility = ["//:sandbox"], deps = [ @@ -18,7 +17,6 @@ benchmark_test( benchmark_test( name = "fio_test", - size = "enormous", srcs = ["fio_test.go"], visibility = ["//:sandbox"], deps = [ diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD index 380783f0b..ad2ef3a55 100644 --- a/test/benchmarks/media/BUILD +++ b/test/benchmarks/media/BUILD @@ -11,7 +11,6 @@ go_library( benchmark_test( name = "ffmpeg_test", - size = "enormous", srcs = ["ffmpeg_test.go"], library = ":media", visibility = ["//:sandbox"], diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD index 3425b8dad..56a4d4f39 100644 --- a/test/benchmarks/ml/BUILD +++ b/test/benchmarks/ml/BUILD @@ -11,7 +11,6 @@ go_library( benchmark_test( name = "tensorflow_test", - size = "enormous", srcs = ["tensorflow_test.go"], library = ":ml", visibility = ["//:sandbox"], diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD index 2741570f5..e047020bf 100644 --- a/test/benchmarks/network/BUILD +++ b/test/benchmarks/network/BUILD @@ -18,7 +18,6 @@ go_library( benchmark_test( name = "iperf_test", - size = "enormous", srcs = [ "iperf_test.go", ], @@ -34,7 +33,6 @@ benchmark_test( benchmark_test( name = "node_test", - size = "enormous", srcs = [ "node_test.go", ], @@ -49,7 +47,6 @@ benchmark_test( benchmark_test( name = "ruby_test", - size = "enormous", srcs = [ "ruby_test.go", ], @@ -64,7 +61,6 @@ benchmark_test( benchmark_test( name = "nginx_test", - size = "enormous", srcs = [ "nginx_test.go", ], @@ -79,7 +75,6 @@ benchmark_test( benchmark_test( name = "httpd_test", - size = "enormous", srcs = [ "httpd_test.go", ], diff --git a/test/e2e/BUILD b/test/e2e/BUILD index 29a84f184..3b3dadf04 100644 --- a/test/e2e/BUILD +++ b/test/e2e/BUILD @@ -8,7 +8,6 @@ go_test( srcs = [ "exec_test.go", "integration_test.go", - "regression_test.go", ], library = ":integration", tags = [ diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index e83576722..1accc3b3b 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -168,13 +168,6 @@ func TestCheckpointRestore(t *testing.T) { t.Skip("Pause/resume is not supported.") } - // TODO(gvisor.dev/issue/3373): Remove after implementing. - if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 { - t.Skip("CheckpointRestore not implemented in VFS2.") - } else if err != nil { - t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err) - } - ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) @@ -592,6 +585,30 @@ func runIntegrationTest(t *testing.T, capAdd []string, args ...string) { } } +// Test that UDS can be created using overlay when parent directory is in lower +// layer only (b/134090485). +// +// Prerequisite: the directory where the socket file is created must not have +// been open for write before bind(2) is called. +func TestBindOverlay(t *testing.T) { + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + // Run the container. + got, err := d.Run(ctx, dockerutil.RunOpts{ + Image: "basic/ubuntu", + }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p") + if err != nil { + t.Fatalf("docker run failed: %v", err) + } + + // Check the output contains what we want. + if want := "foobar-asdf"; !strings.Contains(got, want) { + t.Fatalf("docker run output is missing %q: %s", want, got) + } +} + func TestMain(m *testing.M) { dockerutil.EnsureSupportedDockerVersion() flag.Parse() diff --git a/test/e2e/regression_test.go b/test/e2e/regression_test.go deleted file mode 100644 index 84564cdaa..000000000 --- a/test/e2e/regression_test.go +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package integration - -import ( - "context" - "strings" - "testing" - - "gvisor.dev/gvisor/pkg/test/dockerutil" -) - -// Test that UDS can be created using overlay when parent directory is in lower -// layer only (b/134090485). -// -// Prerequisite: the directory where the socket file is created must not have -// been open for write before bind(2) is called. -func TestBindOverlay(t *testing.T) { - ctx := context.Background() - d := dockerutil.MakeContainer(ctx, t) - defer d.CleanUp(ctx) - - // Run the container. - got, err := d.Run(ctx, dockerutil.RunOpts{ - Image: "basic/ubuntu", - }, "bash", "-c", "nc -q -1 -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -q 0 -U /var/run/sock && wait $p") - if err != nil { - t.Fatalf("docker run failed: %v", err) - } - - // Check the output contains what we want. - if want := "foobar-asdf"; !strings.Contains(got, want) { - t.Fatalf("docker run output is missing %q: %s", want, got) - } -} diff --git a/test/fsstress/BUILD b/test/fsstress/BUILD index d262c8554..e74e7fff2 100644 --- a/test/fsstress/BUILD +++ b/test/fsstress/BUILD @@ -14,9 +14,7 @@ go_test( "manual", "local", ], - deps = [ - "//pkg/test/dockerutil", - ], + deps = ["//pkg/test/dockerutil"], ) go_library( diff --git a/test/fsstress/fsstress_test.go b/test/fsstress/fsstress_test.go index 300c21ceb..d53c8f90d 100644 --- a/test/fsstress/fsstress_test.go +++ b/test/fsstress/fsstress_test.go @@ -17,7 +17,9 @@ package fsstress import ( "context" + "flag" "math/rand" + "os" "strconv" "strings" "testing" @@ -30,33 +32,44 @@ func init() { rand.Seed(int64(time.Now().Nanosecond())) } -func fsstress(t *testing.T, dir string) { +func TestMain(m *testing.M) { + dockerutil.EnsureSupportedDockerVersion() + flag.Parse() + os.Exit(m.Run()) +} + +type config struct { + operations string + processes string + target string +} + +func fsstress(t *testing.T, conf config) { ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) - const ( - operations = "10000" - processes = "100" - image = "basic/fsstress" - ) + const image = "basic/fsstress" seed := strconv.FormatUint(uint64(rand.Uint32()), 10) - args := []string{"-d", dir, "-n", operations, "-p", processes, "-s", seed, "-X"} - t.Logf("Repro: docker run --rm --runtime=runsc %s %s", image, strings.Join(args, "")) + args := []string{"-d", conf.target, "-n", conf.operations, "-p", conf.processes, "-s", seed, "-X"} + t.Logf("Repro: docker run --rm --runtime=%s gvisor.dev/images/%s %s", dockerutil.Runtime(), image, strings.Join(args, " ")) out, err := d.Run(ctx, dockerutil.RunOpts{Image: image}, args...) if err != nil { t.Fatalf("docker run failed: %v\noutput: %s", err, out) } - lines := strings.SplitN(out, "\n", 2) - if len(lines) > 1 || !strings.HasPrefix(out, "seed =") { + // This is to catch cases where fsstress spews out error messages during clean + // up but doesn't return error. + if len(out) > 0 { t.Fatalf("unexpected output: %s", out) } } -func TestFsstressGofer(t *testing.T) { - fsstress(t, "/test") -} - func TestFsstressTmpfs(t *testing.T) { - fsstress(t, "/tmp") + // This takes between 10s to run on my machine. Adjust as needed. + cfg := config{ + operations: "5000", + processes: "20", + target: "/tmp", + } + fsstress(t, cfg) } diff --git a/test/image/image_test.go b/test/image/image_test.go index 968e62f63..952264173 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -183,7 +183,10 @@ func TestMysql(t *testing.T) { // Start the container. if err := server.Spawn(ctx, dockerutil.RunOpts{ Image: "basic/mysql", - Env: []string{"MYSQL_ROOT_PASSWORD=foobar123"}, + Env: []string{ + "MYSQL_ROOT_PASSWORD=foobar123", + "MYSQL_ROOT_HOST=%", // Allow anyone to connect to the server. + }, }); err != nil { t.Fatalf("docker run failed: %v", err) } diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 34e83ec49..634c15727 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -246,6 +246,12 @@ ALL_TESTS = [ expect_netstack_failure = True, ), PacketimpactTestInfo( + name = "tcp_listen_backlog", + ), + PacketimpactTestInfo( + name = "tcp_syncookie", + ), + PacketimpactTestInfo( name = "icmpv6_param_problem", ), PacketimpactTestInfo( diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 92103c1e9..83ff70951 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -385,6 +385,26 @@ packetimpact_testbench( ], ) +packetimpact_testbench( + name = "tcp_listen_backlog", + srcs = ["tcp_listen_backlog_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_testbench( + name = "tcp_syncookie", + srcs = ["tcp_syncookie_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + validate_all_tests() [packetimpact_go_test( diff --git a/test/packetimpact/tests/tcp_listen_backlog_test.go b/test/packetimpact/tests/tcp_listen_backlog_test.go new file mode 100644 index 000000000..26c812d0a --- /dev/null +++ b/test/packetimpact/tests/tcp_listen_backlog_test.go @@ -0,0 +1,86 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_listen_backlog_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +// TestTCPListenBacklog tests for a listening endpoint behavior: +// (1) reply to more SYNs than what is configured as listen backlog +// (2) ignore ACKs (that complete a handshake) when the accept queue is full +// (3) ignore incoming SYNs when the accept queue is full +func TestTCPListenBacklog(t *testing.T) { + dut := testbench.NewDUT(t) + + // Listening endpoint accepts one more connection than the listen backlog. + listenFd, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 0 /*backlog*/) + + var establishedConn testbench.TCPIPv4 + var incompleteConn testbench.TCPIPv4 + + // Test if the DUT listener replies to more SYNs than listen backlog+1 + for i, conn := range []*testbench.TCPIPv4{&establishedConn, &incompleteConn} { + *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + // Expect dut connection to have transitioned to SYN-RCVD state. + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK for %d connection, %s", i, err) + } + } + defer establishedConn.Close(t) + defer incompleteConn.Close(t) + + // Send the ACK to complete handshake. + establishedConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + dut.PollOne(t, listenFd, unix.POLLIN, time.Second) + + // Send the ACK to complete handshake, expect this to be ignored by the + // listener. + incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + + // Drain the accept queue to enable poll for subsequent connections on the + // listener. + dut.Accept(t, listenFd) + + // The ACK for the incomplete connection should be ignored by the + // listening endpoint and the poll on listener should now time out. + if pfds := dut.Poll(t, []unix.PollFd{{Fd: listenFd, Events: unix.POLLIN}}, time.Second); len(pfds) != 0 { + t.Fatalf("got dut.Poll(...) = %#v", pfds) + } + + // Re-send the ACK to complete handshake and re-fill the accept-queue. + incompleteConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + dut.PollOne(t, listenFd, unix.POLLIN, time.Second) + + // Now initiate a new connection when the accept queue is full. + connectingConn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer connectingConn.Close(t) + // Expect dut connection to drop the SYN and let the client stay in SYN_SENT state. + connectingConn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + if got, err := connectingConn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err == nil { + t.Fatalf("expected no SYN-ACK, but got %s", got) + } +} diff --git a/test/packetimpact/tests/tcp_syncookie_test.go b/test/packetimpact/tests/tcp_syncookie_test.go new file mode 100644 index 000000000..1c21c62ff --- /dev/null +++ b/test/packetimpact/tests/tcp_syncookie_test.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_syncookie_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +// TestSynCookie test if the DUT listener is replying back using syn cookies. +// The test does not complete the handshake by not sending the ACK to SYNACK. +// When syncookies are not used, this forces the listener to retransmit SYNACK. +// And when syncookies are being used, there is no such retransmit. +func TestTCPSynCookie(t *testing.T) { + dut := testbench.NewDUT(t) + + // Listening endpoint accepts one more connection than the listen backlog. + _, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + + var withoutSynCookieConn testbench.TCPIPv4 + var withSynCookieConn testbench.TCPIPv4 + + // Test if the DUT listener replies to more SYNs than listen backlog+1 + for _, conn := range []*testbench.TCPIPv4{&withoutSynCookieConn, &withSynCookieConn} { + *conn = dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + } + defer withoutSynCookieConn.Close(t) + defer withSynCookieConn.Close(t) + + checkSynAck := func(t *testing.T, conn *testbench.TCPIPv4, expectRetransmit bool) { + // Expect dut connection to have transitioned to SYN-RCVD state. + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected SYN-ACK, but got %s", err) + } + + // If the DUT listener is using syn cookies, it will not retransmit SYNACK + got, err := conn.ExpectData(t, &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)), Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, 2*time.Second) + if expectRetransmit && err != nil { + t.Fatalf("expected retransmitted SYN-ACK, but got %s", err) + } + if !expectRetransmit && err == nil { + t.Fatalf("expected no retransmitted SYN-ACK, but got %s", got) + } + } + + t.Run("without syncookies", func(t *testing.T) { checkSynAck(t, &withoutSynCookieConn, true /*expectRetransmit*/) }) + t.Run("with syncookies", func(t *testing.T) { checkSynAck(t, &withSynCookieConn, false /*expectRetransmit*/) }) +} diff --git a/test/perf/BUILD b/test/perf/BUILD index ed899ac22..71982fc4d 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -35,7 +35,7 @@ syscall_test( ) syscall_test( - size = "enormous", + size = "large", debug = False, tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", @@ -48,7 +48,7 @@ syscall_test( ) syscall_test( - size = "enormous", + size = "large", debug = False, tags = ["nogotsan"], test = "//test/perf/linux:gettid_benchmark", @@ -106,7 +106,7 @@ syscall_test( ) syscall_test( - size = "enormous", + size = "large", debug = False, test = "//test/perf/linux:signal_benchmark", ) @@ -124,9 +124,10 @@ syscall_test( ) syscall_test( - size = "enormous", + size = "large", add_overlay = True, debug = False, + tags = ["nogotsan"], test = "//test/perf/linux:unlink_benchmark", ) diff --git a/test/perf/linux/getpid_benchmark.cc b/test/perf/linux/getpid_benchmark.cc index db74cb264..047a034bd 100644 --- a/test/perf/linux/getpid_benchmark.cc +++ b/test/perf/linux/getpid_benchmark.cc @@ -31,6 +31,24 @@ void BM_Getpid(benchmark::State& state) { BENCHMARK(BM_Getpid); +#ifdef __x86_64__ + +#define SYSNO_STR1(x) #x +#define SYSNO_STR(x) SYSNO_STR1(x) + +// BM_GetpidOpt uses the most often pattern of calling system calls: +// mov $SYS_XXX, %eax; syscall. +void BM_GetpidOpt(benchmark::State& state) { + for (auto s : state) { + __asm__("movl $" SYSNO_STR(SYS_getpid) ", %%eax\n" + "syscall\n" + : : : "rax", "rcx", "r11"); + } +} + +BENCHMARK(BM_GetpidOpt); +#endif // __x86_64__ + } // namespace } // namespace testing diff --git a/test/runtimes/defs.bzl b/test/runtimes/defs.bzl index 702522d86..2550b61a3 100644 --- a/test/runtimes/defs.bzl +++ b/test/runtimes/defs.bzl @@ -75,7 +75,6 @@ def runtime_test(name, **kwargs): "local", "manual", ], - size = "enormous", **kwargs ) diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index ef299799e..affcae8fd 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -244,6 +244,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:verity_ioctl_test", +) + +syscall_test( test = "//test/syscalls/linux:iptables_test", ) @@ -318,6 +322,10 @@ syscall_test( ) syscall_test( + test = "//test/syscalls/linux:verity_mount_test", +) + +syscall_test( size = "medium", test = "//test/syscalls/linux:mremap_test", ) @@ -772,8 +780,7 @@ syscall_test( ) syscall_test( - # NOTE(b/116636318): Large sendmsg may stall a long time. - size = "enormous", + flaky = 1, # NOTE(b/116636318): Large sendmsg may stall a long time. shard_count = more_shards, test = "//test/syscalls/linux:socket_unix_dgram_local_test", ) @@ -791,8 +798,7 @@ syscall_test( ) syscall_test( - # NOTE(b/116636318): Large sendmsg may stall a long time. - size = "enormous", + flaky = 1, # NOTE(b/116636318): Large sendmsg may stall a long time. shard_count = more_shards, test = "//test/syscalls/linux:socket_unix_seqpacket_local_test", ) @@ -995,3 +1001,7 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:processes_test", ) + +syscall_test( + test = "//test/syscalls/linux:cgroup_test", +) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index e565c6e77..bc2c7c0e3 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -154,7 +154,6 @@ cc_library( defines = select_system(), deps = default_net_util() + [ gtest, - "//net/util:ports", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1015,6 +1014,22 @@ cc_binary( ], ) +cc_binary( + name = "verity_ioctl_test", + testonly = 1, + srcs = ["verity_ioctl.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + gtest, + "//test/util:fs_util", + "//test/util:mount_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + cc_library( name = "iptables_types", testonly = 1, @@ -1305,6 +1320,20 @@ cc_binary( ) cc_binary( + name = "verity_mount_test", + testonly = 1, + srcs = ["verity_mount.cc"], + linkstatic = 1, + deps = [ + gtest, + "//test/util:capability_util", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + ], +) + +cc_binary( name = "mremap_test", testonly = 1, srcs = ["mremap.cc"], @@ -4206,3 +4235,24 @@ cc_binary( "//test/util:test_util", ], ) + +cc_binary( + name = "cgroup_test", + testonly = 1, + srcs = ["cgroup.cc"], + linkstatic = 1, + deps = [ + "//test/util:capability_util", + "//test/util:cgroup_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "@com_google_absl//absl/strings", + gtest, + "//test/util:posix_error", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) diff --git a/test/syscalls/linux/accept_bind.cc b/test/syscalls/linux/accept_bind.cc index f65a14fb8..119a1466b 100644 --- a/test/syscalls/linux/accept_bind.cc +++ b/test/syscalls/linux/accept_bind.cc @@ -67,6 +67,42 @@ TEST_P(AllSocketPairTest, ListenDecreaseBacklog) { SyscallSucceeds()); } +TEST_P(AllSocketPairTest, ListenBacklogSizes_NoRandomSave) { + DisableSave ds; + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + int type; + socklen_t typelen = sizeof(type); + EXPECT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_TYPE, &type, &typelen), + SyscallSucceeds()); + + std::array<int, 3> backlogs = {-1, 0, 1}; + for (auto& backlog : backlogs) { + ASSERT_THAT(listen(sockets->first_fd(), backlog), SyscallSucceeds()); + + int expected_accepts = backlog; + if (backlog < 0) { + expected_accepts = 1024; + } + for (int i = 0; i < expected_accepts; i++) { + SCOPED_TRACE(absl::StrCat("i=", i)); + // Connect to the listening socket. + const FileDescriptor client = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, type, 0)); + ASSERT_THAT(connect(client.get(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + const FileDescriptor accepted = ASSERT_NO_ERRNO_AND_VALUE( + Accept(sockets->first_fd(), nullptr, nullptr)); + } + } +} + TEST_P(AllSocketPairTest, ListenWithoutBind) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); ASSERT_THAT(listen(sockets->first_fd(), 0), SyscallFailsWithErrno(EINVAL)); diff --git a/test/syscalls/linux/cgroup.cc b/test/syscalls/linux/cgroup.cc new file mode 100644 index 000000000..a1006a978 --- /dev/null +++ b/test/syscalls/linux/cgroup.cc @@ -0,0 +1,421 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// All tests in this file rely on being about to mount and unmount cgroupfs, +// which isn't expected to work, or be safe on a general linux system. + +#include <sys/mount.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_split.h" +#include "test/util/capability_util.h" +#include "test/util/cgroup_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { +namespace { + +using ::testing::_; +using ::testing::Ge; +using ::testing::Gt; + +std::vector<std::string> known_controllers = {"cpu", "cpuset", "cpuacct", + "memory"}; + +bool CgroupsAvailable() { + return IsRunningOnGvisor() && !IsRunningWithVFS1() && + TEST_CHECK_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN)); +} + +TEST(Cgroup, MountSucceeds) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + EXPECT_NO_ERRNO(c.ContainsCallingProcess()); +} + +TEST(Cgroup, SeparateMounts) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + + for (const auto& ctl : known_controllers) { + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs(ctl)); + EXPECT_NO_ERRNO(c.ContainsCallingProcess()); + } +} + +TEST(Cgroup, AllControllersImplicit) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + + absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + for (const auto& ctl : known_controllers) { + EXPECT_TRUE(cgroups_entries.contains(ctl)) + << absl::StreamFormat("ctl=%s", ctl); + } + EXPECT_EQ(cgroups_entries.size(), known_controllers.size()); +} + +TEST(Cgroup, AllControllersExplicit) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("all")); + + absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + for (const auto& ctl : known_controllers) { + EXPECT_TRUE(cgroups_entries.contains(ctl)) + << absl::StreamFormat("ctl=%s", ctl); + } + EXPECT_EQ(cgroups_entries.size(), known_controllers.size()); +} + +TEST(Cgroup, ProcsAndTasks) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + absl::flat_hash_set<pid_t> pids = ASSERT_NO_ERRNO_AND_VALUE(c.Procs()); + absl::flat_hash_set<pid_t> tids = ASSERT_NO_ERRNO_AND_VALUE(c.Tasks()); + + EXPECT_GE(tids.size(), pids.size()) << "Found more processes than threads"; + + // Pids should be a strict subset of tids. + for (auto it = pids.begin(); it != pids.end(); ++it) { + EXPECT_TRUE(tids.contains(*it)) + << absl::StreamFormat("Have pid %d, but no such tid", *it); + } +} + +TEST(Cgroup, ControllersMustBeInUniqueHierarchy) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + // Hierarchy #1: all controllers. + Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + // Hierarchy #2: memory. + // + // This should conflict since memory is already in hierarchy #1, and the two + // hierarchies have different sets of controllers, so this mount can't be a + // view into hierarchy #1. + EXPECT_THAT(m.MountCgroupfs("memory"), PosixErrorIs(EBUSY, _)) + << "Memory controller mounted on two hierarchies"; + EXPECT_THAT(m.MountCgroupfs("cpu"), PosixErrorIs(EBUSY, _)) + << "CPU controller mounted on two hierarchies"; +} + +TEST(Cgroup, UnmountFreesControllers) { + SKIP_IF(!CgroupsAvailable()); + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + // All controllers are now attached to all's hierarchy. Attempting new mount + // with any individual controller should fail. + EXPECT_THAT(m.MountCgroupfs("memory"), PosixErrorIs(EBUSY, _)) + << "Memory controller mounted on two hierarchies"; + + // Unmount the "all" hierarchy. This should enable any controller to be + // mounted on a new hierarchy again. + ASSERT_NO_ERRNO(m.Unmount(all)); + EXPECT_NO_ERRNO(m.MountCgroupfs("memory")); + EXPECT_NO_ERRNO(m.MountCgroupfs("cpu")); +} + +TEST(Cgroup, OnlyContainsControllerSpecificFiles) { + SKIP_IF(!CgroupsAvailable()); + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup mem = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + EXPECT_THAT(Exists(mem.Relpath("memory.usage_in_bytes")), + IsPosixErrorOkAndHolds(true)); + // CPU files shouldn't exist in memory cgroups. + EXPECT_THAT(Exists(mem.Relpath("cpu.cfs_period_us")), + IsPosixErrorOkAndHolds(false)); + EXPECT_THAT(Exists(mem.Relpath("cpu.cfs_quota_us")), + IsPosixErrorOkAndHolds(false)); + EXPECT_THAT(Exists(mem.Relpath("cpu.shares")), IsPosixErrorOkAndHolds(false)); + + Cgroup cpu = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + EXPECT_THAT(Exists(cpu.Relpath("cpu.cfs_period_us")), + IsPosixErrorOkAndHolds(true)); + EXPECT_THAT(Exists(cpu.Relpath("cpu.cfs_quota_us")), + IsPosixErrorOkAndHolds(true)); + EXPECT_THAT(Exists(cpu.Relpath("cpu.shares")), IsPosixErrorOkAndHolds(true)); + // Memory files shouldn't exist in cpu cgroups. + EXPECT_THAT(Exists(cpu.Relpath("memory.usage_in_bytes")), + IsPosixErrorOkAndHolds(false)); +} + +TEST(Cgroup, InvalidController) { + SKIP_IF(!CgroupsAvailable()); + + TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string mopts = "this-controller-is-invalid"; + EXPECT_THAT( + mount("none", mountpoint.path().c_str(), "cgroup", 0, mopts.c_str()), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(Cgroup, MoptAllMustBeExclusive) { + SKIP_IF(!CgroupsAvailable()); + + TempPath mountpoint = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string mopts = "all,cpu"; + EXPECT_THAT( + mount("none", mountpoint.path().c_str(), "cgroup", 0, mopts.c_str()), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(MemoryCgroup, MemoryUsageInBytes) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + EXPECT_THAT(c.ReadIntegerControlFile("memory.usage_in_bytes"), + IsPosixErrorOkAndHolds(Gt(0))); +} + +TEST(CPUCgroup, ControlFilesHaveDefaultValues) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + EXPECT_THAT(c.ReadIntegerControlFile("cpu.cfs_quota_us"), + IsPosixErrorOkAndHolds(-1)); + EXPECT_THAT(c.ReadIntegerControlFile("cpu.cfs_period_us"), + IsPosixErrorOkAndHolds(100000)); + EXPECT_THAT(c.ReadIntegerControlFile("cpu.shares"), + IsPosixErrorOkAndHolds(1024)); +} + +TEST(CPUAcctCgroup, CPUAcctUsage) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct")); + + const int64_t usage = + ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage")); + const int64_t usage_user = + ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage_user")); + const int64_t usage_sys = + ASSERT_NO_ERRNO_AND_VALUE(c.ReadIntegerControlFile("cpuacct.usage_sys")); + + EXPECT_GE(usage, 0); + EXPECT_GE(usage_user, 0); + EXPECT_GE(usage_sys, 0); + + EXPECT_GE(usage_user + usage_sys, usage); +} + +TEST(CPUAcctCgroup, CPUAcctStat) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpuacct")); + + std::string stat = + ASSERT_NO_ERRNO_AND_VALUE(c.ReadControlFile("cpuacct.stat")); + + // We're expecting the contents of "cpuacct.stat" to look similar to this: + // + // user 377986 + // system 220662 + + std::vector<absl::string_view> lines = + absl::StrSplit(stat, '\n', absl::SkipEmpty()); + ASSERT_EQ(lines.size(), 2); + + std::vector<absl::string_view> user_tokens = + StrSplit(lines[0], absl::ByChar(' ')); + EXPECT_EQ(user_tokens[0], "user"); + EXPECT_THAT(Atoi<int64_t>(user_tokens[1]), IsPosixErrorOkAndHolds(Ge(0))); + + std::vector<absl::string_view> sys_tokens = + StrSplit(lines[1], absl::ByChar(' ')); + EXPECT_EQ(sys_tokens[0], "system"); + EXPECT_THAT(Atoi<int64_t>(sys_tokens[1]), IsPosixErrorOkAndHolds(Ge(0))); +} + +TEST(ProcCgroups, Empty) { + SKIP_IF(!CgroupsAvailable()); + + absl::flat_hash_map<std::string, CgroupsEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + // No cgroups mounted yet, we should have no entries. + EXPECT_TRUE(entries.empty()); +} + +TEST(ProcCgroups, ProcCgroupsEntries) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + + Cgroup mem = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + absl::flat_hash_map<std::string, CgroupsEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + EXPECT_EQ(entries.size(), 1); + ASSERT_TRUE(entries.contains("memory")); + CgroupsEntry mem_e = entries["memory"]; + EXPECT_EQ(mem_e.subsys_name, "memory"); + EXPECT_GE(mem_e.hierarchy, 1); + // Expect a single root cgroup. + EXPECT_EQ(mem_e.num_cgroups, 1); + // Cgroups are currently always enabled when mounted. + EXPECT_TRUE(mem_e.enabled); + + // Add a second cgroup, and check for new entry. + + Cgroup cpu = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + EXPECT_EQ(entries.size(), 2); + EXPECT_TRUE(entries.contains("memory")); // Still have memory entry. + ASSERT_TRUE(entries.contains("cpu")); + CgroupsEntry cpu_e = entries["cpu"]; + EXPECT_EQ(cpu_e.subsys_name, "cpu"); + EXPECT_GE(cpu_e.hierarchy, 1); + EXPECT_EQ(cpu_e.num_cgroups, 1); + EXPECT_TRUE(cpu_e.enabled); + + // Separate hierarchies, since controllers were mounted separately. + EXPECT_NE(mem_e.hierarchy, cpu_e.hierarchy); +} + +TEST(ProcCgroups, UnmountRemovesEntries) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup cg = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu,memory")); + absl::flat_hash_map<std::string, CgroupsEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + EXPECT_EQ(entries.size(), 2); + + ASSERT_NO_ERRNO(m.Unmount(cg)); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + EXPECT_TRUE(entries.empty()); +} + +TEST(ProcPIDCgroup, Empty) { + SKIP_IF(!CgroupsAvailable()); + + absl::flat_hash_map<std::string, PIDCgroupEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_TRUE(entries.empty()); +} + +TEST(ProcPIDCgroup, Entries) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + + absl::flat_hash_map<std::string, PIDCgroupEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_EQ(entries.size(), 1); + PIDCgroupEntry mem_e = entries["memory"]; + EXPECT_GE(mem_e.hierarchy, 1); + EXPECT_EQ(mem_e.controllers, "memory"); + EXPECT_EQ(mem_e.path, "/"); + + Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_EQ(entries.size(), 2); + EXPECT_TRUE(entries.contains("memory")); // Still have memory entry. + PIDCgroupEntry cpu_e = entries["cpu"]; + EXPECT_GE(cpu_e.hierarchy, 1); + EXPECT_EQ(cpu_e.controllers, "cpu"); + EXPECT_EQ(cpu_e.path, "/"); + + // Separate hierarchies, since controllers were mounted separately. + EXPECT_NE(mem_e.hierarchy, cpu_e.hierarchy); +} + +TEST(ProcPIDCgroup, UnmountRemovesEntries) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup all = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("")); + + absl::flat_hash_map<std::string, PIDCgroupEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_GT(entries.size(), 0); + + ASSERT_NO_ERRNO(m.Unmount(all)); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + EXPECT_TRUE(entries.empty()); +} + +TEST(ProcCgroup, PIDCgroupMatchesCgroups) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory")); + Cgroup c1 = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("cpu")); + + absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + absl::flat_hash_map<std::string, PIDCgroupEntry> pid_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + + CgroupsEntry cgroup_mem = cgroups_entries["memory"]; + PIDCgroupEntry pid_mem = pid_entries["memory"]; + + EXPECT_EQ(cgroup_mem.hierarchy, pid_mem.hierarchy); + + CgroupsEntry cgroup_cpu = cgroups_entries["cpu"]; + PIDCgroupEntry pid_cpu = pid_entries["cpu"]; + + EXPECT_EQ(cgroup_cpu.hierarchy, pid_cpu.hierarchy); + EXPECT_NE(cgroup_mem.hierarchy, cgroup_cpu.hierarchy); + EXPECT_NE(pid_mem.hierarchy, pid_cpu.hierarchy); +} + +TEST(ProcCgroup, MultiControllerHierarchy) { + SKIP_IF(!CgroupsAvailable()); + + Mounter m(ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir())); + Cgroup c = ASSERT_NO_ERRNO_AND_VALUE(m.MountCgroupfs("memory,cpu")); + + absl::flat_hash_map<std::string, CgroupsEntry> cgroups_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcCgroupsEntries()); + + CgroupsEntry mem_e = cgroups_entries["memory"]; + CgroupsEntry cpu_e = cgroups_entries["cpu"]; + + // Both controllers should have the same hierarchy ID. + EXPECT_EQ(mem_e.hierarchy, cpu_e.hierarchy); + + absl::flat_hash_map<std::string, PIDCgroupEntry> pid_entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcPIDCgroupEntries(getpid())); + + // Expecting an entry listing both controllers, that matches the previous + // hierarchy ID. Note that the controllers are listed in alphabetical order. + PIDCgroupEntry pid_e = pid_entries["cpu,memory"]; + EXPECT_EQ(pid_e.hierarchy, mem_e.hierarchy); +} + +} // namespace +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/fpsig_fork.cc b/test/syscalls/linux/fpsig_fork.cc index c47567b4e..79b0596c4 100644 --- a/test/syscalls/linux/fpsig_fork.cc +++ b/test/syscalls/linux/fpsig_fork.cc @@ -44,6 +44,8 @@ namespace { #define SET_FP0(var) SET_FPREG(var, d0) #endif +#define DEFAULT_MXCSR 0x1f80 + int parent, child; void sigusr1(int s, siginfo_t* siginfo, void* _uc) { @@ -57,6 +59,12 @@ void sigusr1(int s, siginfo_t* siginfo, void* _uc) { uint64_t got; GET_FP0(got); TEST_CHECK_MSG(val == got, "Basic FP check failed in sigusr1()"); + +#ifdef __x86_64 + uint32_t mxcsr; + __asm__("STMXCSR %0" : "=m"(mxcsr)); + TEST_CHECK_MSG(mxcsr == DEFAULT_MXCSR, "Unexpected mxcsr"); +#endif } TEST(FPSigTest, Fork) { @@ -125,6 +133,55 @@ TEST(FPSigTest, Fork) { } } +#ifdef __x86_64__ +TEST(FPSigTest, ForkWithZeroMxcsr) { + parent = getpid(); + pid_t parent_tid = gettid(); + + struct sigaction sa = {}; + sigemptyset(&sa.sa_mask); + sa.sa_flags = SA_SIGINFO; + sa.sa_sigaction = sigusr1; + ASSERT_THAT(sigaction(SIGUSR1, &sa, nullptr), SyscallSucceeds()); + + // The control bits of the MXCSR register are callee-saved (preserved across + // calls), while the status bits are caller-saved (not preserved). + uint32_t expected = 0, origin; + __asm__("STMXCSR %0" : "=m"(origin)); + __asm__("LDMXCSR %0" : : "m"(expected)); + + asm volatile( + "movl %[killnr], %%eax;" + "movl %[parent], %%edi;" + "movl %[tid], %%esi;" + "movl %[sig], %%edx;" + "syscall;" + : + : [killnr] "i"(__NR_tgkill), [parent] "rm"(parent), + [tid] "rm"(parent_tid), [sig] "i"(SIGUSR1) + : "rax", "rdi", "rsi", "rdx", + // Clobbered by syscall. + "rcx", "r11"); + + uint32_t got; + __asm__("STMXCSR %0" : "=m"(got)); + __asm__("LDMXCSR %0" : : "m"(origin)); + + if (getpid() == parent) { // Parent. + int status; + ASSERT_THAT(waitpid(child, &status, 0), SyscallSucceedsWithValue(child)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + } + + // TEST_CHECK_MSG since this may run in the child. + TEST_CHECK_MSG(expected == got, "Bad mxcsr value"); + + if (getpid() != parent) { // Child. + _exit(0); + } +} +#endif + } // namespace } // namespace testing diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc index 28f51a3bf..8c5732147 100644 --- a/test/syscalls/linux/semaphore.cc +++ b/test/syscalls/linux/semaphore.cc @@ -234,14 +234,6 @@ TEST(SemaphoreTest, SemTimedOpBlock) { AutoSem sem(semget(IPC_PRIVATE, 1, 0600 | IPC_CREAT)); ASSERT_THAT(sem.get(), SyscallSucceeds()); - ScopedThread th([&sem] { - absl::SleepFor(absl::Milliseconds(100)); - - struct sembuf buf = {}; - buf.sem_op = 1; - ASSERT_THAT(RetryEINTR(semop)(sem.get(), &buf, 1), SyscallSucceeds()); - }); - struct sembuf buf = {}; buf.sem_op = -1; struct timespec timeout = {}; diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 597b5bcb1..d391363fb 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -489,13 +489,6 @@ void TestListenWhileConnect(const TestParam& param, TestAddress const& listener = param.listener; TestAddress const& connector = param.connector; - constexpr int kBacklog = 2; - // Linux completes one more connection than the listen backlog argument. - // To ensure that there is at least one client connection that stays in - // connecting state, keep 2 more client connections than the listen backlog. - // gVisor differs in this behavior though, gvisor.dev/issue/3153. - constexpr int kClients = kBacklog + 2; - // Create the listening socket. FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); @@ -503,6 +496,13 @@ void TestListenWhileConnect(const TestParam& param, ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), SyscallSucceeds()); + // This test is only interested in deterministically getting a socket in + // connecting state. For that, we use a listen backlog of zero which would + // mean there is exactly one connection that gets established and is enqueued + // to the accept queue. We poll on the listener to ensure that is enqueued. + // After that the subsequent client connect will stay in connecting state as + // the accept queue is full. + constexpr int kBacklog = 0; ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); // Get the port bound by the listening socket. @@ -515,42 +515,49 @@ void TestListenWhileConnect(const TestParam& param, sockaddr_storage conn_addr = connector.addr; ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - std::vector<FileDescriptor> clients; - for (int i = 0; i < kClients; i++) { - FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - clients.push_back(std::move(client)); - } + FileDescriptor established_client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + connect(established_client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + // Ensure that the accept queue has the completed connection. + constexpr int kTimeout = 10000; + pollfd pfd = { + .fd = listen_fd.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + + FileDescriptor connecting_client = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + // Keep the last client in connecting state. + int ret = + connect(connecting_client.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); } stopListen(listen_fd); - for (auto& client : clients) { - constexpr int kTimeout = 10000; + std::array<std::pair<int, int>, 2> sockets = { + std::make_pair(established_client.get(), ECONNRESET), + std::make_pair(connecting_client.get(), ECONNREFUSED), + }; + for (size_t i = 0; i < sockets.size(); i++) { + SCOPED_TRACE(absl::StrCat("i=", i)); + auto [fd, expected_errno] = sockets[i]; pollfd pfd = { - .fd = client.get(), - .events = POLLIN, + .fd = fd, }; - // When the listening socket is closed, then we expect the remote to reset - // the connection. - ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); - ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR); + // When the listening socket is closed, the peer would reset the connection. + EXPECT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + EXPECT_EQ(pfd.revents, POLLHUP | POLLERR); char c; - // Subsequent read can fail with: - // ECONNRESET: If the client connection was established and was reset by the - // remote. - // ECONNREFUSED: If the client connection failed to be established. - ASSERT_THAT(read(client.get(), &c, sizeof(c)), - AnyOf(SyscallFailsWithErrno(ECONNRESET), - SyscallFailsWithErrno(ECONNREFUSED))); - // The last client connection would be in connecting (SYN_SENT) state. - if (client.get() == clients[kClients - 1].get()) { - ASSERT_EQ(errno, ECONNREFUSED) << strerror(errno); - } + EXPECT_THAT(read(fd, &c, sizeof(c)), SyscallFailsWithErrno(expected_errno)); } } @@ -570,7 +577,59 @@ TEST_P(SocketInetLoopbackTest, TCPListenShutdownWhileConnect) { // random save as established connections which can't be delivered to the accept // queue because the queue is full are not correctly delivered after restore // causing the last accept to timeout on the restore. -TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { +TEST_P(SocketInetLoopbackTest, TCPAcceptBacklogSizes_NoRandomSave) { + auto const& param = GetParam(); + + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + + // Create the listening socket. + const FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + std::array<int, 3> backlogs = {-1, 0, 1}; + for (auto& backlog : backlogs) { + ASSERT_THAT(listen(listen_fd.get(), backlog), SyscallSucceeds()); + + int expected_accepts; + if (backlog < 0) { + expected_accepts = 1024; + } else { + expected_accepts = backlog + 1; + } + for (int i = 0; i < expected_accepts; i++) { + SCOPED_TRACE(absl::StrCat("i=", i)); + // Connect to the listening socket. + const FileDescriptor conn_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + ASSERT_THAT( + RetryEINTR(connect)(conn_fd.get(), + reinterpret_cast<struct sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + const FileDescriptor accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + } + } +} + +// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/ +// random save as established connections which can't be delivered to the accept +// queue because the queue is full are not correctly delivered after restore +// causing the last accept to timeout on the restore. +TEST_P(SocketInetLoopbackTest, TCPBacklog_NoRandomSave) { auto const& param = GetParam(); TestAddress const& listener = param.listener; @@ -595,6 +654,7 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); int i = 0; while (1) { + SCOPED_TRACE(absl::StrCat("i=", i)); int ret; // Connect to the listening socket. @@ -620,103 +680,133 @@ TEST_P(SocketInetLoopbackTest, TCPbacklog_NoRandomSave) { i++; } + int client_conns = i; + int accepted_conns = 0; for (; i != 0; i--) { - // Accept the connection. - // - // We have to assign a name to the accepted socket, as unamed temporary - // objects are destructed upon full evaluation of the expression it is in, - // potentially causing the connecting socket to fail to shutdown properly. - auto accepted = - ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + SCOPED_TRACE(absl::StrCat("i=", i)); + pollfd pfd = { + .fd = listen_fd.get(), + .events = POLLIN, + }; + // Look for incoming connections to accept. The last connect request could + // be established from the client side, but the ACK of the handshake could + // be dropped by the listener if the accept queue was filled up by the + // previous connect. + int ret; + ASSERT_THAT(ret = poll(&pfd, 1, 3000), SyscallSucceeds()); + if (ret == 0) break; + if (pfd.revents == POLLIN) { + // Accept the connection. + // + // We have to assign a name to the accepted socket, as unamed temporary + // objects are destructed upon full evaluation of the expression it is in, + // potentially causing the connecting socket to fail to shutdown properly. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + accepted_conns++; + } } + // We should accept at least listen backlog + 1 connections. As the stack is + // enqueuing established connections to the accept queue, newer SYNs could + // still be replied to causing those client connections would be accepted as + // we start dequeuing the queue. + ASSERT_GE(accepted_conns, kBacklogSize + 1); + ASSERT_GE(client_conns, accepted_conns); } -// Test if the stack completes atmost listen backlog number of client -// connections. It exercises the path of the stack that enqueues completed -// connections to accept queue vs new incoming SYNs. -TEST_P(SocketInetLoopbackTest, TCPConnectBacklog_NoRandomSave) { - const auto& param = GetParam(); - const TestAddress& listener = param.listener; - const TestAddress& connector = param.connector; +// TODO(b/157236388): Remove _NoRandomSave once bug is fixed. Test fails w/ +// random save as established connections which can't be delivered to the accept +// queue because the queue is full are not correctly delivered after restore +// causing the last accept to timeout on the restore. +TEST_P(SocketInetLoopbackTest, TCPBacklogAcceptAll_NoRandomSave) { + auto const& param = GetParam(); + TestAddress const& listener = param.listener; + TestAddress const& connector = param.connector; + // Create the listening socket. + FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); + sockaddr_storage listen_addr = listener.addr; + ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), + listener.addr_len), + SyscallSucceeds()); constexpr int kBacklog = 1; - // Keep the number of client connections more than the listen backlog. - // Linux completes one more connection than the listen backlog argument. - // gVisor differs in this behavior though, gvisor.dev/issue/3153. - int kClients = kBacklog + 2; - if (IsRunningOnGvisor()) { - kClients--; - } + ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); - // Run the following test for few iterations to test race between accept queue - // getting filled with incoming SYNs. - for (int num = 0; num < 10; num++) { - FileDescriptor listen_fd = ASSERT_NO_ERRNO_AND_VALUE( - Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP)); - sockaddr_storage listen_addr = listener.addr; - ASSERT_THAT(bind(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - listener.addr_len), - SyscallSucceeds()); - ASSERT_THAT(listen(listen_fd.get(), kBacklog), SyscallSucceeds()); + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT(getsockname(listen_fd.get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - socklen_t addrlen = listener.addr_len; - ASSERT_THAT( - getsockname(listen_fd.get(), reinterpret_cast<sockaddr*>(&listen_addr), - &addrlen), - SyscallSucceeds()); - uint16_t const port = - ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); - sockaddr_storage conn_addr = connector.addr; - ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + sockaddr_storage conn_addr = connector.addr; + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); - std::vector<FileDescriptor> clients; - // Issue multiple non-blocking client connects. - for (int i = 0; i < kClients; i++) { - FileDescriptor client = ASSERT_NO_ERRNO_AND_VALUE( - Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); - int ret = connect(client.get(), reinterpret_cast<sockaddr*>(&conn_addr), - connector.addr_len); - if (ret != 0) { - EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); - } - clients.push_back(std::move(client)); + // Fill up the accept queue and trigger more client connections which would be + // waiting to be accepted. + std::array<FileDescriptor, kBacklog + 1> established_clients; + for (auto& fd : established_clients) { + fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT(connect(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + } + std::array<FileDescriptor, kBacklog> waiting_clients; + for (auto& fd : waiting_clients) { + fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP)); + int ret = connect(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len); + if (ret != 0) { + EXPECT_THAT(ret, SyscallFailsWithErrno(EINPROGRESS)); } + } - // Now that client connects are issued, wait for the accept queue to get - // filled and ensure no new client connection is completed. - for (int i = 0; i < kClients; i++) { - pollfd pfd = { - .fd = clients[i].get(), - .events = POLLOUT, - }; - if (i < kClients - 1) { - // Poll for client side connection completions with a large timeout. - // We cannot poll on the listener side without calling accept as poll - // stays level triggered with non-zero accept queue length. - // - // Client side poll would not guarantee that the completed connection - // has been enqueued in to the acccept queue, but the fact that the - // listener ACKd the SYN, means that it cannot complete any new incoming - // SYNs when it has already ACKd for > backlog number of SYNs. - ASSERT_THAT(poll(&pfd, 1, 10000), SyscallSucceedsWithValue(1)) - << "num=" << num << " i=" << i << " kClients=" << kClients; - ASSERT_EQ(pfd.revents, POLLOUT) << "num=" << num << " i=" << i; - } else { - // Now that we expect accept queue filled up, ensure that the last - // client connection never completes with a smaller poll timeout. - ASSERT_THAT(poll(&pfd, 1, 1000), SyscallSucceedsWithValue(0)) - << "num=" << num << " i=" << i; - } + auto accept_connection = [&]() { + constexpr int kTimeout = 10000; + pollfd pfd = { + .fd = listen_fd.get(), + .events = POLLIN, + }; + ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + ASSERT_EQ(pfd.revents, POLLIN); + // Accept the connection. + // + // We have to assign a name to the accepted socket, as unamed temporary + // objects are destructed upon full evaluation of the expression it is in, + // potentially causing the connecting socket to fail to shutdown properly. + auto accepted = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_fd.get(), nullptr, nullptr)); + }; - ASSERT_THAT(close(clients[i].release()), SyscallSucceedsWithValue(0)) - << "num=" << num << " i=" << i; - } - clients.clear(); - // We close the listening side and open a new listener. We could instead - // drain the accept queue by calling accept() and reuse the listener, but - // that is racy as the retransmitted SYNs could get ACKd as we make room in - // the accept queue. - ASSERT_THAT(close(listen_fd.release()), SyscallSucceedsWithValue(0)); + // Ensure that we accept all client connections. The waiting connections would + // get enqueued as we drain the accept queue. + for (int i = 0; i < std::size(established_clients); i++) { + SCOPED_TRACE(absl::StrCat("established clients i=", i)); + accept_connection(); + } + + // The waiting client connections could be in one of these 2 states: + // (1) SYN_SENT: if the SYN was dropped because accept queue was full + // (2) ESTABLISHED: if the listener sent back a SYNACK, but may have dropped + // the ACK from the client if the accept queue was full (send out a data to + // re-send that ACK, to address that case). + for (int i = 0; i < std::size(waiting_clients); i++) { + SCOPED_TRACE(absl::StrCat("waiting clients i=", i)); + constexpr int kTimeout = 10000; + pollfd pfd = { + .fd = waiting_clients[i].get(), + .events = POLLOUT, + }; + EXPECT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1)); + EXPECT_EQ(pfd.revents, POLLOUT); + char c; + EXPECT_THAT(RetryEINTR(send)(waiting_clients[i].get(), &c, sizeof(c), 0), + SyscallSucceedsWithValue(sizeof(c))); + accept_connection(); } } diff --git a/test/syscalls/linux/verity_ioctl.cc b/test/syscalls/linux/verity_ioctl.cc new file mode 100644 index 000000000..dcd28f2c3 --- /dev/null +++ b/test/syscalls/linux/verity_ioctl.cc @@ -0,0 +1,133 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/mount.h> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/fs_util.h" +#include "test/util/mount_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +#ifndef FS_IOC_ENABLE_VERITY +#define FS_IOC_ENABLE_VERITY 1082156677 +#endif + +#ifndef FS_IOC_MEASURE_VERITY +#define FS_IOC_MEASURE_VERITY 3221513862 +#endif + +#ifndef FS_VERITY_FL +#define FS_VERITY_FL 1048576 +#endif + +#ifndef FS_IOC_GETFLAGS +#define FS_IOC_GETFLAGS 2148034049 +#endif + +struct fsverity_digest { + __u16 digest_algorithm; + __u16 digest_size; /* input/output */ + __u8 digest[]; +}; + +const int fsverity_max_digest_size = 64; +const int fsverity_default_digest_size = 32; + +class IoctlTest : public ::testing::Test { + protected: + void SetUp() override { + // Verity is implemented in VFS2. + SKIP_IF(IsRunningWithVFS1()); + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + // Mount a tmpfs file system, to be wrapped by a verity fs. + tmpfs_dir_ = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(mount("", tmpfs_dir_.path().c_str(), "tmpfs", 0, ""), + SyscallSucceeds()); + + // Create a new file in the tmpfs mount. + constexpr char kContents[] = "foobarbaz"; + file_ = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(tmpfs_dir_.path(), kContents, 0777)); + filename_ = Basename(file_.path()); + } + + TempPath tmpfs_dir_; + TempPath file_; + std::string filename_; +}; + +TEST_F(IoctlTest, Enable) { + // mount a verity fs on the existing tmpfs mount. + std::string mount_opts = "lower_path=" + tmpfs_dir_.path(); + auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT( + mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()), + SyscallSucceeds()); + + printf("verity path: %s, filename: %s\n", verity_dir.path().c_str(), + filename_.c_str()); + fflush(nullptr); + // Confirm that the verity flag is absent. + int flag = 0; + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777)); + ASSERT_THAT(ioctl(fd.get(), FS_IOC_GETFLAGS, &flag), SyscallSucceeds()); + EXPECT_EQ(flag & FS_VERITY_FL, 0); + + // Enable the file and confirm that the verity flag is present. + ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds()); + ASSERT_THAT(ioctl(fd.get(), FS_IOC_GETFLAGS, &flag), SyscallSucceeds()); + EXPECT_EQ(flag & FS_VERITY_FL, FS_VERITY_FL); +} + +TEST_F(IoctlTest, Measure) { + // mount a verity fs on the existing tmpfs mount. + std::string mount_opts = "lower_path=" + tmpfs_dir_.path(); + auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT( + mount("", verity_dir.path().c_str(), "verity", 0, mount_opts.c_str()), + SyscallSucceeds()); + + // Confirm that the file cannot be measured. + auto const fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(JoinPath(verity_dir.path(), filename_), O_RDONLY, 0777)); + int digest_size = sizeof(struct fsverity_digest) + fsverity_max_digest_size; + struct fsverity_digest *digest = + reinterpret_cast<struct fsverity_digest *>(malloc(digest_size)); + memset(digest, 0, digest_size); + digest->digest_size = fsverity_max_digest_size; + ASSERT_THAT(ioctl(fd.get(), FS_IOC_MEASURE_VERITY, digest), + SyscallFailsWithErrno(ENODATA)); + + // Enable the file and confirm that the file can be measured. + ASSERT_THAT(ioctl(fd.get(), FS_IOC_ENABLE_VERITY), SyscallSucceeds()); + ASSERT_THAT(ioctl(fd.get(), FS_IOC_MEASURE_VERITY, digest), + SyscallSucceeds()); + EXPECT_EQ(digest->digest_size, fsverity_default_digest_size); + free(digest); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/verity_mount.cc b/test/syscalls/linux/verity_mount.cc new file mode 100644 index 000000000..e73dd5599 --- /dev/null +++ b/test/syscalls/linux/verity_mount.cc @@ -0,0 +1,53 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <sys/mount.h> + +#include <iomanip> +#include <sstream> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/capability_util.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Mount verity file system on an existing gofer mount. +TEST(MountTest, MountExisting) { + // Verity is implemented in VFS2. + SKIP_IF(IsRunningWithVFS1()); + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + + // Mount a new tmpfs file system. + auto const tmpfs_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + ASSERT_THAT(mount("", tmpfs_dir.path().c_str(), "tmpfs", 0, ""), + SyscallSucceeds()); + + // Mount a verity file system on the existing gofer mount. + auto const verity_dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + std::string opts = "lower_path=" + tmpfs_dir.path(); + EXPECT_THAT(mount("", verity_dir.path().c_str(), "verity", 0, opts.c_str()), + SyscallSucceeds()); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/util/BUILD b/test/util/BUILD index e561f3daa..383de00ed 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -94,6 +94,7 @@ cc_library( ":file_descriptor", ":posix_error", "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", gtest, ], ) @@ -368,3 +369,20 @@ cc_library( testonly = 1, hdrs = ["temp_umask.h"], ) + +cc_library( + name = "cgroup_util", + testonly = 1, + srcs = ["cgroup_util.cc"], + hdrs = ["cgroup_util.h"], + deps = [ + ":cleanup", + ":fs_util", + ":mount_util", + ":posix_error", + ":temp_path", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + ], +) diff --git a/test/util/cgroup_util.cc b/test/util/cgroup_util.cc new file mode 100644 index 000000000..65d9c4986 --- /dev/null +++ b/test/util/cgroup_util.cc @@ -0,0 +1,223 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/cgroup_util.h" + +#include <sys/syscall.h> +#include <unistd.h> + +#include "absl/strings/str_split.h" +#include "test/util/fs_util.h" +#include "test/util/mount_util.h" + +namespace gvisor { +namespace testing { + +Cgroup::Cgroup(std::string path) : cgroup_path_(path) { + id_ = ++Cgroup::next_id_; + std::cerr << absl::StreamFormat("[cg#%d] <= %s", id_, cgroup_path_) + << std::endl; +} + +PosixErrorOr<std::string> Cgroup::ReadControlFile( + absl::string_view name) const { + std::string buf; + RETURN_IF_ERRNO(GetContents(Relpath(name), &buf)); + + const std::string alias_path = absl::StrFormat("[cg#%d]/%s", id_, name); + std::cerr << absl::StreamFormat("<contents of %s>", alias_path) << std::endl; + std::cerr << buf; + std::cerr << absl::StreamFormat("<end of %s>", alias_path) << std::endl; + + return buf; +} + +PosixErrorOr<int64_t> Cgroup::ReadIntegerControlFile( + absl::string_view name) const { + ASSIGN_OR_RETURN_ERRNO(const std::string buf, ReadControlFile(name)); + ASSIGN_OR_RETURN_ERRNO(const int64_t val, Atoi<int64_t>(buf)); + return val; +} + +PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::Procs() const { + ASSIGN_OR_RETURN_ERRNO(std::string buf, ReadControlFile("cgroup.procs")); + return ParsePIDList(buf); +} + +PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::Tasks() const { + ASSIGN_OR_RETURN_ERRNO(std::string buf, ReadControlFile("tasks")); + return ParsePIDList(buf); +} + +PosixError Cgroup::ContainsCallingProcess() const { + ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> procs, Procs()); + ASSIGN_OR_RETURN_ERRNO(const absl::flat_hash_set<pid_t> tasks, Tasks()); + const pid_t pid = getpid(); + const pid_t tid = syscall(SYS_gettid); + if (!procs.contains(pid)) { + return PosixError( + ENOENT, absl::StrFormat("Cgroup doesn't contain process %d", pid)); + } + if (!tasks.contains(tid)) { + return PosixError(ENOENT, + absl::StrFormat("Cgroup doesn't contain task %d", tid)); + } + return NoError(); +} + +PosixErrorOr<absl::flat_hash_set<pid_t>> Cgroup::ParsePIDList( + absl::string_view data) const { + absl::flat_hash_set<pid_t> res; + std::vector<absl::string_view> lines = absl::StrSplit(data, '\n'); + for (const std::string_view& line : lines) { + if (line.empty()) { + continue; + } + ASSIGN_OR_RETURN_ERRNO(const int32_t pid, Atoi<int32_t>(line)); + res.insert(static_cast<pid_t>(pid)); + } + return res; +} + +int64_t Cgroup::next_id_ = 0; + +PosixErrorOr<Cgroup> Mounter::MountCgroupfs(std::string mopts) { + ASSIGN_OR_RETURN_ERRNO(TempPath mountpoint, + TempPath::CreateDirIn(root_.path())); + ASSIGN_OR_RETURN_ERRNO( + Cleanup mount, Mount("none", mountpoint.path(), "cgroup", 0, mopts, 0)); + const std::string mountpath = mountpoint.path(); + std::cerr << absl::StreamFormat( + "Mount(\"none\", \"%s\", \"cgroup\", 0, \"%s\", 0) => OK", + mountpath, mopts) + << std::endl; + Cgroup cg = Cgroup(mountpath); + mountpoints_[cg.id()] = std::move(mountpoint); + mounts_[cg.id()] = std::move(mount); + return cg; +} + +PosixError Mounter::Unmount(const Cgroup& c) { + auto mount = mounts_.find(c.id()); + auto mountpoint = mountpoints_.find(c.id()); + + if (mount == mounts_.end() || mountpoint == mountpoints_.end()) { + return PosixError( + ESRCH, absl::StrFormat("No mount found for cgroupfs containing cg#%d", + c.id())); + } + + std::cerr << absl::StreamFormat("Unmount([cg#%d])", c.id()) << std::endl; + + // Simply delete the entries, their destructors will unmount and delete the + // mountpoint. Note the order is important to avoid errors: mount then + // mountpoint. + mounts_.erase(mount); + mountpoints_.erase(mountpoint); + + return NoError(); +} + +constexpr char kProcCgroupsHeader[] = + "#subsys_name\thierarchy\tnum_cgroups\tenabled"; + +PosixErrorOr<absl::flat_hash_map<std::string, CgroupsEntry>> +ProcCgroupsEntries() { + std::string content; + RETURN_IF_ERRNO(GetContents("/proc/cgroups", &content)); + + bool found_header = false; + absl::flat_hash_map<std::string, CgroupsEntry> entries; + std::vector<std::string> lines = absl::StrSplit(content, '\n'); + std::cerr << "<contents of /proc/cgroups>" << std::endl; + for (const std::string& line : lines) { + std::cerr << line << std::endl; + + if (!found_header) { + EXPECT_EQ(line, kProcCgroupsHeader); + found_header = true; + continue; + } + if (line.empty()) { + continue; + } + + // Parse a single entry from /proc/cgroups. + // + // Example entries, fields are tab separated in the real file: + // + // #subsys_name hierarchy num_cgroups enabled + // cpuset 12 35 1 + // cpu 3 222 1 + // ^ ^ ^ ^ + // 0 1 2 3 + + CgroupsEntry entry; + std::vector<std::string> fields = + StrSplit(line, absl::ByAnyChar(": \t"), absl::SkipEmpty()); + + entry.subsys_name = fields[0]; + ASSIGN_OR_RETURN_ERRNO(entry.hierarchy, Atoi<uint32_t>(fields[1])); + ASSIGN_OR_RETURN_ERRNO(entry.num_cgroups, Atoi<uint64_t>(fields[2])); + ASSIGN_OR_RETURN_ERRNO(const int enabled, Atoi<int>(fields[3])); + entry.enabled = enabled != 0; + + entries[entry.subsys_name] = entry; + } + std::cerr << "<end of /proc/cgroups>" << std::endl; + + return entries; +} + +PosixErrorOr<absl::flat_hash_map<std::string, PIDCgroupEntry>> +ProcPIDCgroupEntries(pid_t pid) { + const std::string path = absl::StrFormat("/proc/%d/cgroup", pid); + std::string content; + RETURN_IF_ERRNO(GetContents(path, &content)); + + absl::flat_hash_map<std::string, PIDCgroupEntry> entries; + std::vector<std::string> lines = absl::StrSplit(content, '\n'); + + std::cerr << absl::StreamFormat("<contents of %s>", path) << std::endl; + for (const std::string& line : lines) { + std::cerr << line << std::endl; + + if (line.empty()) { + continue; + } + + // Parse a single entry from /proc/<pid>/cgroup. + // + // Example entries: + // + // 2:cpu:/path/to/cgroup + // 1:memory:/ + + PIDCgroupEntry entry; + std::vector<std::string> fields = + absl::StrSplit(line, absl::ByChar(':'), absl::SkipEmpty()); + + ASSIGN_OR_RETURN_ERRNO(entry.hierarchy, Atoi<uint32_t>(fields[0])); + entry.controllers = fields[1]; + entry.path = fields[2]; + + entries[entry.controllers] = entry; + } + std::cerr << absl::StreamFormat("<end of %s>", path) << std::endl; + + return entries; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/cgroup_util.h b/test/util/cgroup_util.h new file mode 100644 index 000000000..b049559df --- /dev/null +++ b/test/util/cgroup_util.h @@ -0,0 +1,111 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_UTIL_CGROUP_UTIL_H_ +#define GVISOR_TEST_UTIL_CGROUP_UTIL_H_ + +#include <unistd.h> + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "test/util/cleanup.h" +#include "test/util/fs_util.h" +#include "test/util/temp_path.h" + +namespace gvisor { +namespace testing { + +// Cgroup represents a cgroup directory on a mounted cgroupfs. +class Cgroup { + public: + Cgroup(std::string path); + + uint64_t id() const { return id_; } + + std::string Relpath(absl::string_view leaf) const { + return JoinPath(cgroup_path_, leaf); + } + + // Returns the contents of a cgroup control file with the given name. + PosixErrorOr<std::string> ReadControlFile(absl::string_view name) const; + + // Reads the contents of a cgroup control with the given name, and attempts + // to parse it as an integer. + PosixErrorOr<int64_t> ReadIntegerControlFile(absl::string_view name) const; + + // Returns the thread ids of the leaders of thread groups managed by this + // cgroup. + PosixErrorOr<absl::flat_hash_set<pid_t>> Procs() const; + + PosixErrorOr<absl::flat_hash_set<pid_t>> Tasks() const; + + // ContainsCallingProcess checks whether the calling process is part of the + PosixError ContainsCallingProcess() const; + + private: + PosixErrorOr<absl::flat_hash_set<pid_t>> ParsePIDList( + absl::string_view data) const; + + static int64_t next_id_; + int64_t id_; + const std::string cgroup_path_; +}; + +// Mounter is a utility for creating cgroupfs mounts. It automatically manages +// the lifetime of created mounts. +class Mounter { + public: + Mounter(TempPath root) : root_(std::move(root)) {} + + PosixErrorOr<Cgroup> MountCgroupfs(std::string mopts); + + PosixError Unmount(const Cgroup& c); + + private: + // The destruction order of these members avoids errors during cleanup. We + // first unmount (by executing the mounts_ cleanups), then delete the + // mountpoint subdirs, then delete the root. + TempPath root_; + absl::flat_hash_map<int64_t, TempPath> mountpoints_; + absl::flat_hash_map<int64_t, Cleanup> mounts_; +}; + +// Represents a line from /proc/cgroups. +struct CgroupsEntry { + std::string subsys_name; + uint32_t hierarchy; + uint64_t num_cgroups; + bool enabled; +}; + +// Returns a parsed representation of /proc/cgroups. +PosixErrorOr<absl::flat_hash_map<std::string, CgroupsEntry>> +ProcCgroupsEntries(); + +// Represents a line from /proc/<pid>/cgroup. +struct PIDCgroupEntry { + uint32_t hierarchy; + std::string controllers; + std::string path; +}; + +// Returns a parsed representation of /proc/<pid>/cgroup. +PosixErrorOr<absl::flat_hash_map<std::string, PIDCgroupEntry>> +ProcPIDCgroupEntries(pid_t pid); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_UTIL_CGROUP_UTIL_H_ diff --git a/test/util/fs_util.cc b/test/util/fs_util.cc index 5f1ce0d8a..483ae848d 100644 --- a/test/util/fs_util.cc +++ b/test/util/fs_util.cc @@ -28,6 +28,8 @@ #include "absl/strings/str_cat.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/util/cleanup.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -366,6 +368,48 @@ PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath, return files; } +PosixError DirContains(absl::string_view path, + const std::vector<std::string>& expect, + const std::vector<std::string>& exclude) { + ASSIGN_OR_RETURN_ERRNO(auto listing, ListDir(path, false)); + + for (auto& expected_entry : expect) { + auto cursor = std::find(listing.begin(), listing.end(), expected_entry); + if (cursor == listing.end()) { + return PosixError(ENOENT, absl::StrFormat("Failed to find '%s' in '%s'", + expected_entry, path)); + } + } + for (auto& excluded_entry : exclude) { + auto cursor = std::find(listing.begin(), listing.end(), excluded_entry); + if (cursor != listing.end()) { + return PosixError(ENOENT, absl::StrCat("File '", excluded_entry, + "' found in path '", path, "'")); + } + } + return NoError(); +} + +PosixError EventuallyDirContains(absl::string_view path, + const std::vector<std::string>& expect, + const std::vector<std::string>& exclude) { + constexpr int kRetryCount = 100; + const absl::Duration kRetryDelay = absl::Milliseconds(100); + + for (int i = 0; i < kRetryCount; ++i) { + auto res = DirContains(path, expect, exclude); + if (res.ok()) { + return res; + } + if (i < kRetryCount - 1) { + // Sleep if this isn't the final iteration. + absl::SleepFor(kRetryDelay); + } + } + return PosixError(ETIMEDOUT, + "Timed out while waiting for directory to contain files "); +} + PosixError RecursivelyDelete(absl::string_view path, int* undeleted_dirs, int* undeleted_files) { ASSIGN_OR_RETURN_ERRNO(bool exists, Exists(path)); diff --git a/test/util/fs_util.h b/test/util/fs_util.h index 2190c3bca..bb2d1d3c8 100644 --- a/test/util/fs_util.h +++ b/test/util/fs_util.h @@ -129,6 +129,18 @@ PosixError WalkTree( PosixErrorOr<std::vector<std::string>> ListDir(absl::string_view abspath, bool skipdots); +// Check that a directory contains children nodes named in expect, and does not +// contain any children nodes named in exclude. +PosixError DirContains(absl::string_view path, + const std::vector<std::string>& expect, + const std::vector<std::string>& exclude); + +// Same as DirContains, but adds a retry. Suitable for checking a directory +// being modified asynchronously. +PosixError EventuallyDirContains(absl::string_view path, + const std::vector<std::string>& expect, + const std::vector<std::string>& exclude); + // Attempt to recursively delete a directory or file. Returns an error and // the number of undeleted directories and files. If either // undeleted_dirs or undeleted_files is nullptr then it will not be used. diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go index 8b4bff3b6..2b3c03fec 100644 --- a/tools/nogo/analyzers.go +++ b/tools/nogo/analyzers.go @@ -83,11 +83,6 @@ var AllAnalyzers = []*analysis.Analyzer{ checklocks.Analyzer, } -// EscapeAnalyzers is a list of escape-related analyzers. -var EscapeAnalyzers = []*analysis.Analyzer{ - checkescape.EscapeAnalyzer, -} - func register(all []*analysis.Analyzer) { // Register all fact types. // @@ -129,5 +124,4 @@ func init() { // Register lists. register(AllAnalyzers) - register(EscapeAnalyzers) } diff --git a/tools/nogo/check/main.go b/tools/nogo/check/main.go index 69bdfe502..4194770be 100644 --- a/tools/nogo/check/main.go +++ b/tools/nogo/check/main.go @@ -31,7 +31,6 @@ var ( stdlibFile = flag.String("stdlib", "", "stdlib configuration file (in JSON format)") findingsOutput = flag.String("findings", "", "output file (or stdout, if not specified)") factsOutput = flag.String("facts", "", "output file for facts (optional)") - escapesOutput = flag.String("escapes", "", "output file for escapes (optional)") ) func loadConfig(file string, config interface{}) interface{} { @@ -66,25 +65,13 @@ func main() { // Run the configuration. if *stdlibFile != "" { - // Perform basic analysis. + // Perform stdlib analysis. c := loadConfig(*stdlibFile, new(nogo.StdlibConfig)).(*nogo.StdlibConfig) findings, factData, err = nogo.CheckStdlib(c, nogo.AllAnalyzers) - } else if *packageFile != "" { - // Perform basic analysis. + // Perform standard analysis. c := loadConfig(*packageFile, new(nogo.PackageConfig)).(*nogo.PackageConfig) findings, factData, err = nogo.CheckPackage(c, nogo.AllAnalyzers, nil) - - // Do we need to do escape analysis? - if *escapesOutput != "" { - escapes, _, err := nogo.CheckPackage(c, nogo.EscapeAnalyzers, nil) - if err != nil { - log.Fatalf("error performing escape analysis: %v", err) - } - if err := nogo.WriteFindingsToFile(escapes, *escapesOutput); err != nil { - log.Fatalf("error writing escapes to %q: %v", *escapesOutput, err) - } - } } else { log.Fatalf("please provide at least one of package or stdlib!") } diff --git a/tools/nogo/defs.bzl b/tools/nogo/defs.bzl index 0c48a7a5a..cb407a736 100644 --- a/tools/nogo/defs.bzl +++ b/tools/nogo/defs.bzl @@ -174,7 +174,6 @@ NogoInfo = provider( fields = { "facts": "serialized package facts", "raw_findings": "raw package findings (if relevant)", - "escapes": "escape-only findings (if relevant)", "importpath": "package import path", "binaries": "package binary files", "srcs": "srcs (for go_test support)", @@ -281,7 +280,6 @@ def _nogo_aspect_impl(target, ctx): go_ctx = go_context(ctx, goos = nogo_target_info.goos, goarch = nogo_target_info.goarch) facts = ctx.actions.declare_file(target.label.name + ".facts") raw_findings = ctx.actions.declare_file(target.label.name + ".raw_findings") - escapes = ctx.actions.declare_file(target.label.name + ".escapes") config = struct( ImportPath = importpath, GoFiles = [src.path for src in srcs if src.path.endswith(".go")], @@ -298,7 +296,7 @@ def _nogo_aspect_impl(target, ctx): inputs.append(config_file) ctx.actions.run( inputs = inputs, - outputs = [facts, raw_findings, escapes], + outputs = [facts, raw_findings], tools = depset(go_ctx.runfiles.to_list() + ctx.files._nogo_objdump_tool), executable = ctx.files._nogo_check[0], mnemonic = "NogoAnalysis", @@ -309,7 +307,6 @@ def _nogo_aspect_impl(target, ctx): "-package=%s" % config_file.path, "-findings=%s" % raw_findings.path, "-facts=%s" % facts.path, - "-escapes=%s" % escapes.path, ], ) @@ -322,15 +319,16 @@ def _nogo_aspect_impl(target, ctx): all_raw_findings = [stdlib_info.raw_findings] + depset(all_raw_findings).to_list() + [raw_findings] # Return the package facts as output. - return [NogoInfo( - facts = facts, - raw_findings = all_raw_findings, - escapes = escapes, - importpath = importpath, - binaries = binaries, - srcs = srcs, - deps = deps, - )] + return [ + NogoInfo( + facts = facts, + raw_findings = all_raw_findings, + importpath = importpath, + binaries = binaries, + srcs = srcs, + deps = deps, + ), + ] nogo_aspect = go_rule( aspect, @@ -367,7 +365,6 @@ def _nogo_test_impl(ctx): if len(ctx.attr.deps) != 1: fail("nogo_test requires exactly one dep.") raw_findings = ctx.attr.deps[0][NogoInfo].raw_findings - escapes = ctx.attr.deps[0][NogoInfo].escapes # Build a step that applies the configuration. config_srcs = ctx.attr.config[NogoConfigInfo].srcs @@ -409,8 +406,6 @@ def _nogo_test_impl(ctx): # pays attention to the mnemoic above, so this must be # what is expected by the tooling. nogo_findings = depset([findings]), - # Expose all escape analysis findings (see above). - nogo_escapes = depset([escapes]), )] nogo_test = rule( @@ -432,3 +427,18 @@ nogo_test = rule( }, test = True, ) + +def _nogo_aspect_tricorder_impl(target, ctx): + if ctx.rule.kind != "nogo_test" or OutputGroupInfo not in target: + return [] + if not hasattr(target[OutputGroupInfo], "nogo_findings"): + return [] + return [ + OutputGroupInfo(tricorder = target[OutputGroupInfo].nogo_findings), + ] + +# Trivial aspect that forwards the findings from a nogo_test rule to +# go/tricorder, which reads from the `tricorder` output group. +nogo_aspect_tricorder = aspect( + implementation = _nogo_aspect_tricorder_impl, +) |