diff options
Diffstat (limited to 'pkg')
47 files changed, 1569 insertions, 505 deletions
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 61790221b..3f39edb66 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -215,6 +215,8 @@ type AtomicRefCount struct { // LeakMode configures the leak checker. type LeakMode uint32 +// TODO(gvisor.dev/issue/1624): Simplify down to two modes once vfs1 ref +// counting is gone. const ( // UninitializedLeakChecking indicates that the leak checker has not yet been initialized. UninitializedLeakChecking LeakMode = iota @@ -244,6 +246,11 @@ func SetLeakMode(mode LeakMode) { atomic.StoreUint32(&leakMode, uint32(mode)) } +// GetLeakMode returns the current leak mode. +func GetLeakMode() LeakMode { + return LeakMode(atomic.LoadUint32(&leakMode)) +} + const maxStackFrames = 40 type fileLine struct { diff --git a/pkg/refs_vfs2/BUILD b/pkg/refs_vfs2/BUILD new file mode 100644 index 000000000..7f180c7bd --- /dev/null +++ b/pkg/refs_vfs2/BUILD @@ -0,0 +1,28 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template") + +package(licenses = ["notice"]) + +go_template( + name = "refs_template", + srcs = [ + "refs_template.go", + ], + types = [ + "T", + ], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/log", + "//pkg/refs", + ], +) + +go_library( + name = "refs", + srcs = [ + "refs.go", + ], + visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/context"], +) diff --git a/pkg/refs_vfs2/refs.go b/pkg/refs_vfs2/refs.go new file mode 100644 index 000000000..ee01b17b0 --- /dev/null +++ b/pkg/refs_vfs2/refs.go @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package refs defines an interface for a reference-counted object. +package refs + +import ( + "gvisor.dev/gvisor/pkg/context" +) + +// RefCounter is the interface to be implemented by objects that are reference +// counted. +type RefCounter interface { + // IncRef increments the reference counter on the object. + IncRef() + + // DecRef decrements the object's reference count. Users of refs_template.Refs + // may specify a destructor to be called once the reference count reaches zero. + DecRef(ctx context.Context) + + // TryIncRef attempts to increment the reference count, but may fail if all + // references have already been dropped, in which case it returns false. If + // true is returned, then a valid reference is now held on the object. + TryIncRef() bool +} diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refs_vfs2/refs_template.go new file mode 100644 index 000000000..3e5b458c7 --- /dev/null +++ b/pkg/refs_vfs2/refs_template.go @@ -0,0 +1,133 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package refs_template defines a template that can be used by reference counted +// objects. +package refs_template + +import ( + "runtime" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/log" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" +) + +// T is the type of the reference counted object. It is only used to customize +// debug output when leak checking. +type T interface{} + +// ownerType is used to customize logging. Note that we use a pointer to T so +// that we do not copy the entire object when passed as a format parameter. +var ownerType *T + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// Note that the number of references is actually refCount + 1 so that a default +// zero-value Refs object contains one reference. +// +// +stateify savable +type Refs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +func (r *Refs) finalize() { + var note string + switch refs_vfs1.GetLeakMode() { + case refs_vfs1.NoLeakChecking: + return + case refs_vfs1.UninitializedLeakChecking: + note = "(Leak checker uninitialized): " + } + if n := r.ReadRefs(); n != 0 { + log.Warningf("%sAtomicRefCount %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) + } +} + +// EnableLeakCheck checks for reference leaks when Refs gets garbage collected. +func (r *Refs) EnableLeakCheck() { + if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { + runtime.SetFinalizer(r, (*Refs).finalize) + } +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *Refs) ReadRefs() int64 { + // Account for the internal -1 offset on refcounts. + return atomic.LoadInt64(&r.refCount) + 1 +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *Refs) IncRef() { + if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { + panic("Incrementing non-positive ref count") + } +} + +// TryIncRef implements refs.RefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *Refs) TryIncRef() bool { + const speculativeRef = 1 << 32 + v := atomic.AddInt64(&r.refCount, speculativeRef) + if int32(v) < 0 { + // This object has already been freed. + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + // Turn into a real reference. + atomic.AddInt64(&r.refCount, -speculativeRef+1) + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *Refs) DecRef(destroy func()) { + switch v := atomic.AddInt64(&r.refCount, -1); { + case v < -1: + panic("Decrementing non-positive ref count") + + case v == -1: + // Call the destructor. + if destroy != nil { + destroy() + } + } +} diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 663e51989..2bf3c45e1 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -49,6 +49,9 @@ type ProfileOpts struct { // - dump out the stack trace of current go routines. // sentryctl -pid <pid> pprof-goroutine type Profile struct { + // Kernel is the kernel under profile. It's immutable. + Kernel *kernel.Kernel + // mu protects the fields below. mu sync.Mutex @@ -57,9 +60,6 @@ type Profile struct { // traceFile is the current execution trace output file. traceFile *fd.FD - - // Kernel is the kernel under profile. - Kernel *kernel.Kernel } // StartCPUProfile is an RPC stub which starts recording the CPU profile in a diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index 852ec3c5c..a40625e19 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -160,8 +160,8 @@ func (fd *tunFD) EventUnregister(e *waiter.Entry) { fd.device.EventUnregister(e) } -// isNetTunSupported returns whether /dev/net/tun device is supported for s. -func isNetTunSupported(s inet.Stack) bool { +// IsNetTunSupported returns whether /dev/net/tun device is supported for s. +func IsNetTunSupported(s inet.Stack) bool { _, ok := s.(*netstack.Stack) return ok } diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index 702fdd392..8615b60f0 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -272,6 +272,96 @@ func (f *tcpSackFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSeque return n, f.tcpSack.stack.SetTCPSACKEnabled(*f.tcpSack.enabled) } +// +stateify savable +type tcpRecovery struct { + fsutil.SimpleFileInode + + stack inet.Stack `state:"wait"` + recovery inet.TCPLossRecovery +} + +func newTCPRecoveryInode(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { + ts := &tcpRecovery{ + SimpleFileInode: *fsutil.NewSimpleFileInode(ctx, fs.RootOwner, fs.FilePermsFromMode(0644), linux.PROC_SUPER_MAGIC), + stack: s, + } + sattr := fs.StableAttr{ + DeviceID: device.ProcDevice.DeviceID(), + InodeID: device.ProcDevice.NextIno(), + BlockSize: usermem.PageSize, + Type: fs.SpecialFile, + } + return fs.NewInode(ctx, ts, msrc, sattr) +} + +// Truncate implements fs.InodeOperations.Truncate. +func (*tcpRecovery) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (r *tcpRecovery) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + flags.Pread = true + flags.Pwrite = true + return fs.NewFile(ctx, dirent, flags, &tcpRecoveryFile{ + tcpRecovery: r, + stack: r.stack, + }), nil +} + +// +stateify savable +type tcpRecoveryFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + tcpRecovery *tcpRecovery + + stack inet.Stack `state:"wait"` +} + +// Read implements fs.FileOperations.Read. +func (f *tcpRecoveryFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + return 0, io.EOF + } + + recovery, err := f.stack.TCPRecovery() + if err != nil { + return 0, err + } + f.tcpRecovery.recovery = recovery + s := fmt.Sprintf("%d\n", f.tcpRecovery.recovery) + n, err := dst.CopyOut(ctx, []byte(s)) + return int64(n), err +} + +// Write implements fs.FileOperations.Write. +func (f *tcpRecoveryFile) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, offset int64) (int64, error) { + if src.NumBytes() == 0 { + return 0, nil + } + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + f.tcpRecovery.recovery = inet.TCPLossRecovery(v) + if err := f.tcpRecovery.stack.SetTCPRecovery(f.tcpRecovery.recovery); err != nil { + return 0, err + } + return n, nil +} + func (p *proc) newSysNetCore(ctx context.Context, msrc *fs.MountSource, s inet.Stack) *fs.Inode { // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -351,6 +441,11 @@ func (p *proc) newSysNetIPv4Dir(ctx context.Context, msrc *fs.MountSource, s ine contents["tcp_wmem"] = newTCPMemInode(ctx, msrc, s, tcpWMem) } + // Add tcp_recovery. + if _, err := s.TCPRecovery(); err == nil { + contents["tcp_recovery"] = newTCPRecoveryInode(ctx, msrc, s) + } + d := ramfs.NewDir(ctx, contents, fs.RootOwner, fs.FilePermsFromMode(0555)) return newProcInode(ctx, d, msrc, fs.SpecialDirectory, nil) } diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 4a800dcf9..16787116f 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -85,5 +85,6 @@ go_test( deps = [ "//pkg/p9", "//pkg/sentry/contexttest", + "//pkg/sentry/pgalloc", ], ) diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 1679066ba..2a8011eb4 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -90,10 +90,8 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { uid: uint32(opts.kuid), gid: uint32(opts.kgid), blockSize: usermem.PageSize, // arbitrary - handle: handle{ - fd: -1, - }, - nlink: uint32(2), + hostFD: -1, + nlink: uint32(2), } switch opts.mode.FileType() { case linux.S_IFDIR: @@ -205,14 +203,14 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { off := uint64(0) const count = 64 * 1024 // for consistency with the vfs1 client d.handleMu.RLock() - if !d.handleReadable { + if d.readFile.isNil() { // This should not be possible because a readable handle should // have been opened when the calling directoryFD was opened. d.handleMu.RUnlock() panic("gofer.dentry.getDirents called without a readable handle") } for { - p9ds, err := d.handle.file.readdir(ctx, off, count) + p9ds, err := d.readFile.readdir(ctx, off, count) if err != nil { d.handleMu.RUnlock() return nil, err @@ -304,5 +302,5 @@ func (fd *directoryFD) Seek(ctx context.Context, offset int64, whence int32) (in // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *directoryFD) Sync(ctx context.Context) error { - return fd.dentry().handle.sync(ctx) + return fd.dentry().syncRemoteFile(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index e6af37d0d..eaef2594d 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -15,6 +15,7 @@ package gofer import ( + "math" "sync" "sync/atomic" @@ -54,7 +55,7 @@ func (fs *filesystem) Sync(ctx context.Context) error { // Sync regular files. for _, d := range ds { - err := d.syncSharedHandle(ctx) + err := d.syncCachedFile(ctx) d.DecRef(ctx) if err != nil && retErr == nil { retErr = err @@ -724,8 +725,29 @@ func (fs *filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. if rp.Mount() != vd.Mount() { return syserror.EXDEV } - // 9P2000.L supports hard links, but we don't. - return syserror.EPERM + d := vd.Dentry().Impl().(*dentry) + if d.isDir() { + return syserror.EPERM + } + gid := auth.KGID(atomic.LoadUint32(&d.gid)) + uid := auth.KUID(atomic.LoadUint32(&d.uid)) + mode := linux.FileMode(atomic.LoadUint32(&d.mode)) + if err := vfs.MayLink(rp.Credentials(), mode, uid, gid); err != nil { + return err + } + if d.nlink == 0 { + return syserror.ENOENT + } + if d.nlink == math.MaxUint32 { + return syserror.EMLINK + } + if err := parent.file.link(ctx, d.file, childName); err != nil { + return err + } + + // Success! + atomic.AddUint32(&d.nlink, 1) + return nil }, nil) } @@ -1085,12 +1107,18 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving useRegularFileFD := child.fileType() == linux.S_IFREG && !d.fs.opts.regularFilesUseSpecialFileFD if useRegularFileFD { child.handleMu.Lock() - child.handle.file = openFile - if fdobj != nil { - child.handle.fd = int32(fdobj.Release()) + if vfs.MayReadFileWithOpenFlags(opts.Flags) { + child.readFile = openFile + if fdobj != nil { + child.hostFD = int32(fdobj.Release()) + } + } else if fdobj != nil { + // Can't use fdobj if it's not readable. + fdobj.Close() + } + if vfs.MayWriteFileWithOpenFlags(opts.Flags) { + child.writeFile = openFile } - child.handleReadable = vfs.MayReadFileWithOpenFlags(opts.Flags) - child.handleWritable = vfs.MayWriteFileWithOpenFlags(opts.Flags) child.handleMu.Unlock() } // Insert the dentry into the tree. diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 2e5575d8d..63e589859 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -192,10 +192,14 @@ const ( // // - File timestamps are based on client clocks. This ensures that users of // the client observe timestamps that are coherent with their own clocks - // and consistent with Linux's semantics. However, since it is not always - // possible for clients to set arbitrary atimes and mtimes, and never - // possible for clients to set arbitrary ctimes, file timestamp changes are - // stored in the client only and never sent to the remote filesystem. + // and consistent with Linux's semantics (in particular, it is not always + // possible for clients to set arbitrary atimes and mtimes depending on the + // remote filesystem implementation, and never possible for clients to set + // arbitrary ctimes.) If a dentry containing a client-defined atime or + // mtime is evicted from cache, client timestamps will be sent to the + // remote filesystem on a best-effort basis to attempt to ensure that + // timestamps will be preserved when another dentry representing the same + // file is instantiated. InteropModeExclusive InteropMode = iota // InteropModeWritethrough is appropriate when there are read-only users of @@ -502,9 +506,9 @@ func (fs *filesystem) Release(ctx context.Context) { for d := range fs.syncableDentries { d.handleMu.Lock() d.dataMu.Lock() - if d.handleWritable { + if h := d.writeHandleLocked(); h.isOpen() { // Write dirty cached data to the remote file. - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil { log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err) } // TODO(jamieliu): Do we need to flushf/fsync d? @@ -514,9 +518,9 @@ func (fs *filesystem) Release(ctx context.Context) { d.dirty.RemoveAll() d.dataMu.Unlock() // Close the host fd if one exists. - if d.handle.fd >= 0 { - syscall.Close(int(d.handle.fd)) - d.handle.fd = -1 + if d.hostFD >= 0 { + syscall.Close(int(d.hostFD)) + d.hostFD = -1 } d.handleMu.Unlock() } @@ -558,8 +562,6 @@ type dentry struct { // filesystem.renameMu. name string - // We don't support hard links, so each dentry maps 1:1 to an inode. - // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file @@ -620,9 +622,20 @@ type dentry struct { mtime int64 ctime int64 btime int64 - // File size, protected by both metadataMu and dataMu (i.e. both must be + // File size, which differs from other metadata in two ways: + // + // - We make a best-effort attempt to keep it up to date even if + // !dentry.cachedMetadataAuthoritative() for the sake of O_APPEND writes. + // + // - size is protected by both metadataMu and dataMu (i.e. both must be // locked to mutate it; locking either is sufficient to access it). size uint64 + // If this dentry does not represent a synthetic file, deleted is 0, and + // atimeDirty/mtimeDirty are non-zero, atime/mtime may have diverged from the + // remote file's timestamps, which should be updated when this dentry is + // evicted. + atimeDirty uint32 + mtimeDirty uint32 // nlink counts the number of hard links to this dentry. It's updated and // accessed using atomic operations. It's not protected by metadataMu like the @@ -635,30 +648,28 @@ type dentry struct { // the file into memmap.MappingSpaces. mappings is protected by mapsMu. mappings memmap.MappingSet - // If this dentry represents a regular file or directory: - // - // - handle is the I/O handle used by all regularFileFDs/directoryFDs - // representing this dentry. - // - // - handleReadable is true if handle is readable. + // - If this dentry represents a regular file or directory, readFile is the + // p9.File used for reads by all regularFileFDs/directoryFDs representing + // this dentry. // - // - handleWritable is true if handle is writable. + // - If this dentry represents a regular file, writeFile is the p9.File + // used for writes by all regularFileFDs representing this dentry. // - // Invariants: - // - // - If handleReadable == handleWritable == false, then handle.file == nil - // (i.e. there is no open handle). Conversely, if handleReadable || - // handleWritable == true, then handle.file != nil (i.e. there is an open - // handle). - // - // - handleReadable and handleWritable cannot transition from true to false - // (i.e. handles may not be downgraded). + // - If this dentry represents a regular file, hostFD is the host FD used + // for memory mappings and I/O (when applicable) in preference to readFile + // and writeFile. hostFD is always readable; if !writeFile.isNil(), it must + // also be writable. If hostFD is -1, no such host FD is available. // // These fields are protected by handleMu. - handleMu sync.RWMutex - handle handle - handleReadable bool - handleWritable bool + // + // readFile and writeFile may or may not represent the same p9.File. Once + // either p9.File transitions from closed (isNil() == true) to open + // (isNil() == false), it may be mutated with handleMu locked, but cannot + // be closed until the dentry is destroyed. + handleMu sync.RWMutex + readFile p9file + writeFile p9file + hostFD int32 dataMu sync.RWMutex @@ -672,7 +683,7 @@ type dentry struct { // tracks dirty segments in cache. dirty is protected by dataMu. dirty fsutil.DirtySet - // pf implements platform.File for mappings of handle.fd. + // pf implements platform.File for mappings of hostFD. pf dentryPlatformFile // If this dentry represents a symbolic link, InteropModeShared is not in @@ -734,9 +745,7 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), blockSize: usermem.PageSize, - handle: handle{ - fd: -1, - }, + hostFD: -1, } d.pf.dentry = d if mask.UID { @@ -803,10 +812,12 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { if attr.BlockSize != 0 { atomic.StoreUint32(&d.blockSize, uint32(attr.BlockSize)) } - if mask.ATime { + // Don't override newer client-defined timestamps with old server-defined + // ones. + if mask.ATime && atomic.LoadUint32(&d.atimeDirty) == 0 { atomic.StoreInt64(&d.atime, dentryTimestampFromP9(attr.ATimeSeconds, attr.ATimeNanoSeconds)) } - if mask.MTime { + if mask.MTime && atomic.LoadUint32(&d.mtimeDirty) == 0 { atomic.StoreInt64(&d.mtime, dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds)) } if mask.CTime { @@ -825,9 +836,13 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { // Preconditions: !d.isSynthetic() func (d *dentry) updateFromGetattr(ctx context.Context) error { - // Use d.handle.file, which represents a 9P fid that has been opened, in - // preference to d.file, which represents a 9P fid that has not. This may - // be significantly more efficient in some implementations. + // Use d.readFile or d.writeFile, which represent 9P fids that have been + // opened, in preference to d.file, which represents a 9P fid that has not. + // This may be significantly more efficient in some implementations. Prefer + // d.writeFile over d.readFile since some filesystem implementations may + // update a writable handle's metadata after writes to that handle, without + // making metadata updates immediately visible to read-only handles + // representing the same file. var ( file p9file handleMuRLocked bool @@ -837,8 +852,11 @@ func (d *dentry) updateFromGetattr(ctx context.Context) error { d.metadataMu.Lock() defer d.metadataMu.Unlock() d.handleMu.RLock() - if !d.handle.file.isNil() { - file = d.handle.file + if !d.writeFile.isNil() { + file = d.writeFile + handleMuRLocked = true + } else if !d.readFile.isNil() { + file = d.readFile handleMuRLocked = true } else { file = d.file @@ -903,51 +921,44 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return err } defer mnt.EndWrite() - setLocalAtime := false - setLocalMtime := false + + if stat.Mask&linux.STATX_SIZE != 0 { + // Reject attempts to truncate files other than regular files, since + // filesystem implementations may return the wrong errno. + switch mode.FileType() { + case linux.S_IFREG: + // ok + case linux.S_IFDIR: + return syserror.EISDIR + default: + return syserror.EINVAL + } + } + + var now int64 if d.cachedMetadataAuthoritative() { - // Timestamp updates will be handled locally. - setLocalAtime = stat.Mask&linux.STATX_ATIME != 0 - setLocalMtime = stat.Mask&linux.STATX_MTIME != 0 - stat.Mask &^= linux.STATX_ATIME | linux.STATX_MTIME - - // Prepare for truncate. - if stat.Mask&linux.STATX_SIZE != 0 { - switch mode.FileType() { - case linux.ModeRegular: - if !setLocalMtime { - // Truncate updates mtime. - setLocalMtime = true - stat.Mtime.Nsec = linux.UTIME_NOW - } - case linux.ModeDirectory: - return syserror.EISDIR - default: - return syserror.EINVAL + // Truncate updates mtime. + if stat.Mask&(linux.STATX_SIZE|linux.STATX_MTIME) == linux.STATX_SIZE { + stat.Mask |= linux.STATX_MTIME + stat.Mtime = linux.StatxTimestamp{ + Nsec: linux.UTIME_NOW, } } + + // Use client clocks for timestamps. + now = d.fs.clock.Now().Nanoseconds() + if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW { + stat.Atime = statxTimestampFromDentry(now) + } + if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW { + stat.Mtime = statxTimestampFromDentry(now) + } } + d.metadataMu.Lock() defer d.metadataMu.Unlock() - if stat.Mask&linux.STATX_SIZE != 0 { - // The size needs to be changed even when - // !d.cachedMetadataAuthoritative() because d.mappings has to be - // updated. - d.updateFileSizeLocked(stat.Size) - } if !d.isSynthetic() { if stat.Mask != 0 { - if stat.Mask&linux.STATX_SIZE != 0 { - // Check whether to allow a truncate request to be made. - switch d.mode & linux.S_IFMT { - case linux.S_IFREG: - // Allow. - case linux.S_IFDIR: - return syserror.EISDIR - default: - return syserror.EINVAL - } - } if err := d.file.setAttr(ctx, p9.SetAttrMask{ Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, @@ -969,6 +980,12 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs }); err != nil { return err } + if stat.Mask&linux.STATX_SIZE != 0 { + // d.size should be kept up to date, and privatized + // copy-on-write mappings of truncated pages need to be + // invalidated, even if InteropModeShared is in effect. + d.updateFileSizeLocked(stat.Size) + } } if d.fs.opts.interop == InteropModeShared { // There's no point to updating d's metadata in this case since @@ -978,7 +995,6 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs return nil } } - now := d.fs.clock.Now().Nanoseconds() if stat.Mask&linux.STATX_MODE != 0 { atomic.StoreUint32(&d.mode, d.fileType()|uint32(stat.Mode)) } @@ -988,23 +1004,18 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs if stat.Mask&linux.STATX_GID != 0 { atomic.StoreUint32(&d.gid, stat.GID) } - if setLocalAtime { - if stat.Atime.Nsec == linux.UTIME_NOW { - atomic.StoreInt64(&d.atime, now) - } else { - atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) - } - // Restore mask bits that we cleared earlier. - stat.Mask |= linux.STATX_ATIME + // Note that stat.Atime.Nsec and stat.Mtime.Nsec can't be UTIME_NOW because + // if d.cachedMetadataAuthoritative() then we converted stat.Atime and + // stat.Mtime to client-local timestamps above, and if + // !d.cachedMetadataAuthoritative() then we returned after calling + // d.file.setAttr(). For the same reason, now must have been initialized. + if stat.Mask&linux.STATX_ATIME != 0 { + atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) + atomic.StoreUint32(&d.atimeDirty, 0) } - if setLocalMtime { - if stat.Mtime.Nsec == linux.UTIME_NOW { - atomic.StoreInt64(&d.mtime, now) - } else { - atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) - } - // Restore mask bits that we cleared earlier. - stat.Mask |= linux.STATX_MTIME + if stat.Mask&linux.STATX_MTIME != 0 { + atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) + atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) return nil @@ -1147,7 +1158,8 @@ func (d *dentry) OnZeroWatches(ctx context.Context) { // operation. One of the calls may destroy the dentry, so subsequent calls will // do nothing. // -// Preconditions: d.fs.renameMu must be locked for writing. +// Preconditions: d.fs.renameMu must be locked for writing; it may be +// temporarily unlocked. func (d *dentry) checkCachingLocked(ctx context.Context) { // Dentries with a non-zero reference count must be retained. (The only way // to obtain a reference on a dentry with zero references is via path @@ -1227,11 +1239,13 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { } } -// destroyLocked destroys the dentry. It may flushes dirty pages from cache, -// close p9 file and remove reference on parent dentry. +// destroyLocked destroys the dentry. // -// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. d is -// not a child dentry. +// Preconditions: +// * d.fs.renameMu must be locked for writing; it may be temporarily unlocked. +// * d.refs == 0. +// * d.parent.children[d.name] != d, i.e. d is not reachable by path traversal +// from its former parent dentry. func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: @@ -1243,26 +1257,64 @@ func (d *dentry) destroyLocked(ctx context.Context) { panic("dentry.destroyLocked() called with references on the dentry") } + // Allow the following to proceed without renameMu locked to improve + // scalability. + d.fs.renameMu.Unlock() + + mf := d.fs.mfp.MemoryFile() d.handleMu.Lock() - if !d.handle.file.isNil() { - mf := d.fs.mfp.MemoryFile() - d.dataMu.Lock() + d.dataMu.Lock() + if h := d.writeHandleLocked(); h.isOpen() { // Write dirty pages back to the remote filesystem. - if d.handleWritable { - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { - log.Warningf("gofer.dentry.DecRef: failed to write dirty data back: %v", err) - } + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { + log.Warningf("gofer.dentry.destroyLocked: failed to write dirty data back: %v", err) } - // Discard cached data. + } + // Discard cached data. + if !d.cache.IsEmpty() { + mf.MarkAllUnevictable(d) d.cache.DropAll(mf) d.dirty.RemoveAll() - d.dataMu.Unlock() - // Clunk open fids and close open host FDs. - d.handle.close(ctx) + } + d.dataMu.Unlock() + // Clunk open fids and close open host FDs. + if !d.readFile.isNil() { + d.readFile.close(ctx) + } + if !d.writeFile.isNil() && d.readFile != d.writeFile { + d.writeFile.close(ctx) + } + d.readFile = p9file{} + d.writeFile = p9file{} + if d.hostFD >= 0 { + syscall.Close(int(d.hostFD)) + d.hostFD = -1 } d.handleMu.Unlock() if !d.file.isNil() { + if !d.isDeleted() { + // Write dirty timestamps back to the remote filesystem. + atimeDirty := atomic.LoadUint32(&d.atimeDirty) != 0 + mtimeDirty := atomic.LoadUint32(&d.mtimeDirty) != 0 + if atimeDirty || mtimeDirty { + atime := atomic.LoadInt64(&d.atime) + mtime := atomic.LoadInt64(&d.mtime) + if err := d.file.setAttr(ctx, p9.SetAttrMask{ + ATime: atimeDirty, + ATimeNotSystemTime: atimeDirty, + MTime: mtimeDirty, + MTimeNotSystemTime: mtimeDirty, + }, p9.SetAttr{ + ATimeSeconds: uint64(atime / 1e9), + ATimeNanoSeconds: uint64(atime % 1e9), + MTimeSeconds: uint64(mtime / 1e9), + MTimeNanoSeconds: uint64(mtime % 1e9), + }); err != nil { + log.Warningf("gofer.dentry.destroyLocked: failed to write dirty timestamps back: %v", err) + } + } + } d.file.close(ctx) d.file = p9file{} // Remove d from the set of syncable dentries. @@ -1270,6 +1322,9 @@ func (d *dentry) destroyLocked(ctx context.Context) { delete(d.fs.syncableDentries, d) d.fs.syncMu.Unlock() } + + d.fs.renameMu.Lock() + // Drop the reference held by d on its parent without recursively locking // d.fs.renameMu. if d.parent != nil { @@ -1369,80 +1424,120 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // O_TRUNC). if !trunc { d.handleMu.RLock() - if (!read || d.handleReadable) && (!write || d.handleWritable) { - // The current handle is sufficient. + if (!read || !d.readFile.isNil()) && (!write || !d.writeFile.isNil()) { + // Current handles are sufficient. d.handleMu.RUnlock() return nil } d.handleMu.RUnlock() } - haveOldFD := false + fdToClose := int32(-1) + invalidateTranslations := false d.handleMu.Lock() - if (read && !d.handleReadable) || (write && !d.handleWritable) || trunc { - // Get a new handle. - wantReadable := d.handleReadable || read - wantWritable := d.handleWritable || write - h, err := openHandle(ctx, d.file, wantReadable, wantWritable, trunc) + if (read && d.readFile.isNil()) || (write && d.writeFile.isNil()) || trunc { + // Get a new handle. If this file has been opened for both reading and + // writing, try to get a single handle that is usable for both: + // + // - Writable memory mappings of a host FD require that the host FD is + // opened for both reading and writing. + // + // - NOTE(b/141991141): Some filesystems may not ensure coherence + // between multiple handles for the same file. + openReadable := !d.readFile.isNil() || read + openWritable := !d.writeFile.isNil() || write + h, err := openHandle(ctx, d.file, openReadable, openWritable, trunc) + if err == syserror.EACCES && (openReadable != read || openWritable != write) { + // It may not be possible to use a single handle for both + // reading and writing, since permissions on the file may have + // changed to e.g. disallow reading after previously being + // opened for reading. In this case, we have no choice but to + // use separate handles for reading and writing. + ctx.Debugf("gofer.dentry.ensureSharedHandle: bifurcating read/write handles for dentry %p", d) + openReadable = read + openWritable = write + h, err = openHandle(ctx, d.file, openReadable, openWritable, trunc) + } if err != nil { d.handleMu.Unlock() return err } - if !d.handle.file.isNil() { - // Check that old and new handles are compatible: If the old handle - // includes a host file descriptor but the new one does not, or - // vice versa, old and new memory mappings may be incoherent. - haveOldFD = d.handle.fd >= 0 - haveNewFD := h.fd >= 0 - if haveOldFD != haveNewFD { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: can't change host FD availability from %v to %v across dentry handle upgrade", haveOldFD, haveNewFD) - h.close(ctx) - return syserror.EIO - } - if haveOldFD { - // We may have raced with callers of d.pf.FD() that are now - // using the old file descriptor, preventing us from safely - // closing it. We could handle this by invalidating existing - // memmap.Translations, but this is expensive. Instead, use - // dup3 to make the old file descriptor refer to the new file - // description, then close the new file descriptor (which is no - // longer needed). Racing callers may use the old or new file - // description, but this doesn't matter since they refer to the - // same file (unless d.fs.opts.overlayfsStaleRead is true, - // which we handle separately). - if err := syscall.Dup3(int(h.fd), int(d.handle.fd), syscall.O_CLOEXEC); err != nil { + + if d.hostFD < 0 && openReadable && h.fd >= 0 { + // We have no existing FD; use the new FD for at least reading. + d.hostFD = h.fd + } else if d.hostFD >= 0 && d.writeFile.isNil() && openWritable { + // We have an existing read-only FD, but the file has just been + // opened for writing, so we need to start supporting writable memory + // mappings. This may race with callers of d.pf.FD() using the existing + // FD, so in most cases we need to delay closing the old FD until after + // invalidating memmap.Translations that might have observed it. + if !openReadable || h.fd < 0 { + // We don't have a read/write FD, so we have no FD that can be + // used to create writable memory mappings. Switch to using the + // internal page cache. + invalidateTranslations = true + fdToClose = d.hostFD + d.hostFD = -1 + } else if d.fs.opts.overlayfsStaleRead { + // We do have a read/write FD, but it may not be coherent with + // the existing read-only FD, so we must switch to mappings of + // the new FD in both the application and sentry. + if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, d.handle.fd, err) + ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) h.close(ctx) return err } - syscall.Close(int(h.fd)) - h.fd = d.handle.fd - if d.fs.opts.overlayfsStaleRead { - // Replace sentry mappings of the old FD with mappings of - // the new FD, since the two are not necessarily coherent. - if err := d.pf.hostFileMapper.RegenerateMappings(int(h.fd)); err != nil { - d.handleMu.Unlock() - ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to replace sentry mappings of old FD with mappings of new FD: %v", err) - h.close(ctx) - return err - } + invalidateTranslations = true + fdToClose = d.hostFD + d.hostFD = h.fd + } else { + // We do have a read/write FD. To avoid invalidating existing + // memmap.Translations (which is expensive), use dup3 to make + // the old file descriptor refer to the new file description, + // then close the new file descriptor (which is no longer + // needed). Racing callers of d.pf.FD() may use the old or new + // file description, but this doesn't matter since they refer + // to the same file, and any racing mappings must be read-only. + if err := syscall.Dup3(int(h.fd), int(d.hostFD), syscall.O_CLOEXEC); err != nil { + oldHostFD := d.hostFD + d.handleMu.Unlock() + ctx.Warningf("gofer.dentry.ensureSharedHandle: failed to dup fd %d to fd %d: %v", h.fd, oldHostFD, err) + h.close(ctx) + return err } - // Clunk the old fid before making the new handle visible (by - // unlocking d.handleMu). - d.handle.file.close(ctx) + fdToClose = h.fd } + } else { + // h.fd is not useful. + fdToClose = h.fd + } + + // Switch to new fids. + var oldReadFile p9file + if openReadable { + oldReadFile = d.readFile + d.readFile = h.file + } + var oldWriteFile p9file + if openWritable { + oldWriteFile = d.writeFile + d.writeFile = h.file + } + // NOTE(b/141991141): Clunk old fids before making new fids visible (by + // unlocking d.handleMu). + if !oldReadFile.isNil() { + oldReadFile.close(ctx) + } + if !oldWriteFile.isNil() && oldReadFile != oldWriteFile { + oldWriteFile.close(ctx) } - // Switch to the new handle. - d.handle = h - d.handleReadable = wantReadable - d.handleWritable = wantWritable } d.handleMu.Unlock() - if d.fs.opts.overlayfsStaleRead && haveOldFD { - // Invalidate application mappings that may be using the old FD; they + if invalidateTranslations { + // Invalidate application mappings that may be using an old FD; they // will be replaced with mappings using the new FD after future calls // to d.Translate(). This requires holding d.mapsMu, which precedes // d.handleMu in the lock order. @@ -1450,7 +1545,51 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool d.mappings.InvalidateAll(memmap.InvalidateOpts{}) d.mapsMu.Unlock() } + if fdToClose >= 0 { + syscall.Close(int(fdToClose)) + } + + return nil +} + +// Preconditions: d.handleMu must be locked. +func (d *dentry) readHandleLocked() handle { + return handle{ + file: d.readFile, + fd: d.hostFD, + } +} +// Preconditions: d.handleMu must be locked. +func (d *dentry) writeHandleLocked() handle { + return handle{ + file: d.writeFile, + fd: d.hostFD, + } +} + +func (d *dentry) syncRemoteFile(ctx context.Context) error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + return d.syncRemoteFileLocked(ctx) +} + +// Preconditions: d.handleMu must be locked. +func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if d.hostFD >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(d.hostFD)) + ctx.UninterruptibleSleepFinish(false) + return err + } + if !d.writeFile.isNil() { + return d.writeFile.fsync(ctx) + } + if !d.readFile.isNil() { + return d.readFile.fsync(ctx) + } return nil } diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index 56d80bcf8..bfe75dfe4 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -20,10 +20,13 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/contexttest" + "gvisor.dev/gvisor/pkg/sentry/pgalloc" ) func TestDestroyIdempotent(t *testing.T) { + ctx := contexttest.Context(t) fs := filesystem{ + mfp: pgalloc.MemoryFileProviderFromContext(ctx), syncableDentries: make(map[*dentry]struct{}), opts: filesystemOptions{ // Test relies on no dentry being held in the cache. @@ -31,7 +34,6 @@ func TestDestroyIdempotent(t *testing.T) { }, } - ctx := contexttest.Context(t) attr := &p9.Attr{ Mode: p9.ModeRegular, } @@ -50,6 +52,8 @@ func TestDestroyIdempotent(t *testing.T) { } parent.cacheNewChildLocked(child, "child") + fs.renameMu.Lock() + defer fs.renameMu.Unlock() child.checkCachingLocked(ctx) if got := atomic.LoadInt64(&child.refs); got != -1 { t.Fatalf("child.refs=%d, want: -1", got) diff --git a/pkg/sentry/fsimpl/gofer/handle.go b/pkg/sentry/fsimpl/gofer/handle.go index 8792ca4f2..104157512 100644 --- a/pkg/sentry/fsimpl/gofer/handle.go +++ b/pkg/sentry/fsimpl/gofer/handle.go @@ -63,6 +63,10 @@ func openHandle(ctx context.Context, file p9file, read, write, trunc bool) (hand }, nil } +func (h *handle) isOpen() bool { + return !h.file.isNil() +} + func (h *handle) close(ctx context.Context) { h.file.close(ctx) h.file = p9file{} @@ -124,18 +128,3 @@ func (h *handle) writeFromBlocksAt(ctx context.Context, srcs safemem.BlockSeq, o } return cp, cperr } - -func (h *handle) sync(ctx context.Context) error { - // Handle most common case first. - if h.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(h.fd)) - ctx.UninterruptibleSleepFinish(false) - return err - } - if h.file.isNil() { - // File hasn't been touched, there is nothing to sync. - return nil - } - return h.file.fsync(ctx) -} diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index db6bed4f6..7e1cbf065 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -64,34 +64,34 @@ func (fd *regularFileFD) OnClose(ctx context.Context) error { } d.handleMu.RLock() defer d.handleMu.RUnlock() - return d.handle.file.flush(ctx) + if d.writeFile.isNil() { + return nil + } + return d.writeFile.flush(ctx) } // Allocate implements vfs.FileDescriptionImpl.Allocate. func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint64) error { - d := fd.dentry() d.metadataMu.Lock() defer d.metadataMu.Unlock() - size := offset + length - // Allocating a smaller size is a noop. - if size <= d.size { + size := offset + length + if d.cachedMetadataAuthoritative() && size <= d.size { return nil } - d.handleMu.Lock() - defer d.handleMu.Unlock() - - err := d.handle.file.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + d.handleMu.RLock() + err := d.writeFile.allocate(ctx, p9.ToAllocateMode(mode), offset, length) + d.handleMu.RUnlock() if err != nil { return err } d.dataMu.Lock() atomic.StoreUint64(&d.size, size) d.dataMu.Unlock() - if !d.cachedMetadataAuthoritative() { + if d.cachedMetadataAuthoritative() { d.touchCMtimeLocked() } return nil @@ -113,7 +113,7 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // Check for reading at EOF before calling into MM (but not under // InteropModeShared, which makes d.size unreliable). d := fd.dentry() - if d.fs.opts.interop != InteropModeShared && uint64(offset) >= atomic.LoadUint64(&d.size) { + if d.cachedMetadataAuthoritative() && uint64(offset) >= atomic.LoadUint64(&d.size) { return 0, io.EOF } @@ -217,16 +217,23 @@ func (fd *regularFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off n, err := src.CopyInTo(ctx, rw) if err != nil { - return n, offset, err + return n, offset + n, err } if n > 0 && fd.vfsfd.StatusFlags()&(linux.O_DSYNC|linux.O_SYNC) != 0 { - // Write dirty cached pages touched by the write back to the remote file. + // Note that if any of the following fail, then we can't guarantee that + // any data was actually written with the semantics of O_DSYNC or + // O_SYNC, so we return zero bytes written. Compare Linux's + // mm/filemap.c:generic_file_write_iter() => + // include/linux/fs.h:generic_write_sync(). + // + // Write dirty cached pages touched by the write back to the remote + // file. if err := d.writeback(ctx, offset, src.NumBytes()); err != nil { - return n, offset, err + return 0, offset, err } // Request the remote filesystem to sync the remote file. - if err := d.handle.sync(ctx); err != nil { - return n, offset, err + if err := d.syncRemoteFile(ctx); err != nil { + return 0, offset, err } } return n, offset + n, nil @@ -317,10 +324,11 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) // coherence with memory-mapped I/O), or if InteropModeShared is in effect // (which prevents us from caching file contents and makes dentry.size // unreliable), or if the file was opened O_DIRECT, read directly from - // dentry.handle without locking dentry.dataMu. + // dentry.readHandleLocked() without locking dentry.dataMu. rw.d.handleMu.RLock() - if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { - n, err := rw.d.handle.readToBlocksAt(rw.ctx, dsts, rw.off) + h := rw.d.readHandleLocked() + if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { + n, err := h.readToBlocksAt(rw.ctx, dsts, rw.off) rw.d.handleMu.RUnlock() rw.off += n return n, err @@ -388,7 +396,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) End: gapEnd, } optMR := gap.Range() - err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, rw.d.handle.readToBlocksAt) + err := rw.d.cache.Fill(rw.ctx, reqMR, maxFillRange(reqMR, optMR), mf, usage.PageCache, h.readToBlocksAt) mf.MarkEvictable(rw.d, pgalloc.EvictableRange{optMR.Start, optMR.End}) seg, gap = rw.d.cache.Find(rw.off) if !seg.Ok() { @@ -403,7 +411,7 @@ func (rw *dentryReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) } else { // Read directly from the file. gapDsts := dsts.TakeFirst64(gapMR.Length()) - n, err := rw.d.handle.readToBlocksAt(rw.ctx, gapDsts, gapMR.Start) + n, err := h.readToBlocksAt(rw.ctx, gapDsts, gapMR.Start) done += n rw.off += n dsts = dsts.DropFirst64(n) @@ -435,11 +443,12 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro // If we have a mmappable host FD (which must be used here to ensure // coherence with memory-mapped I/O), or if InteropModeShared is in effect // (which prevents us from caching file contents), or if the file was - // opened with O_DIRECT, write directly to dentry.handle without locking - // dentry.dataMu. + // opened with O_DIRECT, write directly to dentry.writeHandleLocked() + // without locking dentry.dataMu. rw.d.handleMu.RLock() - if (rw.d.handle.fd >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { - n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, srcs, rw.off) + h := rw.d.writeHandleLocked() + if (rw.d.hostFD >= 0 && !rw.d.fs.opts.forcePageCache) || rw.d.fs.opts.interop == InteropModeShared || rw.direct { + n, err := h.writeFromBlocksAt(rw.ctx, srcs, rw.off) rw.off += n rw.d.dataMu.Lock() if rw.off > rw.d.size { @@ -501,7 +510,7 @@ func (rw *dentryReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, erro // for detecting or avoiding this. gapMR := gap.Range().Intersect(mr) gapSrcs := srcs.TakeFirst64(gapMR.Length()) - n, err := rw.d.handle.writeFromBlocksAt(rw.ctx, gapSrcs, gapMR.Start) + n, err := h.writeFromBlocksAt(rw.ctx, gapSrcs, gapMR.Start) done += n rw.off += n srcs = srcs.DropFirst64(n) @@ -527,7 +536,7 @@ exitLoop: if err := fsutil.SyncDirty(rw.ctx, memmap.MappableRange{ Start: start, End: rw.off, - }, &rw.d.cache, &rw.d.dirty, rw.d.size, mf, rw.d.handle.writeFromBlocksAt); err != nil { + }, &rw.d.cache, &rw.d.dirty, rw.d.size, mf, h.writeFromBlocksAt); err != nil { // We have no idea how many bytes were actually flushed. rw.off = start done = 0 @@ -545,6 +554,7 @@ func (d *dentry) writeback(ctx context.Context, offset, size int64) error { } d.handleMu.RLock() defer d.handleMu.RUnlock() + h := d.writeHandleLocked() d.dataMu.Lock() defer d.dataMu.Unlock() // Compute the range of valid bytes (overflow-checked). @@ -558,7 +568,7 @@ func (d *dentry) writeback(ctx context.Context, offset, size int64) error { return fsutil.SyncDirty(ctx, memmap.MappableRange{ Start: uint64(offset), End: uint64(end), - }, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) + }, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) } // Seek implements vfs.FileDescriptionImpl.Seek. @@ -615,24 +625,23 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncSharedHandle(ctx) + return fd.dentry().syncCachedFile(ctx) } -func (d *dentry) syncSharedHandle(ctx context.Context) error { +func (d *dentry) syncCachedFile(ctx context.Context) error { d.handleMu.RLock() defer d.handleMu.RUnlock() - if d.handleWritable { + if h := d.writeHandleLocked(); h.isOpen() { d.dataMu.Lock() // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), d.handle.writeFromBlocksAt) + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) d.dataMu.Unlock() if err != nil { return err } } - // Sync the remote file. - return d.handle.sync(ctx) + return d.syncRemoteFileLocked(ctx) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -656,7 +665,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt return syserror.ENODEV } d.handleMu.RLock() - haveFD := d.handle.fd >= 0 + haveFD := d.hostFD >= 0 d.handleMu.RUnlock() if !haveFD { return syserror.ENODEV @@ -677,7 +686,7 @@ func (d *dentry) mayCachePages() bool { return true } d.handleMu.RLock() - haveFD := d.handle.fd >= 0 + haveFD := d.hostFD >= 0 d.handleMu.RUnlock() return haveFD } @@ -735,7 +744,7 @@ func (d *dentry) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, // Translate implements memmap.Mappable.Translate. func (d *dentry) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { d.handleMu.RLock() - if d.handle.fd >= 0 && !d.fs.opts.forcePageCache { + if d.hostFD >= 0 && !d.fs.opts.forcePageCache { d.handleMu.RUnlock() mr := optional if d.fs.opts.limitHostFDTranslation { @@ -771,7 +780,8 @@ func (d *dentry) Translate(ctx context.Context, required, optional memmap.Mappab } mf := d.fs.mfp.MemoryFile() - cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, d.handle.readToBlocksAt) + h := d.readHandleLocked() + cerr := d.cache.Fill(ctx, required, maxFillRange(required, optional), mf, usage.PageCache, h.readToBlocksAt) var ts []memmap.Translation var translatedEnd uint64 @@ -840,9 +850,12 @@ func (d *dentry) InvalidateUnsavable(ctx context.Context) error { // Write the cache's contents back to the remote file so that if we have a // host fd after restore, the remote file's contents are coherent. mf := d.fs.mfp.MemoryFile() + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() d.dataMu.Lock() defer d.dataMu.Unlock() - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { return err } @@ -857,20 +870,23 @@ func (d *dentry) InvalidateUnsavable(ctx context.Context) error { // Evict implements pgalloc.EvictableMemoryUser.Evict. func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { + mr := memmap.MappableRange{er.Start, er.End} + mf := d.fs.mfp.MemoryFile() d.mapsMu.Lock() defer d.mapsMu.Unlock() + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() d.dataMu.Lock() defer d.dataMu.Unlock() - mr := memmap.MappableRange{er.Start, er.End} - mf := d.fs.mfp.MemoryFile() // Only allow pages that are no longer memory-mapped to be evicted. for mgap := d.mappings.LowerBoundGap(mr.Start); mgap.Ok() && mgap.Start() < mr.End; mgap = mgap.NextGap() { mgapMR := mgap.Range().Intersect(mr) if mgapMR.Length() == 0 { continue } - if err := fsutil.SyncDirty(ctx, mgapMR, &d.cache, &d.dirty, d.size, mf, d.handle.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirty(ctx, mgapMR, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { log.Warningf("Failed to writeback cached data %v: %v", mgapMR, err) } d.cache.Drop(mgapMR, mf) @@ -882,8 +898,8 @@ func (d *dentry) Evict(ctx context.Context, er pgalloc.EvictableRange) { // cannot implement both vfs.DentryImpl.IncRef and memmap.File.IncRef. // // dentryPlatformFile is only used when a host FD representing the remote file -// is available (i.e. dentry.handle.fd >= 0), and that FD is used for -// application memory mappings (i.e. !filesystem.opts.forcePageCache). +// is available (i.e. dentry.hostFD >= 0), and that FD is used for application +// memory mappings (i.e. !filesystem.opts.forcePageCache). type dentryPlatformFile struct { *dentry @@ -891,8 +907,8 @@ type dentryPlatformFile struct { // by dentry.dataMu. fdRefs fsutil.FrameRefSet - // If this dentry represents a regular file, and handle.fd >= 0, - // hostFileMapper caches mappings of handle.fd. + // If this dentry represents a regular file, and dentry.hostFD >= 0, + // hostFileMapper caches mappings of dentry.hostFD. hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. @@ -916,15 +932,13 @@ func (d *dentryPlatformFile) DecRef(fr memmap.FileRange) { // MapInternal implements memmap.File.MapInternal. func (d *dentryPlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { d.handleMu.RLock() - bs, err := d.hostFileMapper.MapInternal(fr, int(d.handle.fd), at.Write) - d.handleMu.RUnlock() - return bs, err + defer d.handleMu.RUnlock() + return d.hostFileMapper.MapInternal(fr, int(d.hostFD), at.Write) } // FD implements memmap.File.FD. func (d *dentryPlatformFile) FD() int { d.handleMu.RLock() - fd := d.handle.fd - d.handleMu.RUnlock() - return int(fd) + defer d.handleMu.RUnlock() + return int(d.hostFD) } diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index fc269ef2b..a6368fdd0 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -17,6 +17,7 @@ package gofer import ( "sync" "sync/atomic" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -279,5 +280,13 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncSharedHandle(ctx) + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if fd.handle.fd >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(fd.handle.fd)) + ctx.UninterruptibleSleepFinish(false) + return err + } + return fd.handle.file.fsync(ctx) } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 0eef4e16e..2cb8191b9 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -47,6 +47,7 @@ func (d *dentry) touchAtime(mnt *vfs.Mount) { now := d.fs.clock.Now().Nanoseconds() d.metadataMu.Lock() atomic.StoreInt64(&d.atime, now) + atomic.StoreUint32(&d.atimeDirty, 1) d.metadataMu.Unlock() mnt.EndWrite() } @@ -67,6 +68,7 @@ func (d *dentry) touchCMtime() { d.metadataMu.Lock() atomic.StoreInt64(&d.mtime, now) atomic.StoreInt64(&d.ctime, now) + atomic.StoreUint32(&d.mtimeDirty, 1) d.metadataMu.Unlock() } @@ -76,4 +78,5 @@ func (d *dentry) touchCMtimeLocked() { now := d.fs.clock.Now().Nanoseconds() atomic.StoreInt64(&d.mtime, now) atomic.StoreInt64(&d.ctime, now) + atomic.StoreUint32(&d.mtimeDirty, 1) } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 6dac2afa4..b71778128 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -55,7 +55,8 @@ func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) *ke if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]*kernfs.Dentry{ "ipv4": kernfs.NewStaticDir(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, map[string]*kernfs.Dentry{ - "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), + "tcp_recovery": fs.newDentry(root, fs.NextIno(), 0644, &tcpRecoveryData{stack: stack}), + "tcp_sack": fs.newDentry(root, fs.NextIno(), 0644, &tcpSackData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the @@ -207,3 +208,49 @@ func (d *tcpSackData) Write(ctx context.Context, src usermem.IOSequence, offset *d.enabled = v != 0 return n, d.stack.SetTCPSACKEnabled(*d.enabled) } + +// tcpRecoveryData implements vfs.WritableDynamicBytesSource for +// /proc/sys/net/ipv4/tcp_recovery. +// +// +stateify savable +type tcpRecoveryData struct { + kernfs.DynamicBytesFile + + stack inet.Stack `state:"wait"` +} + +var _ vfs.WritableDynamicBytesSource = (*tcpRecoveryData)(nil) + +// Generate implements vfs.DynamicBytesSource. +func (d *tcpRecoveryData) Generate(ctx context.Context, buf *bytes.Buffer) error { + recovery, err := d.stack.TCPRecovery() + if err != nil { + return err + } + + buf.WriteString(fmt.Sprintf("%d\n", recovery)) + return nil +} + +func (d *tcpRecoveryData) Write(ctx context.Context, src usermem.IOSequence, offset int64) (int64, error) { + if offset != 0 { + // No need to handle partial writes thus far. + return 0, syserror.EINVAL + } + if src.NumBytes() == 0 { + return 0, nil + } + + // Limit the amount of memory allocated. + src = src.TakeFirst(usermem.PageSize - 1) + + var v int32 + n, err := usermem.CopyInt32StringInVec(ctx, src.IO, src.Addrs, &v, src.Opts) + if err != nil { + return 0, err + } + if err := d.stack.SetTCPRecovery(inet.TCPLossRecovery(v)); err != nil { + return 0, err + } + return n, nil +} diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index fb77f95cc..065812065 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -566,7 +566,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if replaced != nil { newParentDir.removeChildLocked(replaced) if replaced.inode.isDir() { - newParentDir.inode.decLinksLocked(ctx) // from replaced's ".." + // Remove links for replaced/. and replaced/.. + replaced.inode.decLinksLocked(ctx) + newParentDir.inode.decLinksLocked(ctx) } replaced.inode.decLinksLocked(ctx) } diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 2916a0644..c0b4831d1 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -56,6 +56,12 @@ type Stack interface { // settings. SetTCPSACKEnabled(enabled bool) error + // TCPRecovery returns the TCP loss detection algorithm. + TCPRecovery() (TCPLossRecovery, error) + + // SetTCPRecovery attempts to change TCP loss detection algorithm. + SetTCPRecovery(recovery TCPLossRecovery) error + // Statistics reports stack statistics. Statistics(stat interface{}, arg string) error @@ -189,3 +195,14 @@ type StatSNMPUDP [8]uint64 // StatSNMPUDPLite describes UdpLite line of /proc/net/snmp. type StatSNMPUDPLite [8]uint64 + +// TCPLossRecovery indicates TCP loss detection and recovery methods to use. +type TCPLossRecovery int32 + +// Loss recovery constants from include/net/tcp.h which are used to set +// /proc/sys/net/ipv4/tcp_recovery. +const ( + TCP_RACK_LOSS_DETECTION TCPLossRecovery = 1 << iota + TCP_RACK_STATIC_REO_WND + TCP_RACK_NO_DUPTHRESH +) diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index d8961fc94..9771f01fc 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -25,6 +25,7 @@ type TestStack struct { TCPRecvBufSize TCPBufferSize TCPSendBufSize TCPBufferSize TCPSACKFlag bool + Recovery TCPLossRecovery } // NewTestStack returns a TestStack with no network interfaces. The value of @@ -91,6 +92,17 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { return nil } +// TCPRecovery implements Stack.TCPRecovery. +func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) { + return s.Recovery, nil +} + +// SetTCPRecovery implements Stack.SetTCPRecovery. +func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error { + s.Recovery = recovery + return nil +} + // Statistics implements inet.Stack.Statistics. func (s *TestStack) Statistics(stat interface{}, arg string) error { return nil diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index dba563160..ed5ae03d3 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -49,7 +49,7 @@ func bluepillArchEnter(context *arch.SignalContext64) (c *vCPU) { regs.Sp = context.Sp regs.Pc = context.Pc regs.Pstate = context.Pstate - regs.Pstate &^= uint64(ring0.KernelFlagsClear) + regs.Pstate &^= uint64(ring0.PsrFlagsClear) regs.Pstate |= ring0.KernelFlagsSet return } @@ -63,7 +63,7 @@ func bluepillArchExit(c *vCPU, context *arch.SignalContext64) { context.Sp = regs.Sp context.Pc = regs.Pc context.Pstate = regs.Pstate - context.Pstate &^= uint64(ring0.UserFlagsClear) + context.Pstate &^= uint64(ring0.PsrFlagsClear) context.Pstate |= ring0.UserFlagsSet lazyVfp := c.GetLazyVFP() diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 8b64f3a1e..b35c930e2 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -41,7 +41,7 @@ func fpsimdPtr(addr *byte) *arch.FpsimdContext { func dieArchSetup(c *vCPU, context *arch.SignalContext64, guestRegs *userRegs) { // If the vCPU is in user mode, we set the stack to the stored stack // value in the vCPU itself. We don't want to unwind the user stack. - if guestRegs.Regs.Pstate&ring0.PSR_MODE_MASK == ring0.PSR_MODE_EL0t { + if guestRegs.Regs.Pstate&ring0.PsrModeMask == ring0.UserFlagsSet { regs := c.CPU.Registers() context.Regs[0] = regs.Regs[0] context.Sp = regs.Sp diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index ff8c068c0..307a7645f 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -151,12 +151,6 @@ func (c *vCPU) initArchState() error { // the MMIO address base. arm64HypercallMMIOBase = toLocation - data = ring0.PsrDefaultSet | ring0.KernelFlagsSet - reg.id = _KVM_ARM64_REGS_PSTATE - if err := c.setOneRegister(®); err != nil { - return err - } - // Initialize the PCID database. if hasGuestPCID { // Note that NewPCIDs may return a nil table here, in which diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index 8122ac6e2..87a573cc4 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -26,30 +26,31 @@ const ( _PMD_PGT_SIZE = 0x4000 _PTE_PGT_BASE = 0x7000 _PTE_PGT_SIZE = 0x1000 - - _PSR_D_BIT = 0x00000200 - _PSR_A_BIT = 0x00000100 - _PSR_I_BIT = 0x00000080 - _PSR_F_BIT = 0x00000040 ) const ( - // PSR bits - PSR_MODE_EL0t = 0x00000000 - PSR_MODE_EL1t = 0x00000004 - PSR_MODE_EL1h = 0x00000005 - PSR_MODE_MASK = 0x0000000f + // DAIF bits:debug, sError, IRQ, FIQ. + _PSR_D_BIT = 0x00000200 + _PSR_A_BIT = 0x00000100 + _PSR_I_BIT = 0x00000080 + _PSR_F_BIT = 0x00000040 + _PSR_DAIF_SHIFT = 6 + _PSR_DAIF_MASK = 0xf << _PSR_DAIF_SHIFT - // KernelFlagsSet should always be set in the kernel. - KernelFlagsSet = PSR_MODE_EL1h + // PSR bits. + _PSR_MODE_EL0t = 0x00000000 + _PSR_MODE_EL1t = 0x00000004 + _PSR_MODE_EL1h = 0x00000005 + _PSR_MODE_MASK = 0x0000000f - // UserFlagsSet are always set in userspace. - UserFlagsSet = PSR_MODE_EL0t + PsrFlagsClear = _PSR_MODE_MASK | _PSR_DAIF_MASK + PsrModeMask = _PSR_MODE_MASK - KernelFlagsClear = PSR_MODE_MASK - UserFlagsClear = PSR_MODE_MASK + // KernelFlagsSet should always be set in the kernel. + KernelFlagsSet = _PSR_MODE_EL1h | _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT - PsrDefaultSet = _PSR_D_BIT | _PSR_A_BIT | _PSR_I_BIT | _PSR_F_BIT + // UserFlagsSet are always set in userspace. + UserFlagsSet = _PSR_MODE_EL0t ) // Vector is an exception vector. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 9fd02d628..d8a7bc2f9 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -312,12 +312,6 @@ ISB $15; \ DSB $15; -#define IRQ_ENABLE \ - MSR $2, DAIFSet; - -#define IRQ_DISABLE \ - MSR $2, DAIFClr; - #define VFP_ENABLE \ MOVD $FPEN_ENABLE, R0; \ WORD $0xd5181040; \ //MSR R0, CPACR_EL1 @@ -509,8 +503,6 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 // Start is the CPU entrypoint. TEXT ·Start(SB),NOSPLIT,$0 - IRQ_DISABLE - // Init. MOVD $SCTLR_EL1_DEFAULT, R1 MSR R1, SCTLR_EL1 diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index d483ff03c..42009dac0 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -56,7 +56,7 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { // Sanitize registers. regs := switchOpts.Registers - regs.Pstate &= ^uint64(UserFlagsClear) + regs.Pstate &= ^uint64(PsrFlagsClear) regs.Pstate |= UserFlagsSet LoadFloatingPoint(switchOpts.FloatingPointState) diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index a48082631..fda3dcb35 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -53,6 +53,7 @@ type Stack struct { interfaceAddrs map[int32][]inet.InterfaceAddr routes []inet.Route supportsIPv6 bool + tcpRecovery inet.TCPLossRecovery tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool @@ -350,6 +351,16 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + return s.tcpRecovery, nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserror.EACCES +} + // getLine reads one line from proc file, with specified prefix. // The last argument, withHeader, specifies if it contains line header. func getLine(f *os.File, prefix string, withHeader bool) string { diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 67737ae87..f0fe18684 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -207,6 +207,20 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + var recovery tcp.Recovery + if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + return inet.TCPLossRecovery(recovery), nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError() +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { switch stats := stat.(type) { diff --git a/pkg/sentry/syscalls/linux/vfs2/memfd.go b/pkg/sentry/syscalls/linux/vfs2/memfd.go index 519583e4e..c4c0f9e0a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/memfd.go +++ b/pkg/sentry/syscalls/linux/vfs2/memfd.go @@ -51,6 +51,7 @@ func MemfdCreate(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S if err != nil { return 0, nil, err } + defer file.DecRef(t) fd, err := t.NewFDFromVFS2(0, file, kernel.FDFlags{ CloseOnExec: cloExec, diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 576ab3920..d3c1197e3 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -356,6 +356,8 @@ type FileDescriptionImpl interface { // Allocate grows the file to offset + length bytes. // Only mode == 0 is supported currently. + // + // Preconditions: The FileDescription was opened for writing. Allocate(ctx context.Context, mode, offset, length uint64) error // waiter.Waitable methods may be used to poll for I/O events. @@ -565,6 +567,9 @@ func (fd *FileDescription) StatFS(ctx context.Context) (linux.Statfs, error) { // Allocate grows file represented by FileDescription to offset + length bytes. func (fd *FileDescription) Allocate(ctx context.Context, mode, offset, length uint64) error { + if !fd.IsWritable() { + return syserror.EBADF + } return fd.impl.Allocate(ctx, mode, offset, length) } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 62ac932bb..d0d1efd0d 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -101,6 +101,11 @@ const ( // IPv4Version is the version of the ipv4 protocol. IPv4Version = 4 + // IPv4AllSystems is the all systems IPv4 multicast address as per + // IANA's IPv4 Multicast Address Space Registry. See + // https://www.iana.org/assignments/multicast-addresses/multicast-addresses.xhtml. + IPv4AllSystems tcpip.Address = "\xe0\x00\x00\x01" + // IPv4Broadcast is the broadcast address of the IPv4 procotol. IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff" diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go index 644ba7c33..5d286ccbc 100644 --- a/pkg/tcpip/stack/ndp_test.go +++ b/pkg/tcpip/stack/ndp_test.go @@ -1689,13 +1689,7 @@ func containsV6Addr(list []tcpip.ProtocolAddress, item tcpip.AddressWithPrefix) AddressWithPrefix: item, } - for _, i := range list { - if i == protocolAddress { - return true - } - } - - return false + return containsAddr(list, protocolAddress) } // TestNoAutoGenAddr tests that SLAAC is not performed when configured not to. diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 1d37716c2..27e1feec0 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -115,17 +115,15 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { - if linkRes != nil { - if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { - e := NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - } - return e, nil, nil + if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { + e := NeighborEntry{ + Addr: remoteAddr, + LocalAddr: localAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAt: time.Now(), } + return e, nil, nil } entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) @@ -289,8 +287,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, nil) +func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index 4cb2c9c6b..b4fa69e3e 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -335,32 +335,6 @@ func TestNeighborCacheEntry(t *testing.T) { } } -// TestNeighborCacheEntryNoLinkAddress verifies calling entry() without a -// LinkAddressResolver returns ErrNoLinkAddress. -func TestNeighborCacheEntryNoLinkAddress(t *testing.T) { - nudDisp := testNUDDispatcher{} - c := DefaultNUDConfigurations() - clock := newFakeClock() - neigh := newTestNeighborCache(&nudDisp, c, clock) - store := newTestEntryStore() - - entry, ok := store.entry(0) - if !ok { - t.Fatalf("store.entry(0) not found") - } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, nil, nil) - if err != tcpip.ErrNoLinkAddress { - t.Errorf("got neigh.entry(%s, %s, nil, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) - } - - // No events should have been dispatched. - nudDisp.mu.Lock() - defer nudDisp.mu.Unlock() - if diff := cmp.Diff(nudDisp.events, []testEntryEventInfo(nil)); diff != "" { - t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) - } -} - func TestNeighborCacheRemoveEntry(t *testing.T) { config := DefaultNUDConfigurations() @@ -1048,9 +1022,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Fatalf("c.store.entry(0) not found") } c.neigh.addStaticEntry(entry.Addr, entry.LinkAddr) - e, _, err := c.neigh.entry(entry.Addr, "", nil, nil) + e, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil) if err != nil { - t.Errorf("unexpected error from c.neigh.entry(%s, \"\", nil nil): %s", entry.Addr, err) + t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ Addr: entry.Addr, @@ -1059,7 +1033,7 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("c.neigh.entry(%s, \"\", nil, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) + t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } wantEvents := []testEntryEventInfo{ diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 08c9ccd25..b769fb2fa 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -236,7 +236,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1, entryTestAddr2, nudState, &linkRes) + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) // Stub out ndpState to verify modification of default routers. nic.mu.ndp = ndpState{ @@ -2344,6 +2344,106 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { nudDisp.mu.Unlock() } +// TestEntryUnknownToStaleToProbeToReachable exercises the following scenario: +// 1. Probe is received +// 2. Entry is created in Stale +// 3. Packet is queued on the entry +// 4. Entry transitions to Delay then Probe +// 5. Probe is sent +func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { + c := DefaultNUDConfigurations() + // Eliminate random factors from ReachableTime computation so the transition + // from Probe to Reachable will only take BaseReachableTime duration. + c.MinRandomFactor = 1 + c.MaxRandomFactor = 1 + + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handleProbeLocked(entryTestLinkAddr1) + e.handlePacketQueuedLocked() + e.mu.Unlock() + + clock.advance(c.DelayFirstProbeTime) + + wantProbes := []entryTestProbeInfo{ + // Probe caused by the Delay-to-Probe transition + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() + if got, want := e.neigh.State, Probe; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if got, want := e.neigh.State, Reachable; got != want { + t.Errorf("got e.neigh.State = %q, want = %q", got, want) + } + if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) + } + e.mu.Unlock() + + clock.advance(c.BaseReachableTime) + + wantEvents := []testEntryEventInfo{ + { + EventType: entryTestAdded, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, + { + EventType: entryTestChanged, + NICID: entryTestNICID, + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, + } + nudDisp.mu.Lock() + if diff := cmp.Diff(nudDisp.events, wantEvents, eventDiffOpts()...); diff != "" { + t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) + } + nudDisp.mu.Unlock() +} + func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() // Eliminate random factors from ReachableTime computation so the transition diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index f21066fce..eaaf756cd 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -217,6 +217,11 @@ func (n *NIC) disableLocked() *tcpip.Error { } if _, ok := n.stack.networkProtocols[header.IPv4ProtocolNumber]; ok { + // The NIC may have already left the multicast group. + if err := n.leaveGroupLocked(header.IPv4AllSystems, false /* force */); err != nil && err != tcpip.ErrBadLocalAddress { + return err + } + // The address may have already been removed. if err := n.removePermanentAddressLocked(ipv4BroadcastAddr.AddressWithPrefix.Address); err != nil && err != tcpip.ErrBadLocalAddress { return err @@ -255,6 +260,13 @@ func (n *NIC) enable() *tcpip.Error { if _, err := n.addAddressLocked(ipv4BroadcastAddr, NeverPrimaryEndpoint, permanent, static, false /* deprecated */); err != nil { return err } + + // As per RFC 1122 section 3.3.7, all hosts should join the all-hosts + // multicast group. Note, the IANA calls the all-hosts multicast group the + // all-systems multicast group. + if err := n.joinGroupLocked(header.IPv4ProtocolNumber, header.IPv4AllSystems); err != nil { + return err + } } // Join the IPv6 All-Nodes Multicast group if the stack is configured to @@ -609,6 +621,9 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // If none exists a temporary one may be created if we are in promiscuous mode // or spoofing. Promiscuous mode will only be checked if promiscuous is true. // Similarly, spoofing will only be checked if spoofing is true. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, tempRef getRefBehaviour) *referencedNetworkEndpoint { n.mu.RLock() @@ -633,6 +648,16 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t } } + // Check if address is a broadcast address for the endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + n.mu.RUnlock() + return ref + } + } + // A usable reference was not found, create a temporary one if requested by // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous @@ -670,8 +695,34 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t return ref } +// getRefForBroadcastLocked returns an endpoint where address is the IPv4 +// broadcast address for the endpoint's network. +// +// n.mu MUST be read locked. +func (n *NIC) getRefForBroadcastRLocked(address tcpip.Address) *referencedNetworkEndpoint { + for _, ref := range n.mu.endpoints { + // Only IPv4 has a notion of broadcast addresses. + if ref.protocol != header.IPv4ProtocolNumber { + continue + } + + addr := ref.addrWithPrefix() + subnet := addr.Subnet() + if subnet.IsBroadcast(address) && ref.tryIncRef() { + return ref + } + } + + return nil +} + /// getRefOrCreateTempLocked returns an existing endpoint for address or creates /// and returns a temporary endpoint. +// +// If the address is the IPv4 broadcast address for an endpoint's network, that +// endpoint will be returned. +// +// n.mu must be write locked. func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { if ref, ok := n.mu.endpoints[NetworkEndpointID{address}]; ok { // No need to check the type as we are ok with expired endpoints at this @@ -685,6 +736,15 @@ func (n *NIC) getRefOrCreateTempLocked(protocol tcpip.NetworkProtocolNumber, add n.removeEndpointLocked(ref) } + // Check if address is a broadcast address for an endpoint's network. + // + // Only IPv4 has a notion of broadcast addresses. + if protocol == header.IPv4ProtocolNumber { + if ref := n.getRefForBroadcastRLocked(address); ref != nil { + return ref + } + } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index f848d50ad..e1ec15487 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress) + HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7189e8e7e..5b19c5d59 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1985,8 +1985,8 @@ func generateRandInt64() int64 { // FindNetworkEndpoint returns the network endpoint for the given address. func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) (NetworkEndpoint, *tcpip.Error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() for _, nic := range s.nics { id := NetworkEndpointID{address} @@ -2005,8 +2005,8 @@ func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, addres // FindNICNameFromID returns the name of the nic for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() nic, ok := s.nics[id] if !ok { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index f22062889..0b6deda02 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -277,6 +277,17 @@ func (l *linkEPWithMockedAttach) isAttached() bool { return l.attached } +// Checks to see if list contains an address. +func containsAddr(list []tcpip.ProtocolAddress, item tcpip.ProtocolAddress) bool { + for _, i := range list { + if i == item { + return true + } + } + + return false +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. @@ -1704,7 +1715,7 @@ func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, sub // Trying the next address should always fail since it is outside the range. if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = 0", fakeNetNumber, tcpip.Address(addrBytes), gotNicID) } } @@ -3089,6 +3100,13 @@ func TestIPv6SourceAddressSelectionScopeAndSameAddress(t *testing.T) { func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { const nicID = 1 + broadcastAddr := tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: header.IPv4Broadcast, + PrefixLen: 32, + }, + } e := loopback.New() s := stack.New(stack.Options{ @@ -3099,49 +3117,41 @@ func TestAddRemoveIPv4BroadcastAddressOnNICEnableDisable(t *testing.T) { t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) } - allStackAddrs := s.AllAddresses() - allNICAddrs, ok := allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } // Enabling the NIC should add the IPv4 broadcast address. if err := s.EnableNIC(nicID); err != nil { t.Fatalf("s.EnableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 1 { - t.Fatalf("got len(allNICAddrs) = %d, want = 1", l) - } - want := tcpip.ProtocolAddress{ - Protocol: header.IPv4ProtocolNumber, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: header.IPv4Broadcast, - PrefixLen: 32, - }, - } - if allNICAddrs[0] != want { - t.Fatalf("got allNICAddrs[0] = %+v, want = %+v", allNICAddrs[0], want) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if !containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, want = %+v", allNICAddrs, broadcastAddr) + } } // Disabling the NIC should remove the IPv4 broadcast address. if err := s.DisableNIC(nicID); err != nil { t.Fatalf("s.DisableNIC(%d): %s", nicID, err) } - allStackAddrs = s.AllAddresses() - allNICAddrs, ok = allStackAddrs[nicID] - if !ok { - t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) - } - if l := len(allNICAddrs); l != 0 { - t.Fatalf("got len(allNICAddrs) = %d, want = 0", l) + + { + allStackAddrs := s.AllAddresses() + if allNICAddrs, ok := allStackAddrs[nicID]; !ok { + t.Fatalf("entry for %d missing from allStackAddrs = %+v", nicID, allStackAddrs) + } else if containsAddr(allNICAddrs, broadcastAddr) { + t.Fatalf("got allNICAddrs = %+v, don't want = %+v", allNICAddrs, broadcastAddr) + } } } @@ -3189,50 +3199,93 @@ func TestLeaveIPv6SolicitedNodeAddrBeforeAddrRemoval(t *testing.T) { } } -func TestJoinLeaveAllNodesMulticastOnNICEnableDisable(t *testing.T) { +func TestJoinLeaveMulticastOnNICEnableDisable(t *testing.T) { const nicID = 1 - e := loopback.New() - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocol{ipv6.NewProtocol()}, - }) - nicOpts := stack.NICOptions{Disabled: true} - if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { - t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + }{ + { + name: "IPv6 All-Nodes", + proto: header.IPv6ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + }, + { + name: "IPv4 All-Systems", + proto: header.IPv4ProtocolNumber, + addr: header.IPv4AllSystems, + }, } - // Should not be in the IPv6 all-nodes multicast group yet because the NIC has - // not been enabled yet. - isInGroup, err := s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + e := loopback.New() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + }) + nicOpts := stack.NICOptions{Disabled: true} + if err := s.CreateNICWithOptions(nicID, e, nicOpts); err != nil { + t.Fatalf("CreateNIC(%d, _, %+v) = %s", nicID, nicOpts, err) + } - // The all-nodes multicast group should be joined when the NIC is enabled. - if err := s.EnableNIC(nicID); err != nil { - t.Fatalf("s.EnableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if !isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, header.IPv6AllNodesMulticastAddress) - } + // Should not be in the multicast group yet because the NIC has not been + // enabled yet. + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } - // The all-nodes multicast group should be left when the NIC is disabled. - if err := s.DisableNIC(nicID); err != nil { - t.Fatalf("s.DisableNIC(%d): %s", nicID, err) - } - isInGroup, err = s.IsInGroup(nicID, header.IPv6AllNodesMulticastAddress) - if err != nil { - t.Fatalf("IsInGroup(%d, %s): %s", nicID, header.IPv6AllNodesMulticastAddress, err) - } - if isInGroup { - t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, header.IPv6AllNodesMulticastAddress) + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // The multicast group should be left when the NIC is disabled. + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + + // The all-nodes multicast group should be joined when the NIC is enabled. + if err := s.EnableNIC(nicID); err != nil { + t.Fatalf("s.EnableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if !isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = false, want = true", nicID, test.addr) + } + + // Leaving the group before disabling the NIC should not cause an error. + if err := s.LeaveGroup(test.proto, nicID, test.addr); err != nil { + t.Fatalf("s.LeaveGroup(%d, %d, %s): %s", test.proto, nicID, test.addr, err) + } + + if err := s.DisableNIC(nicID); err != nil { + t.Fatalf("s.DisableNIC(%d): %s", nicID, err) + } + + if isInGroup, err := s.IsInGroup(nicID, test.addr); err != nil { + t.Fatalf("IsInGroup(%d, %s): %s", nicID, test.addr, err) + } else if isInGroup { + t.Fatalf("got IsInGroup(%d, %s) = true, want = false", nicID, test.addr) + } + }) } } diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD new file mode 100644 index 000000000..7fff30462 --- /dev/null +++ b/pkg/tcpip/tests/integration/BUILD @@ -0,0 +1,21 @@ +load("//tools:defs.bzl", "go_test") + +package(licenses = ["notice"]) + +go_test( + name = "integration_test", + size = "small", + srcs = ["multicast_broadcast_test.go"], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go new file mode 100644 index 000000000..d9b2d147a --- /dev/null +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -0,0 +1,274 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const defaultMTU = 1280 + +// TestIncomingMulticastAndBroadcast tests receiving a packet destined to some +// multicast or broadcast address. +func TestIncomingMulticastAndBroadcast(t *testing.T) { + const ( + nicID = 1 + remotePort = 5555 + localPort = 80 + ttl = 255 + ) + + data := []byte{1, 2, 3, 4} + + // Local IPv4 subnet: 192.168.1.58/24 + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + + // Local IPv6 subnet: 200a::1/64 + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + + // Remote addrs. + remoteIPv4Addr := tcpip.Address("\x64\x0a\x7b\x18") + remoteIPv6Addr := tcpip.Address("\x20\x0b\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02") + + rxIPv4UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + totalLen := header.IPv4MinimumSize + payloadLen + hdr := buffer.NewPrependable(totalLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv4Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + Protocol: uint8(udp.ProtocolNumber), + TTL: ttl, + SrcAddr: remoteIPv4Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv4ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + rxIPv6UDP := func(e *channel.Endpoint, dst tcpip.Address) { + payloadLen := header.UDPMinimumSize + len(data) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + payloadLen) + u := header.UDP(hdr.Prepend(payloadLen)) + u.Encode(&header.UDPFields{ + SrcPort: remotePort, + DstPort: localPort, + Length: uint16(payloadLen), + }) + copy(u.Payload(), data) + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, remoteIPv6Addr, dst, uint16(payloadLen)) + sum = header.Checksum(data, sum) + u.SetChecksum(^u.CalculateChecksum(sum)) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLen), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: ttl, + SrcAddr: remoteIPv6Addr, + DstAddr: dst, + }) + + e.InjectInbound(header.IPv6ProtocolNumber, &stack.PacketBuffer{ + Data: hdr.View().ToVectorisedView(), + }) + } + + tests := []struct { + name string + bindAddr tcpip.Address + dstAddr tcpip.Address + expectRx bool + }{ + { + name: "IPv4 unicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + { + name: "IPv4 unicast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4Addr.Address, + expectRx: false, + }, + { + name: "IPv4 unicast binding to wildcard", + dstAddr: ipv4Addr.Address, + expectRx: true, + }, + + { + name: "IPv4 directed broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + { + name: "IPv4 directed broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: ipv4SubnetBcast, + expectRx: false, + }, + { + name: "IPv4 directed broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 broadcast binding to broadcast", + bindAddr: header.IPv4Broadcast, + dstAddr: header.IPv4Broadcast, + expectRx: true, + }, + { + name: "IPv4 broadcast binding to subnet broadcast", + bindAddr: ipv4SubnetBcast, + dstAddr: header.IPv4Broadcast, + expectRx: false, + }, + { + name: "IPv4 broadcast binding to wildcard", + dstAddr: ipv4SubnetBcast, + expectRx: true, + }, + + { + name: "IPv4 all-systems multicast binding to all-systems multicast", + bindAddr: header.IPv4AllSystems, + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to wildcard", + dstAddr: header.IPv4AllSystems, + expectRx: true, + }, + { + name: "IPv4 all-systems multicast binding to unicast", + bindAddr: ipv4Addr.Address, + dstAddr: header.IPv4AllSystems, + expectRx: false, + }, + + // IPv6 has no notion of a broadcast. + { + name: "IPv6 unicast binding to wildcard", + dstAddr: ipv6Addr.Address, + expectRx: true, + }, + { + name: "IPv6 broadcast-like address binding to wildcard", + dstAddr: ipv6SubnetBcast, + expectRx: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID, err) + } + ipv4ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: ipv4Addr} + if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv4ProtoAddr, err) + } + ipv6ProtoAddr := tcpip.ProtocolAddress{Protocol: header.IPv6ProtocolNumber, AddressWithPrefix: ipv6Addr} + if err := s.AddProtocolAddress(nicID, ipv6ProtoAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID, ipv6ProtoAddr, err) + } + + var netproto tcpip.NetworkProtocolNumber + var rxUDP func(*channel.Endpoint, tcpip.Address) + switch l := len(test.dstAddr); l { + case header.IPv4AddressSize: + netproto = header.IPv4ProtocolNumber + rxUDP = rxIPv4UDP + case header.IPv6AddressSize: + netproto = header.IPv6ProtocolNumber + rxUDP = rxIPv6UDP + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) + } + defer ep.Close() + + bindAddr := tcpip.FullAddress{Addr: test.bindAddr, Port: localPort} + if err := ep.Bind(bindAddr); err != nil { + t.Fatalf("ep.Bind(%+v): %s", bindAddr, err) + } + + rxUDP(e, test.dstAddr) + if gotPayload, _, err := ep.Read(nil); test.expectRx { + if err != nil { + t.Fatalf("Read(nil): %s", err) + } + if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } + } else { + if err != tcpip.ErrWouldBlock { + t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + } + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 39ea38fe6..b8b52b03d 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1777,15 +1777,8 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Same as effectively disabling TCPLinger timeout. v = 0 } - var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption - if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { - // We were unable to retrieve a stack config, just use - // the DefaultTCPLingerTimeout. - if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { - stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) - } - } - // Cap it to the stack wide TCPLinger timeout. + // Cap it to MaxTCPLingerTimeout. + stkTCPLingerTimeout := tcpip.TCPLingerTimeoutOption(MaxTCPLingerTimeout) if v > stkTCPLingerTimeout { v = stkTCPLingerTimeout } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b34e47bbd..2e5093b36 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -62,6 +62,10 @@ const ( // FIN_WAIT_2 state before being marked closed. DefaultTCPLingerTimeout = 60 * time.Second + // MaxTCPLingerTimeout is the maximum amount of time that sockets + // linger in FIN_WAIT_2 state before being marked closed. + MaxTCPLingerTimeout = 120 * time.Second + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger // in TIME_WAIT state before being marked closed. DefaultTCPTimeWaitTimeout = 60 * time.Second @@ -80,6 +84,25 @@ const ( // enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. type SACKEnabled bool +// Recovery is used by stack.(*Stack).TransportProtocolOption to +// set loss detection algorithm in TCP. +type Recovery int32 + +const ( + // RACKLossDetection indicates RACK is used for loss detection and + // recovery. + RACKLossDetection Recovery = 1 << iota + + // RACKStaticReoWnd indicates the reordering window should not be + // adjusted when DSACK is received. + RACKStaticReoWnd + + // RACKNoDupTh indicates RACK should not consider the classic three + // duplicate acknowledgements rule to mark the segments as lost. This + // is used when reordering is not detected. + RACKNoDupTh +) + // DelayEnabled is used by stack.(Stack*).TransportProtocolOption to // enable/disable Nagle's algorithm in TCP. type DelayEnabled bool @@ -161,6 +184,7 @@ func (s *synRcvdCounter) Threshold() uint64 { type protocol struct { mu sync.RWMutex sackEnabled bool + recovery Recovery delayEnabled bool sendBufferSize SendBufferSizeOption recvBufferSize ReceiveBufferSizeOption @@ -280,6 +304,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case Recovery: + p.mu.Lock() + p.recovery = Recovery(v) + p.mu.Unlock() + return nil + case DelayEnabled: p.mu.Lock() p.delayEnabled = bool(v) @@ -394,6 +424,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *Recovery: + p.mu.RLock() + *v = Recovery(p.recovery) + p.mu.RUnlock() + return nil + case *DelayEnabled: p.mu.RLock() *v = DelayEnabled(p.delayEnabled) @@ -535,6 +571,7 @@ func NewProtocol() stack.TransportProtocol { minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, + recovery: RACKLossDetection, } p.dispatcher.init(runtime.GOMAXPROCS(0)) return &p diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index fb25b86b9..1b58eb91b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6112,7 +6112,7 @@ func TestTCPLingerTimeout(t *testing.T) { {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, // Values > stack's TCPLingerTimeout are capped to the stack's // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) - {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + {"AboveMaxLingerTimeout", 125 * time.Second, 120 * time.Second}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { diff --git a/pkg/test/dockerutil/container.go b/pkg/test/dockerutil/container.go index 5a2157951..052b6b99d 100644 --- a/pkg/test/dockerutil/container.go +++ b/pkg/test/dockerutil/container.go @@ -58,12 +58,6 @@ type Container struct { // a handle to restart the profile. Generally, tests/benchmarks using // profiles need to run as root. profiles []Profile - - // Stores streams attached to the container. Used by WaitForOutputSubmatch. - streams types.HijackedResponse - - // stores previously read data from the attached streams. - streamBuf bytes.Buffer } // RunOpts are options for running a container. @@ -175,11 +169,25 @@ func (c *Container) SpawnProcess(ctx context.Context, r RunOpts, args ...string) return Process{}, err } + // Open a connection to the container for parsing logs and for TTY. + stream, err := c.client.ContainerAttach(ctx, c.id, + types.ContainerAttachOptions{ + Stream: true, + Stdin: true, + Stdout: true, + Stderr: true, + }) + if err != nil { + return Process{}, fmt.Errorf("connect failed container id %s: %v", c.id, err) + } + + c.cleanups = append(c.cleanups, func() { stream.Close() }) + if err := c.Start(ctx); err != nil { return Process{}, err } - return Process{container: c, conn: c.streams}, nil + return Process{container: c, conn: stream}, nil } // Run is analogous to 'docker run'. @@ -273,23 +281,6 @@ func (c *Container) hostConfig(r RunOpts) *container.HostConfig { // Start is analogous to 'docker start'. func (c *Container) Start(ctx context.Context) error { - - // Open a connection to the container for parsing logs and for TTY. - streams, err := c.client.ContainerAttach(ctx, c.id, - types.ContainerAttachOptions{ - Stream: true, - Stdin: true, - Stdout: true, - Stderr: true, - }) - if err != nil { - return fmt.Errorf("failed to connect to container: %v", err) - } - - c.streams = streams - c.cleanups = append(c.cleanups, func() { - c.streams.Close() - }) if err := c.client.ContainerStart(ctx, c.id, types.ContainerStartOptions{}); err != nil { return fmt.Errorf("ContainerStart failed: %v", err) } @@ -485,34 +476,19 @@ func (c *Container) WaitForOutput(ctx context.Context, pattern string, timeout t // WaitForOutputSubmatch searches container logs for the given // pattern or times out. It returns any regexp submatches as well. func (c *Container) WaitForOutputSubmatch(ctx context.Context, pattern string, timeout time.Duration) ([]string, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() re := regexp.MustCompile(pattern) - if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil { - return matches, nil - } - - for exp := time.Now().Add(timeout); time.Now().Before(exp); { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - - c.streams.Conn.SetDeadline(time.Now().Add(50 * time.Millisecond)) - _, err := stdcopy.StdCopy(&c.streamBuf, &c.streamBuf, c.streams.Reader) - + for { + logs, err := c.Logs(ctx) if err != nil { - // check that it wasn't a timeout - if nerr, ok := err.(net.Error); !ok || !nerr.Timeout() { - return nil, err - } + return nil, fmt.Errorf("failed to get logs: %v logs: %s", err, logs) } - - if matches := re.FindStringSubmatch(c.streamBuf.String()); matches != nil { + if matches := re.FindStringSubmatch(logs); matches != nil { return matches, nil } + time.Sleep(50 * time.Millisecond) } - - return nil, fmt.Errorf("timeout waiting for output %q: out: %s", re.String(), c.streamBuf.String()) } // Kill kills the container. @@ -537,8 +513,18 @@ func (c *Container) CleanUp(ctx context.Context) { for _, profile := range c.profiles { profile.OnCleanUp(c) } + // Forget profiles. c.profiles = nil + + // Execute all cleanups. We execute cleanups here to close any + // open connections to the container before closing. Open connections + // can cause Kill and Remove to hang. + for _, c := range c.cleanups { + c() + } + c.cleanups = nil + // Kill the container. if err := c.Kill(ctx); err != nil && !strings.Contains(err.Error(), "is not running") { // Just log; can't do anything here. @@ -550,9 +536,4 @@ func (c *Container) CleanUp(ctx context.Context) { } // Forget all mounts. c.mounts = nil - // Execute all cleanups. - for _, c := range c.cleanups { - c() - } - c.cleanups = nil } diff --git a/pkg/test/dockerutil/profile.go b/pkg/test/dockerutil/profile.go index 1fab33083..f0396ef24 100644 --- a/pkg/test/dockerutil/profile.go +++ b/pkg/test/dockerutil/profile.go @@ -49,17 +49,16 @@ type Profile interface { // should have --profile set as an option in /etc/docker/daemon.json in // order for profiling to work with Pprof. type Pprof struct { - BasePath string // path to put profiles - BlockProfile bool - CPUProfile bool - GoRoutineProfile bool - HeapProfile bool - MutexProfile bool - Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. - shouldRun bool - cmd *exec.Cmd - stdout io.ReadCloser - stderr io.ReadCloser + BasePath string // path to put profiles + BlockProfile bool + CPUProfile bool + HeapProfile bool + MutexProfile bool + Duration time.Duration // duration to run profiler e.g. '10s' or '1m'. + shouldRun bool + cmd *exec.Cmd + stdout io.ReadCloser + stderr io.ReadCloser } // MakePprofFromFlags makes a Pprof profile from flags. @@ -68,13 +67,12 @@ func MakePprofFromFlags(c *Container) *Pprof { return nil } return &Pprof{ - BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), - BlockProfile: *pprofBlock, - CPUProfile: *pprofCPU, - GoRoutineProfile: *pprofGo, - HeapProfile: *pprofHeap, - MutexProfile: *pprofMutex, - Duration: *duration, + BasePath: filepath.Join(*pprofBaseDir, c.runtime, c.Name), + BlockProfile: *pprofBlock, + CPUProfile: *pprofCPU, + HeapProfile: *pprofHeap, + MutexProfile: *pprofMutex, + Duration: *duration, } } @@ -138,9 +136,6 @@ func (p *Pprof) makeProfileArgs(c *Container) []string { if p.CPUProfile { ret = append(ret, fmt.Sprintf("--profile-cpu=%s", filepath.Join(p.BasePath, "cpu.pprof"))) } - if p.GoRoutineProfile { - ret = append(ret, fmt.Sprintf("--profile-goroutine=%s", filepath.Join(p.BasePath, "go.pprof"))) - } if p.HeapProfile { ret = append(ret, fmt.Sprintf("--profile-heap=%s", filepath.Join(p.BasePath, "heap.pprof"))) } diff --git a/pkg/test/dockerutil/profile_test.go b/pkg/test/dockerutil/profile_test.go index b7b4d7618..8c4ffe483 100644 --- a/pkg/test/dockerutil/profile_test.go +++ b/pkg/test/dockerutil/profile_test.go @@ -51,13 +51,12 @@ func TestPprof(t *testing.T) { { name: "All", pprof: Pprof{ - BasePath: basePath, - BlockProfile: true, - CPUProfile: true, - GoRoutineProfile: true, - HeapProfile: true, - MutexProfile: true, - Duration: 2 * time.Second, + BasePath: basePath, + BlockProfile: true, + CPUProfile: true, + HeapProfile: true, + MutexProfile: true, + Duration: 2 * time.Second, }, expectedFiles: []string{block, cpu, goprofle, heap, mutex}, }, |