diff options
Diffstat (limited to 'pkg/sentry')
115 files changed, 3700 insertions, 902 deletions
diff --git a/pkg/sentry/arch/fpu/fpu_amd64.go b/pkg/sentry/arch/fpu/fpu_amd64.go index 1e9625bee..f0ba26736 100644 --- a/pkg/sentry/arch/fpu/fpu_amd64.go +++ b/pkg/sentry/arch/fpu/fpu_amd64.go @@ -219,6 +219,11 @@ func (s *State) PtraceSetXstateRegs(src io.Reader, maxlen int, featureSet *cpuid return copy(*s, f), nil } +// SetMXCSR sets the MXCSR control/status register in the state. +func (s *State) SetMXCSR(mxcsr uint32) { + hostarch.ByteOrder.PutUint32((*s)[mxcsrOffset:], mxcsr) +} + // BytePointer returns a pointer to the first byte of the state. // //go:nosplit diff --git a/pkg/sentry/devices/memdev/zero.go b/pkg/sentry/devices/memdev/zero.go index 1929e41cd..49c53452a 100644 --- a/pkg/sentry/devices/memdev/zero.go +++ b/pkg/sentry/devices/memdev/zero.go @@ -93,6 +93,7 @@ func (fd *zeroFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) erro // "/dev/zero (deleted)". opts.Offset = 0 opts.MappingIdentity = &fd.vfsfd + opts.SentryOwnedContent = true opts.MappingIdentity.IncRef() return nil } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 0b3d0617f..46a2dc47d 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -384,8 +384,16 @@ func (c *ConnectedEndpoint) CloseUnread() {} // SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { - // gVisor does not permit setting of SO_SNDBUF for host backed unix domain - // sockets. + // gVisor does not permit setting of SO_SNDBUF for host backed unix + // domain sockets. + return atomic.LoadInt64(&c.sndbuf) +} + +// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize. +func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_RCVBUF for host backed unix + // domain sockets. Receive buffer does not have any effect for unix + // sockets and we claim to be the same as send buffer. return atomic.LoadInt64(&c.sndbuf) } diff --git a/pkg/sentry/fsimpl/cgroupfs/BUILD b/pkg/sentry/fsimpl/cgroupfs/BUILD new file mode 100644 index 000000000..37efb641a --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/BUILD @@ -0,0 +1,48 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +licenses(["notice"]) + +go_template_instance( + name = "dir_refs", + out = "dir_refs.go", + package = "cgroupfs", + prefix = "dir", + template = "//pkg/refsvfs2:refs_template", + types = { + "T": "dir", + }, +) + +go_library( + name = "cgroupfs", + srcs = [ + "base.go", + "cgroupfs.go", + "cpu.go", + "cpuacct.go", + "cpuset.go", + "dir_refs.go", + "job.go", + "memory.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/coverage", + "//pkg/log", + "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/kernfs", + "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", + "//pkg/sentry/usage", + "//pkg/sentry/vfs", + "//pkg/sync", + "//pkg/syserror", + "//pkg/usermem", + ], +) diff --git a/pkg/sentry/fsimpl/cgroupfs/base.go b/pkg/sentry/fsimpl/cgroupfs/base.go new file mode 100644 index 000000000..0f54888d8 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/base.go @@ -0,0 +1,261 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + "sort" + "strconv" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +// controllerCommon implements kernel.CgroupController. +// +// Must call init before use. +// +// +stateify savable +type controllerCommon struct { + ty kernel.CgroupControllerType + fs *filesystem +} + +func (c *controllerCommon) init(ty kernel.CgroupControllerType, fs *filesystem) { + c.ty = ty + c.fs = fs +} + +// Type implements kernel.CgroupController.Type. +func (c *controllerCommon) Type() kernel.CgroupControllerType { + return kernel.CgroupControllerType(c.ty) +} + +// HierarchyID implements kernel.CgroupController.HierarchyID. +func (c *controllerCommon) HierarchyID() uint32 { + return c.fs.hierarchyID +} + +// NumCgroups implements kernel.CgroupController.NumCgroups. +func (c *controllerCommon) NumCgroups() uint64 { + return atomic.LoadUint64(&c.fs.numCgroups) +} + +// Enabled implements kernel.CgroupController.Enabled. +// +// Controllers are currently always enabled. +func (c *controllerCommon) Enabled() bool { + return true +} + +// Filesystem implements kernel.CgroupController.Filesystem. +func (c *controllerCommon) Filesystem() *vfs.Filesystem { + return c.fs.VFSFilesystem() +} + +// RootCgroup implements kernel.CgroupController.RootCgroup. +func (c *controllerCommon) RootCgroup() kernel.Cgroup { + return c.fs.rootCgroup() +} + +// controller is an interface for common functionality related to all cgroups. +// It is an extension of the public cgroup interface, containing cgroup +// functionality private to cgroupfs. +type controller interface { + kernel.CgroupController + + // AddControlFiles should extend the contents map with inodes representing + // control files defined by this controller. + AddControlFiles(ctx context.Context, creds *auth.Credentials, c *cgroupInode, contents map[string]kernfs.Inode) +} + +// cgroupInode implements kernel.CgroupImpl and kernfs.Inode. +// +// +stateify savable +type cgroupInode struct { + dir + fs *filesystem + + // ts is the list of tasks in this cgroup. The kernel is responsible for + // removing tasks from this list before they're destroyed, so any tasks on + // this list are always valid. + // + // ts, and cgroup membership in general is protected by fs.tasksMu. + ts map[*kernel.Task]struct{} +} + +var _ kernel.CgroupImpl = (*cgroupInode)(nil) + +func (fs *filesystem) newCgroupInode(ctx context.Context, creds *auth.Credentials) kernfs.Inode { + c := &cgroupInode{ + fs: fs, + ts: make(map[*kernel.Task]struct{}), + } + + contents := make(map[string]kernfs.Inode) + contents["cgroup.procs"] = fs.newControllerFile(ctx, creds, &cgroupProcsData{c}) + contents["tasks"] = fs.newControllerFile(ctx, creds, &tasksData{c}) + + for _, ctl := range fs.controllers { + ctl.AddControlFiles(ctx, creds, c, contents) + } + + c.dir.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|linux.FileMode(0555)) + c.dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) + c.dir.InitRefs() + c.dir.IncLinks(c.dir.OrderedChildren.Populate(contents)) + + atomic.AddUint64(&fs.numCgroups, 1) + + return c +} + +func (c *cgroupInode) HierarchyID() uint32 { + return c.fs.hierarchyID +} + +// Controllers implements kernel.CgroupImpl.Controllers. +func (c *cgroupInode) Controllers() []kernel.CgroupController { + return c.fs.kcontrollers +} + +// Enter implements kernel.CgroupImpl.Enter. +func (c *cgroupInode) Enter(t *kernel.Task) { + c.fs.tasksMu.Lock() + c.ts[t] = struct{}{} + c.fs.tasksMu.Unlock() +} + +// Leave implements kernel.CgroupImpl.Leave. +func (c *cgroupInode) Leave(t *kernel.Task) { + c.fs.tasksMu.Lock() + delete(c.ts, t) + c.fs.tasksMu.Unlock() +} + +func sortTIDs(tids []kernel.ThreadID) { + sort.Slice(tids, func(i, j int) bool { return tids[i] < tids[j] }) +} + +// +stateify savable +type cgroupProcsData struct { + *cgroupInode +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cgroupProcsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + + pgids := make(map[kernel.ThreadID]struct{}) + + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + for task := range d.ts { + // Map dedups pgid, since iterating over all tasks produces multiple + // entries for the group leaders. + if pgid := currPidns.IDOfThreadGroup(task.ThreadGroup()); pgid != 0 { + pgids[pgid] = struct{}{} + } + } + + pgidList := make([]kernel.ThreadID, 0, len(pgids)) + for pgid, _ := range pgids { + pgidList = append(pgidList, pgid) + } + sortTIDs(pgidList) + + for _, pgid := range pgidList { + fmt.Fprintf(buf, "%d\n", pgid) + } + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *cgroupProcsData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. + return src.NumBytes(), nil +} + +// +stateify savable +type tasksData struct { + *cgroupInode +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *tasksData) Generate(ctx context.Context, buf *bytes.Buffer) error { + t := kernel.TaskFromContext(ctx) + currPidns := t.ThreadGroup().PIDNamespace() + + var pids []kernel.ThreadID + + d.fs.tasksMu.RLock() + defer d.fs.tasksMu.RUnlock() + + for task := range d.ts { + if pid := currPidns.IDOfTask(task); pid != 0 { + pids = append(pids, pid) + } + } + sortTIDs(pids) + + for _, pid := range pids { + fmt.Fprintf(buf, "%d\n", pid) + } + + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *tasksData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + // TODO(b/183137098): Payload is the pid for a process to add to this cgroup. + return src.NumBytes(), nil +} + +// parseInt64FromString interprets src as string encoding a int64 value, and +// returns the parsed value. +func parseInt64FromString(ctx context.Context, src usermem.IOSequence, offset int64) (val, len int64, err error) { + const maxInt64StrLen = 20 // i.e. len(fmt.Sprintf("%d", math.MinInt64)) == 20 + + t := kernel.TaskFromContext(ctx) + src = src.DropFirst64(offset) + + buf := t.CopyScratchBuffer(maxInt64StrLen) + n, err := src.CopyIn(ctx, buf) + if err != nil { + return 0, int64(n), err + } + buf = buf[:n] + + val, err = strconv.ParseInt(string(buf), 10, 64) + if err != nil { + // Note: This also handles zero-len writes if offset is beyond the end + // of src, or src is empty. + ctx.Warningf("cgroupfs.parseInt64FromString: failed to parse %q: %v", string(buf), err) + return 0, int64(n), syserror.EINVAL + } + + return val, int64(n), nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go new file mode 100644 index 000000000..bd3e69757 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -0,0 +1,425 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package cgroupfs implements cgroupfs. +// +// A cgroup is a collection of tasks on the system, organized into a tree-like +// structure similar to a filesystem directory tree. In fact, each cgroup is +// represented by a directory on cgroupfs, and is manipulated through control +// files in the directory. +// +// All cgroups on a system are organized into hierarchies. Hierarchies are a +// distinct tree of cgroups, with a common set of controllers. One or more +// cgroupfs mounts may point to each hierarchy. These mounts provide a common +// view into the same tree of cgroups. +// +// A controller (also known as a "resource controller", or a cgroup "subsystem") +// determines the behaviour of each cgroup. +// +// In addition to cgroupfs, the kernel has a cgroup registry that tracks +// system-wide state related to cgroups such as active hierarchies and the +// controllers associated with them. +// +// Since cgroupfs doesn't allow hardlinks, there is a unique mapping between +// cgroupfs dentries and inodes. +// +// # Synchronization +// +// Cgroup hierarchy creation and destruction is protected by the +// kernel.CgroupRegistry.mu. Once created, a hierarchy's set of controllers, the +// filesystem associated with it, and the root cgroup for the hierarchy are +// immutable. +// +// Membership of tasks within cgroups is protected by +// cgroupfs.filesystem.tasksMu. Tasks also maintain a set of all cgroups they're +// in, and this list is protected by Task.mu. +// +// Lock order: +// +// kernel.CgroupRegistry.mu +// cgroupfs.filesystem.mu +// Task.mu +// cgroupfs.filesystem.tasksMu. +package cgroupfs + +import ( + "fmt" + "sort" + "strconv" + "strings" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" +) + +const ( + // Name is the default filesystem name. + Name = "cgroup" + readonlyFileMode = linux.FileMode(0444) + writableFileMode = linux.FileMode(0644) + defaultMaxCachedDentries = uint64(1000) +) + +const ( + controllerCPU = kernel.CgroupControllerType("cpu") + controllerCPUAcct = kernel.CgroupControllerType("cpuacct") + controllerCPUSet = kernel.CgroupControllerType("cpuset") + controllerJob = kernel.CgroupControllerType("job") + controllerMemory = kernel.CgroupControllerType("memory") +) + +var allControllers = []kernel.CgroupControllerType{ + controllerCPU, + controllerCPUAcct, + controllerCPUSet, + controllerJob, + controllerMemory, +} + +// SupportedMountOptions is the set of supported mount options for cgroupfs. +var SupportedMountOptions = []string{"all", "cpu", "cpuacct", "cpuset", "job", "memory"} + +// FilesystemType implements vfs.FilesystemType. +// +// +stateify savable +type FilesystemType struct{} + +// InternalData contains internal data passed in to the cgroupfs mount via +// vfs.GetFilesystemOptions.InternalData. +// +// +stateify savable +type InternalData struct { + DefaultControlValues map[string]int64 +} + +// filesystem implements vfs.FilesystemImpl. +// +// +stateify savable +type filesystem struct { + kernfs.Filesystem + devMinor uint32 + + // hierarchyID is the id the cgroup registry assigns to this hierarchy. Has + // the value kernel.InvalidCgroupHierarchyID until the FS is fully + // initialized. + // + // hierarchyID is immutable after initialization. + hierarchyID uint32 + + // controllers and kcontrollers are both the list of controllers attached to + // this cgroupfs. Both lists are the same set of controllers, but typecast + // to different interfaces for convenience. Both must stay in sync, and are + // immutable. + controllers []controller + kcontrollers []kernel.CgroupController + + numCgroups uint64 // Protected by atomic ops. + + root *kernfs.Dentry + + // tasksMu serializes task membership changes across all cgroups within a + // filesystem. + tasksMu sync.RWMutex `state:"nosave"` +} + +// Name implements vfs.FilesystemType.Name. +func (FilesystemType) Name() string { + return Name +} + +// Release implements vfs.FilesystemType.Release. +func (FilesystemType) Release(ctx context.Context) {} + +// GetFilesystem implements vfs.FilesystemType.GetFilesystem. +func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + devMinor, err := vfsObj.GetAnonBlockDevMinor() + if err != nil { + return nil, nil, err + } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + + var wantControllers []kernel.CgroupControllerType + if _, ok := mopts["cpu"]; ok { + delete(mopts, "cpu") + wantControllers = append(wantControllers, controllerCPU) + } + if _, ok := mopts["cpuacct"]; ok { + delete(mopts, "cpuacct") + wantControllers = append(wantControllers, controllerCPUAcct) + } + if _, ok := mopts["cpuset"]; ok { + delete(mopts, "cpuset") + wantControllers = append(wantControllers, controllerCPUSet) + } + if _, ok := mopts["job"]; ok { + delete(mopts, "job") + wantControllers = append(wantControllers, controllerJob) + } + if _, ok := mopts["memory"]; ok { + delete(mopts, "memory") + wantControllers = append(wantControllers, controllerMemory) + } + if _, ok := mopts["all"]; ok { + if len(wantControllers) > 0 { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: other controllers specified with all: %v", wantControllers) + return nil, nil, syserror.EINVAL + } + + delete(mopts, "all") + wantControllers = allControllers + } + + if len(wantControllers) == 0 { + // Specifying no controllers implies all controllers. + wantControllers = allControllers + } + + if len(mopts) != 0 { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + + k := kernel.KernelFromContext(ctx) + r := k.CgroupRegistry() + + // "It is not possible to mount the same controller against multiple + // cgroup hierarchies. For example, it is not possible to mount both + // the cpu and cpuacct controllers against one hierarchy, and to mount + // the cpu controller alone against another hierarchy." - man cgroups(7) + // + // Is there a hierarchy available with all the controllers we want? If so, + // this mount is a view into the same hierarchy. + // + // Note: we're guaranteed to have at least one requested controller, since + // no explicit controller name implies all controllers. + if vfsfs := r.FindHierarchy(wantControllers); vfsfs != nil { + fs := vfsfs.Impl().(*filesystem) + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: mounting new view to hierarchy %v", fs.hierarchyID) + fs.root.IncRef() + return vfsfs, fs.root.VFSDentry(), nil + } + + // No existing hierarchy with the exactly controllers found. Make a new + // one. Note that it's possible this mount creation is unsatisfiable, if one + // or more of the requested controllers are already on existing + // hierarchies. We'll find out about such collisions when we try to register + // the new hierarchy later. + fs := &filesystem{ + devMinor: devMinor, + } + fs.MaxCachedDentries = maxCachedDentries + fs.VFSFilesystem().Init(vfsObj, &fsType, fs) + + var defaults map[string]int64 + if opts.InternalData != nil { + ctx.Debugf("cgroupfs.FilesystemType.GetFilesystem: default control values: %v", defaults) + defaults = opts.InternalData.(*InternalData).DefaultControlValues + } + + for _, ty := range wantControllers { + var c controller + switch ty { + case controllerCPU: + c = newCPUController(fs, defaults) + case controllerCPUAcct: + c = newCPUAcctController(fs) + case controllerCPUSet: + c = newCPUSetController(fs) + case controllerJob: + c = newJobController(fs) + case controllerMemory: + c = newMemoryController(fs, defaults) + default: + panic(fmt.Sprintf("Unreachable: unknown cgroup controller %q", ty)) + } + fs.controllers = append(fs.controllers, c) + } + + if len(defaults) != 0 { + // Internal data is always provided at sentry startup and unused values + // indicate a problem with the sandbox config. Fail fast. + panic(fmt.Sprintf("cgroupfs.FilesystemType.GetFilesystem: unknown internal mount data: %v", defaults)) + } + + // Controllers usually appear in alphabetical order when displayed. Sort it + // here now, so it never needs to be sorted elsewhere. + sort.Slice(fs.controllers, func(i, j int) bool { return fs.controllers[i].Type() < fs.controllers[j].Type() }) + fs.kcontrollers = make([]kernel.CgroupController, 0, len(fs.controllers)) + for _, c := range fs.controllers { + fs.kcontrollers = append(fs.kcontrollers, c) + } + + root := fs.newCgroupInode(ctx, creds) + var rootD kernfs.Dentry + rootD.InitRoot(&fs.Filesystem, root) + fs.root = &rootD + + // Register controllers. The registry may be modified concurrently, so if we + // get an error, we raced with someone else who registered the same + // controllers first. + hid, err := r.Register(fs.kcontrollers) + if err != nil { + ctx.Infof("cgroupfs.FilesystemType.GetFilesystem: failed to register new hierarchy with controllers %v: %v", wantControllers, err) + rootD.DecRef(ctx) + fs.VFSFilesystem().DecRef(ctx) + return nil, nil, syserror.EBUSY + } + fs.hierarchyID = hid + + // Move all existing tasks to the root of the new hierarchy. + k.PopulateNewCgroupHierarchy(fs.rootCgroup()) + + return fs.VFSFilesystem(), rootD.VFSDentry(), nil +} + +func (fs *filesystem) rootCgroup() kernel.Cgroup { + return kernel.Cgroup{ + Dentry: fs.root, + CgroupImpl: fs.root.Inode().(kernel.CgroupImpl), + } +} + +// Release implements vfs.FilesystemImpl.Release. +func (fs *filesystem) Release(ctx context.Context) { + k := kernel.KernelFromContext(ctx) + r := k.CgroupRegistry() + + if fs.hierarchyID != kernel.InvalidCgroupHierarchyID { + k.ReleaseCgroupHierarchy(fs.hierarchyID) + r.Unregister(fs.hierarchyID) + } + + fs.Filesystem.VFSFilesystem().VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) + fs.Filesystem.Release(ctx) +} + +// MountOptions implements vfs.FilesystemImpl.MountOptions. +func (fs *filesystem) MountOptions() string { + var cnames []string + for _, c := range fs.controllers { + cnames = append(cnames, string(c.Type())) + } + return strings.Join(cnames, ",") +} + +// +stateify savable +type implStatFS struct{} + +// StatFS implements kernfs.Inode.StatFS. +func (*implStatFS) StatFS(context.Context, *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil +} + +// dir implements kernfs.Inode for a generic cgroup resource controller +// directory. Specific controllers extend this to add their own functionality. +// +// +stateify savable +type dir struct { + dirRefs + kernfs.InodeAlwaysValid + kernfs.InodeAttrs + kernfs.InodeNotSymlink + kernfs.InodeDirectoryNoNewChildren // TODO(b/183137098): Implement mkdir. + kernfs.OrderedChildren + implStatFS + + locks vfs.FileLocks +} + +// Keep implements kernfs.Inode.Keep. +func (*dir) Keep() bool { + return true +} + +// SetStat implements kernfs.Inode.SetStat not allowing inode attributes to be changed. +func (*dir) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Open implements kernfs.Inode.Open. +func (d *dir) Open(ctx context.Context, rp *vfs.ResolvingPath, kd *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd, err := kernfs.NewGenericDirectoryFD(rp.Mount(), kd, &d.OrderedChildren, &d.locks, &opts, kernfs.GenericDirectoryFDOptions{ + SeekEnd: kernfs.SeekEndStaticEntries, + }) + if err != nil { + return nil, err + } + return fd.VFSFileDescription(), nil +} + +// DecRef implements kernfs.Inode.DecRef. +func (d *dir) DecRef(ctx context.Context) { + d.dirRefs.DecRef(func() { d.Destroy(ctx) }) +} + +// StatFS implements kernfs.Inode.StatFS. +func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { + return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil +} + +// controllerFile represents a generic control file that appears within a cgroup +// directory. +// +// +stateify savable +type controllerFile struct { + kernfs.DynamicBytesFile +} + +func (fs *filesystem) newControllerFile(ctx context.Context, creds *auth.Credentials, data vfs.DynamicBytesSource) kernfs.Inode { + f := &controllerFile{} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, readonlyFileMode) + return f +} + +func (fs *filesystem) newControllerWritableFile(ctx context.Context, creds *auth.Credentials, data vfs.WritableDynamicBytesSource) kernfs.Inode { + f := &controllerFile{} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), data, writableFileMode) + return f +} + +// staticControllerFile represents a generic control file that appears within a +// cgroup directory which always returns the same data when read. +// staticControllerFiles are not writable. +// +// +stateify savable +type staticControllerFile struct { + kernfs.DynamicBytesFile + vfs.StaticData +} + +// Note: We let the caller provide the mode so that static files may be used to +// fake both readable and writable control files. However, static files are +// effectively readonly, as attempting to write to them will return EIO +// regardless of the mode. +func (fs *filesystem) newStaticControllerFile(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, data string) kernfs.Inode { + f := &staticControllerFile{StaticData: vfs.StaticData{Data: data}} + f.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), f, mode) + return f +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpu.go b/pkg/sentry/fsimpl/cgroupfs/cpu.go new file mode 100644 index 000000000..24d86a277 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpu.go @@ -0,0 +1,70 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// +stateify savable +type cpuController struct { + controllerCommon + + // CFS bandwidth control parameters, values in microseconds. + cfsPeriod int64 + cfsQuota int64 + + // CPU shares, values should be (num core * 1024). + shares int64 +} + +var _ controller = (*cpuController)(nil) + +func newCPUController(fs *filesystem, defaults map[string]int64) *cpuController { + // Default values for controller parameters from Linux. + c := &cpuController{ + cfsPeriod: 100000, + cfsQuota: -1, + shares: 1024, + } + + if val, ok := defaults["cpu.cfs_period_us"]; ok { + c.cfsPeriod = val + delete(defaults, "cpu.cfs_period_us") + } + if val, ok := defaults["cpu.cfs_quota_us"]; ok { + c.cfsQuota = val + delete(defaults, "cpu.cfs_quota_us") + } + if val, ok := defaults["cpu.shares"]; ok { + c.shares = val + delete(defaults, "cpu.shares") + } + + c.controllerCommon.init(controllerCPU, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpuController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["cpu.cfs_period_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsPeriod)) + contents["cpu.cfs_quota_us"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.cfsQuota)) + contents["cpu.shares"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.shares)) +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuacct.go b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go new file mode 100644 index 000000000..d4104a00e --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpuacct.go @@ -0,0 +1,114 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +// +stateify savable +type cpuacctController struct { + controllerCommon +} + +var _ controller = (*cpuacctController)(nil) + +func newCPUAcctController(fs *filesystem) *cpuacctController { + c := &cpuacctController{} + c.controllerCommon.init(controllerCPUAcct, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpuacctController) AddControlFiles(ctx context.Context, creds *auth.Credentials, cg *cgroupInode, contents map[string]kernfs.Inode) { + cpuacctCG := &cpuacctCgroup{cg} + contents["cpuacct.stat"] = c.fs.newControllerFile(ctx, creds, &cpuacctStatData{cpuacctCG}) + contents["cpuacct.usage"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageData{cpuacctCG}) + contents["cpuacct.usage_user"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageUserData{cpuacctCG}) + contents["cpuacct.usage_sys"] = c.fs.newControllerFile(ctx, creds, &cpuacctUsageSysData{cpuacctCG}) +} + +// +stateify savable +type cpuacctCgroup struct { + *cgroupInode +} + +func (c *cpuacctCgroup) collectCPUStats() usage.CPUStats { + var cs usage.CPUStats + c.fs.tasksMu.RLock() + // Note: This isn't very accurate, since the tasks are potentially + // still running as we accumulate their stats. + for t := range c.ts { + cs.Accumulate(t.CPUStats()) + } + c.fs.tasksMu.RUnlock() + return cs +} + +// +stateify savable +type cpuacctStatData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctStatData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "user %d\n", linux.ClockTFromDuration(cs.UserTime)) + fmt.Fprintf(buf, "system %d\n", linux.ClockTFromDuration(cs.SysTime)) + return nil +} + +// +stateify savable +type cpuacctUsageData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()+cs.SysTime.Nanoseconds()) + return nil +} + +// +stateify savable +type cpuacctUsageUserData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageUserData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.UserTime.Nanoseconds()) + return nil +} + +// +stateify savable +type cpuacctUsageSysData struct { + *cpuacctCgroup +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *cpuacctUsageSysData) Generate(ctx context.Context, buf *bytes.Buffer) error { + cs := d.collectCPUStats() + fmt.Fprintf(buf, "%d\n", cs.SysTime.Nanoseconds()) + return nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cpuset.go b/pkg/sentry/fsimpl/cgroupfs/cpuset.go new file mode 100644 index 000000000..ac547f8e2 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/cpuset.go @@ -0,0 +1,39 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// +stateify savable +type cpusetController struct { + controllerCommon +} + +var _ controller = (*cpusetController)(nil) + +func newCPUSetController(fs *filesystem) *cpusetController { + c := &cpusetController{} + c.controllerCommon.init(controllerCPUSet, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *cpusetController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + // This controller is currently intentionally empty. +} diff --git a/pkg/sentry/fsimpl/cgroupfs/job.go b/pkg/sentry/fsimpl/cgroupfs/job.go new file mode 100644 index 000000000..48919c338 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/job.go @@ -0,0 +1,64 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/usermem" +) + +// +stateify savable +type jobController struct { + controllerCommon + id int64 +} + +var _ controller = (*jobController)(nil) + +func newJobController(fs *filesystem) *jobController { + c := &jobController{} + c.controllerCommon.init(controllerJob, fs) + return c +} + +func (c *jobController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["job.id"] = c.fs.newControllerWritableFile(ctx, creds, &jobIDData{c: c}) +} + +// +stateify savable +type jobIDData struct { + c *jobController +} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *jobIDData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "%d\n", d.c.id) + return nil +} + +// Write implements vfs.WritableDynamicBytesSource.Write. +func (d *jobIDData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + val, n, err := parseInt64FromString(ctx, src, offset) + if err != nil { + return n, err + } + d.c.id = val + return n, nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/memory.go b/pkg/sentry/fsimpl/cgroupfs/memory.go new file mode 100644 index 000000000..485c98376 --- /dev/null +++ b/pkg/sentry/fsimpl/cgroupfs/memory.go @@ -0,0 +1,74 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cgroupfs + +import ( + "bytes" + "fmt" + "math" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/usage" +) + +// +stateify savable +type memoryController struct { + controllerCommon + + limitBytes int64 +} + +var _ controller = (*memoryController)(nil) + +func newMemoryController(fs *filesystem, defaults map[string]int64) *memoryController { + c := &memoryController{ + // Linux sets this to (PAGE_COUNTER_MAX * PAGE_SIZE) by default, which + // is ~ 2**63 on a 64-bit system. So essentially, inifinity. The exact + // value isn't very important. + limitBytes: math.MaxInt64, + } + if val, ok := defaults["memory.limit_in_bytes"]; ok { + c.limitBytes = val + delete(defaults, "memory.limit_in_bytes") + } + c.controllerCommon.init(controllerMemory, fs) + return c +} + +// AddControlFiles implements controller.AddControlFiles. +func (c *memoryController) AddControlFiles(ctx context.Context, creds *auth.Credentials, _ *cgroupInode, contents map[string]kernfs.Inode) { + contents["memory.usage_in_bytes"] = c.fs.newControllerFile(ctx, creds, &memoryUsageInBytesData{}) + contents["memory.limit_in_bytes"] = c.fs.newStaticControllerFile(ctx, creds, linux.FileMode(0644), fmt.Sprintf("%d\n", c.limitBytes)) +} + +// +stateify savable +type memoryUsageInBytesData struct{} + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *memoryUsageInBytesData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // TODO(b/183151557): This is a giant hack, we're using system-wide + // accounting since we know there is only one cgroup. + k := kernel.KernelFromContext(ctx) + mf := k.MemoryFile() + mf.UpdateUsage() + _, totalBytes := usage.MemoryAccounting.Copy() + + fmt.Fprintf(buf, "%d\n", totalBytes) + return nil +} diff --git a/pkg/sentry/fsimpl/ext/BUILD b/pkg/sentry/fsimpl/ext/BUILD index 7b1eec3da..2dbc6bfd5 100644 --- a/pkg/sentry/fsimpl/ext/BUILD +++ b/pkg/sentry/fsimpl/ext/BUILD @@ -46,7 +46,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fd", "//pkg/fspath", diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 6d5258a9b..52879f871 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "revalidate.go", "save_restore.go", "socket.go", "special_file.go", diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 43c3c5a2d..97ce80853 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -117,6 +117,17 @@ func appendDentry(ds *[]*dentry, d *dentry) *[]*dentry { return ds } +// Precondition: !parent.isSynthetic() && !child.isSynthetic(). +func appendNewChildDentry(ds **[]*dentry, parent *dentry, child *dentry) { + // The new child was added to parent and took a ref on the parent (hence + // parent can be removed from cache). A new child has 0 refs for now. So + // checkCachingLocked() should be called on both. Call it first on the parent + // as it may create space in the cache for child to be inserted - hence + // avoiding a cache eviction. + *ds = appendDentry(*ds, parent) + *ds = appendDentry(*ds, child) +} + // Preconditions: ds != nil. func putDentrySlice(ds *[]*dentry) { // Allow dentries to be GC'd. @@ -141,21 +152,8 @@ func (fs *filesystem) renameMuRUnlockAndCheckCaching(ctx context.Context, dsp ** return } ds := **dsp - // Only go through calling dentry.checkCachingLocked() (which requires - // re-locking renameMu) if we actually have any dentries with zero refs. - checkAny := false - for i := range ds { - if atomic.LoadInt64(&ds[i].refs) == 0 { - checkAny = true - break - } - } - if checkAny { - fs.renameMu.Lock() - for _, d := range ds { - d.checkCachingLocked(ctx) - } - fs.renameMu.Unlock() + for _, d := range ds { + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } putDentrySlice(*dsp) } @@ -166,7 +164,7 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[] return } for _, d := range **ds { - d.checkCachingLocked(ctx) + d.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } fs.renameMu.Unlock() putDentrySlice(*ds) @@ -182,165 +180,96 @@ func (fs *filesystem) renameMuUnlockAndCheckCaching(ctx context.Context, ds **[] // * fs.renameMu must be locked. // * d.dirMu must be locked. // * !rp.Done(). -// * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up -// to date. +// * If !d.cachedMetadataAuthoritative(), then d and all children that are +// part of rp must have been revalidated. // // Postconditions: The returned dentry's cached metadata is up to date. -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, bool, error) { if !d.isDir() { - return nil, syserror.ENOTDIR + return nil, false, syserror.ENOTDIR } if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err + return nil, false, err } + followedSymlink := false afterSymlink: name := rp.Component() if name == "." { rp.Advance() - return d, nil + return d, followedSymlink, nil } if name == ".." { if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { - return nil, err + return nil, false, err } else if isRoot || d.parent == nil { rp.Advance() - return d, nil - } - // We must assume that d.parent is correct, because if d has been moved - // elsewhere in the remote filesystem so that its parent has changed, - // we have no way of determining its new parent's location in the - // filesystem. - // - // Call rp.CheckMount() before updating d.parent's metadata, since if - // we traverse to another mount then d.parent's metadata is irrelevant. - if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { - return nil, err + return d, followedSymlink, nil } - if d != d.parent && !d.cachedMetadataAuthoritative() { - if err := d.parent.updateFromGetattr(ctx); err != nil { - return nil, err - } + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, false, err } rp.Advance() - return d.parent, nil + return d.parent, followedSymlink, nil } - child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), d, name, ds) + child, err := fs.getChildLocked(ctx, d, name, ds) if err != nil { - return nil, err - } - if child == nil { - return nil, syserror.ENOENT + return nil, false, err } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { - return nil, err + return nil, false, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx, rp.Mount()) if err != nil { - return nil, err + return nil, false, err } if err := rp.HandleSymlink(target); err != nil { - return nil, err + return nil, false, err } + followedSymlink = true goto afterSymlink // don't check the current directory again } rp.Advance() - return child, nil + return child, followedSymlink, nil } // getChildLocked returns a dentry representing the child of parent with the -// given name. If no such child exists, getChildLocked returns (nil, nil). +// given name. Returns ENOENT if the child doesn't exist. // // Preconditions: // * fs.renameMu must be locked. // * parent.dirMu must be locked. // * parent.isDir(). // * name is not "." or "..". -// -// Postconditions: If getChildLocked returns a non-nil dentry, its cached -// metadata is up to date. -func (fs *filesystem) getChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { +// * dentry at name has been revalidated +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if len(name) > maxFilenameLen { return nil, syserror.ENAMETOOLONG } - child, ok := parent.children[name] - if (ok && fs.opts.interop != InteropModeShared) || parent.isSynthetic() { - // Whether child is nil or not, it is cached information that is - // assumed to be correct. + if child, ok := parent.children[name]; ok || parent.isSynthetic() { + if child == nil { + return nil, syserror.ENOENT + } return child, nil } - // We either don't have cached information or need to verify that it's - // still correct, either of which requires a remote lookup. Check if this - // name is valid before performing the lookup. - return fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, ds) -} -// Preconditions: Same as getChildLocked, plus: -// * !parent.isSynthetic(). -func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, child *dentry, ds **[]*dentry) (*dentry, error) { - if child != nil { - // Need to lock child.metadataMu because we might be updating child - // metadata. We need to hold the lock *before* getting metadata from the - // server and release it after updating local metadata. - child.metadataMu.Lock() - } qid, file, attrMask, attr, err := parent.file.walkGetAttrOne(ctx, name) - if err != nil && err != syserror.ENOENT { - if child != nil { - child.metadataMu.Unlock() + if err != nil { + if err == syserror.ENOENT { + parent.cacheNegativeLookupLocked(name) } return nil, err } - if child != nil { - if !file.isNil() && qid.Path == child.qidPath { - // The file at this path hasn't changed. Just update cached metadata. - file.close(ctx) - child.updateFromP9AttrsLocked(attrMask, &attr) - child.metadataMu.Unlock() - return child, nil - } - child.metadataMu.Unlock() - if file.isNil() && child.isSynthetic() { - // We have a synthetic file, and no remote file has arisen to - // replace it. - return child, nil - } - // The file at this path has changed or no longer exists. Mark the - // dentry invalidated, and re-evaluate its caching status (i.e. if it - // has 0 references, drop it). Wait to update parent.children until we - // know what to replace the existing dentry with (i.e. one of the - // returns below), to avoid a redundant map access. - vfsObj.InvalidateDentry(ctx, &child.vfsd) - if child.isSynthetic() { - // Normally we don't mark invalidated dentries as deleted since - // they may still exist (but at a different path), and also for - // consistency with Linux. However, synthetic files are guaranteed - // to become unreachable if their dentries are invalidated, so - // treat their invalidation as deletion. - child.setDeleted() - parent.syntheticChildren-- - child.decRefNoCaching() - parent.dirents = nil - } - *ds = appendDentry(*ds, child) - } - if file.isNil() { - // No file exists at this path now. Cache the negative lookup if - // allowed. - parent.cacheNegativeLookupLocked(name) - return nil, nil - } + // Create a new dentry representing the file. - child, err = fs.newDentry(ctx, file, qid, attrMask, &attr) + child, err := fs.newDentry(ctx, file, qid, attrMask, &attr) if err != nil { file.close(ctx) delete(parent.children, name) return nil, err } parent.cacheNewChildLocked(child, name) - // For now, child has 0 references, so our caller should call - // child.checkCachingLocked(). - *ds = appendDentry(*ds, child) + appendNewChildDentry(ds, parent, child) return child, nil } @@ -355,14 +284,22 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // * If !d.cachedMetadataAuthoritative(), then d's cached metadata must be up // to date. func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { + if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil { + return nil, err + } for !rp.Final() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err } d = next + if followedSymlink { + if err := fs.revalidateParentDir(ctx, rp, d, ds); err != nil { + return nil, err + } + } } if !d.isDir() { return nil, syserror.ENOTDIR @@ -375,20 +312,22 @@ func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // Preconditions: fs.renameMu must be locked. func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, ds **[]*dentry) (*dentry, error) { d := rp.Start().Impl().(*dentry) - if !d.cachedMetadataAuthoritative() { - // Get updated metadata for rp.Start() as required by fs.stepLocked(). - if err := d.updateFromGetattr(ctx); err != nil { - return nil, err - } + if err := fs.revalidatePath(ctx, rp, d, ds); err != nil { + return nil, err } for !rp.Done() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, followedSymlink, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err } d = next + if followedSymlink { + if err := fs.revalidatePath(ctx, rp, d, ds); err != nil { + return nil, err + } + } } if rp.MustBeDir() && !d.isDir() { return nil, syserror.ENOTDIR @@ -408,13 +347,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return err @@ -432,25 +364,47 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if parent.isDeleted() { return syserror.ENOENT } + if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, name, &ds); err != nil { + return err + } parent.dirMu.Lock() defer parent.dirMu.Unlock() - child, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), parent, name, &ds) - switch { - case err != nil && err != syserror.ENOENT: - return err - case child != nil: + if len(name) > maxFilenameLen { + return syserror.ENAMETOOLONG + } + // Check for existence only if caching information is available. Otherwise, + // don't check for existence just yet. We will check for existence if the + // checks for writability fail below. Existence check is done by the creation + // RPCs themselves. + if child, ok := parent.children[name]; ok && child != nil { return syserror.EEXIST } + checkExistence := func() error { + if child, err := fs.getChildLocked(ctx, parent, name, &ds); err != nil && err != syserror.ENOENT { + return err + } else if child != nil { + return syserror.EEXIST + } + return nil + } mnt := rp.Mount() if err := mnt.CheckBeginWrite(); err != nil { + // Existence check takes precedence. + if existenceErr := checkExistence(); existenceErr != nil { + return existenceErr + } return err } defer mnt.EndWrite() if err := parent.checkPermissions(rp.Credentials(), vfs.MayWrite); err != nil { + // Existence check takes precedence. + if existenceErr := checkExistence(); existenceErr != nil { + return existenceErr + } return err } if !dir && rp.MustBeDir() { @@ -500,13 +454,6 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return err - } - } parent, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return err @@ -532,33 +479,32 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b return syserror.EISDIR } } + vfsObj := rp.VirtualFilesystem() + if err := fs.revalidateOne(ctx, vfsObj, parent, rp.Component(), &ds); err != nil { + return err + } + mntns := vfs.MountNamespaceFromContext(ctx) defer mntns.DecRef(ctx) + parent.dirMu.Lock() defer parent.dirMu.Unlock() - child, ok := parent.children[name] - if ok && child == nil { - return syserror.ENOENT - } - - sticky := atomic.LoadUint32(&parent.mode)&linux.ModeSticky != 0 - if sticky { - if !ok { - // If the sticky bit is set, we need to retrieve the child to determine - // whether removing it is allowed. - child, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) - if err != nil { - return err - } - } else if child != nil && !child.cachedMetadataAuthoritative() { - // Make sure the dentry representing the file at name is up to date - // before examining its metadata. - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) - if err != nil { - return err - } + // Load child if sticky bit is set because we need to determine whether + // deletion is allowed. + var child *dentry + if atomic.LoadUint32(&parent.mode)&linux.ModeSticky == 0 { + var ok bool + child, ok = parent.children[name] + if ok && child == nil { + // Hit a negative cached entry, child doesn't exist. + return syserror.ENOENT + } + } else { + child, _, err = fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + if err != nil { + return err } if err := parent.mayDelete(rp.Credentials(), child); err != nil { return err @@ -567,11 +513,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b // If a child dentry exists, prepare to delete it. This should fail if it is // a mount point. We detect mount points by speculatively calling - // PrepareDeleteDentry, which fails if child is a mount point. However, we - // may need to revalidate the file in this case to make sure that it has not - // been deleted or replaced on the remote fs, in which case the mount point - // will have disappeared. If calling PrepareDeleteDentry fails again on the - // up-to-date dentry, we can be sure that it is a mount point. + // PrepareDeleteDentry, which fails if child is a mount point. // // Also note that if child is nil, then it can't be a mount point. if child != nil { @@ -586,23 +528,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b child.dirMu.Lock() defer child.dirMu.Unlock() if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - // We can skip revalidation in several cases: - // - We are not in InteropModeShared - // - The parent directory is synthetic, in which case the child must also - // be synthetic - // - We already updated the child during the sticky bit check above - if parent.cachedMetadataAuthoritative() || sticky { - return err - } - child, err = fs.revalidateChildLocked(ctx, vfsObj, parent, name, child, &ds) - if err != nil { - return err - } - if child != nil { - if err := vfsObj.PrepareDeleteDentry(mntns, &child.vfsd); err != nil { - return err - } - } + return err } } flags := uint32(0) @@ -723,6 +649,8 @@ func (fs *filesystem) GetDentryAt(ctx context.Context, rp *vfs.ResolvingPath, op } } d.IncRef() + // Call d.checkCachingLocked() so it can be removed from the cache if needed. + ds = appendDentry(ds, d) return &d.vfsd, nil } @@ -732,18 +660,13 @@ func (fs *filesystem) GetParentDentryAt(ctx context.Context, rp *vfs.ResolvingPa fs.renameMu.RLock() defer fs.renameMuRUnlockAndCheckCaching(ctx, &ds) start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by - // fs.walkParentDirLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } d, err := fs.walkParentDirLocked(ctx, rp, start, &ds) if err != nil { return nil, err } d.IncRef() + // Call d.checkCachingLocked() so it can be removed from the cache if needed. + ds = appendDentry(ds, d) return &d.vfsd, nil } @@ -782,7 +705,7 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. // MkdirAt implements vfs.FilesystemImpl.MkdirAt. func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.MkdirOptions) error { creds := rp.Credentials() - return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, _ **[]*dentry) error { + return fs.doCreateAt(ctx, rp, true /* dir */, func(parent *dentry, name string, ds **[]*dentry) error { // If the parent is a setgid directory, use the parent's GID // rather than the caller's and enable setgid. kgid := creds.EffectiveKGID @@ -802,6 +725,7 @@ func (fs *filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kuid: creds.EffectiveKUID, kgid: creds.EffectiveKGID, }) + *ds = appendDentry(*ds, parent) } if fs.opts.interop != InteropModeShared { parent.incLinks() @@ -836,7 +760,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v // to creating a synthetic one, i.e. one that is kept entirely in memory. // Check that we're not overriding an existing file with a synthetic one. - _, err = fs.stepLocked(ctx, rp, parent, true, ds) + _, _, err = fs.stepLocked(ctx, rp, parent, true, ds) switch { case err == nil: // Step succeeded, another file exists. @@ -855,6 +779,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kgid: creds.EffectiveKGID, endpoint: opts.Endpoint, }) + *ds = appendDentry(*ds, parent) return nil case linux.S_IFIFO: parent.createSyntheticChildLocked(&createSyntheticOpts{ @@ -864,6 +789,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v kgid: creds.EffectiveKGID, pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) + *ds = appendDentry(*ds, parent) return nil } // Retain error from gofer if synthetic file cannot be created internally. @@ -895,12 +821,6 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf defer unlock() start := rp.Start().Impl().(*dentry) - if !start.cachedMetadataAuthoritative() { - // Get updated metadata for start as required by fs.stepLocked(). - if err := start.updateFromGetattr(ctx); err != nil { - return nil, err - } - } if rp.Done() { // Reject attempts to open mount root directory with O_CREAT. if mayCreate && rp.MustBeDir() { @@ -909,9 +829,17 @@ func (fs *filesystem) OpenAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if mustCreate { return nil, syserror.EEXIST } + if !start.cachedMetadataAuthoritative() { + // Refresh dentry's attributes before opening. + if err := start.updateFromGetattr(ctx); err != nil { + return nil, err + } + } start.IncRef() defer start.DecRef(ctx) unlock() + // start is intentionally not added to ds (which would remove it from the + // cache) because doing so regresses performance in practice. return start.open(ctx, rp, &opts) } @@ -928,9 +856,12 @@ afterTrailingSymlink: if mayCreate && rp.MustBeDir() { return nil, syserror.EISDIR } + if err := fs.revalidateOne(ctx, rp.VirtualFilesystem(), parent, rp.Component(), &ds); err != nil { + return nil, err + } // Determine whether or not we need to create a file. parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + child, _, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) if err == syserror.ENOENT && mayCreate { if parent.isSynthetic() { parent.dirMu.Unlock() @@ -965,6 +896,8 @@ afterTrailingSymlink: child.IncRef() defer child.DecRef(ctx) unlock() + // child is intentionally not added to ds (which would remove it from the + // cache) because doing so regresses performance in practice. return child.open(ctx, rp, &opts) } @@ -1188,7 +1121,6 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } - *ds = appendDentry(*ds, child) // Incorporate the fid that was opened by lcreate. useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { @@ -1212,6 +1144,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving } // Insert the dentry into the tree. d.cacheNewChildLocked(child, name) + appendNewChildDentry(ds, d, child) if d.cachedMetadataAuthoritative() { d.touchCMtime() d.dirents = nil @@ -1296,18 +1229,23 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err := oldParent.checkPermissions(creds, vfs.MayWrite|vfs.MayExec); err != nil { return err } + vfsObj := rp.VirtualFilesystem() + if err := fs.revalidateOne(ctx, vfsObj, newParent, newName, &ds); err != nil { + return err + } + if err := fs.revalidateOne(ctx, vfsObj, oldParent, oldName, &ds); err != nil { + return err + } + // We need a dentry representing the renamed file since, if it's a // directory, we need to check for write permission on it. oldParent.dirMu.Lock() defer oldParent.dirMu.Unlock() - renamed, err := fs.getChildLocked(ctx, vfsObj, oldParent, oldName, &ds) + renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) if err != nil { return err } - if renamed == nil { - return syserror.ENOENT - } if err := oldParent.mayDelete(creds, renamed); err != nil { return err } @@ -1336,8 +1274,8 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.isDeleted() { return syserror.ENOENT } - replaced, err := fs.getChildLocked(ctx, rp.VirtualFilesystem(), newParent, newName, &ds) - if err != nil { + replaced, err := fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil && err != syserror.ENOENT { return err } var replacedVFSD *vfs.Dentry @@ -1401,8 +1339,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // parent isn't actually changing. if oldParent != newParent { oldParent.decRefNoCaching() - ds = appendDentry(ds, oldParent) newParent.IncRef() + ds = appendDentry(ds, newParent) + ds = appendDentry(ds, oldParent) if renamed.isSynthetic() { oldParent.syntheticChildren-- newParent.syntheticChildren++ @@ -1546,6 +1485,7 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath if d.isSocket() { if !d.isSynthetic() { d.IncRef() + ds = appendDentry(ds, d) return &endpoint{ dentry: d, path: opts.Addr, diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index a0c05231a..21692d2ac 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -18,21 +18,23 @@ // Lock order: // regularFileFD/directoryFD.mu // filesystem.renameMu -// dentry.dirMu -// filesystem.syncMu -// dentry.metadataMu -// *** "memmap.Mappable locks" below this point -// dentry.mapsMu -// *** "memmap.Mappable locks taken by Translate" below this point -// dentry.handleMu -// dentry.dataMu -// filesystem.inoMu +// dentry.cachingMu +// filesystem.cacheMu +// dentry.dirMu +// filesystem.syncMu +// dentry.metadataMu +// *** "memmap.Mappable locks" below this point +// dentry.mapsMu +// *** "memmap.Mappable locks taken by Translate" below this point +// dentry.handleMu +// dentry.dataMu +// filesystem.inoMu // specialFileFD.mu // specialFileFD.bufMu // -// Locking dentry.dirMu in multiple dentries requires that either ancestor -// dentries are locked before descendant dentries, or that filesystem.renameMu -// is locked for writing. +// Locking dentry.dirMu and dentry.metadataMu in multiple dentries requires that +// either ancestor dentries are locked before descendant dentries, or that +// filesystem.renameMu is locked for writing. package gofer import ( @@ -140,7 +142,8 @@ type filesystem struct { // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) // cachedDentriesLen is the number of dentries in cachedDentries. These fields - // are protected by renameMu. + // are protected by cacheMu. + cacheMu sync.Mutex `state:"nosave"` cachedDentries dentryList cachedDentriesLen uint64 @@ -620,11 +623,11 @@ func (fs *filesystem) Release(ctx context.Context) { // the reference count on every synthetic dentry. Synthetic dentries have one // reference for existence that should be dropped during filesystem.Release. // -// Precondition: d.fs.renameMu is locked. +// Precondition: d.fs.renameMu is locked for writing. func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { if d.isSynthetic() { d.decRefNoCaching() - d.checkCachingLocked(ctx) + d.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } if d.isDir() { var children []*dentry @@ -682,9 +685,13 @@ type dentry struct { // deleted. deleted is accessed using atomic memory operations. deleted uint32 + // cachingMu is used to synchronize concurrent dentry caching attempts on + // this dentry. + cachingMu sync.Mutex `state:"nosave"` + // If cached is true, dentryEntry links dentry into // filesystem.cachedDentries. cached and dentryEntry are protected by - // filesystem.renameMu. + // cachingMu. cached bool dentryEntry @@ -980,36 +987,63 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } // Preconditions: !d.isSynthetic(). +// Preconditions: d.metadataMu is locked. +func (d *dentry) refreshSizeLocked(ctx context.Context) error { + d.handleMu.RLock() + + if d.writeFD < 0 { + d.handleMu.RUnlock() + // Ask the gofer if we don't have a host FD. + return d.updateFromGetattrLocked(ctx) + } + + var stat unix.Statx_t + err := unix.Statx(int(d.writeFD), "", unix.AT_EMPTY_PATH, unix.STATX_SIZE, &stat) + d.handleMu.RUnlock() // must be released before updateSizeLocked() + if err != nil { + return err + } + d.updateSizeLocked(stat.Size) + return nil +} + +// Preconditions: !d.isSynthetic(). func (d *dentry) updateFromGetattr(ctx context.Context) error { - // Use d.readFile or d.writeFile, which represent 9P fids that have been + // d.metadataMu must be locked *before* we getAttr so that we do not end up + // updating stale attributes in d.updateFromP9AttrsLocked(). + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + return d.updateFromGetattrLocked(ctx) +} + +// Preconditions: +// * !d.isSynthetic(). +// * d.metadataMu is locked. +func (d *dentry) updateFromGetattrLocked(ctx context.Context) error { + // Use d.readFile or d.writeFile, which represent 9P FIDs that have been // opened, in preference to d.file, which represents a 9P fid that has not. // This may be significantly more efficient in some implementations. Prefer // d.writeFile over d.readFile since some filesystem implementations may // update a writable handle's metadata after writes to that handle, without // making metadata updates immediately visible to read-only handles // representing the same file. - var ( - file p9file - handleMuRLocked bool - ) - // d.metadataMu must be locked *before* we getAttr so that we do not end up - // updating stale attributes in d.updateFromP9AttrsLocked(). - d.metadataMu.Lock() - defer d.metadataMu.Unlock() d.handleMu.RLock() - if !d.writeFile.isNil() { + handleMuRLocked := true + var file p9file + switch { + case !d.writeFile.isNil(): file = d.writeFile - handleMuRLocked = true - } else if !d.readFile.isNil() { + case !d.readFile.isNil(): file = d.readFile - handleMuRLocked = true - } else { + default: file = d.file d.handleMu.RUnlock() + handleMuRLocked = false } + _, attrMask, attr, err := file.getAttr(ctx, dentryAttrMask()) if handleMuRLocked { - d.handleMu.RUnlock() + d.handleMu.RUnlock() // must be released before updateFromP9AttrsLocked() } if err != nil { return err @@ -1104,24 +1138,27 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs defer d.metadataMu.Unlock() // As with Linux, if the UID, GID, or file size is changing, we have to - // clear permission bits. Note that when set, clearSGID causes - // permissions to be updated, but does not modify stat.Mask, as - // modification would cause an extra inotify flag to be set. - clearSGID := stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid) || - stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid) || + // clear permission bits. Note that when set, clearSGID may cause + // permissions to be updated. + clearSGID := (stat.Mask&linux.STATX_UID != 0 && stat.UID != atomic.LoadUint32(&d.uid)) || + (stat.Mask&linux.STATX_GID != 0 && stat.GID != atomic.LoadUint32(&d.gid)) || stat.Mask&linux.STATX_SIZE != 0 if clearSGID { if stat.Mask&linux.STATX_MODE != 0 { stat.Mode = uint16(vfs.ClearSUIDAndSGID(uint32(stat.Mode))) } else { - stat.Mode = uint16(vfs.ClearSUIDAndSGID(atomic.LoadUint32(&d.mode))) + oldMode := atomic.LoadUint32(&d.mode) + if updatedMode := vfs.ClearSUIDAndSGID(oldMode); updatedMode != oldMode { + stat.Mode = uint16(updatedMode) + stat.Mask |= linux.STATX_MODE + } } } if !d.isSynthetic() { if stat.Mask != 0 { if err := d.file.setAttr(ctx, p9.SetAttrMask{ - Permissions: stat.Mask&linux.STATX_MODE != 0 || clearSGID, + Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, GID: stat.Mask&linux.STATX_GID != 0, Size: stat.Mask&linux.STATX_SIZE != 0, @@ -1156,7 +1193,7 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - if stat.Mask&linux.STATX_MODE != 0 || clearSGID { + if stat.Mask&linux.STATX_MODE != 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } if stat.Mask&linux.STATX_UID != 0 { @@ -1312,9 +1349,7 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { if d.decRefNoCaching() == 0 { - d.fs.renameMu.Lock() - d.checkCachingLocked(ctx) - d.fs.renameMu.Unlock() + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } } @@ -1374,15 +1409,16 @@ func (d *dentry) Watches() *vfs.Watches { // // If no watches are left on this dentry and it has no references, cache it. func (d *dentry) OnZeroWatches(ctx context.Context) { - if atomic.LoadInt64(&d.refs) == 0 { - d.fs.renameMu.Lock() - d.checkCachingLocked(ctx) - d.fs.renameMu.Unlock() - } + d.checkCachingLocked(ctx, false /* renameMuWriteLocked */) } -// checkCachingLocked should be called after d's reference count becomes 0 or it -// becomes disowned. +// checkCachingLocked should be called after d's reference count becomes 0 or +// it becomes disowned. +// +// For performance, checkCachingLocked can also be called after d's reference +// count becomes non-zero, so that d can be removed from the LRU cache. This +// may help in reducing the size of the cache and hence reduce evictions. Note +// that this is not necessary for correctness. // // It may be called on a destroyed dentry. For example, // renameMu[R]UnlockAndCheckCaching may call checkCachingLocked multiple times @@ -1390,33 +1426,46 @@ func (d *dentry) OnZeroWatches(ctx context.Context) { // operation. One of the calls may destroy the dentry, so subsequent calls will // do nothing. // -// Preconditions: d.fs.renameMu must be locked for writing; it may be -// temporarily unlocked. -func (d *dentry) checkCachingLocked(ctx context.Context) { - // Dentries with a non-zero reference count must be retained. (The only way - // to obtain a reference on a dentry with zero references is via path - // resolution, which requires renameMu, so if d.refs is zero then it will - // remain zero while we hold renameMu for writing.) +// Preconditions: d.fs.renameMu must be locked for writing if +// renameMuWriteLocked is true; it may be temporarily unlocked. +func (d *dentry) checkCachingLocked(ctx context.Context, renameMuWriteLocked bool) { + d.cachingMu.Lock() refs := atomic.LoadInt64(&d.refs) if refs == -1 { // Dentry has already been destroyed. + d.cachingMu.Unlock() return } if refs > 0 { - // This isn't strictly necessary (fs.cachedDentries is permitted to - // contain dentries with non-zero refs, which are skipped by - // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but - // since we are already holding fs.renameMu for writing we may as well. + // fs.cachedDentries is permitted to contain dentries with non-zero refs, + // which are skipped by fs.evictCachedDentryLocked() upon reaching the end + // of the LRU. But it is still beneficial to remove d from the cache as we + // are already holding d.cachingMu. Keeping a cleaner cache also reduces + // the number of evictions (which is expensive as it acquires fs.renameMu). d.removeFromCacheLocked() + d.cachingMu.Unlock() return } // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { + d.removeFromCacheLocked() + d.cachingMu.Unlock() + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu for writing as needed by d.destroyLocked(). + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + // Now that renameMu is locked for writing, no more refs can be taken on + // d because path resolution requires renameMu for reading at least. + if atomic.LoadInt64(&d.refs) != 0 { + // Destroy d only if its ref is still 0. If not, either someone took a + // ref on it or it got destroyed before fs.renameMu could be acquired. + return + } + } if d.isDeleted() { d.watches.HandleDeletion(ctx) } - d.removeFromCacheLocked() d.destroyLocked(ctx) return } @@ -1426,24 +1475,36 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { // d.watches cannot concurrently transition from zero to non-zero, because // adding a watch requires holding a reference on d. if d.watches.Size() > 0 { - // As in the refs > 0 case, this is not strictly necessary. + // As in the refs > 0 case, removing d is beneficial. d.removeFromCacheLocked() + d.cachingMu.Unlock() return } if atomic.LoadInt32(&d.fs.released) != 0 { + d.cachingMu.Unlock() + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu to access d.parent. Lock it for writing as + // needed by d.destroyLocked() later. + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + } if d.parent != nil { d.parent.dirMu.Lock() delete(d.parent.children, d.name) d.parent.dirMu.Unlock() } d.destroyLocked(ctx) + return } + d.fs.cacheMu.Lock() // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) d.fs.cachedDentries.PushFront(d) + d.fs.cacheMu.Unlock() + d.cachingMu.Unlock() return } // Cache the dentry, then evict the least recently used cached dentry if @@ -1451,18 +1512,28 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { d.fs.cachedDentries.PushFront(d) d.fs.cachedDentriesLen++ d.cached = true - if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { + shouldEvict := d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries + d.fs.cacheMu.Unlock() + d.cachingMu.Unlock() + + if shouldEvict { + if !renameMuWriteLocked { + // Need to lock d.fs.renameMu for writing as needed by + // d.evictCachedDentryLocked(). + d.fs.renameMu.Lock() + defer d.fs.renameMu.Unlock() + } d.fs.evictCachedDentryLocked(ctx) - // Whether or not victim was destroyed, we brought fs.cachedDentriesLen - // back down to fs.opts.maxCachedDentries, so we don't loop. } } -// Preconditions: d.fs.renameMu must be locked for writing. +// Preconditions: d.cachingMu must be locked. func (d *dentry) removeFromCacheLocked() { if d.cached { + d.fs.cacheMu.Lock() d.fs.cachedDentries.Remove(d) d.fs.cachedDentriesLen-- + d.fs.cacheMu.Unlock() d.cached = false } } @@ -1477,28 +1548,43 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { // Preconditions: // * fs.renameMu must be locked for writing; it may be temporarily unlocked. -// * fs.cachedDentriesLen != 0. func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + fs.cacheMu.Lock() victim := fs.cachedDentries.Back() + fs.cacheMu.Unlock() + if victim == nil { + // fs.cachedDentries may have become empty between when it was checked and + // when we locked fs.cacheMu. + return + } + + victim.cachingMu.Lock() victim.removeFromCacheLocked() // victim.refs or victim.watches.Size() may have become non-zero from an // earlier path resolution since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 { - if victim.parent != nil { - victim.parent.dirMu.Lock() - if !victim.vfsd.IsDead() { - // Note that victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount points. - fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) - delete(victim.parent.children, victim.name) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victim.parent.dirMu.Unlock() + if atomic.LoadInt64(&victim.refs) != 0 || victim.watches.Size() != 0 { + victim.cachingMu.Unlock() + return + } + if victim.parent != nil { + victim.parent.dirMu.Lock() + if !victim.vfsd.IsDead() { + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + // We're only deleting the dentry, not the file it + // represents, so we don't need to update + // victimParent.dirents etc. } - victim.destroyLocked(ctx) + victim.parent.dirMu.Unlock() } + // Safe to unlock cachingMu now that victim.vfsd.IsDead(). Henceforth any + // concurrent caching attempts on victim will attempt to destroy it and so + // will try to acquire fs.renameMu (which we have already acquired). Hence, + // fs.renameMu will synchronize the destroy attempts. + victim.cachingMu.Unlock() + victim.destroyLocked(ctx) } // destroyLocked destroys the dentry. @@ -1584,7 +1670,7 @@ func (d *dentry) destroyLocked(ctx context.Context) { // Drop the reference held by d on its parent without recursively locking // d.fs.renameMu. if d.parent != nil && d.parent.decRefNoCaching() == 0 { - d.parent.checkCachingLocked(ctx) + d.parent.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } refsvfs2.Unregister(d) } diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index 76f08e252..806392d50 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -55,7 +55,7 @@ func TestDestroyIdempotent(t *testing.T) { fs.renameMu.Lock() defer fs.renameMu.Unlock() - child.checkCachingLocked(ctx) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) if got := atomic.LoadInt64(&child.refs); got != -1 { t.Fatalf("child.refs=%d, want: -1", got) } @@ -63,6 +63,6 @@ func TestDestroyIdempotent(t *testing.T) { if got := atomic.LoadInt64(&parent.refs); got != -1 { t.Fatalf("parent.refs=%d, want: -1", got) } - child.checkCachingLocked(ctx) - child.checkCachingLocked(ctx) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) + child.checkCachingLocked(ctx, true /* renameMuWriteLocked */) } diff --git a/pkg/sentry/fsimpl/gofer/p9file.go b/pkg/sentry/fsimpl/gofer/p9file.go index 21b4a96fe..b0a429d42 100644 --- a/pkg/sentry/fsimpl/gofer/p9file.go +++ b/pkg/sentry/fsimpl/gofer/p9file.go @@ -238,3 +238,10 @@ func (f p9file) connect(ctx context.Context, flags p9.ConnectFlags) (*fd.FD, err ctx.UninterruptibleSleepFinish(false) return fdobj, err } + +func (f p9file) multiGetAttr(ctx context.Context, names []string) ([]p9.FullStat, error) { + ctx.UninterruptibleSleepStart(false) + stats, err := f.file.MultiGetAttr(names) + ctx.UninterruptibleSleepFinish(false) + return stats, err +} diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 47563538c..f0e7bbaf7 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -204,18 +204,19 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } d := fd.dentry() + + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + // If the fd was opened with O_APPEND, make sure the file size is updated. // There is a possible race here if size is modified externally after // metadata cache is updated. if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { - if err := d.updateFromGetattr(ctx); err != nil { + if err := d.refreshSizeLocked(ctx); err != nil { return 0, offset, err } } - d.metadataMu.Lock() - defer d.metadataMu.Unlock() - // Set offset to file size if the fd was opened with O_APPEND. if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 { // Holding d.metadataMu is sufficient for reading d.size. @@ -701,6 +702,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt } // After this point, d may be used as a memmap.Mappable. d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init) + opts.SentryOwnedContent = d.fs.opts.forcePageCache return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts) } diff --git a/pkg/sentry/fsimpl/gofer/revalidate.go b/pkg/sentry/fsimpl/gofer/revalidate.go new file mode 100644 index 000000000..8f81f0822 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/revalidate.go @@ -0,0 +1,386 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +type errPartialRevalidation struct{} + +// Error implements error.Error. +func (errPartialRevalidation) Error() string { + return "partial revalidation" +} + +type errRevalidationStepDone struct{} + +// Error implements error.Error. +func (errRevalidationStepDone) Error() string { + return "stop revalidation" +} + +// revalidatePath checks cached dentries for external modification. File +// attributes are refreshed and cache is invalidated in case the dentry has been +// deleted, or a new file/directory created in its place. +// +// Revalidation stops at symlinks and mount points. The caller is responsible +// for revalidating again after symlinks are resolved and after changing to +// different mounts. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidatePath(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error { + // Revalidation is done even if start is synthetic in case the path is + // something like: ../non_synthetic_file. + if fs.opts.interop != InteropModeShared { + return nil + } + + // Copy resolving path to walk the path for revalidation. + rp := rpOrig.Copy() + err := fs.revalidate(ctx, rp, start, rp.Done, ds) + rp.Release(ctx) + return err +} + +// revalidateParentDir does the same as revalidatePath, but stops at the parent. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidateParentDir(ctx context.Context, rpOrig *vfs.ResolvingPath, start *dentry, ds **[]*dentry) error { + // Revalidation is done even if start is synthetic in case the path is + // something like: ../non_synthetic_file and parent is non synthetic. + if fs.opts.interop != InteropModeShared { + return nil + } + + // Copy resolving path to walk the path for revalidation. + rp := rpOrig.Copy() + err := fs.revalidate(ctx, rp, start, rp.Final, ds) + rp.Release(ctx) + return err +} + +// revalidateOne does the same as revalidatePath, but checks a single dentry. +// +// Preconditions: +// * fs.renameMu must be locked. +func (fs *filesystem) revalidateOne(ctx context.Context, vfsObj *vfs.VirtualFilesystem, parent *dentry, name string, ds **[]*dentry) error { + // Skip revalidation for interop mode different than InteropModeShared or + // if the parent is synthetic (child must be synthetic too, but it cannot be + // replaced without first replacing the parent). + if parent.cachedMetadataAuthoritative() { + return nil + } + + parent.dirMu.Lock() + child, ok := parent.children[name] + parent.dirMu.Unlock() + if !ok { + return nil + } + + state := makeRevalidateState(parent) + defer state.release() + + state.add(name, child) + return fs.revalidateHelper(ctx, vfsObj, state, ds) +} + +// revalidate revalidates path components in rp until done returns true, or +// until a mount point or symlink is reached. It may send multiple MultiGetAttr +// calls to the gofer to handle ".." in the path. +// +// Preconditions: +// * fs.renameMu must be locked. +// * InteropModeShared is in effect. +func (fs *filesystem) revalidate(ctx context.Context, rp *vfs.ResolvingPath, start *dentry, done func() bool, ds **[]*dentry) error { + state := makeRevalidateState(start) + defer state.release() + + // Skip synthetic dentries because the start dentry cannot be replaced in case + // it has been created in the remote file system. + if !start.isSynthetic() { + state.add("", start) + } + +done: + for cur := start; !done(); { + var err error + cur, err = fs.revalidateStep(ctx, rp, cur, state) + if err != nil { + switch err.(type) { + case errPartialRevalidation: + if err := fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds); err != nil { + return err + } + + // Reset state to release any remaining locks and restart from where + // stepping stopped. + state.reset() + state.start = cur + + // Skip synthetic dentries because the start dentry cannot be replaced in + // case it has been created in the remote file system. + if !cur.isSynthetic() { + state.add("", cur) + } + + case errRevalidationStepDone: + break done + + default: + return err + } + } + } + return fs.revalidateHelper(ctx, rp.VirtualFilesystem(), state, ds) +} + +// revalidateStep walks one element of the path and updates revalidationState +// with the dentry if needed. It may also stop the stepping or ask for a +// partial revalidation. Partial revalidation requires the caller to revalidate +// the current revalidationState, release all locks, and resume stepping. +// In case a symlink is hit, revalidation stops and the caller is responsible +// for calling revalidate again after the symlink is resolved. Revalidation may +// also stop for other reasons, like hitting a child not in the cache. +// +// Returns: +// * (dentry, nil): step worked, continue stepping.` +// * (dentry, errPartialRevalidation): revalidation should be done with the +// state gathered so far. Then continue stepping with the remainder of the +// path, starting at `dentry`. +// * (nil, errRevalidationStepDone): revalidation doesn't need to step any +// further. It hit a symlink, a mount point, or an uncached dentry. +// +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). +// * InteropModeShared is in effect (assumes no negative dentries). +func (fs *filesystem) revalidateStep(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, state *revalidateState) (*dentry, error) { + switch name := rp.Component(); name { + case ".": + // Do nothing. + + case "..": + // Partial revalidation is required when ".." is hit because metadata locks + // can only be acquired from parent to child to avoid deadlocks. + if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { + return nil, errRevalidationStepDone{} + } else if isRoot || d.parent == nil { + rp.Advance() + return d, errPartialRevalidation{} + } + // We must assume that d.parent is correct, because if d has been moved + // elsewhere in the remote filesystem so that its parent has changed, + // we have no way of determining its new parent's location in the + // filesystem. + // + // Call rp.CheckMount() before updating d.parent's metadata, since if + // we traverse to another mount then d.parent's metadata is irrelevant. + if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { + return nil, errRevalidationStepDone{} + } + rp.Advance() + return d.parent, errPartialRevalidation{} + + default: + d.dirMu.Lock() + child, ok := d.children[name] + d.dirMu.Unlock() + if !ok { + // child is not cached, no need to validate any further. + return nil, errRevalidationStepDone{} + } + + state.add(name, child) + + // Symlink must be resolved before continuing with revalidation. + if child.isSymlink() { + return nil, errRevalidationStepDone{} + } + + d = child + } + + rp.Advance() + return d, nil +} + +// revalidateHelper calls the gofer to stat all dentries in `state`. It will +// update or invalidate dentries in the cache based on the result. +// +// Preconditions: +// * fs.renameMu must be locked. +// * InteropModeShared is in effect. +func (fs *filesystem) revalidateHelper(ctx context.Context, vfsObj *vfs.VirtualFilesystem, state *revalidateState, ds **[]*dentry) error { + if len(state.names) == 0 { + return nil + } + // Lock metadata on all dentries *before* getting attributes for them. + state.lockAllMetadata() + stats, err := state.start.file.multiGetAttr(ctx, state.names) + if err != nil { + return err + } + + i := -1 + for d := state.popFront(); d != nil; d = state.popFront() { + i++ + found := i < len(stats) + if i == 0 && len(state.names[0]) == 0 { + if found && !d.isSynthetic() { + // First dentry is where the search is starting, just update attributes + // since it cannot be replaced. + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) + } + d.metadataMu.Unlock() + continue + } + + // Note that synthetic dentries will always fails the comparison check + // below. + if !found || d.qidPath != stats[i].QID.Path { + d.metadataMu.Unlock() + if !found && d.isSynthetic() { + // We have a synthetic file, and no remote file has arisen to replace + // it. + return nil + } + // The file at this path has changed or no longer exists. Mark the + // dentry invalidated, and re-evaluate its caching status (i.e. if it + // has 0 references, drop it). The dentry will be reloaded next time it's + // accessed. + vfsObj.InvalidateDentry(ctx, &d.vfsd) + + name := state.names[i] + d.parent.dirMu.Lock() + + if d.isSynthetic() { + // Normally we don't mark invalidated dentries as deleted since + // they may still exist (but at a different path), and also for + // consistency with Linux. However, synthetic files are guaranteed + // to become unreachable if their dentries are invalidated, so + // treat their invalidation as deletion. + d.setDeleted() + d.decRefNoCaching() + *ds = appendDentry(*ds, d) + + d.parent.syntheticChildren-- + d.parent.dirents = nil + } + + // Since the dirMu was released and reacquired, re-check that the + // parent's child with this name is still the same. Do not touch it if + // it has been replaced with a different one. + if child := d.parent.children[name]; child == d { + // Invalidate dentry so it gets reloaded next time it's accessed. + delete(d.parent.children, name) + } + d.parent.dirMu.Unlock() + + return nil + } + + // The file at this path hasn't changed. Just update cached metadata. + d.updateFromP9AttrsLocked(stats[i].Valid, &stats[i].Attr) + d.metadataMu.Unlock() + } + + return nil +} + +// revalidateStatePool caches revalidateState instances to save array +// allocations for dentries and names. +var revalidateStatePool = sync.Pool{ + New: func() interface{} { + return &revalidateState{} + }, +} + +// revalidateState keeps state related to a revalidation request. It keeps track +// of {name, dentry} list being revalidated, as well as metadata locks on the +// dentries. The list must be in ancestry order, in other words `n` must be +// `n-1` child. +type revalidateState struct { + // start is the dentry where to start the attributes search. + start *dentry + + // List of names of entries to refresh attributes. Names length must be the + // same as detries length. They are kept in separate slices because names is + // used to call File.MultiGetAttr(). + names []string + + // dentries is the list of dentries that correspond to the names above. + // dentry.metadataMu is acquired as each dentry is added to this list. + dentries []*dentry + + // locked indicates if metadata lock has been acquired on dentries. + locked bool +} + +func makeRevalidateState(start *dentry) *revalidateState { + r := revalidateStatePool.Get().(*revalidateState) + r.start = start + return r +} + +// release must be called after the caller is done with this object. It releases +// all metadata locks and resources. +func (r *revalidateState) release() { + r.reset() + revalidateStatePool.Put(r) +} + +// Preconditions: +// * d is a descendant of all dentries in r.dentries. +func (r *revalidateState) add(name string, d *dentry) { + r.names = append(r.names, name) + r.dentries = append(r.dentries, d) +} + +func (r *revalidateState) lockAllMetadata() { + for _, d := range r.dentries { + d.metadataMu.Lock() + } + r.locked = true +} + +func (r *revalidateState) popFront() *dentry { + if len(r.dentries) == 0 { + return nil + } + d := r.dentries[0] + r.dentries = r.dentries[1:] + return d +} + +// reset releases all metadata locks and resets all fields to allow this +// instance to be reused. +func (r *revalidateState) reset() { + if r.locked { + // Unlock any remaining dentries. + for _, d := range r.dentries { + d.metadataMu.Unlock() + } + r.locked = false + } + r.start = nil + r.names = r.names[:0] + r.dentries = r.dentries[:0] +} diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 3b90375b6..a81f550b1 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -460,6 +460,9 @@ func (i *inode) DecRef(ctx context.Context) { if err := unix.Close(i.hostFD); err != nil { log.Warningf("failed to close host fd %d: %v", i.hostFD, err) } + // We can't rely on fdnotifier when closing the fd, because the event may race + // with fdnotifier.RemoveFD. Instead, notify the queue explicitly. + i.queue.Notify(waiter.EventHUp | waiter.ReadableEvents | waiter.WritableEvents) }) } diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go index 31301c715..c502d8e99 100644 --- a/pkg/sentry/fsimpl/host/save_restore.go +++ b/pkg/sentry/fsimpl/host/save_restore.go @@ -68,3 +68,10 @@ func (i *inode) afterLoad() { } } } + +// afterLoad is invoked by stateify. +func (c *ConnectedEndpoint) afterLoad() { + if err := c.initFromOptions(); err != nil { + panic(fmt.Sprintf("initFromOptions failed: %v", err)) + } +} diff --git a/pkg/sentry/fsimpl/host/socket.go b/pkg/sentry/fsimpl/host/socket.go index 60e237ac7..ca85f5601 100644 --- a/pkg/sentry/fsimpl/host/socket.go +++ b/pkg/sentry/fsimpl/host/socket.go @@ -39,7 +39,7 @@ import ( func newEndpoint(ctx context.Context, hostFD int, queue *waiter.Queue) (transport.Endpoint, error) { // Set up an external transport.Endpoint using the host fd. addr := fmt.Sprintf("hostfd:[%d]", hostFD) - e, err := NewConnectedEndpoint(ctx, hostFD, addr, true /* saveable */) + e, err := NewConnectedEndpoint(hostFD, addr) if err != nil { return nil, err.ToError() } @@ -86,7 +86,10 @@ type ConnectedEndpoint struct { // for restoring them. func (c *ConnectedEndpoint) init() *syserr.Error { c.InitRefs() + return c.initFromOptions() +} +func (c *ConnectedEndpoint) initFromOptions() *syserr.Error { family, err := unix.GetsockoptInt(c.fd, unix.SOL_SOCKET, unix.SO_DOMAIN) if err != nil { return syserr.FromError(err) @@ -123,7 +126,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { // The caller is responsible for calling Init(). Additionaly, Release needs to // be called twice because ConnectedEndpoint is both a transport.Receiver and // transport.ConnectedEndpoint. -func NewConnectedEndpoint(ctx context.Context, hostFD int, addr string, saveable bool) (*ConnectedEndpoint, *syserr.Error) { +func NewConnectedEndpoint(hostFD int, addr string) (*ConnectedEndpoint, *syserr.Error) { e := ConnectedEndpoint{ fd: hostFD, addr: addr, @@ -330,8 +333,16 @@ func (c *ConnectedEndpoint) CloseUnread() {} // SetSendBufferSize implements transport.ConnectedEndpoint.SetSendBufferSize. func (c *ConnectedEndpoint) SetSendBufferSize(v int64) (newSz int64) { - // gVisor does not permit setting of SO_SNDBUF for host backed unix domain - // sockets. + // gVisor does not permit setting of SO_SNDBUF for host backed unix + // domain sockets. + return atomic.LoadInt64(&c.sndbuf) +} + +// SetReceiveBufferSize implements transport.ConnectedEndpoint.SetReceiveBufferSize. +func (c *ConnectedEndpoint) SetReceiveBufferSize(v int64) (newSz int64) { + // gVisor does not permit setting of SO_RCVBUF for host backed unix + // domain sockets. Receive buffer does not have any effect for unix + // sockets and we claim to be the same as send buffer. return atomic.LoadInt64(&c.sndbuf) } diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index 65054b0ea..84b1c3745 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -25,8 +25,10 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// DynamicBytesFile implements kernfs.Inode and represents a read-only -// file whose contents are backed by a vfs.DynamicBytesSource. +// DynamicBytesFile implements kernfs.Inode and represents a read-only file +// whose contents are backed by a vfs.DynamicBytesSource. If data additionally +// implements vfs.WritableDynamicBytesSource, the file also supports dispatching +// writes to the implementer, but note that this will not update the source data. // // Must be instantiated with NewDynamicBytesFile or initialized with Init // before first use. @@ -40,7 +42,9 @@ type DynamicBytesFile struct { InodeNotSymlink locks vfs.FileLocks - data vfs.DynamicBytesSource + // data can additionally implement vfs.WritableDynamicBytesSource to support + // writes. + data vfs.DynamicBytesSource } var _ Inode = (*DynamicBytesFile)(nil) diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index badca4d9f..f50b0fb08 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -612,16 +612,24 @@ afterTrailingSymlink: // ReadlinkAt implements vfs.FilesystemImpl.ReadlinkAt. func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (string, error) { - fs.mu.RLock() defer fs.processDeferredDecRefs(ctx) - defer fs.mu.RUnlock() + + fs.mu.RLock() d, err := fs.walkExistingLocked(ctx, rp) if err != nil { + fs.mu.RUnlock() return "", err } if !d.isSymlink() { + fs.mu.RUnlock() return "", syserror.EINVAL } + + // Inode.Readlink() cannot be called holding fs locks. + d.IncRef() + defer d.DecRef(ctx) + fs.mu.RUnlock() + return d.inode.Readlink(ctx, rp.Mount()) } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 565d723f0..6f699c9cd 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -61,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -508,6 +509,15 @@ func (d *Dentry) Inode() Inode { return d.inode } +// FSLocalPath returns an absolute path to d, relative to the root of its +// filesystem. +func (d *Dentry) FSLocalPath() string { + var b fspath.Builder + _ = genericPrependPath(vfs.VirtualDentry{}, nil, d, &b) + b.PrependByte('/') + return b.String() +} + // The Inode interface maps filesystem-level operations that operate on paths to // equivalent operations on specific filesystem nodes. // @@ -524,6 +534,9 @@ func (d *Dentry) Inode() Inode { // - Checking that dentries passed to methods are of the appropriate file type. // - Checking permissions. // +// Inode functions may be called holding filesystem wide locks and are not +// allowed to call vfs functions that may reenter, unless otherwise noted. +// // Specific responsibilities of implementations are documented below. type Inode interface { // Methods related to reference counting. A generic implementation is @@ -670,6 +683,9 @@ type inodeDirectory interface { type inodeSymlink interface { // Readlink returns the target of a symbolic link. If an inode is not a // symlink, the implementation should return EINVAL. + // + // Readlink is called with no kernfs locks held, so it may reenter if needed + // to resolve symlink targets. Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) // Getlink returns the target of a symbolic link, as used by path diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index 254a8b062..ce8f55b1f 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -86,13 +86,13 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) - var cgroups map[string]string + var fakeCgroupControllers map[string]string if opts.InternalData != nil { data := opts.InternalData.(*InternalData) - cgroups = data.Cgroups + fakeCgroupControllers = data.Cgroups } - inode := procfs.newTasksInode(ctx, k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, fakeCgroupControllers) var dentry kernfs.Dentry dentry.InitRoot(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index fea138f93..d05cc1508 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -47,7 +47,7 @@ type taskInode struct { var _ kernfs.Inode = (*taskInode)(nil) -func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, cgroupControllers map[string]string) (kernfs.Inode, error) { +func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, isThreadGroup bool, fakeCgroupControllers map[string]string) (kernfs.Inode, error) { if task.ExitState() == kernel.TaskExitDead { return nil, syserror.ESRCH } @@ -82,10 +82,12 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns "uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { - contents["task"] = fs.newSubtasks(ctx, task, pidns, cgroupControllers) + contents["task"] = fs.newSubtasks(ctx, task, pidns, fakeCgroupControllers) } - if len(cgroupControllers) > 0 { - contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newCgroupData(cgroupControllers)) + if len(fakeCgroupControllers) > 0 { + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, newFakeCgroupData(fakeCgroupControllers)) + } else { + contents["cgroup"] = fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskCgroupData{task: task}) } taskInode := &taskInode{task: task} @@ -226,11 +228,14 @@ func newIO(t *kernel.Task, isThreadGroup bool) *ioData { return &ioData{ioUsage: t} } -// newCgroupData creates inode that shows cgroup information. -// From man 7 cgroups: "For each cgroup hierarchy of which the process is a -// member, there is one entry containing three colon-separated fields: -// hierarchy-ID:controller-list:cgroup-path" -func newCgroupData(controllers map[string]string) dynamicInode { +// newFakeCgroupData creates an inode that shows fake cgroup +// information passed in as mount options. From man 7 cgroups: "For +// each cgroup hierarchy of which the process is a member, there is +// one entry containing three colon-separated fields: +// hierarchy-ID:controller-list:cgroup-path" +// +// TODO(b/182488796): Remove once all users adopt cgroupfs. +func newFakeCgroupData(controllers map[string]string) dynamicInode { var buf bytes.Buffer // The hierarchy ids must be positive integers (for cgroup v1), but the diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 02bf74dbc..4718fac7a 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -221,6 +221,8 @@ func (s *fdSymlink) Readlink(ctx context.Context, _ *vfs.Mount) (string, error) defer file.DecRef(ctx) root := vfs.RootFromContext(ctx) defer root.DecRef(ctx) + + // Note: it's safe to reenter kernfs from Readlink if needed to resolve path. return s.task.Kernel().VFS().PathnameWithDeleted(ctx, root, file.VirtualDentry()) } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 85909d551..b294dfd6a 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -1100,3 +1100,32 @@ func (fd *namespaceFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) err func (fd *namespaceFD) Release(ctx context.Context) { fd.inode.DecRef(ctx) } + +// taskCgroupData generates data for /proc/[pid]/cgroup. +// +// +stateify savable +type taskCgroupData struct { + dynamicBytesFileSetAttr + task *kernel.Task +} + +var _ dynamicInode = (*taskCgroupData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (d *taskCgroupData) Generate(ctx context.Context, buf *bytes.Buffer) error { + // When a task is existing on Linux, a task's cgroup set is cleared and + // reset to the initial cgroup set, which is essentially the set of root + // cgroups. Because of this, the /proc/<pid>/cgroup file is always readable + // on Linux throughout a task's lifetime. + // + // The sentry removes tasks from cgroups during the exit process, but + // doesn't move them into an initial cgroup set, so partway through task + // exit this file show a task is in no cgroups, which is incorrect. Instead, + // once a task has left its cgroups, we return an error. + if d.task.ExitState() >= kernel.TaskExitInitiated { + return syserror.ESRCH + } + + d.task.GenerateProcTaskCgroup(buf) + return nil +} diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index fdc580610..cf905fae4 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -54,17 +54,18 @@ type tasksInode struct { // '/proc/self' and '/proc/thread-self' have custom directory offsets in // Linux. So handle them outside of OrderedChildren. - // cgroupControllers is a map of controller name to directory in the + // fakeCgroupControllers is a map of controller name to directory in the // cgroup hierarchy. These controllers are immutable and will be listed // in /proc/pid/cgroup if not nil. - cgroupControllers map[string]string + fakeCgroupControllers map[string]string } var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, fakeCgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ + "cmdline": fs.newInode(ctx, root, 0444, &cmdLineData{}), "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), @@ -76,11 +77,16 @@ func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), "version": fs.newInode(ctx, root, 0444, &versionData{}), } + // If fakeCgroupControllers are provided, don't create a cgroupfs backed + // /proc/cgroup as it will not match the fake controllers. + if len(fakeCgroupControllers) == 0 { + contents["cgroups"] = fs.newInode(ctx, root, 0444, &cgroupsData{}) + } inode := &tasksInode{ - pidns: pidns, - fs: fs, - cgroupControllers: cgroupControllers, + pidns: pidns, + fs: fs, + fakeCgroupControllers: fakeCgroupControllers, } inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.InitRefs() @@ -118,7 +124,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err return nil, syserror.ENOENT } - return i.fs.newTaskInode(ctx, task, i.pidns, true, i.cgroupControllers) + return i.fs.newTaskInode(ctx, task, i.pidns, true, i.fakeCgroupControllers) } // IterDirents implements kernfs.inodeDirectory.IterDirents. diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index f0029cda6..045ed7a2d 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -336,15 +336,6 @@ var _ dynamicInode = (*versionData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { - k := kernel.KernelFromContext(ctx) - init := k.GlobalInit() - if init == nil { - // Attempted to read before the init Task is created. This can - // only occur during startup, which should never need to read - // this file. - panic("Attempted to read version before initial Task is available") - } - // /proc/version takes the form: // // "SYSNAME version RELEASE (COMPILE_USER@COMPILE_HOST) @@ -364,7 +355,7 @@ func (*versionData) Generate(ctx context.Context, buf *bytes.Buffer) error { // FIXME(mpratt): Using Version from the init task SyscallTable // disregards the different version a task may have (e.g., in a uts // namespace). - ver := init.Leader().SyscallTable().Version + ver := kernelVersion(ctx) fmt.Fprintf(buf, "%s version %s %s\n", ver.Sysname, ver.Release, ver.Version) return nil } @@ -384,3 +375,47 @@ func (d *filesystemsData) Generate(ctx context.Context, buf *bytes.Buffer) error k.VFS().GenerateProcFilesystems(buf) return nil } + +// cgroupsData backs /proc/cgroups. +// +// +stateify savable +type cgroupsData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cgroupsData)(nil) + +// Generate implements vfs.DynamicBytesSource.Generate. +func (*cgroupsData) Generate(ctx context.Context, buf *bytes.Buffer) error { + r := kernel.KernelFromContext(ctx).CgroupRegistry() + r.GenerateProcCgroups(buf) + return nil +} + +// cmdLineData backs /proc/cmdline. +// +// +stateify savable +type cmdLineData struct { + dynamicBytesFileSetAttr +} + +var _ dynamicInode = (*cmdLineData)(nil) + +// Generate implements vfs.DynamicByteSource.Generate. +func (*cmdLineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + fmt.Fprintf(buf, "BOOT_IMAGE=/vmlinuz-%s-gvisor quiet\n", kernelVersion(ctx).Release) + return nil +} + +// kernelVersion returns the kernel version. +func kernelVersion(ctx context.Context) kernel.Version { + k := kernel.KernelFromContext(ctx) + init := k.GlobalInit() + if init == nil { + // Attempted to read before the init Task is created. This can + // only occur during startup, which should never need to read + // this file. + panic("Attempted to read version before initial Task is available") + } + return init.Leader().SyscallTable().Version +} diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index d6f076cd6..e534fbca8 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -47,6 +47,7 @@ var ( var ( tasksStaticFiles = map[string]testutil.DirentType{ + "cmdline": linux.DT_REG, "cpuinfo": linux.DT_REG, "filesystems": linux.DT_REG, "loadavg": linux.DT_REG, diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1d9280dae..14eb10dcd 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -122,11 +122,11 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs } func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode { - // If kcov is available, set up /sys/kernel/debug/kcov. Technically, debugfs - // should be mounted at debug/, but for our purposes, it is sufficient to - // keep it in sys. + // Set up /sys/kernel/debug/kcov. Technically, debugfs should be + // mounted at debug/, but for our purposes, it is sufficient to keep it + // in sys. var children map[string]kernfs.Inode - if coverage.KcovAvailable() { + if coverage.KcovSupported() { log.Debugf("Set up /sys/kernel/debug/kcov") children = map[string]kernfs.Inode{ "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{ diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index b3f9d1010..c766164c7 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -17,6 +17,7 @@ go_library( "//pkg/fspath", "//pkg/hostarch", "//pkg/memutil", + "//pkg/metric", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 807e4f44a..33e52ce64 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/memutil" + "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -62,6 +63,8 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } + metric.CreateSentryMetrics() + kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index cd849e87e..c45bddff6 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -488,6 +488,7 @@ func (fd *regularFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { file := fd.inode().impl.(*regularFile) + opts.SentryOwnedContent = true return vfs.GenericConfigureMMap(&fd.vfsfd, file, opts) } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 2da251233..d473a922d 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -18,10 +18,12 @@ go_library( "//pkg/marshal/primitive", "//pkg/merkletree", "//pkg/refsvfs2", + "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/sync", diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 6cb1a23e0..3582d14c9 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -168,10 +168,6 @@ afterSymlink: // Preconditions: // * fs.renameMu must be locked. // * d.dirMu must be locked. -// -// TODO(b/166474175): Investigate all possible errors returned in this -// function, and make sure we differentiate all errors that indicate unexpected -// modifications to the file system from the ones that are not harmful. func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() @@ -200,7 +196,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -209,7 +205,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's hash. @@ -223,12 +219,14 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err } + defer parentMerkleFD.DecRef(ctx) + // dataSize is the size of raw data for the Merkle tree. For a file, // dataSize is the size of the whole file. For a directory, dataSize is // the size of all its children's hashes. @@ -241,7 +239,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -251,7 +249,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := FileReadWriteSeeker{ @@ -264,7 +262,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi Start: parent.lowerVD, }, &vfs.StatOptions{}) if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -276,16 +274,15 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi var buf bytes.Buffer parent.hashMu.RLock() _, err = merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, - Children: parent.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + Children: parent.childrenNames, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), @@ -294,7 +291,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi }) parent.hashMu.RUnlock() if err != nil && err != io.EOF { - return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. @@ -331,19 +328,21 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { return err } + defer fd.DecRef(ctx) + merkleSize, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ Name: merkleSizeXattr, Size: sizeOfStringInt32, }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return err @@ -351,7 +350,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry size, err := strconv.Atoi(merkleSize) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } if d.isDir() && len(d.childrenNames) == 0 { @@ -361,14 +360,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenOffsetXattr, childPath, err)) } if err != nil { return err } childrenOffset, err := strconv.Atoi(childrenOffString) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) } childrenSizeString, err := fd.GetXattr(ctx, &vfs.GetXattrOptions{ @@ -377,23 +376,23 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry }) if err == syserror.ENODATA { - return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", childrenSizeXattr, childPath, err)) } if err != nil { return err } childrenSize, err := strconv.Atoi(childrenSizeString) if err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) } childrenNames := make([]byte, childrenSize) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(childrenOffset), vfs.ReadOptions{}); err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to read children map for %s: %v", childPath, err)) } if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { - return alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames of %s: %v", childPath, err)) } } @@ -405,15 +404,14 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry var buf bytes.Buffer d.hashMu.RLock() params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - Children: d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + Children: d.childrenNames, HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: 0, // Set read size to 0 so only the metadata is verified. @@ -438,7 +436,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry } if _, err := merkletree.Verify(params); err != nil && err != io.EOF { - return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + return fs.alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) } d.mode = uint32(stat.Mode) d.uid = stat.UID @@ -471,7 +469,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // The file was previously accessed. If the // file does not exist now, it indicates an // unexpected modification to the file system. - return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) } if err != nil { return nil, err @@ -483,7 +481,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // does not exist now, it indicates an unexpected // modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) } if err != nil { return nil, err @@ -553,8 +551,8 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } childVD, err := parent.getLowerAt(ctx, vfsObj, name) - if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) + if parent.verityEnabled() && err == syserror.ENOENT { + return nil, fs.alertIntegrityViolation(fmt.Sprintf("file %s expected but not found", parentPath+"/"+name)) } if err != nil { return nil, err @@ -565,30 +563,31 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, defer childVD.DecRef(ctx) childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) - if err == syserror.ENOENT { - if !fs.allowRuntimeEnable { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) - } - childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(merklePrefix + name), - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT, - Mode: 0644, - }) - if err != nil { - return nil, err - } - childMerkleFD.DecRef(ctx) - childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) - if err != nil { + if err != nil { + if err == syserror.ENOENT { + if parent.verityEnabled() { + return nil, fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath+"/"+name)) + } + childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ + Root: parent.lowerVD, + Start: parent.lowerVD, + Path: fspath.Parse(merklePrefix + name), + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR | linux.O_CREAT, + Mode: 0644, + }) + if err != nil { + return nil, err + } + childMerkleFD.DecRef(ctx) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + if err != nil { + return nil, err + } + } else { return nil, err } } - if err != nil && err != syserror.ENOENT { - return nil, err - } // Clear the Merkle tree file if they are to be generated at runtime. // TODO(b/182315468): Optimize the Merkle tree generate process to @@ -632,8 +631,6 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childVD.IncRef() childMerkleVD.IncRef() - parent.IncRef() - child.parent = parent child.name = name child.mode = uint32(stat.Mode) @@ -657,6 +654,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, } } + parent.IncRef() + child.parent = parent + return child, nil } @@ -855,7 +855,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // missing, it indicates an unexpected modification to the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err } @@ -878,7 +878,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -903,7 +903,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }) if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -921,7 +921,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err != nil { if err == syserror.ENOENT { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + return nil, d.fs.alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } return nil, err } @@ -985,8 +985,6 @@ func (fs *filesystem) SetStatAt(ctx context.Context, rp *vfs.ResolvingPath, opts } // StatAt implements vfs.FilesystemImpl.StatAt. -// TODO(b/170157489): Investigate whether stats other than Mode/UID/GID should -// be verified. func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vfs.StatOptions) (linux.Statx, error) { var ds *[]*dentry fs.renameMu.RLock() diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index a7d92a878..31d34ef60 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -34,6 +34,8 @@ package verity import ( + "bytes" + "encoding/hex" "encoding/json" "fmt" "math" @@ -44,19 +46,20 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" - - "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -95,14 +98,18 @@ const ( ) var ( - // action specifies the action towards detected violation. - action ViolationAction - // verityMu synchronizes concurrent operations that enable verity and perform // verification checks. verityMu sync.RWMutex ) +// Mount option names for verityfs. +const ( + moptLowerPath = "lower_path" + moptRootHash = "root_hash" + moptRootName = "root_name" +) + // HashAlgorithm is a type specifying the algorithm used to hash the file // content. type HashAlgorithm int @@ -169,6 +176,12 @@ type filesystem struct { // system. alg HashAlgorithm + // action specifies the action towards detected violation. + action ViolationAction + + // opts is the string mount options passed to opts.Data. + opts string + // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. @@ -191,9 +204,6 @@ type filesystem struct { // // +stateify savable type InternalFilesystemOptions struct { - // RootMerkleFileName is the name of the verity root Merkle tree file. - RootMerkleFileName string - // LowerName is the name of the filesystem wrapped by verity fs. LowerName string @@ -201,9 +211,6 @@ type InternalFilesystemOptions struct { // system. Alg HashAlgorithm - // RootHash is the root hash of the overall verity file system. - RootHash []byte - // AllowRuntimeEnable specifies whether the verity file system allows // enabling verification for files (i.e. building Merkle trees) during // runtime. @@ -228,8 +235,8 @@ func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In ErrorOnViolation // mode, it returns EIO, otherwise it panic. -func alertIntegrityViolation(msg string) error { - if action == ErrorOnViolation { +func (fs *filesystem) alertIntegrityViolation(msg string) error { + if fs.action == ErrorOnViolation { return syserror.EIO } panic(msg) @@ -237,28 +244,99 @@ func alertIntegrityViolation(msg string) error { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { + mopts := vfs.GenericParseMountOptions(opts.Data) + var rootHash []byte + if encodedRootHash, ok := mopts[moptRootHash]; ok { + delete(mopts, moptRootHash) + hash, err := hex.DecodeString(encodedRootHash) + if err != nil { + ctx.Warningf("verity.FilesystemType.GetFilesystem: Failed to decode root hash: %v", err) + return nil, nil, syserror.EINVAL + } + rootHash = hash + } + var lowerPathname string + if path, ok := mopts[moptLowerPath]; ok { + delete(mopts, moptLowerPath) + lowerPathname = path + } + rootName := "root" + if root, ok := mopts[moptRootName]; ok { + delete(mopts, moptRootName) + rootName = root + } + + // Check for unparsed options. + if len(mopts) != 0 { + ctx.Warningf("verity.FilesystemType.GetFilesystem: unknown options: %v", mopts) + return nil, nil, syserror.EINVAL + } + + // Handle internal options. iopts, ok := opts.InternalData.(InternalFilesystemOptions) - if !ok { + if len(lowerPathname) == 0 && !ok { ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } - action = iopts.Action - - // Mount the lower file system. The lower file system is wrapped inside - // verity, and should not be exposed or connected. - mopts := &vfs.MountOptions{ - GetFilesystemOptions: iopts.LowerGetFSOptions, - InternalMount: true, + if len(lowerPathname) != 0 { + if ok { + ctx.Warningf("verity.FilesystemType.GetFilesystem: unexpected verity configs with specified lower path") + return nil, nil, syserror.EINVAL + } + iopts = InternalFilesystemOptions{ + AllowRuntimeEnable: len(rootHash) == 0, + Action: ErrorOnViolation, + } } - mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mopts) - if err != nil { - return nil, nil, err + + var lowerMount *vfs.Mount + var mountedLowerVD vfs.VirtualDentry + // Use an existing mount if lowerPath is provided. + if len(lowerPathname) != 0 { + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + lowerPath := fspath.Parse(lowerPathname) + if !lowerPath.Absolute { + ctx.Infof("verity.FilesystemType.GetFilesystem: lower_path %q must be absolute", lowerPathname) + return nil, nil, syserror.EINVAL + } + var err error + mountedLowerVD, err = vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: lowerPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("verity.FilesystemType.GetFilesystem: failed to resolve lower_path %q: %v", lowerPathname, err) + return nil, nil, err + } + lowerMount = mountedLowerVD.Mount() + defer mountedLowerVD.DecRef(ctx) + } else { + // Mount the lower file system. The lower file system is wrapped inside + // verity, and should not be exposed or connected. + mountOpts := &vfs.MountOptions{ + GetFilesystemOptions: iopts.LowerGetFSOptions, + InternalMount: true, + } + mnt, err := vfsObj.MountDisconnected(ctx, creds, "", iopts.LowerName, mountOpts) + if err != nil { + return nil, nil, err + } + lowerMount = mnt } fs := &filesystem{ creds: creds.Fork(), alg: iopts.Alg, - lowerMount: mnt, + lowerMount: lowerMount, + action: iopts.Action, + opts: opts.Data, allowRuntimeEnable: iopts.AllowRuntimeEnable, } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -266,11 +344,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // Construct the root dentry. d := fs.newDentry() d.refs = 1 - lowerVD := vfs.MakeVirtualDentry(mnt, mnt.Root()) + lowerVD := vfs.MakeVirtualDentry(lowerMount, lowerMount.Root()) lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + rootName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -311,7 +389,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // the root Merkle file, or it's never generated. fs.vfsfs.DecRef(ctx) d.DecRef(ctx) - return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") + return nil, nil, fs.alertIntegrityViolation("Failed to find root Merkle file") } // Clear the Merkle tree file if they are to be generated at runtime. @@ -350,9 +428,15 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID - d.hash = make([]byte, len(iopts.RootHash)) d.childrenNames = make(map[string]struct{}) + d.hashMu.Lock() + d.hash = make([]byte, len(rootHash)) + copy(d.hash, rootHash) + d.hashMu.Unlock() + + fs.rootDentry = d + if !d.isDir() { ctx.Warningf("verity root must be a directory") return nil, nil, syserror.EINVAL @@ -368,7 +452,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Size: sizeOfStringInt32, }) if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenOffsetXattr, err)) } if err != nil { return nil, nil, err @@ -376,7 +460,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt off, err := strconv.Atoi(offString) if err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenOffsetXattr, err)) } sizeString, err := vfsObj.GetXattrAt(ctx, creds, &vfs.PathOperation{ @@ -387,14 +471,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Size: sizeOfStringInt32, }) if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", childrenSizeXattr, err)) } if err != nil { return nil, nil, err } size, err := strconv.Atoi(sizeString) if err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", childrenSizeXattr, err)) } lowerMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ @@ -404,19 +488,21 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to open root Merkle file: %v", err)) } if err != nil { return nil, nil, err } + defer lowerMerkleFD.DecRef(ctx) + childrenNames := make([]byte, size) if _, err := lowerMerkleFD.PRead(ctx, usermem.BytesIOSequence(childrenNames), int64(off), vfs.ReadOptions{}); err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to read root children map: %v", err)) } if err := json.Unmarshal(childrenNames, &d.childrenNames); err != nil { - return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) + return nil, nil, fs.alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) } if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { @@ -424,13 +510,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } } - d.hashMu.Lock() - copy(d.hash, iopts.RootHash) - d.hashMu.Unlock() d.vfsd.Init(d) - fs.rootDentry = d - return &fs.vfsfs, &d.vfsd, nil } @@ -441,7 +522,7 @@ func (fs *filesystem) Release(ctx context.Context) { // MountOptions implements vfs.FilesystemImpl.MountOptions. func (fs *filesystem) MountOptions() string { - return "" + return fs.opts } // dentry implements vfs.DentryImpl. @@ -722,6 +803,10 @@ type fileDescription struct { // underlying file system. lowerFD *vfs.FileDescription + // lowerMappable is the memmap.Mappable corresponding to this file in the + // underlying file system. + lowerMappable memmap.Mappable + // merkleReader is the read-only FileDescription corresponding to the // Merkle tree file in the underlying file system. merkleReader *vfs.FileDescription @@ -755,7 +840,6 @@ func (fd *fileDescription) Release(ctx context.Context) { // Stat implements vfs.FileDescriptionImpl.Stat. func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { - // TODO(b/162788573): Add integrity check for metadata. stat, err := fd.lowerFD.Stat(ctx, opts) if err != nil { return linux.Statx{}, err @@ -794,7 +878,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa // Verify that the child is expected. if dirent.Name != "." && dirent.Name != ".." { if _, ok := fd.d.childrenNames[dirent.Name]; !ok { - return alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name)) + return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children %s", dirent.Name)) } } } @@ -808,7 +892,7 @@ func (fd *fileDescription) IterDirents(ctx context.Context, cb vfs.IterDirentsCa // The result should contain all children plus "." and "..". if fd.d.verityEnabled() && len(ds) != len(fd.d.childrenNames)+2 { - return alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) + return fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Unexpected children number %d", len(ds))) } for fd.off < int64(len(ds)) { @@ -875,10 +959,9 @@ func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, ui } params := &merkletree.GenerateParams{ - TreeReader: &merkleReader, - TreeWriter: &merkleWriter, - Children: fd.d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + TreeReader: &merkleReader, + TreeWriter: &merkleWriter, + Children: fd.d.childrenNames, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), Name: fd.d.name, Mode: uint32(stat.Mode), @@ -980,7 +1063,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { // or directory other than the root, the parent Merkle tree file should // have also been initialized. if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { - return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") + return 0, fd.d.fs.alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } hash, dataSize, err := fd.generateMerkleLocked(ctx) @@ -1053,7 +1136,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest hosta if fd.d.fs.allowRuntimeEnable { return 0, syserror.ENODATA } - return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found") + return 0, fd.d.fs.alertIntegrityViolation("Ioctl measureVerity: no hash found") } // The first part of VerityDigest is the metadata. @@ -1107,8 +1190,6 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. case linux.FS_IOC_GETFLAGS: return fd.verityFlags(ctx, args[2].Pointer()) default: - // TODO(b/169682228): Investigate which ioctl commands should - // be allowed. return 0, syserror.ENOSYS } } @@ -1143,7 +1224,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { return 0, err @@ -1153,7 +1234,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } dataReader := FileReadWriteSeeker{ @@ -1168,16 +1249,15 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of fd.d.hashMu.RLock() n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, - Children: fd.d.childrenNames, - //TODO(b/156980949): Support passing other hash algorithms. + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + Children: fd.d.childrenNames, HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), @@ -1186,7 +1266,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of }) fd.d.hashMu.RUnlock() if err != nil { - return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + return 0, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } return n, err } @@ -1201,6 +1281,24 @@ func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, op return 0, syserror.EROFS } +// ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. +func (fd *fileDescription) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpts) error { + if err := fd.lowerFD.ConfigureMMap(ctx, opts); err != nil { + return err + } + fd.lowerMappable = opts.Mappable + if opts.MappingIdentity != nil { + opts.MappingIdentity.DecRef(ctx) + opts.MappingIdentity = nil + } + + // Check if mmap is allowed on the lower filesystem. + if !opts.SentryOwnedContent { + return syserror.ENODEV + } + return vfs.GenericConfigureMMap(&fd.vfsfd, fd, opts) +} + // LockBSD implements vfs.FileDescriptionImpl.LockBSD. func (fd *fileDescription) LockBSD(ctx context.Context, uid fslock.UniqueID, ownerPID int32, t fslock.LockType, block fslock.Blocker) error { return fd.lowerFD.LockBSD(ctx, ownerPID, t, block) @@ -1226,6 +1324,115 @@ func (fd *fileDescription) TestPOSIX(ctx context.Context, uid fslock.UniqueID, t return fd.lowerFD.TestPOSIX(ctx, uid, t, r) } +// Translate implements memmap.Mappable.Translate. +func (fd *fileDescription) Translate(ctx context.Context, required, optional memmap.MappableRange, at hostarch.AccessType) ([]memmap.Translation, error) { + ts, err := fd.lowerMappable.Translate(ctx, required, optional, at) + if err != nil { + return nil, err + } + + // dataSize is the size of the whole file. + dataSize, err := fd.merkleReader.GetXattr(ctx, &vfs.GetXattrOptions{ + Name: merkleSizeXattr, + Size: sizeOfStringInt32, + }) + + // The Merkle tree file for the child should have been created and + // contains the expected xattrs. If the xattr does not exist, it + // indicates unexpected modifications to the file system. + if err == syserror.ENODATA { + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + } + if err != nil { + return nil, err + } + + // The dataSize xattr should be an integer. If it's not, it indicates + // unexpected modifications to the file system. + size, err := strconv.Atoi(dataSize) + if err != nil { + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + } + + merkleReader := FileReadWriteSeeker{ + FD: fd.merkleReader, + Ctx: ctx, + } + + for _, t := range ts { + // Content integrity relies on sentry owning the backing data. MapInternal is guaranteed + // to fetch sentry owned memory because we disallow verity mmaps otherwise. + ims, err := t.File.MapInternal(memmap.FileRange{t.Offset, t.Offset + t.Source.Length()}, hostarch.Read) + if err != nil { + return nil, err + } + dataReader := mmapReadSeeker{ims, t.Source.Start} + var buf bytes.Buffer + _, err = merkletree.Verify(&merkletree.VerifyParams{ + Out: &buf, + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), + ReadOffset: int64(t.Source.Start), + ReadSize: int64(t.Source.Length()), + Expected: fd.d.hash, + DataAndTreeInSameFile: false, + }) + if err != nil { + return nil, fd.d.fs.alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) + } + } + return ts, err +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (fd *fileDescription) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) error { + return fd.lowerMappable.AddMapping(ctx, ms, ar, offset, writable) +} + +// RemoveMapping implements memmap.Mappable.RemoveMapping. +func (fd *fileDescription) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar hostarch.AddrRange, offset uint64, writable bool) { + fd.lowerMappable.RemoveMapping(ctx, ms, ar, offset, writable) +} + +// CopyMapping implements memmap.Mappable.CopyMapping. +func (fd *fileDescription) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR hostarch.AddrRange, offset uint64, writable bool) error { + return fd.lowerMappable.CopyMapping(ctx, ms, srcAR, dstAR, offset, writable) +} + +// InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. +func (fd *fileDescription) InvalidateUnsavable(context.Context) error { + return nil +} + +// mmapReadSeeker is a helper struct used by fileDescription.Translate to pass +// a safemem.BlockSeq pointing to the mapped region as io.ReaderAt. +type mmapReadSeeker struct { + safemem.BlockSeq + Offset uint64 +} + +// ReadAt implements io.ReaderAt.ReadAt. off is the offset into the mapped file. +func (r *mmapReadSeeker) ReadAt(p []byte, off int64) (int, error) { + bs := r.BlockSeq + // Adjust the offset into the mapped file to get the offset into the internally + // mapped region. + readOffset := off - int64(r.Offset) + if readOffset < 0 { + return 0, syserror.EINVAL + } + bs.DropFirst64(uint64(readOffset)) + view := bs.TakeFirst64(uint64(len(p))) + dst := safemem.BlockSeqOf(safemem.BlockFromSafeSlice(p)) + n, err := safemem.CopySeq(dst, view) + return int(n), err +} + // FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as // io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. type FileReadWriteSeeker struct { diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 57bd65202..5c78a0019 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -89,10 +89,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, AllowUserMount: true, }) + data := "root_name=" + rootMerkleFilename mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ + Data: data, InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, LowerName: "tmpfs", Alg: hashAlg, AllowRuntimeEnable: true, diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index e9eb89378..a1ec6daab 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -141,6 +141,7 @@ go_library( srcs = [ "abstract_socket_namespace.go", "aio.go", + "cgroup.go", "context.go", "fd_table.go", "fd_table_refs.go", @@ -178,6 +179,7 @@ go_library( "task.go", "task_acct.go", "task_block.go", + "task_cgroup.go", "task_clone.go", "task_context.go", "task_exec.go", @@ -241,6 +243,7 @@ go_library( "//pkg/sentry/fs/lock", "//pkg/sentry/fs/timerfd", "//pkg/sentry/fsbridge", + "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/fsimpl/pipefs", "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/fsimpl/timerfd", diff --git a/pkg/sentry/kernel/cgroup.go b/pkg/sentry/kernel/cgroup.go new file mode 100644 index 000000000..1f1c63f37 --- /dev/null +++ b/pkg/sentry/kernel/cgroup.go @@ -0,0 +1,281 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "bytes" + "fmt" + "sort" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" +) + +// InvalidCgroupHierarchyID indicates an uninitialized hierarchy ID. +const InvalidCgroupHierarchyID uint32 = 0 + +// CgroupControllerType is the name of a cgroup controller. +type CgroupControllerType string + +// CgroupController is the common interface to cgroup controllers available to +// the entire sentry. The controllers themselves are defined by cgroupfs. +// +// Callers of this interface are often unable access synchronization needed to +// ensure returned values remain valid. Some of values returned from this +// interface are thus snapshots in time, and may become stale. This is ok for +// many callers like procfs. +type CgroupController interface { + // Returns the type of this cgroup controller (ex "memory", "cpu"). Returned + // value is valid for the lifetime of the controller. + Type() CgroupControllerType + + // Hierarchy returns the ID of the hierarchy this cgroup controller is + // attached to. Returned value is valid for the lifetime of the controller. + HierarchyID() uint32 + + // Filesystem returns the filesystem this controller is attached to. + // Returned value is valid for the lifetime of the controller. + Filesystem() *vfs.Filesystem + + // RootCgroup returns the root cgroup for this controller. Returned value is + // valid for the lifetime of the controller. + RootCgroup() Cgroup + + // NumCgroups returns the number of cgroups managed by this controller. + // Returned value is a snapshot in time. + NumCgroups() uint64 + + // Enabled returns whether this controller is enabled. Returned value is a + // snapshot in time. + Enabled() bool +} + +// Cgroup represents a named pointer to a cgroup in cgroupfs. When a task enters +// a cgroup, it holds a reference on the underlying dentry pointing to the +// cgroup. +// +// +stateify savable +type Cgroup struct { + *kernfs.Dentry + CgroupImpl +} + +func (c *Cgroup) decRef() { + c.Dentry.DecRef(context.Background()) +} + +// Path returns the absolute path of c, relative to its hierarchy root. +func (c *Cgroup) Path() string { + return c.FSLocalPath() +} + +// HierarchyID returns the id of the hierarchy that contains this cgroup. +func (c *Cgroup) HierarchyID() uint32 { + // Note: a cgroup is guaranteed to have at least one controller. + return c.Controllers()[0].HierarchyID() +} + +// CgroupImpl is the common interface to cgroups. +type CgroupImpl interface { + Controllers() []CgroupController + Enter(t *Task) + Leave(t *Task) +} + +// hierarchy represents a cgroupfs filesystem instance, with a unique set of +// controllers attached to it. Multiple cgroupfs mounts may reference the same +// hierarchy. +// +// +stateify savable +type hierarchy struct { + id uint32 + // These are a subset of the controllers in CgroupRegistry.controllers, + // grouped here by hierarchy for conveninent lookup. + controllers map[CgroupControllerType]CgroupController + // fs is not owned by hierarchy. The FS is responsible for unregistering the + // hierarchy on destruction, which removes this association. + fs *vfs.Filesystem +} + +func (h *hierarchy) match(ctypes []CgroupControllerType) bool { + if len(ctypes) != len(h.controllers) { + return false + } + for _, ty := range ctypes { + if _, ok := h.controllers[ty]; !ok { + return false + } + } + return true +} + +// CgroupRegistry tracks the active set of cgroup controllers on the system. +// +// +stateify savable +type CgroupRegistry struct { + // lastHierarchyID is the id of the last allocated cgroup hierarchy. Valid + // ids are from 1 to math.MaxUint32. Must be accessed through atomic ops. + // + lastHierarchyID uint32 + + mu sync.Mutex `state:"nosave"` + + // controllers is the set of currently known cgroup controllers on the + // system. Protected by mu. + // + // +checklocks:mu + controllers map[CgroupControllerType]CgroupController + + // hierarchies is the active set of cgroup hierarchies. Protected by mu. + // + // +checklocks:mu + hierarchies map[uint32]hierarchy +} + +func newCgroupRegistry() *CgroupRegistry { + return &CgroupRegistry{ + controllers: make(map[CgroupControllerType]CgroupController), + hierarchies: make(map[uint32]hierarchy), + } +} + +// nextHierarchyID returns a newly allocated, unique hierarchy ID. +func (r *CgroupRegistry) nextHierarchyID() (uint32, error) { + if hid := atomic.AddUint32(&r.lastHierarchyID, 1); hid != 0 { + return hid, nil + } + return InvalidCgroupHierarchyID, fmt.Errorf("cgroup hierarchy ID overflow") +} + +// FindHierarchy returns a cgroup filesystem containing exactly the set of +// controllers named in names. If no such FS is found, FindHierarchy return +// nil. FindHierarchy takes a reference on the returned FS, which is transferred +// to the caller. +func (r *CgroupRegistry) FindHierarchy(ctypes []CgroupControllerType) *vfs.Filesystem { + r.mu.Lock() + defer r.mu.Unlock() + + for _, h := range r.hierarchies { + if h.match(ctypes) { + h.fs.IncRef() + return h.fs + } + } + + return nil +} + +// Register registers the provided set of controllers with the registry as a new +// hierarchy. If any controller is already registered, the function returns an +// error without modifying the registry. The hierarchy can be later referenced +// by the returned id. +func (r *CgroupRegistry) Register(cs []CgroupController) (uint32, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if len(cs) == 0 { + return InvalidCgroupHierarchyID, fmt.Errorf("can't register hierarchy with no controllers") + } + + for _, c := range cs { + if _, ok := r.controllers[c.Type()]; ok { + return InvalidCgroupHierarchyID, fmt.Errorf("controllers may only be mounted on a single hierarchy") + } + } + + hid, err := r.nextHierarchyID() + if err != nil { + return hid, err + } + + h := hierarchy{ + id: hid, + controllers: make(map[CgroupControllerType]CgroupController), + fs: cs[0].Filesystem(), + } + for _, c := range cs { + n := c.Type() + r.controllers[n] = c + h.controllers[n] = c + } + r.hierarchies[hid] = h + return hid, nil +} + +// Unregister removes a previously registered hierarchy from the registry. If +// the controller was not previously registered, Unregister is a no-op. +func (r *CgroupRegistry) Unregister(hid uint32) { + r.mu.Lock() + defer r.mu.Unlock() + + if h, ok := r.hierarchies[hid]; ok { + for name, _ := range h.controllers { + delete(r.controllers, name) + } + delete(r.hierarchies, hid) + } +} + +// computeInitialGroups takes a reference on each of the returned cgroups. The +// caller takes ownership of this returned reference. +func (r *CgroupRegistry) computeInitialGroups(inherit map[Cgroup]struct{}) map[Cgroup]struct{} { + r.mu.Lock() + defer r.mu.Unlock() + + ctlSet := make(map[CgroupControllerType]CgroupController) + cgset := make(map[Cgroup]struct{}) + + // Remember controllers from the inherited cgroups set... + for cg, _ := range inherit { + cg.IncRef() // Ref transferred to caller. + for _, ctl := range cg.Controllers() { + ctlSet[ctl.Type()] = ctl + cgset[cg] = struct{}{} + } + } + + // ... and add the root cgroups of all the missing controllers. + for name, ctl := range r.controllers { + if _, ok := ctlSet[name]; !ok { + cg := ctl.RootCgroup() + cg.IncRef() // Ref transferred to caller. + cgset[cg] = struct{}{} + } + } + return cgset +} + +// GenerateProcCgroups writes the contents of /proc/cgroups to buf. +func (r *CgroupRegistry) GenerateProcCgroups(buf *bytes.Buffer) { + r.mu.Lock() + entries := make([]string, 0, len(r.controllers)) + for _, c := range r.controllers { + en := 0 + if c.Enabled() { + en = 1 + } + entries = append(entries, fmt.Sprintf("%s\t%d\t%d\t%d\n", c.Type(), c.HierarchyID(), c.NumCgroups(), en)) + } + r.mu.Unlock() + + sort.Strings(entries) + fmt.Fprint(buf, "#subsys_name\thierarchy\tnum_cgroups\tenabled\n") + for _, e := range entries { + fmt.Fprint(buf, e) + } +} diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 43065b45a..e6e9da898 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -294,6 +294,11 @@ type Kernel struct { // YAMAPtraceScope is the current level of YAMA ptrace restrictions. YAMAPtraceScope int32 + + // cgroupRegistry contains the set of active cgroup controllers on the + // system. It is controller by cgroupfs. Nil if cgroupfs is unavailable on + // the system. + cgroupRegistry *CgroupRegistry } // InitKernelArgs holds arguments to Init. @@ -438,6 +443,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.socketMount = socketMount k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) + + k.cgroupRegistry = newCgroupRegistry() } return nil } @@ -1815,6 +1822,11 @@ func (k *Kernel) SocketMount() *vfs.Mount { return k.socketMount } +// CgroupRegistry returns the cgroup registry. +func (k *Kernel) CgroupRegistry() *CgroupRegistry { + return k.cgroupRegistry +} + // Release releases resources owned by k. // // Precondition: This should only be called after the kernel is fully @@ -1831,3 +1843,43 @@ func (k *Kernel) Release() { k.timekeeper.Destroy() k.vdso.Release(ctx) } + +// PopulateNewCgroupHierarchy moves all tasks into a newly created cgroup +// hierarchy. +// +// Precondition: root must be a new cgroup with no tasks. This implies the +// controllers for root are also new and currently manage no task, which in turn +// implies the new cgroup can be populated without migrating tasks between +// cgroups. +func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) { + k.tasks.mu.RLock() + k.tasks.forEachTaskLocked(func(t *Task) { + if t.exitState != TaskExitNone { + return + } + t.mu.Lock() + t.enterCgroupLocked(root) + t.mu.Unlock() + }) + k.tasks.mu.RUnlock() +} + +// ReleaseCgroupHierarchy moves all tasks out of all cgroups belonging to the +// hierarchy with the provided id. This is intended for use during hierarchy +// teardown, as otherwise the tasks would be orphaned w.r.t to some controllers. +func (k *Kernel) ReleaseCgroupHierarchy(hid uint32) { + k.tasks.mu.RLock() + k.tasks.forEachTaskLocked(func(t *Task) { + if t.exitState != TaskExitNone { + return + } + t.mu.Lock() + for cg, _ := range t.cgroups { + if cg.HierarchyID() == hid { + t.leaveCgroupLocked(cg) + } + } + t.mu.Unlock() + }) + k.tasks.mu.RUnlock() +} diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index 399985039..be1371855 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -587,6 +587,12 @@ type Task struct { // // kcov is exclusive to the task goroutine. kcov *Kcov + + // cgroups is the set of cgroups this task belongs to. This may be empty if + // no cgroup controllers are enabled. Protected by mu. + // + // +checklocks:mu + cgroups map[Cgroup]struct{} } func (t *Task) savePtraceTracer() *Task { diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go new file mode 100644 index 000000000..25d2504fa --- /dev/null +++ b/pkg/sentry/kernel/task_cgroup.go @@ -0,0 +1,138 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernel + +import ( + "bytes" + "fmt" + "sort" + "strings" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/syserror" +) + +// EnterInitialCgroups moves t into an initial set of cgroups. +// +// Precondition: t isn't in any cgroups yet, t.cgs is empty. +// +// +checklocksignore parent.mu is conditionally acquired. +func (t *Task) EnterInitialCgroups(parent *Task) { + var inherit map[Cgroup]struct{} + if parent != nil { + parent.mu.Lock() + defer parent.mu.Unlock() + inherit = parent.cgroups + } + joinSet := t.k.cgroupRegistry.computeInitialGroups(inherit) + + t.mu.Lock() + defer t.mu.Unlock() + // Transfer ownership of joinSet refs to the task's cgset. + t.cgroups = joinSet + for c, _ := range t.cgroups { + // Since t isn't in any cgroup yet, we can skip the check against + // existing cgroups. + c.Enter(t) + } +} + +// EnterCgroup moves t into c. +func (t *Task) EnterCgroup(c Cgroup) error { + newControllers := make(map[CgroupControllerType]struct{}) + for _, ctl := range c.Controllers() { + newControllers[ctl.Type()] = struct{}{} + } + + t.mu.Lock() + defer t.mu.Unlock() + + for oldCG, _ := range t.cgroups { + for _, oldCtl := range oldCG.Controllers() { + if _, ok := newControllers[oldCtl.Type()]; ok { + // Already in a cgroup with the same controller as one of the + // new ones. Requires migration between cgroups. + // + // TODO(b/183137098): Implement cgroup migration. + log.Warningf("Cgroup migration is not implemented") + return syserror.EBUSY + } + } + } + + // No migration required. + t.enterCgroupLocked(c) + + return nil +} + +// +checklocks:t.mu +func (t *Task) enterCgroupLocked(c Cgroup) { + c.IncRef() + t.cgroups[c] = struct{}{} + c.Enter(t) +} + +// LeaveCgroups removes t out from all its cgroups. +func (t *Task) LeaveCgroups() { + t.mu.Lock() + defer t.mu.Unlock() + for c, _ := range t.cgroups { + t.leaveCgroupLocked(c) + } +} + +// +checklocks:t.mu +func (t *Task) leaveCgroupLocked(c Cgroup) { + c.Leave(t) + delete(t.cgroups, c) + c.decRef() +} + +// taskCgroupEntry represents a line in /proc/<pid>/cgroup, and is used to +// format a cgroup for display. +type taskCgroupEntry struct { + hierarchyID uint32 + controllers string + path string +} + +// GenerateProcTaskCgroup writes the contents of /proc/<pid>/cgroup for t to buf. +func (t *Task) GenerateProcTaskCgroup(buf *bytes.Buffer) { + t.mu.Lock() + defer t.mu.Unlock() + + cgEntries := make([]taskCgroupEntry, 0, len(t.cgroups)) + for c, _ := range t.cgroups { + ctls := c.Controllers() + ctlNames := make([]string, 0, len(ctls)) + for _, ctl := range ctls { + ctlNames = append(ctlNames, string(ctl.Type())) + } + + cgEntries = append(cgEntries, taskCgroupEntry{ + // Note: We're guaranteed to have at least one controller, and all + // controllers are guaranteed to be on the same hierarchy. + hierarchyID: ctls[0].HierarchyID(), + controllers: strings.Join(ctlNames, ","), + path: c.Path(), + }) + } + + sort.Slice(cgEntries, func(i, j int) bool { return cgEntries[i].hierarchyID > cgEntries[j].hierarchyID }) + for _, cgE := range cgEntries { + fmt.Fprintf(buf, "%d:%s:%s\n", cgE.hierarchyID, cgE.controllers, cgE.path) + } +} diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index ad59e4f60..b1af1a7ef 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -275,6 +275,10 @@ func (*runExitMain) execute(t *Task) taskRunState { t.fsContext.DecRef(t) t.fdTable.DecRef(t) + // Detach task from all cgroups. This must happen before potentially the + // last ref to the cgroupfs mount is dropped below. + t.LeaveCgroups() + t.mu.Lock() if t.mountNamespaceVFS2 != nil { t.mountNamespaceVFS2.DecRef(t) diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index fc18b6253..32031cd70 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -151,6 +151,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { rseqSignature: cfg.RSeqSignature, futexWaiter: futex.NewWaiter(), containerID: cfg.ContainerID, + cgroups: make(map[Cgroup]struct{}), } t.creds.Store(cfg.Credentials) t.endStopCond.L = &t.tg.signalHandlers.mu @@ -189,6 +190,10 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { t.parent.children[t] = struct{}{} } + if VFS2Enabled { + t.EnterInitialCgroups(t.parent) + } + if tg.leader == nil { // New thread group. tg.leader = t diff --git a/pkg/sentry/kernel/task_syscall.go b/pkg/sentry/kernel/task_syscall.go index 2c658d001..601fc0d3a 100644 --- a/pkg/sentry/kernel/task_syscall.go +++ b/pkg/sentry/kernel/task_syscall.go @@ -30,8 +30,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -var vsyscallCount = metric.MustCreateNewUint64Metric("/kernel/vsyscall_count", false /* sync */, "Number of times vsyscalls were invoked by the application") - // SyscallRestartBlock represents the restart block for a syscall restartable // with a custom function. It encapsulates the state required to restart a // syscall across a S/R. @@ -284,7 +282,7 @@ func (*runSyscallExit) execute(t *Task) taskRunState { // indicated by an execution fault at address addr. doVsyscall returns the // task's next run state. func (t *Task) doVsyscall(addr hostarch.Addr, sysno uintptr) taskRunState { - vsyscallCount.Increment() + metric.WeirdnessMetric.Increment("vsyscall_count") // Grab the caller up front, to make sure there's a sensible stack. caller := t.Arch().Native(uintptr(0)) diff --git a/pkg/sentry/kernel/threads.go b/pkg/sentry/kernel/threads.go index 09d070ec8..77ad62445 100644 --- a/pkg/sentry/kernel/threads.go +++ b/pkg/sentry/kernel/threads.go @@ -114,6 +114,15 @@ func (ts *TaskSet) forEachThreadGroupLocked(f func(tg *ThreadGroup)) { } } +// forEachTaskLocked applies f to each Task in ts. +// +// Preconditions: ts.mu must be locked (for reading or writing). +func (ts *TaskSet) forEachTaskLocked(f func(t *Task)) { + for t := range ts.Root.tids { + f(t) + } +} + // A PIDNamespace represents a PID namespace, a bimap between thread IDs and // tasks. See the pid_namespaces(7) man page for further details. // diff --git a/pkg/sentry/loader/BUILD b/pkg/sentry/loader/BUILD index ecb6603a1..4c65215fa 100644 --- a/pkg/sentry/loader/BUILD +++ b/pkg/sentry/loader/BUILD @@ -11,11 +11,12 @@ go_library( "vdso.go", "vdso_state.go", ], + marshal = True, + marshal_debug = True, visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/cpuid", "//pkg/hostarch", diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index e92d9fdc3..8fc3e2a79 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/hostarch" @@ -47,10 +46,10 @@ const ( var ( // header64Size is the size of elf.Header64. - header64Size = int(binary.Size(elf.Header64{})) + header64Size = (*linux.ElfHeader64)(nil).SizeBytes() // Prog64Size is the size of elf.Prog64. - prog64Size = int(binary.Size(elf.Prog64{})) + prog64Size = (*linux.ElfProg64)(nil).SizeBytes() ) func progFlagsAsPerms(f elf.ProgFlag) hostarch.AccessType { @@ -136,7 +135,6 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { log.Infof("Unsupported ELF endianness: %v", endian) return elfInfo{}, syserror.ENOEXEC } - byteOrder := binary.LittleEndian if version := elf.Version(ident[elf.EI_VERSION]); version != elf.EV_CURRENT { log.Infof("Unsupported ELF version: %v", version) @@ -145,7 +143,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { // EI_OSABI is ignored by Linux, which is the only OS supported. os := abi.Linux - var hdr elf.Header64 + var hdr linux.ElfHeader64 hdrBuf := make([]byte, header64Size) _, err = f.ReadFull(ctx, usermem.BytesIOSequence(hdrBuf), 0) if err != nil { @@ -156,7 +154,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { } return elfInfo{}, err } - binary.Unmarshal(hdrBuf, byteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBuf) // We support amd64 and arm64. var a arch.Arch @@ -213,8 +211,8 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { phdrs := make([]elf.ProgHeader, hdr.Phnum) for i := range phdrs { - var prog64 elf.Prog64 - binary.Unmarshal(phdrBuf[:prog64Size], byteOrder, &prog64) + var prog64 linux.ElfProg64 + prog64.UnmarshalUnsafe(phdrBuf[:prog64Size]) phdrBuf = phdrBuf[prog64Size:] phdrs[i] = elf.ProgHeader{ Type: elf.ProgType(prog64.Type), diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 72868646a..610686ea0 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -375,6 +375,11 @@ type MMapOpts struct { // // If Force is true, Unmap and Fixed must be true. Force bool + + // SentryOwnedContent indicates the sentry exclusively controls the + // underlying memory backing the mapping thus the memory content is + // guaranteed not to be modified outside the sentry's purview. + SentryOwnedContent bool } // File represents a host file that may be mapped into an platform.AddressSpace. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index f04898dc1..b307832fd 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -65,6 +65,7 @@ go_test( name = "kvm_test", srcs = [ "kvm_amd64_test.go", + "kvm_amd64_test.s", "kvm_arm64_test.go", "kvm_test.go", "virtual_map_test.go", diff --git a/pkg/sentry/platform/kvm/bluepill.go b/pkg/sentry/platform/kvm/bluepill.go index fd1131638..bb9967b9f 100644 --- a/pkg/sentry/platform/kvm/bluepill.go +++ b/pkg/sentry/platform/kvm/bluepill.go @@ -16,7 +16,6 @@ package kvm import ( "fmt" - "reflect" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/ring0" @@ -36,6 +35,14 @@ func sighandler() // dieArchSetup and the assembly implementation for dieTrampoline. func dieTrampoline() +// Return the start address of the functions above. +// +// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal +// wrapper function rather than the function itself. We must reference from +// assembly to get the ABI0 (i.e., primary) address. +func addrOfSighandler() uintptr +func addrOfDieTrampoline() uintptr + var ( // bounceSignal is the signal used for bouncing KVM. // @@ -87,10 +94,10 @@ func (c *vCPU) die(context *arch.SignalContext64, msg string) { func init() { // Install the handler. - if err := safecopy.ReplaceSignalHandler(bluepillSignal, reflect.ValueOf(sighandler).Pointer(), &savedHandler); err != nil { + if err := safecopy.ReplaceSignalHandler(bluepillSignal, addrOfSighandler(), &savedHandler); err != nil { panic(fmt.Sprintf("Unable to set handler for signal %d: %v", bluepillSignal, err)) } // Extract the address for the trampoline. - dieTrampolineAddr = reflect.ValueOf(dieTrampoline).Pointer() + dieTrampolineAddr = addrOfDieTrampoline() } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64.s b/pkg/sentry/platform/kvm/bluepill_amd64.s index 025ea93b5..953024600 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64.s +++ b/pkg/sentry/platform/kvm/bluepill_amd64.s @@ -81,8 +81,20 @@ fallback: MOVQ ·savedHandler(SB), AX JMP AX +// func addrOfSighandler() uintptr +TEXT ·addrOfSighandler(SB), $0-8 + MOVQ $·sighandler(SB), AX + MOVQ AX, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_amd64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 PUSHQ BX // First argument (vCPU). PUSHQ AX // Fake the old RIP as caller. JMP ·dieHandler(SB) + +// func addrOfDieTrampoline() uintptr +TEXT ·addrOfDieTrampoline(SB), $0-8 + MOVQ $·dieTrampoline(SB), AX + MOVQ AX, ret+0(FP) + RET diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.s b/pkg/sentry/platform/kvm/bluepill_arm64.s index 09c7e88e5..308f2a951 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.s +++ b/pkg/sentry/platform/kvm/bluepill_arm64.s @@ -92,6 +92,12 @@ fallback: MOVD ·savedHandler(SB), R7 B (R7) +// func addrOfSighandler() uintptr +TEXT ·addrOfSighandler(SB), $0-8 + MOVD $·sighandler(SB), R0 + MOVD R0, ret+0(FP) + RET + // dieTrampoline: see bluepill.go, bluepill_arm64_unsafe.go for documentation. TEXT ·dieTrampoline(SB),NOSPLIT,$0 // R0: Fake the old PC as caller @@ -99,3 +105,9 @@ TEXT ·dieTrampoline(SB),NOSPLIT,$0 MOVD.P R1, 8(RSP) // R1: First argument (vCPU) MOVD.P R0, 8(RSP) // R0: Fake the old PC as caller B ·dieHandler(SB) + +// func addrOfDieTrampoline() uintptr +TEXT ·addrOfDieTrampoline(SB), $0-8 + MOVD $·dieTrampoline(SB), R0 + MOVD R0, ret+0(FP) + RET diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.go b/pkg/sentry/platform/kvm/kvm_amd64_test.go index e44e995a0..b8dd1e4a5 100644 --- a/pkg/sentry/platform/kvm/kvm_amd64_test.go +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.go @@ -49,3 +49,40 @@ func TestSegments(t *testing.T) { return false }) } + +// stmxcsr reads the MXCSR control and status register. +func stmxcsr(addr *uint32) + +func TestMXCSR(t *testing.T) { + applicationTest(t, true, testutil.SyscallLoop, func(c *vCPU, regs *arch.Registers, pt *pagetables.PageTables) bool { + var si arch.SignalInfo + switchOpts := ring0.SwitchOpts{ + Registers: regs, + FloatingPointState: &dummyFPState, + PageTables: pt, + FullRestore: true, + } + + const mxcsrControllMask = uint32(0x1f80) + mxcsrBefore := uint32(0) + mxcsrAfter := uint32(0) + stmxcsr(&mxcsrBefore) + if mxcsrBefore == 0 { + // goruntime sets mxcsr to 0x1f80 and it never changes + // the control configuration. + panic("mxcsr is zero") + } + switchOpts.FloatingPointState.SetMXCSR(0) + if _, err := c.SwitchToUser( + switchOpts, &si); err == platform.ErrContextInterrupt { + return true // Retry. + } else if err != nil { + t.Errorf("application syscall failed: %v", err) + } + stmxcsr(&mxcsrAfter) + if mxcsrAfter&mxcsrControllMask != mxcsrBefore&mxcsrControllMask { + t.Errorf("mxcsr = %x (expected %x)", mxcsrBefore, mxcsrAfter) + } + return false + }) +} diff --git a/pkg/sentry/platform/kvm/kvm_amd64_test.s b/pkg/sentry/platform/kvm/kvm_amd64_test.s new file mode 100644 index 000000000..8e9079867 --- /dev/null +++ b/pkg/sentry/platform/kvm/kvm_amd64_test.s @@ -0,0 +1,21 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +// stmxcsr reads the MXCSR control and status register. +TEXT ·stmxcsr(SB),NOSPLIT,$0-8 + MOVQ addr+0(FP), SI + STMXCSR (SI) + RET diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 2492d57be..eb2dcccac 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -66,6 +66,7 @@ const ( _KVM_CAP_ARM_VM_IPA_SIZE = 0xa5 _KVM_CAP_VCPU_EVENTS = 0x29 _KVM_CAP_ARM_INJECT_SERROR_ESR = 0x9e + _KVM_CAP_TSC_CONTROL = 0x3c ) // KVM limits. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 6d90eaefa..1b5d5f66e 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -67,6 +67,9 @@ type machine struct { // maxSlots is the maximum number of memory slots supported by the machine. maxSlots int + // tscControl checks whether cpu supports TSC scaling + tscControl bool + // usedSlots is the set of used physical addresses (sorted). usedSlots []uintptr @@ -212,6 +215,11 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of slots is %d.", m.maxSlots) m.usedSlots = make([]uintptr, m.maxSlots) + // Check TSC Scaling + hasTSCControl, _, errno := unix.RawSyscall(unix.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_TSC_CONTROL) + m.tscControl = errno == 0 && hasTSCControl == 1 + log.Debugf("TSC scaling support: %t.", m.tscControl) + // Create the upper shared pagetables and kernel(sentry) pagetables. m.upperSharedPageTables = pagetables.New(newAllocator()) m.mapUpperHalf(m.upperSharedPageTables) diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index d1b2c9c92..9a2337654 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -216,6 +216,11 @@ func (c *vCPU) setSystemTime() error { // capabilities as it is emulated in KVM. We don't actually use this // capability, but it means that this method should be robust to // different hardware configurations. + + // if tsc scaling is not supported, fallback to legacy mode + if !c.machine.tscControl { + return c.setSystemTimeLegacy() + } rawFreq, err := c.getTSCFreq() if err != nil { return c.setSystemTimeLegacy() @@ -349,6 +354,10 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) // allocations occur. entersyscall() bluepill(c) + // The root table physical page has to be mapped to not fault in iret + // or sysret after switching into a user address space. sysret and + // iret are in the upper half that is global and already mapped. + switchOpts.PageTables.PrefaultRootTable() prefaultFloatingPointState(switchOpts.FloatingPointState) vector = c.CPU.SwitchToUser(switchOpts) exitsyscall() diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 242bee833..8926b1d9f 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -56,7 +56,7 @@ const ( // Beyond a relatively small number, there are likely few perform // benefits, since the TLB has likely long since lost any translations // from more than a few PCIDs past. - poolPCIDs = 8 + poolPCIDs = 128 ) func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index 1dd184586..92edc992b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -232,7 +232,7 @@ func (c *vCPU) setSystemTime() error { } // Is this past minIterations and within ~10% of minimum? upperThreshold := (((minimum << 3) + minimum) >> 3) - if iter >= minIterations && ( current <= upperThreshold || minimum < 50 ) { + if iter >= minIterations && (current <= upperThreshold || minimum < 50) { // Try to set the TSC if err := c.setTSC(end + (minimum / 2)); err != nil { return err diff --git a/pkg/sentry/platform/ptrace/stub_amd64.s b/pkg/sentry/platform/ptrace/stub_amd64.s index 16f9c523e..d5c3f901f 100644 --- a/pkg/sentry/platform/ptrace/stub_amd64.s +++ b/pkg/sentry/platform/ptrace/stub_amd64.s @@ -109,6 +109,12 @@ parent_dead: SYSCALL HLT +// func addrOfStub() uintptr +TEXT ·addrOfStub(SB), $0-8 + MOVQ $·stub(SB), AX + MOVQ AX, ret+0(FP) + RET + // stubCall calls the stub function at the given address with the given PPID. // // This is a distinct function because stub, above, may be mapped at any diff --git a/pkg/sentry/platform/ptrace/stub_arm64.s b/pkg/sentry/platform/ptrace/stub_arm64.s index 6162df02a..4664cd4ad 100644 --- a/pkg/sentry/platform/ptrace/stub_arm64.s +++ b/pkg/sentry/platform/ptrace/stub_arm64.s @@ -102,6 +102,12 @@ parent_dead: SVC HLT +// func addrOfStub() uintptr +TEXT ·addrOfStub(SB), $0-8 + MOVD $·stub(SB), R0 + MOVD R0, ret+0(FP) + RET + // stubCall calls the stub function at the given address with the given PPID. // // This is a distinct function because stub, above, may be mapped at any diff --git a/pkg/sentry/platform/ptrace/stub_unsafe.go b/pkg/sentry/platform/ptrace/stub_unsafe.go index 5c9b7784f..1fbdea898 100644 --- a/pkg/sentry/platform/ptrace/stub_unsafe.go +++ b/pkg/sentry/platform/ptrace/stub_unsafe.go @@ -26,6 +26,13 @@ import ( // stub is defined in arch-specific assembly. func stub() +// addrOfStub returns the start address of stub. +// +// In Go 1.17+, Go references to assembly functions resolve to an ABIInternal +// wrapper function rather than the function itself. We must reference from +// assembly to get the ABI0 (i.e., primary) address. +func addrOfStub() uintptr + // stubCall calls the stub at the given address with the given pid. func stubCall(addr, pid uintptr) @@ -41,7 +48,7 @@ func unsafeSlice(addr uintptr, length int) (slice []byte) { // stubInit initializes the stub. func stubInit() { // Grab the existing stub. - stubBegin := reflect.ValueOf(stub).Pointer() + stubBegin := addrOfStub() stubLen := int(safecopy.FindEndAddress(stubBegin) - stubBegin) stubSlice := unsafeSlice(stubBegin, stubLen) mapLen := uintptr(stubLen) diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 080859125..7ee89a735 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -8,7 +8,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/hostarch", "//pkg/marshal", diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 0e0e82365..2029e7cf4 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -14,9 +14,11 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/context", "//pkg/hostarch", + "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 45a05cd63..235b9c306 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -18,9 +18,11 @@ package control import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -193,7 +195,7 @@ func putUint32(buf []byte, n uint32) []byte { // putCmsg writes a control message header and as much data as will fit into // the unused capacity of a buffer. func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) { - space := binary.AlignDown(cap(buf)-len(buf), 4) + space := bits.AlignDown(cap(buf)-len(buf), 4) // We can't write to space that doesn't exist, so if we are going to align // the available space, we must align down. @@ -230,7 +232,7 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([ return alignSlice(buf, align), flags } -func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte { +func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data marshal.Marshallable) []byte { if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader { return buf } @@ -241,8 +243,7 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = putUint32(buf, msgType) hdrBuf := buf - - buf = binary.Marshal(buf, hostarch.ByteOrder, data) + buf = append(buf, marshal.Marshal(data)...) // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { @@ -288,7 +289,7 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int // alignSlice extends a slice's length (up to the capacity) to align it. func alignSlice(buf []byte, align uint) []byte { - aligned := binary.AlignUp(len(buf), align) + aligned := bits.AlignUp(len(buf), align) if aligned > cap(buf) { // Linux allows unaligned data if there isn't room for alignment. // Since there isn't room for alignment, there isn't room for any @@ -300,12 +301,13 @@ func alignSlice(buf []byte, align uint) []byte { // PackTimestamp packs a SO_TIMESTAMP socket control message. func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte { + timestampP := linux.NsecToTimeval(timestamp) return putCmsgStruct( buf, linux.SOL_SOCKET, linux.SO_TIMESTAMP, t.Arch().Width(), - linux.NsecToTimeval(timestamp), + ×tampP, ) } @@ -316,7 +318,7 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { linux.SOL_TCP, linux.TCP_INQ, t.Arch().Width(), - inq, + primitive.AllocateInt32(inq), ) } @@ -327,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { linux.SOL_IP, linux.IP_TOS, t.Arch().Width(), - tos, + primitive.AllocateUint8(tos), ) } @@ -338,7 +340,7 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { linux.SOL_IPV6, linux.IPV6_TCLASS, t.Arch().Width(), - tClass, + primitive.AllocateUint32(tClass), ) } @@ -423,7 +425,7 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt // cmsgSpace is equivalent to CMSG_SPACE in Linux. func cmsgSpace(t *kernel.Task, dataLen int) int { - return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width()) + return linux.SizeOfControlMessageHeader + bits.AlignUp(dataLen, t.Arch().Width()) } // CmsgsSpace returns the number of bytes needed to fit the control messages @@ -475,7 +477,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) + h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, syserror.EINVAL @@ -491,7 +493,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) case linux.SOL_SOCKET: switch h.Type { case linux.SCM_RIGHTS: - rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) numRights := rightsSize / linux.SizeOfControlMessageRight if len(fds)+numRights > linux.SCM_MAX_FD { @@ -502,7 +504,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) } - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -510,23 +512,23 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) + creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, syserror.EINVAL } var ts linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &ts) + ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) cmsgs.IP.Timestamp = ts.ToNsecCapped() cmsgs.IP.HasTimestamp = true - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: // Unknown message type. @@ -539,8 +541,10 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTOS = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &cmsgs.IP.TOS) - i += binary.AlignUp(length, width) + var tos primitive.Uint8 + tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + cmsgs.IP.TOS = uint8(tos) + i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -549,19 +553,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) + packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) cmsgs.IP.PacketInfo = packetInfo - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -571,7 +575,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) cmsgs.IP.SockErr = &errCmsg - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL @@ -583,17 +587,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, syserror.EINVAL } cmsgs.IP.HasTClass = true - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &cmsgs.IP.TClass) - i += binary.AlignUp(length, width) + var tclass primitive.Uint32 + tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + cmsgs.IP.TClass = uint32(tclass) + i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, syserror.EINVAL } - binary.Unmarshal(buf[i:i+addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) cmsgs.IP.OriginalDstAddress = &addr - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -603,7 +609,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) cmsgs.IP.SockErr = &errCmsg - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, syserror.EINVAL diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a5c2155a2..2e3064565 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -17,7 +17,6 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/fdnotifier", "//pkg/hostarch", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index a784e23b5..52ae4bc9c 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -19,7 +19,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/hostarch" @@ -528,24 +527,28 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], hostarch.ByteOrder, &controlMessages.IP.Timestamp) + ts := linux.Timeval{} + ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) + controlMessages.IP.Timestamp = ts.ToNsecCapped() } case linux.SOL_IP: switch unixCmsg.Header.Type { case linux.IP_TOS: controlMessages.IP.HasTOS = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], hostarch.ByteOrder, &controlMessages.IP.TOS) + var tos primitive.Uint8 + tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()]) + controlMessages.IP.TOS = uint8(tos) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], hostarch.ByteOrder, &packetInfo) + packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()]) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -558,11 +561,13 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], hostarch.ByteOrder, &controlMessages.IP.TClass) + var tclass primitive.Uint32 + tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()]) + controlMessages.IP.TClass = uint32(tclass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -575,7 +580,9 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s switch unixCmsg.Header.Type { case linux.TCP_INQ: controlMessages.IP.HasInq = true - binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], hostarch.ByteOrder, &controlMessages.IP.Inq) + var inq primitive.Int32 + inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq]) + controlMessages.IP.Inq = int32(inq) } } } @@ -689,7 +696,7 @@ func (s *socketOpsCommon) State() uint32 { return 0 } - binary.Unmarshal(buf, hostarch.ByteOrder, &info) + info.UnmarshalUnsafe(buf[:info.SizeBytes()]) return uint32(info.State) } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 26e8ae17a..393a1ab3a 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -15,6 +15,7 @@ package hostinet import ( + "encoding/binary" "fmt" "io" "io/ioutil" @@ -26,10 +27,10 @@ import ( "syscall" "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" @@ -147,8 +148,8 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli if len(link.Data) < unix.SizeofIfInfomsg { return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } - var ifinfo unix.IfInfomsg - binary.Unmarshal(link.Data[:unix.SizeofIfInfomsg], hostarch.ByteOrder, &ifinfo) + var ifinfo linux.InterfaceInfoMessage + ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()]) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -178,11 +179,11 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli if len(addr.Data) < unix.SizeofIfAddrmsg { return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } - var ifaddr unix.IfAddrmsg - binary.Unmarshal(addr.Data[:unix.SizeofIfAddrmsg], hostarch.ByteOrder, &ifaddr) + var ifaddr linux.InterfaceAddrMessage + ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()]) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, - PrefixLen: ifaddr.Prefixlen, + PrefixLen: ifaddr.PrefixLen, Flags: ifaddr.Flags, } attrs, err := syscall.ParseNetlinkRouteAttr(&addr) @@ -210,13 +211,13 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) continue } - var ifRoute unix.RtMsg - binary.Unmarshal(routeMsg.Data[:unix.SizeofRtMsg], hostarch.ByteOrder, &ifRoute) + var ifRoute linux.RouteMessage + ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()]) inetRoute := inet.Route{ Family: ifRoute.Family, - DstLen: ifRoute.Dst_len, - SrcLen: ifRoute.Src_len, - TOS: ifRoute.Tos, + DstLen: ifRoute.DstLen, + SrcLen: ifRoute.SrcLen, + TOS: ifRoute.TOS, Table: ifRoute.Table, Protocol: ifRoute.Protocol, Scope: ifRoute.Scope, @@ -245,7 +246,9 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) if len(attr.Value) != expected { return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected) } - binary.Unmarshal(attr.Value, hostarch.ByteOrder, &inetRoute.OutputInterface) + var outputIF primitive.Int32 + outputIF.UnmarshalUnsafe(attr.Value) + inetRoute.OutputInterface = int32(outputIF) } } diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 4381dfa06..61b2c9755 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -14,14 +14,16 @@ go_library( "tcp_matcher.go", "udp_matcher.go", ], + marshal = True, # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/hostarch", "//pkg/log", + "//pkg/marshal", "//pkg/sentry/kernel", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 4bd305a44..6fc7781ad 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -79,7 +78,7 @@ func marshalEntryMatch(name string, data []byte) []byte { nflog("marshaling matcher %q", name) // We have to pad this struct size to a multiple of 8 bytes. - size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) + size := bits.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) matcher := linux.KernelXTEntryMatch{ XTEntryMatch: linux.XTEntryMatch{ MatchSize: uint16(size), @@ -88,9 +87,11 @@ func marshalEntryMatch(name string, data []byte) []byte { } copy(matcher.Name[:], name) - buf := make([]byte, 0, size) - buf = binary.Marshal(buf, hostarch.ByteOrder, matcher) - return append(buf, make([]byte, size-len(buf))...) + buf := make([]byte, size) + entryLen := matcher.XTEntryMatch.SizeBytes() + matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen]) + copy(buf[entryLen:], matcher.Data) + return buf } func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 1fc4cb651..cb78ef60b 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -18,8 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -141,10 +139,9 @@ func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } var entry linux.IPTEntry - buf := optVal[:linux.SizeOfIPTEntry] - binary.Unmarshal(buf, hostarch.ByteOrder, &entry) + entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIPTEntry:] + optVal = optVal[entry.SizeBytes():] if entry.TargetOffset < linux.SizeOfIPTEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 67a52b628..5cb7fe4aa 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -18,8 +18,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -144,10 +142,9 @@ func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, return nil, syserr.ErrInvalidArgument } var entry linux.IP6TEntry - buf := optVal[:linux.SizeOfIP6TEntry] - binary.Unmarshal(buf, hostarch.ByteOrder, &entry) + entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[linux.SizeOfIP6TEntry:] + optVal = optVal[entry.SizeBytes():] if entry.TargetOffset < linux.SizeOfIP6TEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 5200e08ed..f42d73178 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -22,7 +22,6 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -121,7 +120,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } - if binary.Size(entries) > uintptr(outLen) { + if entries.SizeBytes() > outLen { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } @@ -146,7 +145,7 @@ func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr hostarch.Addr, outLe nflog("couldn't read entries: %v", err) return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } - if binary.Size(entries) > uintptr(outLen) { + if entries.SizeBytes() > outLen { nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } @@ -179,7 +178,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] - binary.Unmarshal(replaceBuf, hostarch.ByteOrder, &replace) + replace.UnmarshalBytes(replaceBuf) // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table @@ -274,10 +273,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } // TODO(gvisor.dev/issue/170): Support other chains. - // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, - // make sure all other chains point to ACCEPT rules. + // Since we don't support FORWARD, yet, make sure all other chains point to + // ACCEPT rules. for hook, ruleIdx := range table.BuiltinChains { - if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if hook := stack.Hook(hook); hook == stack.Forward { if ruleIdx == stack.HookUnset { continue } @@ -309,8 +308,8 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch - buf := optVal[:linux.SizeOfXTEntryMatch] - binary.Unmarshal(buf, hostarch.ByteOrder, &match) + buf := optVal[:match.SizeBytes()] + match.UnmarshalUnsafe(buf) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index b2cc6be20..60845cab3 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -59,8 +58,8 @@ func (ownerMarshaler) marshal(mr matcher) []byte { } } - buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) - return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, hostarch.ByteOrder, iptOwnerInfo)) + buf := marshal.Marshal(&iptOwnerInfo) + return marshalEntryMatch(matcherNameOwner, buf) } // unmarshal implements matchMaker.unmarshal. @@ -72,7 +71,7 @@ func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack. // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 80f8c6430..e94aceb92 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,11 +15,12 @@ package netfilter import ( + "encoding/binary" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -35,6 +36,11 @@ const ErrorTargetName = "ERROR" // change the destination port and/or IP for packets. const RedirectTargetName = "REDIRECT" +// SNATTargetName is used to mark targets as SNAT targets. SNAT targets should +// be reached for only NAT table. These targets will change the source port +// and/or IP for packets. +const SNATTargetName = "SNAT" + func init() { // Standard targets include ACCEPT, DROP, RETURN, and JUMP. registerTargetMaker(&standardTargetMaker{ @@ -59,6 +65,13 @@ func init() { registerTargetMaker(&nfNATTargetMaker{ NetworkProtocol: header.IPv6ProtocolNumber, }) + + registerTargetMaker(&snatTargetMakerV4{ + NetworkProtocol: header.IPv4ProtocolNumber, + }) + registerTargetMaker(&snatTargetMakerV6{ + NetworkProtocol: header.IPv6ProtocolNumber, + }) } // The stack package provides some basic, useful targets for us. The following @@ -131,6 +144,17 @@ func (rt *redirectTarget) id() targetID { } } +type snatTarget struct { + stack.SNATTarget +} + +func (st *snatTarget) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + } +} + type standardTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } @@ -166,8 +190,7 @@ func (*standardTargetMaker) marshal(target target) []byte { Verdict: verdict, } - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -176,8 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } var standardTarget linux.XTStandardTarget - buf = buf[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &standardTarget) + standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()]) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -222,8 +244,7 @@ func (*errorTargetMaker) marshal(target target) []byte { copy(xt.Name[:], errorName) copy(xt.Target.Name[:], ErrorTargetName) - ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -233,7 +254,7 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &errTgt) + errTgt.UnmarshalUnsafe(buf) // Error targets are used in 2 cases: // * An actual error case. These rules have an error named @@ -276,12 +297,11 @@ func (*redirectTargetMaker) marshal(target target) []byte { } copy(xt.Target.Name[:], RedirectTargetName) - ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) xt.NfRange.RangeSize = 1 xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED xt.NfRange.RangeIPV4.MinPort = htons(rt.Port) xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort - return binary.Marshal(ret, hostarch.ByteOrder, xt) + return marshal.Marshal(&xt) } func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { @@ -297,7 +317,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &rt) + rt.UnmarshalUnsafe(buf) // Copy linux.XTRedirectTarget to stack.RedirectTarget. target := redirectTarget{RedirectTarget: stack.RedirectTarget{ @@ -336,12 +356,13 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return &target, nil } +// +marshal type nfNATTarget struct { Target linux.XTEntryTarget Range linux.NFNATRange } -const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange +const nfNATMarshalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange type nfNATTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber @@ -358,7 +379,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte { rt := target.(*redirectTarget) nt := nfNATTarget{ Target: linux.XTEntryTarget{ - TargetSize: nfNATMarhsalledSize, + TargetSize: nfNATMarshalledSize, }, Range: linux.NFNATRange{ Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, @@ -371,12 +392,11 @@ func (*nfNATTargetMaker) marshal(target target) []byte { nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto - ret := make([]byte, 0, nfNATMarhsalledSize) - return binary.Marshal(ret, hostarch.ByteOrder, nt) + return marshal.Marshal(&nt) } func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { - if size := nfNATMarhsalledSize; len(buf) < size { + if size := nfNATMarshalledSize; len(buf) < size { nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) return nil, syserr.ErrInvalidArgument } @@ -387,8 +407,8 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize] - binary.Unmarshal(buf, hostarch.ByteOrder, &natRange) + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] + natRange.UnmarshalUnsafe(buf) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -418,6 +438,159 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return &target, nil } +type snatTargetMakerV4 struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (st *snatTargetMakerV4) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + } +} + +func (*snatTargetMakerV4) marshal(target target) []byte { + st := target.(*snatTarget) + // This is a snat target named snat. + xt := linux.XTSNATTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTSNATTarget, + }, + } + copy(xt.Target.Name[:], SNATTargetName) + + xt.NfRange.RangeSize = 1 + xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED + xt.NfRange.RangeIPV4.MinPort = htons(st.Port) + xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort + copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr) + copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr) + return marshal.Marshal(&xt) +} + +func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { + if len(buf) < linux.SizeOfXTSNATTarget { + nflog("snatTargetMakerV4: buf has insufficient size for snat target %d", len(buf)) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("snatTargetMakerV4: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var st linux.XTSNATTarget + buf = buf[:linux.SizeOfXTSNATTarget] + st.UnmarshalUnsafe(buf) + + // Copy linux.XTSNATTarget to stack.SNATTarget. + target := snatTarget{SNATTarget: stack.SNATTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }} + + // RangeSize should be 1. + nfRange := st.NfRange + if nfRange.RangeSize != 1 { + nflog("snatTargetMakerV4: bad rangesize %d", nfRange.RangeSize) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/5772): If the rule doesn't specify the source port, + // choose one automatically. + if nfRange.RangeIPV4.MinPort == 0 { + nflog("snatTargetMakerV4: snat target needs to specify a non-zero port") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP { + nflog("snatTargetMakerV4: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) + return nil, syserr.ErrInvalidArgument + } + + target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.Port = ntohs(nfRange.RangeIPV4.MinPort) + + return &target, nil +} + +type snatTargetMakerV6 struct { + NetworkProtocol tcpip.NetworkProtocolNumber +} + +func (st *snatTargetMakerV6) id() targetID { + return targetID{ + name: SNATTargetName, + networkProtocol: st.NetworkProtocol, + revision: 1, + } +} + +func (*snatTargetMakerV6) marshal(target target) []byte { + st := target.(*snatTarget) + nt := nfNATTarget{ + Target: linux.XTEntryTarget{ + TargetSize: nfNATMarshalledSize, + }, + Range: linux.NFNATRange{ + Flags: linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED, + }, + } + copy(nt.Target.Name[:], SNATTargetName) + copy(nt.Range.MinAddr[:], st.Addr) + copy(nt.Range.MaxAddr[:], st.Addr) + nt.Range.MinProto = htons(st.Port) + nt.Range.MaxProto = nt.Range.MinProto + + return marshal.Marshal(&nt) +} + +func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { + if size := nfNATMarshalledSize; len(buf) < size { + nflog("snatTargetMakerV6: buf has insufficient size (%d) for SNAT V6 target (%d)", len(buf), size) + return nil, syserr.ErrInvalidArgument + } + + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + nflog("snatTargetMakerV6: bad proto %d", p) + return nil, syserr.ErrInvalidArgument + } + + var natRange linux.NFNATRange + buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] + natRange.UnmarshalUnsafe(buf) + + // TODO(gvisor.dev/issue/5689): Support port or address ranges. + if natRange.MinAddr != natRange.MaxAddr { + nflog("snatTargetMakerV6: MinAddr and MaxAddr are different") + return nil, syserr.ErrInvalidArgument + } + if natRange.MinProto != natRange.MaxProto { + nflog("snatTargetMakerV6: MinProto and MaxProto are different") + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/5698): Support other NF_NAT_RANGE flags. + if natRange.Flags != linux.NF_NAT_RANGE_MAP_IPS|linux.NF_NAT_RANGE_PROTO_SPECIFIED { + nflog("snatTargetMakerV6: invalid range flags %d", natRange.Flags) + return nil, syserr.ErrInvalidArgument + } + + target := snatTarget{ + SNATTarget: stack.SNATTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Addr: tcpip.Address(natRange.MinAddr[:]), + Port: ntohs(natRange.MinProto), + }, + } + + return &target, nil +} + // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { @@ -453,8 +626,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget - buf := optVal[:linux.SizeOfXTEntryTarget] - binary.Unmarshal(buf, hostarch.ByteOrder, &target) + target.UnmarshalUnsafe(optVal[:target.SizeBytes()]) return unmarshalTarget(target, filter, optVal) } @@ -480,7 +652,7 @@ func (jt *JumpTarget) id() targetID { } // Action implements stack.Target.Action. -func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 69557f515..95bb9826e 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,8 +46,7 @@ func (tcpMarshaler) marshal(mr matcher) []byte { DestinationPortStart: matcher.destinationPortStart, DestinationPortEnd: matcher.destinationPortEnd, } - buf := make([]byte, 0, linux.SizeOfXTTCP) - return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, hostarch.ByteOrder, xttcp)) + return marshalEntryMatch(matcherNameTCP, marshal.Marshal(&xttcp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +58,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - binary.Unmarshal(buf[:linux.SizeOfXTTCP], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 6a60e6bd6..fb8be27e6 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -18,8 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -47,8 +46,7 @@ func (udpMarshaler) marshal(mr matcher) []byte { DestinationPortStart: matcher.destinationPortStart, DestinationPortEnd: matcher.destinationPortEnd, } - buf := make([]byte, 0, linux.SizeOfXTUDP) - return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, hostarch.ByteOrder, xtudp)) + return marshalEntryMatch(matcherNameUDP, marshal.Marshal(&xtudp)) } // unmarshal implements matchMaker.unmarshal. @@ -60,7 +58,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - binary.Unmarshal(buf[:linux.SizeOfXTUDP], hostarch.ByteOrder, &matchData) + matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 171b95c63..64cd263da 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -14,7 +14,7 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/binary", + "//pkg/bits", "//pkg/context", "//pkg/hostarch", "//pkg/marshal", @@ -50,5 +50,7 @@ go_test( deps = [ ":netlink", "//pkg/abi/linux", + "//pkg/marshal", + "//pkg/marshal/primitive", ], ) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index ab0e68af7..80385bfdc 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -19,15 +19,17 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/hostarch" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" ) // alignPad returns the length of padding required for alignment. // // Preconditions: align is a power of two. func alignPad(length int, align uint) int { - return binary.AlignUp(length, align) - length + return bits.AlignUp(length, align) - length } // Message contains a complete serialized netlink message. @@ -42,7 +44,7 @@ type Message struct { func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ hdr: hdr, - buf: binary.Marshal(nil, hostarch.ByteOrder, hdr), + buf: marshal.Marshal(&hdr), } } @@ -58,7 +60,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { return } var hdr linux.NetlinkMessageHeader - binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBytes) // Msg portion. totalMsgLen := int(hdr.Length) @@ -92,7 +94,7 @@ func (m *Message) Header() linux.NetlinkMessageHeader { // GetData unmarshals the payload message header from this netlink message, and // returns the attributes portion. -func (m *Message) GetData(msg interface{}) (AttrsView, bool) { +func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) { b := BytesView(m.buf) _, ok := b.Extract(linux.NetlinkMessageHeaderSize) @@ -100,12 +102,12 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) { return nil, false } - size := int(binary.Size(msg)) + size := msg.SizeBytes() msgBytes, ok := b.Extract(size) if !ok { return nil, false } - binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg) + msg.UnmarshalUnsafe(msgBytes) numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) // Linux permits the last message not being aligned, just consume all of it. @@ -131,7 +133,7 @@ func (m *Message) Finalize() []byte { // Align the message. Note that the message length in the header (set // above) is the useful length of the message, not the total aligned // length. See net/netlink/af_netlink.c:__nlmsg_put. - aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) + aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) m.putZeros(aligned - len(m.buf)) return m.buf } @@ -145,45 +147,45 @@ func (m *Message) putZeros(n int) { } // Put serializes v into the message. -func (m *Message) Put(v interface{}) { - m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v) +func (m *Message) Put(v marshal.Marshallable) { + m.buf = append(m.buf, marshal.Marshal(v)...) } // PutAttr adds v to the message as a netlink attribute. // // Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize + -// binary.Size(v) fits in math.MaxUint16 bytes. -func (m *Message) PutAttr(atype uint16, v interface{}) { - l := linux.NetlinkAttrHeaderSize + int(binary.Size(v)) +// v.SizeBytes()) fits in math.MaxUint16 bytes. +func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) { + l := linux.NetlinkAttrHeaderSize + v.SizeBytes() if l > math.MaxUint16 { panic(fmt.Sprintf("attribute too large: %d", l)) } - m.Put(linux.NetlinkAttrHeader{ + m.Put(&linux.NetlinkAttrHeader{ Type: atype, Length: uint16(l), }) m.Put(v) // Align the attribute. - aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) + aligned := bits.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } // PutAttrString adds s to the message as a netlink attribute. func (m *Message) PutAttrString(atype uint16, s string) { l := linux.NetlinkAttrHeaderSize + len(s) + 1 - m.Put(linux.NetlinkAttrHeader{ + m.Put(&linux.NetlinkAttrHeader{ Type: atype, Length: uint16(l), }) // String + NUL-termination. - m.Put([]byte(s)) + m.Put(primitive.AsByteSlice([]byte(s))) m.putZeros(1) // Align the attribute. - aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) + aligned := bits.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } @@ -251,7 +253,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest if !ok { return } - binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr) + hdr.UnmarshalUnsafe(hdrBytes) value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) if !ok { diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index ef13d9386..968968469 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -20,13 +20,31 @@ import ( "testing" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) type dummyNetlinkMsg struct { + marshal.StubMarshallable Foo uint16 } +func (*dummyNetlinkMsg) SizeBytes() int { + return 2 +} + +func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { + p := primitive.Uint16(m.Foo) + p.MarshalUnsafe(dst) +} + +func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { + var p primitive.Uint16 + p.UnmarshalUnsafe(src) + m.Foo = uint16(p) +} + func TestParseMessage(t *testing.T) { tests := []struct { desc string diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 744fc74f4..c6c04b4e3 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -11,6 +11,7 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/marshal/primitive", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 5a2255db3..86f6419dc 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -21,6 +21,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -167,7 +168,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { Type: linux.RTM_NEWLINK, }) - m.Put(linux.InterfaceInfoMessage{ + m.Put(&linux.InterfaceInfoMessage{ Family: linux.AF_UNSPEC, Type: i.DeviceType, Index: idx, @@ -175,7 +176,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { }) m.PutAttrString(linux.IFLA_IFNAME, i.Name) - m.PutAttr(linux.IFLA_MTU, i.MTU) + m.PutAttr(linux.IFLA_MTU, primitive.AllocateUint32(i.MTU)) mac := make([]byte, 6) brd := mac @@ -183,8 +184,8 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { mac = i.Addr brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) } - m.PutAttr(linux.IFLA_ADDRESS, mac) - m.PutAttr(linux.IFLA_BROADCAST, brd) + m.PutAttr(linux.IFLA_ADDRESS, primitive.AsByteSlice(mac)) + m.PutAttr(linux.IFLA_BROADCAST, primitive.AsByteSlice(brd)) // TODO(gvisor.dev/issue/578): There are many more attributes. } @@ -216,14 +217,15 @@ func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netl Type: linux.RTM_NEWADDR, }) - m.Put(linux.InterfaceAddrMessage{ + m.Put(&linux.InterfaceAddrMessage{ Family: a.Family, PrefixLen: a.PrefixLen, Index: uint32(id), }) - m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr)) - m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) + addr := primitive.ByteSlice([]byte(a.Addr)) + m.PutAttr(linux.IFA_LOCAL, &addr) + m.PutAttr(linux.IFA_ADDRESS, &addr) // TODO(gvisor.dev/issue/578): There are many more attributes. } @@ -366,7 +368,7 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net Type: linux.RTM_NEWROUTE, }) - m.Put(linux.RouteMessage{ + m.Put(&linux.RouteMessage{ Family: rt.Family, DstLen: rt.DstLen, SrcLen: rt.SrcLen, @@ -382,18 +384,18 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net Flags: rt.Flags, }) - m.PutAttr(254, []byte{123}) + m.PutAttr(254, primitive.AsByteSlice([]byte{123})) if rt.DstLen > 0 { - m.PutAttr(linux.RTA_DST, rt.DstAddr) + m.PutAttr(linux.RTA_DST, primitive.AsByteSlice(rt.DstAddr)) } if rt.SrcLen > 0 { - m.PutAttr(linux.RTA_SRC, rt.SrcAddr) + m.PutAttr(linux.RTA_SRC, primitive.AsByteSlice(rt.SrcAddr)) } if rt.OutputInterface != 0 { - m.PutAttr(linux.RTA_OIF, rt.OutputInterface) + m.PutAttr(linux.RTA_OIF, primitive.AllocateInt32(rt.OutputInterface)) } if len(rt.GatewayAddr) > 0 { - m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr) + m.PutAttr(linux.RTA_GATEWAY, primitive.AsByteSlice(rt.GatewayAddr)) } // TODO(gvisor.dev/issue/578): There are many more attributes. @@ -503,7 +505,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms hdr := msg.Header() // All messages start with a 1 byte protocol family. - var family uint8 + var family primitive.Uint8 if _, ok := msg.GetData(&family); !ok { // Linux ignores messages missing the protocol family. See // net/core/rtnetlink.c:rtnetlink_rcv_msg. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 30c297149..d75a2879f 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -20,7 +20,6 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -223,7 +222,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa) + sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument @@ -338,16 +337,14 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr } s.mu.Lock() defer s.mu.Unlock() - sendBufferSizeP := primitive.Int32(s.sendBufferSize) - return &sendBufferSizeP, nil + return primitive.AllocateInt32(int32(s.sendBufferSize)), nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - recvBufferSizeP := primitive.Int32(math.MaxInt32) - return &recvBufferSizeP, nil + return primitive.AllocateInt32(math.MaxInt32), nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -484,7 +481,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * Family: linux.AF_NETLINK, PortID: uint32(s.portID), } - return sa, uint32(binary.Size(sa)), nil + return sa, uint32(sa.SizeBytes()), nil } // GetPeerName implements socket.Socket.GetPeerName. @@ -495,7 +492,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * // must be the kernel. PortID: 0, } - return sa, uint32(binary.Size(sa)), nil + return sa, uint32(sa.SizeBytes()), nil } // RecvMsg implements socket.Socket.RecvMsg. @@ -504,7 +501,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags Family: linux.AF_NETLINK, PortID: 0, } - fromLen := uint32(binary.Size(from)) + fromLen := uint32(from.SizeBytes()) trunc := flags&linux.MSG_TRUNC != 0 @@ -640,7 +637,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys }) // Add the dump_done_errno payload. - m.Put(int64(0)) + m.Put(primitive.AllocateInt64(0)) _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { @@ -658,7 +655,7 @@ func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ + m.Put(&linux.NetlinkErrorMessage{ Error: int32(-err.ToLinux().Number()), Header: hdr, }) @@ -668,7 +665,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ + m.Put(&linux.NetlinkErrorMessage{ Error: 0, Header: hdr, }) diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index 0b39a5b67..9561b7c25 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -19,7 +19,6 @@ go_library( ], deps = [ "//pkg/abi/linux", - "//pkg/binary", "//pkg/context", "//pkg/hostarch", "//pkg/log", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index ed6572bab..60ef33360 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,6 +26,7 @@ package netstack import ( "bytes" + "encoding/binary" "fmt" "io" "io/ioutil" @@ -35,7 +36,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/log" @@ -199,6 +199,13 @@ var Metrics = tcpip.Stats{ OptionRecordRouteReceived: mustCreateMetric("/netstack/ip/options/record_route_received", "Number of record route options found in received IP packets."), OptionRouterAlertReceived: mustCreateMetric("/netstack/ip/options/router_alert_received", "Number of router alert options found in received IP packets."), OptionUnknownReceived: mustCreateMetric("/netstack/ip/options/unknown_received", "Number of unknown options found in received IP packets."), + Forwarding: tcpip.IPForwardingStats{ + Unrouteable: mustCreateMetric("/netstack/ip/forwarding/unrouteable", "Number of IP packets received which couldn't be routed and thus were not forwarded."), + ExhaustedTTL: mustCreateMetric("/netstack/ip/forwarding/exhausted_ttl", "Number of IP packets received which could not be forwarded due to an exhausted TTL."), + LinkLocalSource: mustCreateMetric("/netstack/ip/forwarding/link_local_source_address", "Number of IP packets received which could not be forwarded due to a link-local source address."), + LinkLocalDestination: mustCreateMetric("/netstack/ip/forwarding/link_local_destination_address", "Number of IP packets received which could not be forwarded due to a link-local destination address."), + Errors: mustCreateMetric("/netstack/ip/forwarding/errors", "Number of IP packets which couldn't be forwarded."), + }, }, ARP: tcpip.ARPStats{ PacketsReceived: mustCreateMetric("/netstack/arp/packets_received", "Number of ARP packets received from the link layer."), @@ -242,6 +249,7 @@ var Metrics = tcpip.Stats{ FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."), Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."), ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."), + FailedPortReservations: mustCreateMetric("/netstack/tcp/failed_port_reservations", "Number of time TCP failed to reserve a port."), }, UDP: tcpip.UDPStats{ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."), @@ -374,9 +382,9 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue }), nil } -var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{})) -var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{})) -var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{})) +var sockAddrInetSize = (*linux.SockAddrInet)(nil).SizeBytes() +var sockAddrInet6Size = (*linux.SockAddrInet6)(nil).SizeBytes() +var sockAddrLinkSize = (*linux.SockAddrLink)(nil).SizeBytes() // bytesToIPAddress converts an IPv4 or IPv6 address from the user to the // netstack representation taking any addresses into account. @@ -612,7 +620,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - binary.Unmarshal(sockaddr[:sockAddrLinkSize], hostarch.ByteOrder, &a) + a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) if a.Protocol != uint16(s.protocol) { return syserr.ErrInvalidArgument @@ -885,10 +893,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } + size := ep.SocketOptions().GetReceiveBufferSize() if size > math.MaxInt32 { size = math.MaxInt32 @@ -1314,7 +1319,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return &v, nil case linux.IP6T_ORIGINAL_DST: - if outLen < int(binary.Size(linux.SockAddrInet6{})) { + if outLen < sockAddrInet6Size { return nil, syserr.ErrInvalidArgument } @@ -1511,7 +1516,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return &v, nil case linux.SO_ORIGINAL_DST: - if outLen < int(binary.Size(linux.SockAddrInet{})) { + if outLen < sockAddrInetSize { return nil, syserr.ErrInvalidArgument } @@ -1661,7 +1666,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := hostarch.ByteOrder.Uint32(optVal) - ep.SocketOptions().SetSendBufferSize(int64(v), true) + ep.SocketOptions().SetSendBufferSize(int64(v), true /* notify */) return nil case linux.SO_RCVBUF: @@ -1670,7 +1675,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := hostarch.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) + ep.SocketOptions().SetReceiveBufferSize(int64(v), true /* notify */) + return nil case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -1743,7 +1749,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1756,7 +1762,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - binary.Unmarshal(optVal[:linux.SizeOfTimeval], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1792,7 +1798,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - binary.Unmarshal(optVal[:linux.SizeOfLinger], hostarch.ByteOrder, &v) + v.UnmarshalBytes(optVal[:linux.SizeOfLinger]) + + if v != (linux.Linger{}) { + socket.SetSockOptEmitUnimplementedEvent(t, name) + } ep.SocketOptions().SetLinger(tcpip.LingerOption{ Enabled: v.OnOff != 0, @@ -2091,9 +2101,9 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } var ( - inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{})) - inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{})) - inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{})) + inetMulticastRequestSize = (*linux.InetMulticastRequest)(nil).SizeBytes() + inetMulticastRequestWithNICSize = (*linux.InetMulticastRequestWithNIC)(nil).SizeBytes() + inet6MulticastRequestSize = (*linux.Inet6MulticastRequest)(nil).SizeBytes() ) // copyInMulticastRequest copies in a variable-size multicast request. The @@ -2118,12 +2128,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], hostarch.ByteOrder, &req) + req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize]) return req, nil } var req linux.InetMulticastRequestWithNIC - binary.Unmarshal(optVal[:inetMulticastRequestSize], hostarch.ByteOrder, &req.InetMulticastRequest) + req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize]) return req, nil } @@ -2133,7 +2143,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - binary.Unmarshal(optVal[:inet6MulticastRequestSize], hostarch.ByteOrder, &req) + req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize]) return req, nil } @@ -3102,8 +3112,8 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe continue } // Populate ifr.ifr_netmask (type sockaddr). - hostarch.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET)) - hostarch.ByteOrder.PutUint16(ifr.Data[2:4], 0) + hostarch.ByteOrder.PutUint16(ifr.Data[0:], uint16(linux.AF_INET)) + hostarch.ByteOrder.PutUint16(ifr.Data[2:], 0) var mask uint32 = 0xffffffff << (32 - addr.PrefixLen) // Netmask is expected to be returned as a big endian // value. diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 4c3d48096..9e56487a6 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -24,7 +24,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal" @@ -572,19 +571,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - binary.Unmarshal(data[:unix.SizeofSockaddrInet4], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - binary.Unmarshal(data[:unix.SizeofSockaddrInet6], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - binary.Unmarshal(data[:unix.SizeofSockaddrUnix], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - binary.Unmarshal(data[:unix.SizeofSockaddrNetlink], hostarch.ByteOrder, &addr) + addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -716,7 +715,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInetSize], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrInetSize]) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -729,7 +728,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrInet6Size], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrInet6Size]) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -745,7 +744,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - binary.Unmarshal(addr[:sockAddrLinkSize], hostarch.ByteOrder, &a) + a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 159b8f90f..33f9aeb06 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -130,7 +130,8 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv } ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } @@ -175,8 +176,9 @@ func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider idGenerator: uid, stype: stype, } - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ep.ops.SetSendBufferSize(connected.SendMaxQueueSize(), false /* notify */) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) return ep } @@ -299,8 +301,9 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } - ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits) + ne.ops.InitHandler(ne, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) + ne.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize} readQueue.InitRefs() @@ -343,11 +346,11 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn return nil default: - // Busy; return ECONNREFUSED per spec. + // Busy; return EAGAIN per spec. ne.Close(ctx) e.Unlock() ce.Unlock() - return syserr.ErrConnectionRefused + return syserr.ErrTryAgain } } @@ -366,6 +369,7 @@ func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint // to reflect this endpoint's send buffer size. if bufSz := e.connected.SetSendBufferSize(e.ops.GetSendBufferSize()); bufSz != e.ops.GetSendBufferSize() { e.ops.SetSendBufferSize(bufSz, false /* notify */) + e.ops.SetReceiveBufferSize(bufSz, false /* notify */) } } diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go index 590b0bd01..b20334d4f 100644 --- a/pkg/sentry/socket/unix/transport/connectioned_state.go +++ b/pkg/sentry/socket/unix/transport/connectioned_state.go @@ -54,5 +54,5 @@ func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEnd // afterLoad is invoked by stateify. func (e *connectionedEndpoint) afterLoad() { - e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) } diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index d0df28b59..61338728a 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -45,7 +45,8 @@ func NewConnectionless(ctx context.Context) Endpoint { q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) - ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits) + ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) + ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) return ep } diff --git a/pkg/sentry/socket/unix/transport/connectionless_state.go b/pkg/sentry/socket/unix/transport/connectionless_state.go index 2ef337ec8..1bb71baf7 100644 --- a/pkg/sentry/socket/unix/transport/connectionless_state.go +++ b/pkg/sentry/socket/unix/transport/connectionless_state.go @@ -16,5 +16,5 @@ package transport // afterLoad is invoked by stateify. func (e *connectionlessEndpoint) afterLoad() { - e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits) + e.ops.InitHandler(e, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits) } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 0c5f5ab42..837ab4fde 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -868,11 +868,7 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { } func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.ReceiveBufferSizeOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } + log.Warningf("Unsupported socket option: %d", opt) return nil } @@ -905,19 +901,6 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { } return int(v), nil - case tcpip.ReceiveBufferSizeOption: - e.Lock() - if e.receiver == nil { - e.Unlock() - return -1, &tcpip.ErrNotConnected{} - } - v := e.receiver.RecvMaxQueueSize() - e.Unlock() - if v < 0 { - return -1, &tcpip.ErrQueueSizeNotSupported{} - } - return int(v), nil - default: log.Warningf("Unsupported socket option: %d", opt) return -1, &tcpip.ErrUnknownProtocolOption{} @@ -1029,3 +1012,15 @@ func getSendBufferLimits(tcpip.StackHandler) tcpip.SendBufferSizeOption { Max: maxBufferSize, } } + +// getReceiveBufferLimits implements tcpip.GetReceiveBufferLimits. +// +// We define min, max and default values for unix socket implementation. Unix +// sockets do not use receive buffer. +func getReceiveBufferLimits(tcpip.StackHandler) tcpip.ReceiveBufferSizeOption { + return tcpip.ReceiveBufferSizeOption{ + Min: minimumBufferSize, + Default: defaultBufferSize, + Max: maxBufferSize, + } +} diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index 2ebd77f82..1fbbd133c 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -25,7 +25,6 @@ go_library( ":strace_go_proto", "//pkg/abi", "//pkg/abi/linux", - "//pkg/binary", "//pkg/bits", "//pkg/eventchannel", "//pkg/hostarch", diff --git a/pkg/sentry/strace/linux64_amd64.go b/pkg/sentry/strace/linux64_amd64.go index 71b92eaee..d66befe81 100644 --- a/pkg/sentry/strace/linux64_amd64.go +++ b/pkg/sentry/strace/linux64_amd64.go @@ -371,6 +371,7 @@ var linuxAMD64 = SyscallMap{ 433: makeSyscallInfo("fspick", FD, Path, Hex), 434: makeSyscallInfo("pidfd_open", Hex, Hex), 435: makeSyscallInfo("clone3", Hex, Hex), + 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet), } func init() { diff --git a/pkg/sentry/strace/linux64_arm64.go b/pkg/sentry/strace/linux64_arm64.go index bd7361a52..1a2d7d75f 100644 --- a/pkg/sentry/strace/linux64_arm64.go +++ b/pkg/sentry/strace/linux64_arm64.go @@ -312,6 +312,7 @@ var linuxARM64 = SyscallMap{ 433: makeSyscallInfo("fspick", FD, Path, Hex), 434: makeSyscallInfo("pidfd_open", Hex, Hex), 435: makeSyscallInfo("clone3", Hex, Hex), + 441: makeSyscallInfo("epoll_pwait2", FD, EpollEvents, Hex, Timespec, SigSet), } func init() { diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index e5b7f9b96..f4aab25b0 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -20,14 +20,13 @@ import ( "gvisor.dev/gvisor/pkg/abi" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" - - "gvisor.dev/gvisor/pkg/hostarch" ) // SocketFamily are the possible socket(2) families. @@ -162,6 +161,15 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } +func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights { + count := len(src) / linux.SizeOfControlMessageRight + cmr := make(linux.ControlMessageRights, count) + for i, _ := range cmr { + cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:])) + } + return cmr +} + func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) string { if length > maxBytes { return fmt.Sprintf("%#x (error decoding control: invalid length (%d))", addr, length) @@ -181,7 +189,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var h linux.ControlMessageHeader - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], hostarch.ByteOrder, &h) + h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) var skipData bool level := "SOL_SOCKET" @@ -221,18 +229,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) continue } switch h.Type { case linux.SCM_RIGHTS: - rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) - - numRights := rightsSize / linux.SizeOfControlMessageRight - fds := make(linux.ControlMessageRights, numRights) - binary.Unmarshal(buf[i:i+rightsSize], hostarch.ByteOrder, &fds) - + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) + fds := unmarshalControlMessageRights(buf[i : i+rightsSize]) rights := make([]string, 0, len(fds)) for _, fd := range fds { rights = append(rights, fmt.Sprint(fd)) @@ -258,7 +262,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var creds linux.ControlMessageCredentials - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], hostarch.ByteOrder, &creds) + creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", @@ -282,7 +286,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) } var tv linux.Timeval - binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], hostarch.ByteOrder, &tv) + tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", @@ -296,7 +300,7 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) default: panic("unreachable") } - i += binary.AlignUp(length, width) + i += bits.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) diff --git a/pkg/sentry/syscalls/epoll.go b/pkg/sentry/syscalls/epoll.go index e115683f8..3b4d79889 100644 --- a/pkg/sentry/syscalls/epoll.go +++ b/pkg/sentry/syscalls/epoll.go @@ -119,7 +119,7 @@ func RemoveEpoll(t *kernel.Task, epfd int32, fd int32) error { } // WaitEpoll implements the epoll_wait(2) linux syscall. -func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEvent, error) { +func WaitEpoll(t *kernel.Task, fd int32, max int, timeoutInNanos int64) ([]linux.EpollEvent, error) { // Get epoll from the file descriptor. epollfile := t.GetFile(fd) if epollfile == nil { @@ -136,7 +136,7 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve // Try to read events and return right away if we got them or if the // caller requested a non-blocking "wait". r := e.ReadEvents(max) - if len(r) != 0 || timeout == 0 { + if len(r) != 0 || timeoutInNanos == 0 { return r, nil } @@ -144,8 +144,8 @@ func WaitEpoll(t *kernel.Task, fd int32, max int, timeout int) ([]linux.EpollEve // and register with the epoll object for readability events. var haveDeadline bool var deadline ktime.Time - if timeout > 0 { - timeoutDur := time.Duration(timeout) * time.Millisecond + if timeoutInNanos > 0 { + timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) haveDeadline = true } diff --git a/pkg/sentry/syscalls/linux/error.go b/pkg/sentry/syscalls/linux/error.go index efec93f73..6eabfd219 100644 --- a/pkg/sentry/syscalls/linux/error.go +++ b/pkg/sentry/syscalls/linux/error.go @@ -29,10 +29,17 @@ import ( ) var ( - partialResultMetric = metric.MustCreateNewUint64Metric("/syscalls/partial_result", true /* sync */, "Whether or not a partial result has occurred for this sandbox.") - partialResultOnce sync.Once + partialResultOnce sync.Once ) +// incrementPartialResultMetric increments PartialResultMetric by calling +// Increment(). This is added as the func Do() which is called below requires +// us to pass a function which does not take any arguments, whereas Increment() +// takes a variadic number of arguments. +func incrementPartialResultMetric() { + metric.WeirdnessMetric.Increment("partial_result") +} + // HandleIOErrorVFS2 handles special error cases for partial results. For some // errors, we may consume the error and return only the partial read/write. // @@ -48,7 +55,7 @@ func HandleIOErrorVFS2(ctx context.Context, partialResult bool, ioerr, intr erro root := vfs.RootFromContext(ctx) name, _ := fs.PathnameWithDeleted(ctx, root, f.VirtualDentry()) log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q", partialResult, ioerr, ioerr, op, name) - partialResultOnce.Do(partialResultMetric.Increment) + partialResultOnce.Do(incrementPartialResultMetric) } return nil } @@ -66,7 +73,7 @@ func handleIOError(ctx context.Context, partialResult bool, ioerr, intr error, o // An unknown error is encountered with a partial read/write. name, _ := f.Dirent.FullName(nil /* ignore chroot */) log.Traceback("Invalid request partialResult %v and err (type %T) %v for %s operation on %q, %T", partialResult, ioerr, ioerr, op, name, f.FileOperations) - partialResultOnce.Do(partialResultMetric.Increment) + partialResultOnce.Do(incrementPartialResultMetric) } return nil } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 2d2212605..090c5ffcb 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -404,6 +404,7 @@ var AMD64 = &kernel.SyscallTable{ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{ 0xffffffffff600000: 96, // vsyscall gettimeofday(2) @@ -722,6 +723,7 @@ var ARM64 = &kernel.SyscallTable{ 433: syscalls.ErrorWithEvent("fspick", syserror.ENOSYS, "", nil), 434: syscalls.ErrorWithEvent("pidfd_open", syserror.ENOSYS, "", nil), 435: syscalls.ErrorWithEvent("clone3", syserror.ENOSYS, "", nil), + 441: syscalls.Supported("epoll_pwait2", EpollPwait2), }, Emulate: map[hostarch.Addr]uintptr{}, Missing: func(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, error) { diff --git a/pkg/sentry/syscalls/linux/sys_epoll.go b/pkg/sentry/syscalls/linux/sys_epoll.go index 7f460d30b..69cbc98d0 100644 --- a/pkg/sentry/syscalls/linux/sys_epoll.go +++ b/pkg/sentry/syscalls/linux/sys_epoll.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/epoll" @@ -104,14 +105,8 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } -// EpollWait implements the epoll_wait(2) linux syscall. -func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - epfd := args[0].Int() - eventsAddr := args[1].Pointer() - maxEvents := int(args[2].Int()) - timeout := int(args[3].Int()) - - r, err := syscalls.WaitEpoll(t, epfd, maxEvents, timeout) +func waitEpoll(t *kernel.Task, fd int32, eventsAddr hostarch.Addr, max int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { + r, err := syscalls.WaitEpoll(t, fd, max, timeoutInNanos) if err != nil { return 0, nil, syserror.ConvertIntr(err, syserror.EINTR) } @@ -123,6 +118,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } return uintptr(len(r)), nil, nil + +} + +// EpollWait implements the epoll_wait(2) linux syscall. +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + // Convert milliseconds to nanoseconds. + timeoutInNanos := int64(args[3].Int()) * 1000000 + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) } // EpollPwait implements the epoll_pwait(2) linux syscall. @@ -144,4 +150,38 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } +// EpollPwait2 implements the epoll_pwait(2) linux syscall. +func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutPtr := args[3].Pointer() + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + haveTimeout := timeoutPtr != 0 + + var timeoutInNanos int64 = -1 + if haveTimeout { + timeout, err := copyTimespecIn(t, timeoutPtr) + if err != nil { + return 0, nil, err + } + timeoutInNanos = timeout.ToNsec() + + } + + if maskAddr != 0 { + mask, err := CopyInSigSet(t, maskAddr, maskSize) + if err != nil { + return 0, nil, err + } + + oldmask := t.SignalMask() + t.SetSignalMask(mask) + t.SetSavedSignalMask(oldmask) + } + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) +} + // LINT.ThenChange(vfs2/epoll.go) diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9bdf6d3d8..e07917613 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -35,12 +35,6 @@ import ( // LINT.IfChange -// minListenBacklog is the minimum reasonable backlog for listening sockets. -const minListenBacklog = 8 - -// maxListenBacklog is the maximum allowed backlog for listening sockets. -const maxListenBacklog = 1024 - // maxAddrLen is the maximum socket address length we're willing to accept. const maxAddrLen = 200 @@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8 // buffers upto INT_MAX. const maxControlLen = 10 * 1024 * 1024 +// maxListenBacklog is the maximum limit of listen backlog supported. +const maxListenBacklog = 1024 + // nameLenOffset is the offset from the start of the MessageHeader64 struct to // the NameLen field. const nameLenOffset = 8 @@ -367,7 +364,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Listen implements the linux syscall listen(2). func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() - backlog := args[1].Int() + backlog := args[1].Uint() // Get socket from the file descriptor. file := t.GetFile(fd) @@ -382,14 +379,23 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ENOTSOCK } - // Per Linux, the backlog is silently capped to reasonable values. - if backlog <= 0 { - backlog = minListenBacklog - } if backlog > maxListenBacklog { + // Linux treats incoming backlog as uint with a limit defined by + // sysctl_somaxconn. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666 backlog = maxListenBacklog } + // Accept one more than the configured listen backlog to keep in parity with + // Linux. Ref, because of missing equality check here: + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937 + // + // In case of unix domain sockets, the following check + // https://github.com/torvalds/linux/blob/7d6beb71da3/net/unix/af_unix.c#L1293 + // will allow 1 connect through since it checks for a receive queue len > + // backlog and not >=. + backlog++ + return 0, nil, s.Listen(t, int(backlog)).ToError() } @@ -457,8 +463,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(v.SizeBytes()) - if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/epoll.go b/pkg/sentry/syscalls/linux/vfs2/epoll.go index b980aa43e..047d955b6 100644 --- a/pkg/sentry/syscalls/linux/vfs2/epoll.go +++ b/pkg/sentry/syscalls/linux/vfs2/epoll.go @@ -19,6 +19,7 @@ import ( "time" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" @@ -118,13 +119,7 @@ func EpollCtl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } } -// EpollWait implements Linux syscall epoll_wait(2). -func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - epfd := args[0].Int() - eventsAddr := args[1].Pointer() - maxEvents := int(args[2].Int()) - timeout := int(args[3].Int()) - +func waitEpoll(t *kernel.Task, epfd int32, eventsAddr hostarch.Addr, maxEvents int, timeoutInNanos int64) (uintptr, *kernel.SyscallControl, error) { var _EP_MAX_EVENTS = math.MaxInt32 / sizeofEpollEvent // Linux: fs/eventpoll.c:EP_MAX_EVENTS if maxEvents <= 0 || maxEvents > _EP_MAX_EVENTS { return 0, nil, syserror.EINVAL @@ -158,7 +153,7 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } return 0, nil, err } - if timeout == 0 { + if timeoutInNanos == 0 { return 0, nil, nil } // In the first iteration of this loop, register with the epoll @@ -173,8 +168,8 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys defer epfile.EventUnregister(&w) } else { // Set up the timer if a timeout was specified. - if timeout > 0 && !haveDeadline { - timeoutDur := time.Duration(timeout) * time.Millisecond + if timeoutInNanos > 0 && !haveDeadline { + timeoutDur := time.Duration(timeoutInNanos) * time.Nanosecond deadline = t.Kernel().MonotonicClock().Now().Add(timeoutDur) haveDeadline = true } @@ -186,6 +181,17 @@ func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sys } } } + +} + +// EpollWait implements Linux syscall epoll_wait(2). +func EpollWait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutInNanos := int64(args[3].Int()) * 1000000 + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) } // EpollPwait implements Linux syscall epoll_pwait(2). @@ -199,3 +205,29 @@ func EpollPwait(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return EpollWait(t, args) } + +// EpollPwait2 implements Linux syscall epoll_pwait(2). +func EpollPwait2(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + epfd := args[0].Int() + eventsAddr := args[1].Pointer() + maxEvents := int(args[2].Int()) + timeoutPtr := args[3].Pointer() + maskAddr := args[4].Pointer() + maskSize := uint(args[5].Uint()) + haveTimeout := timeoutPtr != 0 + + var timeoutInNanos int64 = -1 + if haveTimeout { + var timeout linux.Timespec + if _, err := timeout.CopyIn(t, timeoutPtr); err != nil { + return 0, nil, err + } + timeoutInNanos = timeout.ToNsec() + } + + if err := setTempSignalSet(t, maskAddr, maskSize); err != nil { + return 0, nil, err + } + + return waitEpoll(t, epfd, eventsAddr, maxEvents, timeoutInNanos) +} diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index a87a66146..69f69e3af 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -35,12 +35,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" ) -// minListenBacklog is the minimum reasonable backlog for listening sockets. -const minListenBacklog = 8 - -// maxListenBacklog is the maximum allowed backlog for listening sockets. -const maxListenBacklog = 1024 - // maxAddrLen is the maximum socket address length we're willing to accept. const maxAddrLen = 200 @@ -52,6 +46,9 @@ const maxOptLen = 1024 * 8 // buffers upto INT_MAX. const maxControlLen = 10 * 1024 * 1024 +// maxListenBacklog is the maximum limit of listen backlog supported. +const maxListenBacklog = 1024 + // nameLenOffset is the offset from the start of the MessageHeader64 struct to // the NameLen field. const nameLenOffset = 8 @@ -371,7 +368,7 @@ func Bind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC // Listen implements the linux syscall listen(2). func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { fd := args[0].Int() - backlog := args[1].Int() + backlog := args[1].Uint() // Get socket from the file descriptor. file := t.GetFileVFS2(fd) @@ -386,14 +383,23 @@ func Listen(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return 0, nil, syserror.ENOTSOCK } - // Per Linux, the backlog is silently capped to reasonable values. - if backlog <= 0 { - backlog = minListenBacklog - } if backlog > maxListenBacklog { + // Linux treats incoming backlog as uint with a limit defined by + // sysctl_somaxconn. + // https://github.com/torvalds/linux/blob/7acac4b3196/net/socket.c#L1666 backlog = maxListenBacklog } + // Accept one more than the configured listen backlog to keep in parity with + // Linux. Ref, because of missing equality check here: + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/sock.h#L937 + // + // In case of unix domain sockets, the following check + // https://github.com/torvalds/linux/blob/7d6beb71da3/net/unix/af_unix.c#L1293 + // will allow 1 connect through since it checks for a receive queue len > + // backlog and not >=. + backlog++ + return 0, nil, s.Listen(t, int(backlog)).ToError() } @@ -461,8 +467,7 @@ func GetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, e.ToError() } - vLen := int32(v.SizeBytes()) - if _, err := primitive.CopyInt32Out(t, optLenAddr, vLen); err != nil { + if _, err := primitive.CopyInt32Out(t, optLenAddr, int32(v.SizeBytes())); err != nil { return 0, nil, err } diff --git a/pkg/sentry/syscalls/linux/vfs2/vfs2.go b/pkg/sentry/syscalls/linux/vfs2/vfs2.go index c50fd97eb..0fc81e694 100644 --- a/pkg/sentry/syscalls/linux/vfs2/vfs2.go +++ b/pkg/sentry/syscalls/linux/vfs2/vfs2.go @@ -159,6 +159,7 @@ func Override() { s.Table[327] = syscalls.Supported("preadv2", Preadv2) s.Table[328] = syscalls.Supported("pwritev2", Pwritev2) s.Table[332] = syscalls.Supported("statx", Statx) + s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2) s.Init() // Override ARM64. @@ -269,6 +270,7 @@ func Override() { s.Table[286] = syscalls.Supported("preadv2", Preadv2) s.Table[287] = syscalls.Supported("pwritev2", Pwritev2) s.Table[291] = syscalls.Supported("statx", Statx) + s.Table[441] = syscalls.Supported("epoll_pwait2", EpollPwait2) s.Init() } diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 87d8687ce..1f617ca8f 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -32,6 +32,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "//pkg/gohacks", "//pkg/log", "//pkg/metric", "//pkg/sync", diff --git a/pkg/sentry/time/calibrated_clock.go b/pkg/sentry/time/calibrated_clock.go index f9a93115d..39bf1e0de 100644 --- a/pkg/sentry/time/calibrated_clock.go +++ b/pkg/sentry/time/calibrated_clock.go @@ -25,11 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// fallbackMetric tracks failed updates. It is not sync, as it is not critical -// that all occurrences are captured and CalibratedClock may fallback many -// times. -var fallbackMetric = metric.MustCreateNewUint64Metric("/time/fallback", false /* sync */, "Incremented when a clock falls back to system calls due to a failed update") - // CalibratedClock implements a clock that tracks a reference clock. // // Users should call Update at regular intervals of around approxUpdateInterval @@ -102,7 +97,7 @@ func (c *CalibratedClock) resetLocked(str string, v ...interface{}) { c.Warningf(str+" Resetting clock; time may jump.", v...) c.ready = false c.ref.Reset() - fallbackMetric.Increment() + metric.WeirdnessMetric.Increment("time_fallback") } // updateParams updates the timekeeping parameters based on the passed diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index f612a71b2..176bcc242 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -524,7 +524,7 @@ func (fd *FileDescription) Stat(ctx context.Context, opts StatOptions) (linux.St Start: fd.vd, }) stat, err := fd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, err } return fd.impl.Stat(ctx, opts) @@ -539,7 +539,7 @@ func (fd *FileDescription) SetStat(ctx context.Context, opts SetStatOptions) err Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetStatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.SetStat(ctx, opts) @@ -555,7 +555,7 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vd, }) statfs, err := fd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, err } return fd.impl.StatFS(ctx) @@ -701,7 +701,7 @@ func (fd *FileDescription) ListXattr(ctx context.Context, size uint64) ([]string Start: fd.vd, }) names, err := fd.vd.mount.fs.impl.ListXattrAt(ctx, rp, size) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return names, err } names, err := fd.impl.ListXattr(ctx, size) @@ -730,7 +730,7 @@ func (fd *FileDescription) GetXattr(ctx context.Context, opts *GetXattrOptions) Start: fd.vd, }) val, err := fd.vd.mount.fs.impl.GetXattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return val, err } return fd.impl.GetXattr(ctx, *opts) @@ -746,7 +746,7 @@ func (fd *FileDescription) SetXattr(ctx context.Context, opts *SetXattrOptions) Start: fd.vd, }) err := fd.vd.mount.fs.impl.SetXattrAt(ctx, rp, *opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.SetXattr(ctx, *opts) @@ -762,7 +762,7 @@ func (fd *FileDescription) RemoveXattr(ctx context.Context, name string) error { Start: fd.vd, }) err := fd.vd.mount.fs.impl.RemoveXattrAt(ctx, rp, name) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } return fd.impl.RemoveXattr(ctx, name) diff --git a/pkg/sentry/vfs/file_description_impl_util.go b/pkg/sentry/vfs/file_description_impl_util.go index 1556b41a3..b87d9690a 100644 --- a/pkg/sentry/vfs/file_description_impl_util.go +++ b/pkg/sentry/vfs/file_description_impl_util.go @@ -252,6 +252,9 @@ type WritableDynamicBytesSource interface { // are backed by a bytes.Buffer that is regenerated when necessary, consistent // with Linux's fs/seq_file.c:single_open(). // +// If data additionally implements WritableDynamicBytesSource, writes are +// dispatched to the implementer. The source data is not automatically modified. +// // DynamicBytesFileDescriptionImpl.SetDataSource() must be called before first // use. // diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 922f9e697..82fd382c2 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -826,6 +826,9 @@ func (vfs *VirtualFilesystem) GenerateProcMounts(ctx context.Context, taskRootDi if mnt.Flags.NoExec { opts += ",noexec" } + if mopts := mnt.fs.Impl().MountOptions(); mopts != "" { + opts += "," + mopts + } // Format: // <special device or remote filesystem> <mount point> <filesystem type> <mount options> <needs dump> <fsck order> @@ -970,17 +973,22 @@ func superBlockOpts(mountPath string, mnt *Mount) string { opts += "," + mopts } - // NOTE(b/147673608): If the mount is a cgroup, we also need to include - // the cgroup name in the options. For now we just read that from the - // path. + // NOTE(b/147673608): If the mount is a ramdisk-based fake cgroupfs, we also + // need to include the cgroup name in the options. For now we just read that + // from the path. Note that this is only possible when "cgroup" isn't + // registered as a valid filesystem type. // - // TODO(gvisor.dev/issue/190): Once gVisor has full cgroup support, we - // should get this value from the cgroup itself, and not rely on the - // path. + // TODO(gvisor.dev/issue/190): Once we removed fake cgroupfs support, we + // should remove this. + if cgroupfs := mnt.vfs.getFilesystemType("cgroup"); cgroupfs != nil && cgroupfs.opts.AllowUserMount { + // Real cgroupfs available. + return opts + } if mnt.fs.FilesystemType().Name() == "cgroup" { splitPath := strings.Split(mountPath, "/") cgroupType := splitPath[len(splitPath)-1] opts += "," + cgroupType } + return opts } diff --git a/pkg/sentry/vfs/opath.go b/pkg/sentry/vfs/opath.go index 39fbac987..47848c76b 100644 --- a/pkg/sentry/vfs/opath.go +++ b/pkg/sentry/vfs/opath.go @@ -121,7 +121,7 @@ func (fd *opathFD) Stat(ctx context.Context, opts StatOptions) (linux.Statx, err Start: fd.vfsfd.vd, }) stat, err := fd.vfsfd.vd.mount.fs.impl.StatAt(ctx, rp, opts) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, err } @@ -134,6 +134,6 @@ func (fd *opathFD) StatFS(ctx context.Context) (linux.Statfs, error) { Start: fd.vfsfd.vd, }) statfs, err := fd.vfsfd.vd.mount.fs.impl.StatFSAt(ctx, rp) - vfsObj.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, err } diff --git a/pkg/sentry/vfs/resolving_path.go b/pkg/sentry/vfs/resolving_path.go index e4fd55012..97b898aba 100644 --- a/pkg/sentry/vfs/resolving_path.go +++ b/pkg/sentry/vfs/resolving_path.go @@ -44,13 +44,10 @@ type ResolvingPath struct { start *Dentry pit fspath.Iterator - flags uint16 - mustBeDir bool // final file must be a directory? - mustBeDirOrig bool - symlinks uint8 // number of symlinks traversed - symlinksOrig uint8 - curPart uint8 // index into parts - numOrigParts uint8 + flags uint16 + mustBeDir bool // final file must be a directory? + symlinks uint8 // number of symlinks traversed + curPart uint8 // index into parts creds *auth.Credentials @@ -60,14 +57,9 @@ type ResolvingPath struct { nextStart *Dentry // ref held if not nil absSymlinkTarget fspath.Path - // ResolvingPath must track up to two relative paths: the "current" - // relative path, which is updated whenever a relative symlink is - // encountered, and the "original" relative path, which is updated from the - // current relative path by handleError() when resolution must change - // filesystems (due to reaching a mount boundary or absolute symlink) and - // overwrites the current relative path when Restart() is called. - parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator - origParts [1 + linux.MaxSymlinkTraversals]fspath.Iterator + // ResolvingPath tracks relative paths, which is updated whenever a relative + // symlink is encountered. + parts [1 + linux.MaxSymlinkTraversals]fspath.Iterator } const ( @@ -120,6 +112,8 @@ var resolvingPathPool = sync.Pool{ }, } +// getResolvingPath gets a new ResolvingPath from the pool. Caller must call +// ResolvingPath.Release() when done. func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *PathOperation) *ResolvingPath { rp := resolvingPathPool.Get().(*ResolvingPath) rp.vfs = vfs @@ -132,17 +126,37 @@ func (vfs *VirtualFilesystem) getResolvingPath(creds *auth.Credentials, pop *Pat rp.flags |= rpflagsFollowFinalSymlink } rp.mustBeDir = pop.Path.Dir - rp.mustBeDirOrig = pop.Path.Dir rp.symlinks = 0 rp.curPart = 0 - rp.numOrigParts = 1 rp.creds = creds rp.parts[0] = pop.Path.Begin - rp.origParts[0] = pop.Path.Begin return rp } -func (vfs *VirtualFilesystem) putResolvingPath(ctx context.Context, rp *ResolvingPath) { +// Copy creates another ResolvingPath with the same state as the original. +// Copies are independent, using the copy does not change the original and +// vice-versa. +// +// Caller must call Resease() when done. +func (rp *ResolvingPath) Copy() *ResolvingPath { + copy := resolvingPathPool.Get().(*ResolvingPath) + *copy = *rp // All fields all shallow copiable. + + // Take extra reference for the copy if the original had them. + if copy.flags&rpflagsHaveStartRef != 0 { + copy.start.IncRef() + } + if copy.flags&rpflagsHaveMountRef != 0 { + copy.mount.IncRef() + } + // Reset error state. + copy.nextStart = nil + copy.nextMount = nil + return copy +} + +// Release decrements references if needed and returns the object to the pool. +func (rp *ResolvingPath) Release(ctx context.Context) { rp.root = VirtualDentry{} rp.decRefStartAndMount(ctx) rp.mount = nil @@ -240,25 +254,6 @@ func (rp *ResolvingPath) Advance() { } } -// Restart resets the stream of path components represented by rp to its state -// on entry to the current FilesystemImpl method. -func (rp *ResolvingPath) Restart(ctx context.Context) { - rp.pit = rp.origParts[rp.numOrigParts-1] - rp.mustBeDir = rp.mustBeDirOrig - rp.symlinks = rp.symlinksOrig - rp.curPart = rp.numOrigParts - 1 - copy(rp.parts[:], rp.origParts[:rp.numOrigParts]) - rp.releaseErrorState(ctx) -} - -func (rp *ResolvingPath) relpathCommit() { - rp.mustBeDirOrig = rp.mustBeDir - rp.symlinksOrig = rp.symlinks - rp.numOrigParts = rp.curPart + 1 - copy(rp.origParts[:rp.curPart], rp.parts[:]) - rp.origParts[rp.curPart] = rp.pit -} - // CheckRoot is called before resolving the parent of the Dentry d. If the // Dentry is contextually a VFS root, such that path resolution should treat // d's parent as itself, CheckRoot returns (true, nil). If the Dentry is the @@ -405,11 +400,10 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.flags |= rpflagsHaveMountRef | rpflagsHaveStartRef rp.nextMount = nil rp.nextStart = nil - // Commit the previous FileystemImpl's progress through the relative - // path. (Don't consume the path component that caused us to traverse + // Don't consume the path component that caused us to traverse // through the mount root - i.e. the ".." - because we still need to - // resolve the mount point's parent in the new FilesystemImpl.) - rp.relpathCommit() + // resolve the mount point's parent in the new FilesystemImpl. + // // Restart path resolution on the new Mount. Don't bother calling // rp.releaseErrorState() since we already set nextMount and nextStart // to nil above. @@ -425,9 +419,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.nextMount = nil // Consume the path component that represented the mount point. rp.Advance() - // Commit the previous FilesystemImpl's progress through the relative - // path. - rp.relpathCommit() // Restart path resolution on the new Mount. rp.releaseErrorState(ctx) return true @@ -442,9 +433,6 @@ func (rp *ResolvingPath) handleError(ctx context.Context, err error) bool { rp.Advance() // Prepend the symlink target to the relative path. rp.relpathPrepend(rp.absSymlinkTarget) - // Commit the previous FilesystemImpl's progress through the relative - // path, including the symlink target we just prepended. - rp.relpathCommit() // Restart path resolution on the new Mount. rp.releaseErrorState(ctx) return true diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 00f1847d8..87fdcf403 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -208,11 +208,11 @@ func (vfs *VirtualFilesystem) AccessAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.AccessAt(ctx, rp, creds, ats) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -230,11 +230,11 @@ func (vfs *VirtualFilesystem) GetDentryAt(ctx context.Context, creds *auth.Crede dentry: d, } rp.mount.IncRef() - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return vd, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return VirtualDentry{}, err } } @@ -252,7 +252,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } rp.mount.IncRef() name := rp.Component() - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return parentVD, name, nil } if checkInvariants { @@ -261,7 +261,7 @@ func (vfs *VirtualFilesystem) getParentDirAndName(ctx context.Context, creds *au } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return VirtualDentry{}, "", err } } @@ -292,7 +292,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential for { err := rp.mount.fs.impl.LinkAt(ctx, rp, oldVD) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldVD.DecRef(ctx) return nil } @@ -302,7 +302,7 @@ func (vfs *VirtualFilesystem) LinkAt(ctx context.Context, creds *auth.Credential } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldVD.DecRef(ctx) return err } @@ -331,7 +331,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MkdirAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -340,7 +340,7 @@ func (vfs *VirtualFilesystem) MkdirAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -366,7 +366,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.MknodAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -375,7 +375,7 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -425,7 +425,6 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential rp := vfs.getResolvingPath(creds, pop) if opts.Flags&linux.O_DIRECTORY != 0 { rp.mustBeDir = true - rp.mustBeDirOrig = true } // Ignore O_PATH for verity, as verity performs extra operations on the fd for verification. // The underlying filesystem that verity wraps opens the fd with O_PATH. @@ -444,7 +443,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential for { fd, err := rp.mount.fs.impl.OpenAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) if opts.FileExec { if fd.Mount().Flags.NoExec { @@ -468,7 +467,7 @@ func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credential return fd, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -480,11 +479,11 @@ func (vfs *VirtualFilesystem) ReadlinkAt(ctx context.Context, creds *auth.Creden for { target, err := rp.mount.fs.impl.ReadlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return target, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return "", err } } @@ -533,7 +532,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.RenameAt(ctx, rp, oldParentVD, oldName, renameOpts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldParentVD.DecRef(ctx) return nil } @@ -543,7 +542,7 @@ func (vfs *VirtualFilesystem) RenameAt(ctx context.Context, creds *auth.Credenti } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) oldParentVD.DecRef(ctx) return err } @@ -569,7 +568,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia for { err := rp.mount.fs.impl.RmdirAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -578,7 +577,7 @@ func (vfs *VirtualFilesystem) RmdirAt(ctx context.Context, creds *auth.Credentia } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -590,11 +589,11 @@ func (vfs *VirtualFilesystem) SetStatAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SetStatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -606,11 +605,11 @@ func (vfs *VirtualFilesystem) StatAt(ctx context.Context, creds *auth.Credential for { stat, err := rp.mount.fs.impl.StatAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return stat, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return linux.Statx{}, err } } @@ -623,11 +622,11 @@ func (vfs *VirtualFilesystem) StatFSAt(ctx context.Context, creds *auth.Credenti for { statfs, err := rp.mount.fs.impl.StatFSAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return statfs, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return linux.Statfs{}, err } } @@ -652,7 +651,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent for { err := rp.mount.fs.impl.SymlinkAt(ctx, rp, target) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -661,7 +660,7 @@ func (vfs *VirtualFilesystem) SymlinkAt(ctx context.Context, creds *auth.Credent } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -686,7 +685,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti for { err := rp.mount.fs.impl.UnlinkAt(ctx, rp) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if checkInvariants { @@ -695,7 +694,7 @@ func (vfs *VirtualFilesystem) UnlinkAt(ctx context.Context, creds *auth.Credenti } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -707,7 +706,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C for { bep, err := rp.mount.fs.impl.BoundEndpointAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return bep, nil } if checkInvariants { @@ -716,7 +715,7 @@ func (vfs *VirtualFilesystem) BoundEndpointAt(ctx context.Context, creds *auth.C } } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -729,7 +728,7 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede for { names, err := rp.mount.fs.impl.ListXattrAt(ctx, rp, size) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return names, nil } if err == syserror.ENOTSUP { @@ -737,11 +736,11 @@ func (vfs *VirtualFilesystem) ListXattrAt(ctx context.Context, creds *auth.Crede // fs/xattr.c:vfs_listxattr() falls back to allowing the security // subsystem to return security extended attributes, which by // default don't exist. - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil, err } } @@ -754,11 +753,11 @@ func (vfs *VirtualFilesystem) GetXattrAt(ctx context.Context, creds *auth.Creden for { val, err := rp.mount.fs.impl.GetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return val, nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return "", err } } @@ -771,11 +770,11 @@ func (vfs *VirtualFilesystem) SetXattrAt(ctx context.Context, creds *auth.Creden for { err := rp.mount.fs.impl.SetXattrAt(ctx, rp, *opts) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } @@ -787,11 +786,11 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre for { err := rp.mount.fs.impl.RemoveXattrAt(ctx, rp, name) if err == nil { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return nil } if !rp.handleError(ctx, err) { - vfs.putResolvingPath(ctx, rp) + rp.Release(ctx) return err } } |