diff options
213 files changed, 5133 insertions, 3295 deletions
diff --git a/.buildkite/pipeline.yaml b/.buildkite/pipeline.yaml index cb272aef6..aa2fd1f47 100644 --- a/.buildkite/pipeline.yaml +++ b/.buildkite/pipeline.yaml @@ -183,9 +183,13 @@ steps: - <<: *benchmarks label: ":metal: FFMPEG benchmarks" command: make benchmark-platforms BENCHMARKS_SUITE=ffmpeg BENCHMARKS_TARGETS=test/benchmarks/media:ffmpeg_test + # For fio, running with --test.benchtime=Xs scales the written/read + # bytes to several GB. This is not a problem for root/bind/volume mounts, + # but for tmpfs mounts, the size can grow to more memory than the machine + # has availabe. Fix the runs to 10GB written/read for the benchmark. - <<: *benchmarks label: ":floppy_disk: FIO benchmarks" - command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test + command: make benchmark-platforms BENCHMARKS_SUITE=fio BENCHMARKS_TARGETS=test/benchmarks/fs:fio_test BENCHMARKS_OPTIONS=--test.benchtime=10000x - <<: *benchmarks label: ":globe_with_meridians: HTTPD benchmarks" command: make benchmark-platforms BENCHMARKS_FILTER="Continuous" BENCHMARKS_SUITE=httpd BENCHMARKS_TARGETS=test/benchmarks/network:httpd_test diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b0381a563..b572dc94f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -16,6 +16,10 @@ jobs: default: runs-on: ubuntu-latest steps: + - name: Cancel previous + uses: styfle/cancel-workflow-action@0.7.0 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v2 - run: make - run: make build OPTIONS="--build_tag_filters nogo" TARGETS="//..." diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 594dc7ffc..4c8b8ea5c 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -16,6 +16,10 @@ jobs: generate: runs-on: ubuntu-latest steps: + - name: Cancel previous + uses: styfle/cancel-workflow-action@0.7.0 + with: + access_token: ${{ github.token }} - id: setup run: | if ! [[ -z "${{ secrets.GO_TOKEN }}" ]]; then @@ -143,6 +143,7 @@ dev: $(RUNTIME_BIN) ## Installs a set of local runtimes. Requires sudo. @$(call configure_noreload,$(RUNTIME)-d,--net-raw --debug --strace --log-packets) @$(call configure_noreload,$(RUNTIME)-p,--net-raw --profile) @$(call configure_noreload,$(RUNTIME)-vfs2-d,--net-raw --debug --strace --log-packets --vfs2) + @$(call configure_noreload,$(RUNTIME)-vfs2-fuse-d,--net-raw --debug --strace --log-packets --vfs2 --fuse) @$(call reload_docker) .PHONY: dev @@ -325,6 +326,7 @@ containerd-tests: containerd-test-1.4.3 ## BENCHMARKS_FILTER - filter to be applied to the test suite. ## BENCHMARKS_OPTIONS - options to be passed to the test. ## BENCHMARKS_PROFILE - profile options to be passed to the test. +## BENCH_RUNTIME_ARGS - args to configure the runtime which runs the benchmarks. ## BENCHMARKS_PROJECT ?= gvisor-benchmarks BENCHMARKS_DATASET ?= kokoro @@ -338,6 +340,7 @@ BENCHMARKS_FILTER := . BENCHMARKS_OPTIONS := -test.benchtime=30s BENCHMARKS_ARGS := -test.v -test.bench=$(BENCHMARKS_FILTER) $(BENCHMARKS_OPTIONS) BENCHMARKS_PROFILE := -pprof-dir=/tmp/profile -pprof-cpu -pprof-heap -pprof-block -pprof-mutex +BENCH_RUNTIME_ARGS ?= --vfs2 init-benchmark-table: ## Initializes a BigQuery table with the benchmark schema. @$(call run,//tools/parsers:parser,init --project=$(BENCHMARKS_PROJECT) --dataset=$(BENCHMARKS_DATASET) --table=$(BENCHMARKS_TABLE)) @@ -358,13 +361,13 @@ run_benchmark = \ benchmark-platforms: load-benchmarks $(RUNTIME_BIN) ## Runs benchmarks for runc and all given platforms in BENCHMARK_PLATFORMS. @$(foreach PLATFORM,$(BENCHMARKS_PLATFORMS), \ - $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) --vfs2) && \ + $(call run_benchmark,$(PLATFORM),--platform=$(PLATFORM) $(BENCH_RUNTIME_ARGS)) && \ ) true @$(call run_benchmark,runc) .PHONY: benchmark-platforms run-benchmark: load-benchmarks $(RUNTIME_BIN) ## Runs single benchmark and optionally sends data to BigQuery. - @$(call run_benchmark,$(RUNTIME),) + @$(call run_benchmark,$(RUNTIME),$(BENCH_RUNTIME_ARGS)) .PHONY: run-benchmark ## @@ -303,8 +303,8 @@ go_repository( go_repository( name = "com_github_gofrs_flock", importpath = "github.com/gofrs/flock", - sum = "h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs=", - version = "v0.6.1-0.20180915234121-886344bea079", + sum = "h1:MSdYClljsF3PbENUUEx85nkWfJSGfzYI9yEBZOJz6CY=", + version = "v0.8.0", ) go_repository( @@ -20,7 +20,7 @@ require ( github.com/docker/docker v1.4.2-0.20191028175130-9e7d5ac5ea55 // indirect github.com/docker/go-connections v0.3.0 // indirect github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect - github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 // indirect + github.com/gofrs/flock v0.8.0 // indirect github.com/gogo/googleapis v1.4.0 // indirect github.com/gogo/protobuf v1.3.1 // indirect github.com/golang/mock v1.4.4 // indirect @@ -135,6 +135,8 @@ github.com/godbus/dbus/v5 v5.0.3 h1:ZqHaoEF7TBzh4jzPmqVhE/5A1z9of6orkAe5uHoAeME= github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079 h1:JFTFz3HZTGmgMz4E1TabNBNJljROSYgja1b4l50FNVs= github.com/gofrs/flock v0.6.1-0.20180915234121-886344bea079/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= +github.com/gofrs/flock v0.8.0 h1:MSdYClljsF3PbENUUEx85nkWfJSGfzYI9yEBZOJz6CY= +github.com/gofrs/flock v0.8.0/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gogo/googleapis v1.4.0 h1:zgVt4UpGxcqVOw97aRGxT4svlcmdK35fynLNctY32zI= github.com/gogo/googleapis v1.4.0/go.mod h1:5YRNX2z1oM5gXdAkurHa942MDgEJyk02w4OecKY87+c= github.com/gogo/protobuf v1.2.2-0.20190723190241-65acae22fc9d/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= diff --git a/images/syzkaller/README.md b/images/syzkaller/README.md index 47e309422..7e500cab3 100644 --- a/images/syzkaller/README.md +++ b/images/syzkaller/README.md @@ -51,8 +51,8 @@ syzkaller repro in /tmp/syzkaller/repro. Now we can run syz-repro to reproduce a crash: ```bash -docker run --privileged -it --rm -v - /tmp/syzkaller:/tmp/syzkaller --entrypoint="" - gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config +docker run --privileged -it --rm -v \ + /tmp/syzkaller:/tmp/syzkaller --entrypoint="" \ + gvisor.dev/images/syzkaller:latest ./bin/syz-repro -config \ /tmp/syzkaller/syzkaller.cfg /tmp/syzkaller/repro ``` @@ -46,8 +46,6 @@ global: - "(field|method|struct|type) .* should be .*" # Generated proto code sometimes duplicates imports with aliases. - "duplicate import" - # TODO(b/179817829): Upgrade to flock to v0.8.0. - - "flock.NewFlock is deprecated: Use New instead" internal: suppress: # We use ALL_CAPS for system definitions, diff --git a/pkg/abi/linux/ptrace_amd64.go b/pkg/abi/linux/ptrace_amd64.go index ed3881e27..50e22fe7e 100644 --- a/pkg/abi/linux/ptrace_amd64.go +++ b/pkg/abi/linux/ptrace_amd64.go @@ -50,3 +50,14 @@ type PtraceRegs struct { Fs uint64 Gs uint64 } + +// InstructionPointer returns the address of the next instruction to +// be executed. +func (p *PtraceRegs) InstructionPointer() uint64 { + return p.Rip +} + +// StackPointer returns the address of the Stack pointer. +func (p *PtraceRegs) StackPointer() uint64 { + return p.Rsp +} diff --git a/pkg/abi/linux/ptrace_arm64.go b/pkg/abi/linux/ptrace_arm64.go index 6147738b3..da36811d2 100644 --- a/pkg/abi/linux/ptrace_arm64.go +++ b/pkg/abi/linux/ptrace_arm64.go @@ -27,3 +27,14 @@ type PtraceRegs struct { Pc uint64 Pstate uint64 } + +// InstructionPointer returns the address of the next instruction to be +// executed. +func (p *PtraceRegs) InstructionPointer() uint64 { + return p.Pc +} + +// StackPointer returns the address of the Stack pointer. +func (p *PtraceRegs) StackPointer() uint64 { + return p.Sp +} diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 52061175f..bbe282c03 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -17,6 +17,7 @@ package proc import ( "fmt" "io" + "math" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -26,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/ramfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -498,6 +500,120 @@ func (f *ipForwardingFile) Write(ctx context.Context, _ *fs.File, src usermem.IO return n, f.stack.SetForwarding(ipv4.ProtocolNumber, *f.ipf.enabled) } +// portRangeInode implements fs.InodeOperations. It provides and allows +// modification of the range of ephemeral ports that IPv4 and IPv6 sockets +// choose from. +// +// +stateify savable +type portRangeInode struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + + // start and end store the port range. We must save/restore this here, + // since a netstack instance is created on restore. + start *uint16 + end *uint16 +} + +func newPortRangeInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ipf := &portRangeInode{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ipf, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. Truncate is called when +// O_TRUNC is specified for any kind of existing Dirent but is not called via +// (f)truncate for proc files. +func (*portRangeInode) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// +stateify savable +type portRangeFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + inode *portRangeInode +} + +// GetFile implements fs.InodeOperations.GetFile. +func (in *portRangeInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &portRangeFile{ + inode: in, + }), nil +} + +// Read implements fs.FileOperations.Read. +func (pf *portRangeFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + if pf.inode.start == nil { + start, end := pf.inode.stack.PortRange() + pf.inode.start = &start + pf.inode.end = &end + } + + contents := fmt.Sprintf("%d %d\n", *pf.inode.start, *pf.inode.end) + n, err := dst.CopyOut(ctx, []byte(contents)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +// +// Offset is ignored, multiple writes are not supported. +func (pf *portRangeFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + + // Only consider size of one memory page for input for performance + // reasons. + src = src.TakeFirst(usermem.PageSize - 1) + + ports := make([]int32, 2) + n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts) + if err != nil { + return 0, err + } + + // Port numbers must be uint16s. + if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 { + return 0, syserror.EINVAL + } + + if err := pf.inode.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil { + return 0, err + } + if pf.inode.start == nil { + pf.inode.start = new(uint16) + pf.inode.end = new(uint16) + } + *pf.inode.start = uint16(ports[0]) + *pf.inode.end = uint16(ports[1]) + return n, nil +} + func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { contents := map[string]*fs.Inode{ // Add tcp_sack. @@ -506,12 +622,15 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine // Add ip_forward. "ip_forward": newIPForwardingInode(ctx, msrc, s), + // Allow for configurable ephemeral port ranges. Note that this + // controls ports for both IPv4 and IPv6 sockets. + "ip_local_port_range": newPortRangeInode(ctx, msrc, s), + // The following files are simple stubs until they are // implemented in netstack, most of these files are // configuration related. We use the value closest to the // actual netstack behavior or any empty file, all of these // files will have mode 0444 (read-only for all users). - "ip_local_port_range": newStaticProcInode(ctx, msrc, []byte("16000 65535")), "ip_local_reserved_ports": newStaticProcInode(ctx, msrc, []byte("")), "ipfrag_time": newStaticProcInode(ctx, msrc, []byte("30")), "ip_nonlocal_bind": newStaticProcInode(ctx, msrc, []byte("0")), diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index d8c237753..e75954105 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -137,6 +137,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.Release(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // rootInode is the root directory inode for the devpts mounts. // // +stateify savable diff --git a/pkg/sentry/fsimpl/ext/filesystem.go b/pkg/sentry/fsimpl/ext/filesystem.go index 917f1873d..d4fc484a2 100644 --- a/pkg/sentry/fsimpl/ext/filesystem.go +++ b/pkg/sentry/fsimpl/ext/filesystem.go @@ -548,3 +548,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.mu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 204d8d143..fef857afb 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -47,19 +47,14 @@ type FilesystemType struct{} // +stateify savable type filesystemOptions struct { - // userID specifies the numeric uid of the mount owner. - // This option should not be specified by the filesystem owner. - // It is set by libfuse (or, if libfuse is not used, must be set - // by the filesystem itself). For more information, see man page - // for fuse(8) - userID uint32 - - // groupID specifies the numeric gid of the mount owner. - // This option should not be specified by the filesystem owner. - // It is set by libfuse (or, if libfuse is not used, must be set - // by the filesystem itself). For more information, see man page - // for fuse(8) - groupID uint32 + // mopts contains the raw, unparsed mount options passed to this filesystem. + mopts string + + // uid of the mount owner. + uid auth.KUID + + // gid of the mount owner. + gid auth.KGID // rootMode specifies the the file mode of the filesystem's root. rootMode linux.FileMode @@ -73,6 +68,19 @@ type filesystemOptions struct { // specified as "max_read" in fs parameters. // If not specified by user, use math.MaxUint32 as default value. maxRead uint32 + + // defaultPermissions is the default_permissions mount option. It instructs + // the kernel to perform a standard unix permission checks based on + // ownership and mode bits, instead of deferring the check to the server. + // + // Immutable after mount. + defaultPermissions bool + + // allowOther is the allow_other mount option. It allows processes that + // don't own the FUSE mount to call into it. + // + // Immutable after mount. + allowOther bool } // filesystem implements vfs.FilesystemImpl. @@ -108,18 +116,18 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } - var fsopts filesystemOptions + fsopts := filesystemOptions{mopts: opts.Data} mopts := vfs.GenericParseMountOptions(opts.Data) deviceDescriptorStr, ok := mopts["fd"] if !ok { - log.Warningf("%s.GetFilesystem: communication file descriptor N (obtained by opening /dev/fuse) must be specified as 'fd=N'", fsType.Name()) + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option fd missing") return nil, nil, syserror.EINVAL } delete(mopts, "fd") deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) if err != nil { - log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err) + ctx.Debugf("fusefs.FilesystemType.GetFilesystem: invalid fd: %q (%v)", deviceDescriptorStr, err) return nil, nil, syserror.EINVAL } @@ -141,38 +149,54 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Parse and set all the other supported FUSE mount options. // TODO(gVisor.dev/issue/3229): Expand the supported mount options. - if userIDStr, ok := mopts["user_id"]; ok { + if uidStr, ok := mopts["user_id"]; ok { delete(mopts, "user_id") - userID, err := strconv.ParseUint(userIDStr, 10, 32) + uid, err := strconv.ParseUint(uidStr, 10, 32) if err != nil { - log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), userIDStr) + log.Warningf("%s.GetFilesystem: invalid user_id: user_id=%s", fsType.Name(), uidStr) return nil, nil, syserror.EINVAL } - fsopts.userID = uint32(userID) + kuid := creds.UserNamespace.MapToKUID(auth.UID(uid)) + if !kuid.Ok() { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped uid: %d", uid) + return nil, nil, syserror.EINVAL + } + fsopts.uid = kuid + } else { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option user_id missing") + return nil, nil, syserror.EINVAL } - if groupIDStr, ok := mopts["group_id"]; ok { + if gidStr, ok := mopts["group_id"]; ok { delete(mopts, "group_id") - groupID, err := strconv.ParseUint(groupIDStr, 10, 32) + gid, err := strconv.ParseUint(gidStr, 10, 32) if err != nil { - log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), groupIDStr) + log.Warningf("%s.GetFilesystem: invalid group_id: group_id=%s", fsType.Name(), gidStr) + return nil, nil, syserror.EINVAL + } + kgid := creds.UserNamespace.MapToKGID(auth.GID(gid)) + if !kgid.Ok() { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: unmapped gid: %d", gid) return nil, nil, syserror.EINVAL } - fsopts.groupID = uint32(groupID) + fsopts.gid = kgid + } else { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option group_id missing") + return nil, nil, syserror.EINVAL } - rootMode := linux.FileMode(0777) - modeStr, ok := mopts["rootmode"] - if ok { + if modeStr, ok := mopts["rootmode"]; ok { delete(mopts, "rootmode") mode, err := strconv.ParseUint(modeStr, 8, 32) if err != nil { log.Warningf("%s.GetFilesystem: invalid mode: %q", fsType.Name(), modeStr) return nil, nil, syserror.EINVAL } - rootMode = linux.FileMode(mode) + fsopts.rootMode = linux.FileMode(mode) + } else { + ctx.Warningf("fusefs.FilesystemType.GetFilesystem: mandatory mount option rootmode missing") + return nil, nil, syserror.EINVAL } - fsopts.rootMode = rootMode // Set the maxInFlightRequests option. fsopts.maxActiveRequests = maxActiveRequestsDefault @@ -192,6 +216,16 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.maxRead = math.MaxUint32 } + if _, ok := mopts["default_permissions"]; ok { + delete(mopts, "default_permissions") + fsopts.defaultPermissions = true + } + + if _, ok := mopts["allow_other"]; ok { + delete(mopts, "allow_other") + fsopts.allowOther = true + } + // Check for unparsed options. if len(mopts) != 0 { log.Warningf("%s.GetFilesystem: unsupported or unknown options: %v", fsType.Name(), mopts) @@ -260,6 +294,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.Release(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return fs.opts.mopts +} + // inode implements kernfs.Inode. // // +stateify savable @@ -318,6 +357,37 @@ func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FU return i } +// CheckPermissions implements kernfs.Inode.CheckPermissions. +func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { + // Since FUSE operations are ultimately backed by a userspace process (the + // fuse daemon), allowing a process to call into fusefs grants the daemon + // ptrace-like capabilities over the calling process. Because of this, by + // default FUSE only allows the mount owner to interact with the + // filesystem. This explicitly excludes setuid/setgid processes. + // + // This behaviour can be overriden with the 'allow_other' mount option. + // + // See fs/fuse/dir.c:fuse_allow_current_process() in Linux. + if !i.fs.opts.allowOther { + if creds.RealKUID != i.fs.opts.uid || + creds.EffectiveKUID != i.fs.opts.uid || + creds.SavedKUID != i.fs.opts.uid || + creds.RealKGID != i.fs.opts.gid || + creds.EffectiveKGID != i.fs.opts.gid || + creds.SavedKGID != i.fs.opts.gid { + return syserror.EACCES + } + } + + // By default, fusefs delegates all permission checks to the server. + // However, standard unix permission checks can be enabled with the + // default_permissions mount option. + if i.fs.opts.defaultPermissions { + return i.InodeAttrs.CheckPermissions(ctx, creds, ats) + } + return nil +} + // Open implements kernfs.Inode.Open. func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { isDir := i.InodeAttrs.Mode().IsDir() diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 8f95473b6..c34451269 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -15,7 +15,9 @@ package gofer import ( + "fmt" "math" + "strings" "sync" "sync/atomic" @@ -1608,3 +1610,58 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +type mopt struct { + key string + value interface{} +} + +func (m mopt) String() string { + if m.value == nil { + return fmt.Sprintf("%s", m.key) + } + return fmt.Sprintf("%s=%v", m.key, m.value) +} + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + optsKV := []mopt{ + {moptTransport, transportModeFD}, // Only valid value, currently. + {moptReadFD, fs.opts.fd}, // Currently, read and write FD are the same. + {moptWriteFD, fs.opts.fd}, // Currently, read and write FD are the same. + {moptAname, fs.opts.aname}, + {moptDfltUID, fs.opts.dfltuid}, + {moptDfltGID, fs.opts.dfltgid}, + {moptMsize, fs.opts.msize}, + {moptVersion, fs.opts.version}, + {moptDentryCacheLimit, fs.opts.maxCachedDentries}, + } + + switch fs.opts.interop { + case InteropModeExclusive: + optsKV = append(optsKV, mopt{moptCache, cacheFSCache}) + case InteropModeWritethrough: + optsKV = append(optsKV, mopt{moptCache, cacheFSCacheWritethrough}) + case InteropModeShared: + if fs.opts.regularFilesUseSpecialFileFD { + optsKV = append(optsKV, mopt{moptCache, cacheNone}) + } else { + optsKV = append(optsKV, mopt{moptCache, cacheRemoteRevalidating}) + } + } + if fs.opts.forcePageCache { + optsKV = append(optsKV, mopt{moptForcePageCache, nil}) + } + if fs.opts.limitHostFDTranslation { + optsKV = append(optsKV, mopt{moptLimitHostFDTranslation, nil}) + } + if fs.opts.overlayfsStaleRead { + optsKV = append(optsKV, mopt{moptOverlayfsStaleRead, nil}) + } + + opts := make([]string, 0, len(optsKV)) + for _, opt := range optsKV { + opts = append(opts, opt.String()) + } + return strings.Join(opts, ",") +} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 1508cbdf1..71569dc65 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -66,6 +66,34 @@ import ( // Name is the default filesystem name. const Name = "9p" +// Mount option names for goferfs. +const ( + moptTransport = "trans" + moptReadFD = "rfdno" + moptWriteFD = "wfdno" + moptAname = "aname" + moptDfltUID = "dfltuid" + moptDfltGID = "dfltgid" + moptMsize = "msize" + moptVersion = "version" + moptDentryCacheLimit = "dentry_cache_limit" + moptCache = "cache" + moptForcePageCache = "force_page_cache" + moptLimitHostFDTranslation = "limit_host_fd_translation" + moptOverlayfsStaleRead = "overlayfs_stale_read" +) + +// Valid values for the "cache" mount option. +const ( + cacheNone = "none" + cacheFSCache = "fscache" + cacheFSCacheWritethrough = "fscache_writethrough" + cacheRemoteRevalidating = "remote_revalidating" +) + +// Valid values for "trans" mount option. +const transportModeFD = "fd" + // FilesystemType implements vfs.FilesystemType. // // +stateify savable @@ -301,39 +329,39 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Get the attach name. fsopts.aname = "/" - if aname, ok := mopts["aname"]; ok { - delete(mopts, "aname") + if aname, ok := mopts[moptAname]; ok { + delete(mopts, moptAname) fsopts.aname = aname } // Parse the cache policy. For historical reasons, this defaults to the // least generally-applicable option, InteropModeExclusive. fsopts.interop = InteropModeExclusive - if cache, ok := mopts["cache"]; ok { - delete(mopts, "cache") + if cache, ok := mopts[moptCache]; ok { + delete(mopts, moptCache) switch cache { - case "fscache": + case cacheFSCache: fsopts.interop = InteropModeExclusive - case "fscache_writethrough": + case cacheFSCacheWritethrough: fsopts.interop = InteropModeWritethrough - case "none": + case cacheNone: fsopts.regularFilesUseSpecialFileFD = true fallthrough - case "remote_revalidating": + case cacheRemoteRevalidating: fsopts.interop = InteropModeShared default: - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: cache=%s", cache) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid cache policy: %s=%s", moptCache, cache) return nil, nil, syserror.EINVAL } } // Parse the default UID and GID. fsopts.dfltuid = _V9FS_DEFUID - if dfltuidstr, ok := mopts["dfltuid"]; ok { - delete(mopts, "dfltuid") + if dfltuidstr, ok := mopts[moptDfltUID]; ok { + delete(mopts, moptDfltUID) dfltuid, err := strconv.ParseUint(dfltuidstr, 10, 32) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltuid=%s", dfltuidstr) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltUID, dfltuidstr) return nil, nil, syserror.EINVAL } // In Linux, dfltuid is interpreted as a UID and is converted to a KUID @@ -342,11 +370,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.dfltuid = auth.KUID(dfltuid) } fsopts.dfltgid = _V9FS_DEFGID - if dfltgidstr, ok := mopts["dfltgid"]; ok { - delete(mopts, "dfltgid") + if dfltgidstr, ok := mopts[moptDfltGID]; ok { + delete(mopts, moptDfltGID) dfltgid, err := strconv.ParseUint(dfltgidstr, 10, 32) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: dfltgid=%s", dfltgidstr) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid default UID: %s=%s", moptDfltGID, dfltgidstr) return nil, nil, syserror.EINVAL } fsopts.dfltgid = auth.KGID(dfltgid) @@ -354,11 +382,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Parse the 9P message size. fsopts.msize = 1024 * 1024 // 1M, tested to give good enough performance up to 64M - if msizestr, ok := mopts["msize"]; ok { - delete(mopts, "msize") + if msizestr, ok := mopts[moptMsize]; ok { + delete(mopts, moptMsize) msize, err := strconv.ParseUint(msizestr, 10, 32) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: msize=%s", msizestr) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid message size: %s=%s", moptMsize, msizestr) return nil, nil, syserror.EINVAL } fsopts.msize = uint32(msize) @@ -366,34 +394,34 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Parse the 9P protocol version. fsopts.version = p9.HighestVersionString() - if version, ok := mopts["version"]; ok { - delete(mopts, "version") + if version, ok := mopts[moptVersion]; ok { + delete(mopts, moptVersion) fsopts.version = version } // Parse the dentry cache limit. fsopts.maxCachedDentries = 1000 - if str, ok := mopts["dentry_cache_limit"]; ok { - delete(mopts, "dentry_cache_limit") + if str, ok := mopts[moptDentryCacheLimit]; ok { + delete(mopts, moptDentryCacheLimit) maxCachedDentries, err := strconv.ParseUint(str, 10, 64) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid dentry cache limit: %s=%s", moptDentryCacheLimit, str) return nil, nil, syserror.EINVAL } fsopts.maxCachedDentries = maxCachedDentries } // Handle simple flags. - if _, ok := mopts["force_page_cache"]; ok { - delete(mopts, "force_page_cache") + if _, ok := mopts[moptForcePageCache]; ok { + delete(mopts, moptForcePageCache) fsopts.forcePageCache = true } - if _, ok := mopts["limit_host_fd_translation"]; ok { - delete(mopts, "limit_host_fd_translation") + if _, ok := mopts[moptLimitHostFDTranslation]; ok { + delete(mopts, moptLimitHostFDTranslation) fsopts.limitHostFDTranslation = true } - if _, ok := mopts["overlayfs_stale_read"]; ok { - delete(mopts, "overlayfs_stale_read") + if _, ok := mopts[moptOverlayfsStaleRead]; ok { + delete(mopts, moptOverlayfsStaleRead) fsopts.overlayfsStaleRead = true } // fsopts.regularFilesUseSpecialFileFD can only be enabled by specifying @@ -469,34 +497,34 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { // Check that the transport is "fd". - trans, ok := mopts["trans"] - if !ok || trans != "fd" { - ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'") + trans, ok := mopts[moptTransport] + if !ok || trans != transportModeFD { + ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as '%s=%s'", moptTransport, transportModeFD) return -1, syserror.EINVAL } - delete(mopts, "trans") + delete(mopts, moptTransport) // Check that read and write FDs are provided and identical. - rfdstr, ok := mopts["rfdno"] + rfdstr, ok := mopts[moptReadFD] if !ok { - ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'") + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as '%s=<file descriptor>'", moptReadFD) return -1, syserror.EINVAL } - delete(mopts, "rfdno") + delete(mopts, moptReadFD) rfd, err := strconv.Atoi(rfdstr) if err != nil { - ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr) + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: %s=%s", moptReadFD, rfdstr) return -1, syserror.EINVAL } - wfdstr, ok := mopts["wfdno"] + wfdstr, ok := mopts[moptWriteFD] if !ok { - ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'") + ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as '%s=<file descriptor>'", moptWriteFD) return -1, syserror.EINVAL } - delete(mopts, "wfdno") + delete(mopts, moptWriteFD) wfd, err := strconv.Atoi(wfdstr) if err != nil { - ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr) + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: %s=%s", moptWriteFD, wfdstr) return -1, syserror.EINVAL } if rfd != wfd { diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index ad5de80dc..b9cce4181 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -260,6 +260,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s unix.Stat_t diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index e63588e33..1cd3137e6 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -67,6 +67,11 @@ type filesystem struct { kernfs.Filesystem } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + type file struct { kernfs.DynamicBytesFile content string diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 917709d75..84e37f793 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -1764,3 +1764,15 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + // Return the mount options from the topmost layer. + var vd vfs.VirtualDentry + if fs.opts.UpperRoot.Ok() { + vd = fs.opts.UpperRoot + } else { + vd = fs.opts.LowerRoots[0] + } + return vd.Mount().Filesystem().Impl().MountOptions() +} diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 429733c10..3f05e444e 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -80,6 +80,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // inode implements kernfs.Inode. // // +stateify savable diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 8716d0a3c..254a8b062 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -104,6 +104,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.Release(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return fmt.Sprintf("dentry_cache_limit=%d", fs.MaxCachedDentries) +} + // dynamicInode is an overfitted interface for common Inodes with // dynamicByteSource types used in procfs. // diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index fd7823daa..fb274b78e 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -17,6 +17,7 @@ package proc import ( "bytes" "fmt" + "math" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -69,17 +70,17 @@ func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]kernfs.Inode{ "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ - "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), - "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), - "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), - "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), - "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), + "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), + "ip_local_port_range": fs.newInode(ctx, root, 0644, &portRange{stack: stack}), + "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")), "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")), "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")), "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")), @@ -421,3 +422,68 @@ func (ipf *ipForwarding) Write(ctx context.Context, src usermem.IOSequence, offs } return n, nil } + +// portRange implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/ip_local_port_range. +// +// +stateify savable +type portRange struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` + + // start and end store the port range. We must save/restore this here, + // since a netstack instance is created on restore. + start *uint16 + end *uint16 +} + +var _ vfs.WritableDynamicBytesSource = (*portRange)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (pr *portRange) Generate(ctx context.Context, buf *bytes.Buffer) error { + if pr.start == nil { + start, end := pr.stack.PortRange() + pr.start = &start + pr.end = &end + } + _, err := fmt.Fprintf(buf, "%d %d\n", *pr.start, *pr.end) + return err +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (pr *portRange) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit input size so as not to impact performance if input size is + // large. + src = src.TakeFirst(usermem.PageSize - 1) + + ports := make([]int32, 2) + n, err := usermem.CopyInt32StringsInVec(ctx, src.IO, src.Addrs, ports, src.Opts) + if err != nil { + return 0, err + } + + // Port numbers must be uint16s. + if ports[0] < 0 || ports[1] < 0 || ports[0] > math.MaxUint16 || ports[1] > math.MaxUint16 { + return 0, syserror.EINVAL + } + + if err := pr.stack.SetPortRange(uint16(ports[0]), uint16(ports[1])); err != nil { + return 0, err + } + if pr.start == nil { + pr.start = new(uint16) + pr.end = new(uint16) + } + *pr.start = uint16(ports[0]) + *pr.end = uint16(ports[1]) + return n, nil +} diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index fda1fa942..735756280 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -85,6 +85,11 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // inode implements kernfs.Inode. // // +stateify savable diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index dbd9ebdda..1d9280dae 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -143,6 +143,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.Filesystem.Release(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return fmt.Sprintf("dentry_cache_limit=%d", fs.MaxCachedDentries) +} + // dir implements kernfs.Inode. // // +stateify savable diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 4f675c21e..5fdca1d46 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -898,3 +898,8 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe d = d.parent } } + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return fs.mopts +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index a01e413e0..8df81f589 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -70,6 +70,10 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 + // mopts contains the tmpfs-specific mount options passed to this + // filesystem. Immutable. + mopts string + // mu serializes changes to the Dentry tree. mu sync.RWMutex `state:"nosave"` @@ -184,6 +188,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mfp: mfp, clock: clock, devMinor: devMinor, + mopts: opts.Data, } fs.vfsfs.Init(vfsObj, newFSType, &fs) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 9057d2b4e..6cb1a23e0 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -590,6 +590,23 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, return nil, err } + // Clear the Merkle tree file if they are to be generated at runtime. + // TODO(b/182315468): Optimize the Merkle tree generate process to + // allow only updating certain files/directories. + if fs.allowRuntimeEnable { + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: childMerkleVD, + Start: childMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_TRUNC, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + } + // The dentry needs to be cleaned up if any error occurs. IncRef will be // called if a verity child dentry is successfully created. defer childMerkleVD.DecRef(ctx) diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 374f71568..0d9b0ee2c 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -38,6 +38,7 @@ import ( "fmt" "math" "strconv" + "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -310,6 +311,24 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt d.DecRef(ctx) return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") } + + // Clear the Merkle tree file if they are to be generated at runtime. + // TODO(b/182315468): Optimize the Merkle tree generate process to + // allow only updating certain files/directories. + if fs.allowRuntimeEnable { + lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_TRUNC, + Mode: 0644, + }) + if err != nil { + return nil, nil, err + } + lowerMerkleFD.DecRef(ctx) + } + d.lowerMerkleVD = lowerMerkleVD // Get metadata from the underlying file system. @@ -418,6 +437,11 @@ func (fs *filesystem) Release(ctx context.Context) { fs.lowerMount.DecRef(ctx) } +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + return "" +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -750,6 +774,50 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } +// IterDirents implements vfs.FileDescriptionImpl.IterDirents. +func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback) error { + if !fd.d.isDir() { + return syserror.ENOTDIR + } + fd.mu.Lock() + defer fd.mu.Unlock() + + var ds []vfs.Dirent + err := fd.lowerFD.IterDirents(ctx, vfs.IterDirentsCallbackFunc(func(dirent vfs.Dirent) error { + // Do not include the Merkle tree files. + if strings.Contains(dirent.Name, merklePrefix) || strings.Contains(dirent.Name, merkleRootPrefix) { + return nil + } + if fd.d.verityEnabled() { + // Verify that the child is expected. + if dirent.Name != "." && dirent.Name != ".." { + if _, ok := fd.d.childrenNames[dirent.Name]; !ok { + return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name)) + } + } + } + ds = append(ds, dirent) + return nil + })) + + if err != nil { + return err + } + + // The result should contain all children plus "." and "..". + if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 { + return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) + } + + for fd.off < int64(len(ds)) { + if err := cb.Handle(ds[fd.off]); err != nil { + return err + } + fd.off++ + } + return nil +} + // Seek implements vfs.FileDescriptionImpl.Seek. func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { fd.mu.Lock() diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index f31277d30..6b71bd3a9 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -93,6 +93,14 @@ type Stack interface { // SetForwarding enables or disables packet forwarding between NICs. SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error + + // PortRange returns the UDP and TCP inclusive range of ephemeral ports + // used in both IPv4 and IPv6. + PortRange() (uint16, uint16) + + // SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range + // (inclusive). + SetPortRange(start uint16, end uint16) error } // Interface contains information about a network interface. diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 9ebeba8a3..03e2608c2 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -164,3 +164,15 @@ func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable b s.IPForwarding = enable return nil } + +// PortRange implements inet.Stack.PortRange. +func (*TestStack) PortRange() (uint16, uint16) { + // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). + return 32768, 28232 +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (*TestStack) SetPortRange(start uint16, end uint16) error { + // No-op. + return nil +} diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index e6323244c..5bcf92e14 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -504,3 +504,14 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } + +// PortRange implements inet.Stack.PortRange. +func (*Stack) PortRange() (uint16, uint16) { + // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). + return 32768, 28232 +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (*Stack) SetPortRange(start uint16, end uint16) error { + return syserror.EACCES +} diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 71c3bc034..b215067cf 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -478,3 +478,13 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) } return nil } + +// PortRange implements inet.Stack.PortRange. +func (s *Stack) PortRange() (uint16, uint16) { + return s.Stack.PortRange() +} + +// SetPortRange implements inet.Stack.SetPortRange. +func (s *Stack) SetPortRange(start uint16, end uint16) error { + return syserr.TranslateNetstackError(s.Stack.SetPortRange(start, end)).ToError() +} diff --git a/pkg/sentry/vfs/anonfs.go b/pkg/sentry/vfs/anonfs.go index 7ad0eaf86..3caf417ca 100644 --- a/pkg/sentry/vfs/anonfs.go +++ b/pkg/sentry/vfs/anonfs.go @@ -291,6 +291,11 @@ func (fs *anonFilesystem) PrependPath(ctx context.Context, vfsroot, vd VirtualDe return PrependPathSyntheticError{} } +// MountOptions implements FilesystemImpl.MountOptions. +func (fs *anonFilesystem) MountOptions() string { + return "" +} + // IncRef implements DentryImpl.IncRef. func (d *anonDentry) IncRef() { // no-op diff --git a/pkg/sentry/vfs/filesystem.go b/pkg/sentry/vfs/filesystem.go index 2c4b81e78..059939010 100644 --- a/pkg/sentry/vfs/filesystem.go +++ b/pkg/sentry/vfs/filesystem.go @@ -502,6 +502,15 @@ type FilesystemImpl interface { // // Preconditions: vd.Mount().Filesystem().Impl() == this FilesystemImpl. PrependPath(ctx context.Context, vfsroot, vd VirtualDentry, b *fspath.Builder) error + + // MountOptions returns mount options for the current filesystem. This + // should only return options specific to the filesystem (i.e. don't return + // "ro", "rw", etc). Options should be returned as a comma-separated string, + // similar to the input to the 5th argument to mount. + // + // If the implementation has no filesystem-specific options, it should + // return the empty string. + MountOptions() string } // PrependPathAtVFSRootError is returned by implementations of diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index bac9eb905..922f9e697 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -959,13 +959,17 @@ func manglePath(p string) string { // superBlockOpts returns the super block options string for the the mount at // the given path. func superBlockOpts(mountPath string, mnt *Mount) string { - // gVisor doesn't (yet) have a concept of super block options, so we - // use the ro/rw bit from the mount flag. + // Compose super block options by combining global mount flags with + // FS-specific mount options. opts := "rw" if mnt.ReadOnly() { opts = "ro" } + if mopts := mnt.fs.Impl().MountOptions(); mopts != "" { + opts += "," + mopts + } + // NOTE(b/147673608): If the mount is a cgroup, we also need to include // the cgroup name in the options. For now we just read that from the // path. diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 0b9139570..79e564de6 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -51,6 +51,7 @@ var ( ErrNotPermittedNet = New((&tcpip.ErrNotPermitted{}).String(), linux.EPERM) ErrBadBuffer = New((&tcpip.ErrBadBuffer{}).String(), linux.EFAULT) ErrMalformedHeader = New((&tcpip.ErrMalformedHeader{}).String(), linux.EINVAL) + ErrInvalidPortRange = New((&tcpip.ErrInvalidPortRange{}).String(), linux.EINVAL) ) // TranslateNetstackError converts an error from the tcpip package to a sentry @@ -135,6 +136,8 @@ func TranslateNetstackError(err tcpip.Error) *Error { return ErrBadBuffer case *tcpip.ErrMalformedHeader: return ErrMalformedHeader + case *tcpip.ErrInvalidPortRange: + return ErrInvalidPortRange default: panic(fmt.Sprintf("unknown error %T", err)) } diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 75d8e1f03..fc622b246 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -567,7 +567,7 @@ func TCPWindowLessThanEq(window uint16) TransportChecker { } // TCPFlags creates a checker that checks the tcp flags. -func TCPFlags(flags uint8) TransportChecker { +func TCPFlags(flags header.TCPFlags) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() @@ -576,15 +576,15 @@ func TCPFlags(flags uint8) TransportChecker { t.Fatalf("TCP header not found in h: %T", h) } - if f := tcp.Flags(); f != flags { - t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) + if got := tcp.Flags(); got != flags { + t.Errorf("got tcp.Flags() = %s, want %s", got, flags) } } } // TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the // given mask, match the supplied flags. -func TCPFlagsMatch(flags, mask uint8) TransportChecker { +func TCPFlagsMatch(flags, mask header.TCPFlags) TransportChecker { return func(t *testing.T, h header.Transport) { t.Helper() @@ -593,8 +593,8 @@ func TCPFlagsMatch(flags, mask uint8) TransportChecker { t.Fatalf("TCP header not found in h: %T", h) } - if f := tcp.Flags(); (f & mask) != (flags & mask) { - t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) + if got := tcp.Flags(); (got & mask) != (flags & mask) { + t.Errorf("got tcp.Flags() = %s, want %s, mask %s", got, flags, mask) } } } diff --git a/pkg/tcpip/errors.go b/pkg/tcpip/errors.go index 3b7cc52f3..5d478ac32 100644 --- a/pkg/tcpip/errors.go +++ b/pkg/tcpip/errors.go @@ -300,6 +300,19 @@ func (*ErrInvalidOptionValue) IgnoreStats() bool { } func (*ErrInvalidOptionValue) String() string { return "invalid option value specified" } +// ErrInvalidPortRange indicates an attempt to set an invalid port range. +// +// +stateify savable +type ErrInvalidPortRange struct{} + +func (*ErrInvalidPortRange) isError() {} + +// IgnoreStats implements Error. +func (*ErrInvalidPortRange) IgnoreStats() bool { + return true +} +func (*ErrInvalidPortRange) String() string { return "invalid port range" } + // ErrMalformedHeader indicates the operation encountered a malformed header. // // +stateify savable diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 4c6f808e5..adc835d30 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -45,9 +45,23 @@ const ( TCPMaxSACKBlocks = 4 ) +// TCPFlags is the dedicated type for TCP flags. +type TCPFlags uint8 + +// String implements Stringer.String. +func (f TCPFlags) String() string { + flagsStr := []byte("FSRPAU") + for i := range flagsStr { + if f&(1<<uint(i)) == 0 { + flagsStr[i] = ' ' + } + } + return string(flagsStr) +} + // Flags that may be set in a TCP segment. const ( - TCPFlagFin = 1 << iota + TCPFlagFin TCPFlags = 1 << iota TCPFlagSyn TCPFlagRst TCPFlagPsh @@ -94,7 +108,7 @@ type TCPFields struct { DataOffset uint8 // Flags is the "flags" field of a TCP packet. - Flags uint8 + Flags TCPFlags // WindowSize is the "window size" field of a TCP packet. WindowSize uint16 @@ -234,8 +248,8 @@ func (b TCP) Payload() []byte { } // Flags returns the flags field of the tcp header. -func (b TCP) Flags() uint8 { - return b[TCPFlagsOffset] +func (b TCP) Flags() TCPFlags { + return TCPFlags(b[TCPFlagsOffset]) } // WindowSize returns the "window size" field of the tcp header. @@ -319,10 +333,10 @@ func (b TCP) ParsedOptions() TCPOptions { return ParseTCPOptions(b.Options()) } -func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) { +func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) { binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq) binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack) - b[TCPFlagsOffset] = flags + b[TCPFlagsOffset] = uint8(flags) binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd) } @@ -338,7 +352,7 @@ func (b TCP) Encode(t *TCPFields) { // EncodePartial updates a subset of the fields of the tcp header. It is useful // in cases when similar segments are produced. -func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) { +func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) { // Add the total length and "flags" field contributions to the checksum. // We don't use the flags field directly from the header because it's a // one-byte field with an odd offset, so it would be accounted for diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go index 72563837b..96db8460f 100644 --- a/pkg/tcpip/header/tcp_test.go +++ b/pkg/tcpip/header/tcp_test.go @@ -146,3 +146,23 @@ func TestTCPParseOptions(t *testing.T) { } } } + +func TestTCPFlags(t *testing.T) { + for _, tt := range []struct { + flags header.TCPFlags + want string + }{ + {header.TCPFlagFin, "F "}, + {header.TCPFlagSyn, " S "}, + {header.TCPFlagRst, " R "}, + {header.TCPFlagPsh, " P "}, + {header.TCPFlagAck, " A "}, + {header.TCPFlagUrg, " U"}, + {header.TCPFlagSyn | header.TCPFlagAck, " S A "}, + {header.TCPFlagFin | header.TCPFlagAck, "F A "}, + } { + if got := tt.flags.String(); got != tt.want { + t.Errorf("got TCPFlags(%#b).String() = %s, want = %s", tt.flags, got, tt.want) + } + } +} diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 84189bba5..7aaee3d13 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -398,13 +398,7 @@ func logPacket(prefix string, dir direction, protocol tcpip.NetworkProtocolNumbe // Initialize the TCP flags. flags := tcp.Flags() - flagsStr := []byte("FSRPAU") - for i := range flagsStr { - if flags&(1<<uint(i)) == 0 { - flagsStr[i] = ' ' - } - } - details = fmt.Sprintf("flags:0x%02x (%s) seqnum: %d ack: %d win: %d xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) + details = fmt.Sprintf("flags: %s seqnum: %d ack: %d win: %d xsum:0x%x", flags, tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum()) if flags&header.TCPFlagSyn != 0 { details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0)) } else { diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 3fcdea119..ae0461a6d 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -232,7 +232,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.mu.Lock() - e.mu.dad.StopLocked(addr, false /* aborted */) + e.mu.dad.StopLocked(addr, &stack.DADDupAddrDetected{HolderLinkAddress: linkAddr}) e.mu.Unlock() // The solicited, override, and isRouter flags are not available for ARP; diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go index 6f89a6a16..0053646ee 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection.go @@ -126,9 +126,12 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet s.timer.Stop() delete(d.addresses, addr) - r := stack.DADResult{Resolved: dadDone, Err: err} + var res stack.DADResult = &stack.DADSucceeded{} + if err != nil { + res = &stack.DADError{Err: err} + } for _, h := range s.completionHandlers { - h(r) + h(res) } }), } @@ -142,7 +145,7 @@ func (d *DAD) CheckDuplicateAddressLocked(addr tcpip.Address, h stack.DADComplet // StopLocked stops a currently running DAD process. // // Precondition: d.protocolMU must be locked. -func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) { +func (d *DAD) StopLocked(addr tcpip.Address, reason stack.DADResult) { s, ok := d.addresses[addr] if !ok { return @@ -152,14 +155,8 @@ func (d *DAD) StopLocked(addr tcpip.Address, aborted bool) { s.timer.Stop() delete(d.addresses, addr) - var err tcpip.Error - if aborted { - err = &tcpip.ErrAborted{} - } - - r := stack.DADResult{Resolved: false, Err: err} for _, h := range s.completionHandlers { - h(r) + h(reason) } } diff --git a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go index 18c357b56..e00aa4678 100644 --- a/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go +++ b/pkg/tcpip/network/internal/ip/duplicate_address_detection_test.go @@ -78,10 +78,10 @@ func (m *mockDADProtocol) checkDuplicateAddress(addr tcpip.Address, h stack.DADC return m.mu.dad.CheckDuplicateAddressLocked(addr, h) } -func (m *mockDADProtocol) stop(addr tcpip.Address, aborted bool) { +func (m *mockDADProtocol) stop(addr tcpip.Address, reason stack.DADResult) { m.mu.Lock() defer m.mu.Unlock() - m.mu.dad.StopLocked(addr, aborted) + m.mu.dad.StopLocked(addr, reason) } func (m *mockDADProtocol) setConfigs(c stack.DADConfigurations) { @@ -175,7 +175,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { } clock.Advance(delta) for i := 0; i < 2; i++ { - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("(i=%d) dad result mismatch (-want +got):\n%s", i, diff) } } @@ -189,7 +189,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { default: } clock.Advance(delta) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } @@ -202,7 +202,7 @@ func TestDADCheckDuplicateAddress(t *testing.T) { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } clock.Advance(dadConfig2Duration) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } @@ -241,19 +241,19 @@ func TestDADStop(t *testing.T) { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } - dad.stop(addr1, true /* aborted */) - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: false, Err: &tcpip.ErrAborted{}}}, <-ch); diff != "" { + dad.stop(addr1, &stack.DADAborted{}) + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADAborted{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } - dad.stop(addr2, false /* aborted */) - if diff := cmp.Diff(dadResult{Addr: addr2, R: stack.DADResult{Resolved: false, Err: nil}}, <-ch); diff != "" { + dad.stop(addr2, &stack.DADDupAddrDetected{}) + if diff := cmp.Diff(dadResult{Addr: addr2, R: &stack.DADDupAddrDetected{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } dadResolutionDuration := time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr3, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr3, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } @@ -266,7 +266,7 @@ func TestDADStop(t *testing.T) { t.Errorf("dad check mismatch (-want +got):\n%s", diff) } clock.Advance(dadResolutionDuration) - if diff := cmp.Diff(dadResult{Addr: addr1, R: stack.DADResult{Resolved: true, Err: nil}}, <-ch); diff != "" { + if diff := cmp.Diff(dadResult{Addr: addr1, R: &stack.DADSucceeded{}}, <-ch); diff != "" { t.Errorf("dad result mismatch (-want +got):\n%s", diff) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 4b21ee79c..5e7f10f4b 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -32,12 +32,14 @@ go_test( "ipv4_test.go", ], deps = [ + "//pkg/sync", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/internal/testutil", diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index cabe274d6..8a2140ebe 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -899,10 +899,9 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { // Close cleans up resources associated with the endpoint. func (e *endpoint) Close() { e.mu.Lock() - defer e.mu.Unlock() - e.disableLocked() e.mu.addressableEndpointState.Cleanup() + e.mu.Unlock() e.protocol.forgetEndpoint(e.nic.ID()) } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 26d9696d7..cfed241bf 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -26,12 +26,14 @@ import ( "time" "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" @@ -2985,3 +2987,120 @@ func TestPacketQueing(t *testing.T) { }) } } + +// TestCloseLocking test that lock ordering is followed when closing an +// endpoint. +func TestCloseLocking(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + src = tcpip.Address("\x10\x00\x00\x01") + dst = tcpip.Address("\x10\x00\x00\x02") + + iterations = 1000 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + }) + + // Perform NAT so that the endoint tries to search for a sibling endpoint + // which ends up taking the protocol and endpoint lock (in that order). + table := stack.Table{ + Rules: []stack.Rule{ + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.RedirectTarget{Port: 5, NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + {Target: &stack.ErrorTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, + }, + BuiltinChains: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 1, + stack.Forward: stack.HookUnset, + stack.Output: 2, + stack.Postrouting: 3, + }, + Underflows: [stack.NumHooks]int{ + stack.Prerouting: 0, + stack.Input: 1, + stack.Forward: stack.HookUnset, + stack.Output: 2, + stack.Postrouting: 3, + }, + } + if err := s.IPTables().ReplaceTable(stack.NATID, table, false /* ipv6 */); err != nil { + t.Fatalf("s.IPTables().ReplaceTable(...): %s", err) + } + + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + + if err := s.AddAddress(nicID1, ipv4.ProtocolNumber, src); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) failed: %s", nicID1, ipv4.ProtocolNumber, src, err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }}) + + var wq waiter.Queue + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + addr := tcpip.FullAddress{NIC: nicID1, Addr: dst, Port: 53} + if err := ep.Connect(addr); err != nil { + t.Errorf("ep.Connect(%#v): %s", addr, err) + } + + var wg sync.WaitGroup + defer wg.Wait() + + // Writing packets should trigger NAT which requires the stack to search the + // protocol for network endpoints with the destination address. + // + // Creating and removing interfaces should modify the protocol and endpoint + // which requires taking the locks of each. + // + // We expect the protocol > endpoint lock ordering to be followed here. + wg.Add(2) + go func() { + defer wg.Done() + + data := []byte{1, 2, 3, 4} + + for i := 0; i < iterations; i++ { + var r bytes.Reader + r.Reset(data) + if n, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { + t.Errorf("ep.Write(_, _): %s", err) + return + } else if want := int64(len(data)); n != want { + t.Errorf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) + return + } + } + }() + go func() { + defer wg.Done() + + for i := 0; i < iterations; i++ { + if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + t.Errorf("CreateNIC(%d, _): %s", nicID2, err) + return + } + if err := s.RemoveNIC(nicID2); err != nil { + t.Errorf("RemoveNIC(%d): %s", nicID2, err) + return + } + } + }() +} diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index e80e681da..6344a3e09 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -382,6 +382,10 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // stack know so it can handle such a scenario and do nothing further with // the NS. if srcAddr == header.IPv6Any { + // Since this is a DAD message we know the sender does not actually hold + // the target address so there is no "holder". + var holderLinkAddress tcpip.LinkAddress + // We would get an error if the address no longer exists or the address // is no longer tentative (DAD resolved between the call to // hasTentativeAddr and this point). Both of these are valid scenarios: @@ -393,7 +397,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, holderLinkAddress); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: default: panic(fmt.Sprintf("unexpected error handling duplicate tentative address: %s", err)) @@ -561,10 +565,24 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // 5, NDP messages cannot be fragmented. Also note that in the common case // NDP datagrams are very small and AsView() will not incur allocations. na := header.NDPNeighborAdvert(payload.AsView()) + + it, err := na.Options().Iter(false /* check */) + if err != nil { + // If we have a malformed NDP NA option, drop the packet. + received.invalid.Increment() + return + } + + targetLinkAddr, ok := getTargetLinkAddr(it) + if !ok { + received.invalid.Increment() + return + } + targetAddr := na.TargetAddress() e.dad.mu.Lock() - e.dad.mu.dad.StopLocked(targetAddr, false /* aborted */) + e.dad.mu.dad.StopLocked(targetAddr, &stack.DADDupAddrDetected{HolderLinkAddress: targetLinkAddr}) e.dad.mu.Unlock() if e.hasTentativeAddr(targetAddr) { @@ -584,7 +602,7 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // // TODO(gvisor.dev/issue/4046): Handle the scenario when a duplicate // address is detected for an assigned address. - switch err := e.dupTentativeAddrDetected(targetAddr); err.(type) { + switch err := e.dupTentativeAddrDetected(targetAddr, targetLinkAddr); err.(type) { case nil, *tcpip.ErrBadAddress, *tcpip.ErrInvalidEndpointState: return default: @@ -592,13 +610,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } } - it, err := na.Options().Iter(false /* check */) - if err != nil { - // If we have a malformed NDP NA option, drop the packet. - received.invalid.Increment() - return - } - // At this point we know that the target address is not tentative on the // NIC. However, the target address may still be assigned to the NIC but not // tentative (it could be permanent). Such a scenario is beyond the scope of @@ -608,11 +619,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r // TODO(b/143147598): Handle the scenario described above. Also inform the // netstack integration that a duplicate address was detected outside of // DAD. - targetLinkAddr, ok := getTargetLinkAddr(it) - if !ok { - received.invalid.Increment() - return - } // As per RFC 4861 section 7.1.2: // A node MUST silently discard any received Neighbor Advertisement diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 544717678..46b6cc41a 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -348,7 +348,7 @@ func (e *endpoint) hasTentativeAddr(addr tcpip.Address) bool { // dupTentativeAddrDetected removes the tentative address if it exists. If the // address was generated via SLAAC, an attempt is made to generate a new // address. -func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error { +func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address, holderLinkAddr tcpip.LinkAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() @@ -363,7 +363,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) tcpip.Error { // If the address is a SLAAC address, do not invalidate its SLAAC prefix as an // attempt will be made to generate a new address for it. - if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, true /* dadFailure */); err != nil { + if err := e.removePermanentEndpointLocked(addressEndpoint, false /* allowSLAACInvalidation */, &stack.DADDupAddrDetected{HolderLinkAddress: holderLinkAddr}); err != nil { return err } @@ -536,8 +536,20 @@ func (e *endpoint) disableLocked() { } e.mu.ndp.stopSolicitingRouters() + // Stop DAD for all the tentative unicast addresses. + e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { + if addressEndpoint.GetKind() != stack.PermanentTentative { + return true + } + + addr := addressEndpoint.AddressWithPrefix().Address + if header.IsV6UnicastAddress(addr) { + e.mu.ndp.stopDuplicateAddressDetection(addr, &stack.DADAborted{}) + } + + return true + }) e.mu.ndp.cleanupState(false /* hostOnly */) - e.stopDADForPermanentAddressesLocked() // The endpoint may have already left the multicast group. switch err := e.leaveGroupLocked(header.IPv6AllNodesMulticastAddress); err.(type) { @@ -555,25 +567,6 @@ func (e *endpoint) disableLocked() { } } -// stopDADForPermanentAddressesLocked stops DAD for all permaneent addresses. -// -// Precondition: e.mu must be write locked. -func (e *endpoint) stopDADForPermanentAddressesLocked() { - // Stop DAD for all the tentative unicast addresses. - e.mu.addressableEndpointState.ForEachEndpoint(func(addressEndpoint stack.AddressEndpoint) bool { - if addressEndpoint.GetKind() != stack.PermanentTentative { - return true - } - - addr := addressEndpoint.AddressWithPrefix().Address - if header.IsV6UnicastAddress(addr) { - e.mu.ndp.stopDuplicateAddressDetection(addr, false /* failed */) - } - - return true - }) -} - // DefaultTTL is the default hop limit for this endpoint. func (e *endpoint) DefaultTTL() uint8 { return e.protocol.DefaultTTL() @@ -1384,8 +1377,6 @@ func (e *endpoint) handlePacket(pkt *stack.PacketBuffer) { func (e *endpoint) Close() { e.mu.Lock() e.disableLocked() - e.mu.ndp.removeSLAACAddresses(false /* keepLinkLocal */) - e.stopDADForPermanentAddressesLocked() e.mu.addressableEndpointState.Cleanup() e.mu.Unlock() @@ -1451,14 +1442,14 @@ func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { return &tcpip.ErrBadLocalAddress{} } - return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, false /* dadFailure */) + return e.removePermanentEndpointLocked(addressEndpoint, true /* allowSLAACInvalidation */, &stack.DADAborted{}) } // removePermanentEndpointLocked is like removePermanentAddressLocked except // it works with a stack.AddressEndpoint. // // Precondition: e.mu must be write locked. -func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation, dadFailure bool) tcpip.Error { +func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEndpoint, allowSLAACInvalidation bool, dadResult stack.DADResult) tcpip.Error { addr := addressEndpoint.AddressWithPrefix() // If we are removing an address generated via SLAAC, cleanup // its SLAAC resources and notify the integrator. @@ -1469,16 +1460,16 @@ func (e *endpoint) removePermanentEndpointLocked(addressEndpoint stack.AddressEn e.mu.ndp.cleanupTempSLAACAddrResourcesAndNotify(addr) } - return e.removePermanentEndpointInnerLocked(addressEndpoint, dadFailure) + return e.removePermanentEndpointInnerLocked(addressEndpoint, dadResult) } // removePermanentEndpointInnerLocked is like removePermanentEndpointLocked // except it does not cleanup SLAAC address state. // // Precondition: e.mu must be write locked. -func (e *endpoint) removePermanentEndpointInnerLocked(addressEndpoint stack.AddressEndpoint, dadFailure bool) tcpip.Error { +func (e *endpoint) removePermanentEndpointInnerLocked(addressEndpoint stack.AddressEndpoint, dadResult stack.DADResult) tcpip.Error { addr := addressEndpoint.AddressWithPrefix() - e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadFailure) + e.mu.ndp.stopDuplicateAddressDetection(addr.Address, dadResult) if err := e.mu.addressableEndpointState.RemovePermanentEndpoint(addressEndpoint); err != nil { return err diff --git a/pkg/tcpip/network/ipv6/ndp.go b/pkg/tcpip/network/ipv6/ndp.go index c22f60709..d9b728878 100644 --- a/pkg/tcpip/network/ipv6/ndp.go +++ b/pkg/tcpip/network/ipv6/ndp.go @@ -208,16 +208,12 @@ const ( // NDPDispatcher is the interface integrators of netstack must implement to // receive and handle NDP related events. type NDPDispatcher interface { - // OnDuplicateAddressDetectionStatus is called when the DAD process for an - // address (addr) on a NIC (with ID nicID) completes. resolved is set to true - // if DAD completed successfully (no duplicate addr detected); false otherwise - // (addr was detected to be a duplicate on the link the NIC is a part of, or - // it was stopped for some other reason, such as the address being removed). - // If an error occured during DAD, err is set and resolved must be ignored. + // OnDuplicateAddressDetectionResult is called when the DAD process for an + // address on a NIC completes. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) + OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) // OnDefaultRouterDiscovered is called when a new default router is // discovered. Implementations must return true if the newly discovered @@ -225,14 +221,14 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterDiscovered(nicID tcpip.NICID, addr tcpip.Address) bool + OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool // OnDefaultRouterInvalidated is called when a discovered default router that // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnDefaultRouterInvalidated(nicID tcpip.NICID, addr tcpip.Address) + OnDefaultRouterInvalidated(tcpip.NICID, tcpip.Address) // OnOnLinkPrefixDiscovered is called when a new on-link prefix is discovered. // Implementations must return true if the newly discovered on-link prefix @@ -240,14 +236,14 @@ type NDPDispatcher interface { // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixDiscovered(nicID tcpip.NICID, prefix tcpip.Subnet) bool + OnOnLinkPrefixDiscovered(tcpip.NICID, tcpip.Subnet) bool // OnOnLinkPrefixInvalidated is called when a discovered on-link prefix that // was remembered is invalidated. // // This function is not permitted to block indefinitely. This function // is also not permitted to call into the stack. - OnOnLinkPrefixInvalidated(nicID tcpip.NICID, prefix tcpip.Subnet) + OnOnLinkPrefixInvalidated(tcpip.NICID, tcpip.Subnet) // OnAutoGenAddress is called when a new prefix with its autonomous address- // configuration flag set is received and SLAAC was performed. Implementations @@ -280,12 +276,12 @@ type NDPDispatcher interface { // It is up to the caller to use the DNS Servers only for their valid // lifetime. OnRecursiveDNSServerOption may be called for new or // already known DNS servers. If called with known DNS servers, their - // valid lifetimes must be refreshed to lifetime (it may be increased, - // decreased, or completely invalidated when lifetime = 0). + // valid lifetimes must be refreshed to the lifetime (it may be increased, + // decreased, or completely invalidated when the lifetime = 0). // // This function is not permitted to block indefinitely. It must not // call functions on the stack itself. - OnRecursiveDNSServerOption(nicID tcpip.NICID, addrs []tcpip.Address, lifetime time.Duration) + OnRecursiveDNSServerOption(tcpip.NICID, []tcpip.Address, time.Duration) // OnDNSSearchListOption is called when the stack learns of DNS search lists // through NDP. @@ -293,9 +289,9 @@ type NDPDispatcher interface { // It is up to the caller to use the domain names in the search list // for only their valid lifetime. OnDNSSearchListOption may be called // with new or already known domain names. If called with known domain - // names, their valid lifetimes must be refreshed to lifetime (it may - // be increased, decreased or completely invalidated when lifetime = 0. - OnDNSSearchListOption(nicID tcpip.NICID, domainNames []string, lifetime time.Duration) + // names, their valid lifetimes must be refreshed to the lifetime (it may + // be increased, decreased or completely invalidated when the lifetime = 0. + OnDNSSearchListOption(tcpip.NICID, []string, time.Duration) // OnDHCPv6Configuration is called with an updated configuration that is // available via DHCPv6 for the passed NIC. @@ -587,15 +583,25 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE panic(fmt.Sprintf("ndpdad: addr %s is no longer tentative on NIC(%d)", addr, ndp.ep.nic.ID())) } - if r.Resolved { + var dadSucceeded bool + switch r.(type) { + case *stack.DADAborted, *stack.DADError, *stack.DADDupAddrDetected: + dadSucceeded = false + case *stack.DADSucceeded: + dadSucceeded = true + default: + panic(fmt.Sprintf("unrecognized DAD result = %T", r)) + } + + if dadSucceeded { addressEndpoint.SetKind(stack.Permanent) } if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, r.Resolved, r.Err) + ndpDisp.OnDuplicateAddressDetectionResult(ndp.ep.nic.ID(), addr, r) } - if r.Resolved { + if dadSucceeded { if addressEndpoint.ConfigType() == stack.AddressConfigSlaac { // Reset the generation attempts counter as we are starting the // generation of a new address for the SLAAC prefix. @@ -616,7 +622,7 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // Consider DAD to have resolved even if no DAD messages were actually // transmitted. if ndpDisp := ndp.ep.protocol.options.NDPDisp; ndpDisp != nil { - ndpDisp.OnDuplicateAddressDetectionStatus(ndp.ep.nic.ID(), addr, true, nil) + ndpDisp.OnDuplicateAddressDetectionResult(ndp.ep.nic.ID(), addr, &stack.DADSucceeded{}) } ndp.ep.onAddressAssignedLocked(addr) @@ -633,8 +639,8 @@ func (ndp *ndpState) startDuplicateAddressDetection(addr tcpip.Address, addressE // of this function to handle such a scenario. // // The IPv6 endpoint that ndp belongs to MUST be locked. -func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, failed bool) { - ndp.dad.StopLocked(addr, !failed) +func (ndp *ndpState) stopDuplicateAddressDetection(addr tcpip.Address, reason stack.DADResult) { + ndp.dad.StopLocked(addr, reason) } // handleRA handles a Router Advertisement message that arrived on the NIC @@ -1501,7 +1507,7 @@ func (ndp *ndpState) invalidateSLAACPrefix(prefix tcpip.Subnet, state slaacPrefi ndpDisp.OnAutoGenAddressInvalidated(ndp.ep.nic.ID(), addressEndpoint.AddressWithPrefix()) } - if err := ndp.ep.removePermanentEndpointInnerLocked(addressEndpoint, false /* dadFailure */); err != nil { + if err := ndp.ep.removePermanentEndpointInnerLocked(addressEndpoint, &stack.DADAborted{}); err != nil { panic(fmt.Sprintf("ndp: error removing stable SLAAC address %s: %s", addressEndpoint.AddressWithPrefix(), err)) } } @@ -1560,7 +1566,7 @@ func (ndp *ndpState) cleanupSLAACPrefixResources(prefix tcpip.Subnet, state slaa func (ndp *ndpState) invalidateTempSLAACAddr(tempAddrs map[tcpip.Address]tempSLAACAddrState, tempAddr tcpip.Address, tempAddrState tempSLAACAddrState) { ndp.cleanupTempSLAACAddrResourcesAndNotifyInner(tempAddrs, tempAddr, tempAddrState) - if err := ndp.ep.removePermanentEndpointInnerLocked(tempAddrState.addressEndpoint, false /* dadFailure */); err != nil { + if err := ndp.ep.removePermanentEndpointInnerLocked(tempAddrState.addressEndpoint, &stack.DADAborted{}); err != nil { panic(fmt.Sprintf("error removing temporary SLAAC address %s: %s", tempAddrState.addressEndpoint.AddressWithPrefix(), err)) } } diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index 0d53c260d..6e850fd46 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -90,7 +90,7 @@ type testNDPDispatcher struct { addr tcpip.Address } -func (*testNDPDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { +func (*testNDPDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } func (t *testNDPDispatcher) OnDefaultRouterDiscovered(_ tcpip.NICID, addr tcpip.Address) bool { @@ -1314,10 +1314,10 @@ func TestCheckDuplicateAddress(t *testing.T) { t.Fatalf("got s.CheckDuplicateAddress(%d, %d, %s, _) = %d, want = %d", nicID, ProtocolNumber, lladdr0, res, stack.DADAlreadyRunning) } - // Wait for DAD to resolve. + // Wait for DAD to complete. clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits) * dadConfigs.RetransmitTimer) for i := 0; i < dadRequestsMade; i++ { - if diff := cmp.Diff(stack.DADResult{Resolved: true}, <-ch); diff != "" { + if diff := cmp.Diff(&stack.DADSucceeded{}, <-ch); diff != "" { t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) } } diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD index 57abec5c9..210262703 100644 --- a/pkg/tcpip/ports/BUILD +++ b/pkg/tcpip/ports/BUILD @@ -4,7 +4,10 @@ package(licenses = ["notice"]) go_library( name = "ports", - srcs = ["ports.go"], + srcs = [ + "flags.go", + "ports.go", + ], visibility = ["//visibility:public"], deps = [ "//pkg/sync", diff --git a/pkg/tcpip/ports/flags.go b/pkg/tcpip/ports/flags.go new file mode 100644 index 000000000..a8d7bff25 --- /dev/null +++ b/pkg/tcpip/ports/flags.go @@ -0,0 +1,150 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ports + +// Flags represents the type of port reservation. +// +// +stateify savable +type Flags struct { + // MostRecent represents UDP SO_REUSEADDR. + MostRecent bool + + // LoadBalanced indicates SO_REUSEPORT. + // + // LoadBalanced takes precidence over MostRecent. + LoadBalanced bool + + // TupleOnly represents TCP SO_REUSEADDR. + TupleOnly bool +} + +// Bits converts the Flags to their bitset form. +func (f Flags) Bits() BitFlags { + var rf BitFlags + if f.MostRecent { + rf |= MostRecentFlag + } + if f.LoadBalanced { + rf |= LoadBalancedFlag + } + if f.TupleOnly { + rf |= TupleOnlyFlag + } + return rf +} + +// Effective returns the effective behavior of a flag config. +func (f Flags) Effective() Flags { + e := f + if e.LoadBalanced && e.MostRecent { + e.MostRecent = false + } + return e +} + +// BitFlags is a bitset representation of Flags. +type BitFlags uint32 + +const ( + // MostRecentFlag represents Flags.MostRecent. + MostRecentFlag BitFlags = 1 << iota + + // LoadBalancedFlag represents Flags.LoadBalanced. + LoadBalancedFlag + + // TupleOnlyFlag represents Flags.TupleOnly. + TupleOnlyFlag + + // nextFlag is the value that the next added flag will have. + // + // It is used to calculate FlagMask below. It is also the number of + // valid flag states. + nextFlag + + // FlagMask is a bit mask for BitFlags. + FlagMask = nextFlag - 1 + + // MultiBindFlagMask contains the flags that allow binding the same + // tuple multiple times. + MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag +) + +// ToFlags converts the bitset into a Flags struct. +func (f BitFlags) ToFlags() Flags { + return Flags{ + MostRecent: f&MostRecentFlag != 0, + LoadBalanced: f&LoadBalancedFlag != 0, + TupleOnly: f&TupleOnlyFlag != 0, + } +} + +// FlagCounter counts how many references each flag combination has. +type FlagCounter struct { + // refs stores the count for each possible flag combination, (0 though + // FlagMask). + refs [nextFlag]int +} + +// AddRef increases the reference count for a specific flag combination. +func (c *FlagCounter) AddRef(flags BitFlags) { + c.refs[flags]++ +} + +// DropRef decreases the reference count for a specific flag combination. +func (c *FlagCounter) DropRef(flags BitFlags) { + c.refs[flags]-- +} + +// TotalRefs calculates the total number of references for all flag +// combinations. +func (c FlagCounter) TotalRefs() int { + var total int + for _, r := range c.refs { + total += r + } + return total +} + +// FlagRefs returns the number of references with all specified flags. +func (c FlagCounter) FlagRefs(flags BitFlags) int { + var total int + for i, r := range c.refs { + if BitFlags(i)&flags == flags { + total += r + } + } + return total +} + +// AllRefsHave returns if all references have all specified flags. +func (c FlagCounter) AllRefsHave(flags BitFlags) bool { + for i, r := range c.refs { + if BitFlags(i)&flags != flags && r > 0 { + return false + } + } + return true +} + +// SharedFlags returns the set of flags shared by all references. +func (c FlagCounter) SharedFlags() BitFlags { + intersection := FlagMask + for i, r := range c.refs { + if r > 0 { + intersection &= BitFlags(i) + } + } + return intersection +} diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 11dbdbbcf..678199371 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package ports provides PortManager that manages allocating, reserving and releasing ports. +// Package ports provides PortManager that manages allocating, reserving and +// releasing ports. package ports import ( - "math" "math/rand" "sync/atomic" @@ -24,169 +24,44 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" ) -const ( - // FirstEphemeral is the first ephemeral port. - FirstEphemeral = 16000 +const anyIPAddress tcpip.Address = "" - // numEphemeralPorts it the mnumber of available ephemeral ports to - // Netstack. - numEphemeralPorts = math.MaxUint16 - FirstEphemeral + 1 +// Reservation describes a port reservation. +type Reservation struct { + // Networks is a list of network protocols to which the reservation + // applies. Can be IPv4, IPv6, or both. + Networks []tcpip.NetworkProtocolNumber - anyIPAddress tcpip.Address = "" -) - -type portDescriptor struct { - network tcpip.NetworkProtocolNumber - transport tcpip.TransportProtocolNumber - port uint16 -} - -// Flags represents the type of port reservation. -// -// +stateify savable -type Flags struct { - // MostRecent represents UDP SO_REUSEADDR. - MostRecent bool - - // LoadBalanced indicates SO_REUSEPORT. - // - // LoadBalanced takes precidence over MostRecent. - LoadBalanced bool - - // TupleOnly represents TCP SO_REUSEADDR. - TupleOnly bool -} - -// Bits converts the Flags to their bitset form. -func (f Flags) Bits() BitFlags { - var rf BitFlags - if f.MostRecent { - rf |= MostRecentFlag - } - if f.LoadBalanced { - rf |= LoadBalancedFlag - } - if f.TupleOnly { - rf |= TupleOnlyFlag - } - return rf -} - -// Effective returns the effective behavior of a flag config. -func (f Flags) Effective() Flags { - e := f - if e.LoadBalanced && e.MostRecent { - e.MostRecent = false - } - return e -} - -// PortManager manages allocating, reserving and releasing ports. -type PortManager struct { - mu sync.RWMutex - allocatedPorts map[portDescriptor]bindAddresses - - // hint is used to pick ports ephemeral ports in a stable order for - // a given port offset. - // - // hint must be accessed using the portHint/incPortHint helpers. - // TODO(gvisor.dev/issue/940): S/R this field. - hint uint32 -} - -// BitFlags is a bitset representation of Flags. -type BitFlags uint32 - -const ( - // MostRecentFlag represents Flags.MostRecent. - MostRecentFlag BitFlags = 1 << iota - - // LoadBalancedFlag represents Flags.LoadBalanced. - LoadBalancedFlag - - // TupleOnlyFlag represents Flags.TupleOnly. - TupleOnlyFlag - - // nextFlag is the value that the next added flag will have. - // - // It is used to calculate FlagMask below. It is also the number of - // valid flag states. - nextFlag - - // FlagMask is a bit mask for BitFlags. - FlagMask = nextFlag - 1 + // Transport is the transport protocol to which the reservation applies. + Transport tcpip.TransportProtocolNumber - // MultiBindFlagMask contains the flags that allow binding the same - // tuple multiple times. - MultiBindFlagMask = MostRecentFlag | LoadBalancedFlag -) - -// ToFlags converts the bitset into a Flags struct. -func (f BitFlags) ToFlags() Flags { - return Flags{ - MostRecent: f&MostRecentFlag != 0, - LoadBalanced: f&LoadBalancedFlag != 0, - TupleOnly: f&TupleOnlyFlag != 0, - } -} + // Addr is the address of the local endpoint. + Addr tcpip.Address -// FlagCounter counts how many references each flag combination has. -type FlagCounter struct { - // refs stores the count for each possible flag combination, (0 though - // FlagMask). - refs [nextFlag]int -} + // Port is the local port number. + Port uint16 -// AddRef increases the reference count for a specific flag combination. -func (c *FlagCounter) AddRef(flags BitFlags) { - c.refs[flags]++ -} + // Flags describe features of the reservation. + Flags Flags -// DropRef decreases the reference count for a specific flag combination. -func (c *FlagCounter) DropRef(flags BitFlags) { - c.refs[flags]-- -} + // BindToDevice is the NIC to which the reservation applies. + BindToDevice tcpip.NICID -// TotalRefs calculates the total number of references for all flag -// combinations. -func (c FlagCounter) TotalRefs() int { - var total int - for _, r := range c.refs { - total += r - } - return total + // Dest is the destination address. + Dest tcpip.FullAddress } -// FlagRefs returns the number of references with all specified flags. -func (c FlagCounter) FlagRefs(flags BitFlags) int { - var total int - for i, r := range c.refs { - if BitFlags(i)&flags == flags { - total += r - } - } - return total -} - -// AllRefsHave returns if all references have all specified flags. -func (c FlagCounter) AllRefsHave(flags BitFlags) bool { - for i, r := range c.refs { - if BitFlags(i)&flags != flags && r > 0 { - return false - } +func (rs Reservation) dst() destination { + return destination{ + rs.Dest.Addr, + rs.Dest.Port, } - return true } -// IntersectionRefs returns the set of flags shared by all references. -func (c FlagCounter) IntersectionRefs() BitFlags { - intersection := FlagMask - for i, r := range c.refs { - if r > 0 { - intersection &= BitFlags(i) - } - } - return intersection +type portDescriptor struct { + network tcpip.NetworkProtocolNumber + transport tcpip.TransportProtocolNumber + port uint16 } type destination struct { @@ -194,18 +69,14 @@ type destination struct { port uint16 } -func makeDestination(a tcpip.FullAddress) destination { - return destination{ - a.Addr, - a.Port, - } -} - -// portNode is never empty. When it has no elements, it is removed from the -// map that references it. -type portNode map[destination]FlagCounter +// destToCounter maps each destination to the FlagCounter that represents +// endpoints to that destination. +// +// destToCounter is never empty. When it has no elements, it is removed from +// the map that references it. +type destToCounter map[destination]FlagCounter -// intersectionRefs calculates the intersection of flag bit values which affect +// intersectionFlags calculates the intersection of flag bit values which affect // the specified destination. // // If no destinations are present, all flag values are returned as there are no @@ -213,20 +84,20 @@ type portNode map[destination]FlagCounter // // In addition to the intersection, the number of intersecting refs is // returned. -func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { +func (dc destToCounter) intersectionFlags(res Reservation) (BitFlags, int) { intersection := FlagMask var count int - for d, f := range p { - if d == dst { - intersection &= f.IntersectionRefs() + for dest, counter := range dc { + if dest == res.dst() { + intersection &= counter.SharedFlags() count++ continue } // Wildcard destinations affect all destinations for TupleOnly. - if d.addr == anyIPAddress || dst.addr == anyIPAddress { + if dest.addr == anyIPAddress || res.Dest.Addr == anyIPAddress { // Only bitwise and the TupleOnlyFlag. - intersection &= ((^TupleOnlyFlag) | f.IntersectionRefs()) + intersection &= ((^TupleOnlyFlag) | counter.SharedFlags()) count++ } } @@ -234,27 +105,29 @@ func (p portNode) intersectionRefs(dst destination) (BitFlags, int) { return intersection, count } -// deviceNode is never empty. When it has no elements, it is removed from the +// deviceToDest maps NICs to destinations for which there are port reservations. +// +// deviceToDest is never empty. When it has no elements, it is removed from the // map that references it. -type deviceNode map[tcpip.NICID]portNode +type deviceToDest map[tcpip.NICID]destToCounter -// isAvailable checks whether binding is possible by device. If not binding to a -// device, check against all FlagCounters. If binding to a specific device, check -// against the unspecified device and the provided device. +// isAvailable checks whether binding is possible by device. If not binding to +// a device, check against all FlagCounters. If binding to a specific device, +// check against the unspecified device and the provided device. // // If either of the port reuse flags is enabled on any of the nodes, all nodes // sharing a port must share at least one reuse flag. This matches Linux's // behavior. -func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - flagBits := flags.Bits() - if bindToDevice == 0 { +func (dd deviceToDest) isAvailable(res Reservation) bool { + flagBits := res.Flags.Bits() + if res.BindToDevice == 0 { intersection := FlagMask - for _, p := range d { - i, c := p.intersectionRefs(dst) - if c == 0 { + for _, dest := range dd { + flags, count := dest.intersectionFlags(res) + if count == 0 { continue } - intersection &= i + intersection &= flags if intersection&flagBits == 0 { // Can't bind because the (addr,port) was // previously bound without reuse. @@ -266,18 +139,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti intersection := FlagMask - if p, ok := d[0]; ok { - var c int - intersection, c = p.intersectionRefs(dst) - if c > 0 && intersection&flagBits == 0 { + if dests, ok := dd[0]; ok { + var count int + intersection, count = dests.intersectionFlags(res) + if count > 0 && intersection&flagBits == 0 { return false } } - if p, ok := d[bindToDevice]; ok { - i, c := p.intersectionRefs(dst) - intersection &= i - if c > 0 && intersection&flagBits == 0 { + if dests, ok := dd[res.BindToDevice]; ok { + flags, count := dests.intersectionFlags(res) + intersection &= flags + if count > 0 && intersection&flagBits == 0 { return false } } @@ -285,18 +158,18 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID, dst desti return true } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]deviceNode +// addrToDevice maps IP addresses to NICs that have port reservations. +type addrToDevice map[tcpip.Address]deviceToDest // isAvailable checks whether an IP address is available to bind to. If the // address is the "any" address, check all other addresses. Otherwise, just // check against the "any" address and the provided address. -func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - if addr == anyIPAddress { - // If binding to the "any" address then check that there are no conflicts - // with all addresses. - for _, d := range b { - if !d.isAvailable(flags, bindToDevice, dst) { +func (ad addrToDevice) isAvailable(res Reservation) bool { + if res.Addr == anyIPAddress { + // If binding to the "any" address then check that there are no + // conflicts with all addresses. + for _, devices := range ad { + if !devices.isAvailable(res) { return false } } @@ -304,15 +177,15 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice } // Check that there is no conflict with the "any" address. - if d, ok := b[anyIPAddress]; ok { - if !d.isAvailable(flags, bindToDevice, dst) { + if devices, ok := ad[anyIPAddress]; ok { + if !devices.isAvailable(res) { return false } } // Check that this is no conflict with the provided address. - if d, ok := b[addr]; ok { - if !d.isAvailable(flags, bindToDevice, dst) { + if devices, ok := ad[res.Addr]; ok { + if !devices.isAvailable(res) { return false } } @@ -320,50 +193,93 @@ func (b bindAddresses) isAvailable(addr tcpip.Address, flags Flags, bindToDevice return true } +// PortManager manages allocating, reserving and releasing ports. +type PortManager struct { + // mu protects allocatedPorts. + // LOCK ORDERING: mu > ephemeralMu. + mu sync.RWMutex + // allocatedPorts is a nesting of maps that ultimately map Reservations + // to FlagCounters describing whether the Reservation is valid and can + // be reused. + allocatedPorts map[portDescriptor]addrToDevice + + // ephemeralMu protects firstEphemeral and numEphemeral. + ephemeralMu sync.RWMutex + firstEphemeral uint16 + numEphemeral uint16 + + // hint is used to pick ports ephemeral ports in a stable order for + // a given port offset. + // + // hint must be accessed using the portHint/incPortHint helpers. + // TODO(gvisor.dev/issue/940): S/R this field. + hint uint32 +} + // NewPortManager creates new PortManager. func NewPortManager() *PortManager { - return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)} + return &PortManager{ + allocatedPorts: make(map[portDescriptor]addrToDevice), + // Match Linux's default ephemeral range. See: + // https://github.com/torvalds/linux/blob/e54937963fa249595824439dc839c948188dea83/net/ipv4/af_inet.c#L1842 + firstEphemeral: 32768, + numEphemeral: 28232, + } } +// PortTester indicates whether the passed in port is suitable. Returning an +// error causes the function to which the PortTester is passed to return that +// error. +type PortTester func(port uint16) (good bool, err tcpip.Error) + // PickEphemeralPort randomly chooses a starting point and iterates over all // possible ephemeral ports, allowing the caller to decide whether a given port // is suitable for its needs, and stopping when a port is found or an error // occurs. -func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - offset := uint32(rand.Int31n(numEphemeralPorts)) - return s.pickEphemeralPort(offset, numEphemeralPorts, testPort) +func (pm *PortManager) PickEphemeralPort(testPort PortTester) (port uint16, err tcpip.Error) { + pm.ephemeralMu.RLock() + firstEphemeral := pm.firstEphemeral + numEphemeral := pm.numEphemeral + pm.ephemeralMu.RUnlock() + + offset := uint16(rand.Int31n(int32(numEphemeral))) + return pickEphemeralPort(offset, firstEphemeral, numEphemeral, testPort) } -// portHint atomically reads and returns the s.hint value. -func (s *PortManager) portHint() uint32 { - return atomic.LoadUint32(&s.hint) +// portHint atomically reads and returns the pm.hint value. +func (pm *PortManager) portHint() uint16 { + return uint16(atomic.LoadUint32(&pm.hint)) } -// incPortHint atomically increments s.hint by 1. -func (s *PortManager) incPortHint() { - atomic.AddUint32(&s.hint, 1) +// incPortHint atomically increments pm.hint by 1. +func (pm *PortManager) incPortHint() { + atomic.AddUint32(&pm.hint, 1) } -// PickEphemeralPortStable starts at the specified offset + s.portHint and +// PickEphemeralPortStable starts at the specified offset + pm.portHint and // iterates over all ephemeral ports, allowing the caller to decide whether a // given port is suitable for its needs and stopping when a port is found or an // error occurs. -func (s *PortManager) PickEphemeralPortStable(offset uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - p, err := s.pickEphemeralPort(s.portHint()+offset, numEphemeralPorts, testPort) +func (pm *PortManager) PickEphemeralPortStable(offset uint16, testPort PortTester) (port uint16, err tcpip.Error) { + pm.ephemeralMu.RLock() + firstEphemeral := pm.firstEphemeral + numEphemeral := pm.numEphemeral + pm.ephemeralMu.RUnlock() + + p, err := pickEphemeralPort(pm.portHint()+offset, firstEphemeral, numEphemeral, testPort) if err == nil { - s.incPortHint() + pm.incPortHint() } return p, err - } // pickEphemeralPort starts at the offset specified from the FirstEphemeral port // and iterates over the number of ports specified by count and allows the // caller to decide whether a given port is suitable for its needs, and stopping // when a port is found or an error occurs. -func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p uint16) (bool, tcpip.Error)) (port uint16, err tcpip.Error) { - for i := uint32(0); i < count; i++ { - port = uint16(FirstEphemeral + (offset+i)%count) +func pickEphemeralPort(offset, first, count uint16, testPort PortTester) (port uint16, err tcpip.Error) { + for i := uint16(0); i < count; i++ { + port = first + (offset+i)%count ok, err := testPort(port) if err != nil { return 0, err @@ -377,144 +293,145 @@ func (s *PortManager) pickEphemeralPort(offset, count uint32, testPort func(p ui return 0, &tcpip.ErrNoPortAvailable{} } -// IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { - s.mu.Lock() - defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, makeDestination(dest)) -} - -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - for _, network := range networks { - desc := portDescriptor{network, transport, port} - if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, flags, bindToDevice, dst) { - return false - } - } - } - return true -} - // ReservePort marks a port/IP combination as reserved so that it cannot be // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. // -// An optional testPort closure can be passed in which if provided will be used -// to test if the picked port can be used. The function should return true if -// the port is safe to use, false otherwise. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress, testPort func(port uint16) bool) (reservedPort uint16, err tcpip.Error) { - s.mu.Lock() - defer s.mu.Unlock() - - dst := makeDestination(dest) +// An optional PortTester can be passed in which if provided will be used to +// test if the picked port can be used. The function should return true if the +// port is safe to use, false otherwise. +func (pm *PortManager) ReservePort(res Reservation, testPort PortTester) (reservedPort uint16, err tcpip.Error) { + pm.mu.Lock() + defer pm.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. - if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, flags, bindToDevice, dst) { + if res.Port != 0 { + if !pm.reserveSpecificPortLocked(res) { return 0, &tcpip.ErrPortInUse{} } - if testPort != nil && !testPort(port) { - s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, dst) - return 0, &tcpip.ErrPortInUse{} + if testPort != nil { + ok, err := testPort(res.Port) + if err != nil { + pm.releasePortLocked(res) + return 0, err + } + if !ok { + pm.releasePortLocked(res) + return 0, &tcpip.ErrPortInUse{} + } } - return port, nil + return res.Port, nil } // A port wasn't specified, so try to find one. - return s.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { - if !s.reserveSpecificPort(networks, transport, addr, p, flags, bindToDevice, dst) { + return pm.PickEphemeralPort(func(p uint16) (bool, tcpip.Error) { + res.Port = p + if !pm.reserveSpecificPortLocked(res) { return false, nil } - if testPort != nil && !testPort(p) { - s.releasePortLocked(networks, transport, addr, p, flags.Bits(), bindToDevice, dst) - return false, nil + if testPort != nil { + ok, err := testPort(p) + if err != nil { + pm.releasePortLocked(res) + return false, err + } + if !ok { + pm.releasePortLocked(res) + return false, nil + } } return true, nil }) } -// reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dst destination) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice, dst) { - return false +// reserveSpecificPortLocked tries to reserve the given port on all given +// protocols. +func (pm *PortManager) reserveSpecificPortLocked(res Reservation) bool { + // Make sure the port is available. + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + if addrs, ok := pm.allocatedPorts[desc]; ok { + if !addrs.isAvailable(res) { + return false + } + } } - flagBits := flags.Bits() - // Reserve port on all network protocols. - for _, network := range networks { - desc := portDescriptor{network, transport, port} - m, ok := s.allocatedPorts[desc] + flagBits := res.Flags.Bits() + dst := res.dst() + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] if !ok { - m = make(bindAddresses) - s.allocatedPorts[desc] = m + addrToDev = make(addrToDevice) + pm.allocatedPorts[desc] = addrToDev } - d, ok := m[addr] + devToDest, ok := addrToDev[res.Addr] if !ok { - d = make(deviceNode) - m[addr] = d + devToDest = make(deviceToDest) + addrToDev[res.Addr] = devToDest } - p := d[bindToDevice] - if p == nil { - p = make(portNode) + destToCntr := devToDest[res.BindToDevice] + if destToCntr == nil { + destToCntr = make(destToCounter) } - n := p[dst] - n.AddRef(flagBits) - p[dst] = n - d[bindToDevice] = p + counter := destToCntr[dst] + counter.AddRef(flagBits) + destToCntr[dst] = counter + devToDest[res.BindToDevice] = destToCntr } return true } // ReserveTuple adds a port reservation for the tuple on all given protocol. -func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) bool { - flagBits := flags.Bits() - dst := makeDestination(dest) +func (pm *PortManager) ReserveTuple(res Reservation) bool { + flagBits := res.Flags.Bits() + dst := res.dst() - s.mu.Lock() - defer s.mu.Unlock() + pm.mu.Lock() + defer pm.mu.Unlock() // It is easier to undo the entire reservation, so if we find that the // tuple can't be fully added, finish and undo the whole thing. undo := false // Reserve port on all network protocols. - for _, network := range networks { - desc := portDescriptor{network, transport, port} - m, ok := s.allocatedPorts[desc] + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] if !ok { - m = make(bindAddresses) - s.allocatedPorts[desc] = m + addrToDev = make(addrToDevice) + pm.allocatedPorts[desc] = addrToDev } - d, ok := m[addr] + devToDest, ok := addrToDev[res.Addr] if !ok { - d = make(deviceNode) - m[addr] = d + devToDest = make(deviceToDest) + addrToDev[res.Addr] = devToDest } - p := d[bindToDevice] - if p == nil { - p = make(portNode) + destToCntr := devToDest[res.BindToDevice] + if destToCntr == nil { + destToCntr = make(destToCounter) } - n := p[dst] - if n.TotalRefs() != 0 && n.IntersectionRefs()&flagBits == 0 { + counter := destToCntr[dst] + if counter.TotalRefs() != 0 && counter.SharedFlags()&flagBits == 0 { // Tuple already exists. undo = true } - n.AddRef(flagBits) - p[dst] = n - d[bindToDevice] = p + counter.AddRef(flagBits) + destToCntr[dst] = counter + devToDest[res.BindToDevice] = destToCntr } if undo { // releasePortLocked decrements the counts (rather than setting // them to zero), so it will undo the incorrect incrementing // above. - s.releasePortLocked(networks, transport, addr, port, flagBits, bindToDevice, dst) + pm.releasePortLocked(res) return false } @@ -523,47 +440,71 @@ func (s *PortManager) ReserveTuple(networks []tcpip.NetworkProtocolNumber, trans // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags Flags, bindToDevice tcpip.NICID, dest tcpip.FullAddress) { - s.mu.Lock() - defer s.mu.Unlock() +func (pm *PortManager) ReleasePort(res Reservation) { + pm.mu.Lock() + defer pm.mu.Unlock() - s.releasePortLocked(networks, transport, addr, port, flags.Bits(), bindToDevice, makeDestination(dest)) + pm.releasePortLocked(res) } -func (s *PortManager) releasePortLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, flags BitFlags, bindToDevice tcpip.NICID, dst destination) { - for _, network := range networks { - desc := portDescriptor{network, transport, port} - if m, ok := s.allocatedPorts[desc]; ok { - d, ok := m[addr] - if !ok { - continue - } - p, ok := d[bindToDevice] - if !ok { - continue - } - n, ok := p[dst] - if !ok { - continue - } - n.DropRef(flags) - if n.TotalRefs() > 0 { - p[dst] = n - continue - } - delete(p, dst) - if len(p) > 0 { - continue - } - delete(d, bindToDevice) - if len(d) > 0 { - continue - } - delete(m, addr) - if len(m) > 0 { - continue - } - delete(s.allocatedPorts, desc) +func (pm *PortManager) releasePortLocked(res Reservation) { + dst := res.dst() + for _, network := range res.Networks { + desc := portDescriptor{network, res.Transport, res.Port} + addrToDev, ok := pm.allocatedPorts[desc] + if !ok { + continue } + devToDest, ok := addrToDev[res.Addr] + if !ok { + continue + } + destToCounter, ok := devToDest[res.BindToDevice] + if !ok { + continue + } + counter, ok := destToCounter[dst] + if !ok { + continue + } + counter.DropRef(res.Flags.Bits()) + if counter.TotalRefs() > 0 { + destToCounter[dst] = counter + continue + } + delete(destToCounter, dst) + if len(destToCounter) > 0 { + continue + } + delete(devToDest, res.BindToDevice) + if len(devToDest) > 0 { + continue + } + delete(addrToDev, res.Addr) + if len(addrToDev) > 0 { + continue + } + delete(pm.allocatedPorts, desc) + } +} + +// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in +// both IPv4 and IPv6. +func (pm *PortManager) PortRange() (uint16, uint16) { + pm.ephemeralMu.RLock() + defer pm.ephemeralMu.RUnlock() + return pm.firstEphemeral, pm.firstEphemeral + pm.numEphemeral - 1 +} + +// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range +// (inclusive). +func (pm *PortManager) SetPortRange(start uint16, end uint16) tcpip.Error { + if start > end { + return &tcpip.ErrInvalidPortRange{} } + pm.ephemeralMu.Lock() + defer pm.ephemeralMu.Unlock() + pm.firstEphemeral = start + pm.numEphemeral = end - start + 1 + return nil } diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index e70fbb72b..0f43dc8f8 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -329,16 +329,35 @@ func TestPortReservation(t *testing.T) { net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} for _, test := range test.actions { + first, _ := pm.PortRange() if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest) + portRes := Reservation{ + Networks: net, + Transport: fakeTransNumber, + Addr: test.ip, + Port: test.port, + Flags: test.flags, + BindToDevice: test.device, + Dest: test.dest, + } + pm.ReleasePort(portRes) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.flags, test.device, test.dest, nil /* testPort */) + portRes := Reservation{ + Networks: net, + Transport: fakeTransNumber, + Addr: test.ip, + Port: test.port, + Flags: test.flags, + BindToDevice: test.device, + Dest: test.dest, + } + gotPort, err := pm.ReservePort(portRes, nil /* testPort */) if diff := cmp.Diff(test.want, err); diff != "" { - t.Fatalf("unexpected error from ReservePort(.., .., %s, %d, %+v, %d, %v), (-want, +got):\n%s", test.ip, test.port, test.flags, test.device, test.dest, diff) + t.Fatalf("unexpected error from ReservePort(%+v, _), (-want, +got):\n%s", portRes, diff) } - if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { - t.Fatalf("ReservePort(.., .., .., 0, ..) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + if test.port == 0 && (gotPort == 0 || gotPort < first) { + t.Fatalf("ReservePort(%+v, _) = %d, want port number >= %d to be picked", portRes, gotPort, first) } } }) @@ -346,6 +365,11 @@ func TestPortReservation(t *testing.T) { } func TestPickEphemeralPort(t *testing.T) { + const ( + firstEphemeral = 32000 + numEphemeralPorts = 1000 + ) + for _, test := range []struct { name string f func(port uint16) (bool, tcpip.Error) @@ -369,17 +393,17 @@ func TestPickEphemeralPort(t *testing.T) { { name: "only-port-16042-available", f: func(port uint16) (bool, tcpip.Error) { - if port == FirstEphemeral+42 { + if port == firstEphemeral+42 { return true, nil } return false, nil }, - wantPort: FirstEphemeral + 42, + wantPort: firstEphemeral + 42, }, { name: "only-port-under-16000-available", f: func(port uint16) (bool, tcpip.Error) { - if port < FirstEphemeral { + if port < firstEphemeral { return true, nil } return false, nil @@ -389,6 +413,9 @@ func TestPickEphemeralPort(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() + if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { + t.Fatalf("failed to set ephemeral port range: %s", err) + } port, err := pm.PickEphemeralPort(test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) @@ -401,6 +428,11 @@ func TestPickEphemeralPort(t *testing.T) { } func TestPickEphemeralPortStable(t *testing.T) { + const ( + firstEphemeral = 32000 + numEphemeralPorts = 1000 + ) + for _, test := range []struct { name string f func(port uint16) (bool, tcpip.Error) @@ -424,17 +456,17 @@ func TestPickEphemeralPortStable(t *testing.T) { { name: "only-port-16042-available", f: func(port uint16) (bool, tcpip.Error) { - if port == FirstEphemeral+42 { + if port == firstEphemeral+42 { return true, nil } return false, nil }, - wantPort: FirstEphemeral + 42, + wantPort: firstEphemeral + 42, }, { name: "only-port-under-16000-available", f: func(port uint16) (bool, tcpip.Error) { - if port < FirstEphemeral { + if port < firstEphemeral { return true, nil } return false, nil @@ -444,7 +476,10 @@ func TestPickEphemeralPortStable(t *testing.T) { } { t.Run(test.name, func(t *testing.T) { pm := NewPortManager() - portOffset := uint32(rand.Int31n(int32(numEphemeralPorts))) + if err := pm.SetPortRange(firstEphemeral, firstEphemeral+numEphemeralPorts); err != nil { + t.Fatalf("failed to set ephemeral port range: %s", err) + } + portOffset := uint16(rand.Int31n(int32(numEphemeralPorts))) port, err := pm.PickEphemeralPortStable(portOffset, test.f) if diff := cmp.Diff(test.wantErr, err); diff != "" { t.Fatalf("unexpected error from PickEphemeralPort(..), (-want, +got):\n%s", diff) diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 78a4cb072..47796a6ba 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -99,12 +99,11 @@ func prefixSubnetAddr(offset uint8, linkAddr tcpip.LinkAddress) (tcpip.AddressWi } // ndpDADEvent is a set of parameters that was passed to -// ndpDispatcher.OnDuplicateAddressDetectionStatus. +// ndpDispatcher.OnDuplicateAddressDetectionResult. type ndpDADEvent struct { - nicID tcpip.NICID - addr tcpip.Address - resolved bool - err tcpip.Error + nicID tcpip.NICID + addr tcpip.Address + res stack.DADResult } type ndpRouterEvent struct { @@ -173,14 +172,13 @@ type ndpDispatcher struct { dhcpv6ConfigurationC chan ndpDHCPv6Event } -// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionStatus. -func (n *ndpDispatcher) OnDuplicateAddressDetectionStatus(nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) { +// Implements ipv6.NDPDispatcher.OnDuplicateAddressDetectionResult. +func (n *ndpDispatcher) OnDuplicateAddressDetectionResult(nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) { if n.dadC != nil { n.dadC <- ndpDADEvent{ nicID, addr, - resolved, - err, + res, } } } @@ -311,8 +309,8 @@ func (l *channelLinkWithHeaderLength) MaxHeaderLength() uint16 { // Check e to make sure that the event is for addr on nic with ID 1, and the // resolved flag set to resolved with the specified err. -func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, resolved bool, err tcpip.Error) string { - return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, resolved: resolved, err: err}, e, cmp.AllowUnexported(e)) +func checkDADEvent(e ndpDADEvent, nicID tcpip.NICID, addr tcpip.Address, res stack.DADResult) string { + return cmp.Diff(ndpDADEvent{nicID: nicID, addr: addr, res: res}, e, cmp.AllowUnexported(e)) } // TestDADDisabled tests that an address successfully resolves immediately @@ -344,8 +342,8 @@ func TestDADDisabled(t *testing.T) { // DAD on it. select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -491,8 +489,8 @@ func TestDADResolve(t *testing.T) { case <-time.After(defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil { @@ -598,9 +596,10 @@ func TestDADFail(t *testing.T) { const nicID = 1 tests := []struct { - name string - rxPkt func(e *channel.Endpoint, tgt tcpip.Address) - getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + name string + rxPkt func(e *channel.Endpoint, tgt tcpip.Address) + getStat func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter + expectedHolderLinkAddress tcpip.LinkAddress }{ { name: "RxSolicit", @@ -608,6 +607,7 @@ func TestDADFail(t *testing.T) { getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborSolicit }, + expectedHolderLinkAddress: "", }, { name: "RxAdvert", @@ -642,6 +642,7 @@ func TestDADFail(t *testing.T) { getStat: func(s tcpip.ICMPv6ReceivedPacketStats) *tcpip.StatCounter { return s.NeighborAdvert }, + expectedHolderLinkAddress: linkAddr1, }, } @@ -691,8 +692,8 @@ func TestDADFail(t *testing.T) { // something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADDupAddrDetected{HolderLinkAddress: test.expectedHolderLinkAddress}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil { @@ -790,8 +791,8 @@ func TestDADStop(t *testing.T) { // time + extra 1s buffer, something is wrong. t.Fatal("timed out waiting for DAD failure") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr1, false, &tcpip.ErrAborted{}); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr1, &stack.DADAborted{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } @@ -852,8 +853,8 @@ func TestSetNDPConfigurations(t *testing.T) { expectDADEvent := func(nicID tcpip.NICID, addr tcpip.Address) { select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatalf("expected DAD event for %s", addr) @@ -944,8 +945,8 @@ func TestSetNDPConfigurations(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID1, addr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID1, addr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil { @@ -1963,8 +1964,8 @@ func TestAutoGenTempAddr(t *testing.T) { select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2169,8 +2170,8 @@ func TestNoAutoGenTempAddrForLinkLocal(t *testing.T) { } select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, llAddr1, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, llAddr1, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(time.Duration(test.dupAddrTransmits)*test.retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2257,8 +2258,8 @@ func TestNoAutoGenTempAddrWithoutStableAddr(t *testing.T) { // address to be generated. select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -2723,8 +2724,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { t.Helper() clock.Advance(dupAddrTransmits * retransmitTimer) - if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(<-ndpDisp.dadC, nicID, addr, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } } @@ -2754,8 +2755,8 @@ func TestMixedSLAACAddrConflictRegen(t *testing.T) { rxNDPSolicit(e, addr.Address) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -3853,26 +3854,26 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { } } - expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool, err tcpip.Error) { + expectDADEvent := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, err); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, res); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") } } - expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, resolved bool) { + expectDADEventAsync := func(t *testing.T, ndpDisp *ndpDispatcher, addr tcpip.Address, res stack.DADResult) { t.Helper() select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr, resolved, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr, res); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") @@ -3929,7 +3930,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // generated. e.InjectInbound(header.IPv6ProtocolNumber, raBufWithPI(llAddr2, 0, prefix, true, true, 100, 100)) expectAutoGenAddrEvent(t, ndpDisp, stableAddrForTempAddrTest, newAddr) - expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, true) + expectDADEventAsync(t, ndpDisp, stableAddrForTempAddrTest.Address, &stack.DADSucceeded{}) // The stable address will be assigned throughout the test. return []tcpip.AddressWithPrefix{stableAddrForTempAddrTest} @@ -4004,7 +4005,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { // Simulate a DAD conflict. rxNDPSolicit(e, addr.Address) expectAutoGenAddrEvent(t, &ndpDisp, addr, invalidatedAddr) - expectDADEvent(t, &ndpDisp, addr.Address, false, nil) + expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADDupAddrDetected{}) // Attempting to add the address manually should not fail if the // address's state was cleaned up when DAD failed. @@ -4014,7 +4015,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if err := s.RemoveAddress(nicID, addr.Address); err != nil { t.Fatalf("RemoveAddress(%d, %s) = %s", nicID, addr.Address, err) } - expectDADEvent(t, &ndpDisp, addr.Address, false, &tcpip.ErrAborted{}) + expectDADEvent(t, &ndpDisp, addr.Address, &stack.DADAborted{}) } // Should not have any new addresses assigned to the NIC. @@ -4027,7 +4028,7 @@ func TestAutoGenAddrInResponseToDADConflicts(t *testing.T) { if maxRetries+1 > numFailures { addr := addrType.addrGenFn(numFailures, tempIIDHistory[:]) expectAutoGenAddrEventAsync(t, &ndpDisp, addr, newAddr) - expectDADEventAsync(t, &ndpDisp, addr.Address, true) + expectDADEventAsync(t, &ndpDisp, addr.Address, &stack.DADSucceeded{}) if mismatch := addressCheck(s.NICInfo()[nicID].ProtocolAddresses, append(stableAddrs, addr), nil); mismatch != "" { t.Fatal(mismatch) } @@ -4144,8 +4145,8 @@ func TestAutoGenAddrWithEUI64IIDNoDADRetries(t *testing.T) { expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -4243,8 +4244,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, invalidatedAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, false, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADDupAddrDetected{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } default: t.Fatal("expected DAD event") @@ -4255,8 +4256,8 @@ func TestAutoGenAddrContinuesLifetimesAfterRetry(t *testing.T) { expectAutoGenAddrEvent(addr, newAddr) select { case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.Address, true, nil); diff != "" { - t.Errorf("dad event mismatch (-want +got):\n%s", diff) + if diff := checkDADEvent(e, nicID, addr.Address, &stack.DADSucceeded{}); diff != "" { + t.Errorf("DAD event mismatch (-want +got):\n%s", diff) } case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD event") diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 43e9e4beb..85f0f471a 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -852,18 +852,46 @@ type InjectableLinkEndpoint interface { InjectOutbound(dest tcpip.Address, packet []byte) tcpip.Error } -// DADResult is the result of a duplicate address detection process. -type DADResult struct { - // Resolved is true when DAD completed without detecting a duplicate address - // on the link. - // - // Ignored when Err is non-nil. - Resolved bool +// DADResult is a marker interface for the result of a duplicate address +// detection process. +type DADResult interface { + isDADResult() +} + +var _ DADResult = (*DADSucceeded)(nil) + +// DADSucceeded indicates DAD completed without finding any duplicate addresses. +type DADSucceeded struct{} - // Err is an error encountered while performing DAD. +func (*DADSucceeded) isDADResult() {} + +var _ DADResult = (*DADError)(nil) + +// DADError indicates DAD hit an error. +type DADError struct { Err tcpip.Error } +func (*DADError) isDADResult() {} + +var _ DADResult = (*DADAborted)(nil) + +// DADAborted indicates DAD was aborted. +type DADAborted struct{} + +func (*DADAborted) isDADResult() {} + +var _ DADResult = (*DADDupAddrDetected)(nil) + +// DADDupAddrDetected indicates DAD detected a duplicate address. +type DADDupAddrDetected struct { + // HolderLinkAddress is the link address of the node that holds the duplicate + // address. + HolderLinkAddress tcpip.LinkAddress +} + +func (*DADDupAddrDetected) isDADResult() {} + // DADCompletionHandler is a handler for DAD completion. type DADCompletionHandler func(DADResult) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index de94ddfda..53370c354 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -813,6 +813,18 @@ func (s *Stack) Forwarding(protocolNum tcpip.NetworkProtocolNumber) bool { return forwardingProtocol.Forwarding() } +// PortRange returns the UDP and TCP inclusive range of ephemeral ports used in +// both IPv4 and IPv6. +func (s *Stack) PortRange() (uint16, uint16) { + return s.PortManager.PortRange() +} + +// SetPortRange sets the UDP and TCP IPv4 and IPv6 ephemeral port range +// (inclusive). +func (s *Stack) SetPortRange(start uint16, end uint16) tcpip.Error { + return s.PortManager.SetPortRange(start, end) +} + // SetRouteTable assigns the route table to be used by this stack. It // specifies which NIC to use for given destination address ranges. // diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f45cf5fdf..880219007 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -2605,7 +2605,7 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) { // means something is wrong. t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, linkLocalAddr, true, nil); diff != "" { + if diff := checkDADEvent(e, nicID, linkLocalAddr, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } @@ -3289,7 +3289,7 @@ func TestDoDADWhenNICEnabled(t *testing.T) { case <-time.After(dadTransmits*retransmitTimer + defaultAsyncPositiveEventTimeout): t.Fatal("timed out waiting for DAD resolution") case e := <-ndpDisp.dadC: - if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, true, nil); diff != "" { + if diff := checkDADEvent(e, nicID, addr.AddressWithPrefix.Address, &stack.DADSucceeded{}); diff != "" { t.Errorf("dad event mismatch (-want +got):\n%s", diff) } } diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index e799f9290..e188efccb 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -359,7 +359,7 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[0] } - if mpep.flags.IntersectionRefs().ToFlags().Effective().MostRecent { + if mpep.flags.SharedFlags().ToFlags().Effective().MostRecent { return mpep.endpoints[len(mpep.endpoints)-1] } @@ -410,7 +410,7 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, flags p if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { return &tcpip.ErrPortInUse{} } } @@ -429,7 +429,7 @@ func (ep *multiPortEndpoint) singleCheckEndpoint(flags ports.Flags) tcpip.Error if len(ep.endpoints) != 0 { // If it was previously bound, we need to check if we can bind again. - if ep.flags.TotalRefs() > 0 && bits&ep.flags.IntersectionRefs() == 0 { + if ep.flags.TotalRefs() > 0 && bits&ep.flags.SharedFlags() == 0 { return &tcpip.ErrPortInUse{} } } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 405c74c65..095623789 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -1167,53 +1167,53 @@ func TestDAD(t *testing.T) { } tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - dadNetProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedResolved bool + name string + netProto tcpip.NetworkProtocolNumber + dadNetProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedResult stack.DADResult }{ { - name: "IPv4 own address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, - expectedResolved: true, + name: "IPv4 own address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: utils.Ipv4Addr1.AddressWithPrefix.Address, + expectedResult: &stack.DADSucceeded{}, }, { - name: "IPv6 own address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, - expectedResolved: true, + name: "IPv6 own address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: utils.Ipv6Addr1.AddressWithPrefix.Address, + expectedResult: &stack.DADSucceeded{}, }, { - name: "IPv4 duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, - expectedResolved: false, + name: "IPv4 duplicate address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: utils.Ipv4Addr2.AddressWithPrefix.Address, + expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2}, }, { - name: "IPv6 duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, - expectedResolved: false, + name: "IPv6 duplicate address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: utils.Ipv6Addr2.AddressWithPrefix.Address, + expectedResult: &stack.DADDupAddrDetected{HolderLinkAddress: utils.LinkAddr2}, }, { - name: "IPv4 no duplicate address", - netProto: ipv4.ProtocolNumber, - dadNetProto: arp.ProtocolNumber, - remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, - expectedResolved: true, + name: "IPv4 no duplicate address", + netProto: ipv4.ProtocolNumber, + dadNetProto: arp.ProtocolNumber, + remoteAddr: utils.Ipv4Addr3.AddressWithPrefix.Address, + expectedResult: &stack.DADSucceeded{}, }, { - name: "IPv6 no duplicate address", - netProto: ipv6.ProtocolNumber, - dadNetProto: ipv6.ProtocolNumber, - remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, - expectedResolved: true, + name: "IPv6 no duplicate address", + netProto: ipv6.ProtocolNumber, + dadNetProto: ipv6.ProtocolNumber, + remoteAddr: utils.Ipv6Addr3.AddressWithPrefix.Address, + expectedResult: &stack.DADSucceeded{}, }, } @@ -1260,7 +1260,7 @@ func TestDAD(t *testing.T) { } expectResults := 1 - if test.expectedResolved { + if _, ok := test.expectedResult.(*stack.DADSucceeded); ok { const delta = time.Nanosecond clock.Advance(time.Duration(dadConfigs.DupAddrDetectTransmits)*dadConfigs.RetransmitTimer - delta) select { @@ -1285,7 +1285,7 @@ func TestDAD(t *testing.T) { } for i := 0; i < expectResults; i++ { - if diff := cmp.Diff(stack.DADResult{Resolved: test.expectedResolved}, <-ch); diff != "" { + if diff := cmp.Diff(test.expectedResult, <-ch); diff != "" { t.Errorf("(i=%d) DAD result mismatch (-want +got):\n%s", i, diff) } } diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index c56155ea2..80afc2825 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -38,7 +38,7 @@ var _ ipv6.NDPDispatcher = (*ndpDispatcher)(nil) type ndpDispatcher struct{} -func (*ndpDispatcher) OnDuplicateAddressDetectionStatus(tcpip.NICID, tcpip.Address, bool, tcpip.Error) { +func (*ndpDispatcher) OnDuplicateAddressDetectionResult(tcpip.NICID, tcpip.Address, stack.DADResult) { } func (*ndpDispatcher) OnDefaultRouterDiscovered(tcpip.NICID, tcpip.Address) bool { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 09e9d027d..06c63e74a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -26,6 +26,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// TODO(https://gvisor.dev/issues/5623): Unit test this package. + // +stateify savable type icmpPacket struct { icmpPacketEntry @@ -414,6 +416,11 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return &tcpip.ErrInvalidEndpointState{} } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest + icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) @@ -422,7 +429,14 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) + + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() + return err + } + + sentStat.Increment() + return nil } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { @@ -444,6 +458,10 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return &tcpip.ErrInvalidEndpointState{} } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest pkt.Data().AppendView(data) dataRange := pkt.Data().AsRange() @@ -458,7 +476,13 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if ttl == 0 { ttl = r.DefaultTTL() } - return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt) + + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { + r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() + } + + sentStat.Increment() + return nil } // checkV4MappedLocked determines the effective network protocol and converts diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index fcdd032c5..a69d6624d 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -105,7 +105,6 @@ go_test( "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", - "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp/testing/context", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 842c1622b..3b574837c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -432,15 +433,16 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // * e.mu is held. func (e *endpoint) reserveTupleLocked() bool { dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} - if !e.stack.ReserveTuple( - e.effectiveNetProtos, - ProtocolNumber, - e.ID.LocalAddress, - e.ID.LocalPort, - e.boundPortFlags, - e.boundBindToDevice, - dest, - ) { + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: dest, + } + if !e.stack.ReserveTuple(portRes) { return false } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index d1e452421..3404af6bb 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -68,7 +68,7 @@ type handshake struct { ep *endpoint state handshakeState active bool - flags uint8 + flags header.TCPFlags ackNum seqnum.Value // iss is the initial send sequence number, as defined in RFC 793. @@ -606,7 +606,7 @@ func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer func (bt *backoffTimer) reset() tcpip.Error { bt.timeout *= 2 - if bt.timeout > MaxRTO { + if bt.timeout > bt.maxTimeout { return &tcpip.ErrTimeout{} } bt.t.Reset(bt.timeout) @@ -700,7 +700,7 @@ type tcpFields struct { id stack.TransportEndpointID ttl uint8 tos uint8 - flags byte + flags header.TCPFlags seq seqnum.Value ack seqnum.Value rcvWnd seqnum.Size @@ -877,7 +877,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { } // sendRaw sends a TCP segment to the endpoint's peer. -func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags header.TCPFlags, seq, ack seqnum.Value, rcvWnd seqnum.Size) tcpip.Error { var sackBlocks []header.SACKBlock if e.EndpointState() == StateEstablished && e.rcv.pendingRcvdSegments.Len() > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 687b9f459..129f36d11 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1097,7 +1097,16 @@ func (e *endpoint) closeNoShutdownLocked() { e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: e.boundDest, + } + e.stack.ReleasePort(portRes) e.isPortReserved = false e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} @@ -1172,7 +1181,16 @@ func (e *endpoint) cleanupLocked() { } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: e.boundDest, + } + e.stack.ReleasePort(portRes) e.isPortReserved = false } e.boundBindToDevice = 0 @@ -2220,7 +2238,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp portBuf := make([]byte, 2) binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) h.Write(portBuf) - portOffset := h.Sum32() + portOffset := uint16(h.Sum32()) var twReuse tcpip.TCPTimeWaitReuseOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &twReuse); err != nil { @@ -2242,7 +2260,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp if sameAddr && p == e.ID.RemotePort { return false, nil } - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: p, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: addr, + } + if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { if _, ok := err.(*tcpip.ErrPortInUse); !ok || !reuse { return false, nil } @@ -2280,7 +2307,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp tcpEP.notifyProtocolGoroutine(notifyAbort) tcpEP.UnlockUser() // Now try and Reserve again if it fails then we skip. - if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr, nil /* testPort */); err != nil { + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: p, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: addr, + } + if _, err := e.stack.ReservePort(portRes, nil /* testPort */); err != nil { return false, nil } } @@ -2288,7 +2324,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp id := e.ID id.LocalPort = p if err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.portFlags, bindToDevice); err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, bindToDevice, addr) + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: p, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: addr, + } + e.stack.ReleasePort(portRes) if _, ok := err.(*tcpip.ErrPortInUse); ok { return false, nil } @@ -2604,7 +2649,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { } bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, bindToDevice, tcpip.FullAddress{}, func(p uint16) bool { + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: addr.Addr, + Port: addr.Port, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: tcpip.FullAddress{}, + } + port, err := e.stack.ReservePort(portRes, func(p uint16) (bool, tcpip.Error) { id := e.ID id.LocalPort = p // CheckRegisterTransportEndpoint should only return an error if there is a @@ -2616,9 +2670,9 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err tcpip.Error) { // address/port. Hence this will only return an error if there is a matching // listening endpoint. if err := e.stack.CheckRegisterTransportEndpoint(netProtos, ProtocolNumber, id, e.portFlags, bindToDevice); err != nil { - return false + return false, nil } - return true + return true, nil }) if err != nil { return err diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 7c15690a3..a53d76917 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -208,7 +209,16 @@ func (e *endpoint) Resume(s *stack.Stack) { if err != nil { panic("unable to parse BindAddr: " + err.String()) } - if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok { + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: addr.Addr, + Port: addr.Port, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: e.boundDest, + } + if ok := e.stack.ReserveTuple(portRes); !ok { panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) } e.isPortReserved = true diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 04012cd40..2a4667906 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -226,7 +226,7 @@ func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) tcpip.Error // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) - flags := byte(header.TCPFlagRst) + flags := header.TCPFlagRst // As per RFC 793 page 35 (Reset Generation) // 1. If the connection does not exist (CLOSED) then a reset is sent // in response to any incoming segment except another reset. In diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 744382100..8edd6775b 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -62,7 +62,7 @@ type segment struct { views [8]buffer.View `state:"nosave"` sequenceNumber seqnum.Value ackNumber seqnum.Value - flags uint8 + flags header.TCPFlags window seqnum.Size // csum is only populated for received segments. csum uint16 @@ -141,12 +141,12 @@ func (s *segment) clone() *segment { } // flagIsSet checks if at least one flag in flags is set in s.flags. -func (s *segment) flagIsSet(flags uint8) bool { +func (s *segment) flagIsSet(flags header.TCPFlags) bool { return s.flags&flags != 0 } // flagsAreSet checks if all flags in flags are set in s.flags. -func (s *segment) flagsAreSet(flags uint8) bool { +func (s *segment) flagsAreSet(flags header.TCPFlags) bool { return s.flags&flags == flags } diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 83c8deb0e..18817029d 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1613,7 +1613,7 @@ func (s *sender) sendSegment(seg *segment) tcpip.Error { // sendSegmentFromView sends a new segment containing the given payload, flags // and sequence number. -func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) tcpip.Error { +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags header.TCPFlags, seq seqnum.Value) tcpip.Error { s.lastSendTime = time.Now() if seq == s.rttMeasureSeqNum { s.rttMeasureTime = s.lastSendTime diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 0128c1f7e..fd499a47b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -33,7 +33,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" - "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -1373,7 +1372,7 @@ func TestTOSV4(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), ) @@ -1421,7 +1420,7 @@ func TestTrafficClassV6(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), checker.TOS(tos, 0), ) @@ -2202,7 +2201,7 @@ func TestSimpleSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2242,7 +2241,7 @@ func TestZeroWindowSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2264,7 +2263,7 @@ func TestZeroWindowSend(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2311,7 +2310,7 @@ func TestScaledWindowConnect(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2342,7 +2341,7 @@ func TestNonScaledWindowConnect(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2415,7 +2414,7 @@ func TestScaledWindowAccept(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0x5fff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2488,7 +2487,7 @@ func TestNonScaledWindowAccept(t *testing.T) { checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(0xffff), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2666,7 +2665,7 @@ func TestSegmentMerging(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+uint32(i)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) } @@ -2689,7 +2688,7 @@ func TestSegmentMerging(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+11), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2738,7 +2737,7 @@ func TestDelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2786,7 +2785,7 @@ func TestUndelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2809,7 +2808,7 @@ func TestUndelay(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2872,7 +2871,7 @@ func TestMSSNotDelayed(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(seq)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -2923,7 +2922,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3438,7 +3437,7 @@ func TestMaxRTO(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) const numRetransmits = 2 @@ -3447,7 +3446,7 @@ func TestMaxRTO(t *testing.T) { checker.IPv4(t, c.GetPacket(), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { @@ -3490,7 +3489,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.FragmentFlags(0), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): {}} @@ -3502,7 +3501,7 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { checker.FragmentFlags(0), checker.TCP( checker.DstPort(context.TestPort), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) id := header.IPv4(pkt).ID() @@ -3633,7 +3632,7 @@ func TestFinWithNoPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3710,7 +3709,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3729,7 +3728,7 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3796,7 +3795,7 @@ func TestFinWithPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3822,7 +3821,7 @@ func TestFinWithPendingData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3886,7 +3885,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -3907,7 +3906,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -3923,7 +3922,7 @@ func TestFinWithPartialAck(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(791), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) next += uint32(len(view)) @@ -4033,7 +4032,7 @@ func scaledSendWindow(t *testing.T, scale uint8) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -4783,7 +4782,8 @@ func TestConnectAvoidsBoundPorts(t *testing.T) { t.Fatalf("unknown address type: '%s'", candidateAddressType) } - for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { + start, end := s.PortRange() + for i := start; i <= end; i++ { if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { t.Fatalf("Bind(%d) failed: %s", i, err) } @@ -4844,7 +4844,7 @@ func TestPathMTUDiscovery(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(seqNum), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) seqNum += uint32(size) @@ -5129,7 +5129,7 @@ func TestKeepalive(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7174,7 +7174,7 @@ func TestTCPCloseWithData(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -7274,7 +7274,7 @@ func TestTCPUserTimeout(t *testing.T) { checker.DstPort(context.TestPort), checker.TCPSeqNum(next), checker.TCPAckNum(790), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 5a9745ad7..cb4f82903 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -170,7 +170,7 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), checker.TCPTimestampChecker(true, 0, tsVal+1), ), ) @@ -231,7 +231,7 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd checker.TCPSeqNum(uint32(c.IRS)+1), checker.TCPAckNum(790), checker.TCPWindow(wndSize), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), checker.TCPTimestampChecker(false, 0, 0), ), ) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b1cb9a324..2f1c1011d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -101,7 +101,7 @@ type Headers struct { AckNum seqnum.Value // Flags are the TCP flags in the TCP header. - Flags int + Flags header.TCPFlags // RcvWnd is the window to be advertised in the ReceiveWindow field of // the TCP header. @@ -452,7 +452,7 @@ func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcp SeqNum: uint32(h.SeqNum), AckNum: uint32(h.AckNum), DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), - Flags: uint8(h.Flags), + Flags: h.Flags, WindowSize: uint16(h.RcvWnd), }) @@ -544,7 +544,7 @@ func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, op checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -571,7 +571,7 @@ func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int checker.DstPort(TestPort), checker.TCPSeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), checker.TCPAckNum(uint32(seqnum.Value(TestInitialSequenceNumber).Add(1))), - checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPFlagsMatch(header.TCPFlagAck, ^header.TCPFlagPsh), ), ) @@ -650,7 +650,7 @@ func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcp SeqNum: uint32(h.SeqNum), AckNum: uint32(h.AckNum), DataOffset: header.TCPMinimumSize, - Flags: uint8(h.Flags), + Flags: h.Flags, WindowSize: uint16(h.RcvWnd), }) @@ -780,7 +780,7 @@ type RawEndpoint struct { C *Context SrcPort uint16 DstPort uint16 - Flags int + Flags header.TCPFlags NextSeqNum seqnum.Value AckNum seqnum.Value WndSize seqnum.Size diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go index 5e271b7ca..6c5ddc3c7 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go @@ -465,7 +465,7 @@ func TestIgnoreBadResetOnSynSent(t *testing.T) { // Receive a RST with a bad ACK, it should not cause the connection to // be reset. acks := []uint32{1234, 1236, 1000, 5000} - flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} + flags := []header.TCPFlags{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} for _, a := range acks { for _, f := range flags { tcp.Encode(&header.TCPFields{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index b519afed1..c0f566459 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -245,7 +245,16 @@ func (e *endpoint) Close() { switch e.EndpointState() { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: tcpip.FullAddress{}, + } + e.stack.ReleasePort(portRes) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} } @@ -920,7 +929,16 @@ func (e *endpoint) Disconnect() tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) + portRes := ports.Reservation{ + Networks: e.effectiveNetProtos, + Transport: ProtocolNumber, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + Flags: boundPortFlags, + BindToDevice: e.boundBindToDevice, + Dest: tcpip.FullAddress{}, + } + e.stack.ReleasePort(portRes) e.boundPortFlags = ports.Flags{} } e.setEndpointState(StateInitial) @@ -1072,7 +1090,16 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, bindToDevice, tcpip.FullAddress{}, nil /* testPort */) + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: id.LocalAddress, + Port: id.LocalPort, + Flags: e.portFlags, + BindToDevice: bindToDevice, + Dest: tcpip.FullAddress{}, + } + port, err := e.stack.ReservePort(portRes, nil /* testPort */) if err != nil { return id, bindToDevice, err } @@ -1082,7 +1109,16 @@ func (e *endpoint) registerWithStack(netProtos []tcpip.NetworkProtocolNumber, id err := e.stack.RegisterTransportEndpoint(netProtos, ProtocolNumber, id, e, e.boundPortFlags, bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, bindToDevice, tcpip.FullAddress{}) + portRes := ports.Reservation{ + Networks: netProtos, + Transport: ProtocolNumber, + Addr: id.LocalAddress, + Port: id.LocalPort, + Flags: e.boundPortFlags, + BindToDevice: bindToDevice, + Dest: tcpip.FullAddress{}, + } + e.stack.ReleasePort(portRes) e.boundPortFlags = ports.Flags{} } return id, bindToDevice, err diff --git a/runsc/boot/compat.go b/runsc/boot/compat.go index a3a76b609..28e82e117 100644 --- a/runsc/boot/compat.go +++ b/runsc/boot/compat.go @@ -17,8 +17,8 @@ package boot import ( "fmt" "os" - "syscall" + "golang.org/x/sys/unix" "google.golang.org/protobuf/proto" "gvisor.dev/gvisor/pkg/eventchannel" "gvisor.dev/gvisor/pkg/log" @@ -93,19 +93,19 @@ func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) { tr := c.trackers[sysnr] if tr == nil { switch sysnr { - case syscall.SYS_PRCTL: + case unix.SYS_PRCTL: // args: cmd, ... tr = newArgsTracker(0) - case syscall.SYS_IOCTL, syscall.SYS_EPOLL_CTL, syscall.SYS_SHMCTL, syscall.SYS_FUTEX, syscall.SYS_FALLOCATE: + case unix.SYS_IOCTL, unix.SYS_EPOLL_CTL, unix.SYS_SHMCTL, unix.SYS_FUTEX, unix.SYS_FALLOCATE: // args: fd/addr, cmd, ... tr = newArgsTracker(1) - case syscall.SYS_GETSOCKOPT, syscall.SYS_SETSOCKOPT: + case unix.SYS_GETSOCKOPT, unix.SYS_SETSOCKOPT: // args: fd, level, name, ... tr = newArgsTracker(1, 2) - case syscall.SYS_SEMCTL: + case unix.SYS_SEMCTL: // args: semid, semnum, cmd, ... tr = newArgsTracker(2) @@ -131,7 +131,7 @@ func (c *compatEmitter) emitUnimplementedSyscall(us *spb.UnimplementedSyscall) { } func (c *compatEmitter) emitUncaughtSignal(msg *ucspb.UncaughtSignal) { - sig := syscall.Signal(msg.SignalNumber) + sig := unix.Signal(msg.SignalNumber) c.sink.Infof( "Uncaught signal: %q (%d), PID: %d, TID: %d, fault addr: %#x", sig, msg.SignalNumber, msg.Pid, msg.Tid, msg.FaultAddr) diff --git a/runsc/boot/compat_amd64.go b/runsc/boot/compat_amd64.go index 8eb76b2ba..7e13ff87c 100644 --- a/runsc/boot/compat_amd64.go +++ b/runsc/boot/compat_amd64.go @@ -16,8 +16,8 @@ package boot import ( "fmt" - "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/sentry/arch" rpb "gvisor.dev/gvisor/pkg/sentry/arch/registers_go_proto" @@ -92,7 +92,7 @@ func syscallNum(regs *rpb.Registers) uint64 { func newArchArgsTracker(sysnr uint64) syscallTracker { switch sysnr { - case syscall.SYS_ARCH_PRCTL: + case unix.SYS_ARCH_PRCTL: // args: cmd, ... return newArgsTracker(0) } diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index 5e849cb37..1cd5fba5c 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -18,9 +18,9 @@ import ( "errors" "fmt" "os" - "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/control/server" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/log" @@ -366,7 +366,7 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { case 2: // The device file is donated to the platform. // Can't take ownership away from os.File. dup them to get a new FD. - fd, err := syscall.Dup(int(o.Files[1].Fd())) + fd, err := unix.Dup(int(o.Files[1].Fd())) if err != nil { return fmt.Errorf("failed to dup file: %v", err) } diff --git a/runsc/boot/filter/config.go b/runsc/boot/filter/config.go index 2a8c916d5..49b503f99 100644 --- a/runsc/boot/filter/config.go +++ b/runsc/boot/filter/config.go @@ -16,7 +16,6 @@ package filter import ( "os" - "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" @@ -26,19 +25,19 @@ import ( // allowedSyscalls is the set of syscalls executed by the Sentry to the host OS. var allowedSyscalls = seccomp.SyscallRules{ - syscall.SYS_CLOCK_GETTIME: {}, - syscall.SYS_CLOSE: {}, - syscall.SYS_DUP: {}, - syscall.SYS_DUP3: []seccomp.Rule{ + unix.SYS_CLOCK_GETTIME: {}, + unix.SYS_CLOSE: {}, + unix.SYS_DUP: {}, + unix.SYS_DUP3: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.O_CLOEXEC), + seccomp.EqualTo(unix.O_CLOEXEC), }, }, - syscall.SYS_EPOLL_CREATE1: {}, - syscall.SYS_EPOLL_CTL: {}, - syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{ + unix.SYS_EPOLL_CREATE1: {}, + unix.SYS_EPOLL_CTL: {}, + unix.SYS_EPOLL_PWAIT: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, @@ -47,34 +46,34 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.EqualTo(0), }, }, - syscall.SYS_EVENTFD2: []seccomp.Rule{ + unix.SYS_EVENTFD2: []seccomp.Rule{ { seccomp.EqualTo(0), seccomp.EqualTo(0), }, }, - syscall.SYS_EXIT: {}, - syscall.SYS_EXIT_GROUP: {}, - syscall.SYS_FALLOCATE: {}, - syscall.SYS_FCHMOD: {}, - syscall.SYS_FCNTL: []seccomp.Rule{ + unix.SYS_EXIT: {}, + unix.SYS_EXIT_GROUP: {}, + unix.SYS_FALLOCATE: {}, + unix.SYS_FCHMOD: {}, + unix.SYS_FCNTL: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.F_GETFL), + seccomp.EqualTo(unix.F_GETFL), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.F_SETFL), + seccomp.EqualTo(unix.F_SETFL), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.F_GETFD), + seccomp.EqualTo(unix.F_GETFD), }, }, - syscall.SYS_FSTAT: {}, - syscall.SYS_FSYNC: {}, - syscall.SYS_FTRUNCATE: {}, - syscall.SYS_FUTEX: []seccomp.Rule{ + unix.SYS_FSTAT: {}, + unix.SYS_FSYNC: {}, + unix.SYS_FTRUNCATE: {}, + unix.SYS_FUTEX: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), @@ -109,35 +108,35 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.EqualTo(0), }, }, - syscall.SYS_GETPID: {}, + unix.SYS_GETPID: {}, unix.SYS_GETRANDOM: {}, - syscall.SYS_GETSOCKOPT: []seccomp.Rule{ + unix.SYS_GETSOCKOPT: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_DOMAIN), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_DOMAIN), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_TYPE), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_TYPE), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_ERROR), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_ERROR), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_SNDBUF), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_SNDBUF), }, }, - syscall.SYS_GETTID: {}, - syscall.SYS_GETTIMEOFDAY: {}, + unix.SYS_GETTID: {}, + unix.SYS_GETTIMEOFDAY: {}, // SYS_IOCTL is needed for terminal support, but we only allow // setting/getting termios and winsize. - syscall.SYS_IOCTL: []seccomp.Rule{ + unix.SYS_IOCTL: []seccomp.Rule{ { seccomp.MatchAny{}, /* fd */ seccomp.EqualTo(linux.TCGETS), @@ -169,94 +168,94 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.MatchAny{}, /* winsize struct */ }, }, - syscall.SYS_LSEEK: {}, - syscall.SYS_MADVISE: {}, + unix.SYS_LSEEK: {}, + unix.SYS_MADVISE: {}, unix.SYS_MEMBARRIER: []seccomp.Rule{ { seccomp.EqualTo(linux.MEMBARRIER_CMD_GLOBAL), seccomp.EqualTo(0), }, }, - syscall.SYS_MINCORE: {}, + unix.SYS_MINCORE: {}, // Used by the Go runtime as a temporarily workaround for a Linux // 5.2-5.4 bug. // // See src/runtime/os_linux_x86.go. // // TODO(b/148688965): Remove once this is gone from Go. - syscall.SYS_MLOCK: []seccomp.Rule{ + unix.SYS_MLOCK: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.EqualTo(4096), }, }, - syscall.SYS_MMAP: []seccomp.Rule{ + unix.SYS_MMAP: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MAP_SHARED), + seccomp.EqualTo(unix.MAP_SHARED), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MAP_PRIVATE), + seccomp.EqualTo(unix.MAP_PRIVATE), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), + seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_STACK), + seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS | unix.MAP_STACK), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_NORESERVE), + seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS | unix.MAP_NORESERVE), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.PROT_WRITE | syscall.PROT_READ), - seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), + seccomp.EqualTo(unix.PROT_WRITE | unix.PROT_READ), + seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS | unix.MAP_FIXED), }, }, - syscall.SYS_MPROTECT: {}, - syscall.SYS_MUNMAP: {}, - syscall.SYS_NANOSLEEP: {}, - syscall.SYS_PPOLL: {}, - syscall.SYS_PREAD64: {}, - syscall.SYS_PREADV: {}, - unix.SYS_PREADV2: {}, - syscall.SYS_PWRITE64: {}, - syscall.SYS_PWRITEV: {}, - unix.SYS_PWRITEV2: {}, - syscall.SYS_READ: {}, - syscall.SYS_RECVMSG: []seccomp.Rule{ + unix.SYS_MPROTECT: {}, + unix.SYS_MUNMAP: {}, + unix.SYS_NANOSLEEP: {}, + unix.SYS_PPOLL: {}, + unix.SYS_PREAD64: {}, + unix.SYS_PREADV: {}, + unix.SYS_PREADV2: {}, + unix.SYS_PWRITE64: {}, + unix.SYS_PWRITEV: {}, + unix.SYS_PWRITEV2: {}, + unix.SYS_READ: {}, + unix.SYS_RECVMSG: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), + seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_TRUNC), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), + seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_TRUNC | unix.MSG_PEEK), }, }, - syscall.SYS_RECVMMSG: []seccomp.Rule{ + unix.SYS_RECVMMSG: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.EqualTo(fdbased.MaxMsgsPerRecv), - seccomp.EqualTo(syscall.MSG_DONTWAIT), + seccomp.EqualTo(unix.MSG_DONTWAIT), seccomp.EqualTo(0), }, }, @@ -265,34 +264,34 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MSG_DONTWAIT), + seccomp.EqualTo(unix.MSG_DONTWAIT), seccomp.EqualTo(0), }, }, - syscall.SYS_RESTART_SYSCALL: {}, - syscall.SYS_RT_SIGACTION: {}, - syscall.SYS_RT_SIGPROCMASK: {}, - syscall.SYS_RT_SIGRETURN: {}, - syscall.SYS_SCHED_YIELD: {}, - syscall.SYS_SENDMSG: []seccomp.Rule{ + unix.SYS_RESTART_SYSCALL: {}, + unix.SYS_RT_SIGACTION: {}, + unix.SYS_RT_SIGPROCMASK: {}, + unix.SYS_RT_SIGRETURN: {}, + unix.SYS_SCHED_YIELD: {}, + unix.SYS_SENDMSG: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), + seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_NOSIGNAL), }, }, - syscall.SYS_SETITIMER: {}, - syscall.SYS_SHUTDOWN: []seccomp.Rule{ + unix.SYS_SETITIMER: {}, + unix.SYS_SHUTDOWN: []seccomp.Rule{ // Used by fs/host to shutdown host sockets. - {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RD)}, - {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_WR)}, + {seccomp.MatchAny{}, seccomp.EqualTo(unix.SHUT_RD)}, + {seccomp.MatchAny{}, seccomp.EqualTo(unix.SHUT_WR)}, // Used by unet to shutdown connections. - {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)}, + {seccomp.MatchAny{}, seccomp.EqualTo(unix.SHUT_RDWR)}, }, - syscall.SYS_SIGALTSTACK: {}, - unix.SYS_STATX: {}, - syscall.SYS_SYNC_FILE_RANGE: {}, - syscall.SYS_TEE: []seccomp.Rule{ + unix.SYS_SIGALTSTACK: {}, + unix.SYS_STATX: {}, + unix.SYS_SYNC_FILE_RANGE: {}, + unix.SYS_TEE: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, @@ -300,12 +299,12 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.EqualTo(unix.SPLICE_F_NONBLOCK), /* flags */ }, }, - syscall.SYS_TGKILL: []seccomp.Rule{ + unix.SYS_TGKILL: []seccomp.Rule{ { seccomp.EqualTo(uint64(os.Getpid())), }, }, - syscall.SYS_UTIMENSAT: []seccomp.Rule{ + unix.SYS_UTIMENSAT: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.EqualTo(0), /* null pathname */ @@ -313,9 +312,9 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.EqualTo(0), /* flags */ }, }, - syscall.SYS_WRITE: {}, + unix.SYS_WRITE: {}, // For rawfile.NonBlockingWriteIovec. - syscall.SYS_WRITEV: []seccomp.Rule{ + unix.SYS_WRITEV: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, @@ -327,313 +326,313 @@ var allowedSyscalls = seccomp.SyscallRules{ // hostInetFilters contains syscalls that are needed by sentry/socket/hostinet. func hostInetFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - syscall.SYS_ACCEPT4: []seccomp.Rule{ + unix.SYS_ACCEPT4: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC), }, }, - syscall.SYS_BIND: {}, - syscall.SYS_CONNECT: {}, - syscall.SYS_GETPEERNAME: {}, - syscall.SYS_GETSOCKNAME: {}, - syscall.SYS_GETSOCKOPT: []seccomp.Rule{ + unix.SYS_BIND: {}, + unix.SYS_CONNECT: {}, + unix.SYS_GETPEERNAME: {}, + unix.SYS_GETSOCKNAME: {}, + unix.SYS_GETSOCKOPT: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_TOS), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_TOS), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_RECVTOS), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_RECVTOS), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_PKTINFO), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_PKTINFO), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_RECVORIGDSTADDR), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_RECVERR), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_RECVERR), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_TCLASS), + seccomp.EqualTo(unix.SOL_IPV6), + seccomp.EqualTo(unix.IPV6_TCLASS), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_RECVTCLASS), + seccomp.EqualTo(unix.SOL_IPV6), + seccomp.EqualTo(unix.IPV6_RECVTCLASS), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_RECVERR), + seccomp.EqualTo(unix.SOL_IPV6), + seccomp.EqualTo(unix.IPV6_RECVERR), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_V6ONLY), + seccomp.EqualTo(unix.SOL_IPV6), + seccomp.EqualTo(unix.IPV6_V6ONLY), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(unix.SOL_IPV6), seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_ERROR), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_ERROR), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_KEEPALIVE), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_KEEPALIVE), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_SNDBUF), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_SNDBUF), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_RCVBUF), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_RCVBUF), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_REUSEADDR), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_REUSEADDR), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_TYPE), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_TYPE), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_LINGER), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_LINGER), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_TIMESTAMP), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_TIMESTAMP), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_TCP), - seccomp.EqualTo(syscall.TCP_NODELAY), + seccomp.EqualTo(unix.SOL_TCP), + seccomp.EqualTo(unix.TCP_NODELAY), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_TCP), - seccomp.EqualTo(syscall.TCP_INFO), + seccomp.EqualTo(unix.SOL_TCP), + seccomp.EqualTo(unix.TCP_INFO), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(unix.SOL_TCP), seccomp.EqualTo(linux.TCP_INQ), }, }, - syscall.SYS_IOCTL: []seccomp.Rule{ + unix.SYS_IOCTL: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.TIOCOUTQ), + seccomp.EqualTo(unix.TIOCOUTQ), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.TIOCINQ), + seccomp.EqualTo(unix.TIOCINQ), }, }, - syscall.SYS_LISTEN: {}, - syscall.SYS_READV: {}, - syscall.SYS_RECVFROM: {}, - syscall.SYS_RECVMSG: {}, - syscall.SYS_SENDMSG: {}, - syscall.SYS_SENDTO: {}, - syscall.SYS_SETSOCKOPT: []seccomp.Rule{ + unix.SYS_LISTEN: {}, + unix.SYS_READV: {}, + unix.SYS_RECVFROM: {}, + unix.SYS_RECVMSG: {}, + unix.SYS_SENDMSG: {}, + unix.SYS_SENDTO: {}, + unix.SYS_SETSOCKOPT: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_SNDBUF), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_SNDBUF), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_RCVBUF), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_RCVBUF), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_REUSEADDR), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_REUSEADDR), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_TIMESTAMP), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_TIMESTAMP), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_TCP), - seccomp.EqualTo(syscall.TCP_NODELAY), + seccomp.EqualTo(unix.SOL_TCP), + seccomp.EqualTo(unix.TCP_NODELAY), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_TCP), + seccomp.EqualTo(unix.SOL_TCP), seccomp.EqualTo(linux.TCP_INQ), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_TOS), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_TOS), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_RECVTOS), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_RECVTOS), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_PKTINFO), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_PKTINFO), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_RECVORIGDSTADDR), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_RECVORIGDSTADDR), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IP), - seccomp.EqualTo(syscall.IP_RECVERR), + seccomp.EqualTo(unix.SOL_IP), + seccomp.EqualTo(unix.IP_RECVERR), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_TCLASS), + seccomp.EqualTo(unix.SOL_IPV6), + seccomp.EqualTo(unix.IPV6_TCLASS), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_RECVTCLASS), + seccomp.EqualTo(unix.SOL_IPV6), + seccomp.EqualTo(unix.IPV6_RECVTCLASS), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), + seccomp.EqualTo(unix.SOL_IPV6), seccomp.EqualTo(linux.IPV6_RECVORIGDSTADDR), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_RECVERR), + seccomp.EqualTo(unix.SOL_IPV6), + seccomp.EqualTo(unix.IPV6_RECVERR), seccomp.MatchAny{}, seccomp.EqualTo(4), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_IPV6), - seccomp.EqualTo(syscall.IPV6_V6ONLY), + seccomp.EqualTo(unix.SOL_IPV6), + seccomp.EqualTo(unix.IPV6_V6ONLY), seccomp.MatchAny{}, seccomp.EqualTo(4), }, }, - syscall.SYS_SHUTDOWN: []seccomp.Rule{ + unix.SYS_SHUTDOWN: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SHUT_RD), + seccomp.EqualTo(unix.SHUT_RD), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SHUT_WR), + seccomp.EqualTo(unix.SHUT_WR), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SHUT_RDWR), + seccomp.EqualTo(unix.SHUT_RDWR), }, }, - syscall.SYS_SOCKET: []seccomp.Rule{ + unix.SYS_SOCKET: []seccomp.Rule{ { - seccomp.EqualTo(syscall.AF_INET), - seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(unix.AF_INET), + seccomp.EqualTo(unix.SOCK_STREAM | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC), seccomp.EqualTo(0), }, { - seccomp.EqualTo(syscall.AF_INET), - seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(unix.AF_INET), + seccomp.EqualTo(unix.SOCK_DGRAM | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC), seccomp.EqualTo(0), }, { - seccomp.EqualTo(syscall.AF_INET6), - seccomp.EqualTo(syscall.SOCK_STREAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(unix.AF_INET6), + seccomp.EqualTo(unix.SOCK_STREAM | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC), seccomp.EqualTo(0), }, { - seccomp.EqualTo(syscall.AF_INET6), - seccomp.EqualTo(syscall.SOCK_DGRAM | syscall.SOCK_NONBLOCK | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(unix.AF_INET6), + seccomp.EqualTo(unix.SOCK_DGRAM | unix.SOCK_NONBLOCK | unix.SOCK_CLOEXEC), seccomp.EqualTo(0), }, }, - syscall.SYS_WRITEV: {}, + unix.SYS_WRITEV: {}, } } func controlServerFilters(fd int) seccomp.SyscallRules { return seccomp.SyscallRules{ - syscall.SYS_ACCEPT: []seccomp.Rule{ + unix.SYS_ACCEPT: []seccomp.Rule{ { seccomp.EqualTo(fd), }, }, - syscall.SYS_LISTEN: []seccomp.Rule{ + unix.SYS_LISTEN: []seccomp.Rule{ { seccomp.EqualTo(fd), seccomp.EqualTo(16 /* unet.backlog */), }, }, - syscall.SYS_GETSOCKOPT: []seccomp.Rule{ + unix.SYS_GETSOCKOPT: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.SOL_SOCKET), - seccomp.EqualTo(syscall.SO_PEERCRED), + seccomp.EqualTo(unix.SOL_SOCKET), + seccomp.EqualTo(unix.SO_PEERCRED), }, }, } diff --git a/runsc/boot/filter/config_amd64.go b/runsc/boot/filter/config_amd64.go index cea5613b8..42cb8ed3a 100644 --- a/runsc/boot/filter/config_amd64.go +++ b/runsc/boot/filter/config_amd64.go @@ -17,30 +17,29 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" ) func init() { - allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{ + allowedSyscalls[unix.SYS_ARCH_PRCTL] = []seccomp.Rule{ // TODO(b/168828518): No longer used in Go 1.16+. {seccomp.EqualTo(linux.ARCH_SET_FS)}, } - allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{ + allowedSyscalls[unix.SYS_CLONE] = []seccomp.Rule{ // parent_tidptr and child_tidptr are always 0 because neither // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used. { seccomp.EqualTo( - syscall.CLONE_VM | - syscall.CLONE_FS | - syscall.CLONE_FILES | - syscall.CLONE_SETTLS | - syscall.CLONE_SIGHAND | - syscall.CLONE_SYSVSEM | - syscall.CLONE_THREAD), + unix.CLONE_VM | + unix.CLONE_FS | + unix.CLONE_FILES | + unix.CLONE_SETTLS | + unix.CLONE_SIGHAND | + unix.CLONE_SYSVSEM | + unix.CLONE_THREAD), seccomp.MatchAny{}, // newsp seccomp.EqualTo(0), // parent_tidptr seccomp.EqualTo(0), // child_tidptr @@ -49,12 +48,12 @@ func init() { { // TODO(b/168828518): No longer used in Go 1.16+ (on amd64). seccomp.EqualTo( - syscall.CLONE_VM | - syscall.CLONE_FS | - syscall.CLONE_FILES | - syscall.CLONE_SIGHAND | - syscall.CLONE_SYSVSEM | - syscall.CLONE_THREAD), + unix.CLONE_VM | + unix.CLONE_FS | + unix.CLONE_FILES | + unix.CLONE_SIGHAND | + unix.CLONE_SYSVSEM | + unix.CLONE_THREAD), seccomp.MatchAny{}, // newsp seccomp.EqualTo(0), // parent_tidptr seccomp.EqualTo(0), // child_tidptr diff --git a/runsc/boot/filter/config_arm64.go b/runsc/boot/filter/config_arm64.go index 37313f97f..f162f87ff 100644 --- a/runsc/boot/filter/config_arm64.go +++ b/runsc/boot/filter/config_arm64.go @@ -17,21 +17,20 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/seccomp" ) func init() { - allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{ + allowedSyscalls[unix.SYS_CLONE] = []seccomp.Rule{ { seccomp.EqualTo( - syscall.CLONE_VM | - syscall.CLONE_FS | - syscall.CLONE_FILES | - syscall.CLONE_SIGHAND | - syscall.CLONE_SYSVSEM | - syscall.CLONE_THREAD), + unix.CLONE_VM | + unix.CLONE_FS | + unix.CLONE_FILES | + unix.CLONE_SIGHAND | + unix.CLONE_SYSVSEM | + unix.CLONE_THREAD), seccomp.MatchAny{}, // newsp // These arguments are left uninitialized by the Go // runtime, so they may be anything (and are unused by diff --git a/runsc/boot/filter/config_profile.go b/runsc/boot/filter/config_profile.go index 7b8669595..89b66a6da 100644 --- a/runsc/boot/filter/config_profile.go +++ b/runsc/boot/filter/config_profile.go @@ -15,19 +15,18 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/seccomp" ) // profileFilters returns extra syscalls made by runtime/pprof package. func profileFilters() seccomp.SyscallRules { return seccomp.SyscallRules{ - syscall.SYS_OPENAT: []seccomp.Rule{ + unix.SYS_OPENAT: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.O_RDONLY | syscall.O_LARGEFILE | syscall.O_CLOEXEC), + seccomp.EqualTo(unix.O_RDONLY | unix.O_LARGEFILE | unix.O_CLOEXEC), }, }, } diff --git a/runsc/boot/filter/extra_filters_msan.go b/runsc/boot/filter/extra_filters_msan.go index 209e646a7..41baa78cd 100644 --- a/runsc/boot/filter/extra_filters_msan.go +++ b/runsc/boot/filter/extra_filters_msan.go @@ -17,8 +17,7 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/seccomp" ) @@ -26,9 +25,9 @@ import ( func instrumentationFilters() seccomp.SyscallRules { Report("MSAN is enabled: syscall filters less restrictive!") return seccomp.SyscallRules{ - syscall.SYS_CLONE: {}, - syscall.SYS_MMAP: {}, - syscall.SYS_SCHED_GETAFFINITY: {}, - syscall.SYS_SET_ROBUST_LIST: {}, + unix.SYS_CLONE: {}, + unix.SYS_MMAP: {}, + unix.SYS_SCHED_GETAFFINITY: {}, + unix.SYS_SET_ROBUST_LIST: {}, } } diff --git a/runsc/boot/filter/extra_filters_race.go b/runsc/boot/filter/extra_filters_race.go index 5b99eb8cd..79b2104f0 100644 --- a/runsc/boot/filter/extra_filters_race.go +++ b/runsc/boot/filter/extra_filters_race.go @@ -17,8 +17,7 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/seccomp" ) @@ -26,17 +25,17 @@ import ( func instrumentationFilters() seccomp.SyscallRules { Report("TSAN is enabled: syscall filters less restrictive!") return seccomp.SyscallRules{ - syscall.SYS_BRK: {}, - syscall.SYS_CLOCK_NANOSLEEP: {}, - syscall.SYS_CLONE: {}, - syscall.SYS_FUTEX: {}, - syscall.SYS_MMAP: {}, - syscall.SYS_MUNLOCK: {}, - syscall.SYS_NANOSLEEP: {}, - syscall.SYS_OPEN: {}, - syscall.SYS_OPENAT: {}, - syscall.SYS_SET_ROBUST_LIST: {}, + unix.SYS_BRK: {}, + unix.SYS_CLOCK_NANOSLEEP: {}, + unix.SYS_CLONE: {}, + unix.SYS_FUTEX: {}, + unix.SYS_MMAP: {}, + unix.SYS_MUNLOCK: {}, + unix.SYS_NANOSLEEP: {}, + unix.SYS_OPEN: {}, + unix.SYS_OPENAT: {}, + unix.SYS_SET_ROBUST_LIST: {}, // Used within glibc's malloc. - syscall.SYS_TIME: {}, + unix.SYS_TIME: {}, } } diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index 2b0d2cd51..77f632bb9 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -20,9 +20,9 @@ import ( "sort" "strconv" "strings" - "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fd" @@ -312,11 +312,11 @@ func setupContainerFS(ctx context.Context, conf *config.Config, mntr *containerM } func adjustDirentCache(k *kernel.Kernel) error { - var hl syscall.Rlimit - if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &hl); err != nil { + var hl unix.Rlimit + if err := unix.Getrlimit(unix.RLIMIT_NOFILE, &hl); err != nil { return fmt.Errorf("getting RLIMIT_NOFILE: %v", err) } - if int64(hl.Cur) != syscall.RLIM_INFINITY { + if hl.Cur != unix.RLIM_INFINITY { newSize := hl.Cur / 2 if newSize < gofer.DefaultDirentCacheSize { log.Infof("Setting gofer dirent cache size to %d", newSize) @@ -844,10 +844,10 @@ func (c *containerMounter) mountSubmount(ctx context.Context, conf *config.Confi // than simply printed to the logs for the 'runsc boot' command. // // We check the error message string rather than type because the - // actual error types (syscall.EIO, syscall.EPIPE) are lost by file system + // actual error types (unix.EIO, unix.EPIPE) are lost by file system // implementation (e.g. p9). // TODO(gvisor.dev/issue/1765): Remove message when bug is resolved. - if strings.Contains(err.Error(), syscall.EIO.Error()) || strings.Contains(err.Error(), syscall.EPIPE.Error()) { + if strings.Contains(err.Error(), unix.EIO.Error()) || strings.Contains(err.Error(), unix.EPIPE.Error()) { return fmt.Errorf("%v: %s", err, specutils.FaqErrorMsg("memlock", "you may be encountering a Linux kernel bug")) } return err diff --git a/runsc/boot/limits.go b/runsc/boot/limits.go index ce62236e5..3d2b3506d 100644 --- a/runsc/boot/limits.go +++ b/runsc/boot/limits.go @@ -16,9 +16,9 @@ package boot import ( "fmt" - "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sync" @@ -104,9 +104,9 @@ func (d *defs) initDefaults() error { // Read host limits that directly affect the sandbox and adjust the defaults // based on them. - for _, res := range []int{syscall.RLIMIT_FSIZE, syscall.RLIMIT_NOFILE} { - var hl syscall.Rlimit - if err := syscall.Getrlimit(res, &hl); err != nil { + for _, res := range []int{unix.RLIMIT_FSIZE, unix.RLIMIT_NOFILE} { + var hl unix.Rlimit + if err := unix.Getrlimit(res, &hl); err != nil { return err } diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index b77b4762e..3121ca6eb 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -19,7 +19,6 @@ import ( "math/rand" "os" "reflect" - "syscall" "testing" "time" @@ -78,7 +77,7 @@ func testSpec() *specs.Spec { // sandbox side of the connection, and a function that when called will stop the // gofer. func startGofer(root string) (int, func(), error) { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) + fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0) if err != nil { return 0, nil, err } @@ -86,8 +85,8 @@ func startGofer(root string) (int, func(), error) { socket, err := unet.NewSocket(goferEnd) if err != nil { - syscall.Close(sandboxEnd) - syscall.Close(goferEnd) + unix.Close(sandboxEnd) + unix.Close(goferEnd) return 0, nil, fmt.Errorf("error creating server on FD %d: %v", goferEnd, err) } at, err := fsgofer.NewAttachPoint(root, fsgofer.Config{ROMount: true}) diff --git a/runsc/boot/network.go b/runsc/boot/network.go index 3d3a813df..7e627e4c6 100644 --- a/runsc/boot/network.go +++ b/runsc/boot/network.go @@ -19,8 +19,8 @@ import ( "net" "runtime" "strings" - "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" @@ -195,7 +195,7 @@ func (n *Network) CreateLinksAndRoutes(args *CreateLinksAndRoutesArgs, _ *struct for j := 0; j < link.NumChannels; j++ { // Copy the underlying FD. oldFD := args.FilePayload.Files[fdOffset].Fd() - newFD, err := syscall.Dup(int(oldFD)) + newFD, err := unix.Dup(int(oldFD)) if err != nil { return fmt.Errorf("failed to dup FD %v: %v", oldFD, err) } diff --git a/runsc/cgroup/cgroup.go b/runsc/cgroup/cgroup.go index ac9e4e3a8..438b7ef3e 100644 --- a/runsc/cgroup/cgroup.go +++ b/runsc/cgroup/cgroup.go @@ -27,7 +27,6 @@ import ( "path/filepath" "strconv" "strings" - "syscall" "time" "github.com/cenkalti/backoff" @@ -111,7 +110,7 @@ func setValue(path, name, data string) error { err := ioutil.WriteFile(fullpath, []byte(data), 0700) if err == nil { return nil - } else if !errors.Is(err, syscall.EINTR) { + } else if !errors.Is(err, unix.EINTR) { return err } } @@ -161,7 +160,7 @@ func fillFromAncestor(path string) (string, error) { err := ioutil.WriteFile(path, []byte(val), 0700) if err == nil { break - } else if !errors.Is(err, syscall.EINTR) { + } else if !errors.Is(err, unix.EINTR) { return "", err } } @@ -337,7 +336,7 @@ func (c *Cgroup) Install(res *specs.LinuxResources) error { c.Own[key] = true if err := os.MkdirAll(path, 0755); err != nil { - if cfg.optional && errors.Is(err, syscall.EROFS) { + if cfg.optional && errors.Is(err, unix.EROFS) { log.Infof("Skipping cgroup %q", key) continue } @@ -370,7 +369,7 @@ func (c *Cgroup) Uninstall() error { defer cancel() b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx) fn := func() error { - err := syscall.Rmdir(path) + err := unix.Rmdir(path) if os.IsNotExist(err) { return nil } diff --git a/runsc/cli/BUILD b/runsc/cli/BUILD index 32cce2a18..f1e3cce68 100644 --- a/runsc/cli/BUILD +++ b/runsc/cli/BUILD @@ -18,5 +18,6 @@ go_library( "//runsc/flag", "//runsc/specutils", "@com_github_google_subcommands//:go_default_library", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/cli/main.go b/runsc/cli/main.go index bf6928941..a3c515f4b 100644 --- a/runsc/cli/main.go +++ b/runsc/cli/main.go @@ -23,10 +23,10 @@ import ( "os" "os/signal" "runtime" - "syscall" "time" "github.com/google/subcommands" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/platform" @@ -198,7 +198,7 @@ func Main(version string) { // want with them. Since Docker and Containerd both eat boot's stderr, we // dup our stderr to the provided log FD so that panics will appear in the // logs, rather than just disappear. - if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { + if err := unix.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err) } } else if conf.AlsoLogToStderr { @@ -227,11 +227,11 @@ func Main(version string) { // SIGTERM is sent to all processes if a test exceeds its // timeout and this case is handled by syscall_test_runner. log.Warningf("Block the TERM signal. This is only safe in tests!") - signal.Ignore(syscall.SIGTERM) + signal.Ignore(unix.SIGTERM) } // Call the subcommand and pass in the configuration. - var ws syscall.WaitStatus + var ws unix.WaitStatus subcmdCode := subcommands.Execute(context.Background(), conf, &ws) if subcmdCode == subcommands.ExitSuccess { log.Infof("Exiting with status: %v", ws) diff --git a/runsc/cmd/BUILD b/runsc/cmd/BUILD index e3e289da3..2c3b4058b 100644 --- a/runsc/cmd/BUILD +++ b/runsc/cmd/BUILD @@ -77,6 +77,7 @@ go_test( "delete_test.go", "exec_test.go", "gofer_test.go", + "mitigate_test.go", ], data = [ "//runsc", @@ -91,6 +92,8 @@ go_test( "//pkg/urpc", "//runsc/config", "//runsc/container", + "//runsc/mitigate", + "//runsc/mitigate/mock", "//runsc/specutils", "@com_github_google_go_cmp//cmp:go_default_library", "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", diff --git a/runsc/cmd/boot.go b/runsc/cmd/boot.go index 2c92e3067..a14249641 100644 --- a/runsc/cmd/boot.go +++ b/runsc/cmd/boot.go @@ -19,7 +19,6 @@ import ( "os" "runtime/debug" "strings" - "syscall" "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -259,8 +258,8 @@ func (b *Boot) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) ws := l.WaitExit() log.Infof("application exiting with %+v", ws) - waitStatus := args[1].(*syscall.WaitStatus) - *waitStatus = syscall.WaitStatus(ws.Status()) + waitStatus := args[1].(*unix.WaitStatus) + *waitStatus = unix.WaitStatus(ws.Status()) l.Destroy() return subcommands.ExitSuccess } diff --git a/runsc/cmd/checkpoint.go b/runsc/cmd/checkpoint.go index 124198239..a9dbe86de 100644 --- a/runsc/cmd/checkpoint.go +++ b/runsc/cmd/checkpoint.go @@ -18,9 +18,9 @@ import ( "context" "os" "path/filepath" - "syscall" "github.com/google/subcommands" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" @@ -73,7 +73,7 @@ func (c *Checkpoint) Execute(_ context.Context, f *flag.FlagSet, args ...interfa id := f.Arg(0) conf := args[0].(*config.Config) - waitStatus := args[1].(*syscall.WaitStatus) + waitStatus := args[1].(*unix.WaitStatus) cont, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { diff --git a/runsc/cmd/chroot.go b/runsc/cmd/chroot.go index 189244765..e988247da 100644 --- a/runsc/cmd/chroot.go +++ b/runsc/cmd/chroot.go @@ -18,8 +18,8 @@ import ( "fmt" "os" "path/filepath" - "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/specutils" ) @@ -49,11 +49,11 @@ func pivotRoot(root string) error { // will be moved to "/" too. The parent mount of the old_root will be // new_root, so after umounting the old_root, we will see only // the new_root in "/". - if err := syscall.PivotRoot(".", "."); err != nil { + if err := unix.PivotRoot(".", "."); err != nil { return fmt.Errorf("pivot_root failed, make sure that the root mount has a parent: %v", err) } - if err := syscall.Unmount(".", syscall.MNT_DETACH); err != nil { + if err := unix.Unmount(".", unix.MNT_DETACH); err != nil { return fmt.Errorf("error umounting the old root file system: %v", err) } return nil @@ -70,26 +70,26 @@ func setUpChroot(pidns bool) error { // Convert all shared mounts into slave to be sure that nothing will be // propagated outside of our namespace. - if err := syscall.Mount("", "/", "", syscall.MS_SLAVE|syscall.MS_REC, ""); err != nil { + if err := unix.Mount("", "/", "", unix.MS_SLAVE|unix.MS_REC, ""); err != nil { return fmt.Errorf("error converting mounts: %v", err) } - if err := syscall.Mount("runsc-root", chroot, "tmpfs", syscall.MS_NOSUID|syscall.MS_NODEV|syscall.MS_NOEXEC, ""); err != nil { + if err := unix.Mount("runsc-root", chroot, "tmpfs", unix.MS_NOSUID|unix.MS_NODEV|unix.MS_NOEXEC, ""); err != nil { return fmt.Errorf("error mounting tmpfs in choot: %v", err) } if pidns { - flags := uint32(syscall.MS_NOSUID | syscall.MS_NODEV | syscall.MS_NOEXEC | syscall.MS_RDONLY) + flags := uint32(unix.MS_NOSUID | unix.MS_NODEV | unix.MS_NOEXEC | unix.MS_RDONLY) if err := mountInChroot(chroot, "proc", "/proc", "proc", flags); err != nil { return fmt.Errorf("error mounting proc in chroot: %v", err) } } else { - if err := mountInChroot(chroot, "/proc", "/proc", "bind", syscall.MS_BIND|syscall.MS_RDONLY|syscall.MS_REC); err != nil { + if err := mountInChroot(chroot, "/proc", "/proc", "bind", unix.MS_BIND|unix.MS_RDONLY|unix.MS_REC); err != nil { return fmt.Errorf("error mounting proc in chroot: %v", err) } } - if err := syscall.Mount("", chroot, "", syscall.MS_REMOUNT|syscall.MS_RDONLY|syscall.MS_BIND, ""); err != nil { + if err := unix.Mount("", chroot, "", unix.MS_REMOUNT|unix.MS_RDONLY|unix.MS_BIND, ""); err != nil { return fmt.Errorf("error remounting chroot in read-only: %v", err) } diff --git a/runsc/cmd/cmd.go b/runsc/cmd/cmd.go index f1a4887ef..4dd55cc33 100644 --- a/runsc/cmd/cmd.go +++ b/runsc/cmd/cmd.go @@ -19,9 +19,9 @@ import ( "fmt" "runtime" "strconv" - "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/specutils" ) @@ -71,7 +71,7 @@ func setCapsAndCallSelf(args []string, caps *specs.LinuxCapabilities) error { binPath := specutils.ExePath log.Infof("Execve %q again, bye!", binPath) - err := syscall.Exec(binPath, args, []string{}) + err := unix.Exec(binPath, args, []string{}) return fmt.Errorf("error executing %s: %v", binPath, err) } @@ -83,16 +83,16 @@ func callSelfAsNobody(args []string) error { const nobody = 65534 - if _, _, err := syscall.RawSyscall(syscall.SYS_SETGID, uintptr(nobody), 0, 0); err != 0 { + if _, _, err := unix.RawSyscall(unix.SYS_SETGID, uintptr(nobody), 0, 0); err != 0 { return fmt.Errorf("error setting uid: %v", err) } - if _, _, err := syscall.RawSyscall(syscall.SYS_SETUID, uintptr(nobody), 0, 0); err != 0 { + if _, _, err := unix.RawSyscall(unix.SYS_SETUID, uintptr(nobody), 0, 0); err != 0 { return fmt.Errorf("error setting gid: %v", err) } binPath := specutils.ExePath log.Infof("Execve %q again, bye!", binPath) - err := syscall.Exec(binPath, args, []string{}) + err := unix.Exec(binPath, args, []string{}) return fmt.Errorf("error executing %s: %v", binPath, err) } diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index b84142b0d..6212ffb2e 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -21,10 +21,10 @@ import ( "strconv" "strings" "sync" - "syscall" "time" "github.com/google/subcommands" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/runsc/config" @@ -135,7 +135,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Perform synchronous actions. if d.signal > 0 { log.Infof("Sending signal %d to process: %d", d.signal, c.Sandbox.Pid) - if err := syscall.Kill(c.Sandbox.Pid, syscall.Signal(d.signal)); err != nil { + if err := unix.Kill(c.Sandbox.Pid, unix.Signal(d.signal)); err != nil { return Errorf("failed to send signal %d to processs %d", d.signal, c.Sandbox.Pid) } } @@ -317,7 +317,7 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) wg.Wait() }() signals := make(chan os.Signal, 1) - signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) + signal.Notify(signals, unix.SIGTERM, unix.SIGINT) select { case <-readyChan: break // Safe to proceed. diff --git a/runsc/cmd/do.go b/runsc/cmd/do.go index 8a8d9f752..22c1dfeb8 100644 --- a/runsc/cmd/do.go +++ b/runsc/cmd/do.go @@ -26,10 +26,10 @@ import ( "path/filepath" "strconv" "strings" - "syscall" "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" @@ -86,7 +86,7 @@ func (c *Do) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) su } conf := args[0].(*config.Config) - waitStatus := args[1].(*syscall.WaitStatus) + waitStatus := args[1].(*unix.WaitStatus) if conf.Rootless { if err := specutils.MaybeRunAsRoot(); err != nil { @@ -225,7 +225,7 @@ func resolvePath(path string) (string, error) { return "", fmt.Errorf("resolving %q: %v", path, err) } path = filepath.Clean(path) - if err := syscall.Access(path, 0); err != nil { + if err := unix.Access(path, 0); err != nil { return "", fmt.Errorf("unable to access %q: %v", path, err) } return path, nil diff --git a/runsc/cmd/exec.go b/runsc/cmd/exec.go index e9726401a..242d474b8 100644 --- a/runsc/cmd/exec.go +++ b/runsc/cmd/exec.go @@ -24,11 +24,11 @@ import ( "path/filepath" "strconv" "strings" - "syscall" "time" "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -110,7 +110,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) if err != nil { Fatalf("parsing process spec: %v", err) } - waitStatus := args[1].(*syscall.WaitStatus) + waitStatus := args[1].(*unix.WaitStatus) c, err := container.Load(conf.RootDir, container.FullID{ContainerID: id}, container.LoadOpts{}) if err != nil { @@ -149,7 +149,7 @@ func (ex *Exec) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return ex.exec(c, e, waitStatus) } -func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *syscall.WaitStatus) subcommands.ExitStatus { +func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *unix.WaitStatus) subcommands.ExitStatus { // Start the new process and get its pid. pid, err := c.Execute(e) if err != nil { @@ -189,7 +189,7 @@ func (ex *Exec) exec(c *container.Container, e *control.ExecArgs, waitStatus *sy return subcommands.ExitSuccess } -func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.ExitStatus { +func (ex *Exec) execChildAndWait(waitStatus *unix.WaitStatus) subcommands.ExitStatus { var args []string for _, a := range os.Args[1:] { if !strings.Contains(a, "detach") { @@ -233,7 +233,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi cmd.Stdin = tty cmd.Stdout = tty cmd.Stderr = tty - cmd.SysProcAttr = &syscall.SysProcAttr{ + cmd.SysProcAttr = &unix.SysProcAttr{ Setsid: true, Setctty: true, // The Ctty FD must be the FD in the child process's FD @@ -263,7 +263,7 @@ func (ex *Exec) execChildAndWait(waitStatus *syscall.WaitStatus) subcommands.Exi } return pid == cmd.Process.Pid, nil } - if pe, ok := err.(*os.PathError); !ok || pe.Err != syscall.ENOENT { + if pe, ok := err.(*os.PathError); !ok || pe.Err != unix.ENOENT { return false, err } // No file yet, continue to wait... diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 371fcc0ae..639b2219c 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -21,7 +21,6 @@ import ( "os" "path/filepath" "strings" - "syscall" "github.com/google/subcommands" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -149,16 +148,16 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // fsgofer should run with a umask of 0, because we want to preserve file // modes exactly as sent by the sandbox, which will have applied its own umask. - syscall.Umask(0) + unix.Umask(0) if err := fsgofer.OpenProcSelfFD(); err != nil { Fatalf("failed to open /proc/self/fd: %v", err) } - if err := syscall.Chroot(root); err != nil { + if err := unix.Chroot(root); err != nil { Fatalf("failed to chroot to %q: %v", root, err) } - if err := syscall.Chdir("/"); err != nil { + if err := unix.Chdir("/"); err != nil { Fatalf("changing working dir: %v", err) } log.Infof("Process chroot'd to %q", root) @@ -166,7 +165,8 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) // Start with root mount, then add any other additional mount as needed. ats := make([]p9.Attacher, 0, len(spec.Mounts)+1) ap, err := fsgofer.NewAttachPoint("/", fsgofer.Config{ - ROMount: spec.Root.Readonly || conf.Overlay, + ROMount: spec.Root.Readonly || conf.Overlay, + EnableXattr: conf.Verity, }) if err != nil { Fatalf("creating attach point: %v", err) @@ -178,8 +178,9 @@ func (g *Gofer) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) for _, m := range spec.Mounts { if specutils.Is9PMount(m) { cfg := fsgofer.Config{ - ROMount: isReadonlyMount(m.Options) || conf.Overlay, - HostUDS: conf.FSGoferHostUDS, + ROMount: isReadonlyMount(m.Options) || conf.Overlay, + HostUDS: conf.FSGoferHostUDS, + EnableXattr: conf.Verity, } ap, err := fsgofer.NewAttachPoint(m.Destination, cfg) if err != nil { @@ -262,7 +263,7 @@ func isReadonlyMount(opts []string) bool { func setupRootFS(spec *specs.Spec, conf *config.Config) error { // Convert all shared mounts into slaves to be sure that nothing will be // propagated outside of our namespace. - if err := syscall.Mount("", "/", "", syscall.MS_SLAVE|syscall.MS_REC, ""); err != nil { + if err := unix.Mount("", "/", "", unix.MS_SLAVE|unix.MS_REC, ""); err != nil { Fatalf("error converting mounts: %v", err) } @@ -274,30 +275,30 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error { // // We need a directory to construct a new root and we know that // runsc can't start without /proc, so we can use it for this. - flags := uintptr(syscall.MS_NOSUID | syscall.MS_NODEV | syscall.MS_NOEXEC) - if err := syscall.Mount("runsc-root", "/proc", "tmpfs", flags, ""); err != nil { + flags := uintptr(unix.MS_NOSUID | unix.MS_NODEV | unix.MS_NOEXEC) + if err := unix.Mount("runsc-root", "/proc", "tmpfs", flags, ""); err != nil { Fatalf("error mounting tmpfs: %v", err) } // Prepare tree structure for pivot_root(2). os.Mkdir("/proc/proc", 0755) os.Mkdir("/proc/root", 0755) - if err := syscall.Mount("runsc-proc", "/proc/proc", "proc", flags|syscall.MS_RDONLY, ""); err != nil { + if err := unix.Mount("runsc-proc", "/proc/proc", "proc", flags|unix.MS_RDONLY, ""); err != nil { Fatalf("error mounting proc: %v", err) } root = "/proc/root" } // Mount root path followed by submounts. - if err := syscall.Mount(spec.Root.Path, root, "bind", syscall.MS_BIND|syscall.MS_REC, ""); err != nil { + if err := unix.Mount(spec.Root.Path, root, "bind", unix.MS_BIND|unix.MS_REC, ""); err != nil { return fmt.Errorf("mounting root on root (%q) err: %v", root, err) } - flags := uint32(syscall.MS_SLAVE | syscall.MS_REC) + flags := uint32(unix.MS_SLAVE | unix.MS_REC) if spec.Linux != nil && spec.Linux.RootfsPropagation != "" { flags = specutils.PropOptionsToFlags([]string{spec.Linux.RootfsPropagation}) } - if err := syscall.Mount("", root, "", uintptr(flags), ""); err != nil { + if err := unix.Mount("", root, "", uintptr(flags), ""); err != nil { return fmt.Errorf("mounting root (%q) with flags: %#x, err: %v", root, flags, err) } @@ -323,8 +324,8 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error { // If root is a mount point but not read-only, we can change mount options // to make it read-only for extra safety. log.Infof("Remounting root as readonly: %q", root) - flags := uintptr(syscall.MS_BIND | syscall.MS_REMOUNT | syscall.MS_RDONLY | syscall.MS_REC) - if err := syscall.Mount(root, root, "bind", flags, ""); err != nil { + flags := uintptr(unix.MS_BIND | unix.MS_REMOUNT | unix.MS_RDONLY | unix.MS_REC) + if err := unix.Mount(root, root, "bind", flags, ""); err != nil { return fmt.Errorf("remounting root as read-only with source: %q, target: %q, flags: %#x, err: %v", root, root, flags, err) } } @@ -354,10 +355,10 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error { return fmt.Errorf("resolving symlinks to %q: %v", m.Destination, err) } - flags := specutils.OptionsToFlags(m.Options) | syscall.MS_BIND + flags := specutils.OptionsToFlags(m.Options) | unix.MS_BIND if conf.Overlay { // Force mount read-only if writes are not going to be sent to it. - flags |= syscall.MS_RDONLY + flags |= unix.MS_RDONLY } log.Infof("Mounting src: %q, dst: %q, flags: %#x", m.Source, dst, flags) @@ -368,7 +369,7 @@ func setupMounts(conf *config.Config, mounts []specs.Mount, root string) error { // Set propagation options that cannot be set together with other options. flags = specutils.PropOptionsToFlags(m.Options) if flags != 0 { - if err := syscall.Mount("", dst, "", uintptr(flags), ""); err != nil { + if err := unix.Mount("", dst, "", uintptr(flags), ""); err != nil { return fmt.Errorf("mount dst: %q, flags: %#x, err: %v", dst, flags, err) } } @@ -469,8 +470,8 @@ func adjustMountOptions(conf *config.Config, path string, opts []string) ([]stri copy(rv, opts) if conf.OverlayfsStaleRead { - statfs := syscall.Statfs_t{} - if err := syscall.Statfs(path, &statfs); err != nil { + statfs := unix.Statfs_t{} + if err := unix.Statfs(path, &statfs); err != nil { return nil, err } if statfs.Type == unix.OVERLAYFS_SUPER_MAGIC { diff --git a/runsc/cmd/kill.go b/runsc/cmd/kill.go index e0df39266..239fc7ac2 100644 --- a/runsc/cmd/kill.go +++ b/runsc/cmd/kill.go @@ -19,7 +19,6 @@ import ( "fmt" "strconv" "strings" - "syscall" "github.com/google/subcommands" "golang.org/x/sys/unix" @@ -99,10 +98,10 @@ func (k *Kill) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitSuccess } -func parseSignal(s string) (syscall.Signal, error) { +func parseSignal(s string) (unix.Signal, error) { n, err := strconv.Atoi(s) if err == nil { - sig := syscall.Signal(n) + sig := unix.Signal(n) for _, msig := range signalMap { if sig == msig { return sig, nil @@ -116,7 +115,7 @@ func parseSignal(s string) (syscall.Signal, error) { return -1, fmt.Errorf("unknown signal %q", s) } -var signalMap = map[string]syscall.Signal{ +var signalMap = map[string]unix.Signal{ "ABRT": unix.SIGABRT, "ALRM": unix.SIGALRM, "BUS": unix.SIGBUS, diff --git a/runsc/cmd/mitigate.go b/runsc/cmd/mitigate.go index 822af1917..fddf0e0dd 100644 --- a/runsc/cmd/mitigate.go +++ b/runsc/cmd/mitigate.go @@ -16,6 +16,8 @@ package cmd import ( "context" + "fmt" + "io/ioutil" "github.com/google/subcommands" "gvisor.dev/gvisor/pkg/log" @@ -23,9 +25,23 @@ import ( "gvisor.dev/gvisor/runsc/mitigate" ) +const ( + // cpuInfo is the path used to parse CPU info. + cpuInfo = "/proc/cpuinfo" + // allPossibleCPUs is the path used to enable CPUs. + allPossibleCPUs = "/sys/devices/system/cpu/possible" +) + // Mitigate implements subcommands.Command for the "mitigate" command. type Mitigate struct { - mitigate mitigate.Mitigate + // Run the command without changing the underlying system. + dryRun bool + // Reverse mitigate by turning on all CPU cores. + reverse bool + // Path to file to read to create CPUSet. + path string + // Callback to check if a given thread is vulnerable. + vulnerable func(other mitigate.Thread) bool } // Name implements subcommands.command.name. @@ -38,14 +54,19 @@ func (*Mitigate) Synopsis() string { return "mitigate mitigates the underlying system against side channel attacks" } -// Usage implements subcommands.Command.Usage. -func (m *Mitigate) Usage() string { - return m.mitigate.Usage() +// Usage implments Usage for cmd.Mitigate. +func (m Mitigate) Usage() string { + return `mitigate [flags] + +mitigate mitigates a system to the "MDS" vulnerability by implementing a manual shutdown of SMT. The command checks /proc/cpuinfo for cpus having the MDS vulnerability, and if found, shutdown all but one CPU per hyperthread pair via /sys/devices/system/cpu/cpu{N}/online. CPUs can be restored by writing "2" to each file in /sys/devices/system/cpu/cpu{N}/online or performing a system reboot. + +The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online.` } -// SetFlags implements subcommands.Command.SetFlags. +// SetFlags sets flags for the command Mitigate. func (m *Mitigate) SetFlags(f *flag.FlagSet) { - m.mitigate.SetFlags(f) + f.BoolVar(&m.dryRun, "dryrun", false, "run the command without changing system") + f.BoolVar(&m.reverse, "reverse", false, "reverse mitigate by enabling all CPUs") } // Execute implements subcommands.Command.Execute. @@ -55,10 +76,97 @@ func (m *Mitigate) Execute(_ context.Context, f *flag.FlagSet, args ...interface return subcommands.ExitUsageError } - if err := m.mitigate.Execute(); err != nil { + m.path = cpuInfo + if m.reverse { + m.path = allPossibleCPUs + } + + m.vulnerable = func(other mitigate.Thread) bool { + return other.IsVulnerable() + } + + if _, err := m.doExecute(); err != nil { log.Warningf("Execute failed: %v", err) return subcommands.ExitFailure } return subcommands.ExitSuccess } + +// Execute executes the Mitigate command. +func (m *Mitigate) doExecute() (mitigate.CPUSet, error) { + if m.dryRun { + log.Infof("Running with DryRun. No cpu settings will be changed.") + } + if m.reverse { + data, err := ioutil.ReadFile(m.path) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %v", m.path, err) + } + + set, err := m.doReverse(data) + if err != nil { + return nil, fmt.Errorf("reverse operation failed: %v", err) + } + return set, nil + } + + data, err := ioutil.ReadFile(m.path) + if err != nil { + return nil, fmt.Errorf("failed to read %s: %v", m.path, err) + } + set, err := m.doMitigate(data) + if err != nil { + return nil, fmt.Errorf("mitigate operation failed: %v", err) + } + return set, nil +} + +func (m *Mitigate) doMitigate(data []byte) (mitigate.CPUSet, error) { + set, err := mitigate.NewCPUSet(data, m.vulnerable) + if err != nil { + return nil, err + } + + log.Infof("Mitigate found the following CPUs...") + log.Infof("%s", set) + + disableList := set.GetShutdownList() + log.Infof("Disabling threads on thread pairs.") + for _, t := range disableList { + log.Infof("Disable thread: %s", t) + if m.dryRun { + continue + } + if err := t.Disable(); err != nil { + return nil, fmt.Errorf("error disabling thread: %s err: %v", t, err) + } + } + log.Infof("Shutdown successful.") + return set, nil +} + +func (m *Mitigate) doReverse(data []byte) (mitigate.CPUSet, error) { + set, err := mitigate.NewCPUSetFromPossible(data) + if err != nil { + return nil, err + } + + log.Infof("Reverse mitigate found the following CPUs...") + log.Infof("%s", set) + + enableList := set.GetRemainingList() + + log.Infof("Enabling all CPUs...") + for _, t := range enableList { + log.Infof("Enabling thread: %s", t) + if m.dryRun { + continue + } + if err := t.Enable(); err != nil { + return nil, fmt.Errorf("error enabling thread: %s err: %v", t, err) + } + } + log.Infof("Enable successful.") + return set, nil +} diff --git a/runsc/cmd/mitigate_test.go b/runsc/cmd/mitigate_test.go new file mode 100644 index 000000000..163fece42 --- /dev/null +++ b/runsc/cmd/mitigate_test.go @@ -0,0 +1,169 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "fmt" + "io/ioutil" + "os" + "strings" + "testing" + + "gvisor.dev/gvisor/runsc/mitigate" + "gvisor.dev/gvisor/runsc/mitigate/mock" +) + +type executeTestCase struct { + name string + mitigateData string + mitigateError error + mitigateCPU int + reverseData string + reverseError error + reverseCPU int +} + +func TestExecute(t *testing.T) { + + partial := `processor : 1 +vendor_id : AuthenticAMD +cpu family : 23 +model : 49 +model name : AMD EPYC 7B12 +physical id : 0 +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass +power management: +` + + for _, tc := range []executeTestCase{ + { + name: "CascadeLake4", + mitigateData: mock.CascadeLake4.MakeCPUString(), + mitigateCPU: 2, + reverseData: mock.CascadeLake4.MakeSysPossibleString(), + reverseCPU: 4, + }, + { + name: "Empty", + mitigateData: "", + mitigateError: fmt.Errorf(`mitigate operation failed: no cpus found for: ""`), + reverseData: "", + reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from possible: ""`), + }, + { + name: "Partial", + mitigateData: `processor : 0 +vendor_id : AuthenticAMD +cpu family : 23 +model : 49 +model name : AMD EPYC 7B12 +physical id : 0 +core id : 0 +cpu cores : 1 +bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass +power management::84 + +` + partial, + mitigateError: fmt.Errorf(`mitigate operation failed: failed to match key "core id": %q`, partial), + reverseData: "1-", + reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from possible: %q`, "1-"), + }, + } { + t.Run(tc.name, func(t *testing.T) { + m := &Mitigate{ + dryRun: true, + vulnerable: func(other mitigate.Thread) bool { + return other.IsVulnerable() + }, + } + m.doExecuteTest(t, "Mitigate", tc.mitigateData, tc.mitigateCPU, tc.mitigateError) + + m.reverse = true + m.doExecuteTest(t, "Reverse", tc.reverseData, tc.reverseCPU, tc.reverseError) + }) + } +} + +func TestExecuteSmoke(t *testing.T) { + smokeMitigate, err := ioutil.ReadFile(cpuInfo) + if err != nil { + t.Fatalf("Failed to read %s: %v", cpuInfo, err) + } + + m := &Mitigate{ + dryRun: true, + vulnerable: func(other mitigate.Thread) bool { + return other.IsVulnerable() + }, + } + + m.doExecuteTest(t, "Mitigate", string(smokeMitigate), 0, nil) + + smokeReverse, err := ioutil.ReadFile(allPossibleCPUs) + if err != nil { + t.Fatalf("Failed to read %s: %v", allPossibleCPUs, err) + } + + m.reverse = true + m.doExecuteTest(t, "Reverse", string(smokeReverse), 0, nil) +} + +// doExecuteTest runs Execute with the mitigate operation and reverse operation. +func (m *Mitigate) doExecuteTest(t *testing.T, name, data string, want int, wantErr error) { + t.Run(name, func(t *testing.T) { + file, err := ioutil.TempFile("", "outfile.txt") + if err != nil { + t.Fatalf("Failed to create tmpfile: %v", err) + } + defer os.Remove(file.Name()) + + if _, err := file.WriteString(data); err != nil { + t.Fatalf("Failed to write to file: %v", err) + } + + // Set fields for mitigate and dryrun to keep test hermetic. + m.path = file.Name() + + set, err := m.doExecute() + if err = checkErr(wantErr, err); err != nil { + t.Fatalf("Mitigate error mismatch: %v", err) + } + + // case where test should end in error or we don't care + // about how many cpus are returned. + if wantErr != nil || want < 1 { + return + } + got := len(set.GetRemainingList()) + if want != got { + t.Fatalf("Failed wrong number of remaining CPUs: want %d, got %d", want, got) + } + + }) +} + +// checkErr checks error for equality. +func checkErr(want, got error) error { + switch { + case want == nil && got == nil: + case want != nil && got == nil: + fallthrough + case want == nil && got != nil: + fallthrough + case want.Error() != strings.Trim(got.Error(), " "): + return fmt.Errorf("got: %v want: %v", got, want) + } + return nil +} diff --git a/runsc/cmd/restore.go b/runsc/cmd/restore.go index 096ec814c..b21f05921 100644 --- a/runsc/cmd/restore.go +++ b/runsc/cmd/restore.go @@ -17,9 +17,9 @@ package cmd import ( "context" "path/filepath" - "syscall" "github.com/google/subcommands" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" @@ -78,7 +78,7 @@ func (r *Restore) Execute(_ context.Context, f *flag.FlagSet, args ...interface{ id := f.Arg(0) conf := args[0].(*config.Config) - waitStatus := args[1].(*syscall.WaitStatus) + waitStatus := args[1].(*unix.WaitStatus) if conf.Rootless { return Errorf("Rootless mode not supported with %q", r.Name()) diff --git a/runsc/cmd/run.go b/runsc/cmd/run.go index c48cbe4cd..722181aff 100644 --- a/runsc/cmd/run.go +++ b/runsc/cmd/run.go @@ -16,9 +16,9 @@ package cmd import ( "context" - "syscall" "github.com/google/subcommands" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" @@ -65,7 +65,7 @@ func (r *Run) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) s id := f.Arg(0) conf := args[0].(*config.Config) - waitStatus := args[1].(*syscall.WaitStatus) + waitStatus := args[1].(*unix.WaitStatus) if conf.Rootless { return Errorf("Rootless mode not supported with %q", r.Name()) diff --git a/runsc/cmd/wait.go b/runsc/cmd/wait.go index 5d55422c7..d7a783b88 100644 --- a/runsc/cmd/wait.go +++ b/runsc/cmd/wait.go @@ -18,9 +18,9 @@ import ( "context" "encoding/json" "os" - "syscall" "github.com/google/subcommands" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/runsc/config" "gvisor.dev/gvisor/runsc/container" "gvisor.dev/gvisor/runsc/flag" @@ -77,7 +77,7 @@ func (wt *Wait) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) Fatalf("loading container: %v", err) } - var waitStatus syscall.WaitStatus + var waitStatus unix.WaitStatus switch { // Wait on the whole container. case wt.rootPID == unsetPID && wt.pid == unsetPID: @@ -119,7 +119,7 @@ type waitResult struct { // exitStatus returns the correct exit status for a process based on if it // was signaled or exited cleanly. -func exitStatus(status syscall.WaitStatus) int { +func exitStatus(status unix.WaitStatus) int { if status.Signaled() { return 128 + int(status.Signal()) } diff --git a/runsc/config/config.go b/runsc/config/config.go index e9fd7708f..34ef48825 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -64,6 +64,9 @@ type Config struct { // Overlay is whether to wrap the root filesystem in an overlay. Overlay bool `flag:"overlay"` + // Verity is whether there's one or more verity file system to mount. + Verity bool `flag:"verity"` + // FSGoferHostUDS enables the gofer to mount a host UDS. FSGoferHostUDS bool `flag:"fsgofer-host-uds"` diff --git a/runsc/config/flags.go b/runsc/config/flags.go index 7e738dfdf..adbee506c 100644 --- a/runsc/config/flags.go +++ b/runsc/config/flags.go @@ -69,6 +69,7 @@ func RegisterFlags() { // Flags that control sandbox runtime behavior: FS related. flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") + flag.Bool("verity", false, "specifies whether a verity file system will be mounted.") flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem") flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") flag.Bool("vfs2", false, "enables VFSv2. This uses the new VFS layer that is faster than the previous one.") diff --git a/runsc/container/BUILD b/runsc/container/BUILD index 8793c8916..3620dc8c3 100644 --- a/runsc/container/BUILD +++ b/runsc/container/BUILD @@ -30,6 +30,7 @@ go_library( "@com_github_cenkalti_backoff//:go_default_library", "@com_github_gofrs_flock//:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/container/console_test.go b/runsc/container/console_test.go index 7a3d5a523..79b056fce 100644 --- a/runsc/container/console_test.go +++ b/runsc/container/console_test.go @@ -21,7 +21,6 @@ import ( "math/rand" "os" "path/filepath" - "syscall" "testing" "time" @@ -320,7 +319,7 @@ func TestJobControlSignalExec(t *testing.T) { // Send a SIGTERM to the foreground process for the exec PID. Note that // although we pass in the PID of "bash", it should actually terminate // "sleep", since that is the foreground process. - if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.SIGTERM, true /* fgProcess */); err != nil { + if err := c.Sandbox.SignalProcess(c.ID, pid, unix.SIGTERM, true /* fgProcess */); err != nil { t.Fatalf("error signaling container: %v", err) } @@ -340,7 +339,7 @@ func TestJobControlSignalExec(t *testing.T) { // Send a SIGKILL to the foreground process again. This time "bash" // should be killed. We use SIGKILL instead of SIGTERM or SIGINT // because bash ignores those. - if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.SIGKILL, true /* fgProcess */); err != nil { + if err := c.Sandbox.SignalProcess(c.ID, pid, unix.SIGKILL, true /* fgProcess */); err != nil { t.Fatalf("error signaling container: %v", err) } expectedPL = expectedPL[:1] @@ -356,7 +355,7 @@ func TestJobControlSignalExec(t *testing.T) { if !ws.Signaled() { t.Error("ws.Signaled() got false, want true") } - if got, want := ws.Signal(), syscall.SIGKILL; got != want { + if got, want := ws.Signal(), unix.SIGKILL; got != want { t.Errorf("ws.Signal() got %v, want %v", got, want) } } @@ -423,7 +422,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { // very early, otherwise it might exit before we have a chance to call // Wait. var ( - ws syscall.WaitStatus + ws unix.WaitStatus wg sync.WaitGroup ) wg.Add(1) @@ -459,7 +458,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { // Send a SIGTERM to the foreground process. We pass PID=0, indicating // that the root process should be killed. However, by setting // fgProcess=true, the signal should actually be sent to sleep. - if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, syscall.SIGTERM, true /* fgProcess */); err != nil { + if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, unix.SIGTERM, true /* fgProcess */); err != nil { t.Fatalf("error signaling container: %v", err) } @@ -479,7 +478,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { // Send a SIGKILL to the foreground process again. This time "bash" // should be killed. We use SIGKILL instead of SIGTERM or SIGINT // because bash ignores those. - if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, syscall.SIGKILL, true /* fgProcess */); err != nil { + if err := c.Sandbox.SignalProcess(c.ID, 0 /* PID */, unix.SIGKILL, true /* fgProcess */); err != nil { t.Fatalf("error signaling container: %v", err) } @@ -488,7 +487,7 @@ func TestJobControlSignalRootContainer(t *testing.T) { if !ws.Signaled() { t.Error("ws.Signaled() got false, want true") } - if got, want := ws.Signal(), syscall.SIGKILL; got != want { + if got, want := ws.Signal(), unix.SIGKILL; got != want { t.Errorf("ws.Signal() got %v, want %v", got, want) } } diff --git a/runsc/container/container.go b/runsc/container/container.go index 40812efb8..f9d83c118 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -30,6 +30,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/log" @@ -244,7 +245,7 @@ func New(conf *config.Config, args Args) (*Container, error) { // If there is cgroup config, install it before creating sandbox process. if err := cg.Install(args.Spec.Linux.Resources); err != nil { switch { - case errors.Is(err, syscall.EACCES) && conf.Rootless: + case errors.Is(err, unix.EACCES) && conf.Rootless: log.Warningf("Skipping cgroup configuration in rootless mode: %v", err) cg = nil default: @@ -447,7 +448,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *config.Config, restoreFile s } // Run is a helper that calls Create + Start + Wait. -func Run(conf *config.Config, args Args) (syscall.WaitStatus, error) { +func Run(conf *config.Config, args Args) (unix.WaitStatus, error) { log.Debugf("Run container, cid: %s, rootDir: %q", args.ID, conf.RootDir) c, err := New(conf, args) if err != nil { @@ -517,7 +518,7 @@ func (c *Container) SandboxPid() int { // Wait waits for the container to exit, and returns its WaitStatus. // Call to wait on a stopped container is needed to retrieve the exit status // and wait returns immediately. -func (c *Container) Wait() (syscall.WaitStatus, error) { +func (c *Container) Wait() (unix.WaitStatus, error) { log.Debugf("Wait on container, cid: %s", c.ID) ws, err := c.Sandbox.Wait(c.ID) if err == nil { @@ -529,7 +530,7 @@ func (c *Container) Wait() (syscall.WaitStatus, error) { // WaitRootPID waits for process 'pid' in the sandbox's PID namespace and // returns its WaitStatus. -func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) { +func (c *Container) WaitRootPID(pid int32) (unix.WaitStatus, error) { log.Debugf("Wait on process %d in sandbox, cid: %s", pid, c.Sandbox.ID) if !c.IsSandboxRunning() { return 0, fmt.Errorf("sandbox is not running") @@ -539,7 +540,7 @@ func (c *Container) WaitRootPID(pid int32) (syscall.WaitStatus, error) { // WaitPID waits for process 'pid' in the container's PID namespace and returns // its WaitStatus. -func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) { +func (c *Container) WaitPID(pid int32) (unix.WaitStatus, error) { log.Debugf("Wait on process %d in container, cid: %s", pid, c.ID) if !c.IsSandboxRunning() { return 0, fmt.Errorf("sandbox is not running") @@ -551,7 +552,7 @@ func (c *Container) WaitPID(pid int32) (syscall.WaitStatus, error) { // is SIGKILL, then waits for all processes to exit before returning. // SignalContainer returns an error if the container is already stopped. // TODO(b/113680494): Distinguish different error types. -func (c *Container) SignalContainer(sig syscall.Signal, all bool) error { +func (c *Container) SignalContainer(sig unix.Signal, all bool) error { log.Debugf("Signal container, cid: %s, signal: %v (%d)", c.ID, sig, sig) // Signaling container in Stopped state is allowed. When all=false, // an error will be returned anyway; when all=true, this allows @@ -568,7 +569,7 @@ func (c *Container) SignalContainer(sig syscall.Signal, all bool) error { } // SignalProcess sends sig to a specific process in the container. -func (c *Container) SignalProcess(sig syscall.Signal, pid int32) error { +func (c *Container) SignalProcess(sig unix.Signal, pid int32) error { log.Debugf("Signal process %d in container, cid: %s, signal: %v (%d)", pid, c.ID, sig, sig) if err := c.requireStatus("signal a process inside", Running); err != nil { return err @@ -586,7 +587,7 @@ func (c *Container) ForwardSignals(pid int32, fgProcess bool) func() { log.Debugf("Forwarding all signals to container, cid: %s, PIDPID: %d, fgProcess: %t", c.ID, pid, fgProcess) stop := sighandling.StartSignalForwarding(func(sig linux.Signal) { log.Debugf("Forwarding signal %d to container, cid: %s, PID: %d, fgProcess: %t", sig, c.ID, pid, fgProcess) - if err := c.Sandbox.SignalProcess(c.ID, pid, syscall.Signal(sig), fgProcess); err != nil { + if err := c.Sandbox.SignalProcess(c.ID, pid, unix.Signal(sig), fgProcess); err != nil { log.Warningf("error forwarding signal %d to container %q: %v", sig, c.ID, err) } }) @@ -768,9 +769,9 @@ func (c *Container) stop() error { // Try killing gofer if it does not exit with container. if c.GoferPid != 0 { log.Debugf("Killing gofer for container, cid: %s, PID: %d", c.ID, c.GoferPid) - if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil { + if err := unix.Kill(c.GoferPid, unix.SIGKILL); err != nil { // The gofer may already be stopped, log the error. - log.Warningf("Error sending signal %d to gofer %d: %v", syscall.SIGKILL, c.GoferPid, err) + log.Warningf("Error sending signal %d to gofer %d: %v", unix.SIGKILL, c.GoferPid, err) } } @@ -793,7 +794,7 @@ func (c *Container) waitForStopped() error { b := backoff.WithContext(backoff.NewConstantBackOff(100*time.Millisecond), ctx) op := func() error { if c.IsSandboxRunning() { - if err := c.SignalContainer(syscall.Signal(0), false); err == nil { + if err := c.SignalContainer(unix.Signal(0), false); err == nil { return fmt.Errorf("container is still running") } } @@ -803,7 +804,7 @@ func (c *Container) waitForStopped() error { if c.goferIsChild { // The gofer process is a child of the current process, // so we can wait it and collect its zombie. - wpid, err := syscall.Wait4(int(c.GoferPid), nil, syscall.WNOHANG, nil) + wpid, err := unix.Wait4(int(c.GoferPid), nil, unix.WNOHANG, nil) if err != nil { return fmt.Errorf("error waiting the gofer process: %v", err) } @@ -811,7 +812,7 @@ func (c *Container) waitForStopped() error { return fmt.Errorf("gofer is still running") } - } else if err := syscall.Kill(c.GoferPid, 0); err == nil { + } else if err := unix.Kill(c.GoferPid, 0); err == nil { return fmt.Errorf("gofer is still running") } c.GoferPid = 0 @@ -892,7 +893,7 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu sandEnds := make([]*os.File, 0, mountCount) for i := 0; i < mountCount; i++ { - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM|syscall.SOCK_CLOEXEC, 0) + fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM|unix.SOCK_CLOEXEC, 0) if err != nil { return nil, nil, err } @@ -914,8 +915,8 @@ func (c *Container) createGoferProcess(spec *specs.Spec, conf *config.Config, bu if attached { // The gofer is attached to the lifetime of this process, so it // should synchronously die when this process dies. - cmd.SysProcAttr = &syscall.SysProcAttr{ - Pdeathsig: syscall.SIGKILL, + cmd.SysProcAttr = &unix.SysProcAttr{ + Pdeathsig: unix.SIGKILL, } } @@ -1113,7 +1114,7 @@ func setOOMScoreAdj(pid int, scoreAdj int) error { } defer f.Close() if _, err := f.WriteString(strconv.Itoa(scoreAdj)); err != nil { - if errors.Is(err, syscall.ESRCH) { + if errors.Is(err, unix.ESRCH) { log.Warningf("Process (%d) exited while setting oom_score_adj", pid) return nil } diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 862d9444d..5a0c468a4 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -27,12 +27,12 @@ import ( "reflect" "strconv" "strings" - "syscall" "testing" "time" "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/log" @@ -103,7 +103,7 @@ func waitForProcessCount(cont *Container, want int) error { func blockUntilWaitable(pid int) error { _, _, err := specutils.RetryEintr(func() (uintptr, uintptr, error) { var err error - _, _, err1 := syscall.Syscall6(syscall.SYS_WAITID, 1, uintptr(pid), 0, syscall.WEXITED|syscall.WNOWAIT, 0, 0) + _, _, err1 := unix.Syscall6(unix.SYS_WAITID, 1, uintptr(pid), 0, unix.WEXITED|unix.WNOWAIT, 0, 0) if err1 != 0 { err = err1 } @@ -468,7 +468,7 @@ func TestLifecycle(t *testing.T) { if err != nil { ch <- err } - if got, want := ws.Signal(), syscall.SIGTERM; got != want { + if got, want := ws.Signal(), unix.SIGTERM; got != want { ch <- fmt.Errorf("got signal %v, want %v", got, want) } ch <- nil @@ -479,8 +479,8 @@ func TestLifecycle(t *testing.T) { time.Sleep(time.Second) // Send the container a SIGTERM which will cause it to stop. - if err := c.SignalContainer(syscall.SIGTERM, false); err != nil { - t.Fatalf("error sending signal %v to container: %v", syscall.SIGTERM, err) + if err := c.SignalContainer(unix.SIGTERM, false); err != nil { + t.Fatalf("error sending signal %v to container: %v", unix.SIGTERM, err) } // Wait for it to die. @@ -815,11 +815,11 @@ func TestExec(t *testing.T) { t.Run("nonexist", func(t *testing.T) { // b/179114837 found by Syzkaller that causes nil pointer panic when // trying to dec-ref an unix socket FD. - fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + fds, err := unix.Socketpair(unix.AF_UNIX, unix.SOCK_STREAM, 0) if err != nil { t.Fatal(err) } - defer syscall.Close(fds[0]) + defer unix.Close(fds[0]) _, err = cont.executeSync(&control.ExecArgs{ Argv: []string{"/nonexist"}, @@ -956,7 +956,7 @@ func TestKillPid(t *testing.T) { pid = int32(p.PID) } } - if err := cont.SignalProcess(syscall.SIGKILL, pid); err != nil { + if err := cont.SignalProcess(unix.SIGKILL, pid); err != nil { t.Fatalf("failed to signal process %d: %v", pid, err) } @@ -1601,12 +1601,12 @@ func TestReadonlyRoot(t *testing.T) { } // Read mounts to check that root is readonly. - out, err := executeCombinedOutput(c, "/bin/sh", "-c", "mount | grep ' / '") + out, err := executeCombinedOutput(c, "/bin/sh", "-c", "mount | grep ' / ' | grep -o -e '(.*)'") if err != nil { t.Fatalf("exec failed: %v", err) } - t.Logf("root mount: %q", out) - if !strings.Contains(string(out), "(ro)") { + t.Logf("root mount options: %q", out) + if !strings.Contains(string(out), "ro") { t.Errorf("root not mounted readonly: %q", out) } @@ -1615,7 +1615,7 @@ func TestReadonlyRoot(t *testing.T) { if err != nil { t.Fatalf("touch file in ro mount: %v", err) } - if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { + if !ws.Exited() || unix.Errno(ws.ExitStatus()) != unix.EPERM { t.Fatalf("wrong waitStatus: %v", ws) } }) @@ -1659,13 +1659,13 @@ func TestReadonlyMount(t *testing.T) { } // Read mounts to check that volume is readonly. - cmd := fmt.Sprintf("mount | grep ' %s '", dir) + cmd := fmt.Sprintf("mount | grep ' %s ' | grep -o -e '(.*)'", dir) out, err := executeCombinedOutput(c, "/bin/sh", "-c", cmd) if err != nil { t.Fatalf("exec failed, err: %v", err) } - t.Logf("mount: %q", out) - if !strings.Contains(string(out), "(ro)") { + t.Logf("mount options: %q", out) + if !strings.Contains(string(out), "ro") { t.Errorf("volume not mounted readonly: %q", out) } @@ -1674,7 +1674,7 @@ func TestReadonlyMount(t *testing.T) { if err != nil { t.Fatalf("touch file in ro mount: %v", err) } - if !ws.Exited() || syscall.Errno(ws.ExitStatus()) != syscall.EPERM { + if !ws.Exited() || unix.Errno(ws.ExitStatus()) != unix.EPERM { t.Fatalf("wrong WaitStatus: %v", ws) } }) @@ -1750,8 +1750,8 @@ func TestUIDMap(t *testing.T) { if !ws.Exited() || ws.ExitStatus() != 0 { t.Fatalf("container failed, waitStatus: %v", ws) } - st := syscall.Stat_t{} - if err := syscall.Stat(testFile, &st); err != nil { + st := unix.Stat_t{} + if err := unix.Stat(testFile, &st); err != nil { t.Fatalf("error stat /testfile: %v", err) } @@ -1880,7 +1880,7 @@ func doGoferExitTest(t *testing.T, vfs2 bool) { } err = blockUntilWaitable(c.GoferPid) - if err != nil && err != syscall.ECHILD { + if err != nil && err != unix.ECHILD { t.Errorf("error waiting for gofer to exit: %v", err) } } @@ -1929,7 +1929,7 @@ func TestUserLog(t *testing.T) { } // sched_rr_get_interval - not implemented in gvisor. - num := strconv.Itoa(syscall.SYS_SCHED_RR_GET_INTERVAL) + num := strconv.Itoa(unix.SYS_SCHED_RR_GET_INTERVAL) spec := testutil.NewSpecWithArgs(app, "syscall", "--syscall="+num) conf := testutil.TestConfig(t) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) @@ -2159,10 +2159,10 @@ func TestMountPropagation(t *testing.T) { f.Close() // Setup src as a shared mount. - if err := syscall.Mount(src, src, "bind", syscall.MS_BIND, ""); err != nil { + if err := unix.Mount(src, src, "bind", unix.MS_BIND, ""); err != nil { t.Fatalf("mount(%q, %q, MS_BIND): %v", dir, srcMnt, err) } - if err := syscall.Mount("", src, "", syscall.MS_SHARED, ""); err != nil { + if err := unix.Mount("", src, "", unix.MS_SHARED, ""); err != nil { t.Fatalf("mount(%q, MS_SHARED): %v", srcMnt, err) } @@ -2209,7 +2209,7 @@ func TestMountPropagation(t *testing.T) { // After the container is started, mount dir inside source and check what // happens to both destinations. - if err := syscall.Mount(dir, srcMnt, "bind", syscall.MS_BIND, ""); err != nil { + if err := unix.Mount(dir, srcMnt, "bind", unix.MS_BIND, ""); err != nil { t.Fatalf("mount(%q, %q, MS_BIND): %v", dir, srcMnt, err) } @@ -2449,7 +2449,7 @@ func TestCreateWithCorruptedStateFile(t *testing.T) { } } -func execute(cont *Container, name string, arg ...string) (syscall.WaitStatus, error) { +func execute(cont *Container, name string, arg ...string) (unix.WaitStatus, error) { args := &control.ExecArgs{ Filename: name, Argv: append([]string{name}, arg...), @@ -2483,7 +2483,7 @@ func executeCombinedOutput(cont *Container, name string, arg ...string) ([]byte, } // executeSync synchronously executes a new process. -func (c *Container) executeSync(args *control.ExecArgs) (syscall.WaitStatus, error) { +func (c *Container) executeSync(args *control.ExecArgs) (unix.WaitStatus, error) { pid, err := c.Execute(args) if err != nil { return 0, fmt.Errorf("error executing: %v", err) diff --git a/runsc/container/multi_container_test.go b/runsc/container/multi_container_test.go index b434cdb23..0f0a223ce 100644 --- a/runsc/container/multi_container_test.go +++ b/runsc/container/multi_container_test.go @@ -22,11 +22,11 @@ import ( "path" "path/filepath" "strings" - "syscall" "testing" "time" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -403,7 +403,7 @@ func TestMultiPIDNSKill(t *testing.T) { t.Logf("Container %q procs: %s", c.ID, procListToString(procs)) pidToKill := procs[processes-1].PID t.Logf("PID to kill: %d", pidToKill) - if err := c.SignalProcess(syscall.SIGKILL, int32(pidToKill)); err != nil { + if err := c.SignalProcess(unix.SIGKILL, int32(pidToKill)); err != nil { t.Errorf("container.SignalProcess: %v", err) } // Wait for the process to get killed. @@ -432,7 +432,7 @@ func TestMultiPIDNSKill(t *testing.T) { pidToKill = procs[len(procs)-1].PID t.Logf("PID that should not be killed: %d", pidToKill) - err = c.SignalProcess(syscall.SIGKILL, int32(pidToKill)) + err = c.SignalProcess(unix.SIGKILL, int32(pidToKill)) if err == nil { t.Fatalf("killing another container's process should fail") } @@ -640,7 +640,7 @@ func TestMultiContainerSignal(t *testing.T) { } // Kill process 2. - if err := containers[1].SignalContainer(syscall.SIGKILL, false); err != nil { + if err := containers[1].SignalContainer(unix.SIGKILL, false); err != nil { t.Errorf("failed to kill process 2: %v", err) } @@ -660,10 +660,10 @@ func TestMultiContainerSignal(t *testing.T) { t.Errorf("failed to destroy container: %v", err) } _, _, err = specutils.RetryEintr(func() (uintptr, uintptr, error) { - cpid, err := syscall.Wait4(goferPid, nil, 0, nil) + cpid, err := unix.Wait4(goferPid, nil, 0, nil) return uintptr(cpid), 0, err }) - if err != syscall.ECHILD { + if err != unix.ECHILD { t.Errorf("error waiting for gofer to exit: %v", err) } // Make sure process 1 is still running. @@ -673,28 +673,28 @@ func TestMultiContainerSignal(t *testing.T) { // Now that process 2 is gone, ensure we get an error trying to // signal it again. - if err := containers[1].SignalContainer(syscall.SIGKILL, false); err == nil { + if err := containers[1].SignalContainer(unix.SIGKILL, false); err == nil { t.Errorf("container %q shouldn't exist, but we were able to signal it", containers[1].ID) } // Kill process 1. - if err := containers[0].SignalContainer(syscall.SIGKILL, false); err != nil { + if err := containers[0].SignalContainer(unix.SIGKILL, false); err != nil { t.Errorf("failed to kill process 1: %v", err) } // Ensure that container's gofer and sandbox process are no more. err = blockUntilWaitable(containers[0].GoferPid) - if err != nil && err != syscall.ECHILD { + if err != nil && err != unix.ECHILD { t.Errorf("error waiting for gofer to exit: %v", err) } err = blockUntilWaitable(containers[0].Sandbox.Pid) - if err != nil && err != syscall.ECHILD { + if err != nil && err != unix.ECHILD { t.Errorf("error waiting for sandbox to exit: %v", err) } // The sentry should be gone, so signaling should yield an error. - if err := containers[0].SignalContainer(syscall.SIGKILL, false); err == nil { + if err := containers[0].SignalContainer(unix.SIGKILL, false); err == nil { t.Errorf("sandbox %q shouldn't exist, but we were able to signal it", containers[0].Sandbox.ID) } @@ -893,7 +893,7 @@ func TestMultiContainerKillAll(t *testing.T) { if tc.killContainer { // First kill the init process to make the container be stopped with // processes still running inside. - containers[1].SignalContainer(syscall.SIGKILL, false) + containers[1].SignalContainer(unix.SIGKILL, false) op := func() error { c, err := Load(conf.RootDir, FullID{ContainerID: ids[1]}, LoadOpts{}) if err != nil { @@ -914,7 +914,7 @@ func TestMultiContainerKillAll(t *testing.T) { t.Fatalf("failed to load child container %q: %v", c.ID, err) } // Kill'Em All - if err := c.SignalContainer(syscall.SIGKILL, true); err != nil { + if err := c.SignalContainer(unix.SIGKILL, true); err != nil { t.Fatalf("failed to send SIGKILL to container %q: %v", c.ID, err) } @@ -1640,8 +1640,8 @@ func TestMultiContainerGoferKilled(t *testing.T) { } // Kill container's gofer. - if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil { - t.Fatalf("syscall.Kill(%d, SIGKILL)=%v", c.GoferPid, err) + if err := unix.Kill(c.GoferPid, unix.SIGKILL); err != nil { + t.Fatalf("unix.Kill(%d, SIGKILL)=%v", c.GoferPid, err) } // Wait until container stops. @@ -1672,8 +1672,8 @@ func TestMultiContainerGoferKilled(t *testing.T) { // Kill root container's gofer to bring entire sandbox down. c = containers[0] - if err := syscall.Kill(c.GoferPid, syscall.SIGKILL); err != nil { - t.Fatalf("syscall.Kill(%d, SIGKILL)=%v", c.GoferPid, err) + if err := unix.Kill(c.GoferPid, unix.SIGKILL); err != nil { + t.Fatalf("unix.Kill(%d, SIGKILL)=%v", c.GoferPid, err) } // Wait until sandbox stops. waitForProcessList will loop until sandbox exits diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go index c46322ba4..0399903a0 100644 --- a/runsc/container/state_file.go +++ b/runsc/container/state_file.go @@ -22,9 +22,9 @@ import ( "path/filepath" "regexp" "strings" - "syscall" "github.com/gofrs/flock" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sync" ) @@ -89,7 +89,7 @@ func Load(rootDir string, id FullID, opts LoadOpts) (*Container, error) { c.changeStatus(Stopped) } case Running: - if err := c.SignalContainer(syscall.Signal(0), false); err != nil { + if err := c.SignalContainer(unix.Signal(0), false); err != nil { c.changeStatus(Stopped) } } @@ -245,7 +245,7 @@ type StateFile struct { // lock globally locks all locking operations for the container. func (s *StateFile) lock() error { s.once.Do(func() { - s.flock = flock.NewFlock(s.lockPath()) + s.flock = flock.New(s.lockPath()) }) if err := s.flock.Lock(); err != nil { diff --git a/runsc/fsgofer/filter/config.go b/runsc/fsgofer/filter/config.go index d1af539cb..fd72414ce 100644 --- a/runsc/fsgofer/filter/config.go +++ b/runsc/fsgofer/filter/config.go @@ -16,7 +16,6 @@ package filter import ( "os" - "syscall" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" @@ -25,12 +24,12 @@ import ( // allowedSyscalls is the set of syscalls executed by the gofer. var allowedSyscalls = seccomp.SyscallRules{ - syscall.SYS_ACCEPT: {}, - syscall.SYS_CLOCK_GETTIME: {}, - syscall.SYS_CLOSE: {}, - syscall.SYS_DUP: {}, - syscall.SYS_EPOLL_CTL: {}, - syscall.SYS_EPOLL_PWAIT: []seccomp.Rule{ + unix.SYS_ACCEPT: {}, + unix.SYS_CLOCK_GETTIME: {}, + unix.SYS_CLOSE: {}, + unix.SYS_DUP: {}, + unix.SYS_EPOLL_CTL: {}, + unix.SYS_EPOLL_PWAIT: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, @@ -39,34 +38,34 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.EqualTo(0), }, }, - syscall.SYS_EVENTFD2: []seccomp.Rule{ + unix.SYS_EVENTFD2: []seccomp.Rule{ { seccomp.EqualTo(0), seccomp.EqualTo(0), }, }, - syscall.SYS_EXIT: {}, - syscall.SYS_EXIT_GROUP: {}, - syscall.SYS_FALLOCATE: []seccomp.Rule{ + unix.SYS_EXIT: {}, + unix.SYS_EXIT_GROUP: {}, + unix.SYS_FALLOCATE: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.EqualTo(0), }, }, - syscall.SYS_FCHMOD: {}, - syscall.SYS_FCHOWNAT: {}, - syscall.SYS_FCNTL: []seccomp.Rule{ + unix.SYS_FCHMOD: {}, + unix.SYS_FCHOWNAT: {}, + unix.SYS_FCNTL: []seccomp.Rule{ { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.F_GETFL), + seccomp.EqualTo(unix.F_GETFL), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.F_SETFL), + seccomp.EqualTo(unix.F_SETFL), }, { seccomp.MatchAny{}, - seccomp.EqualTo(syscall.F_GETFD), + seccomp.EqualTo(unix.F_GETFD), }, // Used by flipcall.PacketWindowAllocator.Init(). { @@ -74,11 +73,11 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.EqualTo(unix.F_ADD_SEALS), }, }, - syscall.SYS_FSTAT: {}, - syscall.SYS_FSTATFS: {}, - syscall.SYS_FSYNC: {}, - syscall.SYS_FTRUNCATE: {}, - syscall.SYS_FUTEX: { + unix.SYS_FSTAT: {}, + unix.SYS_FSTATFS: {}, + unix.SYS_FSYNC: {}, + unix.SYS_FTRUNCATE: {}, + unix.SYS_FUTEX: { seccomp.Rule{ seccomp.MatchAny{}, seccomp.EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG), @@ -116,78 +115,78 @@ var allowedSyscalls = seccomp.SyscallRules{ seccomp.EqualTo(0), }, }, - syscall.SYS_GETDENTS64: {}, - syscall.SYS_GETPID: {}, - unix.SYS_GETRANDOM: {}, - syscall.SYS_GETTID: {}, - syscall.SYS_GETTIMEOFDAY: {}, - syscall.SYS_LINKAT: {}, - syscall.SYS_LSEEK: {}, - syscall.SYS_MADVISE: {}, - unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init(). - syscall.SYS_MKDIRAT: {}, - syscall.SYS_MKNODAT: {}, + unix.SYS_GETDENTS64: {}, + unix.SYS_GETPID: {}, + unix.SYS_GETRANDOM: {}, + unix.SYS_GETTID: {}, + unix.SYS_GETTIMEOFDAY: {}, + unix.SYS_LINKAT: {}, + unix.SYS_LSEEK: {}, + unix.SYS_MADVISE: {}, + unix.SYS_MEMFD_CREATE: {}, /// Used by flipcall.PacketWindowAllocator.Init(). + unix.SYS_MKDIRAT: {}, + unix.SYS_MKNODAT: {}, // Used by the Go runtime as a temporarily workaround for a Linux // 5.2-5.4 bug. // // See src/runtime/os_linux_x86.go. // // TODO(b/148688965): Remove once this is gone from Go. - syscall.SYS_MLOCK: []seccomp.Rule{ + unix.SYS_MLOCK: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.EqualTo(4096), }, }, - syscall.SYS_MMAP: []seccomp.Rule{ + unix.SYS_MMAP: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MAP_SHARED), + seccomp.EqualTo(unix.MAP_SHARED), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS), + seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MAP_PRIVATE | syscall.MAP_ANONYMOUS | syscall.MAP_FIXED), + seccomp.EqualTo(unix.MAP_PRIVATE | unix.MAP_ANONYMOUS | unix.MAP_FIXED), }, }, - syscall.SYS_MPROTECT: {}, - syscall.SYS_MUNMAP: {}, - syscall.SYS_NANOSLEEP: {}, - syscall.SYS_OPENAT: {}, - syscall.SYS_PPOLL: {}, - syscall.SYS_PREAD64: {}, - syscall.SYS_PWRITE64: {}, - syscall.SYS_READ: {}, - syscall.SYS_READLINKAT: {}, - syscall.SYS_RECVMSG: []seccomp.Rule{ + unix.SYS_MPROTECT: {}, + unix.SYS_MUNMAP: {}, + unix.SYS_NANOSLEEP: {}, + unix.SYS_OPENAT: {}, + unix.SYS_PPOLL: {}, + unix.SYS_PREAD64: {}, + unix.SYS_PWRITE64: {}, + unix.SYS_READ: {}, + unix.SYS_READLINKAT: {}, + unix.SYS_RECVMSG: []seccomp.Rule{ { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC), + seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_TRUNC), }, { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_TRUNC | syscall.MSG_PEEK), + seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_TRUNC | unix.MSG_PEEK), }, }, - syscall.SYS_RENAMEAT: {}, - syscall.SYS_RESTART_SYSCALL: {}, + unix.SYS_RENAMEAT: {}, + unix.SYS_RESTART_SYSCALL: {}, // May be used by the runtime during panic(). - syscall.SYS_RT_SIGACTION: {}, - syscall.SYS_RT_SIGPROCMASK: {}, - syscall.SYS_RT_SIGRETURN: {}, - syscall.SYS_SCHED_YIELD: {}, - syscall.SYS_SENDMSG: []seccomp.Rule{ + unix.SYS_RT_SIGACTION: {}, + unix.SYS_RT_SIGPROCMASK: {}, + unix.SYS_RT_SIGRETURN: {}, + unix.SYS_SCHED_YIELD: {}, + unix.SYS_SENDMSG: []seccomp.Rule{ // Used by fdchannel.Endpoint.SendFD(). { seccomp.MatchAny{}, @@ -198,51 +197,51 @@ var allowedSyscalls = seccomp.SyscallRules{ { seccomp.MatchAny{}, seccomp.MatchAny{}, - seccomp.EqualTo(syscall.MSG_DONTWAIT | syscall.MSG_NOSIGNAL), + seccomp.EqualTo(unix.MSG_DONTWAIT | unix.MSG_NOSIGNAL), }, }, - syscall.SYS_SHUTDOWN: []seccomp.Rule{ - {seccomp.MatchAny{}, seccomp.EqualTo(syscall.SHUT_RDWR)}, + unix.SYS_SHUTDOWN: []seccomp.Rule{ + {seccomp.MatchAny{}, seccomp.EqualTo(unix.SHUT_RDWR)}, }, - syscall.SYS_SIGALTSTACK: {}, + unix.SYS_SIGALTSTACK: {}, // Used by fdchannel.NewConnectedSockets(). - syscall.SYS_SOCKETPAIR: { + unix.SYS_SOCKETPAIR: { { - seccomp.EqualTo(syscall.AF_UNIX), - seccomp.EqualTo(syscall.SOCK_SEQPACKET | syscall.SOCK_CLOEXEC), + seccomp.EqualTo(unix.AF_UNIX), + seccomp.EqualTo(unix.SOCK_SEQPACKET | unix.SOCK_CLOEXEC), seccomp.EqualTo(0), }, }, - syscall.SYS_SYMLINKAT: {}, - syscall.SYS_TGKILL: []seccomp.Rule{ + unix.SYS_SYMLINKAT: {}, + unix.SYS_TGKILL: []seccomp.Rule{ { seccomp.EqualTo(uint64(os.Getpid())), }, }, - syscall.SYS_UNLINKAT: {}, - syscall.SYS_UTIMENSAT: {}, - syscall.SYS_WRITE: {}, + unix.SYS_UNLINKAT: {}, + unix.SYS_UTIMENSAT: {}, + unix.SYS_WRITE: {}, } var udsSyscalls = seccomp.SyscallRules{ - syscall.SYS_SOCKET: []seccomp.Rule{ + unix.SYS_SOCKET: []seccomp.Rule{ { - seccomp.EqualTo(syscall.AF_UNIX), - seccomp.EqualTo(syscall.SOCK_STREAM), + seccomp.EqualTo(unix.AF_UNIX), + seccomp.EqualTo(unix.SOCK_STREAM), seccomp.EqualTo(0), }, { - seccomp.EqualTo(syscall.AF_UNIX), - seccomp.EqualTo(syscall.SOCK_DGRAM), + seccomp.EqualTo(unix.AF_UNIX), + seccomp.EqualTo(unix.SOCK_DGRAM), seccomp.EqualTo(0), }, { - seccomp.EqualTo(syscall.AF_UNIX), - seccomp.EqualTo(syscall.SOCK_SEQPACKET), + seccomp.EqualTo(unix.AF_UNIX), + seccomp.EqualTo(unix.SOCK_SEQPACKET), seccomp.EqualTo(0), }, }, - syscall.SYS_CONNECT: []seccomp.Rule{ + unix.SYS_CONNECT: []seccomp.Rule{ { seccomp.MatchAny{}, }, diff --git a/runsc/fsgofer/filter/config_amd64.go b/runsc/fsgofer/filter/config_amd64.go index 686753d96..2d0151dcc 100644 --- a/runsc/fsgofer/filter/config_amd64.go +++ b/runsc/fsgofer/filter/config_amd64.go @@ -17,30 +17,29 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/seccomp" ) func init() { - allowedSyscalls[syscall.SYS_ARCH_PRCTL] = []seccomp.Rule{ + allowedSyscalls[unix.SYS_ARCH_PRCTL] = []seccomp.Rule{ // TODO(b/168828518): No longer used in Go 1.16+. {seccomp.EqualTo(linux.ARCH_SET_FS)}, } - allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{ + allowedSyscalls[unix.SYS_CLONE] = []seccomp.Rule{ // parent_tidptr and child_tidptr are always 0 because neither // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used. { seccomp.EqualTo( - syscall.CLONE_VM | - syscall.CLONE_FS | - syscall.CLONE_FILES | - syscall.CLONE_SETTLS | - syscall.CLONE_SIGHAND | - syscall.CLONE_SYSVSEM | - syscall.CLONE_THREAD), + unix.CLONE_VM | + unix.CLONE_FS | + unix.CLONE_FILES | + unix.CLONE_SETTLS | + unix.CLONE_SIGHAND | + unix.CLONE_SYSVSEM | + unix.CLONE_THREAD), seccomp.MatchAny{}, // newsp seccomp.EqualTo(0), // parent_tidptr seccomp.EqualTo(0), // child_tidptr @@ -49,12 +48,12 @@ func init() { { // TODO(b/168828518): No longer used in Go 1.16+ (on amd64). seccomp.EqualTo( - syscall.CLONE_VM | - syscall.CLONE_FS | - syscall.CLONE_FILES | - syscall.CLONE_SIGHAND | - syscall.CLONE_SYSVSEM | - syscall.CLONE_THREAD), + unix.CLONE_VM | + unix.CLONE_FS | + unix.CLONE_FILES | + unix.CLONE_SIGHAND | + unix.CLONE_SYSVSEM | + unix.CLONE_THREAD), seccomp.MatchAny{}, // newsp seccomp.EqualTo(0), // parent_tidptr seccomp.EqualTo(0), // child_tidptr @@ -62,5 +61,5 @@ func init() { }, } - allowedSyscalls[syscall.SYS_NEWFSTATAT] = []seccomp.Rule{} + allowedSyscalls[unix.SYS_NEWFSTATAT] = []seccomp.Rule{} } diff --git a/runsc/fsgofer/filter/config_arm64.go b/runsc/fsgofer/filter/config_arm64.go index ff0cf77a0..7d458c02d 100644 --- a/runsc/fsgofer/filter/config_arm64.go +++ b/runsc/fsgofer/filter/config_arm64.go @@ -17,23 +17,22 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/seccomp" ) func init() { - allowedSyscalls[syscall.SYS_CLONE] = []seccomp.Rule{ + allowedSyscalls[unix.SYS_CLONE] = []seccomp.Rule{ // parent_tidptr and child_tidptr are always 0 because neither // CLONE_PARENT_SETTID nor CLONE_CHILD_SETTID are used. { seccomp.EqualTo( - syscall.CLONE_VM | - syscall.CLONE_FS | - syscall.CLONE_FILES | - syscall.CLONE_SIGHAND | - syscall.CLONE_SYSVSEM | - syscall.CLONE_THREAD), + unix.CLONE_VM | + unix.CLONE_FS | + unix.CLONE_FILES | + unix.CLONE_SIGHAND | + unix.CLONE_SYSVSEM | + unix.CLONE_THREAD), seccomp.MatchAny{}, // newsp // These arguments are left uninitialized by the Go // runtime, so they may be anything (and are unused by @@ -44,5 +43,5 @@ func init() { }, } - allowedSyscalls[syscall.SYS_FSTATAT] = []seccomp.Rule{} + allowedSyscalls[unix.SYS_FSTATAT] = []seccomp.Rule{} } diff --git a/runsc/fsgofer/filter/extra_filters_msan.go b/runsc/fsgofer/filter/extra_filters_msan.go index 8c6179c8f..d768ed0bb 100644 --- a/runsc/fsgofer/filter/extra_filters_msan.go +++ b/runsc/fsgofer/filter/extra_filters_msan.go @@ -17,8 +17,7 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/seccomp" ) @@ -27,7 +26,7 @@ import ( func instrumentationFilters() seccomp.SyscallRules { log.Warningf("*** SECCOMP WARNING: MSAN is enabled: syscall filters less restrictive!") return seccomp.SyscallRules{ - syscall.SYS_SCHED_GETAFFINITY: {}, - syscall.SYS_SET_ROBUST_LIST: {}, + unix.SYS_SCHED_GETAFFINITY: {}, + unix.SYS_SET_ROBUST_LIST: {}, } } diff --git a/runsc/fsgofer/filter/extra_filters_race.go b/runsc/fsgofer/filter/extra_filters_race.go index cbd5c487e..9e75c025d 100644 --- a/runsc/fsgofer/filter/extra_filters_race.go +++ b/runsc/fsgofer/filter/extra_filters_race.go @@ -17,8 +17,7 @@ package filter import ( - "syscall" - + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/seccomp" ) @@ -27,18 +26,18 @@ import ( func instrumentationFilters() seccomp.SyscallRules { log.Warningf("*** SECCOMP WARNING: TSAN is enabled: syscall filters less restrictive!") return seccomp.SyscallRules{ - syscall.SYS_BRK: {}, - syscall.SYS_CLOCK_NANOSLEEP: {}, - syscall.SYS_CLONE: {}, - syscall.SYS_FUTEX: {}, - syscall.SYS_MADVISE: {}, - syscall.SYS_MMAP: {}, - syscall.SYS_MUNLOCK: {}, - syscall.SYS_NANOSLEEP: {}, - syscall.SYS_OPEN: {}, - syscall.SYS_OPENAT: {}, - syscall.SYS_SET_ROBUST_LIST: {}, + unix.SYS_BRK: {}, + unix.SYS_CLOCK_NANOSLEEP: {}, + unix.SYS_CLONE: {}, + unix.SYS_FUTEX: {}, + unix.SYS_MADVISE: {}, + unix.SYS_MMAP: {}, + unix.SYS_MUNLOCK: {}, + unix.SYS_NANOSLEEP: {}, + unix.SYS_OPEN: {}, + unix.SYS_OPENAT: {}, + unix.SYS_SET_ROBUST_LIST: {}, // Used within glibc's malloc. - syscall.SYS_TIME: {}, + unix.SYS_TIME: {}, } } diff --git a/runsc/fsgofer/fsgofer.go b/runsc/fsgofer/fsgofer.go index cfa3796b1..1e80a634d 100644 --- a/runsc/fsgofer/fsgofer.go +++ b/runsc/fsgofer/fsgofer.go @@ -66,6 +66,9 @@ type Config struct { // HostUDS signals whether the gofer can mount a host's UDS. HostUDS bool + + // enableXattr allows Get/SetXattr for the mounted file systems. + EnableXattr bool } type attachPoint struct { @@ -795,12 +798,22 @@ func (l *localFile) SetAttr(valid p9.SetAttrMask, attr p9.SetAttr) error { return err } -func (*localFile) GetXattr(string, uint64) (string, error) { - return "", unix.EOPNOTSUPP +func (l *localFile) GetXattr(name string, size uint64) (string, error) { + if !l.attachPoint.conf.EnableXattr { + return "", unix.EOPNOTSUPP + } + buffer := make([]byte, size) + if _, err := unix.Fgetxattr(l.file.FD(), name, buffer); err != nil { + return "", err + } + return string(buffer), nil } -func (*localFile) SetXattr(string, string, uint32) error { - return unix.EOPNOTSUPP +func (l *localFile) SetXattr(name string, value string, flags uint32) error { + if !l.attachPoint.conf.EnableXattr { + return unix.EOPNOTSUPP + } + return unix.Fsetxattr(l.file.FD(), name, []byte(value), int(flags)) } func (*localFile) ListXattr(uint64) (map[string]struct{}, error) { diff --git a/runsc/fsgofer/fsgofer_test.go b/runsc/fsgofer/fsgofer_test.go index 99ea9bd32..a5f09f88f 100644 --- a/runsc/fsgofer/fsgofer_test.go +++ b/runsc/fsgofer/fsgofer_test.go @@ -565,6 +565,38 @@ func TestSetAttrOwner(t *testing.T) { }) } +func SetGetXattr(l *localFile, name string, value string) error { + if err := l.SetXattr(name, value, 0 /* flags */); err != nil { + return err + } + ret, err := l.GetXattr(name, uint64(len(value))) + if err != nil { + return err + } + if ret != value { + return fmt.Errorf("Got value %s, want %s", ret, value) + } + return nil +} + +func TestSetGetXattr(t *testing.T) { + xattrConfs := []Config{{ROMount: false, EnableXattr: false}, {ROMount: false, EnableXattr: true}} + runCustom(t, []uint32{unix.S_IFREG}, xattrConfs, func(t *testing.T, s state) { + name := "user.test" + value := "tmp" + err := SetGetXattr(s.file, name, value) + if s.conf.EnableXattr { + if err != nil { + t.Fatalf("%v: SetGetXattr failed, err: %v", s, err) + } + } else { + if err == nil { + t.Fatalf("%v: SetGetXattr should have failed", s) + } + } + }) +} + func TestLink(t *testing.T) { if !specutils.HasCapabilities(capability.CAP_DAC_READ_SEARCH) { t.Skipf("Link test requires CAP_DAC_READ_SEARCH, running as %d", os.Getuid()) diff --git a/runsc/mitigate/BUILD b/runsc/mitigate/BUILD index 561854e66..1238890fc 100644 --- a/runsc/mitigate/BUILD +++ b/runsc/mitigate/BUILD @@ -4,28 +4,20 @@ package(licenses = ["notice"]) go_library( name = "mitigate", - srcs = [ - "cpu.go", - "mitigate.go", - "mitigate_conf.go", - ], + srcs = ["mitigate.go"], visibility = [ "//runsc:__subpackages__", ], - deps = [ - "//pkg/log", - "//runsc/flag", - "@in_gopkg_yaml_v2//:go_default_library", - ], + deps = ["@in_gopkg_yaml_v2//:go_default_library"], ) go_test( name = "mitigate_test", size = "small", - srcs = [ - "cpu_test.go", - "mitigate_test.go", - ], + srcs = ["mitigate_test.go"], library = ":mitigate", - deps = ["@com_github_google_go_cmp//cmp:go_default_library"], + deps = [ + "//runsc/mitigate/mock", + "@com_github_google_go_cmp//cmp:go_default_library", + ], ) diff --git a/runsc/mitigate/cpu.go b/runsc/mitigate/cpu.go deleted file mode 100644 index 4b2aa351f..000000000 --- a/runsc/mitigate/cpu.go +++ /dev/null @@ -1,423 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mitigate - -import ( - "fmt" - "io/ioutil" - "regexp" - "strconv" - "strings" -) - -const ( - // mds is the only bug we care about. - mds = "mds" - - // Constants for parsing /proc/cpuinfo. - processorKey = "processor" - vendorIDKey = "vendor_id" - cpuFamilyKey = "cpu family" - modelKey = "model" - physicalIDKey = "physical id" - coreIDKey = "core id" - bugsKey = "bugs" - - // Path to shutdown a CPU. - cpuOnlineTemplate = "/sys/devices/system/cpu/cpu%d/online" -) - -// cpuSet contains a map of all CPUs on the system, mapped -// by Physical ID and CoreIDs. threads with the same -// Core and Physical ID are Hyperthread pairs. -type cpuSet map[cpuID]*threadGroup - -// newCPUSet creates a CPUSet from data read from /proc/cpuinfo. -func newCPUSet(data []byte, vulnerable func(thread) bool) (cpuSet, error) { - processors, err := getThreads(string(data)) - if err != nil { - return nil, err - } - - set := make(cpuSet) - for _, p := range processors { - // Each ID is of the form physicalID:coreID. Hyperthread pairs - // have identical physical and core IDs. We need to match - // Hyperthread pairs so that we can shutdown all but one per - // pair. - core, ok := set[p.id] - if !ok { - core = &threadGroup{} - set[p.id] = core - } - core.isVulnerable = core.isVulnerable || vulnerable(p) - core.threads = append(core.threads, p) - } - return set, nil -} - -// newCPUSetFromPossible makes a cpuSet data read from -// /sys/devices/system/cpu/possible. This is used in enable operations -// where the caller simply wants to enable all CPUS. -func newCPUSetFromPossible(data []byte) (cpuSet, error) { - threads, err := getThreadsFromPossible(data) - if err != nil { - return nil, err - } - - // We don't care if a CPU is vulnerable or not, we just - // want to return a list of all CPUs on the host. - set := cpuSet{ - threads[0].id: &threadGroup{ - threads: threads, - isVulnerable: false, - }, - } - return set, nil -} - -// String implements the String method for CPUSet. -func (c cpuSet) String() string { - ret := "" - for _, tg := range c { - ret += fmt.Sprintf("%s\n", tg) - } - return ret -} - -// getRemainingList returns the list of threads that will remain active -// after mitigation. -func (c cpuSet) getRemainingList() []thread { - threads := make([]thread, 0, len(c)) - for _, core := range c { - // If we're vulnerable, take only one thread from the pair. - if core.isVulnerable { - threads = append(threads, core.threads[0]) - continue - } - // Otherwise don't shutdown anything. - threads = append(threads, core.threads...) - } - return threads -} - -// getShutdownList returns the list of threads that will be shutdown on -// mitigation. -func (c cpuSet) getShutdownList() []thread { - threads := make([]thread, 0) - for _, core := range c { - // Only if we're vulnerable do shutdown anything. In this case, - // shutdown all but the first entry. - if core.isVulnerable && len(core.threads) > 1 { - threads = append(threads, core.threads[1:]...) - } - } - return threads -} - -// threadGroup represents Hyperthread pairs on the same physical/core ID. -type threadGroup struct { - threads []thread - isVulnerable bool -} - -// String implements the String method for threadGroup. -func (c threadGroup) String() string { - ret := fmt.Sprintf("ThreadGroup:\nIsVulnerable: %t\n", c.isVulnerable) - for _, processor := range c.threads { - ret += fmt.Sprintf("%s\n", processor) - } - return ret -} - -// getThreads returns threads structs from reading /proc/cpuinfo. -func getThreads(data string) ([]thread, error) { - // Each processor entry should start with the - // processor key. Find the beginings of each. - r := buildRegex(processorKey, `\d+`) - indices := r.FindAllStringIndex(data, -1) - if len(indices) < 1 { - return nil, fmt.Errorf("no cpus found for: %q", data) - } - - // Add the ending index for last entry. - indices = append(indices, []int{len(data), -1}) - - // Valid cpus are now defined by strings in between - // indexes (e.g. data[index[i], index[i+1]]). - // There should be len(indicies) - 1 CPUs - // since the last index is the end of the string. - cpus := make([]thread, 0, len(indices)) - // Find each string that represents a CPU. These begin "processor". - for i := 1; i < len(indices); i++ { - start := indices[i-1][0] - end := indices[i][0] - // Parse the CPU entry, which should be between start/end. - c, err := newThread(data[start:end]) - if err != nil { - return nil, err - } - cpus = append(cpus, c) - } - return cpus, nil -} - -// getThreadsFromPossible makes threads from data read from /sys/devices/system/cpu/possible. -func getThreadsFromPossible(data []byte) ([]thread, error) { - possibleRegex := regexp.MustCompile(`(?m)^(\d+)(-(\d+))?$`) - matches := possibleRegex.FindStringSubmatch(string(data)) - if len(matches) != 4 { - return nil, fmt.Errorf("mismatch regex from %s: %q", allPossibleCPUs, string(data)) - } - - // If matches[3] is empty, we only have one cpu entry. - if matches[3] == "" { - matches[3] = matches[1] - } - - begin, err := strconv.ParseInt(matches[1], 10, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse begin: %v", err) - } - end, err := strconv.ParseInt(matches[3], 10, 64) - if err != nil { - return nil, fmt.Errorf("failed to parse end: %v", err) - } - if begin > end || begin < 0 || end < 0 { - return nil, fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", begin, end) - } - - ret := make([]thread, 0, end-begin) - for i := begin; i <= end; i++ { - ret = append(ret, thread{ - processorNumber: i, - id: cpuID{ - physicalID: 0, // we don't care about id for enable ops. - coreID: 0, - }, - }) - } - - return ret, nil -} - -// cpuID for each thread is defined by the physical and -// core IDs. If equal, two threads are Hyperthread pairs. -type cpuID struct { - physicalID int64 - coreID int64 -} - -// type cpu represents pertinent info about a cpu. -type thread struct { - processorNumber int64 // the processor number of this CPU. - vendorID string // the vendorID of CPU (e.g. AuthenticAMD). - cpuFamily int64 // CPU family number (e.g. 6 for CascadeLake/Skylake). - model int64 // CPU model number (e.g. 85 for CascadeLake/Skylake). - id cpuID // id for this thread - bugs map[string]struct{} // map of vulnerabilities parsed from the 'bugs' field. -} - -// newThread parses a CPU from a single cpu entry from /proc/cpuinfo. -func newThread(data string) (thread, error) { - empty := thread{} - processor, err := parseProcessor(data) - if err != nil { - return empty, err - } - - vendorID, err := parseVendorID(data) - if err != nil { - return empty, err - } - - cpuFamily, err := parseCPUFamily(data) - if err != nil { - return empty, err - } - - model, err := parseModel(data) - if err != nil { - return empty, err - } - - physicalID, err := parsePhysicalID(data) - if err != nil { - return empty, err - } - - coreID, err := parseCoreID(data) - if err != nil { - return empty, err - } - - bugs, err := parseBugs(data) - if err != nil { - return empty, err - } - - return thread{ - processorNumber: processor, - vendorID: vendorID, - cpuFamily: cpuFamily, - model: model, - id: cpuID{ - physicalID: physicalID, - coreID: coreID, - }, - bugs: bugs, - }, nil -} - -// String implements the String method for thread. -func (t thread) String() string { - template := `CPU: %d -CPU ID: %+v -Vendor: %s -Family/Model: %d/%d -Bugs: %s -` - bugs := make([]string, 0) - for bug := range t.bugs { - bugs = append(bugs, bug) - } - - return fmt.Sprintf(template, t.processorNumber, t.id, t.vendorID, t.cpuFamily, t.model, strings.Join(bugs, ",")) -} - -// enable turns on the CPU by writing 1 to /sys/devices/cpu/cpu{N}/online. -func (t thread) enable() error { - cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) - return ioutil.WriteFile(cpuPath, []byte{'1'}, 0644) -} - -// disable turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online. -func (t thread) disable() error { - cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) - return ioutil.WriteFile(cpuPath, []byte{'0'}, 0644) -} - -// isVulnerable checks if a CPU is vulnerable to mds. -func (t thread) isVulnerable() bool { - _, ok := t.bugs[mds] - return ok -} - -// isActive checks if a CPU is active from /sys/devices/system/cpu/cpu{N}/online -// If the file does not exist (ioutil returns in error), we assume the CPU is on. -func (t thread) isActive() bool { - cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) - data, err := ioutil.ReadFile(cpuPath) - if err != nil { - return true - } - return len(data) > 0 && data[0] != '0' -} - -// similarTo checks family/model/bugs fields for equality of two -// processors. -func (t thread) similarTo(other thread) bool { - if t.vendorID != other.vendorID { - return false - } - - if other.cpuFamily != t.cpuFamily { - return false - } - - if other.model != t.model { - return false - } - - if len(other.bugs) != len(t.bugs) { - return false - } - - for bug := range t.bugs { - if _, ok := other.bugs[bug]; !ok { - return false - } - } - return true -} - -// parseProcessor grabs the processor field from /proc/cpuinfo output. -func parseProcessor(data string) (int64, error) { - return parseIntegerResult(data, processorKey) -} - -// parseVendorID grabs the vendor_id field from /proc/cpuinfo output. -func parseVendorID(data string) (string, error) { - return parseRegex(data, vendorIDKey, `[\w\d]+`) -} - -// parseCPUFamily grabs the cpu family field from /proc/cpuinfo output. -func parseCPUFamily(data string) (int64, error) { - return parseIntegerResult(data, cpuFamilyKey) -} - -// parseModel grabs the model field from /proc/cpuinfo output. -func parseModel(data string) (int64, error) { - return parseIntegerResult(data, modelKey) -} - -// parsePhysicalID parses the physical id field. -func parsePhysicalID(data string) (int64, error) { - return parseIntegerResult(data, physicalIDKey) -} - -// parseCoreID parses the core id field. -func parseCoreID(data string) (int64, error) { - return parseIntegerResult(data, coreIDKey) -} - -// parseBugs grabs the bugs field from /proc/cpuinfo output. -func parseBugs(data string) (map[string]struct{}, error) { - result, err := parseRegex(data, bugsKey, `[\d\w\s]*`) - if err != nil { - return nil, err - } - bugs := strings.Split(result, " ") - ret := make(map[string]struct{}, len(bugs)) - for _, bug := range bugs { - ret[bug] = struct{}{} - } - return ret, nil -} - -// parseIntegerResult parses fields expecting an integer. -func parseIntegerResult(data, key string) (int64, error) { - result, err := parseRegex(data, key, `\d+`) - if err != nil { - return 0, err - } - return strconv.ParseInt(result, 0, 64) -} - -// buildRegex builds a regex for parsing each CPU field. -func buildRegex(key, match string) *regexp.Regexp { - reg := fmt.Sprintf(`(?m)^%s\s*:\s*(.*)$`, key) - return regexp.MustCompile(reg) -} - -// parseRegex parses data with key inserted into a standard regex template. -func parseRegex(data, key, match string) (string, error) { - r := buildRegex(key, match) - matches := r.FindStringSubmatch(data) - if len(matches) < 2 { - return "", fmt.Errorf("failed to match key %q: %q", key, data) - } - return matches[1], nil -} diff --git a/runsc/mitigate/cpu_test.go b/runsc/mitigate/cpu_test.go deleted file mode 100644 index 374333465..000000000 --- a/runsc/mitigate/cpu_test.go +++ /dev/null @@ -1,605 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package mitigate - -import ( - "fmt" - "io/ioutil" - "strings" - "testing" -) - -// mockCPU represents data from CPUs that will be mitigated. -type mockCPU struct { - name string - vendorID string - family int - model int - modelName string - bugs string - physicalCores int - cores int - threadsPerCore int -} - -var cascadeLake4 = mockCPU{ - name: "CascadeLake", - vendorID: "GenuineIntel", - family: 6, - model: 85, - modelName: "Intel(R) Xeon(R) CPU", - bugs: "spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa", - physicalCores: 1, - cores: 2, - threadsPerCore: 2, -} - -var haswell2 = mockCPU{ - name: "Haswell", - vendorID: "GenuineIntel", - family: 6, - model: 63, - modelName: "Intel(R) Xeon(R) CPU", - bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs", - physicalCores: 1, - cores: 1, - threadsPerCore: 2, -} - -var haswell2core = mockCPU{ - name: "Haswell2Physical", - vendorID: "GenuineIntel", - family: 6, - model: 63, - modelName: "Intel(R) Xeon(R) CPU", - bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs", - physicalCores: 2, - cores: 1, - threadsPerCore: 1, -} - -var amd8 = mockCPU{ - name: "AMD", - vendorID: "AuthenticAMD", - family: 23, - model: 49, - modelName: "AMD EPYC 7B12", - bugs: "sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass", - physicalCores: 4, - cores: 1, - threadsPerCore: 2, -} - -// makeCPUString makes a string formated like /proc/cpuinfo for each cpuTestCase -func (tc mockCPU) makeCPUString() string { - template := `processor : %d -vendor_id : %s -cpu family : %d -model : %d -model name : %s -physical id : %d -core id : %d -cpu cores : %d -bugs : %s -` - ret := `` - for i := 0; i < tc.physicalCores; i++ { - for j := 0; j < tc.cores; j++ { - for k := 0; k < tc.threadsPerCore; k++ { - processorNum := (i*tc.cores+j)*tc.threadsPerCore + k - ret += fmt.Sprintf(template, - processorNum, /*processor*/ - tc.vendorID, /*vendor_id*/ - tc.family, /*cpu family*/ - tc.model, /*model*/ - tc.modelName, /*model name*/ - i, /*physical id*/ - j, /*core id*/ - tc.cores*tc.physicalCores, /*cpu cores*/ - tc.bugs /*bugs*/) - } - } - } - return ret -} - -func (tc mockCPU) makeSysPossibleString() string { - max := tc.physicalCores * tc.cores * tc.threadsPerCore - if max == 1 { - return "0" - } - return fmt.Sprintf("0-%d", max-1) -} - -// TestMockCPUSet tests mock cpu test cases against the cpuSet functions. -func TestMockCPUSet(t *testing.T) { - for _, tc := range []struct { - testCase mockCPU - isVulnerable bool - }{ - { - testCase: amd8, - isVulnerable: false, - }, - { - testCase: haswell2, - isVulnerable: true, - }, - { - testCase: haswell2core, - isVulnerable: true, - }, - - { - testCase: cascadeLake4, - isVulnerable: true, - }, - } { - t.Run(tc.testCase.name, func(t *testing.T) { - data := tc.testCase.makeCPUString() - vulnerable := func(t thread) bool { - return t.isVulnerable() - } - set, err := newCPUSet([]byte(data), vulnerable) - if err != nil { - t.Fatalf("Failed to ") - } - remaining := set.getRemainingList() - // In the non-vulnerable case, no cores should be shutdown so all should remain. - want := tc.testCase.physicalCores * tc.testCase.cores * tc.testCase.threadsPerCore - if tc.isVulnerable { - want = tc.testCase.physicalCores * tc.testCase.cores - } - - if want != len(remaining) { - t.Fatalf("Failed to shutdown the correct number of cores: want: %d got: %d", want, len(remaining)) - } - - if !tc.isVulnerable { - return - } - - // If the set is vulnerable, we expect only 1 thread per hyperthread pair. - for _, r := range remaining { - if _, ok := set[r.id]; !ok { - t.Fatalf("Entry %+v not in map, there must be two entries in the same thread group.", r) - } - delete(set, r.id) - } - - possible := tc.testCase.makeSysPossibleString() - set, err = newCPUSetFromPossible([]byte(possible)) - if err != nil { - t.Fatalf("Failed to make cpuSet: %v", err) - } - - want = tc.testCase.physicalCores * tc.testCase.cores * tc.testCase.threadsPerCore - got := len(set.getRemainingList()) - if got != want { - t.Fatalf("Returned the wrong number of CPUs want: %d got: %d", want, got) - } - }) - } -} - -// TestGetCPU tests basic parsing of single CPU strings from reading -// /proc/cpuinfo. -func TestGetCPU(t *testing.T) { - data := `processor : 0 -vendor_id : GenuineIntel -cpu family : 6 -model : 85 -physical id: 0 -core id : 0 -bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa itlb_multihit -` - want := thread{ - processorNumber: 0, - vendorID: "GenuineIntel", - cpuFamily: 6, - model: 85, - id: cpuID{ - physicalID: 0, - coreID: 0, - }, - bugs: map[string]struct{}{ - "cpu_meltdown": struct{}{}, - "spectre_v1": struct{}{}, - "spectre_v2": struct{}{}, - "spec_store_bypass": struct{}{}, - "l1tf": struct{}{}, - "mds": struct{}{}, - "swapgs": struct{}{}, - "taa": struct{}{}, - "itlb_multihit": struct{}{}, - }, - } - - got, err := newThread(data) - if err != nil { - t.Fatalf("getCpu failed with error: %v", err) - } - - if !want.similarTo(got) { - t.Fatalf("Failed cpus not similar: got: %+v, want: %+v", got, want) - } - - if !got.isVulnerable() { - t.Fatalf("Failed: cpu should be vulnerable.") - } -} - -func TestInvalid(t *testing.T) { - result, err := getThreads(`something not a processor`) - if err == nil { - t.Fatalf("getCPU set didn't return an error: %+v", result) - } - - if !strings.Contains(err.Error(), "no cpus") { - t.Fatalf("Incorrect error returned: %v", err) - } -} - -// TestCPUSet tests getting the right number of CPUs from -// parsing full output of /proc/cpuinfo. -func TestCPUSet(t *testing.T) { - data := `processor : 0 -vendor_id : GenuineIntel -cpu family : 6 -model : 63 -model name : Intel(R) Xeon(R) CPU @ 2.30GHz -stepping : 0 -microcode : 0x1 -cpu MHz : 2299.998 -cache size : 46080 KB -physical id : 0 -siblings : 2 -core id : 0 -cpu cores : 1 -apicid : 0 -initial apicid : 0 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities -bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs -bogomips : 4599.99 -clflush size : 64 -cache_alignment : 64 -address sizes : 46 bits physical, 48 bits virtual -power management: - -processor : 1 -vendor_id : GenuineIntel -cpu family : 6 -model : 63 -model name : Intel(R) Xeon(R) CPU @ 2.30GHz -stepping : 0 -microcode : 0x1 -cpu MHz : 2299.998 -cache size : 46080 KB -physical id : 0 -siblings : 2 -core id : 0 -cpu cores : 1 -apicid : 1 -initial apicid : 1 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities -bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs -bogomips : 4599.99 -clflush size : 64 -cache_alignment : 64 -address sizes : 46 bits physical, 48 bits virtual -power management: -` - cpuSet, err := getThreads(data) - if err != nil { - t.Fatalf("getCPUSet failed: %v", err) - } - - wantCPULen := 2 - if len(cpuSet) != wantCPULen { - t.Fatalf("Num CPU mismatch: want: %d, got: %d", wantCPULen, len(cpuSet)) - } - - wantCPU := thread{ - vendorID: "GenuineIntel", - cpuFamily: 6, - model: 63, - bugs: map[string]struct{}{ - "cpu_meltdown": struct{}{}, - "spectre_v1": struct{}{}, - "spectre_v2": struct{}{}, - "spec_store_bypass": struct{}{}, - "l1tf": struct{}{}, - "mds": struct{}{}, - "swapgs": struct{}{}, - }, - } - - for _, c := range cpuSet { - if !wantCPU.similarTo(c) { - t.Fatalf("Failed cpus not equal: got: %+v, want: %+v", c, wantCPU) - } - } -} - -// TestReadFile is a smoke test for parsing methods. -func TestReadFile(t *testing.T) { - data, err := ioutil.ReadFile("/proc/cpuinfo") - if err != nil { - t.Fatalf("Failed to read cpuinfo: %v", err) - } - - vulnerable := func(t thread) bool { - return t.isVulnerable() - } - - set, err := newCPUSet(data, vulnerable) - if err != nil { - t.Fatalf("Failed to parse CPU data %v\n%s", err, data) - } - - if len(set) < 1 { - t.Fatalf("Failed to parse any CPUs: %d", len(set)) - } - - t.Log(set) -} - -// TestVulnerable tests if the isVulnerable method is correct -// among known CPUs in GCP. -func TestVulnerable(t *testing.T) { - const haswell = `processor : 0 -vendor_id : GenuineIntel -cpu family : 6 -model : 63 -model name : Intel(R) Xeon(R) CPU @ 2.30GHz -stepping : 0 -microcode : 0x1 -cpu MHz : 2299.998 -cache size : 46080 KB -physical id : 0 -siblings : 4 -core id : 0 -cpu cores : 2 -apicid : 0 -initial apicid : 0 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities -bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs -bogomips : 4599.99 -clflush size : 64 -cache_alignment : 64 -address sizes : 46 bits physical, 48 bits virtual -power management:` - - const skylake = `processor : 0 -vendor_id : GenuineIntel -cpu family : 6 -model : 85 -model name : Intel(R) Xeon(R) CPU @ 2.00GHz -stepping : 3 -microcode : 0x1 -cpu MHz : 2000.180 -cache size : 39424 KB -physical id : 0 -siblings : 2 -core id : 0 -cpu cores : 1 -apicid : 0 -initial apicid : 0 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities -bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa -bogomips : 4000.36 -clflush size : 64 -cache_alignment : 64 -address sizes : 46 bits physical, 48 bits virtual -power management:` - - const cascade = `processor : 0 -vendor_id : GenuineIntel -cpu family : 6 -model : 85 -model name : Intel(R) Xeon(R) CPU -stepping : 7 -microcode : 0x1 -cpu MHz : 2800.198 -cache size : 33792 KB -physical id : 0 -siblings : 2 -core id : 0 -cpu cores : 1 -apicid : 0 -initial apicid : 0 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 - ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmu -lqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowpr -efetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid r -tm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves a -rat avx512_vnni md_clear arch_capabilities -bugs : spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa -bogomips : 5600.39 -clflush size : 64 -cache_alignment : 64 -address sizes : 46 bits physical, 48 bits virtual -power management:` - - const amd = `processor : 0 -vendor_id : AuthenticAMD -cpu family : 23 -model : 49 -model name : AMD EPYC 7B12 -stepping : 0 -microcode : 0x1000065 -cpu MHz : 2250.000 -cache size : 512 KB -physical id : 0 -siblings : 2 -core id : 0 -cpu cores : 1 -apicid : 0 -initial apicid : 0 -fpu : yes -fpu_exception : yes -cpuid level : 13 -wp : yes -flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid -bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass -bogomips : 4500.00 -TLB size : 3072 4K pages -clflush size : 64 -cache_alignment : 64 -address sizes : 48 bits physical, 48 bits virtual -power management:` - - for _, tc := range []struct { - name string - cpuString string - vulnerable bool - }{ - { - name: "haswell", - cpuString: haswell, - vulnerable: true, - }, { - name: "skylake", - cpuString: skylake, - vulnerable: true, - }, { - name: "amd", - cpuString: amd, - vulnerable: false, - }, - } { - t.Run(tc.name, func(t *testing.T) { - set, err := getThreads(tc.cpuString) - if err != nil { - t.Fatalf("Failed to getCPUSet:%v\n %s", err, tc.cpuString) - } - - if len(set) < 1 { - t.Fatalf("Returned empty cpu set: %v", set) - } - - for _, c := range set { - got := func() bool { - return c.isVulnerable() - }() - - if got != tc.vulnerable { - t.Fatalf("Mismatch vulnerable for cpu %+s: got %t want: %t", tc.name, tc.vulnerable, got) - } - } - }) - } -} - -func TestReverse(t *testing.T) { - const noParse = "-1-" - for _, tc := range []struct { - name string - output string - wantErr error - wantCount int - }{ - { - name: "base", - output: "0-7", - wantErr: nil, - wantCount: 8, - }, - { - name: "huge", - output: "0-111", - wantErr: nil, - wantCount: 112, - }, - { - name: "not zero", - output: "50-53", - wantErr: nil, - wantCount: 4, - }, - { - name: "small", - output: "0", - wantErr: nil, - wantCount: 1, - }, - { - name: "invalid order", - output: "10-6", - wantErr: fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", 10, 6), - }, - { - name: "no parse", - output: noParse, - wantErr: fmt.Errorf(`mismatch regex from /sys/devices/system/cpu/possible: %q`, noParse), - }, - } { - t.Run(tc.name, func(t *testing.T) { - threads, err := getThreadsFromPossible([]byte(tc.output)) - - switch { - case tc.wantErr == nil: - if err != nil { - t.Fatalf("Wanted nil err, got: %v", err) - } - case err == nil: - t.Fatalf("Want error: %v got: %v", tc.wantErr, err) - default: - if tc.wantErr.Error() != err.Error() { - t.Fatalf("Want error: %v got error: %v", tc.wantErr, err) - } - } - - if len(threads) != tc.wantCount { - t.Fatalf("Want count: %d got: %d", tc.wantCount, len(threads)) - } - }) - } -} - -func TestReverseSmoke(t *testing.T) { - data, err := ioutil.ReadFile(allPossibleCPUs) - if err != nil { - t.Fatalf("Failed to read from possible: %v", err) - } - threads, err := getThreadsFromPossible(data) - if err != nil { - t.Fatalf("Could not parse possible output: %v", err) - } - - if len(threads) <= 0 { - t.Fatalf("Didn't get any CPU cores: %d", len(threads)) - } -} diff --git a/runsc/mitigate/mitigate.go b/runsc/mitigate/mitigate.go index 91de623e3..24f67414c 100644 --- a/runsc/mitigate/mitigate.go +++ b/runsc/mitigate/mitigate.go @@ -14,121 +14,440 @@ // Package mitigate provides libraries for the mitigate command. The // mitigate command mitigates side channel attacks such as MDS. Mitigate -// shuts down CPUs via /sys/devices/system/cpu/cpu{N}/online. In addition, -// the mitigate also handles computing available CPU in kubernetes kube_config -// files. +// shuts down CPUs via /sys/devices/system/cpu/cpu{N}/online. package mitigate import ( "fmt" "io/ioutil" - - "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/runsc/flag" + "os" + "regexp" + "sort" + "strconv" + "strings" ) const ( - cpuInfo = "/proc/cpuinfo" - allPossibleCPUs = "/sys/devices/system/cpu/possible" + // mds is the only bug we care about. + mds = "mds" + + // Constants for parsing /proc/cpuinfo. + processorKey = "processor" + vendorIDKey = "vendor_id" + cpuFamilyKey = "cpu family" + modelKey = "model" + physicalIDKey = "physical id" + coreIDKey = "core id" + bugsKey = "bugs" + + // Path to shutdown a CPU. + cpuOnlineTemplate = "/sys/devices/system/cpu/cpu%d/online" ) -// Mitigate handles high level mitigate operations provided to runsc. -type Mitigate struct { - dryRun bool // Run the command without changing the underlying system. - reverse bool // Reverse mitigate by turning on all CPU cores. - other mitigate // Struct holds extra mitigate logic. - path string // path to read for each operation (e.g. /proc/cpuinfo). +// CPUSet contains a map of all CPUs on the system, mapped +// by Physical ID and CoreIDs. threads with the same +// Core and Physical ID are Hyperthread pairs. +type CPUSet map[threadID]*ThreadGroup + +// NewCPUSet creates a CPUSet from data read from /proc/cpuinfo. +func NewCPUSet(data []byte, vulnerable func(Thread) bool) (CPUSet, error) { + processors, err := getThreads(string(data)) + if err != nil { + return nil, err + } + + set := make(CPUSet) + for _, p := range processors { + // Each ID is of the form physicalID:coreID. Hyperthread pairs + // have identical physical and core IDs. We need to match + // Hyperthread pairs so that we can shutdown all but one per + // pair. + core, ok := set[p.id] + if !ok { + core = &ThreadGroup{} + set[p.id] = core + } + core.isVulnerable = core.isVulnerable || vulnerable(p) + core.threads = append(core.threads, p) + } + + // We need to make sure we shutdown the lowest number processor per + // thread group. + for _, tg := range set { + sort.Slice(tg.threads, func(i, j int) bool { + return tg.threads[i].processorNumber < tg.threads[j].processorNumber + }) + } + return set, nil } -// Usage implments Usage for cmd.Mitigate. -func (m Mitigate) Usage() string { - usageString := `mitigate [flags] +// NewCPUSetFromPossible makes a cpuSet data read from +// /sys/devices/system/cpu/possible. This is used in enable operations +// where the caller simply wants to enable all CPUS. +func NewCPUSetFromPossible(data []byte) (CPUSet, error) { + threads, err := GetThreadsFromPossible(data) + if err != nil { + return nil, err + } + + // We don't care if a CPU is vulnerable or not, we just + // want to return a list of all CPUs on the host. + set := CPUSet{ + threads[0].id: &ThreadGroup{ + threads: threads, + isVulnerable: false, + }, + } + return set, nil +} -Mitigate mitigates a system to the "MDS" vulnerability by implementing a manual shutdown of SMT. The command checks /proc/cpuinfo for cpus having the MDS vulnerability, and if found, shutdown all but one CPU per hyperthread pair via /sys/devices/system/cpu/cpu{N}/online. CPUs can be restored by writing "2" to each file in /sys/devices/system/cpu/cpu{N}/online or performing a system reboot. +// String implements the String method for CPUSet. +func (c CPUSet) String() string { + ret := "" + for _, tg := range c { + ret += fmt.Sprintf("%s\n", tg) + } + return ret +} -The command can be reversed with --reverse, which reads the total CPUs from /sys/devices/system/cpu/possible and enables all with /sys/devices/system/cpu/cpu{N}/online. -` - return usageString + m.other.usage() +// GetRemainingList returns the list of threads that will remain active +// after mitigation. +func (c CPUSet) GetRemainingList() []Thread { + threads := make([]Thread, 0, len(c)) + for _, core := range c { + // If we're vulnerable, take only one thread from the pair. + if core.isVulnerable { + threads = append(threads, core.threads[0]) + continue + } + // Otherwise don't shutdown anything. + threads = append(threads, core.threads...) + } + return threads } -// SetFlags sets flags for the command Mitigate. -func (m Mitigate) SetFlags(f *flag.FlagSet) { - f.BoolVar(&m.dryRun, "dryrun", false, "run the command without changing system") - f.BoolVar(&m.reverse, "reverse", false, "reverse mitigate by enabling all CPUs") - m.other.setFlags(f) - m.path = cpuInfo - if m.reverse { - m.path = allPossibleCPUs +// GetShutdownList returns the list of threads that will be shutdown on +// mitigation. +func (c CPUSet) GetShutdownList() []Thread { + threads := make([]Thread, 0) + for _, core := range c { + // Only if we're vulnerable do shutdown anything. In this case, + // shutdown all but the first entry. + if core.isVulnerable && len(core.threads) > 1 { + threads = append(threads, core.threads[1:]...) + } } + return threads } -// Execute executes the Mitigate command. -func (m Mitigate) Execute() error { - data, err := ioutil.ReadFile(m.path) - if err != nil { - return fmt.Errorf("failed to read %s: %v", m.path, err) +// ThreadGroup represents Hyperthread pairs on the same physical/core ID. +type ThreadGroup struct { + threads []Thread + isVulnerable bool +} + +// String implements the String method for threadGroup. +func (c ThreadGroup) String() string { + ret := fmt.Sprintf("ThreadGroup:\nIsVulnerable: %t\n", c.isVulnerable) + for _, processor := range c.threads { + ret += fmt.Sprintf("%s\n", processor) } + return ret +} - if m.reverse { - err := m.doReverse(data) +// getThreads returns threads structs from reading /proc/cpuinfo. +func getThreads(data string) ([]Thread, error) { + // Each processor entry should start with the + // processor key. Find the beginings of each. + r := buildRegex(processorKey, `\d+`) + indices := r.FindAllStringIndex(data, -1) + if len(indices) < 1 { + return nil, fmt.Errorf("no cpus found for: %q", data) + } + + // Add the ending index for last entry. + indices = append(indices, []int{len(data), -1}) + + // Valid cpus are now defined by strings in between + // indexes (e.g. data[index[i], index[i+1]]). + // There should be len(indicies) - 1 CPUs + // since the last index is the end of the string. + cpus := make([]Thread, 0, len(indices)) + // Find each string that represents a CPU. These begin "processor". + for i := 1; i < len(indices); i++ { + start := indices[i-1][0] + end := indices[i][0] + // Parse the CPU entry, which should be between start/end. + c, err := newThread(data[start:end]) if err != nil { - return fmt.Errorf("reverse operation failed: %v", err) + return nil, err } - return nil + cpus = append(cpus, c) + } + return cpus, nil +} + +// GetThreadsFromPossible makes threads from data read from /sys/devices/system/cpu/possible. +func GetThreadsFromPossible(data []byte) ([]Thread, error) { + possibleRegex := regexp.MustCompile(`(?m)^(\d+)(-(\d+))?$`) + matches := possibleRegex.FindStringSubmatch(string(data)) + if len(matches) != 4 { + return nil, fmt.Errorf("mismatch regex from possible: %q", string(data)) + } + + // If matches[3] is empty, we only have one cpu entry. + if matches[3] == "" { + matches[3] = matches[1] } - set, err := m.doMitigate(data) + begin, err := strconv.ParseInt(matches[1], 10, 64) if err != nil { - return fmt.Errorf("mitigate operation failed: %v", err) + return nil, fmt.Errorf("failed to parse begin: %v", err) } - return m.other.execute(set, m.dryRun) + end, err := strconv.ParseInt(matches[3], 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse end: %v", err) + } + if begin > end || begin < 0 || end < 0 { + return nil, fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", begin, end) + } + + ret := make([]Thread, 0, end-begin) + for i := begin; i <= end; i++ { + ret = append(ret, Thread{ + processorNumber: i, + id: threadID{ + physicalID: 0, // we don't care about id for enable ops. + coreID: 0, + }, + }) + } + + return ret, nil +} + +// threadID for each thread is defined by the physical and +// core IDs. If equal, two threads are Hyperthread pairs. +type threadID struct { + physicalID int64 + coreID int64 } -func (m Mitigate) doMitigate(data []byte) (cpuSet, error) { - set, err := newCPUSet(data, m.other.vulnerable) +// Thread represents pertinent info about a single hyperthread in a pair. +type Thread struct { + processorNumber int64 // the processor number of this CPU. + vendorID string // the vendorID of CPU (e.g. AuthenticAMD). + cpuFamily int64 // CPU family number (e.g. 6 for CascadeLake/Skylake). + model int64 // CPU model number (e.g. 85 for CascadeLake/Skylake). + id threadID // id for this thread + bugs map[string]struct{} // map of vulnerabilities parsed from the 'bugs' field. +} + +// newThread parses a CPU from a single cpu entry from /proc/cpuinfo. +func newThread(data string) (Thread, error) { + empty := Thread{} + processor, err := parseProcessor(data) if err != nil { - return nil, err + return empty, err } - log.Infof("Mitigate found the following CPUs...") - log.Infof("%s", set) + vendorID, err := parseVendorID(data) + if err != nil { + return empty, err + } - disableList := set.getShutdownList() - log.Infof("Disabling threads on thread pairs.") - for _, t := range disableList { - log.Infof("Disable thread: %s", t) - if m.dryRun { - continue - } - if err := t.disable(); err != nil { - return nil, fmt.Errorf("error disabling thread: %s err: %v", t, err) - } + cpuFamily, err := parseCPUFamily(data) + if err != nil { + return empty, err } - log.Infof("Shutdown successful.") - return set, nil + + model, err := parseModel(data) + if err != nil { + return empty, err + } + + physicalID, err := parsePhysicalID(data) + if err != nil { + return empty, err + } + + coreID, err := parseCoreID(data) + if err != nil { + return empty, err + } + + bugs, err := parseBugs(data) + if err != nil { + return empty, err + } + + return Thread{ + processorNumber: processor, + vendorID: vendorID, + cpuFamily: cpuFamily, + model: model, + id: threadID{ + physicalID: physicalID, + coreID: coreID, + }, + bugs: bugs, + }, nil +} + +// String implements the String method for thread. +func (t Thread) String() string { + template := `CPU: %d +CPU ID: %+v +Vendor: %s +Family/Model: %d/%d +Bugs: %s +` + bugs := make([]string, 0) + for bug := range t.bugs { + bugs = append(bugs, bug) + } + + return fmt.Sprintf(template, t.processorNumber, t.id, t.vendorID, t.cpuFamily, t.model, strings.Join(bugs, ",")) +} + +// Enable turns on the CPU by writing 1 to /sys/devices/cpu/cpu{N}/online. +func (t Thread) Enable() error { + // Linux ensures that "cpu0" is always online. + if t.processorNumber == 0 { + return nil + } + cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) + f, err := os.OpenFile(cpuPath, os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return fmt.Errorf("failed to open file %s: %v", cpuPath, err) + } + if _, err = f.Write([]byte{'1'}); err != nil { + return fmt.Errorf("failed to write '1' to %s: %v", cpuPath, err) + } + return nil +} + +// Disable turns off the CPU by writing 0 to /sys/devices/cpu/cpu{N}/online. +func (t Thread) Disable() error { + // The core labeled "cpu0" can never be taken offline via this method. + // Linux will return EPERM if the user even creates a file at the /sys + // path above. + if t.processorNumber == 0 { + return fmt.Errorf("invalid shutdown operation: cpu0 cannot be disabled") + } + cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) + return ioutil.WriteFile(cpuPath, []byte{'0'}, 0644) } -func (m Mitigate) doReverse(data []byte) error { - set, err := newCPUSetFromPossible(data) +// IsVulnerable checks if a CPU is vulnerable to mds. +func (t Thread) IsVulnerable() bool { + _, ok := t.bugs[mds] + return ok +} + +// isActive checks if a CPU is active from /sys/devices/system/cpu/cpu{N}/online +// If the file does not exist (ioutil returns in error), we assume the CPU is on. +func (t Thread) isActive() bool { + cpuPath := fmt.Sprintf(cpuOnlineTemplate, t.processorNumber) + data, err := ioutil.ReadFile(cpuPath) if err != nil { - return err + return true } + return len(data) > 0 && data[0] != '0' +} - log.Infof("Reverse mitigate found the following CPUs...") - log.Infof("%s", set) +// SimilarTo checks family/model/bugs fields for equality of two +// processors. +func (t Thread) SimilarTo(other Thread) bool { + if t.vendorID != other.vendorID { + return false + } - enableList := set.getRemainingList() + if other.cpuFamily != t.cpuFamily { + return false + } - log.Infof("Enabling all CPUs...") - for _, t := range enableList { - log.Infof("Enabling thread: %s", t) - if m.dryRun { - continue - } - if err := t.enable(); err != nil { - return fmt.Errorf("error enabling thread: %s err: %v", t, err) + if other.model != t.model { + return false + } + + if len(other.bugs) != len(t.bugs) { + return false + } + + for bug := range t.bugs { + if _, ok := other.bugs[bug]; !ok { + return false } } - log.Infof("Enable successful.") - return nil + return true +} + +// parseProcessor grabs the processor field from /proc/cpuinfo output. +func parseProcessor(data string) (int64, error) { + return parseIntegerResult(data, processorKey) +} + +// parseVendorID grabs the vendor_id field from /proc/cpuinfo output. +func parseVendorID(data string) (string, error) { + return parseRegex(data, vendorIDKey, `[\w\d]+`) +} + +// parseCPUFamily grabs the cpu family field from /proc/cpuinfo output. +func parseCPUFamily(data string) (int64, error) { + return parseIntegerResult(data, cpuFamilyKey) +} + +// parseModel grabs the model field from /proc/cpuinfo output. +func parseModel(data string) (int64, error) { + return parseIntegerResult(data, modelKey) +} + +// parsePhysicalID parses the physical id field. +func parsePhysicalID(data string) (int64, error) { + return parseIntegerResult(data, physicalIDKey) +} + +// parseCoreID parses the core id field. +func parseCoreID(data string) (int64, error) { + return parseIntegerResult(data, coreIDKey) +} + +// parseBugs grabs the bugs field from /proc/cpuinfo output. +func parseBugs(data string) (map[string]struct{}, error) { + result, err := parseRegex(data, bugsKey, `[\d\w\s]*`) + if err != nil { + return nil, err + } + bugs := strings.Split(result, " ") + ret := make(map[string]struct{}, len(bugs)) + for _, bug := range bugs { + ret[bug] = struct{}{} + } + return ret, nil +} + +// parseIntegerResult parses fields expecting an integer. +func parseIntegerResult(data, key string) (int64, error) { + result, err := parseRegex(data, key, `\d+`) + if err != nil { + return 0, err + } + return strconv.ParseInt(result, 0, 64) +} + +// buildRegex builds a regex for parsing each CPU field. +func buildRegex(key, match string) *regexp.Regexp { + reg := fmt.Sprintf(`(?m)^%s\s*:\s*(.*)$`, key) + return regexp.MustCompile(reg) +} + +// parseRegex parses data with key inserted into a standard regex template. +func parseRegex(data, key, match string) (string, error) { + r := buildRegex(key, match) + matches := r.FindStringSubmatch(data) + if len(matches) < 2 { + return "", fmt.Errorf("failed to match key %q: %q", key, data) + } + return matches[1], nil } diff --git a/runsc/mitigate/mitigate_test.go b/runsc/mitigate/mitigate_test.go index b3a9a9b18..fbd8eb886 100644 --- a/runsc/mitigate/mitigate_test.go +++ b/runsc/mitigate/mitigate_test.go @@ -17,138 +17,519 @@ package mitigate import ( "fmt" "io/ioutil" - "os" "strings" "testing" + + "gvisor.dev/gvisor/runsc/mitigate/mock" ) -type executeTestCase struct { - name string - mitigateData string - mitigateError error - reverseData string - reverseError error +// TestMockCPUSet tests mock cpu test cases against the cpuSet functions. +func TestMockCPUSet(t *testing.T) { + for _, tc := range []struct { + testCase mock.CPU + isVulnerable bool + }{ + { + testCase: mock.AMD8, + isVulnerable: false, + }, + { + testCase: mock.Haswell2, + isVulnerable: true, + }, + { + testCase: mock.Haswell2core, + isVulnerable: true, + }, + { + testCase: mock.CascadeLake2, + isVulnerable: true, + }, + { + testCase: mock.CascadeLake4, + isVulnerable: true, + }, + } { + t.Run(tc.testCase.Name, func(t *testing.T) { + data := tc.testCase.MakeCPUString() + vulnerable := func(t Thread) bool { + return t.IsVulnerable() + } + set, err := NewCPUSet([]byte(data), vulnerable) + if err != nil { + t.Fatalf("Failed to create cpuSet: %v", err) + } + + for _, tg := range set { + if err := checkSorted(tg.threads); err != nil { + t.Fatalf("Failed to sort cpuSet: %v", err) + } + } + + remaining := set.GetRemainingList() + // In the non-vulnerable case, no cores should be shutdown so all should remain. + want := tc.testCase.PhysicalCores * tc.testCase.Cores * tc.testCase.ThreadsPerCore + if tc.isVulnerable { + want = tc.testCase.PhysicalCores * tc.testCase.Cores + } + + if want != len(remaining) { + t.Fatalf("Failed to shutdown the correct number of cores: want: %d got: %d", want, len(remaining)) + } + + if !tc.isVulnerable { + return + } + + // If the set is vulnerable, we expect only 1 thread per hyperthread pair. + for _, r := range remaining { + if _, ok := set[r.id]; !ok { + t.Fatalf("Entry %+v not in map, there must be two entries in the same thread group.", r) + } + delete(set, r.id) + } + + possible := tc.testCase.MakeSysPossibleString() + set, err = NewCPUSetFromPossible([]byte(possible)) + if err != nil { + t.Fatalf("Failed to make cpuSet: %v", err) + } + + want = tc.testCase.PhysicalCores * tc.testCase.Cores * tc.testCase.ThreadsPerCore + got := len(set.GetRemainingList()) + if got != want { + t.Fatalf("Returned the wrong number of CPUs want: %d got: %d", want, got) + } + }) + } } -func TestExecute(t *testing.T) { +// TestGetCPU tests basic parsing of single CPU strings from reading +// /proc/cpuinfo. +func TestGetCPU(t *testing.T) { + data := `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 85 +physical id: 0 +core id : 0 +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa itlb_multihit +` + want := Thread{ + processorNumber: 0, + vendorID: "GenuineIntel", + cpuFamily: 6, + model: 85, + id: threadID{ + physicalID: 0, + coreID: 0, + }, + bugs: map[string]struct{}{ + "cpu_meltdown": struct{}{}, + "spectre_v1": struct{}{}, + "spectre_v2": struct{}{}, + "spec_store_bypass": struct{}{}, + "l1tf": struct{}{}, + "mds": struct{}{}, + "swapgs": struct{}{}, + "taa": struct{}{}, + "itlb_multihit": struct{}{}, + }, + } - partial := `processor : 1 -vendor_id : AuthenticAMD -cpu family : 23 -model : 49 -model name : AMD EPYC 7B12 -physical id : 0 -bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass + got, err := newThread(data) + if err != nil { + t.Fatalf("getCpu failed with error: %v", err) + } + + if !want.SimilarTo(got) { + t.Fatalf("Failed cpus not similar: got: %+v, want: %+v", got, want) + } + + if !got.IsVulnerable() { + t.Fatalf("Failed: cpu should be vulnerable.") + } +} + +func TestInvalid(t *testing.T) { + result, err := getThreads(`something not a processor`) + if err == nil { + t.Fatalf("getCPU set didn't return an error: %+v", result) + } + + if !strings.Contains(err.Error(), "no cpus") { + t.Fatalf("Incorrect error returned: %v", err) + } +} + +// TestCPUSet tests getting the right number of CPUs from +// parsing full output of /proc/cpuinfo. +func TestCPUSet(t *testing.T) { + data := `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 63 +model name : Intel(R) Xeon(R) CPU @ 2.30GHz +stepping : 0 +microcode : 0x1 +cpu MHz : 2299.998 +cache size : 46080 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs +bogomips : 4599.99 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management: + +processor : 1 +vendor_id : GenuineIntel +cpu family : 6 +model : 63 +model name : Intel(R) Xeon(R) CPU @ 2.30GHz +stepping : 0 +microcode : 0x1 +cpu MHz : 2299.998 +cache size : 46080 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 1 +initial apicid : 1 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs +bogomips : 4599.99 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual power management: ` + cpuSet, err := getThreads(data) + if err != nil { + t.Fatalf("getCPUSet failed: %v", err) + } - for _, tc := range []executeTestCase{ - { - name: "CascadeLake4", - mitigateData: cascadeLake4.makeCPUString(), - reverseData: cascadeLake4.makeSysPossibleString(), - }, - { - name: "Empty", - mitigateData: "", - mitigateError: fmt.Errorf(`mitigate operation failed: no cpus found for: ""`), - reverseData: "", - reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from %s: ""`, allPossibleCPUs), + wantCPULen := 2 + if len(cpuSet) != wantCPULen { + t.Fatalf("Num CPU mismatch: want: %d, got: %d", wantCPULen, len(cpuSet)) + } + + wantCPU := Thread{ + vendorID: "GenuineIntel", + cpuFamily: 6, + model: 63, + bugs: map[string]struct{}{ + "cpu_meltdown": struct{}{}, + "spectre_v1": struct{}{}, + "spectre_v2": struct{}{}, + "spec_store_bypass": struct{}{}, + "l1tf": struct{}{}, + "mds": struct{}{}, + "swapgs": struct{}{}, }, - { - name: "Partial", - mitigateData: `processor : 0 + } + + for _, c := range cpuSet { + if !wantCPU.SimilarTo(c) { + t.Fatalf("Failed cpus not equal: got: %+v, want: %+v", c, wantCPU) + } + } +} + +// TestReadFile is a smoke test for parsing methods. +func TestReadFile(t *testing.T) { + data, err := ioutil.ReadFile("/proc/cpuinfo") + if err != nil { + t.Fatalf("Failed to read cpuinfo: %v", err) + } + + vulnerable := func(t Thread) bool { + return t.IsVulnerable() + } + + set, err := NewCPUSet(data, vulnerable) + if err != nil { + t.Fatalf("Failed to parse CPU data %v\n%s", err, data) + } + + for _, tg := range set { + if err := checkSorted(tg.threads); err != nil { + t.Fatalf("Failed to sort cpuSet: %v", err) + } + } + + if len(set) < 1 { + t.Fatalf("Failed to parse any CPUs: %d", len(set)) + } + + t.Log(set) +} + +// TestVulnerable tests if the isVulnerable method is correct +// among known CPUs in GCP. +func TestVulnerable(t *testing.T) { + const haswell = `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 63 +model name : Intel(R) Xeon(R) CPU @ 2.30GHz +stepping : 0 +microcode : 0x1 +cpu MHz : 2299.998 +cache size : 46080 KB +physical id : 0 +siblings : 4 +core id : 0 +cpu cores : 2 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs +bogomips : 4599.99 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management:` + + const skylake = `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 85 +model name : Intel(R) Xeon(R) CPU @ 2.00GHz +stepping : 3 +microcode : 0x1 +cpu MHz : 2000.180 +cache size : 39424 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities +bugs : cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa +bogomips : 4000.36 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management:` + + const cascade = `processor : 0 +vendor_id : GenuineIntel +cpu family : 6 +model : 85 +model name : Intel(R) Xeon(R) CPU +stepping : 7 +microcode : 0x1 +cpu MHz : 2800.198 +cache size : 33792 KB +physical id : 0 +siblings : 2 +core id : 0 +cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 + ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmu +lqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowpr +efetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid r +tm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves a +rat avx512_vnni md_clear arch_capabilities +bugs : spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa +bogomips : 5600.39 +clflush size : 64 +cache_alignment : 64 +address sizes : 46 bits physical, 48 bits virtual +power management:` + + const amd = `processor : 0 vendor_id : AuthenticAMD cpu family : 23 model : 49 model name : AMD EPYC 7B12 +stepping : 0 +microcode : 0x1000065 +cpu MHz : 2250.000 +cache size : 512 KB physical id : 0 +siblings : 2 core id : 0 cpu cores : 1 +apicid : 0 +initial apicid : 0 +fpu : yes +fpu_exception : yes +cpuid level : 13 +wp : yes +flags : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw topoext ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr arat npt nrip_save umip rdpid bugs : sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass -power management: +bogomips : 4500.00 +TLB size : 3072 4K pages +clflush size : 64 +cache_alignment : 64 +address sizes : 48 bits physical, 48 bits virtual +power management:` -` + partial, - mitigateError: fmt.Errorf(`mitigate operation failed: failed to match key "core id": %q`, partial), - reverseData: "1-", - reverseError: fmt.Errorf(`reverse operation failed: mismatch regex from %s: %q`, allPossibleCPUs, "1-"), + for _, tc := range []struct { + name string + cpuString string + vulnerable bool + }{ + { + name: "haswell", + cpuString: haswell, + vulnerable: true, + }, { + name: "skylake", + cpuString: skylake, + vulnerable: true, + }, { + name: "amd", + cpuString: amd, + vulnerable: false, }, } { - doExecuteTest(t, Mitigate{}, tc) + t.Run(tc.name, func(t *testing.T) { + set, err := getThreads(tc.cpuString) + if err != nil { + t.Fatalf("Failed to getCPUSet:%v\n %s", err, tc.cpuString) + } + + if len(set) < 1 { + t.Fatalf("Returned empty cpu set: %v", set) + } + + for _, c := range set { + got := func() bool { + return c.IsVulnerable() + }() + + if got != tc.vulnerable { + t.Fatalf("Mismatch vulnerable for cpu %+s: got %t want: %t", tc.name, tc.vulnerable, got) + } + } + }) } } -func TestExecuteSmoke(t *testing.T) { - smokeMitigate, err := ioutil.ReadFile(cpuInfo) +func TestReverse(t *testing.T) { + const noParse = "-1-" + for _, tc := range []struct { + name string + output string + wantErr error + wantCount int + }{ + { + name: "base", + output: "0-7", + wantErr: nil, + wantCount: 8, + }, + { + name: "huge", + output: "0-111", + wantErr: nil, + wantCount: 112, + }, + { + name: "not zero", + output: "50-53", + wantErr: nil, + wantCount: 4, + }, + { + name: "small", + output: "0", + wantErr: nil, + wantCount: 1, + }, + { + name: "invalid order", + output: "10-6", + wantErr: fmt.Errorf("invalid cpu bounds from possible: begin: %d end: %d", 10, 6), + }, + { + name: "no parse", + output: noParse, + wantErr: fmt.Errorf(`mismatch regex from possible: %q`, noParse), + }, + } { + t.Run(tc.name, func(t *testing.T) { + threads, err := GetThreadsFromPossible([]byte(tc.output)) + + switch { + case tc.wantErr == nil: + if err != nil { + t.Fatalf("Wanted nil err, got: %v", err) + } + case err == nil: + t.Fatalf("Want error: %v got: %v", tc.wantErr, err) + default: + if tc.wantErr.Error() != err.Error() { + t.Fatalf("Want error: %v got error: %v", tc.wantErr, err) + } + } + + if len(threads) != tc.wantCount { + t.Fatalf("Want count: %d got: %d", tc.wantCount, len(threads)) + } + }) + } +} + +func TestReverseSmoke(t *testing.T) { + data, err := ioutil.ReadFile("/sys/devices/system/cpu/possible") if err != nil { - t.Fatalf("Failed to read %s: %v", cpuInfo, err) + t.Fatalf("Failed to read from possible: %v", err) } - smokeReverse, err := ioutil.ReadFile(allPossibleCPUs) + threads, err := GetThreadsFromPossible(data) if err != nil { - t.Fatalf("Failed to read %s: %v", allPossibleCPUs, err) + t.Fatalf("Could not parse possible output: %v", err) } - doExecuteTest(t, Mitigate{}, executeTestCase{ - name: "SmokeTest", - mitigateData: string(smokeMitigate), - reverseData: string(smokeReverse), - }) + if len(threads) <= 0 { + t.Fatalf("Didn't get any CPU cores: %d", len(threads)) + } } -// doExecuteTest runs Execute with the mitigate operation and reverse operation. -func doExecuteTest(t *testing.T, m Mitigate, tc executeTestCase) { - t.Run("Mitigate"+tc.name, func(t *testing.T) { - m.dryRun = true - file, err := ioutil.TempFile("", "outfile.txt") - if err != nil { - t.Fatalf("Failed to create tmpfile: %v", err) - } - defer os.Remove(file.Name()) - - if _, err := file.WriteString(tc.mitigateData); err != nil { - t.Fatalf("Failed to write to file: %v", err) - } - - m.path = file.Name() - - got := m.Execute() - if err = checkErr(tc.mitigateError, got); err != nil { - t.Fatalf("Mitigate error mismatch: %v", err) - } - }) - t.Run("Reverse"+tc.name, func(t *testing.T) { - m.dryRun = true - m.reverse = true - - file, err := ioutil.TempFile("", "outfile.txt") - if err != nil { - t.Fatalf("Failed to create tmpfile: %v", err) - } - defer os.Remove(file.Name()) - - if _, err := file.WriteString(tc.reverseData); err != nil { - t.Fatalf("Failed to write to file: %v", err) - } - - m.path = file.Name() - got := m.Execute() - if err = checkErr(tc.reverseError, got); err != nil { - t.Fatalf("Mitigate error mismatch: %v", err) +func checkSorted(threads []Thread) error { + if len(threads) < 2 { + return nil + } + last := threads[0].processorNumber + for _, t := range threads[1:] { + if last >= t.processorNumber { + return fmt.Errorf("threads out of order: thread %d before %d", t.processorNumber, last) } - }) - -} - -// checkErr checks error for equality. -func checkErr(want, got error) error { - switch { - case want == nil && got == nil: - case want != nil && got == nil: - fallthrough - case want == nil && got != nil: - fallthrough - case want.Error() != strings.Trim(got.Error(), " "): - return fmt.Errorf("got: %v want: %v", got, want) + last = t.processorNumber } return nil } diff --git a/runsc/mitigate/mock/BUILD b/runsc/mitigate/mock/BUILD new file mode 100644 index 000000000..5019ff9ee --- /dev/null +++ b/runsc/mitigate/mock/BUILD @@ -0,0 +1,11 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "mock", + srcs = ["mock.go"], + visibility = [ + "//runsc:__subpackages__", + ], +) diff --git a/runsc/mitigate/mock/mock.go b/runsc/mitigate/mock/mock.go new file mode 100644 index 000000000..2db718cb9 --- /dev/null +++ b/runsc/mitigate/mock/mock.go @@ -0,0 +1,141 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mock contains mock CPUs for mitigate tests. +package mock + +import "fmt" + +// CPU represents data from CPUs that will be mitigated. +type CPU struct { + Name string + VendorID string + Family int + Model int + ModelName string + Bugs string + PhysicalCores int + Cores int + ThreadsPerCore int +} + +// CascadeLake2 is a two core Intel CascadeLake machine. +var CascadeLake2 = CPU{ + Name: "CascadeLake", + VendorID: "GenuineIntel", + Family: 6, + Model: 85, + ModelName: "Intel(R) Xeon(R) CPU", + Bugs: "spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa", + PhysicalCores: 1, + Cores: 1, + ThreadsPerCore: 2, +} + +// CascadeLake4 is a four core Intel CascadeLake machine. +var CascadeLake4 = CPU{ + Name: "CascadeLake", + VendorID: "GenuineIntel", + Family: 6, + Model: 85, + ModelName: "Intel(R) Xeon(R) CPU", + Bugs: "spectre_v1 spectre_v2 spec_store_bypass mds swapgs taa", + PhysicalCores: 1, + Cores: 2, + ThreadsPerCore: 2, +} + +// Haswell2 is a two core Intel Haswell machine. +var Haswell2 = CPU{ + Name: "Haswell", + VendorID: "GenuineIntel", + Family: 6, + Model: 63, + ModelName: "Intel(R) Xeon(R) CPU", + Bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs", + PhysicalCores: 1, + Cores: 1, + ThreadsPerCore: 2, +} + +// Haswell2core is a 2 core Intel Haswell machine with no hyperthread pairs. +var Haswell2core = CPU{ + Name: "Haswell2Physical", + VendorID: "GenuineIntel", + Family: 6, + Model: 63, + ModelName: "Intel(R) Xeon(R) CPU", + Bugs: "cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs", + PhysicalCores: 2, + Cores: 1, + ThreadsPerCore: 1, +} + +// AMD8 is an eight core AMD machine. +var AMD8 = CPU{ + Name: "AMD", + VendorID: "AuthenticAMD", + Family: 23, + Model: 49, + ModelName: "AMD EPYC 7B12", + Bugs: "sysret_ss_attrs spectre_v1 spectre_v2 spec_store_bypass", + PhysicalCores: 4, + Cores: 1, + ThreadsPerCore: 2, +} + +// MakeCPUString makes a string formated like /proc/cpuinfo for each cpuTestCase +func (tc CPU) MakeCPUString() string { + template := `processor : %d +vendor_id : %s +cpu family : %d +model : %d +model name : %s +physical id : %d +core id : %d +cpu cores : %d +bugs : %s + +` + + ret := `` + for i := 0; i < tc.PhysicalCores; i++ { + for j := 0; j < tc.Cores; j++ { + for k := 0; k < tc.ThreadsPerCore; k++ { + processorNum := (i*tc.Cores+j)*tc.ThreadsPerCore + k + ret += fmt.Sprintf(template, + processorNum, /*processor*/ + tc.VendorID, /*vendor_id*/ + tc.Family, /*cpu family*/ + tc.Model, /*model*/ + tc.ModelName, /*model name*/ + i, /*physical id*/ + j, /*core id*/ + tc.Cores*tc.PhysicalCores, /*cpu cores*/ + tc.Bugs, /*bugs*/ + ) + } + } + } + return ret +} + +// MakeSysPossibleString makes a string representing a the contents of /sys/devices/system/cpu/possible. +func (tc CPU) MakeSysPossibleString() string { + max := tc.PhysicalCores * tc.Cores * tc.ThreadsPerCore + if max == 1 { + return "0" + } + return fmt.Sprintf("0-%d", max-1) +} diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index 9e429f7d5..f69558021 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -21,7 +21,6 @@ import ( "path/filepath" "runtime" "strconv" - "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/vishvananda/netlink" @@ -102,11 +101,11 @@ func joinNetNS(nsPath string) (func(), error) { // isRootNS determines whether we are running in the root net namespace. // /proc/sys/net/core/rmem_default only exists in root network namespace. func isRootNS() (bool, error) { - err := syscall.Access("/proc/sys/net/core/rmem_default", syscall.F_OK) + err := unix.Access("/proc/sys/net/core/rmem_default", unix.F_OK) switch err { case nil: return true, nil - case syscall.ENOENT: + case unix.ENOENT: return false, nil default: return false, fmt.Errorf("failed to access /proc/sys/net/core/rmem_default: %v", err) @@ -270,17 +269,17 @@ type socketEntry struct { func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (*socketEntry, error) { // Create the socket. const protocol = 0x0300 // htons(ETH_P_ALL) - fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, protocol) if err != nil { return nil, fmt.Errorf("unable to create raw socket: %v", err) } deviceFile := os.NewFile(uintptr(fd), "raw-device-fd") // Bind to the appropriate device. - ll := syscall.SockaddrLinklayer{ + ll := unix.SockaddrLinklayer{ Protocol: protocol, Ifindex: iface.Index, } - if err := syscall.Bind(fd, &ll); err != nil { + if err := unix.Bind(fd, &ll); err != nil { return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) } @@ -291,7 +290,7 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) ( return nil, fmt.Errorf("getting GSO for interface %q: %v", iface.Name, err) } if gso { - if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { + if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) } gsoMaxSize = ifaceLink.Attrs().GSOMaxSize @@ -307,18 +306,18 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) ( // incurring packet drops. const bufSize = 4 << 20 // 4MB. - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, bufSize); err != nil { - syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize) - sz, _ := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF) + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, bufSize); err != nil { + unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, bufSize) + sz, _ := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF) if sz < bufSize { log.Warningf("Failed to increase rcv buffer to %d on SOCK_RAW on %s. Current buffer %d: %v", bufSize, iface.Name, sz, err) } } - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, bufSize); err != nil { - syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize) - sz, _ := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, bufSize); err != nil { + unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF, bufSize) + sz, _ := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF) if sz < bufSize { log.Warningf("Failed to increase snd buffer to %d on SOCK_RAW on %s. Curent buffer %d: %v", bufSize, iface.Name, sz, err) } diff --git a/runsc/sandbox/network_unsafe.go b/runsc/sandbox/network_unsafe.go index 2a2a0fb7e..1b808a8a0 100644 --- a/runsc/sandbox/network_unsafe.go +++ b/runsc/sandbox/network_unsafe.go @@ -15,7 +15,6 @@ package sandbox import ( - "syscall" "unsafe" "golang.org/x/sys/unix" @@ -48,7 +47,7 @@ func isGSOEnabled(fd int, intf string) (bool, error) { ifrData: &val, } - if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), unix.SIOCETHTOOL, uintptr(unsafe.Pointer(&ifr))); err != 0 { + if _, _, err := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.SIOCETHTOOL, uintptr(unsafe.Pointer(&ifr))); err != 0 { return false, err } diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 7fe65c7ba..450f92645 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -30,6 +30,7 @@ import ( "github.com/cenkalti/backoff" specs "github.com/opencontainers/runtime-spec/specs-go" "github.com/syndtr/gocapability/capability" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/cleanup" "gvisor.dev/gvisor/pkg/control/client" "gvisor.dev/gvisor/pkg/control/server" @@ -83,7 +84,7 @@ type Sandbox struct { // child==true and the sandbox was waited on. This field allows for multiple // threads to wait on sandbox and get the exit code, since Linux will return // WaitStatus to one of the waiters only. - status syscall.WaitStatus + status unix.WaitStatus } // Args is used to configure a new sandbox. @@ -383,7 +384,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn binPath := specutils.ExePath cmd := exec.Command(binPath, conf.ToFlags()...) - cmd.SysProcAttr = &syscall.SysProcAttr{} + cmd.SysProcAttr = &unix.SysProcAttr{} // Open the log files to pass to the sandbox as FDs. // @@ -739,7 +740,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn if args.Attached { // Kill sandbox if parent process exits in attached mode. - cmd.SysProcAttr.Pdeathsig = syscall.SIGKILL + cmd.SysProcAttr.Pdeathsig = unix.SIGKILL // Tells boot that any process it creates must have pdeathsig set. cmd.Args = append(cmd.Args, "--attached") } @@ -762,7 +763,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn // // NOTE: The error message is checked because error types are lost over // rpc calls. - if strings.Contains(err.Error(), syscall.EACCES.Error()) { + if strings.Contains(err.Error(), unix.EACCES.Error()) { if permsErr := checkBinaryPermissions(conf); permsErr != nil { return fmt.Errorf("%v: %v", err, permsErr) } @@ -782,7 +783,7 @@ func (s *Sandbox) createSandboxProcess(conf *config.Config, args *Args, startSyn } // Wait waits for the containerized process to exit, and returns its WaitStatus. -func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) { +func (s *Sandbox) Wait(cid string) (unix.WaitStatus, error) { log.Debugf("Waiting for container %q in sandbox %q", cid, s.ID) if conn, err := s.sandboxConnect(); err != nil { @@ -790,14 +791,14 @@ func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) { // There is nothing we can do for subcontainers. For the init container, we // can try to get the sandbox exit code. if !s.IsRootContainer(cid) { - return syscall.WaitStatus(0), err + return unix.WaitStatus(0), err } log.Warningf("Wait on container %q failed: %v. Will try waiting on the sandbox process instead.", cid, err) } else { defer conn.Close() // Try the Wait RPC to the sandbox. - var ws syscall.WaitStatus + var ws unix.WaitStatus err = conn.Call(boot.ContainerWait, &cid, &ws) if err == nil { // It worked! @@ -805,7 +806,7 @@ func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) { } // See comment above. if !s.IsRootContainer(cid) { - return syscall.WaitStatus(0), err + return unix.WaitStatus(0), err } // The sandbox may have exited after we connected, but before @@ -817,10 +818,10 @@ func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) { // The best we can do is ask Linux what the sandbox exit status was, since in // most cases that will be the same as the container exit status. if err := s.waitForStopped(); err != nil { - return syscall.WaitStatus(0), err + return unix.WaitStatus(0), err } if !s.child { - return syscall.WaitStatus(0), fmt.Errorf("sandbox no longer running and its exit status is unavailable") + return unix.WaitStatus(0), fmt.Errorf("sandbox no longer running and its exit status is unavailable") } s.statusMu.Lock() @@ -830,9 +831,9 @@ func (s *Sandbox) Wait(cid string) (syscall.WaitStatus, error) { // WaitPID waits for process 'pid' in the container's sandbox and returns its // WaitStatus. -func (s *Sandbox) WaitPID(cid string, pid int32) (syscall.WaitStatus, error) { +func (s *Sandbox) WaitPID(cid string, pid int32) (unix.WaitStatus, error) { log.Debugf("Waiting for PID %d in sandbox %q", pid, s.ID) - var ws syscall.WaitStatus + var ws unix.WaitStatus conn, err := s.sandboxConnect() if err != nil { return ws, err @@ -861,7 +862,7 @@ func (s *Sandbox) destroy() error { log.Debugf("Destroy sandbox %q", s.ID) if s.Pid != 0 { log.Debugf("Killing sandbox %q", s.ID) - if err := syscall.Kill(s.Pid, syscall.SIGKILL); err != nil && err != syscall.ESRCH { + if err := unix.Kill(s.Pid, unix.SIGKILL); err != nil && err != unix.ESRCH { return fmt.Errorf("killing sandbox %q PID %q: %v", s.ID, s.Pid, err) } if err := s.waitForStopped(); err != nil { @@ -875,7 +876,7 @@ func (s *Sandbox) destroy() error { // SignalContainer sends the signal to a container in the sandbox. If all is // true and signal is SIGKILL, then waits for all processes to exit before // returning. -func (s *Sandbox) SignalContainer(cid string, sig syscall.Signal, all bool) error { +func (s *Sandbox) SignalContainer(cid string, sig unix.Signal, all bool) error { log.Debugf("Signal sandbox %q", s.ID) conn, err := s.sandboxConnect() if err != nil { @@ -903,7 +904,7 @@ func (s *Sandbox) SignalContainer(cid string, sig syscall.Signal, all bool) erro // fgProcess is true, then the signal is sent to the foreground process group // in the same session that PID belongs to. This is only valid if the process // is attached to a host TTY. -func (s *Sandbox) SignalProcess(cid string, pid int32, sig syscall.Signal, fgProcess bool) error { +func (s *Sandbox) SignalProcess(cid string, pid int32, sig unix.Signal, fgProcess bool) error { log.Debugf("Signal sandbox %q", s.ID) conn, err := s.sandboxConnect() if err != nil { @@ -984,7 +985,7 @@ func (s *Sandbox) Resume(cid string) error { func (s *Sandbox) IsRunning() bool { if s.Pid != 0 { // Send a signal 0 to the sandbox process. - if err := syscall.Kill(s.Pid, 0); err == nil { + if err := unix.Kill(s.Pid, 0); err == nil { // Succeeded, process is running. return true } @@ -1147,7 +1148,7 @@ func (s *Sandbox) waitForStopped() error { } // The sandbox process is a child of the current process, // so we can wait it and collect its zombie. - wpid, err := syscall.Wait4(int(s.Pid), &s.status, syscall.WNOHANG, nil) + wpid, err := unix.Wait4(int(s.Pid), &s.status, unix.WNOHANG, nil) if err != nil { return fmt.Errorf("error waiting the sandbox process: %v", err) } diff --git a/runsc/specutils/fs.go b/runsc/specutils/fs.go index 138aa4dd1..b62504a8c 100644 --- a/runsc/specutils/fs.go +++ b/runsc/specutils/fs.go @@ -18,9 +18,9 @@ import ( "fmt" "math/bits" "path" - "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" ) type mapping struct { @@ -31,48 +31,48 @@ type mapping struct { // optionsMap maps mount propagation-related OCI filesystem options to mount(2) // syscall flags. var optionsMap = map[string]mapping{ - "acl": {set: true, val: syscall.MS_POSIXACL}, - "async": {set: false, val: syscall.MS_SYNCHRONOUS}, - "atime": {set: false, val: syscall.MS_NOATIME}, - "bind": {set: true, val: syscall.MS_BIND}, + "acl": {set: true, val: unix.MS_POSIXACL}, + "async": {set: false, val: unix.MS_SYNCHRONOUS}, + "atime": {set: false, val: unix.MS_NOATIME}, + "bind": {set: true, val: unix.MS_BIND}, "defaults": {set: true, val: 0}, - "dev": {set: false, val: syscall.MS_NODEV}, - "diratime": {set: false, val: syscall.MS_NODIRATIME}, - "dirsync": {set: true, val: syscall.MS_DIRSYNC}, - "exec": {set: false, val: syscall.MS_NOEXEC}, - "noexec": {set: true, val: syscall.MS_NOEXEC}, - "iversion": {set: true, val: syscall.MS_I_VERSION}, - "loud": {set: false, val: syscall.MS_SILENT}, - "mand": {set: true, val: syscall.MS_MANDLOCK}, - "noacl": {set: false, val: syscall.MS_POSIXACL}, - "noatime": {set: true, val: syscall.MS_NOATIME}, - "nodev": {set: true, val: syscall.MS_NODEV}, - "nodiratime": {set: true, val: syscall.MS_NODIRATIME}, - "noiversion": {set: false, val: syscall.MS_I_VERSION}, - "nomand": {set: false, val: syscall.MS_MANDLOCK}, - "norelatime": {set: false, val: syscall.MS_RELATIME}, - "nostrictatime": {set: false, val: syscall.MS_STRICTATIME}, - "nosuid": {set: true, val: syscall.MS_NOSUID}, - "rbind": {set: true, val: syscall.MS_BIND | syscall.MS_REC}, - "relatime": {set: true, val: syscall.MS_RELATIME}, - "remount": {set: true, val: syscall.MS_REMOUNT}, - "ro": {set: true, val: syscall.MS_RDONLY}, - "rw": {set: false, val: syscall.MS_RDONLY}, - "silent": {set: true, val: syscall.MS_SILENT}, - "strictatime": {set: true, val: syscall.MS_STRICTATIME}, - "suid": {set: false, val: syscall.MS_NOSUID}, - "sync": {set: true, val: syscall.MS_SYNCHRONOUS}, + "dev": {set: false, val: unix.MS_NODEV}, + "diratime": {set: false, val: unix.MS_NODIRATIME}, + "dirsync": {set: true, val: unix.MS_DIRSYNC}, + "exec": {set: false, val: unix.MS_NOEXEC}, + "noexec": {set: true, val: unix.MS_NOEXEC}, + "iversion": {set: true, val: unix.MS_I_VERSION}, + "loud": {set: false, val: unix.MS_SILENT}, + "mand": {set: true, val: unix.MS_MANDLOCK}, + "noacl": {set: false, val: unix.MS_POSIXACL}, + "noatime": {set: true, val: unix.MS_NOATIME}, + "nodev": {set: true, val: unix.MS_NODEV}, + "nodiratime": {set: true, val: unix.MS_NODIRATIME}, + "noiversion": {set: false, val: unix.MS_I_VERSION}, + "nomand": {set: false, val: unix.MS_MANDLOCK}, + "norelatime": {set: false, val: unix.MS_RELATIME}, + "nostrictatime": {set: false, val: unix.MS_STRICTATIME}, + "nosuid": {set: true, val: unix.MS_NOSUID}, + "rbind": {set: true, val: unix.MS_BIND | unix.MS_REC}, + "relatime": {set: true, val: unix.MS_RELATIME}, + "remount": {set: true, val: unix.MS_REMOUNT}, + "ro": {set: true, val: unix.MS_RDONLY}, + "rw": {set: false, val: unix.MS_RDONLY}, + "silent": {set: true, val: unix.MS_SILENT}, + "strictatime": {set: true, val: unix.MS_STRICTATIME}, + "suid": {set: false, val: unix.MS_NOSUID}, + "sync": {set: true, val: unix.MS_SYNCHRONOUS}, } // propOptionsMap is similar to optionsMap, but it lists propagation options // that cannot be used together with other flags. var propOptionsMap = map[string]mapping{ - "private": {set: true, val: syscall.MS_PRIVATE}, - "rprivate": {set: true, val: syscall.MS_PRIVATE | syscall.MS_REC}, - "slave": {set: true, val: syscall.MS_SLAVE}, - "rslave": {set: true, val: syscall.MS_SLAVE | syscall.MS_REC}, - "unbindable": {set: true, val: syscall.MS_UNBINDABLE}, - "runbindable": {set: true, val: syscall.MS_UNBINDABLE | syscall.MS_REC}, + "private": {set: true, val: unix.MS_PRIVATE}, + "rprivate": {set: true, val: unix.MS_PRIVATE | unix.MS_REC}, + "slave": {set: true, val: unix.MS_SLAVE}, + "rslave": {set: true, val: unix.MS_SLAVE | unix.MS_REC}, + "unbindable": {set: true, val: unix.MS_UNBINDABLE}, + "runbindable": {set: true, val: unix.MS_UNBINDABLE | unix.MS_REC}, } // invalidOptions list options not allowed. @@ -139,7 +139,7 @@ func ValidateMountOptions(opts []string) error { // correct. func validateRootfsPropagation(opt string) error { flags := PropOptionsToFlags([]string{opt}) - if flags&(syscall.MS_SLAVE|syscall.MS_PRIVATE) == 0 { + if flags&(unix.MS_SLAVE|unix.MS_PRIVATE) == 0 { return fmt.Errorf("root mount propagation option must specify private or slave: %q", opt) } return validatePropagation(opt) @@ -147,7 +147,7 @@ func validateRootfsPropagation(opt string) error { func validatePropagation(opt string) error { flags := PropOptionsToFlags([]string{opt}) - exclusive := flags & (syscall.MS_SLAVE | syscall.MS_PRIVATE | syscall.MS_SHARED | syscall.MS_UNBINDABLE) + exclusive := flags & (unix.MS_SLAVE | unix.MS_PRIVATE | unix.MS_SHARED | unix.MS_UNBINDABLE) if bits.OnesCount32(exclusive) > 1 { return fmt.Errorf("mount propagation options are mutually exclusive: %q", opt) } diff --git a/runsc/specutils/namespace.go b/runsc/specutils/namespace.go index 23001d67c..69d7ba5c4 100644 --- a/runsc/specutils/namespace.go +++ b/runsc/specutils/namespace.go @@ -109,7 +109,7 @@ func FilterNS(filter []specs.LinuxNamespaceType, s *specs.Spec) []specs.LinuxNam // setNS sets the namespace of the given type. It must be called with // OSThreadLocked. func setNS(fd, nsType uintptr) error { - if _, _, err := syscall.RawSyscall(unix.SYS_SETNS, fd, nsType, 0); err != 0 { + if _, _, err := unix.RawSyscall(unix.SYS_SETNS, fd, nsType, 0); err != 0 { return err } return nil @@ -158,7 +158,7 @@ func StartInNS(cmd *exec.Cmd, nss []specs.LinuxNamespace) error { defer runtime.UnlockOSThread() if cmd.SysProcAttr == nil { - cmd.SysProcAttr = &syscall.SysProcAttr{} + cmd.SysProcAttr = &unix.SysProcAttr{} } for _, ns := range nss { @@ -185,7 +185,7 @@ func SetUIDGIDMappings(cmd *exec.Cmd, s *specs.Spec) { return } if cmd.SysProcAttr == nil { - cmd.SysProcAttr = &syscall.SysProcAttr{} + cmd.SysProcAttr = &unix.SysProcAttr{} } for _, idMap := range s.Linux.UIDMappings { log.Infof("Mapping host uid %d to container uid %d (size=%d)", idMap.HostID, idMap.ContainerID, idMap.Size) @@ -241,8 +241,8 @@ func MaybeRunAsRoot() error { cmd := exec.Command("/proc/self/exe", os.Args[1:]...) - cmd.SysProcAttr = &syscall.SysProcAttr{ - Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS, + cmd.SysProcAttr = &unix.SysProcAttr{ + Cloneflags: unix.CLONE_NEWUSER | unix.CLONE_NEWNS, // Set current user/group as root inside the namespace. Since we may not // have CAP_SETUID/CAP_SETGID, just map root to the current user/group. UidMappings: []syscall.SysProcIDMap{ @@ -255,7 +255,7 @@ func MaybeRunAsRoot() error { GidMappingsEnableSetgroups: false, // Make sure child is killed when the parent terminates. - Pdeathsig: syscall.SIGKILL, + Pdeathsig: unix.SIGKILL, } cmd.Env = os.Environ() diff --git a/runsc/specutils/seccomp/BUILD b/runsc/specutils/seccomp/BUILD index 3520f2d6d..e9e647d82 100644 --- a/runsc/specutils/seccomp/BUILD +++ b/runsc/specutils/seccomp/BUILD @@ -18,6 +18,7 @@ go_library( "//pkg/sentry/kernel", "//pkg/sentry/syscalls/linux", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", ], ) @@ -30,5 +31,6 @@ go_test( "//pkg/binary", "//pkg/bpf", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/runsc/specutils/seccomp/seccomp.go b/runsc/specutils/seccomp/seccomp.go index 5932f7a41..0ef7a4d54 100644 --- a/runsc/specutils/seccomp/seccomp.go +++ b/runsc/specutils/seccomp/seccomp.go @@ -18,9 +18,9 @@ package seccomp import ( "fmt" - "syscall" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bpf" "gvisor.dev/gvisor/pkg/log" @@ -33,9 +33,9 @@ var ( killThreadAction = linux.SECCOMP_RET_KILL_THREAD trapAction = linux.SECCOMP_RET_TRAP // runc always returns EPERM as the errorcode for SECCOMP_RET_ERRNO - errnoAction = linux.SECCOMP_RET_ERRNO.WithReturnCode(uint16(syscall.EPERM)) + errnoAction = linux.SECCOMP_RET_ERRNO.WithReturnCode(uint16(unix.EPERM)) // runc always returns EPERM as the errorcode for SECCOMP_RET_TRACE - traceAction = linux.SECCOMP_RET_TRACE.WithReturnCode(uint16(syscall.EPERM)) + traceAction = linux.SECCOMP_RET_TRACE.WithReturnCode(uint16(unix.EPERM)) allowAction = linux.SECCOMP_RET_ALLOW ) diff --git a/runsc/specutils/seccomp/seccomp_test.go b/runsc/specutils/seccomp/seccomp_test.go index 850c237ba..11a6c8daa 100644 --- a/runsc/specutils/seccomp/seccomp_test.go +++ b/runsc/specutils/seccomp/seccomp_test.go @@ -16,10 +16,10 @@ package seccomp import ( "fmt" - "syscall" "testing" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/bpf" ) @@ -184,7 +184,7 @@ var ( Args: []specs.LinuxSeccompArg{ { Index: 0, - Value: syscall.CLONE_FS, + Value: unix.CLONE_FS, Op: specs.OpEqualTo, }, }, @@ -192,7 +192,7 @@ var ( }, }, }, - input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_FS}), expected: uint32(errnoAction), }, { @@ -207,12 +207,12 @@ var ( Args: []specs.LinuxSeccompArg{ { Index: 0, - Value: syscall.CLONE_FS, + Value: unix.CLONE_FS, Op: specs.OpEqualTo, }, { Index: 0, - Value: syscall.CLONE_VM, + Value: unix.CLONE_VM, Op: specs.OpEqualTo, }, }, @@ -220,7 +220,7 @@ var ( }, }, }, - input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_FS}), expected: uint32(errnoAction), }, { @@ -235,12 +235,12 @@ var ( Args: []specs.LinuxSeccompArg{ { Index: 1, - Value: syscall.SOL_SOCKET, + Value: unix.SOL_SOCKET, Op: specs.OpEqualTo, }, { Index: 2, - Value: syscall.SO_PEERCRED, + Value: unix.SO_PEERCRED, Op: specs.OpEqualTo, }, }, @@ -248,7 +248,7 @@ var ( }, }, }, - input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET, syscall.SO_PEERCRED}), + input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, unix.SOL_SOCKET, unix.SO_PEERCRED}), expected: uint32(errnoAction), }, { @@ -263,12 +263,12 @@ var ( Args: []specs.LinuxSeccompArg{ { Index: 1, - Value: syscall.SOL_SOCKET, + Value: unix.SOL_SOCKET, Op: specs.OpEqualTo, }, { Index: 2, - Value: syscall.SO_PEERCRED, + Value: unix.SO_PEERCRED, Op: specs.OpEqualTo, }, }, @@ -276,7 +276,7 @@ var ( }, }, }, - input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, syscall.SOL_SOCKET}), + input: testInput(nativeArchAuditNo, "getsockopt", &[6]uint64{0, unix.SOL_SOCKET}), expected: uint32(allowAction), }, { @@ -291,7 +291,7 @@ var ( Args: []specs.LinuxSeccompArg{ { Index: 0, - Value: syscall.CLONE_FS, + Value: unix.CLONE_FS, Op: specs.OpEqualTo, }, }, @@ -299,7 +299,7 @@ var ( }, }, }, - input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_VM}), + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_VM}), expected: uint32(allowAction), }, { @@ -314,8 +314,8 @@ var ( Args: []specs.LinuxSeccompArg{ { Index: 0, - Value: syscall.CLONE_FS, - ValueTwo: syscall.CLONE_FS, + Value: unix.CLONE_FS, + ValueTwo: unix.CLONE_FS, Op: specs.OpMaskedEqual, }, }, @@ -323,7 +323,7 @@ var ( }, }, }, - input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS | syscall.CLONE_VM}), + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_FS | unix.CLONE_VM}), expected: uint32(errnoAction), }, { @@ -338,8 +338,8 @@ var ( Args: []specs.LinuxSeccompArg{ { Index: 0, - Value: syscall.CLONE_FS | syscall.CLONE_VM, - ValueTwo: syscall.CLONE_FS | syscall.CLONE_VM, + Value: unix.CLONE_FS | unix.CLONE_VM, + ValueTwo: unix.CLONE_FS | unix.CLONE_VM, Op: specs.OpMaskedEqual, }, }, @@ -347,7 +347,7 @@ var ( }, }, }, - input: testInput(nativeArchAuditNo, "clone", &[6]uint64{syscall.CLONE_FS}), + input: testInput(nativeArchAuditNo, "clone", &[6]uint64{unix.CLONE_FS}), expected: uint32(allowAction), }, { diff --git a/runsc/specutils/specutils.go b/runsc/specutils/specutils.go index ea55bbc7d..5ba38bfe4 100644 --- a/runsc/specutils/specutils.go +++ b/runsc/specutils/specutils.go @@ -26,12 +26,12 @@ import ( "path/filepath" "strconv" "strings" - "syscall" "time" "github.com/cenkalti/backoff" "github.com/mohae/deepcopy" specs "github.com/opencontainers/runtime-spec/specs-go" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/log" @@ -375,9 +375,9 @@ func WaitForReady(pid int, timeout time.Duration, ready func() (bool, error)) er // Check if the process is still running. // If the process is alive, child is 0 because of the NOHANG option. // If the process has terminated, child equals the process id. - var ws syscall.WaitStatus - var ru syscall.Rusage - child, err := syscall.Wait4(pid, &ws, syscall.WNOHANG, &ru) + var ws unix.WaitStatus + var ru unix.Rusage + child, err := unix.Wait4(pid, &ws, unix.WNOHANG, &ru) if err != nil { return backoff.Permanent(fmt.Errorf("error waiting for process: %v", err)) } else if child == pid { @@ -437,7 +437,7 @@ func Mount(src, dst, typ string, flags uint32) error { return fmt.Errorf("mkdir(%q) failed: %v", parent, err) } // Create the destination file if it does not exist. - f, err := os.OpenFile(dst, syscall.O_CREAT, 0777) + f, err := os.OpenFile(dst, unix.O_CREAT, 0777) if err != nil { return fmt.Errorf("open(%q) failed: %v", dst, err) } @@ -445,7 +445,7 @@ func Mount(src, dst, typ string, flags uint32) error { } // Do the mount. - if err := syscall.Mount(src, dst, typ, uintptr(flags), ""); err != nil { + if err := unix.Mount(src, dst, typ, uintptr(flags), ""); err != nil { return fmt.Errorf("mount(%q, %q, %d) failed: %v", src, dst, flags, err) } return nil @@ -466,7 +466,7 @@ func ContainsStr(strs []string, str string) bool { func RetryEintr(f func() (uintptr, uintptr, error)) (uintptr, uintptr, error) { for { r1, r2, err := f() - if err != syscall.EINTR { + if err != unix.EINTR { return r1, r2, err } } diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD index b4f967441..c94caab60 100644 --- a/test/benchmarks/fs/BUILD +++ b/test/benchmarks/fs/BUILD @@ -11,6 +11,7 @@ benchmark_test( "//pkg/test/dockerutil", "//test/benchmarks/harness", "//test/benchmarks/tools", + "@com_github_docker_docker//api/types/mount:go_default_library", ], ) diff --git a/test/benchmarks/fs/bazel_test.go b/test/benchmarks/fs/bazel_test.go index 8baeff0db..7ced963f6 100644 --- a/test/benchmarks/fs/bazel_test.go +++ b/test/benchmarks/fs/bazel_test.go @@ -25,6 +25,13 @@ import ( "gvisor.dev/gvisor/test/benchmarks/tools" ) +// Dimensions here are clean/dirty cache (do or don't drop caches) +// and if the mount on which we are compiling is a tmpfs/bind mount. +type benchmark struct { + clearCache bool // clearCache drops caches before running. + fstype string // type of filesystem to use. +} + // Note: CleanCache versions of this test require running with root permissions. func BenchmarkBuildABSL(b *testing.B) { runBuildBenchmark(b, "benchmarks/absl", "/abseil-cpp", "absl/base/...") @@ -45,17 +52,18 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { } defer machine.CleanUp() - // Dimensions here are clean/dirty cache (do or don't drop caches) - // and if the mount on which we are compiling is a tmpfs/bind mount. - benchmarks := []struct { - clearCache bool // clearCache drops caches before running. - tmpfs bool // tmpfs will run compilation on a tmpfs. - }{ - {clearCache: true, tmpfs: false}, - {clearCache: false, tmpfs: false}, - {clearCache: true, tmpfs: true}, - {clearCache: false, tmpfs: true}, + benchmarks := make([]benchmark, 0, 6) + for _, filesys := range []string{harness.BindFS, harness.TmpFS, harness.RootFS} { + benchmarks = append(benchmarks, benchmark{ + clearCache: true, + fstype: filesys, + }) + benchmarks = append(benchmarks, benchmark{ + clearCache: false, + fstype: filesys, + }) } + for _, bm := range benchmarks { pageCache := tools.Parameter{ Name: "page_cache", @@ -67,10 +75,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { filesystem := tools.Parameter{ Name: "filesystem", - Value: "bind", - } - if bm.tmpfs { - filesystem.Value = "tmpfs" + Value: bm.fstype, } name, err := tools.ParametersToName(pageCache, filesystem) if err != nil { @@ -83,21 +88,25 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { container := machine.GetContainer(ctx, b) defer container.CleanUp(ctx) + mts, prefix, cleanup, err := harness.MakeMount(machine, bm.fstype) + if err != nil { + b.Fatalf("Failed to make mount: %v", err) + } + defer cleanup() + + runOpts := dockerutil.RunOpts{ + Image: image, + Mounts: mts, + } + // Start a container and sleep. - if err := container.Spawn(ctx, dockerutil.RunOpts{ - Image: image, - }, "sleep", fmt.Sprintf("%d", 1000000)); err != nil { + if err := container.Spawn(ctx, runOpts, "sleep", fmt.Sprintf("%d", 1000000)); err != nil { b.Fatalf("run failed with: %v", err) } - // If we are running on a tmpfs, copy to /tmp which is a tmpfs. - prefix := "" - if bm.tmpfs { - if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, - "cp", "-r", workdir, "/tmp/."); err != nil { - b.Fatalf("failed to copy directory: %v (%s)", err, out) - } - prefix = "/tmp" + if out, err := container.Exec(ctx, dockerutil.ExecOpts{}, + "cp", "-rf", workdir, prefix+"/."); err != nil { + b.Fatalf("failed to copy directory: %v (%s)", err, out) } b.ResetTimer() @@ -118,7 +127,7 @@ func runBuildBenchmark(b *testing.B, image, workdir, target string) { WorkDir: prefix + workdir, }, "bazel", "build", "-c", "opt", target) if err != nil { - b.Fatalf("build failed with: %v", err) + b.Fatalf("build failed with: %v logs: %s", err, got) } b.StopTimer() diff --git a/test/benchmarks/fs/fio_test.go b/test/benchmarks/fs/fio_test.go index cc2d1cbbc..f783a2b33 100644 --- a/test/benchmarks/fs/fio_test.go +++ b/test/benchmarks/fs/fio_test.go @@ -21,7 +21,6 @@ import ( "strings" "testing" - "github.com/docker/docker/api/types/mount" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/test/benchmarks/harness" "gvisor.dev/gvisor/test/benchmarks/tools" @@ -70,7 +69,7 @@ func BenchmarkFio(b *testing.B) { } defer machine.CleanUp() - for _, fsType := range []mount.Type{mount.TypeBind, mount.TypeTmpfs} { + for _, fsType := range []string{harness.BindFS, harness.TmpFS, harness.RootFS} { for _, tc := range testCases { operation := tools.Parameter{ Name: "operation", @@ -82,7 +81,7 @@ func BenchmarkFio(b *testing.B) { } filesystem := tools.Parameter{ Name: "filesystem", - Value: string(fsType), + Value: fsType, } name, err := tools.ParametersToName(operation, blockSize, filesystem) if err != nil { @@ -95,13 +94,7 @@ func BenchmarkFio(b *testing.B) { container := machine.GetContainer(ctx, b) defer container.CleanUp(ctx) - // Directory and filename inside container where fio will read/write. - outdir := "/data" - outfile := filepath.Join(outdir, "test.txt") - - // Make the required mount and grab a cleanup for bind mounts - // as they are backed by a temp directory (mktemp). - mnt, mountCleanup, err := makeMount(machine, fsType, outdir) + mnts, outdir, mountCleanup, err := harness.MakeMount(machine, fsType) if err != nil { b.Fatalf("failed to make mount: %v", err) } @@ -109,12 +102,9 @@ func BenchmarkFio(b *testing.B) { // Start the container with the mount. if err := container.Spawn( - ctx, - dockerutil.RunOpts{ - Image: "benchmarks/fio", - Mounts: []mount.Mount{ - mnt, - }, + ctx, dockerutil.RunOpts{ + Image: "benchmarks/fio", + Mounts: mnts, }, // Sleep on the order of b.N. "sleep", fmt.Sprintf("%d", 1000*b.N), @@ -122,6 +112,9 @@ func BenchmarkFio(b *testing.B) { b.Fatalf("failed to start fio container with: %v", err) } + // Directory and filename inside container where fio will read/write. + outfile := filepath.Join(outdir, "test.txt") + // For reads, we need a file to read so make one inside the container. if strings.Contains(tc.Test, "read") { fallocateCmd := fmt.Sprintf("fallocate -l %dK %s", tc.Size, outfile) @@ -135,6 +128,7 @@ func BenchmarkFio(b *testing.B) { if err := harness.DropCaches(machine); err != nil { b.Skipf("failed to drop caches with %v. You probably need root.", err) } + cmd := tc.MakeCmd(outfile) if err := harness.DropCaches(machine); err != nil { @@ -154,39 +148,6 @@ func BenchmarkFio(b *testing.B) { } } -// makeMount makes a mount and cleanup based on the requested type. Bind -// and volume mounts are backed by a temp directory made with mktemp. -// tmpfs mounts require no such backing and are just made. -// It is up to the caller to call the returned cleanup. -func makeMount(machine harness.Machine, mountType mount.Type, target string) (mount.Mount, func(), error) { - switch mountType { - case mount.TypeVolume, mount.TypeBind: - dir, err := machine.RunCommand("mktemp", "-d") - if err != nil { - return mount.Mount{}, func() {}, fmt.Errorf("failed to create tempdir: %v", err) - } - dir = strings.TrimSuffix(dir, "\n") - - out, err := machine.RunCommand("chmod", "777", dir) - if err != nil { - machine.RunCommand("rm", "-rf", dir) - return mount.Mount{}, func() {}, fmt.Errorf("failed modify directory: %v %s", err, out) - } - return mount.Mount{ - Target: target, - Source: dir, - Type: mount.TypeBind, - }, func() { machine.RunCommand("rm", "-rf", dir) }, nil - case mount.TypeTmpfs: - return mount.Mount{ - Target: target, - Type: mount.TypeTmpfs, - }, func() {}, nil - default: - return mount.Mount{}, func() {}, fmt.Errorf("illegal mount time not supported: %v", mountType) - } -} - // TestMain is the main method for package fs. func TestMain(m *testing.M) { harness.Init() diff --git a/test/benchmarks/harness/BUILD b/test/benchmarks/harness/BUILD index c2e316709..116610938 100644 --- a/test/benchmarks/harness/BUILD +++ b/test/benchmarks/harness/BUILD @@ -14,5 +14,6 @@ go_library( deps = [ "//pkg/test/dockerutil", "//pkg/test/testutil", + "@com_github_docker_docker//api/types/mount:go_default_library", ], ) diff --git a/test/benchmarks/harness/util.go b/test/benchmarks/harness/util.go index aeac7ebff..36abe1069 100644 --- a/test/benchmarks/harness/util.go +++ b/test/benchmarks/harness/util.go @@ -18,8 +18,10 @@ import ( "context" "fmt" "net" + "strings" "testing" + "github.com/docker/docker/api/types/mount" "gvisor.dev/gvisor/pkg/test/dockerutil" "gvisor.dev/gvisor/pkg/test/testutil" ) @@ -55,3 +57,53 @@ func DebugLog(b *testing.B, msg string, args ...interface{}) { b.Logf(msg, args...) } } + +const ( + // BindFS indicates a bind mount should be created. + BindFS = "bindfs" + // TmpFS indicates a tmpfs mount should be created. + TmpFS = "tmpfs" + // RootFS indicates no mount should be created and the root mount should be used. + RootFS = "rootfs" +) + +// MakeMount makes a mount and cleanup based on the requested type. Bind +// and volume mounts are backed by a temp directory made with mktemp. +// tmpfs mounts require no such backing and are just made. +// rootfs mounts do not make a mount, but instead return a target direectory at root. +// It is up to the caller to call the returned cleanup. +func MakeMount(machine Machine, fsType string) ([]mount.Mount, string, func(), error) { + mounts := make([]mount.Mount, 0, 1) + switch fsType { + case BindFS: + dir, err := machine.RunCommand("mktemp", "-d") + if err != nil { + return mounts, "", func() {}, fmt.Errorf("failed to create tempdir: %v", err) + } + dir = strings.TrimSuffix(dir, "\n") + + out, err := machine.RunCommand("chmod", "777", dir) + if err != nil { + machine.RunCommand("rm", "-rf", dir) + return mounts, "", func() {}, fmt.Errorf("failed modify directory: %v %s", err, out) + } + target := "/data" + mounts = append(mounts, mount.Mount{ + Target: target, + Source: dir, + Type: mount.TypeBind, + }) + return mounts, target, func() { machine.RunCommand("rm", "-rf", dir) }, nil + case RootFS: + return mounts, "/", func() {}, nil + case TmpFS: + target := "/data" + mounts = append(mounts, mount.Mount{ + Target: target, + Type: mount.TypeTmpfs, + }) + return mounts, target, func() {}, nil + default: + return mounts, "", func() {}, fmt.Errorf("illegal mount type not supported: %v", fsType) + } +} diff --git a/test/benchmarks/tcp/tcp_proxy.go b/test/benchmarks/tcp/tcp_proxy.go index 780e4f7ae..be74e4d4a 100644 --- a/test/benchmarks/tcp/tcp_proxy.go +++ b/test/benchmarks/tcp/tcp_proxy.go @@ -29,7 +29,6 @@ import ( "runtime" "runtime/pprof" "strconv" - "syscall" "time" "golang.org/x/sys/unix" @@ -112,33 +111,33 @@ func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) { const protocol = 0x0300 // htons(ETH_P_ALL) fds := make([]int, numChannels) for i := range fds { - fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol) + fd, err := unix.Socket(unix.AF_PACKET, unix.SOCK_RAW, protocol) if err != nil { return nil, fmt.Errorf("unable to create raw socket: %v", err) } // Bind to the appropriate device. - ll := syscall.SockaddrLinklayer{ + ll := unix.SockaddrLinklayer{ Protocol: protocol, Ifindex: iface.Index, - Pkttype: syscall.PACKET_HOST, + Pkttype: unix.PACKET_HOST, } - if err := syscall.Bind(fd, &ll); err != nil { + if err := unix.Bind(fd, &ll); err != nil { return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err) } // RAW Sockets by default have a very small SO_RCVBUF of 256KB, // up it to at least 4MB to reduce packet drops. - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil { + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, bufSize); err != nil { return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err) } - if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil { + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_SNDBUF, bufSize); err != nil { return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err) } if !*swgso && *gso != 0 { - if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { + if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil { return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err) } } @@ -403,7 +402,7 @@ func main() { log.Printf("client=%v, server=%v, ready.", *client, *server) sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM) + signal.Notify(sigs, unix.SIGTERM) go func() { <-sigs if *cpuprofile != "" { diff --git a/test/fuse/linux/mount_test.cc b/test/fuse/linux/mount_test.cc index 8a5478116..276f842ea 100644 --- a/test/fuse/linux/mount_test.cc +++ b/test/fuse/linux/mount_test.cc @@ -15,6 +15,7 @@ #include <errno.h> #include <fcntl.h> #include <sys/mount.h> +#include <unistd.h> #include "gtest/gtest.h" #include "test/util/mount_util.h" @@ -29,7 +30,9 @@ namespace { TEST(FuseMount, Success) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/fuse", O_WRONLY)); - std::string mopts = absl::StrCat("fd=", std::to_string(fd.get())); + std::string mopts = + absl::StrFormat("fd=%d,user_id=%d,group_id=%d,rootmode=0777", fd.get(), + getuid(), getgid()); const auto dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); diff --git a/test/iptables/BUILD b/test/iptables/BUILD index ae4bba847..94d4ca2d4 100644 --- a/test/iptables/BUILD +++ b/test/iptables/BUILD @@ -18,6 +18,7 @@ go_library( "//pkg/binary", "//pkg/test/testutil", "//pkg/usermem", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/iptables/iptables_unsafe.go b/test/iptables/iptables_unsafe.go index bd85a8fea..dd1a1c082 100644 --- a/test/iptables/iptables_unsafe.go +++ b/test/iptables/iptables_unsafe.go @@ -16,12 +16,13 @@ package iptables import ( "fmt" - "syscall" "unsafe" + + "golang.org/x/sys/unix" ) type originalDstError struct { - errno syscall.Errno + errno unix.Errno } func (e originalDstError) Error() string { @@ -32,27 +33,27 @@ func (e originalDstError) Error() string { // getsockopt. const SO_ORIGINAL_DST = 80 -func originalDestination4(connfd int) (syscall.RawSockaddrInet4, error) { - var addr syscall.RawSockaddrInet4 - var addrLen uint32 = syscall.SizeofSockaddrInet4 - if errno := originalDestination(connfd, syscall.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 { - return syscall.RawSockaddrInet4{}, originalDstError{errno} +func originalDestination4(connfd int) (unix.RawSockaddrInet4, error) { + var addr unix.RawSockaddrInet4 + var addrLen uint32 = unix.SizeofSockaddrInet4 + if errno := originalDestination(connfd, unix.SOL_IP, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return unix.RawSockaddrInet4{}, originalDstError{errno} } return addr, nil } -func originalDestination6(connfd int) (syscall.RawSockaddrInet6, error) { - var addr syscall.RawSockaddrInet6 - var addrLen uint32 = syscall.SizeofSockaddrInet6 - if errno := originalDestination(connfd, syscall.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 { - return syscall.RawSockaddrInet6{}, originalDstError{errno} +func originalDestination6(connfd int) (unix.RawSockaddrInet6, error) { + var addr unix.RawSockaddrInet6 + var addrLen uint32 = unix.SizeofSockaddrInet6 + if errno := originalDestination(connfd, unix.SOL_IPV6, unsafe.Pointer(&addr), &addrLen); errno != 0 { + return unix.RawSockaddrInet6{}, originalDstError{errno} } return addr, nil } -func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) syscall.Errno { - _, _, errno := syscall.Syscall6( - syscall.SYS_GETSOCKOPT, +func originalDestination(connfd int, level uintptr, optval unsafe.Pointer, optlen *uint32) unix.Errno { + _, _, errno := unix.Syscall6( + unix.SYS_GETSOCKOPT, uintptr(connfd), level, SO_ORIGINAL_DST, diff --git a/test/iptables/nat.go b/test/iptables/nat.go index 7f1d6d7ad..70d8a1832 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -19,8 +19,8 @@ import ( "errors" "fmt" "net" - "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/usermem" ) @@ -584,33 +584,33 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net. // traditional syscalls. // Create the listening socket, bind, listen, and accept. - family := syscall.AF_INET + family := unix.AF_INET if ipv6 { - family = syscall.AF_INET6 + family = unix.AF_INET6 } - sockfd, err := syscall.Socket(family, syscall.SOCK_STREAM, 0) + sockfd, err := unix.Socket(family, unix.SOCK_STREAM, 0) if err != nil { return err } - defer syscall.Close(sockfd) + defer unix.Close(sockfd) - var bindAddr syscall.Sockaddr + var bindAddr unix.Sockaddr if ipv6 { - bindAddr = &syscall.SockaddrInet6{ + bindAddr = &unix.SockaddrInet6{ Port: acceptPort, Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any } } else { - bindAddr = &syscall.SockaddrInet4{ + bindAddr = &unix.SockaddrInet4{ Port: acceptPort, Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY } } - if err := syscall.Bind(sockfd, bindAddr); err != nil { + if err := unix.Bind(sockfd, bindAddr); err != nil { return err } - if err := syscall.Listen(sockfd, 1); err != nil { + if err := unix.Listen(sockfd, 1); err != nil { return err } @@ -619,8 +619,8 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net. errCh := make(chan error) go func() { for { - connFD, _, err := syscall.Accept(sockfd) - if errors.Is(err, syscall.EINTR) { + connFD, _, err := unix.Accept(sockfd) + if errors.Is(err, unix.EINTR) { continue } if err != nil { @@ -641,7 +641,7 @@ func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net. return err case connFD = <-connCh: } - defer syscall.Close(connFD) + defer unix.Close(connFD) // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST // indicates the packet was sent to originalDst:dropPort. @@ -764,35 +764,35 @@ func recvWithRECVORIGDSTADDR(ctx context.Context, ipv6 bool, expectedDst *net.IP // Create the listening socket. var ( - family = syscall.AF_INET - level = syscall.SOL_IP - option = syscall.IP_RECVORIGDSTADDR - bindAddr syscall.Sockaddr = &syscall.SockaddrInet4{ + family = unix.AF_INET + level = unix.SOL_IP + option = unix.IP_RECVORIGDSTADDR + bindAddr unix.Sockaddr = &unix.SockaddrInet4{ Port: int(port), Addr: [4]byte{0, 0, 0, 0}, // INADDR_ANY } ) if ipv6 { - family = syscall.AF_INET6 - level = syscall.SOL_IPV6 + family = unix.AF_INET6 + level = unix.SOL_IPV6 option = 74 // IPV6_RECVORIGDSTADDR, which is missing from the syscall package. - bindAddr = &syscall.SockaddrInet6{ + bindAddr = &unix.SockaddrInet6{ Port: int(port), Addr: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, // in6addr_any } } - sockfd, err := syscall.Socket(family, syscall.SOCK_DGRAM, 0) + sockfd, err := unix.Socket(family, unix.SOCK_DGRAM, 0) if err != nil { - return fmt.Errorf("failed Socket(%d, %d, 0): %w", family, syscall.SOCK_DGRAM, err) + return fmt.Errorf("failed Socket(%d, %d, 0): %w", family, unix.SOCK_DGRAM, err) } - defer syscall.Close(sockfd) + defer unix.Close(sockfd) - if err := syscall.Bind(sockfd, bindAddr); err != nil { + if err := unix.Bind(sockfd, bindAddr); err != nil { return fmt.Errorf("failed Bind(%d, %+v): %v", sockfd, bindAddr, err) } // Enable IP_RECVORIGDSTADDR. - if err := syscall.SetsockoptInt(sockfd, level, option, 1); err != nil { + if err := unix.SetsockoptInt(sockfd, level, option, 1); err != nil { return fmt.Errorf("failed SetsockoptByte(%d, %d, %d, 1): %v", sockfd, level, option, err) } @@ -837,41 +837,41 @@ func recvWithRECVORIGDSTADDR(ctx context.Context, ipv6 bool, expectedDst *net.IP // Verify that the address has the post-NAT port and address. if ipv6 { - return addrMatches6(addr.(syscall.RawSockaddrInet6), localAddrs, redirectPort) + return addrMatches6(addr.(unix.RawSockaddrInet6), localAddrs, redirectPort) } - return addrMatches4(addr.(syscall.RawSockaddrInet4), localAddrs, redirectPort) + return addrMatches4(addr.(unix.RawSockaddrInet4), localAddrs, redirectPort) } -func recvOrigDstAddr4(sockfd int) (syscall.RawSockaddrInet4, error) { - buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet4) +func recvOrigDstAddr4(sockfd int) (unix.RawSockaddrInet4, error) { + buf, err := recvOrigDstAddr(sockfd, unix.SOL_IP, unix.SizeofSockaddrInet4) if err != nil { - return syscall.RawSockaddrInet4{}, err + return unix.RawSockaddrInet4{}, err } - var addr syscall.RawSockaddrInet4 + var addr unix.RawSockaddrInet4 binary.Unmarshal(buf, usermem.ByteOrder, &addr) return addr, nil } -func recvOrigDstAddr6(sockfd int) (syscall.RawSockaddrInet6, error) { - buf, err := recvOrigDstAddr(sockfd, syscall.SOL_IP, syscall.SizeofSockaddrInet6) +func recvOrigDstAddr6(sockfd int) (unix.RawSockaddrInet6, error) { + buf, err := recvOrigDstAddr(sockfd, unix.SOL_IP, unix.SizeofSockaddrInet6) if err != nil { - return syscall.RawSockaddrInet6{}, err + return unix.RawSockaddrInet6{}, err } - var addr syscall.RawSockaddrInet6 + var addr unix.RawSockaddrInet6 binary.Unmarshal(buf, usermem.ByteOrder, &addr) return addr, nil } func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) { buf := make([]byte, 64) - oob := make([]byte, syscall.CmsgSpace(addrSize)) + oob := make([]byte, unix.CmsgSpace(addrSize)) for { - _, oobn, _, _, err := syscall.Recvmsg( + _, oobn, _, _, err := unix.Recvmsg( sockfd, buf, // Message buffer. oob, // Out-of-band buffer. 0) // Flags. - if errors.Is(err, syscall.EINTR) { + if errors.Is(err, unix.EINTR) { continue } if err != nil { @@ -880,7 +880,7 @@ func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) { oob = oob[:oobn] // Parse out the control message. - msgs, err := syscall.ParseSocketControlMessage(oob) + msgs, err := unix.ParseSocketControlMessage(oob) if err != nil { return nil, fmt.Errorf("failed to parse control message: %w", err) } @@ -888,10 +888,10 @@ func recvOrigDstAddr(sockfd int, level uintptr, addrSize int) ([]byte, error) { } } -func addrMatches4(got syscall.RawSockaddrInet4, wantAddrs []net.IP, port uint16) error { +func addrMatches4(got unix.RawSockaddrInet4, wantAddrs []net.IP, port uint16) error { for _, wantAddr := range wantAddrs { - want := syscall.RawSockaddrInet4{ - Family: syscall.AF_INET, + want := unix.RawSockaddrInet4{ + Family: unix.AF_INET, Port: htons(port), } copy(want.Addr[:], wantAddr.To4()) @@ -902,10 +902,10 @@ func addrMatches4(got syscall.RawSockaddrInet4, wantAddrs []net.IP, port uint16) return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs) } -func addrMatches6(got syscall.RawSockaddrInet6, wantAddrs []net.IP, port uint16) error { +func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) error { for _, wantAddr := range wantAddrs { - want := syscall.RawSockaddrInet6{ - Family: syscall.AF_INET6, + want := unix.RawSockaddrInet6{ + Family: unix.AF_INET6, Port: htons(port), } copy(want.Addr[:], wantAddr.To16()) diff --git a/test/packetimpact/dut/posix_server.cc b/test/packetimpact/dut/posix_server.cc index 0d93b806e..ea83bbe72 100644 --- a/test/packetimpact/dut/posix_server.cc +++ b/test/packetimpact/dut/posix_server.cc @@ -395,7 +395,8 @@ class PosixImpl final : public posix_server::Posix::Service { ::grpc::Status Shutdown(grpc::ServerContext *context, const ::posix_server::ShutdownRequest *request, ::posix_server::ShutdownResponse *response) override { - if (shutdown(request->fd(), request->how()) < 0) { + response->set_ret(shutdown(request->fd(), request->how())); + if (response->ret() < 0) { response->set_errno_(errno); } return ::grpc::Status::OK; diff --git a/test/packetimpact/proto/posix_server.proto b/test/packetimpact/proto/posix_server.proto index 521f03465..175a65336 100644 --- a/test/packetimpact/proto/posix_server.proto +++ b/test/packetimpact/proto/posix_server.proto @@ -214,7 +214,8 @@ message ShutdownRequest { } message ShutdownResponse { - int32 errno_ = 1; // "errno" may fail to compile in c++. + int32 ret = 1; + int32 errno_ = 2; // "errno" may fail to compile in c++. } message RecvRequest { diff --git a/test/packetimpact/runner/defs.bzl b/test/packetimpact/runner/defs.bzl index 8ce5edf2b..567f64c41 100644 --- a/test/packetimpact/runner/defs.bzl +++ b/test/packetimpact/runner/defs.bzl @@ -203,6 +203,11 @@ ALL_TESTS = [ name = "tcp_outside_the_window", ), PacketimpactTestInfo( + name = "tcp_outside_the_window_closing", + # TODO(b/181625316): Fix netstack then merge into tcp_outside_the_window. + expect_netstack_failure = True, + ), + PacketimpactTestInfo( name = "tcp_noaccept_close_rst", ), PacketimpactTestInfo( @@ -212,6 +217,11 @@ ALL_TESTS = [ name = "tcp_unacc_seq_ack", ), PacketimpactTestInfo( + name = "tcp_unacc_seq_ack_closing", + # TODO(b/181625316): Fix netstack then merge into tcp_unacc_seq_ack. + expect_netstack_failure = True, + ), + PacketimpactTestInfo( name = "tcp_paws_mechanism", # TODO(b/156682000): Fix netstack then remove the line below. expect_netstack_failure = True, @@ -277,6 +287,11 @@ ALL_TESTS = [ PacketimpactTestInfo( name = "tcp_info", ), + PacketimpactTestInfo( + name = "tcp_fin_retransmission", + # TODO(b/181625316): Fix netstack then remove the line below. + expect_netstack_failure = True, + ), ] def validate_all_tests(): diff --git a/test/packetimpact/runner/dut.go b/test/packetimpact/runner/dut.go index 3da265b78..1064ca976 100644 --- a/test/packetimpact/runner/dut.go +++ b/test/packetimpact/runner/dut.go @@ -109,6 +109,7 @@ type dutInfo struct { dut DUT ctrlNet, testNet *dockerutil.Network netInfo *testbench.DUTTestNet + uname *testbench.DUTUname } // setUpDUT will set up one DUT and return information for setting up the @@ -182,6 +183,10 @@ func setUpDUT(ctx context.Context, t *testing.T, id int, mkDevice func(*dockerut POSIXServerIP: AddressInSubnet(DUTAddr, *ctrlNet.Subnet), POSIXServerPort: CtrlPort, } + info.uname, err = dut.Uname(ctx) + if err != nil { + return dutInfo{}, fmt.Errorf("failed to get uname information on DUT: %w", err) + } return info, nil } @@ -195,7 +200,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co dutInfoChan := make(chan dutInfo, numDUTs) errChan := make(chan error, numDUTs) var dockerNetworks []*dockerutil.Network - var dutTestNets []*testbench.DUTTestNet + var dutInfos []*testbench.DUTInfo var duts []DUT setUpCtx, cancelSetup := context.WithCancel(ctx) @@ -214,7 +219,10 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co select { case info := <-dutInfoChan: dockerNetworks = append(dockerNetworks, info.ctrlNet, info.testNet) - dutTestNets = append(dutTestNets, info.netInfo) + dutInfos = append(dutInfos, &testbench.DUTInfo{ + Net: info.netInfo, + Uname: info.uname, + }) duts = append(duts, info.dut) case err := <-errChan: t.Fatal(err) @@ -241,28 +249,29 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co testbenchContainer, testbenchAddr, dockerNetworks, + nil, /* sysctls */ "tail", "-f", "/dev/null", ); err != nil { t.Fatalf("cannot start testbench container: %s", err) } - for i := range dutTestNets { - name, info, err := deviceByIP(ctx, testbenchContainer, dutTestNets[i].LocalIPv4) + for i := range dutInfos { + name, info, err := deviceByIP(ctx, testbenchContainer, dutInfos[i].Net.LocalIPv4) if err != nil { - t.Fatalf("failed to get the device name associated with %s: %s", dutTestNets[i].LocalIPv4, err) + t.Fatalf("failed to get the device name associated with %s: %s", dutInfos[i].Net.LocalIPv4, err) } - dutTestNets[i].LocalDevName = name - dutTestNets[i].LocalDevID = info.ID - dutTestNets[i].LocalMAC = info.MAC + dutInfos[i].Net.LocalDevName = name + dutInfos[i].Net.LocalDevID = info.ID + dutInfos[i].Net.LocalMAC = info.MAC localIPv6, err := getOrAssignIPv6Addr(ctx, testbenchContainer, name) if err != nil { t.Fatalf("failed to get IPV6 address on %s: %s", testbenchContainer.Name, err) } - dutTestNets[i].LocalIPv6 = localIPv6 + dutInfos[i].Net.LocalIPv6 = localIPv6 } - dutTestNetsBytes, err := json.Marshal(dutTestNets) + dutInfosBytes, err := json.Marshal(dutInfos) if err != nil { - t.Fatalf("failed to marshal %v into json: %s", dutTestNets, err) + t.Fatalf("failed to marshal %v into json: %s", dutInfos, err) } baseSnifferArgs := []string{ @@ -296,7 +305,8 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co "-n", } } - for _, n := range dutTestNets { + for _, info := range dutInfos { + n := info.Net snifferArgs := append(baseSnifferArgs, "-i", n.LocalDevName) if !tshark { snifferArgs = append( @@ -351,7 +361,7 @@ func TestWithDUT(ctx context.Context, t *testing.T, mkDevice func(*dockerutil.Co testArgs = append(testArgs, extraTestArgs...) testArgs = append(testArgs, fmt.Sprintf("--native=%t", native), - "--dut_test_nets_json", string(dutTestNetsBytes), + "--dut_infos_json", string(dutInfosBytes), ) testbenchLogs, err := testbenchContainer.Exec(ctx, dockerutil.ExecOpts{}, testArgs...) if (err != nil) != expectFailure { @@ -388,6 +398,10 @@ type DUT interface { // The t parameter is supposed to be used for t.Cleanup. Don't use it for // t.Fatal/FailNow functions. Prepare(ctx context.Context, t *testing.T, runOpts dockerutil.RunOpts, ctrlNet, testNet *dockerutil.Network) (net.IP, net.HardwareAddr, uint32, string, error) + + // Uname gathers information of DUT using command uname. + Uname(ctx context.Context) (*testbench.DUTUname, error) + // Logs retrieves the logs from the dut. Logs(ctx context.Context) (string, error) } @@ -415,6 +429,10 @@ func (dut *DockerDUT) Prepare(ctx context.Context, _ *testing.T, runOpts dockeru dut.c, DUTAddr, []*dockerutil.Network{ctrlNet, testNet}, + map[string]string{ + // This enables creating ICMP sockets on Linux. + "net.ipv4.ping_group_range": "0 0", + }, containerPosixServerBinary, "--ip=0.0.0.0", fmt.Sprintf("--port=%d", CtrlPort), @@ -440,6 +458,38 @@ func (dut *DockerDUT) Prepare(ctx context.Context, _ *testing.T, runOpts dockeru return remoteIPv6, dutDeviceInfo.MAC, dutDeviceInfo.ID, testNetDev, nil } +// Uname implements DUT.Uname. +func (dut *DockerDUT) Uname(ctx context.Context) (*testbench.DUTUname, error) { + machine, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-m") + if err != nil { + return nil, err + } + kernelRelease, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-r") + if err != nil { + return nil, err + } + kernelVersion, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-v") + if err != nil { + return nil, err + } + kernelName, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-s") + if err != nil { + return nil, err + } + // TODO(gvisor.dev/issues/5586): -o is not supported on macOS. + operatingSystem, err := dut.c.Exec(ctx, dockerutil.ExecOpts{}, "uname", "-o") + if err != nil { + return nil, err + } + return &testbench.DUTUname{ + Machine: strings.TrimRight(machine, "\n"), + KernelName: strings.TrimRight(kernelName, "\n"), + KernelRelease: strings.TrimRight(kernelRelease, "\n"), + KernelVersion: strings.TrimRight(kernelVersion, "\n"), + OperatingSystem: strings.TrimRight(operatingSystem, "\n"), + }, nil +} + // Logs implements DUT.Logs. func (dut *DockerDUT) Logs(ctx context.Context) (string, error) { logs, err := dut.c.Logs(ctx) @@ -545,11 +595,14 @@ func createDockerNetwork(ctx context.Context, n *dockerutil.Network) error { // StartContainer will create a container instance from runOpts, connect it // with the specified docker networks and start executing the specified cmd. -func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, cmd ...string) error { +func StartContainer(ctx context.Context, runOpts dockerutil.RunOpts, c *dockerutil.Container, containerAddr net.IP, ns []*dockerutil.Network, sysctls map[string]string, cmd ...string) error { conf, hostconf, netconf := c.ConfigsFrom(runOpts, cmd...) _ = netconf hostconf.AutoRemove = true hostconf.Sysctls = map[string]string{"net.ipv6.conf.all.disable_ipv6": "0"} + for k, v := range sysctls { + hostconf.Sysctls[k] = v + } if err := c.CreateFrom(ctx, runOpts.Image, conf, hostconf, nil); err != nil { return fmt.Errorf("unable to create container %s: %w", c.Name, err) diff --git a/test/packetimpact/testbench/BUILD b/test/packetimpact/testbench/BUILD index 983c2c030..43b4c7ca1 100644 --- a/test/packetimpact/testbench/BUILD +++ b/test/packetimpact/testbench/BUILD @@ -1,7 +1,6 @@ load("//tools:defs.bzl", "go_library", "go_test") package( - default_visibility = ["//test/packetimpact:__subpackages__"], licenses = ["notice"], ) @@ -15,6 +14,7 @@ go_library( "rawsockets.go", "testbench.go", ], + visibility = ["//test/packetimpact:__subpackages__"], deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/test/packetimpact/testbench/connections.go b/test/packetimpact/testbench/connections.go index 15e1a51de..8ad9040ff 100644 --- a/test/packetimpact/testbench/connections.go +++ b/test/packetimpact/testbench/connections.go @@ -677,17 +677,17 @@ func (conn *TCPIPv4) Connect(t *testing.T) { t.Helper() // Send the SYN. - conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn)}) + conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagSyn)}) // Wait for the SYN-ACK. - synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + synAck, err := conn.Expect(t, TCP{Flags: TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. - conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) + conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagAck)}) } // ConnectWithOptions performs a TCP 3-way handshake with given TCP options. @@ -696,17 +696,17 @@ func (conn *TCPIPv4) ConnectWithOptions(t *testing.T, options []byte) { t.Helper() // Send the SYN. - conn.Send(t, TCP{Flags: Uint8(header.TCPFlagSyn), Options: options}) + conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagSyn), Options: options}) // Wait for the SYN-ACK. - synAck, err := conn.Expect(t, TCP{Flags: Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + synAck, err := conn.Expect(t, TCP{Flags: TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("didn't get synack during handshake: %s", err) } conn.layerStates[len(conn.layerStates)-1].(*tcpState).synAck = synAck // Send an ACK. - conn.Send(t, TCP{Flags: Uint8(header.TCPFlagAck)}) + conn.Send(t, TCP{Flags: TCPFlags(header.TCPFlagAck)}) } // ExpectData is a convenient method that expects a Layer and the Layer after @@ -823,6 +823,27 @@ func (conn *TCPIPv4) LocalAddr(t *testing.T) *unix.SockaddrInet4 { return sa } +// GenerateOTWSeqSegment generates a segment with +// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only +// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the +// receiver. +func GenerateOTWSeqSegment(t *testing.T, conn *TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) TCP { + t.Helper() + lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) + otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) + return TCP{SeqNum: Uint32(otwSeq), Flags: TCPFlags(header.TCPFlagAck)} +} + +// GenerateUnaccACKSegment generates a segment with +// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable +// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. +func GenerateUnaccACKSegment(t *testing.T, conn *TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) TCP { + t.Helper() + lastAcceptable := conn.RemoteSeqNum(t) + unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) + return TCP{AckNum: Uint32(unaccAck), Flags: TCPFlags(header.TCPFlagAck)} +} + // IPv4Conn maintains the state for all the layers in a IPv4 connection. type IPv4Conn struct { Connection diff --git a/test/packetimpact/testbench/dut.go b/test/packetimpact/testbench/dut.go index be5121d98..eabdc8cb3 100644 --- a/test/packetimpact/testbench/dut.go +++ b/test/packetimpact/testbench/dut.go @@ -19,7 +19,6 @@ import ( "encoding/binary" "fmt" "net" - "syscall" "testing" "time" @@ -35,24 +34,26 @@ type DUT struct { conn *grpc.ClientConn posixServer POSIXClient Net *DUTTestNet + Uname *DUTUname } // NewDUT creates a new connection with the DUT over gRPC. func NewDUT(t *testing.T) DUT { t.Helper() - n := GetDUTTestNet() - dut := n.ConnectToDUT(t) + info := getDUTInfo() + dut := info.ConnectToDUT(t) t.Cleanup(func() { dut.TearDownConnection() - dut.Net.Release() + info.release() }) return dut } // ConnectToDUT connects to DUT through gRPC. -func (n *DUTTestNet) ConnectToDUT(t *testing.T) DUT { +func (info *DUTInfo) ConnectToDUT(t *testing.T) DUT { t.Helper() + n := info.Net posixServerAddress := net.JoinHostPort(n.POSIXServerIP.String(), fmt.Sprintf("%d", n.POSIXServerPort)) conn, err := grpc.Dial(posixServerAddress, grpc.WithInsecure(), grpc.WithKeepaliveParams(keepalive.ClientParameters{Timeout: RPCKeepalive})) if err != nil { @@ -63,6 +64,7 @@ func (n *DUTTestNet) ConnectToDUT(t *testing.T) DUT { conn: conn, posixServer: posixServer, Net: n, + Uname: info.Uname, } } @@ -196,7 +198,7 @@ func (dut *DUT) AcceptWithErrno(ctx context.Context, t *testing.T, sockfd int32) if err != nil { t.Fatalf("failed to call Accept: %s", err) } - return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetFd(), dut.protoToSockaddr(t, resp.GetAddr()), unix.Errno(resp.GetErrno_()) } // Bind calls bind on the DUT and causes a fatal test failure if it doesn't @@ -225,7 +227,7 @@ func (dut *DUT) BindWithErrno(ctx context.Context, t *testing.T, fd int32, sa un if err != nil { t.Fatalf("failed to call Bind: %s", err) } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), unix.Errno(resp.GetErrno_()) } // Close calls close on the DUT and causes a fatal test failure if it doesn't @@ -253,7 +255,7 @@ func (dut *DUT) CloseWithErrno(ctx context.Context, t *testing.T, fd int32) (int if err != nil { t.Fatalf("failed to call Close: %s", err) } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), unix.Errno(resp.GetErrno_()) } // Connect calls connect on the DUT and causes a fatal test failure if it @@ -267,7 +269,7 @@ func (dut *DUT) Connect(t *testing.T, fd int32, sa unix.Sockaddr) { ret, err := dut.ConnectWithErrno(ctx, t, fd, sa) // Ignore 'operation in progress' error that can be returned when the socket // is non-blocking. - if err != syscall.Errno(unix.EINPROGRESS) && ret != 0 { + if err != unix.EINPROGRESS && ret != 0 { t.Fatalf("failed to connect socket: %s", err) } } @@ -284,7 +286,7 @@ func (dut *DUT) ConnectWithErrno(ctx context.Context, t *testing.T, fd int32, sa if err != nil { t.Fatalf("failed to call Connect: %s", err) } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), unix.Errno(resp.GetErrno_()) } // GetSockName calls getsockname on the DUT and causes a fatal test failure if @@ -313,7 +315,7 @@ func (dut *DUT) GetSockNameWithErrno(ctx context.Context, t *testing.T, sockfd i if err != nil { t.Fatalf("failed to call Bind: %s", err) } - return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), dut.protoToSockaddr(t, resp.GetAddr()), unix.Errno(resp.GetErrno_()) } func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, optname, optlen int32, typ pb.GetSockOptRequest_SockOptType) (int32, *pb.SockOptVal, error) { @@ -334,7 +336,7 @@ func (dut *DUT) getSockOpt(ctx context.Context, t *testing.T, sockfd, level, opt if optval == nil { t.Fatalf("GetSockOpt response does not contain a value") } - return resp.GetRet(), optval, syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), optval, unix.Errno(resp.GetErrno_()) } // GetSockOpt calls getsockopt on the DUT and causes a fatal test failure if it @@ -452,7 +454,7 @@ func (dut *DUT) ListenWithErrno(ctx context.Context, t *testing.T, sockfd, backl if err != nil { t.Fatalf("failed to call Listen: %s", err) } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), unix.Errno(resp.GetErrno_()) } // PollOne calls poll on the DUT and asserts that the expected event must be @@ -519,7 +521,7 @@ func (dut *DUT) PollWithErrno(ctx context.Context, t *testing.T, pfds []unix.Pol Revents: int16(protoPfd.GetEvents()), }) } - return resp.GetRet(), result, syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), result, unix.Errno(resp.GetErrno_()) } // Send calls send on the DUT and causes a fatal test failure if it doesn't @@ -550,7 +552,7 @@ func (dut *DUT) SendWithErrno(ctx context.Context, t *testing.T, sockfd int32, b if err != nil { t.Fatalf("failed to call Send: %s", err) } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), unix.Errno(resp.GetErrno_()) } // SendTo calls sendto on the DUT and causes a fatal test failure if it doesn't @@ -582,7 +584,7 @@ func (dut *DUT) SendToWithErrno(ctx context.Context, t *testing.T, sockfd int32, if err != nil { t.Fatalf("failed to call SendTo: %s", err) } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), unix.Errno(resp.GetErrno_()) } // SetNonBlocking will set O_NONBLOCK flag for fd if nonblocking @@ -602,7 +604,7 @@ func (dut *DUT) SetNonBlocking(t *testing.T, fd int32, nonblocking bool) { t.Fatalf("failed to call SetNonblocking: %s", err) } if resp.GetRet() == -1 { - t.Fatalf("fcntl(%d, %s) failed: %s", fd, resp.GetCmd(), syscall.Errno(resp.GetErrno_())) + t.Fatalf("fcntl(%d, %s) failed: %s", fd, resp.GetCmd(), unix.Errno(resp.GetErrno_())) } } @@ -619,7 +621,7 @@ func (dut *DUT) setSockOpt(ctx context.Context, t *testing.T, sockfd, level, opt if err != nil { t.Fatalf("failed to call SetSockOpt: %s", err) } - return resp.GetRet(), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), unix.Errno(resp.GetErrno_()) } // SetSockOpt calls setsockopt on the DUT and causes a fatal test failure if it @@ -720,7 +722,7 @@ func (dut *DUT) SocketWithErrno(t *testing.T, domain, typ, proto int32) (int32, if err != nil { t.Fatalf("failed to call Socket: %s", err) } - return resp.GetFd(), syscall.Errno(resp.GetErrno_()) + return resp.GetFd(), unix.Errno(resp.GetErrno_()) } // Recv calls recv on the DUT and causes a fatal test failure if it doesn't @@ -751,7 +753,7 @@ func (dut *DUT) RecvWithErrno(ctx context.Context, t *testing.T, sockfd, len, fl if err != nil { t.Fatalf("failed to call Recv: %s", err) } - return resp.GetRet(), resp.GetBuf(), syscall.Errno(resp.GetErrno_()) + return resp.GetRet(), resp.GetBuf(), unix.Errno(resp.GetErrno_()) } // SetSockLingerOption sets SO_LINGER socket option on the DUT. @@ -771,16 +773,19 @@ func (dut *DUT) SetSockLingerOption(t *testing.T, sockfd int32, timeout time.Dur // Shutdown calls shutdown on the DUT and causes a fatal test failure if it // doesn't succeed. If more control over the timeout or error handling is // needed, use ShutdownWithErrno. -func (dut *DUT) Shutdown(t *testing.T, fd, how int32) error { +func (dut *DUT) Shutdown(t *testing.T, fd, how int32) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), RPCTimeout) defer cancel() - return dut.ShutdownWithErrno(ctx, t, fd, how) + ret, err := dut.ShutdownWithErrno(ctx, t, fd, how) + if ret != 0 { + t.Fatalf("failed to shutdown(%d, %d): %s", fd, how, err) + } } // ShutdownWithErrno calls shutdown on the DUT. -func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int32) error { +func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int32) (int32, error) { t.Helper() req := &pb.ShutdownRequest{ @@ -791,5 +796,8 @@ func (dut *DUT) ShutdownWithErrno(ctx context.Context, t *testing.T, fd, how int if err != nil { t.Fatalf("failed to call Shutdown: %s", err) } - return syscall.Errno(resp.GetErrno_()) + if resp.GetErrno_() == 0 { + return resp.GetRet(), nil + } + return resp.GetRet(), unix.Errno(resp.GetErrno_()) } diff --git a/test/packetimpact/testbench/layers.go b/test/packetimpact/testbench/layers.go index 64a7a171a..2311f7686 100644 --- a/test/packetimpact/testbench/layers.go +++ b/test/packetimpact/testbench/layers.go @@ -407,6 +407,12 @@ func Uint8(v uint8) *uint8 { return &v } +// TCPFlags is a helper routine that allocates a new +// header.TCPFlags value to store v and returns a pointer to it. +func TCPFlags(v header.TCPFlags) *header.TCPFlags { + return &v +} + // Address is a helper routine that allocates a new tcpip.Address value to // store v and returns a pointer to it. func Address(v tcpip.Address) *tcpip.Address { @@ -1030,7 +1036,7 @@ type TCP struct { SeqNum *uint32 AckNum *uint32 DataOffset *uint8 - Flags *uint8 + Flags *header.TCPFlags WindowSize *uint16 Checksum *uint16 UrgentPointer *uint16 @@ -1063,7 +1069,7 @@ func (l *TCP) ToBytes() ([]byte, error) { h.SetDataOffset(uint8(l.length())) } if l.Flags != nil { - h.SetFlags(*l.Flags) + h.SetFlags(uint8(*l.Flags)) } if l.WindowSize != nil { h.SetWindowSize(*l.WindowSize) @@ -1157,7 +1163,7 @@ func parseTCP(b []byte) (Layer, layerParser) { SeqNum: Uint32(h.SequenceNumber()), AckNum: Uint32(h.AckNumber()), DataOffset: Uint8(h.DataOffset()), - Flags: Uint8(h.Flags()), + Flags: TCPFlags(h.Flags()), WindowSize: Uint16(h.WindowSize()), Checksum: Uint16(h.Checksum()), UrgentPointer: Uint16(h.UrgentPointer()), diff --git a/test/packetimpact/testbench/layers_test.go b/test/packetimpact/testbench/layers_test.go index eca0780b5..614a5de1e 100644 --- a/test/packetimpact/testbench/layers_test.go +++ b/test/packetimpact/testbench/layers_test.go @@ -178,7 +178,7 @@ func TestLayerStringFormat(t *testing.T) { SeqNum: Uint32(3452155723), AckNum: Uint32(2596996163), DataOffset: Uint8(5), - Flags: Uint8(20), + Flags: TCPFlags(header.TCPFlagRst | header.TCPFlagAck), WindowSize: Uint16(64240), Checksum: Uint16(0x2e2b), }, @@ -188,7 +188,7 @@ func TestLayerStringFormat(t *testing.T) { "SeqNum:3452155723 " + "AckNum:2596996163 " + "DataOffset:5 " + - "Flags:20 " + + "Flags: R A " + "WindowSize:64240 " + "Checksum:11819" + "}", @@ -436,7 +436,7 @@ func TestTCPOptions(t *testing.T) { DstPort: Uint16(54321), SeqNum: Uint32(0), AckNum: Uint32(0), - Flags: Uint8(header.TCPFlagSyn), + Flags: TCPFlags(header.TCPFlagSyn), WindowSize: Uint16(8192), Checksum: Uint16(0xf51c), UrgentPointer: Uint16(0), @@ -480,7 +480,7 @@ func TestTCPOptions(t *testing.T) { DstPort: Uint16(54321), SeqNum: Uint32(0), AckNum: Uint32(0), - Flags: Uint8(header.TCPFlagSyn), + Flags: TCPFlags(header.TCPFlagSyn), WindowSize: Uint16(8192), Checksum: Uint16(0xe521), UrgentPointer: Uint16(0), diff --git a/test/packetimpact/testbench/testbench.go b/test/packetimpact/testbench/testbench.go index 891897d55..a73c07e64 100644 --- a/test/packetimpact/testbench/testbench.go +++ b/test/packetimpact/testbench/testbench.go @@ -34,14 +34,29 @@ var ( // RPCTimeout is the gRPC timeout. RPCTimeout = 100 * time.Millisecond - // dutTestNetsJSON is the json string that describes all the test networks to + // dutInfosJSON is the json string that describes information about all the // duts available to use. - dutTestNetsJSON string - // dutTestNets is the pool among which the testbench can choose a DUT to work + dutInfosJSON string + // dutInfo is the pool among which the testbench can choose a DUT to work // with. - dutTestNets chan *DUTTestNet + dutInfo chan *DUTInfo ) +// DUTInfo has both network and uname information about the DUT. +type DUTInfo struct { + Uname *DUTUname + Net *DUTTestNet +} + +// DUTUname contains information about the DUT from uname. +type DUTUname struct { + Machine string + KernelName string + KernelRelease string + KernelVersion string + OperatingSystem string +} + // DUTTestNet describes the test network setup on dut and how the testbench // should connect with an existing DUT. type DUTTestNet struct { @@ -86,7 +101,7 @@ func registerFlags(fs *flag.FlagSet) { fs.BoolVar(&Native, "native", Native, "whether the test is running natively") fs.DurationVar(&RPCTimeout, "rpc_timeout", RPCTimeout, "gRPC timeout") fs.DurationVar(&RPCKeepalive, "rpc_keepalive", RPCKeepalive, "gRPC keepalive") - fs.StringVar(&dutTestNetsJSON, "dut_test_nets_json", dutTestNetsJSON, "path to the dut test nets json file") + fs.StringVar(&dutInfosJSON, "dut_infos_json", dutInfosJSON, "json that describes the DUTs") } // Initialize initializes the testbench, it parse the flags and sets up the @@ -94,27 +109,27 @@ func registerFlags(fs *flag.FlagSet) { func Initialize(fs *flag.FlagSet) { registerFlags(fs) flag.Parse() - if err := loadDUTTestNets(); err != nil { + if err := loadDUTInfos(); err != nil { panic(err) } } -// loadDUTTestNets loads available DUT test networks from the json file, it +// loadDUTInfos loads available DUT test infos from the json file, it // must be called after flag.Parse(). -func loadDUTTestNets() error { - var parsedTestNets []DUTTestNet - if err := json.Unmarshal([]byte(dutTestNetsJSON), &parsedTestNets); err != nil { +func loadDUTInfos() error { + var dutInfos []DUTInfo + if err := json.Unmarshal([]byte(dutInfosJSON), &dutInfos); err != nil { return fmt.Errorf("failed to unmarshal JSON: %w", err) } - if got, want := len(parsedTestNets), 1; got < want { + if got, want := len(dutInfos), 1; got < want { return fmt.Errorf("got %d DUTs, the test requires at least %d DUTs", got, want) } // Using a buffered channel as semaphore - dutTestNets = make(chan *DUTTestNet, len(parsedTestNets)) - for i := range parsedTestNets { - parsedTestNets[i].LocalIPv4 = parsedTestNets[i].LocalIPv4.To4() - parsedTestNets[i].RemoteIPv4 = parsedTestNets[i].RemoteIPv4.To4() - dutTestNets <- &parsedTestNets[i] + dutInfo = make(chan *DUTInfo, len(dutInfos)) + for i := range dutInfos { + dutInfos[i].Net.LocalIPv4 = dutInfos[i].Net.LocalIPv4.To4() + dutInfos[i].Net.RemoteIPv4 = dutInfos[i].Net.RemoteIPv4.To4() + dutInfo <- &dutInfos[i] } return nil } @@ -130,14 +145,13 @@ func GenerateRandomPayload(t *testing.T, n int) []byte { return buf } -// GetDUTTestNet gets a usable DUTTestNet, the function will block until any -// becomes available. -func GetDUTTestNet() *DUTTestNet { - return <-dutTestNets +// getDUTInfo returns information about an available DUT from the pool. If no +// DUT is readily available, getDUTInfo blocks until one becomes available. +func getDUTInfo() *DUTInfo { + return <-dutInfo } -// Release releases the DUTTestNet back to the pool so that some other test -// can use. -func (n *DUTTestNet) Release() { - dutTestNets <- n +// release returns the DUTInfo back to the pool. +func (info *DUTInfo) release() { + dutInfo <- info } diff --git a/test/packetimpact/tests/BUILD b/test/packetimpact/tests/BUILD index 301cf4980..d5cb0ae06 100644 --- a/test/packetimpact/tests/BUILD +++ b/test/packetimpact/tests/BUILD @@ -124,6 +124,17 @@ packetimpact_testbench( ) packetimpact_testbench( + name = "tcp_outside_the_window_closing", + srcs = ["tcp_outside_the_window_closing_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_testbench( name = "tcp_noaccept_close_rst", srcs = ["tcp_noaccept_close_rst_test.go"], deps = [ @@ -155,6 +166,17 @@ packetimpact_testbench( ) packetimpact_testbench( + name = "tcp_unacc_seq_ack_closing", + srcs = ["tcp_unacc_seq_ack_closing_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + +packetimpact_testbench( name = "tcp_paws_mechanism", srcs = ["tcp_paws_mechanism_test.go"], deps = [ @@ -375,6 +397,16 @@ packetimpact_testbench( ], ) +packetimpact_testbench( + name = "tcp_fin_retransmission", + srcs = ["tcp_fin_retransmission_test.go"], + deps = [ + "//pkg/tcpip/header", + "//test/packetimpact/testbench", + "@org_golang_x_sys//unix:go_default_library", + ], +) + validate_all_tests() [packetimpact_go_test( diff --git a/test/packetimpact/tests/fin_wait2_timeout_test.go b/test/packetimpact/tests/fin_wait2_timeout_test.go index 11f0fcd1e..cff8ca51d 100644 --- a/test/packetimpact/tests/fin_wait2_timeout_test.go +++ b/test/packetimpact/tests/fin_wait2_timeout_test.go @@ -51,21 +51,21 @@ func TestFinWait2Timeout(t *testing.T) { } dut.Close(t, acceptFd) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected a FIN-ACK within 1 second but got none: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) time.Sleep(5 * time.Second) conn.Drain(t) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) if tt.linger2 { - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, time.Second); err != nil { t.Fatalf("expected a RST packet within a second but got none: %s", err) } } else { - if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { + if got, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, 10*time.Second); got != nil || err == nil { t.Fatalf("expected no RST packets within ten seconds but got one: %s", got) } } diff --git a/test/packetimpact/tests/ipv4_id_uniqueness_test.go b/test/packetimpact/tests/ipv4_id_uniqueness_test.go index a63b41366..2b69ceecb 100644 --- a/test/packetimpact/tests/ipv4_id_uniqueness_test.go +++ b/test/packetimpact/tests/ipv4_id_uniqueness_test.go @@ -100,7 +100,7 @@ func TestIPv4RetransmitIdentificationUniqueness(t *testing.T) { // Let the DUT estimate RTO with RTT from the DATA-ACK. // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which // we can skip sending this ACK. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) dut.Send(t, remoteFD, tc.payload, 0) expectTCP := &testbench.TCP{SeqNum: testbench.Uint32(uint32(*conn.RemoteSeqNum(t)))} diff --git a/test/packetimpact/tests/tcp_cork_mss_test.go b/test/packetimpact/tests/tcp_cork_mss_test.go index a7ba5035e..1db3c9883 100644 --- a/test/packetimpact/tests/tcp_cork_mss_test.go +++ b/test/packetimpact/tests/tcp_cork_mss_test.go @@ -60,24 +60,24 @@ func TestTCPCorkMSS(t *testing.T) { // Expect the segments to be coalesced and sent and capped to MSS. expectedPayload := testbench.Payload{Bytes: expectedData[:mss]} - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) // Expect the coalesced segment to be split and transmitted. expectedPayload = testbench.Payload{Bytes: expectedData[mss:]} - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Check for segments to *not* be held up because of TCP_CORK when // the current send window is less than MSS. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(2 * len(sampleData)))}) dut.Send(t, acceptFD, sampleData, 0) dut.Send(t, acceptFD, sampleData, 0) expectedPayload = testbench.Payload{Bytes: append(sampleData, sampleData...)} - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, &expectedPayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) } diff --git a/test/packetimpact/tests/tcp_fin_retransmission_test.go b/test/packetimpact/tests/tcp_fin_retransmission_test.go new file mode 100644 index 000000000..500f7a783 --- /dev/null +++ b/test/packetimpact/tests/tcp_fin_retransmission_test.go @@ -0,0 +1,87 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_fin_retransmission_test + +import ( + "flag" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +// TestTCPClosingFinRetransmission tests that TCP implementation should retransmit +// FIN segment in CLOSING state. +func TestTCPClosingFinRetransmission(t *testing.T) { + for _, tt := range []struct { + description string + flags header.TCPFlags + }{ + {"CLOSING", header.TCPFlagAck | header.TCPFlagFin}, + {"FIN_WAIT_1", header.TCPFlagAck}, + } { + t.Run(tt.description, func(t *testing.T) { + dut := testbench.NewDUT(t) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) + + // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK. + // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which + // we can skip the next block of code. + sampleData := []byte("Sample Data") + if got, want := dut.Send(t, acceptFD, sampleData, 0), len(sampleData); int(got) != want { + t.Fatalf("got dut.Send(t, %d, %s, 0) = %d, want %d", acceptFD, sampleData, got, want) + } + if _, err := conn.ExpectData(t, &testbench.TCP{}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { + t.Fatalf("expected payload was not received: %s", err) + } + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + + dut.Shutdown(t, acceptFD, unix.SHUT_WR) + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected FINACK from DUT, but got none: %s", err) + } + + // Do not ack the FIN from DUT so that we can test for retransmission. + seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1) + conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(tt.flags)}) + + if tt.flags&header.TCPFlagFin != 0 { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected an ACK to our FIN, but got none: %s", err) + } + } + + if _, err := conn.Expect(t, testbench.TCP{ + SeqNum: seqNumForTheirFIN, + Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck), + }, time.Second); err != nil { + t.Errorf("expected retransmission of FIN from the DUT: %s", err) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_handshake_window_size_test.go b/test/packetimpact/tests/tcp_handshake_window_size_test.go index 5d1266f3c..668e0275c 100644 --- a/test/packetimpact/tests/tcp_handshake_window_size_test.go +++ b/test/packetimpact/tests/tcp_handshake_window_size_test.go @@ -38,8 +38,8 @@ func TestTCPHandshakeWindowSize(t *testing.T) { defer conn.Close(t) // Start handshake with zero window size. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn), WindowSize: testbench.Uint16(uint16(0))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK: %s", err) } // Update the advertised window size to a non-zero value with the ACK that @@ -47,7 +47,7 @@ func TestTCPHandshakeWindowSize(t *testing.T) { // // Set the window size with MSB set and expect the dut to treat it as // an unsigned value. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(uint16(1 << 15))}) acceptFd, _ := dut.Accept(t, listenFD) defer dut.Close(t, acceptFd) @@ -59,7 +59,7 @@ func TestTCPHandshakeWindowSize(t *testing.T) { // expect the dut to honor the recently advertised non-zero window // and actually send out the data instead of probing for zero window. dut.Send(t, acceptFd, sampleData, 0) - if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { + if _, err := conn.ExpectNextData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } } diff --git a/test/packetimpact/tests/tcp_info_test.go b/test/packetimpact/tests/tcp_info_test.go index 69275e54b..3fc2c7fe5 100644 --- a/test/packetimpact/tests/tcp_info_test.go +++ b/test/packetimpact/tests/tcp_info_test.go @@ -51,7 +51,7 @@ func TestTCPInfo(t *testing.T) { if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) info := linux.TCPInfo{} infoBytes := dut.GetSockOpt(t, acceptFD, unix.SOL_TCP, unix.TCP_INFO, int32(linux.SizeOfTCPInfo)) diff --git a/test/packetimpact/tests/tcp_linger_test.go b/test/packetimpact/tests/tcp_linger_test.go index bc4b64388..88942904d 100644 --- a/test/packetimpact/tests/tcp_linger_test.go +++ b/test/packetimpact/tests/tcp_linger_test.go @@ -17,7 +17,6 @@ package tcp_linger_test import ( "context" "flag" - "syscall" "testing" "time" @@ -58,10 +57,10 @@ func TestTCPLingerZeroTimeout(t *testing.T) { dut.Close(t, acceptFD) // If the linger timeout is set to zero, the DUT should send a RST. - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { t.Errorf("expected RST-ACK packet within a second but got none: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) } // TestTCPLingerOff tests when SO_LINGER is not set. DUT should send FIN-ACK @@ -75,10 +74,10 @@ func TestTCPLingerOff(t *testing.T) { dut.Close(t, acceptFD) // If SO_LINGER is not set, DUT should send a FIN-ACK. - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) } // TestTCPLingerNonZeroTimeout tests when SO_LINGER is set with non-zero timeout. @@ -115,10 +114,10 @@ func TestTCPLingerNonZeroTimeout(t *testing.T) { t.Errorf("expected close to return within a second, but returned later") } - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) }) } } @@ -166,10 +165,10 @@ func TestTCPLingerSendNonZeroTimeout(t *testing.T) { t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) } - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) }) } } @@ -183,19 +182,19 @@ func TestTCPLingerShutdownZeroTimeout(t *testing.T) { defer closeAll(t, dut, listenFD, conn) dut.SetSockLingerOption(t, acceptFD, 0, true) - dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + dut.Shutdown(t, acceptFD, unix.SHUT_RDWR) dut.Close(t, acceptFD) // Shutdown will send FIN-ACK with read/write option. - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) } // If the linger timeout is set to zero, the DUT should send a RST. - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}, time.Second); err != nil { t.Errorf("expected RST-ACK packet within a second but got none: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) } // TestTCPLingerShutdownSendNonZeroTimeout tests SO_LINGER with shutdown() and @@ -220,7 +219,7 @@ func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) { sampleData := []byte("Sample Data") dut.Send(t, acceptFD, sampleData, 0) - dut.Shutdown(t, acceptFD, syscall.SHUT_RDWR) + dut.Shutdown(t, acceptFD, unix.SHUT_RDWR) // Increase timeout as Close will take longer time to // return when SO_LINGER is set with non-zero timeout. @@ -243,10 +242,10 @@ func TestTCPLingerShutdownSendNonZeroTimeout(t *testing.T) { t.Fatalf("expected a packet with payload %v: %s", samplePayload, err) } - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { t.Errorf("expected FIN-ACK packet within a second but got none: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) }) } } diff --git a/test/packetimpact/tests/tcp_network_unreachable_test.go b/test/packetimpact/tests/tcp_network_unreachable_test.go index 53dc903e4..5168450ad 100644 --- a/test/packetimpact/tests/tcp_network_unreachable_test.go +++ b/test/packetimpact/tests/tcp_network_unreachable_test.go @@ -17,7 +17,6 @@ package tcp_synsent_reset_test import ( "context" "flag" - "syscall" "testing" "time" @@ -46,12 +45,12 @@ func TestTCPSynSentUnreachable(t *testing.T) { defer cancel() sa := unix.SockaddrInet4{Port: int(port)} copy(sa.Addr[:], dut.Net.LocalIPv4) - if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != unix.EINPROGRESS { t.Errorf("got connect() = %v, want EINPROGRESS", err) } // Get the SYN. - tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, nil, time.Second) if err != nil { t.Fatalf("expected SYN: %s", err) } @@ -100,12 +99,12 @@ func TestTCPSynSentUnreachable6(t *testing.T) { ZoneId: dut.Net.RemoteDevID, } copy(sa.Addr[:], dut.Net.LocalIPv6) - if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != syscall.Errno(unix.EINPROGRESS) { + if _, err := dut.ConnectWithErrno(ctx, t, clientFD, &sa); err != unix.EINPROGRESS { t.Errorf("got connect() = %v, want EINPROGRESS", err) } // Get the SYN. - tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second) + tcpLayers, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, nil, time.Second) if err != nil { t.Fatalf("expected SYN: %s", err) } @@ -156,7 +155,7 @@ func getConnectError(t *testing.T, dut *testbench.DUT, fd int32) error { // failure). dut.PollOne(t, fd, unix.POLLOUT, 10*time.Second) if errno := dut.GetSockOptInt(t, fd, unix.SOL_SOCKET, unix.SO_ERROR); errno != 0 { - return syscall.Errno(errno) + return unix.Errno(errno) } return nil } diff --git a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go index d2871df08..14eb7d93b 100644 --- a/test/packetimpact/tests/tcp_noaccept_close_rst_test.go +++ b/test/packetimpact/tests/tcp_noaccept_close_rst_test.go @@ -40,7 +40,7 @@ func TestTcpNoAcceptCloseReset(t *testing.T) { // it will only respond RST instead of RST+ACK. dut.PollOne(t, listenFd, unix.POLLIN, time.Second) dut.Close(t, listenFd) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}, 1*time.Second); err != nil { t.Fatalf("expected a RST-ACK packet but got none: %s", err) } } diff --git a/test/packetimpact/tests/tcp_outside_the_window_closing_test.go b/test/packetimpact/tests/tcp_outside_the_window_closing_test.go new file mode 100644 index 000000000..1097746c7 --- /dev/null +++ b/test/packetimpact/tests/tcp_outside_the_window_closing_test.go @@ -0,0 +1,86 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_outside_the_window_closing_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +// TestAckOTWSeqInClosing tests that the DUT should send an ACK with +// the right ACK number when receiving a packet with OTW Seq number +// in CLOSING state. https://tools.ietf.org/html/rfc793#page-69 +func TestAckOTWSeqInClosing(t *testing.T) { + for seqNumOffset := seqnum.Size(0); seqNumOffset < 3; seqNumOffset++ { + for _, tt := range []struct { + description string + flags header.TCPFlags + payloads testbench.Layers + }{ + {"SYN", header.TCPFlagSyn, nil}, + {"SYNACK", header.TCPFlagSyn | header.TCPFlagAck, nil}, + {"ACK", header.TCPFlagAck, nil}, + {"FINACK", header.TCPFlagFin | header.TCPFlagAck, nil}, + {"Data", header.TCPFlagAck, []testbench.Layer{&testbench.Payload{Bytes: []byte("abc123")}}}, + } { + t.Run(fmt.Sprintf("%s%d", tt.description, seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1) + defer dut.Close(t, listenFD) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + defer dut.Close(t, acceptFD) + + dut.Shutdown(t, acceptFD, unix.SHUT_WR) + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second); err != nil { + t.Fatalf("expected FINACK from DUT, but got none: %s", err) + } + + // Do not ack the FIN from DUT so that the TCP state on DUT is CLOSING instead of CLOSED. + seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1) + conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}) + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected an ACK to our FIN, but got none: %s", err) + } + + windowSize := seqnum.Size(*conn.SynAck(t).WindowSize) + seqNumOffset + conn.SendFrameStateless(t, conn.CreateFrame(t, testbench.Layers{&testbench.TCP{ + SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), + AckNum: seqNumForTheirFIN, + Flags: testbench.TCPFlags(tt.flags), + }}, tt.payloads...)) + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected an ACK but got none: %s", err) + } + }) + } + } +} diff --git a/test/packetimpact/tests/tcp_outside_the_window_test.go b/test/packetimpact/tests/tcp_outside_the_window_test.go index 8909a348e..7cd7ff703 100644 --- a/test/packetimpact/tests/tcp_outside_the_window_test.go +++ b/test/packetimpact/tests/tcp_outside_the_window_test.go @@ -37,7 +37,7 @@ func init() { func TestTCPOutsideTheWindow(t *testing.T) { for _, tt := range []struct { description string - tcpFlags uint8 + tcpFlags header.TCPFlags payload []testbench.Layer seqNumOffset seqnum.Size expectACK bool @@ -76,11 +76,11 @@ func TestTCPOutsideTheWindow(t *testing.T) { // to the AckNum. localSeqNum := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) conn.Send(t, testbench.TCP{ - Flags: testbench.Uint8(tt.tcpFlags), + Flags: testbench.TCPFlags(tt.tcpFlags), SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), }, tt.payload...) - timeout := 3 * time.Second - gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + timeout := time.Second + gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: localSeqNum}, timeout) if tt.expectACK && err != nil { t.Fatalf("expected an ACK packet within %s but got none: %s", timeout, err) } @@ -93,11 +93,11 @@ func TestTCPOutsideTheWindow(t *testing.T) { // has passed since the last ACK was sent. t.Logf("sending another segment") conn.Send(t, testbench.TCP{ - Flags: testbench.Uint8(tt.tcpFlags), + Flags: testbench.TCPFlags(tt.tcpFlags), SeqNum: testbench.Uint32(uint32(conn.LocalSeqNum(t).Add(windowSize))), }, tt.payload...) timeout := 3 * time.Second - gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: localSeqNum}, timeout) + gotACK, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: localSeqNum}, timeout) if err == nil { t.Fatalf("expected no ACK packet but got one: %s", gotACK) } diff --git a/test/packetimpact/tests/tcp_paws_mechanism_test.go b/test/packetimpact/tests/tcp_paws_mechanism_test.go index 24d9ef4ec..9054955ea 100644 --- a/test/packetimpact/tests/tcp_paws_mechanism_test.go +++ b/test/packetimpact/tests/tcp_paws_mechanism_test.go @@ -38,8 +38,8 @@ func TestPAWSMechanism(t *testing.T) { options := make([]byte, header.TCPOptionTSLength) header.EncodeTSOption(currentTS(), 0, options) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn), Options: options}) - synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn), Options: options}) + synAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("didn't get synack during handshake: %s", err) } @@ -49,7 +49,7 @@ func TestPAWSMechanism(t *testing.T) { } tsecr := parsedSynOpts.TSVal header.EncodeTSOption(currentTS(), tsecr, options) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), Options: options}) acceptFD, _ := dut.Accept(t, listenFD) defer dut.Close(t, acceptFD) @@ -60,9 +60,9 @@ func TestPAWSMechanism(t *testing.T) { // every time we send one, it should not cause any flakiness because timestamps // only need to be non-decreasing. time.Sleep(3 * time.Millisecond) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected an ACK but got none: %s", err) } @@ -85,9 +85,9 @@ func TestPAWSMechanism(t *testing.T) { // 3ms here is chosen arbitrarily and this time.Sleep() should not cause flakiness // due to the exact same reasoning discussed above. time.Sleep(3 * time.Millisecond) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), Options: options}, &testbench.Payload{Bytes: sampleData}) - gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotTCP, err = conn.Expect(t, testbench.TCP{AckNum: lastAckNum, Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected segment with AckNum %d but got none: %s", lastAckNum, err) } diff --git a/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go b/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go index 7dd1c326a..1c8b72ebe 100644 --- a/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go +++ b/test/packetimpact/tests/tcp_queue_send_recv_in_syn_sent_test.go @@ -21,7 +21,6 @@ import ( "errors" "flag" "sync" - "syscall" "testing" "time" @@ -45,10 +44,10 @@ func TestQueueSendInSynSentHandshake(t *testing.T) { sampleData := []byte("Sample Data") dut.SetNonBlocking(t, socket, true) - if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, unix.EINPROGRESS) { t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) } - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second); err != nil { t.Fatalf("expected a SYN from DUT, but got none: %s", err) } @@ -85,19 +84,19 @@ func TestQueueSendInSynSentHandshake(t *testing.T) { time.Sleep(100 * time.Millisecond) // Bring the connection to Established. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}) // Expect the data from the DUT's enqueued send request. // // On Linux, this can be piggybacked with the ACK completing the // handshake. On gVisor, getting such a piggyback is a bit more // complicated because the actual data enqueuing occurs in the // callers of endpoint Write. - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK from DUT, but got none: %s", err) } } @@ -113,15 +112,15 @@ func TestQueueRecvInSynSentHandshake(t *testing.T) { sampleData := []byte("Sample Data") dut.SetNonBlocking(t, socket, true) - if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, unix.EINPROGRESS) { t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) } - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second); err != nil { t.Fatalf("expected a SYN from DUT, but got none: %s", err) } - if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { - t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) + if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != unix.EWOULDBLOCK { + t.Fatalf("expected error %s, got %s", unix.EWOULDBLOCK, err) } // Test blocking read. @@ -160,14 +159,14 @@ func TestQueueRecvInSynSentHandshake(t *testing.T) { time.Sleep(100 * time.Millisecond) // Bring the connection to Established. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK from DUT, but got none: %s", err) } // Send sample payload so that DUT can recv. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagPsh | header.TCPFlagAck)}, &testbench.Payload{Bytes: sampleData}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK from DUT, but got none: %s", err) } } @@ -183,10 +182,10 @@ func TestQueueSendInSynSentRST(t *testing.T) { sampleData := []byte("Sample Data") dut.SetNonBlocking(t, socket, true) - if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, unix.EINPROGRESS) { t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) } - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second); err != nil { t.Fatalf("expected a SYN from DUT, but got none: %s", err) } @@ -207,8 +206,8 @@ func TestQueueSendInSynSentRST(t *testing.T) { // Issue SEND call in SYN-SENT, this should be queued for // process until the connection is established. n, err := dut.SendWithErrno(ctx, t, socket, sampleData, 0) - if err != syscall.Errno(unix.ECONNREFUSED) { - t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) + if err != unix.ECONNREFUSED { + t.Errorf("expected error %s, got %s", unix.ECONNREFUSED, err) } if n != -1 { t.Errorf("expected return value %d, got %d", -1, n) @@ -224,7 +223,7 @@ func TestQueueSendInSynSentRST(t *testing.T) { // request and the system actually being blocked. time.Sleep(100 * time.Millisecond) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}) } // TestQueueRecvInSynSentRST tests recv behavior when the TCP state @@ -238,15 +237,15 @@ func TestQueueRecvInSynSentRST(t *testing.T) { sampleData := []byte("Sample Data") dut.SetNonBlocking(t, socket, true) - if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, syscall.EINPROGRESS) { + if _, err := dut.ConnectWithErrno(context.Background(), t, socket, conn.LocalAddr(t)); !errors.Is(err, unix.EINPROGRESS) { t.Fatalf("failed to bring DUT to SYN-SENT, got: %s, want EINPROGRESS", err) } - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, time.Second); err != nil { t.Fatalf("expected a SYN from DUT, but got none: %s", err) } - if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != syscall.Errno(unix.EWOULDBLOCK) { - t.Fatalf("expected error %s, got %s", syscall.Errno(unix.EWOULDBLOCK), err) + if _, _, err := dut.RecvWithErrno(context.Background(), t, socket, int32(len(sampleData)), 0); err != unix.EWOULDBLOCK { + t.Fatalf("expected error %s, got %s", unix.EWOULDBLOCK, err) } // Test blocking read. @@ -266,8 +265,8 @@ func TestQueueRecvInSynSentRST(t *testing.T) { // Issue RECEIVE call in SYN-SENT, this should be queued for // process until the connection is established. n, _, err := dut.RecvWithErrno(ctx, t, socket, int32(len(sampleData)), 0) - if err != syscall.Errno(unix.ECONNREFUSED) { - t.Errorf("expected error %s, got %s", syscall.Errno(unix.ECONNREFUSED), err) + if err != unix.ECONNREFUSED { + t.Errorf("expected error %s, got %s", unix.ECONNREFUSED, err) } if n != -1 { t.Errorf("expected return value %d, got %d", -1, n) @@ -283,5 +282,5 @@ func TestQueueRecvInSynSentRST(t *testing.T) { // request and the system actually being blocked. time.Sleep(100 * time.Millisecond) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}) } diff --git a/test/packetimpact/tests/tcp_rack_test.go b/test/packetimpact/tests/tcp_rack_test.go index ef902c54d..0a5b0f12b 100644 --- a/test/packetimpact/tests/tcp_rack_test.go +++ b/test/packetimpact/tests/tcp_rack_test.go @@ -97,7 +97,7 @@ func sendAndReceive(t *testing.T, dut testbench.DUT, conn testbench.TCPIPv4, num if sendACK { time.Sleep(simulatedRTT) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(sn))}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(sn))}) } } return lastSent @@ -149,7 +149,7 @@ func TestRACKTLPLost(t *testing.T) { // Cumulative ACK for #[1-5] packets. ackNum := seqNum1.Add(seqnum.Size(6 * payloadSize)) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(ackNum))}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(ackNum))}) // Probe Timeout (PTO) should be two times RTT. Check that the last // packet is retransmitted after probe timeout. @@ -194,7 +194,7 @@ func TestRACKWithSACK(t *testing.T) { sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ start, end, }}, sackBlock[sbOff:]) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) rtt, _ := getRTTAndRTO(t, dut, acceptFd) timeout := 2 * rtt @@ -206,7 +206,7 @@ func TestRACKWithSACK(t *testing.T) { time.Sleep(simulatedRTT) // ACK for #1 packet. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(end))}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(end))}) // RACK considers transmission times of the packets to mark them lost. // As the 3rd packet was sent before the retransmitted 1st packet, RACK @@ -243,7 +243,7 @@ func TestRACKWithoutReorder(t *testing.T) { start, end, }}, sackBlock[sbOff:]) time.Sleep(simulatedRTT) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) // RACK marks #1 and #2 packets as lost and retransmits both after // RTT + reorderWindow. The reorderWindow initially will be a small @@ -289,7 +289,7 @@ func TestRACKWithReorder(t *testing.T) { sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ start, end, }}, sackBlock[sbOff:]) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) } // Send a DSACK block indicating both original and retransmitted @@ -304,7 +304,7 @@ func TestRACKWithReorder(t *testing.T) { dbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ start, end, }}, dsackBlock[dbOff:]) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1 + numPkts*payloadSize)), Options: dsackBlock[:dbOff]}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1 + numPkts*payloadSize)), Options: dsackBlock[:dbOff]}) seqNum1.UpdateForward(seqnum.Size(numPkts * payloadSize)) sendTime := time.Now() @@ -321,7 +321,7 @@ func TestRACKWithReorder(t *testing.T) { sbOff += header.EncodeSACKBlocks([]header.SACKBlock{{ start, end, }}, sackBlock[sbOff:]) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) // Expect the retransmission of #1 packet after RTT+ReorderWindow. if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, time.Second); err != nil { @@ -361,7 +361,7 @@ func TestRACKWithLostRetransmission(t *testing.T) { start, end, }}, sackBlock[sbOff:]) time.Sleep(simulatedRTT) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock[:sbOff]}) // RACK marks #1 packet as lost and retransmits it after // RTT + reorderWindow. The reorderWindow is bounded between a small @@ -394,7 +394,7 @@ func TestRACKWithLostRetransmission(t *testing.T) { start, end, }}, sackBlock1[sbOff1:]) time.Sleep(simulatedRTT) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock1[:sbOff1]}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: testbench.Uint32(uint32(seqNum1)), Options: sackBlock1[:sbOff1]}) // Expect re-retransmission of #1 packet without entering an RTO. if _, err := conn.Expect(t, testbench.TCP{SeqNum: testbench.Uint32(uint32(seqNum1))}, timeout); err != nil { diff --git a/test/packetimpact/tests/tcp_rcv_buf_space_test.go b/test/packetimpact/tests/tcp_rcv_buf_space_test.go index d6ad5cda6..f121d44eb 100644 --- a/test/packetimpact/tests/tcp_rcv_buf_space_test.go +++ b/test/packetimpact/tests/tcp_rcv_buf_space_test.go @@ -17,7 +17,6 @@ package tcp_rcv_buf_space_test import ( "context" "flag" - "syscall" "testing" "golang.org/x/sys/unix" @@ -61,7 +60,7 @@ func TestReduceRecvBuf(t *testing.T) { payloadBytes = l } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, []testbench.Layer{&testbench.Payload{Bytes: payload[:payloadBytes]}}...) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, []testbench.Layer{&testbench.Payload{Bytes: payload[:payloadBytes]}}...) payload = payload[payloadBytes:] } @@ -73,7 +72,7 @@ func TestReduceRecvBuf(t *testing.T) { // Second read should return EAGAIN as the last segment should have been // dropped due to it exceeding the receive buffer space available in the // socket. - if ret, got, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), syscall.MSG_DONTWAIT); got != nil || ret != -1 || err != syscall.EAGAIN { + if ret, got, err := dut.RecvWithErrno(context.Background(), t, acceptFd, int32(len(sampleData)), unix.MSG_DONTWAIT); got != nil || ret != -1 || err != unix.EAGAIN { t.Fatalf("expected no packets but got: %s", got) } } diff --git a/test/packetimpact/tests/tcp_retransmits_test.go b/test/packetimpact/tests/tcp_retransmits_test.go index ba79fbf55..3dc8f63ab 100644 --- a/test/packetimpact/tests/tcp_retransmits_test.go +++ b/test/packetimpact/tests/tcp_retransmits_test.go @@ -15,6 +15,7 @@ package tcp_retransmits_test import ( + "bytes" "flag" "testing" "time" @@ -59,38 +60,43 @@ func TestRetransmits(t *testing.T) { sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} + // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK. + // This is to reduce the test run-time from the default initial RTO of 1s. + // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which + // we can skip this data send/recv which is solely to estimate RTO. dut.Send(t, acceptFd, sampleData, 0) if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - // Give a chance for the dut to estimate RTO with RTT from the DATA-ACK. - // TODO(gvisor.dev/issue/2685) Estimate RTO during handshake, after which - // we can skip sending this ACK. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + // Wait for the DUT to receive the data, thus ensuring that the stack has + // estimated RTO before we query RTO via TCP_INFO. + if got := dut.Recv(t, acceptFd, int32(len(sampleData)), 0); !bytes.Equal(got, sampleData) { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %s, want %s", acceptFd, len(sampleData), got, sampleData) + } const timeoutCorrection = time.Second - const diffCorrection = time.Millisecond + const diffCorrection = 200 * time.Millisecond rto := getRTO(t, dut, acceptFd) - timeout := rto + timeoutCorrection - startTime := time.Now() dut.Send(t, acceptFd, sampleData, 0) seq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) - if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, timeout); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, rto+timeoutCorrection); err != nil { t.Fatalf("expected payload was not received: %s", err) } // Expect retransmits of the same segment. for i := 0; i < 5; i++ { - if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, timeout); err != nil { - t.Fatalf("expected payload was not received within %d loop %d err %s", timeout, i, err) + startTime := time.Now() + rto = getRTO(t, dut, acceptFd) + if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: seq}, samplePayload, rto+timeoutCorrection); err != nil { + t.Fatalf("expected payload was not received within %s loop %d err %s", rto+timeoutCorrection, i, err) } - if diff := time.Since(startTime); diff+diffCorrection < rto { - t.Fatalf("retransmit came sooner got: %d want: >= %d probe %d", diff, rto, i) + t.Fatalf("retransmit came sooner got: %s want: >= %s probe %d", diff+diffCorrection, rto, i) } - startTime = time.Now() - rto = getRTO(t, dut, acceptFd) - timeout = rto + timeoutCorrection } } diff --git a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go index 418393796..64b7288fb 100644 --- a/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go +++ b/test/packetimpact/tests/tcp_send_window_sizes_piggyback_test.go @@ -71,7 +71,7 @@ func TestSendWindowSizesPiggyback(t *testing.T) { dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_NODELAY, 1) - expectedTCP := testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)} + expectedTCP := testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)} dut.Send(t, acceptFd, sampleData, 0) expectedPayload := testbench.Payload{Bytes: tt.expectedPayload1} @@ -90,7 +90,7 @@ func TestSendWindowSizesPiggyback(t *testing.T) { // Send ACK for the previous segment along with data for the dut to // receive and ACK back. Sending this ACK would make room for the dut // to transmit any enqueued segment. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(tt.windowSize)}, &testbench.Payload{Bytes: sampleData}) // Expect the dut to piggyback the ACK for received data along with // the segment enqueued for transmit. diff --git a/test/packetimpact/tests/tcp_synrcvd_reset_test.go b/test/packetimpact/tests/tcp_synrcvd_reset_test.go index 32271d7b2..3346d43c4 100644 --- a/test/packetimpact/tests/tcp_synrcvd_reset_test.go +++ b/test/packetimpact/tests/tcp_synrcvd_reset_test.go @@ -37,11 +37,11 @@ func TestTCPSynRcvdReset(t *testing.T) { defer conn.Close(t) // Expect dut connection to have transitioned to SYN-RCVD state. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. // @@ -49,8 +49,8 @@ func TestTCPSynRcvdReset(t *testing.T) { // CLOSED. We cannot use TCP_INFO to lookup the state as this is a passive // DUT connection. for i := 0; i < 5; i++ { - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Logf("retransmit%d ACK as we did not get the expected RST, %s", i, err) continue } diff --git a/test/packetimpact/tests/tcp_synsent_reset_test.go b/test/packetimpact/tests/tcp_synsent_reset_test.go index 2c8bb101b..cccb0abc6 100644 --- a/test/packetimpact/tests/tcp_synsent_reset_test.go +++ b/test/packetimpact/tests/tcp_synsent_reset_test.go @@ -42,7 +42,7 @@ func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16, copy(sa.Addr[:], dut.Net.LocalIPv4) // Bring the dut to SYN-SENT state with a non-blocking connect. dut.Connect(t, clientFD, &sa) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}, nil, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}, nil, time.Second); err != nil { t.Fatalf("expected SYN\n") } @@ -53,11 +53,11 @@ func dutSynSentState(t *testing.T) (*testbench.DUT, *testbench.TCPIPv4, uint16, func TestTCPSynSentReset(t *testing.T) { _, conn, _, _ := dutSynSentState(t) defer conn.Close(t) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst | header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst | header.TCPFlagAck)}) // Expect the connection to have closed. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } @@ -73,15 +73,15 @@ func TestTCPSynSentRcvdReset(t *testing.T) { // Initiate new SYN connection with the same port pair // (simultaneous open case), expect the dut connection to move to // SYN-RCVD state - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn)}) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagSyn | header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected SYN-ACK %s\n", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}) // Expect the connection to have transitioned SYN-RCVD to CLOSED. // TODO(gvisor.dev/issue/478): Check for TCP_INFO on the dut side. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } diff --git a/test/packetimpact/tests/tcp_timewait_reset_test.go b/test/packetimpact/tests/tcp_timewait_reset_test.go index d1d2fb83d..89037f0a4 100644 --- a/test/packetimpact/tests/tcp_timewait_reset_test.go +++ b/test/packetimpact/tests/tcp_timewait_reset_test.go @@ -42,26 +42,26 @@ func TestTimeWaitReset(t *testing.T) { // Trigger active close. dut.Close(t, acceptFD) - _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected a FIN: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) // Send a FIN, DUT should transition to TIME_WAIT from FIN_WAIT2. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK for our FIN: %s", err) } // Send a RST, the DUT should transition to CLOSED from TIME_WAIT. // This is the default Linux behavior, it can be changed to ignore RSTs via // sysctl net.ipv4.tcp_rfc1337. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) // The DUT should reply with RST to our ACK as the state should have // transitioned to CLOSED. - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, time.Second); err != nil { t.Fatalf("expected a RST: %s", err) } } diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go new file mode 100644 index 000000000..a208210ac --- /dev/null +++ b/test/packetimpact/tests/tcp_unacc_seq_ack_closing_test.go @@ -0,0 +1,94 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_unacc_seq_ack_closing_test + +import ( + "flag" + "fmt" + "testing" + "time" + + "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/test/packetimpact/testbench" +) + +func init() { + testbench.Initialize(flag.CommandLine) +} + +func TestSimultaneousCloseUnaccSeqAck(t *testing.T) { + for _, tt := range []struct { + description string + makeTestingTCP func(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset, windowSize seqnum.Size) testbench.TCP + seqNumOffset seqnum.Size + expectAck bool + }{ + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: false}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: true}, + } { + t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { + dut := testbench.NewDUT(t) + listenFD, remotePort := dut.CreateListener(t, unix.SOCK_STREAM, unix.IPPROTO_TCP, 1 /*backlog*/) + defer dut.Close(t, listenFD) + conn := dut.Net.NewTCPIPv4(t, testbench.TCP{DstPort: &remotePort}, testbench.TCP{SrcPort: &remotePort}) + defer conn.Close(t) + + conn.Connect(t) + acceptFD, _ := dut.Accept(t, listenFD) + + // Trigger active close. + dut.Shutdown(t, acceptFD, unix.SHUT_WR) + + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + if err != nil { + t.Fatalf("expected a FIN: %s", err) + } + // Do not ack the FIN from DUT so that we get to CLOSING. + seqNumForTheirFIN := testbench.Uint32(uint32(*conn.RemoteSeqNum(t)) - 1) + conn.Send(t, testbench.TCP{AckNum: seqNumForTheirFIN, Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}) + + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { + t.Errorf("expected an ACK to our FIN, but got none: %s", err) + } + + sampleData := []byte("Sample Data") + samplePayload := &testbench.Payload{Bytes: sampleData} + + origSeq := uint32(*conn.LocalSeqNum(t)) + // Send a segment with OTW Seq / unacc ACK. + tcp := tt.makeTestingTCP(t, &conn, tt.seqNumOffset, seqnum.Size(*gotTCP.WindowSize)) + if tt.description == "OTWSeq" { + // If we generate an OTW Seq segment, make sure we don't acknowledge their FIN so that + // we stay in CLOSING. + tcp.AckNum = seqNumForTheirFIN + } + conn.Send(t, tcp, samplePayload) + + got, err := conn.Expect(t, testbench.TCP{AckNum: testbench.Uint32(origSeq), Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) + if tt.expectAck && err != nil { + t.Errorf("expected an ack in CLOSING state, but got none: %s", err) + } + if !tt.expectAck && got != nil { + t.Errorf("expected no ack in CLOSING state, but got one: %s", got) + } + }) + } +} diff --git a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go index ea962c818..ce0a26171 100644 --- a/test/packetimpact/tests/tcp_unacc_seq_ack_test.go +++ b/test/packetimpact/tests/tcp_unacc_seq_ack_test.go @@ -17,7 +17,6 @@ package tcp_unacc_seq_ack_test import ( "flag" "fmt" - "syscall" "testing" "time" @@ -39,12 +38,12 @@ func TestEstablishedUnaccSeqAck(t *testing.T) { expectAck bool restoreSeq bool }{ - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: true, restoreSeq: true}, - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true, restoreSeq: true}, - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true, restoreSeq: true}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: true, restoreSeq: false}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: false, restoreSeq: true}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: false, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: true, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: true, restoreSeq: false}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: false, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: false, restoreSeq: true}, } { t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { dut := testbench.NewDUT(t) @@ -59,8 +58,8 @@ func TestEstablishedUnaccSeqAck(t *testing.T) { sampleData := []byte("Sample Data") samplePayload := &testbench.Payload{Bytes: sampleData} - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected ack %s", err) } @@ -74,7 +73,7 @@ func TestEstablishedUnaccSeqAck(t *testing.T) { // ACK matches the TCP layer state. *conn.LocalSeqNum(t) = origSeq } - gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) if tt.expectAck && err != nil { t.Fatalf("expected an ack but got none: %s", err) } @@ -92,12 +91,12 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) { seqNumOffset seqnum.Size expectAck bool }{ - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, expectAck: false}, - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, expectAck: true}, - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, expectAck: true}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, expectAck: false}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, expectAck: true}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, expectAck: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, expectAck: false}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, expectAck: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, expectAck: false}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, expectAck: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, expectAck: true}, } { t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { dut := testbench.NewDUT(t) @@ -110,8 +109,8 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) { acceptFD, _ := dut.Accept(t, listenFD) // Send a FIN to DUT to intiate the passive close. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}) - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagFin)}) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected an ACK for our fin and DUT should enter CLOSE_WAIT: %s", err) } @@ -122,7 +121,7 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) { // Send a segment with OTW Seq / unacc ACK. conn.Send(t, tt.makeTestingTCP(t, &conn, tt.seqNumOffset, windowSize), samplePayload) - gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + gotAck, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) if tt.expectAck && err != nil { t.Errorf("expected an ack but got none: %s", err) } @@ -132,14 +131,14 @@ func TestPassiveCloseUnaccSeqAck(t *testing.T) { // Now let's verify DUT is indeed in CLOSE_WAIT dut.Close(t, acceptFD) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagFin)}, time.Second); err != nil { t.Fatalf("expected DUT to send a FIN: %s", err) } // Ack the FIN from DUT - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) // Send some extra data to DUT - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, samplePayload) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, samplePayload) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, time.Second); err != nil { t.Fatalf("expected DUT to send an RST: %s", err) } }) @@ -153,12 +152,12 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) { seqNumOffset seqnum.Size restoreSeq bool }{ - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 0, restoreSeq: true}, - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 1, restoreSeq: true}, - {description: "OTWSeq", makeTestingTCP: generateOTWSeqSegment, seqNumOffset: 2, restoreSeq: true}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 0, restoreSeq: false}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 1, restoreSeq: true}, - {description: "UnaccAck", makeTestingTCP: generateUnaccACKSegment, seqNumOffset: 2, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 0, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 1, restoreSeq: true}, + {description: "OTWSeq", makeTestingTCP: testbench.GenerateOTWSeqSegment, seqNumOffset: 2, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 0, restoreSeq: false}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 1, restoreSeq: true}, + {description: "UnaccAck", makeTestingTCP: testbench.GenerateUnaccACKSegment, seqNumOffset: 2, restoreSeq: true}, } { t.Run(fmt.Sprintf("%s:offset=%d", tt.description, tt.seqNumOffset), func(t *testing.T) { dut := testbench.NewDUT(t) @@ -171,14 +170,14 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) { acceptFD, _ := dut.Accept(t, listenFD) // Trigger active close. - dut.Shutdown(t, acceptFD, syscall.SHUT_WR) + dut.Shutdown(t, acceptFD, unix.SHUT_WR) // Get to FIN_WAIT2 - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected a FIN: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) sendUnaccSeqAck := func(state string) { t.Helper() @@ -193,7 +192,7 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) { // incoming ACK matches the TCP layer state. *conn.LocalSeqNum(t) = origSeq } - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { t.Errorf("expected an ack in %s state, but got none: %s", state, err) } } @@ -201,8 +200,8 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) { sendUnaccSeqAck("FIN_WAIT2") // Send a FIN to DUT to get to TIME_WAIT - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagFin | header.TCPFlagAck)}) - if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagFin | header.TCPFlagAck)}) + if _, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second); err != nil { t.Fatalf("expected an ACK for our fin and DUT should enter TIME_WAIT: %s", err) } @@ -210,22 +209,3 @@ func TestActiveCloseUnaccpSeqAck(t *testing.T) { }) } } - -// generateOTWSeqSegment generates an segment with -// seqnum = RCV.NXT + RCV.WND + seqNumOffset, the generated segment is only -// acceptable when seqNumOffset is 0, otherwise an ACK is expected from the -// receiver. -func generateOTWSeqSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.LocalSeqNum(t).Add(windowSize) - otwSeq := uint32(lastAcceptable.Add(seqNumOffset)) - return testbench.TCP{SeqNum: testbench.Uint32(otwSeq), Flags: testbench.Uint8(header.TCPFlagAck)} -} - -// generateUnaccACKSegment generates an segment with -// acknum = SND.NXT + seqNumOffset, the generated segment is only acceptable -// when seqNumOffset is 0, otherwise an ACK is expected from the receiver. -func generateUnaccACKSegment(t *testing.T, conn *testbench.TCPIPv4, seqNumOffset seqnum.Size, windowSize seqnum.Size) testbench.TCP { - lastAcceptable := conn.RemoteSeqNum(t) - unaccAck := uint32(lastAcceptable.Add(seqNumOffset)) - return testbench.TCP{AckNum: testbench.Uint32(unaccAck), Flags: testbench.Uint8(header.TCPFlagAck)} -} diff --git a/test/packetimpact/tests/tcp_user_timeout_test.go b/test/packetimpact/tests/tcp_user_timeout_test.go index b16e65366..ef38bd738 100644 --- a/test/packetimpact/tests/tcp_user_timeout_test.go +++ b/test/packetimpact/tests/tcp_user_timeout_test.go @@ -35,7 +35,7 @@ func sendPayload(t *testing.T, conn *testbench.TCPIPv4, dut *testbench.DUT, fd i } conn.Drain(t) dut.Send(t, fd, sampleData, 0) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, &testbench.Payload{Bytes: sampleData}, time.Second); err != nil { t.Fatalf("expected data but got none: %w", err) } } @@ -79,14 +79,14 @@ func TestTCPUserTimeout(t *testing.T) { time.Sleep(tt.sendDelay) conn.Drain(t) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) // If TCP_USER_TIMEOUT was set and the above delay was longer than the // TCP_USER_TIMEOUT then the DUT should send a RST in response to the // testbench's packet. expectRST := tt.userTimeout != 0 && tt.sendDelay > tt.userTimeout expectTimeout := 5 * time.Second - got, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, expectTimeout) + got, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, expectTimeout) if expectRST && err != nil { t.Errorf("expected RST packet within %s but got none: %s", expectTimeout, err) } diff --git a/test/packetimpact/tests/tcp_window_shrink_test.go b/test/packetimpact/tests/tcp_window_shrink_test.go index 093484721..0d65a2ea2 100644 --- a/test/packetimpact/tests/tcp_window_shrink_test.go +++ b/test/packetimpact/tests/tcp_window_shrink_test.go @@ -48,7 +48,7 @@ func TestWindowShrink(t *testing.T) { if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) dut.Send(t, acceptFd, sampleData, 0) dut.Send(t, acceptFd, sampleData, 0) @@ -59,7 +59,7 @@ func TestWindowShrink(t *testing.T) { t.Fatalf("expected payload was not received: %s", err) } // We close our receiving window here - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) dut.Send(t, acceptFd, []byte("Sample Data"), 0) // Note: There is another kind of zero-window probing which Windows uses (by sending one diff --git a/test/packetimpact/tests/tcp_zero_receive_window_test.go b/test/packetimpact/tests/tcp_zero_receive_window_test.go index d06690705..d73495454 100644 --- a/test/packetimpact/tests/tcp_zero_receive_window_test.go +++ b/test/packetimpact/tests/tcp_zero_receive_window_test.go @@ -49,8 +49,8 @@ func TestZeroReceiveWindow(t *testing.T) { // Expect the DUT to eventually advertise zero receive window. // The test would timeout otherwise. for readOnce := false; ; { - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -100,8 +100,8 @@ func TestNonZeroReceiveWindow(t *testing.T) { // we sent. Once we have received ACKs with non-zero receive windows, we break // the loop. for { - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, time.Second) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + gotTCP, err := conn.Expect(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, time.Second) if err != nil { t.Fatalf("expected packet was not received: %s", err) } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go index d094c10eb..22b17a39e 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_retransmit_test.go @@ -15,6 +15,7 @@ package tcp_zero_window_probe_retransmit_test import ( + "bytes" "flag" "testing" "time" @@ -51,18 +52,23 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { - t.Fatalf("expected packet was not received: %s", err) - } // Check for the dut to keep the connection alive as long as the zero window // probes are acknowledged. Check if the zero window probes are sent at // exponentially increasing intervals. The timeout intervals are function // of the recorded first zero probe transmission duration. // - // Advertize zero receive window again. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + // Advertize zero receive window along with a payload. + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh), WindowSize: testbench.Uint16(0)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, nil, time.Second); err != nil { + t.Fatalf("expected packet was not received: %s", err) + } + // Wait for the payload to be received by the DUT, which is also an + // indication of receive of the peer window advertisement. + if got := dut.Recv(t, acceptFd, int32(len(sampleData)), 0); !bytes.Equal(got, sampleData) { + t.Fatalf("got dut.Recv(t, %d, %d, 0) = %s, want %s", acceptFd, len(sampleData), got, sampleData) + } + probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) ackProbe := testbench.Uint32(uint32(*conn.RemoteSeqNum(t))) @@ -98,10 +104,10 @@ func TestZeroWindowProbeRetransmit(t *testing.T) { } prev = got // Acknowledge the zero-window probes from the dut. - conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) } // Advertize non-zero window. - conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.TCPFlags(header.TCPFlagAck)}) // Expect the dut to recover and transmit data. if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) diff --git a/test/packetimpact/tests/tcp_zero_window_probe_test.go b/test/packetimpact/tests/tcp_zero_window_probe_test.go index 650a569cc..8b90fcbe9 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_test.go @@ -53,8 +53,8 @@ func TestZeroWindowProbe(t *testing.T) { t.Fatalf("expected payload was not received: %s", err) } sendTime := time.Now().Sub(start) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -62,7 +62,7 @@ func TestZeroWindowProbe(t *testing.T) { // probe to be sent. // // Advertize zero window to the dut. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Expected sequence number of the zero window probe. probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) @@ -93,7 +93,7 @@ func TestZeroWindowProbe(t *testing.T) { // and sends out the sample payload after the send window opens. // // Advertize non-zero window to the dut and ack the zero window probe. - conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.Uint8(header.TCPFlagAck)}) + conn.Send(t, testbench.TCP{AckNum: ackProbe, Flags: testbench.TCPFlags(header.TCPFlagAck)}) // Expect the dut to recover and transmit data. if _, err := conn.ExpectData(t, &testbench.TCP{SeqNum: ackProbe}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) @@ -104,8 +104,8 @@ func TestZeroWindowProbe(t *testing.T) { // Basically with sequence number to one byte behind the unacknowledged // sequence number. p := testbench.Uint32(uint32(*conn.LocalSeqNum(t))) - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))}) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), SeqNum: testbench.Uint32(uint32(*conn.LocalSeqNum(t) - 1))}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), AckNum: p}, nil, time.Second); err != nil { t.Fatalf("expected a packet with ack number: %d: %s", p, err) } } diff --git a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go index 079fea68c..1ce4d22b7 100644 --- a/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go +++ b/test/packetimpact/tests/tcp_zero_window_probe_usertimeout_test.go @@ -51,8 +51,8 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { if _, err := conn.ExpectData(t, &testbench.TCP{}, samplePayload, time.Second); err != nil { t.Fatalf("expected payload was not received: %s", err) } - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck | header.TCPFlagPsh)}, samplePayload) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}, nil, time.Second); err != nil { t.Fatalf("expected packet was not received: %s", err) } @@ -60,7 +60,7 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // probe to be sent. // // Advertize zero window to the dut. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Expected sequence number of the zero window probe. probeSeq := testbench.Uint32(uint32(*conn.RemoteSeqNum(t) - 1)) @@ -81,7 +81,7 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // Reduce the retransmit timeout. dut.SetSockOptInt(t, acceptFd, unix.IPPROTO_TCP, unix.TCP_USER_TIMEOUT, int32(startProbeDuration.Milliseconds())) // Advertize zero window again. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck), WindowSize: testbench.Uint16(0)}) // Ask the dut to send out data that would trigger zero window probe retransmissions. dut.Send(t, acceptFd, sampleData, 0) @@ -90,8 +90,8 @@ func TestZeroWindowProbeUserTimeout(t *testing.T) { // Expect the connection to have timed out and closed which would cause the dut // to reply with a RST to the ACK we send. - conn.Send(t, testbench.TCP{Flags: testbench.Uint8(header.TCPFlagAck)}) - if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.Uint8(header.TCPFlagRst)}, nil, time.Second); err != nil { + conn.Send(t, testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagAck)}) + if _, err := conn.ExpectData(t, &testbench.TCP{Flags: testbench.TCPFlags(header.TCPFlagRst)}, nil, time.Second); err != nil { t.Fatalf("expected a TCP RST") } } diff --git a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go index 52c6f9d91..f63cfcc9a 100644 --- a/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go +++ b/test/packetimpact/tests/udp_discard_mcast_source_addr_test.go @@ -19,7 +19,6 @@ import ( "flag" "fmt" "net" - "syscall" "testing" "golang.org/x/sys/unix" @@ -56,7 +55,7 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV4(t *testing.T) { ) ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) - if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + if errno != unix.EAGAIN || errno != unix.EWOULDBLOCK { t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) } }) @@ -86,7 +85,7 @@ func TestDiscardsUDPPacketsWithMcastSourceAddressV6(t *testing.T) { &testbench.Payload{Bytes: []byte("test payload")}, ) ret, payload, errno := dut.RecvWithErrno(context.Background(), t, remoteFD, 100, 0) - if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + if errno != unix.EAGAIN || errno != unix.EWOULDBLOCK { t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, payload, errno) } }) diff --git a/test/packetimpact/tests/udp_icmp_error_propagation_test.go b/test/packetimpact/tests/udp_icmp_error_propagation_test.go index 3fca8c7a3..3159d5b89 100644 --- a/test/packetimpact/tests/udp_icmp_error_propagation_test.go +++ b/test/packetimpact/tests/udp_icmp_error_propagation_test.go @@ -20,7 +20,6 @@ import ( "fmt" "net" "sync" - "syscall" "testing" "time" @@ -86,16 +85,16 @@ type testData struct { remotePort uint16 cleanFD int32 cleanPort uint16 - wantErrno syscall.Errno + wantErrno unix.Errno } // wantErrno computes the errno to expect given the connection mode of a UDP // socket and the ICMP error it will receive. -func wantErrno(c connectionMode, icmpErr icmpError) syscall.Errno { +func wantErrno(c connectionMode, icmpErr icmpError) unix.Errno { if c && icmpErr == portUnreachable { - return syscall.Errno(unix.ECONNREFUSED) + return unix.ECONNREFUSED } - return syscall.Errno(0) + return unix.Errno(0) } // sendICMPError sends an ICMP error message in response to a UDP datagram. @@ -123,7 +122,7 @@ func sendICMPError(t *testing.T, conn *testbench.UDPIPv4, icmpErr icmpError, udp conn.SendFrameStateless(t, layers) } -// testRecv tests observing the ICMP error through the recv syscall. A packet +// testRecv tests observing the ICMP error through the recv unix. A packet // is sent to the DUT, and if wantErrno is non-zero, then the first recv should // fail and the second should succeed. Otherwise if wantErrno is zero then the // first recv should succeed immediately. @@ -136,7 +135,7 @@ func testRecv(ctx context.Context, t *testing.T, d testData) { d.conn.Send(t, testbench.UDP{}) - if d.wantErrno != syscall.Errno(0) { + if d.wantErrno != unix.Errno(0) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() ret, _, err := d.dut.RecvWithErrno(ctx, t, d.remoteFD, 100, 0) @@ -162,7 +161,7 @@ func testSendTo(ctx context.Context, t *testing.T, d testData) { t.Fatalf("did not receive UDP packet from clean socket on DUT: %s", err) } - if d.wantErrno != syscall.Errno(0) { + if d.wantErrno != unix.Errno(0) { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() ret, err := d.dut.SendToWithErrno(ctx, t, d.remoteFD, nil, 0, d.conn.LocalAddr(t)) @@ -183,11 +182,11 @@ func testSendTo(ctx context.Context, t *testing.T, d testData) { func testSockOpt(_ context.Context, t *testing.T, d testData) { // Check that there's no pending error on the clean socket. - if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != syscall.Errno(0) { + if errno := unix.Errno(d.dut.GetSockOptInt(t, d.cleanFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != unix.Errno(0) { t.Fatalf("unexpected error (%[1]d) %[1]v on clean socket", errno) } - if errno := syscall.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { + if errno := unix.Errno(d.dut.GetSockOptInt(t, d.remoteFD, unix.SOL_SOCKET, unix.SO_ERROR)); errno != d.wantErrno { t.Fatalf("SO_ERROR sockopt after ICMP error is (%[1]d) %[1]v, expected (%[2]d) %[2]v", errno, d.wantErrno) } @@ -310,7 +309,7 @@ func TestICMPErrorDuringUDPRecv(t *testing.T) { go func() { defer wg.Done() - if wantErrno != syscall.Errno(0) { + if wantErrno != unix.Errno(0) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() diff --git a/test/packetimpact/tests/udp_send_recv_dgram_test.go b/test/packetimpact/tests/udp_send_recv_dgram_test.go index 894d156cf..230b012c7 100644 --- a/test/packetimpact/tests/udp_send_recv_dgram_test.go +++ b/test/packetimpact/tests/udp_send_recv_dgram_test.go @@ -19,7 +19,6 @@ import ( "flag" "fmt" "net" - "syscall" "testing" "time" @@ -241,7 +240,7 @@ func TestUDP(t *testing.T) { }, ) ret, recvPayload, errno := dut.RecvWithErrno(context.Background(), t, socketFD, 100, 0) - if errno != syscall.EAGAIN || errno != syscall.EWOULDBLOCK { + if errno != unix.EAGAIN || errno != unix.EWOULDBLOCK { t.Errorf("Recv got unexpected result, ret=%d, payload=%q, errno=%s", ret, recvPayload, errno) } } diff --git a/test/perf/BUILD b/test/perf/BUILD index e25f090ae..ed899ac22 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -1,4 +1,3 @@ -load("//tools:defs.bzl", "more_shards") load("//test/runner:defs.bzl", "syscall_test") package(licenses = ["notice"]) @@ -38,7 +37,6 @@ syscall_test( syscall_test( size = "enormous", debug = False, - shard_count = more_shards, tags = ["nogotsan"], test = "//test/perf/linux:getdents_benchmark", ) diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go index 38e57d62f..2ad5f58ef 100644 --- a/test/runner/gtest/gtest.go +++ b/test/runner/gtest/gtest.go @@ -35,6 +35,39 @@ var ( filterBenchmarkFlag = "--benchmark_filter" ) +// BuildTestArgs builds arguments to be passed to the test binary to execute +// only the test cases in `indices`. +func BuildTestArgs(indices []int, testCases []TestCase) []string { + var testFilter, benchFilter string + for _, tci := range indices { + tc := testCases[tci] + if tc.all { + // No argument will make all tests run. + return nil + } + if tc.benchmark { + if len(benchFilter) > 0 { + benchFilter += "|" + } + benchFilter += "^" + tc.Name + "$" + } else { + if len(testFilter) > 0 { + testFilter += ":" + } + testFilter += tc.FullName() + } + } + + var args []string + if len(testFilter) > 0 { + args = append(args, fmt.Sprintf("%s=%s", filterTestFlag, testFilter)) + } + if len(benchFilter) > 0 { + args = append(args, fmt.Sprintf("%s=%s", filterBenchmarkFlag, benchFilter)) + } + return args +} + // TestCase is a single gtest test case. type TestCase struct { // Suite is the suite for this test. @@ -59,22 +92,6 @@ func (tc TestCase) FullName() string { return fmt.Sprintf("%s.%s", tc.Suite, tc.Name) } -// Args returns arguments to be passed when invoking the test. -func (tc TestCase) Args() []string { - if tc.all { - return []string{} // No arguments. - } - if tc.benchmark { - return []string{ - fmt.Sprintf("%s=^%s$", filterBenchmarkFlag, tc.Name), - fmt.Sprintf("%s=", filterTestFlag), - } - } - return []string{ - fmt.Sprintf("%s=%s", filterTestFlag, tc.FullName()), - } -} - // ParseTestCases calls a gtest test binary to list its test and returns a // slice with the name and suite of each test. // @@ -90,6 +107,7 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes // We failed to list tests with the given flags. Just // return something that will run the binary with no // flags, which should execute all tests. + fmt.Printf("failed to get test list: %v\n", err) return []TestCase{ { Suite: "Default", diff --git a/test/runner/runner.go b/test/runner/runner.go index e72c59200..a8a134fe2 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -26,7 +26,6 @@ import ( "path/filepath" "strings" "syscall" - "testing" "time" specs "github.com/opencontainers/runtime-spec/specs-go" @@ -57,13 +56,82 @@ var ( leakCheck = flag.Bool("leak-check", false, "check for reference leaks") ) +func main() { + flag.Parse() + if flag.NArg() != 1 { + fatalf("test must be provided") + } + + log.SetLevel(log.Info) + if *debug { + log.SetLevel(log.Debug) + } + + if *platform != "native" && *runscPath == "" { + if err := testutil.ConfigureExePath(); err != nil { + panic(err.Error()) + } + *runscPath = specutils.ExePath + } + + // Make sure stdout and stderr are opened with O_APPEND, otherwise logs + // from outside the sandbox can (and will) stomp on logs from inside + // the sandbox. + for _, f := range []*os.File{os.Stdout, os.Stderr} { + flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) + if err != nil { + fatalf("error getting file flags for %v: %v", f, err) + } + if flags&unix.O_APPEND == 0 { + flags |= unix.O_APPEND + if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { + fatalf("error setting file flags for %v: %v", f, err) + } + } + } + + // Resolve the absolute path for the binary. + testBin, err := filepath.Abs(flag.Args()[0]) + if err != nil { + fatalf("Abs(%q) failed: %v", flag.Args()[0], err) + } + + // Get all test cases in each binary. + testCases, err := gtest.ParseTestCases(testBin, true) + if err != nil { + fatalf("ParseTestCases(%q) failed: %v", testBin, err) + } + + // Get subset of tests corresponding to shard. + indices, err := testutil.TestIndicesForShard(len(testCases)) + if err != nil { + fatalf("TestsForShard() failed: %v", err) + } + if len(indices) == 0 { + log.Warningf("No tests to run in this shard") + return + } + args := gtest.BuildTestArgs(indices, testCases) + + switch *platform { + case "native": + if err := runTestCaseNative(testBin, args); err != nil { + fatalf(err.Error()) + } + default: + if err := runTestCaseRunsc(testBin, args); err != nil { + fatalf(err.Error()) + } + } +} + // runTestCaseNative runs the test case directly on the host machine. -func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { +func runTestCaseNative(testBin string, args []string) error { // These tests might be running in parallel, so make sure they have a // unique test temp dir. tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") if err != nil { - t.Fatalf("could not create temp dir: %v", err) + return fmt.Errorf("could not create temp dir: %v", err) } defer os.RemoveAll(tmpDir) @@ -84,12 +152,12 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { } // Remove shard env variables so that the gunit binary does not try to // interpret them. - env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) + env = filterEnv(env, "TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS") if *addUDSTree { socketDir, cleanup, err := uds.CreateSocketTree("/tmp") if err != nil { - t.Fatalf("failed to create socket tree: %v", err) + return fmt.Errorf("failed to create socket tree: %v", err) } defer cleanup() @@ -99,24 +167,25 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { env = append(env, "TEST_UDS_ATTACH_TREE="+socketDir) } - cmd := exec.Command(testBin, tc.Args()...) + cmd := exec.Command(testBin, args...) cmd.Env = env cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - cmd.SysProcAttr = &syscall.SysProcAttr{} + cmd.SysProcAttr = &unix.SysProcAttr{} if specutils.HasCapabilities(capability.CAP_SYS_ADMIN) { - cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWUTS + cmd.SysProcAttr.Cloneflags |= unix.CLONE_NEWUTS } if specutils.HasCapabilities(capability.CAP_NET_ADMIN) { - cmd.SysProcAttr.Cloneflags |= syscall.CLONE_NEWNET + cmd.SysProcAttr.Cloneflags |= unix.CLONE_NEWNET } if err := cmd.Run(); err != nil { ws := err.(*exec.ExitError).Sys().(syscall.WaitStatus) - t.Errorf("test %q exited with status %d, want 0", tc.FullName(), ws.ExitStatus()) + return fmt.Errorf("test exited with status %d, want 0", ws.ExitStatus()) } + return nil } // runRunsc runs spec in runsc in a standard test configuration. @@ -124,7 +193,7 @@ func runTestCaseNative(testBin string, tc gtest.TestCase, t *testing.T) { // runsc logs will be saved to a path in TEST_UNDECLARED_OUTPUTS_DIR. // // Returns an error if the sandboxed application exits non-zero. -func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { +func runRunsc(spec *specs.Spec) error { bundleDir, cleanup, err := testutil.SetupBundleDir(spec) if err != nil { return fmt.Errorf("SetupBundleDir failed: %v", err) @@ -137,9 +206,8 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { } defer cleanup() - name := tc.FullName() id := testutil.RandomContainerID() - log.Infof("Running test %q in container %q", name, id) + log.Infof("Running test in container %q", id) specutils.LogSpec(spec) args := []string{ @@ -148,7 +216,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { "-log-format=text", "-TESTONLY-unsafe-nonroot=true", "-net-raw=true", - fmt.Sprintf("-panic-signal=%d", syscall.SIGTERM), + fmt.Sprintf("-panic-signal=%d", unix.SIGTERM), "-watchdog-action=panic", "-platform", *platform, "-file-access", *fileAccess, @@ -175,13 +243,8 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { args = append(args, "-ref-leak-mode=log-names") } - testLogDir := "" - if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { - // Create log directory dedicated for this test. - testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) - if err := os.MkdirAll(testLogDir, 0755); err != nil { - return fmt.Errorf("could not create test dir: %v", err) - } + testLogDir := os.Getenv("TEST_UNDECLARED_OUTPUTS_DIR") + if len(testLogDir) > 0 { debugLogDir, err := ioutil.TempDir(testLogDir, "runsc") if err != nil { return fmt.Errorf("could not create temp dir: %v", err) @@ -200,8 +263,8 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { // as root inside that namespace to get it. rArgs := append(args, "run", "--bundle", bundleDir, id) cmd := exec.Command(*runscPath, rArgs...) - cmd.SysProcAttr = &syscall.SysProcAttr{ - Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNS, + cmd.SysProcAttr = &unix.SysProcAttr{ + Cloneflags: unix.CLONE_NEWUSER | unix.CLONE_NEWNS, // Set current user/group as root inside the namespace. UidMappings: []syscall.SysProcIDMap{ {ContainerID: 0, HostID: os.Getuid(), Size: 1}, @@ -219,14 +282,14 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { cmd.Stderr = os.Stderr sig := make(chan os.Signal, 1) defer close(sig) - signal.Notify(sig, syscall.SIGTERM) + signal.Notify(sig, unix.SIGTERM) defer signal.Stop(sig) go func() { s, ok := <-sig if !ok { return } - log.Warningf("%s: Got signal: %v", name, s) + log.Warningf("Got signal: %v", s) done := make(chan bool, 1) dArgs := append([]string{}, args...) dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) @@ -247,7 +310,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { log.Warningf("Send SIGTERM to the sandbox process") dArgs = append(args, "debug", - fmt.Sprintf("--signal=%d", syscall.SIGTERM), + fmt.Sprintf("--signal=%d", unix.SIGTERM), id) signal := exec.Command(*runscPath, dArgs...) signal.Stdout = os.Stdout @@ -259,7 +322,7 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { if err == nil && len(testLogDir) > 0 { // If the test passed, then we erase the log directory. This speeds up // uploading logs in continuous integration & saves on disk space. - os.RemoveAll(testLogDir) + _ = os.RemoveAll(testLogDir) } return err @@ -314,10 +377,10 @@ func setupUDSTree(spec *specs.Spec) (cleanup func(), err error) { } // runsTestCaseRunsc runs the test case in runsc. -func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { +func runTestCaseRunsc(testBin string, args []string) error { // Run a new container with the test executable and filter for the // given test suite and name. - spec := testutil.NewSpecWithArgs(append([]string{testBin}, tc.Args()...)...) + spec := testutil.NewSpecWithArgs(append([]string{testBin}, args...)...) // Mark the root as writeable, as some tests attempt to // write to the rootfs, and expect EACCES, not EROFS. @@ -343,12 +406,12 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // users, so make sure it is world-accessible. tmpDir, err := ioutil.TempDir(testutil.TmpDir(), "") if err != nil { - t.Fatalf("could not create temp dir: %v", err) + return fmt.Errorf("could not create temp dir: %v", err) } defer os.RemoveAll(tmpDir) if err := os.Chmod(tmpDir, 0777); err != nil { - t.Fatalf("could not chmod temp dir: %v", err) + return fmt.Errorf("could not chmod temp dir: %v", err) } // "/tmp" is not replaced with a tmpfs mount inside the sandbox @@ -368,13 +431,12 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Set environment variables that indicate we are running in gVisor with // the given platform, network, and filesystem stack. - platformVar := "TEST_ON_GVISOR" - networkVar := "GVISOR_NETWORK" - env := append(os.Environ(), platformVar+"="+*platform, networkVar+"="+*network) - vfsVar := "GVISOR_VFS" + env := []string{"TEST_ON_GVISOR=" + *platform, "GVISOR_NETWORK=" + *network} + env = append(env, os.Environ()...) + const vfsVar = "GVISOR_VFS" if *vfs2 { env = append(env, vfsVar+"=VFS2") - fuseVar := "FUSE_ENABLED" + const fuseVar = "FUSE_ENABLED" if *fuse { env = append(env, fuseVar+"=TRUE") } else { @@ -386,11 +448,11 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { // Remove shard env variables so that the gunit binary does not try to // interpret them. - env = filterEnv(env, []string{"TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS"}) + env = filterEnv(env, "TEST_SHARD_INDEX", "TEST_TOTAL_SHARDS", "GTEST_SHARD_INDEX", "GTEST_TOTAL_SHARDS") // Set TEST_TMPDIR to /tmp, as some of the syscall tests require it to // be backed by tmpfs. - env = filterEnv(env, []string{"TEST_TMPDIR"}) + env = filterEnv(env, "TEST_TMPDIR") env = append(env, fmt.Sprintf("TEST_TMPDIR=%s", testTmpDir)) spec.Process.Env = env @@ -398,18 +460,19 @@ func runTestCaseRunsc(testBin string, tc gtest.TestCase, t *testing.T) { if *addUDSTree { cleanup, err := setupUDSTree(spec) if err != nil { - t.Fatalf("error creating UDS tree: %v", err) + return fmt.Errorf("error creating UDS tree: %v", err) } defer cleanup() } - if err := runRunsc(tc, spec); err != nil { - t.Errorf("test %q failed with error %v, want nil", tc.FullName(), err) + if err := runRunsc(spec); err != nil { + return fmt.Errorf("test failed with error %v, want nil", err) } + return nil } // filterEnv returns an environment with the excluded variables removed. -func filterEnv(env, exclude []string) []string { +func filterEnv(env []string, exclude ...string) []string { var out []string for _, kv := range env { ok := true @@ -430,82 +493,3 @@ func fatalf(s string, args ...interface{}) { fmt.Fprintf(os.Stderr, s+"\n", args...) os.Exit(1) } - -func matchString(a, b string) (bool, error) { - return a == b, nil -} - -func main() { - flag.Parse() - if flag.NArg() != 1 { - fatalf("test must be provided") - } - testBin := flag.Args()[0] // Only argument. - - log.SetLevel(log.Info) - if *debug { - log.SetLevel(log.Debug) - } - - if *platform != "native" && *runscPath == "" { - if err := testutil.ConfigureExePath(); err != nil { - panic(err.Error()) - } - *runscPath = specutils.ExePath - } - - // Make sure stdout and stderr are opened with O_APPEND, otherwise logs - // from outside the sandbox can (and will) stomp on logs from inside - // the sandbox. - for _, f := range []*os.File{os.Stdout, os.Stderr} { - flags, err := unix.FcntlInt(f.Fd(), unix.F_GETFL, 0) - if err != nil { - fatalf("error getting file flags for %v: %v", f, err) - } - if flags&unix.O_APPEND == 0 { - flags |= unix.O_APPEND - if _, err := unix.FcntlInt(f.Fd(), unix.F_SETFL, flags); err != nil { - fatalf("error setting file flags for %v: %v", f, err) - } - } - } - - // Get all test cases in each binary. - testCases, err := gtest.ParseTestCases(testBin, true) - if err != nil { - fatalf("ParseTestCases(%q) failed: %v", testBin, err) - } - - // Get subset of tests corresponding to shard. - indices, err := testutil.TestIndicesForShard(len(testCases)) - if err != nil { - fatalf("TestsForShard() failed: %v", err) - } - - // Resolve the absolute path for the binary. - testBin, err = filepath.Abs(testBin) - if err != nil { - fatalf("Abs() failed: %v", err) - } - - // Run the tests. - var tests []testing.InternalTest - for _, tci := range indices { - // Capture tc. - tc := testCases[tci] - tests = append(tests, testing.InternalTest{ - Name: fmt.Sprintf("%s_%s", tc.Suite, tc.Name), - F: func(t *testing.T) { - if *platform == "native" { - // Run the test case on host. - runTestCaseNative(testBin, tc, t) - } else { - // Run the test case in runsc. - runTestCaseRunsc(testBin, tc, t) - } - }, - }) - } - - testing.Main(matchString, tests, nil, nil) -} diff --git a/test/runtimes/proctor/BUILD b/test/runtimes/proctor/BUILD index fdc6d3173..b4a9b12de 100644 --- a/test/runtimes/proctor/BUILD +++ b/test/runtimes/proctor/BUILD @@ -7,5 +7,8 @@ go_binary( srcs = ["main.go"], pure = True, visibility = ["//test/runtimes:__pkg__"], - deps = ["//test/runtimes/proctor/lib"], + deps = [ + "//test/runtimes/proctor/lib", + "@org_golang_x_sys//unix:go_default_library", + ], ) diff --git a/test/runtimes/proctor/lib/BUILD b/test/runtimes/proctor/lib/BUILD index 0c8367dfe..f834f1b5a 100644 --- a/test/runtimes/proctor/lib/BUILD +++ b/test/runtimes/proctor/lib/BUILD @@ -13,6 +13,7 @@ go_library( "python.go", ], visibility = ["//test/runtimes/proctor:__pkg__"], + deps = ["@org_golang_x_sys//unix:go_default_library"], ) go_test( diff --git a/test/runtimes/proctor/lib/lib.go b/test/runtimes/proctor/lib/lib.go index f2ba82498..36c60088a 100644 --- a/test/runtimes/proctor/lib/lib.go +++ b/test/runtimes/proctor/lib/lib.go @@ -22,7 +22,8 @@ import ( "os/signal" "path/filepath" "regexp" - "syscall" + + "golang.org/x/sys/unix" ) // TestRunner is an interface that must be implemented for each runtime @@ -59,7 +60,7 @@ func TestRunnerForRuntime(runtime string) (TestRunner, error) { func PauseAndReap() { // Get notified of any new children. ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGCHLD) + signal.Notify(ch, unix.SIGCHLD) for { if _, ok := <-ch; !ok { @@ -69,7 +70,7 @@ func PauseAndReap() { // Reap the child. for { - if cpid, _ := syscall.Wait4(-1, nil, syscall.WNOHANG, nil); cpid < 1 { + if cpid, _ := unix.Wait4(-1, nil, unix.WNOHANG, nil); cpid < 1 { break } } diff --git a/test/runtimes/proctor/main.go b/test/runtimes/proctor/main.go index 81cb68381..8c076a499 100644 --- a/test/runtimes/proctor/main.go +++ b/test/runtimes/proctor/main.go @@ -22,8 +22,8 @@ import ( "log" "os" "strings" - "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/test/runtimes/proctor/lib" ) @@ -42,14 +42,14 @@ func setNumFilesLimit() error { // timeout if the NOFILE limit is too high. On gVisor, syscalls are // slower so these tests will need even more time to pass. const nofile = 32768 - rLimit := syscall.Rlimit{} - err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit) + rLimit := unix.Rlimit{} + err := unix.Getrlimit(unix.RLIMIT_NOFILE, &rLimit) if err != nil { return fmt.Errorf("failed to get RLIMIT_NOFILE: %v", err) } if rLimit.Cur > nofile { rLimit.Cur = nofile - err := syscall.Setrlimit(syscall.RLIMIT_NOFILE, &rLimit) + err := unix.Setrlimit(unix.RLIMIT_NOFILE, &rLimit) if err != nil { return fmt.Errorf("failed to set RLIMIT_NOFILE: %v", err) } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 9adb1cea3..ef299799e 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -65,14 +65,8 @@ syscall_test( syscall_test( size = "large", - # Produce too many logs in the debug mode. - debug = False, shard_count = most_shards, - # Takes too long for TSAN. Since this is kind of a stress test that doesn't - # involve much concurrency, TSAN's usefulness here is limited anyway. - tags = ["nogotsan"], test = "//test/syscalls/linux:socket_stress_test", - vfs2 = False, ) syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 5371f825c..5399d8106 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2330,13 +2330,15 @@ cc_binary( ], linkstatic = 1, deps = [ + gtest, ":ip_socket_test_util", ":socket_test_util", - "@com_google_absl//absl/strings", - gtest, + "//test/util:file_descriptor", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", ], ) diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index e508ce27f..61a421788 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -2162,7 +2162,13 @@ class BlockingChild { return tid_; } - void Join() { Stop(); } + void Join() { + { + absl::MutexLock ml(&mu_); + stop_ = true; + } + thread_.Join(); + } private: void Start() { @@ -2172,11 +2178,6 @@ class BlockingChild { mu_.Await(absl::Condition(&stop_)); } - void Stop() { - absl::MutexLock ml(&mu_); - stop_ = true; - } - mutable absl::Mutex mu_; bool stop_ ABSL_GUARDED_BY(mu_) = false; pid_t tid_; @@ -2190,16 +2191,18 @@ class BlockingChild { TEST(ProcTask, NewThreadAppears) { auto initial = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/task", false)); BlockingChild child1; - EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task", - TaskFiles(initial, {child1.Tid()}))); + // Use Eventually* in case a proc from ealier test is still tearing down. + EXPECT_NO_ERRNO(EventuallyDirContainsExactly( + "/proc/self/task", TaskFiles(initial, {child1.Tid()}))); } TEST(ProcTask, KilledThreadsDisappear) { auto initial = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/proc/self/task/", false)); BlockingChild child1; - EXPECT_NO_ERRNO(DirContainsExactly("/proc/self/task", - TaskFiles(initial, {child1.Tid()}))); + // Use Eventually* in case a proc from ealier test is still tearing down. + EXPECT_NO_ERRNO(EventuallyDirContainsExactly( + "/proc/self/task", TaskFiles(initial, {child1.Tid()}))); // Stat child1's task file. Regression test for b/32097707. struct stat statbuf; diff --git a/test/syscalls/linux/proc_net.cc b/test/syscalls/linux/proc_net.cc index 73140b2e9..20f1dc305 100644 --- a/test/syscalls/linux/proc_net.cc +++ b/test/syscalls/linux/proc_net.cc @@ -40,6 +40,7 @@ namespace { constexpr const char kProcNet[] = "/proc/net"; constexpr const char kIpForward[] = "/proc/sys/net/ipv4/ip_forward"; +constexpr const char kRangeFile[] = "/proc/sys/net/ipv4/ip_local_port_range"; TEST(ProcNetSymlinkTarget, FileMode) { struct stat s; @@ -562,6 +563,42 @@ TEST(ProcSysNetIpv4IpForward, CanReadAndWrite) { EXPECT_EQ(buf, to_write); } +TEST(ProcSysNetPortRange, CanReadAndWrite) { + int min; + int max; + std::string rangefile = ASSERT_NO_ERRNO_AND_VALUE(GetContents(kRangeFile)); + ASSERT_EQ(rangefile.back(), '\n'); + rangefile.pop_back(); + std::vector<std::string> range = + absl::StrSplit(rangefile, absl::ByAnyChar("\t ")); + ASSERT_GT(range.size(), 1); + ASSERT_TRUE(absl::SimpleAtoi(range.front(), &min)); + ASSERT_TRUE(absl::SimpleAtoi(range.back(), &max)); + EXPECT_LE(min, max); + + // If the file isn't writable, there's nothing else to do here. + if (access(kRangeFile, W_OK)) { + return; + } + + constexpr int kSize = 77; + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(kRangeFile, O_WRONLY | O_TRUNC, 0)); + max = min + kSize; + const std::string small_range = absl::StrFormat("%d %d", min, max); + ASSERT_THAT(write(fd.get(), small_range.c_str(), small_range.size()), + SyscallSucceedsWithValue(small_range.size())); + + rangefile = ASSERT_NO_ERRNO_AND_VALUE(GetContents(kRangeFile)); + ASSERT_EQ(rangefile.back(), '\n'); + rangefile.pop_back(); + range = absl::StrSplit(rangefile, absl::ByAnyChar("\t ")); + ASSERT_GT(range.size(), 1); + ASSERT_TRUE(absl::SimpleAtoi(range.front(), &min)); + ASSERT_TRUE(absl::SimpleAtoi(range.back(), &max)); + EXPECT_EQ(min + kSize, max); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index 294b9f6fd..8d15c491e 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -1255,8 +1255,11 @@ TEST_F(PtyTest, PartialBadBuffer) { // Read from the replica into bad_buffer. ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), size)); - EXPECT_THAT(ReadFd(replica_.get(), bad_buffer, size), - SyscallFailsWithErrno(EFAULT)); + // Before Linux 3b830a9c this returned EFAULT, but after that commit it + // returns EAGAIN. + EXPECT_THAT( + ReadFd(replica_.get(), bad_buffer, size), + AnyOf(SyscallFailsWithErrno(EFAULT), SyscallFailsWithErrno(EAGAIN))); EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()) << addr; } diff --git a/test/syscalls/linux/socket_generic_stress.cc b/test/syscalls/linux/socket_generic_stress.cc index 679586530..c35aa2183 100644 --- a/test/syscalls/linux/socket_generic_stress.cc +++ b/test/syscalls/linux/socket_generic_stress.cc @@ -17,29 +17,72 @@ #include <sys/ioctl.h> #include <sys/socket.h> #include <sys/un.h> +#include <unistd.h> #include <array> #include <string> #include "gtest/gtest.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "test/syscalls/linux/ip_socket_test_util.h" #include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" namespace gvisor { namespace testing { +constexpr char kRangeFile[] = "/proc/sys/net/ipv4/ip_local_port_range"; + +PosixErrorOr<int> NumPorts() { + int min = 0; + int max = 1 << 16; + + // Read the ephemeral range from /proc. + ASSIGN_OR_RETURN_ERRNO(std::string rangefile, GetContents(kRangeFile)); + const std::string err_msg = + absl::StrFormat("%s has invalid content: %s", kRangeFile, rangefile); + if (rangefile.back() != '\n') { + return PosixError(EINVAL, err_msg); + } + rangefile.pop_back(); + std::vector<std::string> range = + absl::StrSplit(rangefile, absl::ByAnyChar("\t ")); + if (range.size() < 2 || !absl::SimpleAtoi(range.front(), &min) || + !absl::SimpleAtoi(range.back(), &max)) { + return PosixError(EINVAL, err_msg); + } + + // If we can open as writable, limit the range. + if (!access(kRangeFile, W_OK)) { + ASSIGN_OR_RETURN_ERRNO(FileDescriptor fd, + Open(kRangeFile, O_WRONLY | O_TRUNC, 0)); + max = min + 50; + const std::string small_range = absl::StrFormat("%d %d", min, max); + int n = write(fd.get(), small_range.c_str(), small_range.size()); + if (n < 0) { + return PosixError( + errno, + absl::StrFormat("write(%d [%s], \"%s\", %d)", fd.get(), kRangeFile, + small_range.c_str(), small_range.size())); + } + } + return max - min; +} + // Test fixture for tests that apply to pairs of connected sockets. using ConnectStressTest = SocketPairTest; -TEST_P(ConnectStressTest, Reset65kTimes) { - // TODO(b/165912341): These are too slow on KVM platform with nested virt. - SKIP_IF(GvisorPlatform() == Platform::kKVM); - - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(ConnectStressTest, Reset) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + for (int i = 0; i < nports * 2; i++) { + const std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); // Send some data to ensure that the connection gets reset and the port gets // released immediately. This avoids either end entering TIME-WAIT. @@ -57,6 +100,24 @@ TEST_P(ConnectStressTest, Reset65kTimes) { } } +// Tests that opening too many connections -- without closing them -- does lead +// to port exhaustion. +TEST_P(ConnectStressTest, TooManyOpen) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + int err_num = 0; + std::vector<std::unique_ptr<SocketPair>> sockets = + std::vector<std::unique_ptr<SocketPair>>(nports); + for (int i = 0; i < nports * 2; i++) { + PosixErrorOr<std::unique_ptr<SocketPair>> socks = NewSocketPair(); + if (!socks.ok()) { + err_num = socks.error().errno_value(); + break; + } + sockets.push_back(std::move(socks).ValueOrDie()); + } + ASSERT_EQ(err_num, EADDRINUSE); +} + INSTANTIATE_TEST_SUITE_P( AllConnectedSockets, ConnectStressTest, ::testing::Values(IPv6UDPBidirectionalBindSocketPair(0), @@ -73,14 +134,40 @@ INSTANTIATE_TEST_SUITE_P( // Test fixture for tests that apply to pairs of connected sockets created with // a persistent listener (if applicable). -using PersistentListenerConnectStressTest = SocketPairTest; +class PersistentListenerConnectStressTest : public SocketPairTest { + protected: + PersistentListenerConnectStressTest() : slept_{false} {} -TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { - // TODO(b/165912341): These are too slow on KVM platform with nested virt. - SKIP_IF(GvisorPlatform() == Platform::kKVM); + // NewSocketSleep is the same as NewSocketPair, but will sleep once (over the + // lifetime of the fixture) and retry if creation fails due to EADDRNOTAVAIL. + PosixErrorOr<std::unique_ptr<SocketPair>> NewSocketSleep() { + // We can't reuse a connection too close in time to its last use, as TCP + // uses the timestamp difference to disambiguate connections. With a + // sufficiently small port range, we'll cycle through too quickly, and TCP + // won't allow for connection reuse. Thus, we sleep the first time + // encountering EADDRINUSE to allow for that difference (1 second in + // gVisor). + PosixErrorOr<std::unique_ptr<SocketPair>> socks = NewSocketPair(); + if (socks.ok()) { + return socks; + } + if (!slept_ && socks.error().errno_value() == EADDRNOTAVAIL) { + absl::SleepFor(absl::Milliseconds(1500)); + slept_ = true; + return NewSocketPair(); + } + return socks; + } - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + private: + bool slept_; +}; + +TEST_P(PersistentListenerConnectStressTest, ShutdownCloseFirst) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + for (int i = 0; i < nports * 2; i++) { + std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketSleep()); ASSERT_THAT(shutdown(sockets->first_fd(), SHUT_RDWR), SyscallSucceeds()); if (GetParam().type == SOCK_STREAM) { // Poll the other FD to make sure that we see the FIN from the other @@ -97,12 +184,11 @@ TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseFirst) { } } -TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { - // TODO(b/165912341): These are too slow on KVM platform with nested virt. - SKIP_IF(GvisorPlatform() == Platform::kKVM); - - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(PersistentListenerConnectStressTest, ShutdownCloseSecond) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + for (int i = 0; i < nports * 2; i++) { + const std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); ASSERT_THAT(shutdown(sockets->second_fd(), SHUT_RDWR), SyscallSucceeds()); if (GetParam().type == SOCK_STREAM) { // Poll the other FD to make sure that we see the FIN from the other @@ -119,12 +205,11 @@ TEST_P(PersistentListenerConnectStressTest, 65kTimesShutdownCloseSecond) { } } -TEST_P(PersistentListenerConnectStressTest, 65kTimesClose) { - // TODO(b/165912341): These are too slow on KVM platform with nested virt. - SKIP_IF(GvisorPlatform() == Platform::kKVM); - - for (int i = 0; i < 1 << 16; ++i) { - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); +TEST_P(PersistentListenerConnectStressTest, Close) { + const int nports = ASSERT_NO_ERRNO_AND_VALUE(NumPorts()); + for (int i = 0; i < nports * 2; i++) { + std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketSleep()); } } @@ -149,7 +234,8 @@ TEST_P(DataTransferStressTest, BigDataTransfer) { // TODO(b/165912341): These are too slow on KVM platform with nested virt. SKIP_IF(GvisorPlatform() == Platform::kKVM); - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + const std::unique_ptr<SocketPair> sockets = + ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); int client_fd = sockets->first_fd(); int server_fd = sockets->second_fd(); diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 344a5a22c..54b45b075 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -705,12 +705,6 @@ TEST_P(SocketInetLoopbackTest, TCPFinWait2Test_NoRandomSave) { ds.reset(); - if (!IsRunningOnGvisor()) { - ASSERT_THAT( - bind(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_bound_addr), - conn_addrlen), - SyscallSucceeds()); - } ASSERT_THAT(RetryEINTR(connect)(conn_fd2.get(), reinterpret_cast<sockaddr*>(&conn_addr), conn_addrlen), diff --git a/test/uds/BUILD b/test/uds/BUILD index 51e2c7ce8..a8f49b50c 100644 --- a/test/uds/BUILD +++ b/test/uds/BUILD @@ -12,5 +12,6 @@ go_library( deps = [ "//pkg/log", "//pkg/unet", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/test/uds/uds.go b/test/uds/uds.go index b714c61b0..02a4a7dee 100644 --- a/test/uds/uds.go +++ b/test/uds/uds.go @@ -21,8 +21,8 @@ import ( "io/ioutil" "os" "path/filepath" - "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/unet" ) @@ -31,16 +31,16 @@ import ( // // Only works for stream, seqpacket sockets. func createEchoSocket(path string, protocol int) (cleanup func(), err error) { - fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + fd, err := unix.Socket(unix.AF_UNIX, protocol, 0) if err != nil { return nil, fmt.Errorf("error creating echo(%d) socket: %v", protocol, err) } - if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil { return nil, fmt.Errorf("error binding echo(%d) socket: %v", protocol, err) } - if err := syscall.Listen(fd, 0); err != nil { + if err := unix.Listen(fd, 0); err != nil { return nil, fmt.Errorf("error listening echo(%d) socket: %v", protocol, err) } @@ -97,17 +97,17 @@ func createEchoSocket(path string, protocol int) (cleanup func(), err error) { // // Only relevant for stream, seqpacket sockets. func createNonListeningSocket(path string, protocol int) (cleanup func(), err error) { - fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + fd, err := unix.Socket(unix.AF_UNIX, protocol, 0) if err != nil { return nil, fmt.Errorf("error creating nonlistening(%d) socket: %v", protocol, err) } - if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil { return nil, fmt.Errorf("error binding nonlistening(%d) socket: %v", protocol, err) } cleanup = func() { - if err := syscall.Close(fd); err != nil { + if err := unix.Close(fd); err != nil { log.Warningf("Failed to close nonlistening(%d) socket: %v", protocol, err) } } @@ -119,12 +119,12 @@ func createNonListeningSocket(path string, protocol int) (cleanup func(), err er // // Only works for dgram sockets. func createNullSocket(path string, protocol int) (cleanup func(), err error) { - fd, err := syscall.Socket(syscall.AF_UNIX, protocol, 0) + fd, err := unix.Socket(unix.AF_UNIX, protocol, 0) if err != nil { return nil, fmt.Errorf("error creating null(%d) socket: %v", protocol, err) } - if err := syscall.Bind(fd, &syscall.SockaddrUnix{Name: path}); err != nil { + if err := unix.Bind(fd, &unix.SockaddrUnix{Name: path}); err != nil { return nil, fmt.Errorf("error binding null(%d) socket: %v", protocol, err) } @@ -174,7 +174,7 @@ func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) { sockets map[string]socketCreator }{ { - protocol: syscall.SOCK_STREAM, + protocol: unix.SOCK_STREAM, name: "stream", sockets: map[string]socketCreator{ "echo": createEchoSocket, @@ -182,7 +182,7 @@ func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) { }, }, { - protocol: syscall.SOCK_SEQPACKET, + protocol: unix.SOCK_SEQPACKET, name: "seqpacket", sockets: map[string]socketCreator{ "echo": createEchoSocket, @@ -190,7 +190,7 @@ func CreateSocketTree(baseDir string) (dir string, cleanup func(), err error) { }, }, { - protocol: syscall.SOCK_DGRAM, + protocol: unix.SOCK_DGRAM, name: "dgram", sockets: map[string]socketCreator{ "null": createNullSocket, diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index abd6f69ea..634abd1af 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -427,7 +427,7 @@ func (g *Generator) generateOne(t *marshallableType, fset *token.FileSet) *inter // implementations type t. func (g *Generator) generateOneTestSuite(t *marshallableType) *testGenerator { i := newTestGenerator(t.spec, t.recv) - i.emitTests(t.slice, t.dynamic) + i.emitTests(t.slice) return i } @@ -488,7 +488,11 @@ func (g *Generator) Run() error { panic(fmt.Sprintf("Generated code for '%s' referenced a non-existent import with local name '%s'. Either go-marshal needs to add an import to the generated file, or a package in an input source file has a package name differ from the final component of its path, which go-marshal doesn't know how to detect; use an import alias to work around this limitation.", impl.typeName(), name)) } } - ts = append(ts, g.generateOneTestSuite(t)) + // Do not generate tests for dynamic types because they inherently + // violate some go_marshal requirements. + if !t.dynamic { + ts = append(ts, g.generateOneTestSuite(t)) + } } } diff --git a/tools/go_marshal/gomarshal/generator_tests.go b/tools/go_marshal/gomarshal/generator_tests.go index ca3e15c16..6cf00843f 100644 --- a/tools/go_marshal/gomarshal/generator_tests.go +++ b/tools/go_marshal/gomarshal/generator_tests.go @@ -216,16 +216,12 @@ func (g *testGenerator) emitTestSizeBytesOnTypedNilPtr() { }) } -func (g *testGenerator) emitTests(slice *sliceAPI, isDynamic bool) { +func (g *testGenerator) emitTests(slice *sliceAPI) { g.emitTestNonZeroSize() g.emitTestSuspectAlignment() - if !isDynamic { - // Do not test these for dynamic structs because they violate some - // assumptions that these tests make. - g.emitTestMarshalUnmarshalPreservesData() - g.emitTestWriteToUnmarshalPreservesData() - g.emitTestSizeBytesOnTypedNilPtr() - } + g.emitTestMarshalUnmarshalPreservesData() + g.emitTestWriteToUnmarshalPreservesData() + g.emitTestSizeBytesOnTypedNilPtr() if slice != nil { g.emitTestMarshalUnmarshalSlicePreservesData(slice) diff --git a/tools/make_apt.sh b/tools/make_apt.sh index 68f6973ec..935c4db2d 100755 --- a/tools/make_apt.sh +++ b/tools/make_apt.sh @@ -107,7 +107,9 @@ for pkg in "$@"; do cp -a -L "$(dirname "${pkg}")/${name}.deb" "${destdir}" cp -a -L "$(dirname "${pkg}")/${name}.changes" "${destdir}" chmod 0644 "${destdir}"/"${name}".* + # Sign a package only if it isn't signed yet. # We use [*] here to expand the gpg_opts array into a single shell-word. + dpkg-sig -g "${gpg_opts[*]}" --verify "${destdir}/${name}.deb" || dpkg-sig -g "${gpg_opts[*]}" --sign builder "${destdir}/${name}.deb" done diff --git a/tools/verity/BUILD b/tools/verity/BUILD new file mode 100644 index 000000000..77d16359c --- /dev/null +++ b/tools/verity/BUILD @@ -0,0 +1,15 @@ +load("//tools:defs.bzl", "go_binary") + +licenses(["notice"]) + +go_binary( + name = "measure_tool", + srcs = [ + "measure_tool.go", + "measure_tool_unsafe.go", + ], + pure = True, + deps = [ + "//pkg/abi/linux", + ], +) diff --git a/tools/verity/measure_tool.go b/tools/verity/measure_tool.go new file mode 100644 index 000000000..0d314ae70 --- /dev/null +++ b/tools/verity/measure_tool.go @@ -0,0 +1,87 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This binary can be used to run a measurement of the verity file system, +// generate the corresponding Merkle tree files, and return the root hash. +package main + +import ( + "flag" + "io/ioutil" + "log" + "os" + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" +) + +var path = flag.String("path", "", "path to the verity file system.") + +const maxDigestSize = 64 + +type digest struct { + metadata linux.DigestMetadata + digest [maxDigestSize]byte +} + +func main() { + flag.Parse() + if *path == "" { + log.Fatalf("no path provided") + } + if err := enableDir(*path); err != nil { + log.Fatalf("Failed to enable file system %s: %v", *path, err) + } + // Print the root hash of the file system to stdout. + if err := measure(*path); err != nil { + log.Fatalf("Failed to measure file system %s: %v", *path, err) + } +} + +// enableDir enables verity features on all the files and sub-directories within +// path. +func enableDir(path string) error { + files, err := ioutil.ReadDir(path) + if err != nil { + return err + } + for _, file := range files { + if file.IsDir() { + // For directories, first enable its children. + if err := enableDir(path + "/" + file.Name()); err != nil { + return err + } + } else if file.Mode().IsRegular() { + // For regular files, open and enable verity feature. + f, err := os.Open(path + "/" + file.Name()) + if err != nil { + return err + } + var p uintptr + if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f.Fd()), uintptr(linux.FS_IOC_ENABLE_VERITY), p); err != 0 { + return err + } + } + } + // Once all children are enabled, enable the parent directory. + f, err := os.Open(path) + if err != nil { + return err + } + var p uintptr + if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f.Fd()), uintptr(linux.FS_IOC_ENABLE_VERITY), p); err != 0 { + return err + } + return nil +} diff --git a/runsc/mitigate/mitigate_conf.go b/tools/verity/measure_tool_unsafe.go index ee326324b..d4079be9e 100644 --- a/runsc/mitigate/mitigate_conf.go +++ b/tools/verity/measure_tool_unsafe.go @@ -11,27 +11,29 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -package mitigate +package main import ( - "gvisor.dev/gvisor/runsc/flag" -) - -type mitigate struct { -} + "encoding/hex" + "fmt" + "os" + "syscall" + "unsafe" -// usage returns the usage string portion for the mitigate. -func (m mitigate) usage() string { return "" } - -// setFlags sets additional flags for the Mitigate command. -func (m mitigate) setFlags(f *flag.FlagSet) {} - -// execute performs additional parts of Execute for Mitigate. -func (m mitigate) execute(set cpuSet, dryrun bool) error { - return nil -} + "gvisor.dev/gvisor/pkg/abi/linux" +) -func (m mitigate) vulnerable(other thread) bool { - return other.isVulnerable() +// measure prints the hash of path to stdout. +func measure(path string) error { + f, err := os.Open(path) + if err != nil { + return err + } + var digest digest + digest.metadata.DigestSize = maxDigestSize + if _, _, err := syscall.Syscall(syscall.SYS_IOCTL, uintptr(f.Fd()), uintptr(linux.FS_IOC_MEASURE_VERITY), uintptr(unsafe.Pointer(&digest))); err != 0 { + return err + } + fmt.Fprintf(os.Stdout, "%s\n", hex.EncodeToString(digest.digest[:digest.metadata.DigestSize])) + return err } |