diff options
Diffstat (limited to 'pkg/sentry')
41 files changed, 850 insertions, 862 deletions
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index c9fb55d00..5138f3bf5 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -251,3 +251,26 @@ func (s *SignalInfo) Arch() uint32 { func (s *SignalInfo) SetArch(val uint32) { usermem.ByteOrder.PutUint32(s.Fields[12:16], val) } + +// Band returns the si_band field. +func (s *SignalInfo) Band() int64 { + return int64(usermem.ByteOrder.Uint64(s.Fields[0:8])) +} + +// SetBand mutates the si_band field. +func (s *SignalInfo) SetBand(val int64) { + // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. + // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is + // `int`. See siginfo.h. + usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) +} + +// FD returns the si_fd field. +func (s *SignalInfo) FD() uint32 { + return usermem.ByteOrder.Uint32(s.Fields[8:12]) +} + +// SetFD mutates the si_fd field. +func (s *SignalInfo) SetFD(val uint32) { + usermem.ByteOrder.PutUint32(s.Fields[8:12], val) +} diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index d2dbff268..a020da53b 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -65,7 +65,7 @@ var ( // runs with the lock held for reading. AsyncBarrier will take the lock // for writing, thus ensuring that all Async work completes before // AsyncBarrier returns. - workMu sync.RWMutex + workMu sync.CrossGoroutineRWMutex // asyncError is used to store up to one asynchronous execution error. asyncError = make(chan error, 1) diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index f8aad2dbd..b998fb75d 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -84,6 +84,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode children := map[string]*fs.Inode{ "hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil), + "sem": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), "shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))), "shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))), "shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))), diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 437473b4a..3cdb1e659 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1353,16 +1353,11 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { return } if refs > 0 { - if d.cached { - // This isn't strictly necessary (fs.cachedDentries is permitted to - // contain dentries with non-zero refs, which are skipped by - // fs.evictCachedDentryLocked() upon reaching the end of the LRU), - // but since we are already holding fs.renameMu for writing we may - // as well. - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentriesLen-- - d.cached = false - } + // This isn't strictly necessary (fs.cachedDentries is permitted to + // contain dentries with non-zero refs, which are skipped by + // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but + // since we are already holding fs.renameMu for writing we may as well. + d.removeFromCacheLocked() return } // Deleted and invalidated dentries with zero references are no longer @@ -1371,20 +1366,18 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { if d.isDeleted() { d.watches.HandleDeletion(ctx) } - if d.cached { - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentriesLen-- - d.cached = false - } + d.removeFromCacheLocked() d.destroyLocked(ctx) return } - // If d still has inotify watches and it is not deleted or invalidated, we - // cannot cache it and allow it to be evicted. Otherwise, we will lose its - // watches, even if a new dentry is created for the same file in the future. - // Note that the size of d.watches cannot concurrently transition from zero - // to non-zero, because adding a watch requires holding a reference on d. + // If d still has inotify watches and it is not deleted or invalidated, it + // can't be evicted. Otherwise, we will lose its watches, even if a new + // dentry is created for the same file in the future. Note that the size of + // d.watches cannot concurrently transition from zero to non-zero, because + // adding a watch requires holding a reference on d. if d.watches.Size() > 0 { + // As in the refs > 0 case, this is not strictly necessary. + d.removeFromCacheLocked() return } @@ -1415,6 +1408,15 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { } } +// Preconditions: d.fs.renameMu must be locked for writing. +func (d *dentry) removeFromCacheLocked() { + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } +} + // Precondition: fs.renameMu must be locked for writing; it may be temporarily // unlocked. func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { @@ -1428,12 +1430,10 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { // * fs.cachedDentriesLen != 0. func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { victim := fs.cachedDentries.Back() - fs.cachedDentries.Remove(victim) - fs.cachedDentriesLen-- - victim.cached = false - // victim.refs may have become non-zero from an earlier path resolution - // since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 { + victim.removeFromCacheLocked() + // victim.refs or victim.watches.Size() may have become non-zero from an + // earlier path resolution since it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 { if victim.parent != nil { victim.parent.dirMu.Lock() if !victim.vfsd.IsDead() { diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 469f3a33d..27b00cf6f 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -16,7 +16,6 @@ package overlay import ( "fmt" - "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -129,25 +128,9 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { return err } defer newFD.DecRef(ctx) - bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size - for { - readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{}) - if readErr != nil && readErr != io.EOF { - cleanupUndoCopyUp() - return readErr - } - total := int64(0) - for total < readN { - writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{}) - total += writeN - if writeErr != nil { - cleanupUndoCopyUp() - return writeErr - } - } - if readErr == io.EOF { - break - } + if _, err := vfs.CopyRegularFileData(ctx, newFD, oldFD); err != nil { + cleanupUndoCopyUp() + return err } d.mapsMu.Lock() defer d.mapsMu.Unlock() diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 2b89a7a6d..25c785fd4 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -103,8 +103,8 @@ func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescript for e, mask := range fd.lowerWaiters { fd.cachedFD.EventUnregister(e) upperFD.EventRegister(e, mask) - if ready&mask != 0 { - e.Callback.Callback(e) + if m := ready & mask; m != 0 { + e.Callback.Callback(e, m) } } } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 7c7afdcfa..25c407d98 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -44,6 +44,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k * return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), + "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 59fcff498..04e7110a3 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -254,7 +254,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: parentMerkleFD, Ctx: ctx, } @@ -397,7 +397,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry } } - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: fd, Ctx: ctx, } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 46346e54d..5788c661f 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -91,9 +91,6 @@ var ( // noCrashOnVerificationFailure indicates whether the sandbox should panic // whenever verification fails. If true, an error is returned instead of // panicking. This should only be set for tests. - // - // TODO(b/165661693): Decide whether to panic or return error based on this - // flag. noCrashOnVerificationFailure bool // verityMu synchronizes concurrent operations that enable verity and perform @@ -752,15 +749,15 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) // // Preconditions: fd.d.fs.verityMu must be locked. func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) { - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, } - merkleReader := vfs.FileReadWriteSeeker{ + merkleReader := FileReadWriteSeeker{ FD: fd.merkleReader, Ctx: ctx, } - merkleWriter := vfs.FileReadWriteSeeker{ + merkleWriter := FileReadWriteSeeker{ FD: fd.merkleWriter, Ctx: ctx, } @@ -1050,12 +1047,12 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } - dataReader := vfs.FileReadWriteSeeker{ + dataReader := FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, } - merkleReader := vfs.FileReadWriteSeeker{ + merkleReader := FileReadWriteSeeker{ FD: fd.merkleReader, Ctx: ctx, } @@ -1104,3 +1101,45 @@ func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence) } + +// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as +// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. +type FileReadWriteSeeker struct { + FD *vfs.FileDescription + Ctx context.Context + ROpts vfs.ReadOptions + WOpts vfs.WriteOptions +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts) + return int(n), err +} + +// Read implements io.ReadWriteSeeker.Read. +func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.Read(f.Ctx, dst, f.ROpts) + return int(n), err +} + +// Seek implements io.ReadWriteSeeker.Seek. +func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { + return f.FD.Seek(f.Ctx, offset, int32(whence)) +} + +// WriteAt implements io.WriterAt.WriteAt. +func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts) + return int(n), err +} + +// Write implements io.ReadWriteSeeker.Write. +func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { + buf := usermem.BytesIOSequence(p) + n, err := f.FD.Write(f.Ctx, buf, f.WOpts) + return int(n), err +} diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 5d1f5de08..6ced0afc9 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -35,14 +35,16 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// rootMerkleFilename is the name of the root Merkle tree file. -const rootMerkleFilename = "root.verity" +const ( + // rootMerkleFilename is the name of the root Merkle tree file. + rootMerkleFilename = "root.verity" + // maxDataSize is the maximum data size of a test file. + maxDataSize = 100000 +) -// maxDataSize is the maximum data size written to the file for test. -const maxDataSize = 100000 +var hashAlgs = []HashAlgorithm{SHA256, SHA512} -// getD returns a *dentry corresponding to VD. -func getD(t *testing.T, vd vfs.VirtualDentry) *dentry { +func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry { t.Helper() d, ok := vd.Dentry().Impl().(*dentry) if !ok { @@ -51,10 +53,21 @@ func getD(t *testing.T, vd vfs.VirtualDentry) *dentry { return d } +// dentryFromFD returns the dentry corresponding to fd. +func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry { + t.Helper() + f, ok := fd.Impl().(*fileDescription) + if !ok { + t.Fatalf("can't assert %T as a *fileDescription", fd) + } + return f.d +} + // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { + t.Helper() k, err := testutil.Boot() if err != nil { t.Fatalf("testutil.Boot: %v", err) @@ -102,7 +115,6 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, t.Fatalf("testutil.CreateTask: %v", err) } - t.Helper() t.Cleanup(func() { root.DecRef(ctx) mntns.DecRef(ctx) @@ -111,6 +123,8 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, } // openVerityAt opens a verity file. +// +// TODO(chongc): release reference from opening the file when done. func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ Root: vd, @@ -123,6 +137,8 @@ func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.Vir } // openLowerAt opens the file in the underlying file system. +// +// TODO(chongc): release reference from opening the file when done. func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ Root: d.lowerVD, @@ -135,6 +151,8 @@ func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, } // openLowerMerkleAt opens the Merkle file in the underlying file system. +// +// TODO(chongc): release reference from opening the file when done. func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ Root: d.lowerMerkleVD, @@ -190,21 +208,11 @@ func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFil }, &vfs.RenameOptions{}) } -// getDentry returns a *dentry corresponds to fd. -func getDentry(t *testing.T, fd *vfs.FileDescription) *dentry { - t.Helper() - f, ok := fd.Impl().(*fileDescription) - if !ok { - t.Fatalf("can't assert %T as a *fileDescription", fd) - } - return f.d -} - // newFileFD creates a new file in the verity mount, and returns the FD. The FD // points to a file that has random data generated. func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { // Create the file in the underlying file system. - lowerFD, err := getD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) + lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) if err != nil { return nil, 0, err } @@ -231,9 +239,8 @@ func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, return fd, dataSize, err } -// corruptRandomBit randomly flips a bit in the file represented by fd. -func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { - // Flip a random bit in the underlying file. +// flipRandomBit randomly flips a bit in the file represented by fd. +func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { randomPos := int64(rand.Intn(size)) byteToModify := make([]byte, 1) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { @@ -246,7 +253,14 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er return nil } -var hashAlgs = []HashAlgorithm{SHA256, SHA512} +func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { + t.Helper() + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("enable verity: %v", err) + } +} // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. @@ -264,12 +278,12 @@ func TestOpen(t *testing.T) { } // Ensure that the corresponding Merkle tree file is created. - if _, err = getDentry(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { + if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) } // Ensure the root merkle tree file is created. - if _, err = getD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { + if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) } } @@ -291,11 +305,7 @@ func TestPReadUnmodifiedFileSucceeds(t *testing.T) { } // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) buf := make([]byte, size) n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) @@ -325,11 +335,7 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) buf := make([]byte, size) n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) @@ -359,11 +365,7 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) { } // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Ensure reopening the verity enabled file succeeds. if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil { @@ -387,21 +389,14 @@ func TestOpenNonexistentFile(t *testing.T) { } // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Enable verity on the parent directory. parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, parentFD) // Ensure open an unexpected file in the parent directory fails with // ENOENT rather than verification failure. @@ -426,20 +421,16 @@ func TestPReadModifiedFileFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerFD that's read/writable. - lowerFD, err := getDentry(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) + lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from the modified file fails. @@ -466,20 +457,16 @@ func TestReadModifiedFileFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerFD that's read/writable. - lowerFD, err := getDentry(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) + lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from the modified file fails. @@ -506,14 +493,10 @@ func TestModifiedMerkleFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerMerkleFD that's read/writable. - lowerMerkleFD, err := getDentry(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) + lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } @@ -524,14 +507,13 @@ func TestModifiedMerkleFails(t *testing.T) { t.Errorf("lowerMerkleFD.Stat: %v", err) } - if err := corruptRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from a file with modified Merkle tree fails. buf := make([]byte, size) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - fmt.Println(buf) t.Fatalf("fd.PRead succeeded with modified Merkle file") } } @@ -554,24 +536,17 @@ func TestModifiedParentMerkleFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Enable verity on the parent directory. parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, parentFD) // Open a new lowerMerkleFD that's read/writable. - parentLowerMerkleFD, err := getDentry(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) + parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } @@ -591,8 +566,8 @@ func TestModifiedParentMerkleFails(t *testing.T) { if err != nil { t.Fatalf("Failed convert size to int: %v", err) } - if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("flipRandomBit: %v", err) } parentLowerMerkleFD.DecRef(ctx) @@ -619,13 +594,8 @@ func TestUnmodifiedStatSucceeds(t *testing.T) { t.Fatalf("newFileFD: %v", err) } - // Enable verity on the file and confirms stat succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } - + // Enable verity on the file and confirm that stat succeeds. + enableVerity(ctx, t, fd) if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { t.Errorf("fd.Stat: %v", err) } @@ -648,11 +618,7 @@ func TestModifiedStatFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } + enableVerity(ctx, t, fd) lowerFD := fd.Impl().(*fileDescription).lowerFD // Change the stat of the underlying file, and check that stat fails. @@ -711,19 +677,15 @@ func TestOpenDeletedFileFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) if tc.changeFile { - if err := getD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil { + if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } if tc.changeMerkleFile { - if err := getD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil { + if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } @@ -776,20 +738,16 @@ func TestOpenRenamedFileFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) newFilename := "renamed-test-file" if tc.changeFile { - if err := getD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil { + if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil { t.Fatalf("RenameAt: %v", err) } } if tc.changeMerkleFile { - if err := getD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil { + if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 15519f0df..61aeca044 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -273,7 +273,7 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent { // // Callback is called when one of the files we're polling becomes ready. It // moves said file to the readyList if it's currently in the waiting list. -func (p *pollEntry) Callback(*waiter.Entry) { +func (p *pollEntry) Callback(*waiter.Entry, waiter.EventMask) { e := p.epoll e.listsMu.Lock() @@ -306,9 +306,8 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) { f.EventRegister(&entry.waiter, entry.mask) // Check if the file happens to already be in a ready state. - ready := f.Readiness(entry.mask) & entry.mask - if ready != 0 { - entry.Callback(&entry.waiter) + if ready := f.Readiness(entry.mask) & entry.mask; ready != 0 { + entry.Callback(&entry.waiter, ready) } } diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 2b3955598..f855f038b 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -8,11 +8,13 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/sentry/arch", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/syserror", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index 153d2cd9b..b66d61c6f 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -17,22 +17,45 @@ package fasync import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) -// New creates a new fs.FileAsync. -func New() fs.FileAsync { - return &FileAsync{} +// Table to convert waiter event masks into si_band siginfo codes. +// Taken from fs/fcntl.c:band_table. +var bandTable = map[waiter.EventMask]int64{ + // POLL_IN + waiter.EventIn: linux.EPOLLIN | linux.EPOLLRDNORM, + // POLL_OUT + waiter.EventOut: linux.EPOLLOUT | linux.EPOLLWRNORM | linux.EPOLLWRBAND, + // POLL_ERR + waiter.EventErr: linux.EPOLLERR, + // POLL_PRI + waiter.EventPri: linux.EPOLLPRI | linux.EPOLLRDBAND, + // POLL_HUP + waiter.EventHUp: linux.EPOLLHUP | linux.EPOLLERR, } -// NewVFS2 creates a new vfs.FileAsync. -func NewVFS2() vfs.FileAsync { - return &FileAsync{} +// New returns a function that creates a new fs.FileAsync with the given file +// descriptor. +func New(fd int) func() fs.FileAsync { + return func() fs.FileAsync { + return &FileAsync{fd: fd} + } +} + +// NewVFS2 returns a function that creates a new vfs.FileAsync with the given +// file descriptor. +func NewVFS2(fd int) func() vfs.FileAsync { + return func() vfs.FileAsync { + return &FileAsync{fd: fd} + } } // FileAsync sends signals when the registered file is ready for IO. @@ -42,6 +65,12 @@ type FileAsync struct { // e is immutable after first use (which is protected by mu below). e waiter.Entry + // fd is the file descriptor to notify about. + // It is immutable, set at allocation time. This matches Linux semantics in + // fs/fcntl.c:fasync_helper. + // The fd value is passed to the signal recipient in siginfo.si_fd. + fd int + // regMu protects registeration and unregistration actions on e. // // regMu must be held while registration decisions are being made @@ -56,6 +85,10 @@ type FileAsync struct { mu sync.Mutex `state:"nosave"` requester *auth.Credentials registered bool + // signal is the signal to deliver upon I/O being available. + // The default value ("zero signal") means the default SIGIO signal will be + // delivered. + signal linux.Signal // Only one of the following is allowed to be non-nil. recipientPG *kernel.ProcessGroup @@ -64,10 +97,10 @@ type FileAsync struct { } // Callback sends a signal. -func (a *FileAsync) Callback(e *waiter.Entry) { +func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) { a.mu.Lock() + defer a.mu.Unlock() if !a.registered { - a.mu.Unlock() return } t := a.recipientT @@ -80,19 +113,34 @@ func (a *FileAsync) Callback(e *waiter.Entry) { } if t == nil { // No recipient has been registered. - a.mu.Unlock() return } c := t.Credentials() // Logic from sigio_perm in fs/fcntl.c. - if a.requester.EffectiveKUID == 0 || + permCheck := (a.requester.EffectiveKUID == 0 || a.requester.EffectiveKUID == c.SavedKUID || a.requester.EffectiveKUID == c.RealKUID || a.requester.RealKUID == c.SavedKUID || - a.requester.RealKUID == c.RealKUID { - t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO)) + a.requester.RealKUID == c.RealKUID) + if !permCheck { + return } - a.mu.Unlock() + signalInfo := &arch.SignalInfo{ + Signo: int32(linux.SIGIO), + Code: arch.SignalInfoKernel, + } + if a.signal != 0 { + signalInfo.Signo = int32(a.signal) + signalInfo.SetFD(uint32(a.fd)) + var band int64 + for m, bandCode := range bandTable { + if m&mask != 0 { + band |= bandCode + } + } + signalInfo.SetBand(band) + } + t.SendSignal(signalInfo) } // Register sets the file which will be monitored for IO events. @@ -186,3 +234,25 @@ func (a *FileAsync) ClearOwner() { a.recipientTG = nil a.recipientPG = nil } + +// Signal returns which signal will be sent to the signal recipient. +// A value of zero means the signal to deliver wasn't customized, which means +// the default signal (SIGIO) will be delivered. +func (a *FileAsync) Signal() linux.Signal { + a.mu.Lock() + defer a.mu.Unlock() + return a.signal +} + +// SetSignal overrides which signal to send when I/O is available. +// The default behavior can be reset by specifying signal zero, which means +// to send SIGIO. +func (a *FileAsync) SetSignal(signal linux.Signal) error { + if signal != 0 && !signal.IsValid() { + return syserror.EINVAL + } + a.mu.Lock() + defer a.mu.Unlock() + a.signal = signal + return nil +} diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index 470d8bf83..f17f9c59c 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -121,18 +121,21 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 panic("VFS1 and VFS2 files set") } - slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) + slicePtr := (*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) // Grow the table as required. - if last := int32(len(slice)); fd >= last { + if last := int32(len(*slicePtr)); fd >= last { end := fd + 1 if end < 2*last { end = 2 * last } - slice = append(slice, make([]unsafe.Pointer, end-last)...) - atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) + newSlice := append(*slicePtr, make([]unsafe.Pointer, end-last)...) + slicePtr = &newSlice + atomic.StorePointer(&f.slice, unsafe.Pointer(slicePtr)) } + slice := *slicePtr + var desc *descriptor if file != nil || fileVFS2 != nil { desc = &descriptor{ diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index b99c0bffa..31198d772 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -29,17 +29,17 @@ import ( ) const ( - valueMax = 32767 // SEMVMX + // Maximum semaphore value. + valueMax = linux.SEMVMX - // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL). - semaphoresMax = 32000 + // Maximum number of semaphore sets. + setsMax = linux.SEMMNI - // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI). - setsMax = 32000 + // Maximum number of semaphroes in a semaphore set. + semsMax = linux.SEMMSL - // semaphoresTotalMax is "system-wide limit on the number of semaphores" - // (SEMMNS = SEMMNI*SEMMSL). - semaphoresTotalMax = 1024000000 + // Maximum number of semaphores in all semaphroe sets. + semsTotalMax = linux.SEMMNS ) // Registry maintains a set of semaphores that can be found by key or ID. @@ -122,7 +122,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { // be found. If exclusive is true, it fails if a set with the same key already // exists. func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) { - if nsems < 0 || nsems > semaphoresMax { + if nsems < 0 || nsems > semsMax { return nil, syserror.EINVAL } @@ -166,7 +166,7 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu if len(r.semaphores) >= setsMax { return nil, syserror.EINVAL } - if r.totalSems() > int(semaphoresTotalMax-nsems) { + if r.totalSems() > int(semsTotalMax-nsems) { return nil, syserror.EINVAL } @@ -176,6 +176,22 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu return r.newSet(ctx, key, owner, owner, perms, nsems) } +// IPCInfo returns information about system-wide semaphore limits and parameters. +func (r *Registry) IPCInfo() *linux.SemInfo { + return &linux.SemInfo{ + SemMap: linux.SEMMAP, + SemMni: linux.SEMMNI, + SemMns: linux.SEMMNS, + SemMnu: linux.SEMMNU, + SemMsl: linux.SEMMSL, + SemOpm: linux.SEMOPM, + SemUme: linux.SEMUME, + SemUsz: 0, // SemUsz not supported. + SemVmx: linux.SEMVMX, + SemAem: linux.SEMAEM, + } +} + // RemoveID removes set with give 'id' from the registry and marks the set as // dead. All waiters will be awakened and fail. func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 7c297fb9e..d99be7f46 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -423,11 +423,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File } if f.opts.ManualZeroing { - if err := f.forEachMappingSlice(fr, func(bs []byte) { - for i := range bs { - bs[i] = 0 - } - }); err != nil { + if err := f.manuallyZero(fr); err != nil { return memmap.FileRange{}, err } } @@ -560,19 +556,39 @@ func (f *MemoryFile) Decommit(fr memmap.FileRange) error { panic(fmt.Sprintf("invalid range: %v", fr)) } + if f.opts.ManualZeroing { + // FALLOC_FL_PUNCH_HOLE may not zero pages if ManualZeroing is in + // effect. + if err := f.manuallyZero(fr); err != nil { + return err + } + } else { + if err := f.decommitFile(fr); err != nil { + return err + } + } + + f.markDecommitted(fr) + return nil +} + +func (f *MemoryFile) manuallyZero(fr memmap.FileRange) error { + return f.forEachMappingSlice(fr, func(bs []byte) { + for i := range bs { + bs[i] = 0 + } + }) +} + +func (f *MemoryFile) decommitFile(fr memmap.FileRange) error { // "After a successful call, subsequent reads from this range will // return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with // FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2) - err := syscall.Fallocate( + return syscall.Fallocate( int(f.file.Fd()), _FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE, int64(fr.Start), int64(fr.Length())) - if err != nil { - return err - } - f.markDecommitted(fr) - return nil } func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { @@ -1044,20 +1060,20 @@ func (f *MemoryFile) runReclaim() { break } - if err := f.Decommit(fr); err != nil { - log.Warningf("Reclaim failed to decommit %v: %v", fr, err) - // Zero the pages manually. This won't reduce memory usage, but at - // least ensures that the pages will be zero when reallocated. - f.forEachMappingSlice(fr, func(bs []byte) { - for i := range bs { - bs[i] = 0 + // If ManualZeroing is in effect, pages will be zeroed on allocation + // and may not be freed by decommitFile, so calling decommitFile is + // unnecessary. + if !f.opts.ManualZeroing { + if err := f.decommitFile(fr); err != nil { + log.Warningf("Reclaim failed to decommit %v: %v", fr, err) + // Zero the pages manually. This won't reduce memory usage, but at + // least ensures that the pages will be zero when reallocated. + if err := f.manuallyZero(fr); err != nil { + panic(fmt.Sprintf("Reclaim failed to decommit or zero %v: %v", fr, err)) } - }) - // Pretend the pages were decommitted even though they weren't, - // since the memory accounting implementation has no idea how to - // deal with this. - f.markDecommitted(fr) + } } + f.markDecommitted(fr) f.markReclaimed(fr) } diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index f56aa3b79..571bfcc2e 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -18,8 +18,8 @@ // // In a nutshell, it works as follows: // -// The creation of a new address space creates a new child processes with a -// single thread which is traced by a single goroutine. +// The creation of a new address space creates a new child process with a single +// thread which is traced by a single goroutine. // // A context is just a collection of temporary variables. Calling Switch on a // context does the following: diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 679b287c3..5d01d21dd 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -72,7 +72,6 @@ go_library( "lib_amd64.s", "lib_arm64.go", "lib_arm64.s", - "lib_arm64_unsafe.go", "ring0.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 5f4b2105a..cf0bf3528 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -132,40 +132,6 @@ MOVD offset+PTRACE_R29(reg), R29; \ MOVD offset+PTRACE_R30(reg), R30; -// NOP-s -#define nop31Instructions() \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; - #define ESR_ELx_EC_UNKNOWN (0x00) #define ESR_ELx_EC_WFx (0x01) /* Unallocated EC: 0x02 */ @@ -764,79 +730,43 @@ TEXT ·El0_error_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) // Vectors implements exception vector table. +// The start address of exception vector table should be 11-bits aligned. +// For detail, please refer to arm developer document: +// https://developer.arm.com/documentation/100933/0100/AArch64-exception-vector-table +// Also can refer to the code in linux kernel: arch/arm64/kernel/entry.S TEXT ·Vectors(SB),NOSPLIT,$0 + PCALIGN $2048 B ·El1_sync_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_irq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_fiq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_error_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_sync(SB) - nop31Instructions() + PCALIGN $128 B ·El1_irq(SB) - nop31Instructions() + PCALIGN $128 B ·El1_fiq(SB) - nop31Instructions() + PCALIGN $128 B ·El1_error(SB) - nop31Instructions() + PCALIGN $128 B ·El0_sync(SB) - nop31Instructions() + PCALIGN $128 B ·El0_irq(SB) - nop31Instructions() + PCALIGN $128 B ·El0_fiq(SB) - nop31Instructions() + PCALIGN $128 B ·El0_error(SB) - nop31Instructions() + PCALIGN $128 B ·El0_sync_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_irq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_fiq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_error_invalid(SB) - nop31Instructions() - - // The exception-vector-table is required to be 11-bits aligned. - // Please see Linux source code as reference: arch/arm64/kernel/entry.s. - // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs. - // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. - WORD $0xd503201f //nop - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 0dffd33a3..ef0d8974d 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -62,6 +62,4 @@ func DisableVFP() // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. -func Init() { - rewriteVectors() -} +func Init() {} diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go deleted file mode 100644 index c05166fea..000000000 --- a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package ring0 - -import ( - "reflect" - "syscall" - "unsafe" - - "gvisor.dev/gvisor/pkg/safecopy" - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - nopInstruction = 0xd503201f - instSize = unsafe.Sizeof(uint32(0)) - vectorsRawLen = 0x800 -) - -func unsafeSlice(addr uintptr, length int) (slice []uint32) { - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) - hdr.Data = addr - hdr.Len = length / int(instSize) - hdr.Cap = length / int(instSize) - return slice -} - -// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. -// -// According to the design documentation of Arm64, -// the start address of exception vector table should be 11-bits aligned. -// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S -// But, we can't align a function's start address to a specific address by using golang. -// We have raised this question in golang community: -// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I -// This function will be removed when golang supports this feature. -// -// There are 2 jobs were implemented in this function: -// 1, move the start address of exception vector table into the specific address. -// 2, modify the offset of each instruction. -func rewriteVectors() { - vectorsBegin := reflect.ValueOf(Vectors).Pointer() - - // The exception-vector-table is required to be 11-bits aligned. - // And the size is 0x800. - // Please see the documentation as reference: - // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table - // - // But, golang does not allow to set a function's address to a specific value. - // So, for gvisor, I defined the size of exception-vector-table as 4K, - // filled the 2nd 2K part with NOP-s. - // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. - // - // So, the prerequisite for this function to work correctly is: - // vectorsSafeLen >= 0x1000 - // vectorsRawLen = 0x800 - vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin) - if vectorsSafeLen < 2*vectorsRawLen { - panic("Can't update vectors") - } - - vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32 - vectorsRawLen32 := vectorsRawLen / int(instSize) - - offset := vectorsBegin & (1<<11 - 1) - if offset != 0 { - offset = 1<<11 - offset - } - - pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1) - - _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)) - if errno != 0 { - panic(errno.Error()) - } - - offset = offset / instSize // By index, not bytes. - // Move exception-vector-table into the specific address, should uses memmove here. - for i := 1; i <= vectorsRawLen32; i++ { - vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i] - } - - // Adjust branch since instruction was moved forward. - for i := 0; i < vectorsRawLen32; i++ { - if vectorsSafeTable[int(offset)+i] != nopInstruction { - vectorsSafeTable[int(offset)+i] -= uint32(offset) - } - } - - _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC)) - if errno != 0 { - panic(errno.Error()) - } -} diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index a3f775d15..cc1f6bfcc 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e8a0103bf..5e9ab97ad 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -84,69 +84,73 @@ var Metrics = tcpip.Stats{ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), ICMP: tcpip.ICMPStats{ - V4PacketsSent: tcpip.ICMPv4SentPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + V4: tcpip.ICMPv4Stats{ + PacketsSent: tcpip.ICMPv4SentPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), - }, - V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - V6PacketsSent: tcpip.ICMPv6SentPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + V6: tcpip.ICMPv6Stats{ + PacketsSent: tcpip.ICMPv6SentPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), - }, - V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, }, IP: tcpip.IPStats{ @@ -209,18 +213,6 @@ const sizeOfInt32 int = 4 var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) -// ntohs converts a 16-bit number from network byte order to host byte order. It -// assumes that the host is little endian. -func ntohs(v uint16) uint16 { - return v<<8 | v>>8 -} - -// htons converts a 16-bit number from host byte order to network byte order. It -// assumes that the host is little endian. -func htons(v uint16) uint16 { - return ntohs(v) -} - // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. type commonEndpoint interface { @@ -240,10 +232,6 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and - // transport.Endpoint.SetSockOptBool. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -252,14 +240,14 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and - // transport.Endpoint.GetSockOpt. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 + // LastError implements tcpip.Endpoint.LastError and // transport.Endpoint.LastError. LastError() *tcpip.Error @@ -334,9 +322,7 @@ type socketOpsCommon struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } dirent := socket.NewDirent(t, netstackDevice) @@ -365,88 +351,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, -// AF_INET6, and AF_PACKET addresses. -// -// AddressAndFamily returns an address and its family. -func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { - // Make sure we have at least 2 bytes for the address family. - if len(addr) < 2 { - return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument - } - - // Get the rest of the fields based on the address family. - switch family := usermem.ByteOrder.Uint16(addr); family { - case linux.AF_UNIX: - path := addr[2:] - if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - // Drop the terminating NUL (if one exists) and everything after - // it for filesystem (non-abstract) addresses. - if len(path) > 0 && path[0] != 0 { - if n := bytes.IndexByte(path[1:], 0); n >= 0 { - path = path[:n+1] - } - } - return tcpip.FullAddress{ - Addr: tcpip.Address(path), - }, family, nil - - case linux.AF_INET: - var a linux.SockAddrInet - if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - return out, family, nil - - case linux.AF_INET6: - var a linux.SockAddrInet6 - if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - if isLinkLocal(out.Addr) { - out.NIC = tcpip.NICID(a.Scope_id) - } - return out, family, nil - - case linux.AF_PACKET: - var a linux.SockAddrLink - if len(addr) < sockAddrLinkSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) - if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/173): Return protocol too. - return tcpip.FullAddress{ - NIC: tcpip.NICID(a.InterfaceIndex), - Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), - }, family, nil - - case linux.AF_UNSPEC: - return tcpip.FullAddress{}, family, nil - - default: - return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported - } -} - func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } @@ -723,11 +627,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { return nil } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { - v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return syserr.TranslateNetstackError(err) - } - if !v { + if !s.Endpoint.SocketOptions().GetV6Only() { return nil } } @@ -751,7 +651,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -832,7 +732,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } } else { var err *syserr.Error - addr, family, err = AddressAndFamily(sockaddr) + addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -923,7 +823,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -1007,7 +907,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptSocket(t, s, ep, family, skType, name, outLen) case linux.SOL_TCP: - return getSockOptTCP(t, ep, name, outLen) + return getSockOptTCP(t, s, ep, name, outLen) case linux.SOL_IPV6: return getSockOptIPv6(t, s, ep, name, outPtr, outLen) @@ -1226,8 +1126,13 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v := primitive.Int32(boolToInt32(ep.SocketOptions().GetAcceptConn())) - return &v, nil + // This option is only viable for TCP endpoints. + var v bool + if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) { + v = tcp.EndpointState(ep.State()) == tcp.StateListen + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1236,46 +1141,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(!v)) - return &vP, nil + v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption())) + return &v, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.CorkOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption())) + return &v, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.QuickAckOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck())) + return &v, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1449,19 +1344,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + + family, skType, _ := s.Type() + if family != linux.AF_INET6 { + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only())) + return &v, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1493,13 +1393,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1511,7 +1406,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: @@ -1520,7 +1415,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1540,7 +1435,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1560,7 +1455,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1582,6 +1477,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name // getSockOptIP implements GetSockOpt when level is SOL_IP. func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1624,7 +1524,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) return &a.(*linux.SockAddrInet).Addr, nil @@ -1633,13 +1533,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop())) + return &v, nil case linux.IP_TOS: // Length handling for parity with Linux. @@ -1663,26 +1558,24 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) + return &v, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo())) + return &v, nil + + case linux.IP_HDRINCL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) + return &v, nil case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { @@ -1694,7 +1587,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet), nil case linux.IPT_SO_GET_INFO: @@ -1801,7 +1694,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptSocket(t, s, ep, name, optVal) case linux.SOL_TCP: - return setSockOptTCP(t, ep, name, optVal) + return setSockOptTCP(t, s, ep, name, optVal) case linux.SOL_IPV6: return setSockOptIPv6(t, s, ep, name, optVal) @@ -1991,7 +1884,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. -func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if len(optVal) < sizeOfInt32 { @@ -1999,7 +1897,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) + ep.SocketOptions().SetDelayOption(v == 0) + return nil case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -2007,7 +1906,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) + ep.SocketOptions().SetCorkOption(v != 0) + return nil case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -2015,7 +1915,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) + ep.SocketOptions().SetQuickAck(v != 0) + return nil case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -2127,14 +2028,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + + family, skType, skProto := s.Type() + if family != linux.AF_INET6 { + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } + if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { + return syserr.ErrInvalidEndpointState + } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + return syserr.ErrInvalidEndpointState + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) + ep.SocketOptions().SetV6Only(v != 0) + return nil case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, @@ -2173,7 +2091,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + ep.SocketOptions().SetReceiveTClass(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2181,7 +2100,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return syserr.ErrProtocolNotAvailable } @@ -2256,6 +2175,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2308,7 +2232,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), - InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), + InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]), })) case linux.IP_MULTICAST_LOOP: @@ -2317,7 +2241,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) + ep.SocketOptions().SetMulticastLoop(v != 0) + return nil case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2353,7 +2278,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + ep.SocketOptions().SetReceiveTOS(v != 0) + return nil case linux.IP_PKTINFO: if len(optVal) == 0 { @@ -2363,7 +2289,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + ep.SocketOptions().SetReceivePacketInfo(v != 0) + return nil case linux.IP_HDRINCL: if len(optVal) == 0 { @@ -2373,7 +2300,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + ep.SocketOptions().SetHeaderIncluded(v != 0) + return nil case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { @@ -2515,7 +2443,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { switch name { case linux.IP_TOS, linux.IP_TTL, - linux.IP_HDRINCL, linux.IP_OPTIONS, linux.IP_ROUTER_ALERT, linux.IP_RECVOPTS, @@ -2562,72 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { } } -// isLinkLocal determines if the given IPv6 address is link-local. This is the -// case when it has the fe80::/10 prefix. This check is used to determine when -// the NICID is relevant for a given IPv6 address. -func isLinkLocal(addr tcpip.Address) bool { - return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 -} - -// ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { - switch family { - case linux.AF_UNIX: - var out linux.SockAddrUnix - out.Family = linux.AF_UNIX - l := len([]byte(addr.Addr)) - for i := 0; i < l; i++ { - out.Path[i] = int8(addr.Addr[i]) - } - - // Linux returns the used length of the address struct (including the - // null terminator) for filesystem paths. The Family field is 2 bytes. - // It is sometimes allowed to exclude the null terminator if the - // address length is the max. Abstract and empty paths always return - // the full exact length. - if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return &out, uint32(2 + l) - } - return &out, uint32(3 + l) - - case linux.AF_INET: - var out linux.SockAddrInet - copy(out.Addr[:], addr.Addr) - out.Family = linux.AF_INET - out.Port = htons(addr.Port) - return &out, uint32(sockAddrInetSize) - - case linux.AF_INET6: - var out linux.SockAddrInet6 - if len(addr.Addr) == header.IPv4AddressSize { - // Copy address in v4-mapped format. - copy(out.Addr[12:], addr.Addr) - out.Addr[10] = 0xff - out.Addr[11] = 0xff - } else { - copy(out.Addr[:], addr.Addr) - } - out.Family = linux.AF_INET6 - out.Port = htons(addr.Port) - if isLinkLocal(addr.Addr) { - out.Scope_id = uint32(addr.NIC) - } - return &out, uint32(sockAddrInet6Size) - - case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. - var out linux.SockAddrLink - out.Family = linux.AF_PACKET - out.InterfaceIndex = int32(addr.NIC) - out.HardwareAddrLen = header.EthernetAddressSize - copy(out.HardwareAddr[:], addr.Addr) - return &out, uint32(sockAddrLinkSize) - - default: - return nil, 0 - } -} - // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { @@ -2636,7 +2497,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2648,7 +2509,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2797,10 +2658,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { - addr, addrLen = ConvertAddress(s.family, s.sender) + addr, addrLen = socket.ConvertAddress(s.family, s.sender) switch v := addr.(type) { case *linux.SockAddrLink: - v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } @@ -2965,7 +2826,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, family, err := AddressAndFamily(to) + addrBuf, family, err := socket.AddressAndFamily(to) if err != nil { return 0, err } @@ -3384,6 +3245,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { return rv } +func isTCPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP) +} + +func isUDPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP) +} + +func isICMPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6) +} + // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. func (s *socketOpsCommon) State() uint32 { @@ -3393,7 +3266,7 @@ func (s *socketOpsCommon) State() uint32 { } switch { - case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: + case isTCPSocket(s.skType, s.protocol): // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -3422,7 +3295,7 @@ func (s *socketOpsCommon) State() uint32 { // Internal or unknown state. return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + case isUDPSocket(s.skType, s.protocol): // UDP socket. switch udp.EndpointState(s.Endpoint.State()) { case udp.StateInitial, udp.StateBound, udp.StateClosed: @@ -3432,7 +3305,7 @@ func (s *socketOpsCommon) State() uint32 { default: return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + case isICMPSocket(s.skType, s.protocol): // TODO(b/112063468): Export states for ICMP sockets. case s.skType == linux.SOCK_RAW: // TODO(b/112063468): Export states for raw sockets. diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index b0d9e4d9e..b756bfca0 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // NewVFS2 creates a new endpoint socket. func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } mnt := t.Kernel().SocketMount() @@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addrLen uint32 if peerAddr != nil { // Get address of the peer and write it to peer slice. - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index ead3b2b79..c847ff1c7 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 2a01143f6..0af805246 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fa9ac9059..cc0fadeb5 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: - in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats - out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ 0, // Icmp/InMsgs. - Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors. 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. @@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. 0, // Icmp/OutMsgs. - Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. - out.DstUnreachable.Value(), // OutDestUnreachs. - out.TimeExceeded.Value(), // OutTimeExcds. - out.ParamProblem.Value(), // OutParmProbs. - out.SrcQuench.Value(), // OutSrcQuenchs. - out.Redirect.Value(), // OutRedirects. - out.Echo.Value(), // OutEchos. - out.EchoReply.Value(), // OutEchoReps. - out.Timestamp.Value(), // OutTimestamps. - out.TimestampReply.Value(), // OutTimestampReps. - out.InfoRequest.Value(), // OutAddrMasks. - out.InfoReply.Value(), // OutAddrMaskReps. + Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. } case *inet.StatSNMPTCP: tcp := Metrics.TCP diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fd31479e5..9049e8a21 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -18,6 +18,7 @@ package socket import ( + "bytes" "fmt" "sync/atomic" "syscall" @@ -35,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/usermem" ) @@ -460,3 +462,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { panic(fmt.Sprintf("Unsupported socket family %v", family)) } } + +var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes() +var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes() +var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes() + +// Ntohs converts a 16-bit number from network byte order to host byte order. It +// assumes that the host is little endian. +func Ntohs(v uint16) uint16 { + return v<<8 | v>>8 +} + +// Htons converts a 16-bit number from host byte order to network byte order. It +// assumes that the host is little endian. +func Htons(v uint16) uint16 { + return Ntohs(v) +} + +// isLinkLocal determines if the given IPv6 address is link-local. This is the +// case when it has the fe80::/10 prefix. This check is used to determine when +// the NICID is relevant for a given IPv6 address. +func isLinkLocal(addr tcpip.Address) bool { + return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 +} + +// ConvertAddress converts the given address to a native format. +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { + switch family { + case linux.AF_UNIX: + var out linux.SockAddrUnix + out.Family = linux.AF_UNIX + l := len([]byte(addr.Addr)) + for i := 0; i < l; i++ { + out.Path[i] = int8(addr.Addr[i]) + } + + // Linux returns the used length of the address struct (including the + // null terminator) for filesystem paths. The Family field is 2 bytes. + // It is sometimes allowed to exclude the null terminator if the + // address length is the max. Abstract and empty paths always return + // the full exact length. + if l == 0 || out.Path[0] == 0 || l == len(out.Path) { + return &out, uint32(2 + l) + } + return &out, uint32(3 + l) + + case linux.AF_INET: + var out linux.SockAddrInet + copy(out.Addr[:], addr.Addr) + out.Family = linux.AF_INET + out.Port = Htons(addr.Port) + return &out, uint32(sockAddrInetSize) + + case linux.AF_INET6: + var out linux.SockAddrInet6 + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. + copy(out.Addr[12:], addr.Addr) + out.Addr[10] = 0xff + out.Addr[11] = 0xff + } else { + copy(out.Addr[:], addr.Addr) + } + out.Family = linux.AF_INET6 + out.Port = Htons(addr.Port) + if isLinkLocal(addr.Addr) { + out.Scope_id = uint32(addr.NIC) + } + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(gvisor.dev/issue/173): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + + default: + return nil, 0 + } +} + +// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the +// netstack representation taking any addresses into account. +func BytesToIPAddress(addr []byte) tcpip.Address { + if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { + return "" + } + return tcpip.Address(addr) +} + +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. +// +// AddressAndFamily returns an address and its family. +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { + // Make sure we have at least 2 bytes for the address family. + if len(addr) < 2 { + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument + } + + // Get the rest of the fields based on the address family. + switch family := usermem.ByteOrder.Uint16(addr); family { + case linux.AF_UNIX: + path := addr[2:] + if len(path) > linux.UnixPathMax { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + // Drop the terminating NUL (if one exists) and everything after + // it for filesystem (non-abstract) addresses. + if len(path) > 0 && path[0] != 0 { + if n := bytes.IndexByte(path[1:], 0); n >= 0 { + path = path[:n+1] + } + } + return tcpip.FullAddress{ + Addr: tcpip.Address(path), + }, family, nil + + case linux.AF_INET: + var a linux.SockAddrInet + if len(addr) < sockAddrInetSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + return out, family, nil + + case linux.AF_INET6: + var a linux.SockAddrInet6 + if len(addr) < sockAddrInet6Size { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + if isLinkLocal(out.Addr) { + out.NIC = tcpip.NICID(a.Scope_id) + } + return out, family, nil + + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/173): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, family, nil + + default: + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported + } +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 4abea90cc..0247e93fa 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -178,10 +178,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool sets a socket option for simple cases when a value has - // the int type. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt sets a socket option for simple cases when a value has // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -189,10 +185,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool gets a socket option for simple cases when a return - // value has the int type. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -857,11 +849,6 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { return nil } -func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - log.Warningf("Unsupported socket option: %d", opt) - return nil -} - func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { case tcpip.SendBufferSizeOption: @@ -872,11 +859,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - log.Warningf("Unsupported socket option: %d", opt) - return false, tcpip.ErrUnknownProtocolOption -} - func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index b32bb7ba8..c59297c80 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -136,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, family, err := netstack.AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { if err == syserr.ErrAddressFamilyNotSupported { err = syserr.ErrInvalidArgument @@ -169,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -181,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -255,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -647,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -682,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index eaf0b0d26..27f705bb2 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index a920180d3..d36a64ffc 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -32,8 +32,8 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/kernel", + "//pkg/sentry/socket", "//pkg/sentry/socket/netlink", - "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", "//pkg/usermem", ], diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index cc5f70cd4..d943a7cb1 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" - "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/usermem" ) @@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := netstack.AddressAndFamily(b) + fa, _, err := socket.AddressAndFamily(b) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index bb1f715e2..cff442846 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 519066a47..8db587401 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -646,7 +646,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } - fSetOwn(t, file, set) + fSetOwn(t, int(fd), file, set) return 0, nil, nil case linux.FIOGETOWN, linux.SIOCGPGRP: @@ -901,8 +901,8 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 { // // If who is positive, it represents a PID. If negative, it represents a PGID. // If the PID or PGID is invalid, the owner is silently unset. -func fSetOwn(t *kernel.Task, file *fs.File, who int32) error { - a := file.Async(fasync.New).(*fasync.FileAsync) +func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error { + a := file.Async(fasync.New(fd)).(*fasync.FileAsync) if who < 0 { // Check for overflow before flipping the sign. if who-1 > who { @@ -1049,7 +1049,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN: return uintptr(fGetOwn(t, file)), nil, nil case linux.F_SETOWN: - return 0, nil, fSetOwn(t, file, args[2].Int()) + return 0, nil, fSetOwn(t, int(fd), file, args[2].Int()) case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) @@ -1062,7 +1062,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - a := file.Async(fasync.New).(*fasync.FileAsync) + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) switch owner.Type { case linux.F_OWNER_TID: task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID)) @@ -1111,6 +1111,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } n, err := sz.SetFifoSize(int64(args[2].Int())) return uintptr(n), nil, err + case linux.F_GETSIG: + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) + return uintptr(a.Signal()), nil, nil + case linux.F_SETSIG: + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) + return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index e383a0a87..a1601676f 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -146,8 +146,15 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal v, err := getNCnt(t, id, num) return uintptr(v), nil, err - case linux.IPC_INFO, - linux.SEM_INFO, + case linux.IPC_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.IPCInfo() + _, err := info.CopyOut(t, buf) + // TODO(gvisor.dev/issue/137): Return the index of the highest used entry. + return 0, nil, err + + case linux.SEM_INFO, linux.SEM_STAT, linux.SEM_STAT_ANY: diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 36e89700e..7dd9ef857 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -165,7 +165,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ownerType = linux.F_OWNER_PGRP who = -who } - return 0, nil, setAsyncOwner(t, file, ownerType, who) + return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who) case linux.F_GETOWN_EX: owner, hasOwner := getAsyncOwner(t, file) if !hasOwner { @@ -179,7 +179,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID) + return 0, nil, setAsyncOwner(t, int(fd), file, owner.Type, owner.PID) case linux.F_SETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { @@ -207,6 +207,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err case linux.F_SETLK, linux.F_SETLKW: return 0, nil, posixLock(t, args, file, cmd) + case linux.F_GETSIG: + a := file.AsyncHandler() + if a == nil { + // Default behavior aka SIGIO. + return 0, nil, nil + } + return uintptr(a.(*fasync.FileAsync).Signal()), nil, nil + case linux.F_SETSIG: + a := file.SetAsyncHandler(fasync.NewVFS2(int(fd))).(*fasync.FileAsync) + return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. return 0, nil, syserror.EINVAL @@ -241,7 +251,7 @@ func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwne } } -func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error { +func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, pid int32) error { switch ownerType { case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP: // Acceptable type. @@ -249,7 +259,7 @@ func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32 return syserror.EINVAL } - a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync) + a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync) if pid == 0 { a.ClearOwner() return nil diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 2806c3f6f..20c264fef 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -100,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ownerType = linux.F_OWNER_PGRP who = -who } - return 0, nil, setAsyncOwner(t, file, ownerType, who) + return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who) } ret, err := file.Ioctl(t, t.MemoryManager(), args) diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index a98aac52b..072655fe8 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -204,8 +204,8 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin file.EventRegister(&epi.waiter, wmask) // Check if the file is already ready. - if file.Readiness(wmask)&wmask != 0 { - epi.Callback(nil) + if m := file.Readiness(wmask) & wmask; m != 0 { + epi.Callback(nil, m) } // Add epi to file.epolls so that it is removed when the last @@ -274,8 +274,8 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event file.EventRegister(&epi.waiter, wmask) // Check if the file is already ready with the new mask. - if file.Readiness(wmask)&wmask != 0 { - epi.Callback(nil) + if m := file.Readiness(wmask) & wmask; m != 0 { + epi.Callback(nil, m) } return nil @@ -311,7 +311,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error } // Callback implements waiter.EntryCallback.Callback. -func (epi *epollInterest) Callback(*waiter.Entry) { +func (epi *epollInterest) Callback(*waiter.Entry, waiter.EventMask) { newReady := false epi.epoll.mu.Lock() if !epi.ready { diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index f9e39a94c..5321ac80a 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -15,6 +15,7 @@ package vfs import ( + "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -43,7 +44,7 @@ import ( type FileDescription struct { FileDescriptionRefs - // flagsMu protects statusFlags and asyncHandler below. + // flagsMu protects `statusFlags`, `saved`, and `asyncHandler` below. flagsMu sync.Mutex `state:"nosave"` // statusFlags contains status flags, "initialized by open(2) and possibly @@ -52,6 +53,11 @@ type FileDescription struct { // access to asyncHandler. statusFlags uint32 + // saved is true after beforeSave is called. This is used to prevent + // double-unregistration of asyncHandler. This does not work properly for + // save-resume, which is not currently supported in gVisor (see b/26588733). + saved bool `state:"nosave"` + // asyncHandler handles O_ASYNC signal generation. It is set with the // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must // also be set by fcntl(2). @@ -184,7 +190,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.vd.DecRef(ctx) fd.flagsMu.Lock() - if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + if !fd.saved && fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { fd.asyncHandler.Unregister(fd) } fd.asyncHandler = nil @@ -834,44 +840,27 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn return fd.asyncHandler } -// FileReadWriteSeeker is a helper struct to pass a FileDescription as -// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. -type FileReadWriteSeeker struct { - FD *FileDescription - Ctx context.Context - ROpts ReadOptions - WOpts WriteOptions -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts) - return int(n), err -} - -// Read implements io.ReadWriteSeeker.Read. -func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.Read(f.Ctx, dst, f.ROpts) - return int(n), err -} - -// Seek implements io.ReadWriteSeeker.Seek. -func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { - return f.FD.Seek(f.Ctx, offset, int32(whence)) -} - -// WriteAt implements io.WriterAt.WriteAt. -func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts) - return int(n), err -} - -// Write implements io.ReadWriteSeeker.Write. -func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { - buf := usermem.BytesIOSequence(p) - n, err := f.FD.Write(f.Ctx, buf, f.WOpts) - return int(n), err +// CopyRegularFileData copies data from srcFD to dstFD until reading from srcFD +// returns EOF or an error. It returns the number of bytes copied. +func CopyRegularFileData(ctx context.Context, dstFD, srcFD *FileDescription) (int64, error) { + done := int64(0) + buf := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size + for { + readN, readErr := srcFD.Read(ctx, buf, ReadOptions{}) + if readErr != nil && readErr != io.EOF { + return done, readErr + } + src := buf.TakeFirst64(readN) + for src.NumBytes() != 0 { + writeN, writeErr := dstFD.Write(ctx, src, WriteOptions{}) + done += writeN + src = src.DropFirst64(writeN) + if writeErr != nil { + return done, writeErr + } + } + if readErr == io.EOF { + return done, nil + } + } } diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go index 7723ed643..8f070ed53 100644 --- a/pkg/sentry/vfs/save_restore.go +++ b/pkg/sentry/vfs/save_restore.go @@ -18,8 +18,10 @@ import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/waiter" ) // FilesystemImplSaveRestoreExtension is an optional extension to @@ -120,5 +122,20 @@ func (mnt *Mount) afterLoad() { func (epi *epollInterest) afterLoad() { // Mark all epollInterests as ready after restore so that the next call to // EpollInstance.ReadEvents() rechecks their readiness. - epi.Callback(nil) + epi.Callback(nil, waiter.EventMaskFromLinux(epi.mask)) +} + +// beforeSave is called by stateify. +func (fd *FileDescription) beforeSave() { + fd.saved = true + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Unregister(fd) + } +} + +// afterLoad is called by stateify. +func (fd *FileDescription) afterLoad() { + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Register(fd) + } } |