diff options
Diffstat (limited to 'pkg/sentry')
124 files changed, 2343 insertions, 1792 deletions
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index d75d665ae..dd2effdf9 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint { func (a SyscallArgument) ModeT() uint { return uint(uint16(a.Value)) } + +// ErrFloatingPoint indicates a failed restore due to unusable floating point +// state. +type ErrFloatingPoint struct { + // supported is the supported floating point state. + supported uint64 + + // saved is the saved floating point state. + saved uint64 +} + +// Error returns a sensible description of the restore error. +func (e ErrFloatingPoint) Error() string { + return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) +} diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index 19ce99d25..840e53d33 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -17,27 +17,10 @@ package arch import ( - "fmt" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/usermem" ) -// ErrFloatingPoint indicates a failed restore due to unusable floating point -// state. -type ErrFloatingPoint struct { - // supported is the supported floating point state. - supported uint64 - - // saved is the saved floating point state. - saved uint64 -} - -// Error returns a sensible description of the restore error. -func (e ErrFloatingPoint) Error() string { - return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) -} - // XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 // and SSE state, so this is the equivalent XSTATE_BV value. const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index c9fb55d00..35d2e07c3 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() { } } -// Pid returns the si_pid field. -func (s *SignalInfo) Pid() int32 { +// PID returns the si_pid field. +func (s *SignalInfo) PID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[0:4])) } -// SetPid mutates the si_pid field. -func (s *SignalInfo) SetPid(val int32) { +// SetPID mutates the si_pid field. +func (s *SignalInfo) SetPID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) } -// Uid returns the si_uid field. -func (s *SignalInfo) Uid() int32 { +// UID returns the si_uid field. +func (s *SignalInfo) UID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[4:8])) } -// SetUid mutates the si_uid field. -func (s *SignalInfo) SetUid(val int32) { +// SetUID mutates the si_uid field. +func (s *SignalInfo) SetUID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) } @@ -251,3 +251,26 @@ func (s *SignalInfo) Arch() uint32 { func (s *SignalInfo) SetArch(val uint32) { usermem.ByteOrder.PutUint32(s.Fields[12:16], val) } + +// Band returns the si_band field. +func (s *SignalInfo) Band() int64 { + return int64(usermem.ByteOrder.Uint64(s.Fields[0:8])) +} + +// SetBand mutates the si_band field. +func (s *SignalInfo) SetBand(val int64) { + // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. + // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is + // `int`. See siginfo.h. + usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) +} + +// FD returns the si_fd field. +func (s *SignalInfo) FD() uint32 { + return usermem.ByteOrder.Uint32(s.Fields[8:12]) +} + +// SetFD mutates the si_fd field. +func (s *SignalInfo) SetFD(val uint32) { + usermem.ByteOrder.PutUint32(s.Fields[8:12], val) +} diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 2bf3c45e1..91b8fb44f 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -193,7 +193,7 @@ func (p *Profile) StopTrace(_, _ *struct{}) error { defer p.mu.Unlock() if p.traceFile == nil { - return errors.New("Execution tracing not started") + return errors.New("execution tracing not started") } // Similarly to the case above, if tasks have not ended traces, we will diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index d800f2c85..62eaca965 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { Callback: func(err error) { if err == nil { log.Infof("Save succeeded: exiting...") + s.Kernel.SetSaveSuccess(false /* autosave */) } else { log.Warningf("Save failed: exiting...") s.Kernel.SetSaveError(err) diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index 314661475..badd5b073 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package fdimport provides the Import function. package fdimport import ( diff --git a/pkg/sentry/fs/BUILD b/pkg/sentry/fs/BUILD index ea85ab33c..5c3e852e9 100644 --- a/pkg/sentry/fs/BUILD +++ b/pkg/sentry/fs/BUILD @@ -49,13 +49,13 @@ go_library( "//pkg/amutex", "//pkg/context", "//pkg/log", - "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/secio", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs/lock", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index ff2fe6712..8e0aa9019 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err // copyUpBuffers is a buffer pool for copying file content. The buffer // size is the same used by io.Copy. -var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }} +var copyUpBuffers = sync.Pool{ + New: func() interface{} { + b := make([]byte, 8*usermem.PageSize) + return &b + }, +} // copyContentsLocked copies the contents of lower to upper. It panics if // less than size bytes can be copied. @@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in defer lowerFile.DecRef(ctx) // Use a buffer pool to minimize allocations. - buf := copyUpBuffers.Get().([]byte) + buf := copyUpBuffers.Get().(*[]byte) defer copyUpBuffers.Put(buf) // Transfer the contents. @@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in // optimizations could be self-defeating. So we leave this as simple as possible. var offset int64 for { - nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset) + nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset) if err != nil && err != io.EOF { return err } @@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in } return nil } - nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset) + nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset) if err != nil { return err } diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index c7a11eec1..e04784db2 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) { wg.Add(1) go func(o *overlayTestFile) { if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil { - t.Fatalf("failed to copy up: %v", err) + t.Errorf("failed to copy up: %v", err) } wg.Done() }(file) diff --git a/pkg/sentry/fs/file.go b/pkg/sentry/fs/file.go index 72ea70fcf..57f904801 100644 --- a/pkg/sentry/fs/file.go +++ b/pkg/sentry/fs/file.go @@ -17,13 +17,12 @@ package fs import ( "math" "sync/atomic" - "time" "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/limits" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/uniqueid" @@ -33,28 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -var ( - // RecordWaitTime controls writing metrics for filesystem reads. - // Enabling this comes at a small CPU cost due to performing two - // monotonic clock reads per read call. - // - // Note that this is only performed in the direct read path, and may - // not be consistently applied for other forms of reads, such as - // splice. - RecordWaitTime = false - - reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.") - readWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.") -) - -// IncrementWait increments the given wait time metric, if enabled. -func IncrementWait(m *metric.Uint64Metric, start time.Time) { - if !RecordWaitTime { - return - } - m.IncrementBy(uint64(time.Since(start))) -} - // FileMaxOffset is the maximum possible file offset. const FileMaxOffset = math.MaxInt64 @@ -257,22 +234,19 @@ func (f *File) Readdir(ctx context.Context, serializer DentrySerializer) error { // // Returns syserror.ErrInterrupted if reading was interrupted. func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) { - var start time.Time - if RecordWaitTime { - start = time.Now() - } + start := fsmetric.StartReadWait() + defer fsmetric.FinishReadWait(fsmetric.ReadWait, start) + if !f.mu.Lock(ctx) { - IncrementWait(readWait, start) return 0, syserror.ErrInterrupted } - reads.Increment() + fsmetric.Reads.Increment() n, err := f.FileOperations.Read(ctx, f, dst, f.offset) if n > 0 && !f.flags.NonSeekable { atomic.AddInt64(&f.offset, n) } f.mu.Unlock() - IncrementWait(readWait, start) return n, err } @@ -282,19 +256,16 @@ func (f *File) Readv(ctx context.Context, dst usermem.IOSequence) (int64, error) // // Otherwise same as Readv. func (f *File) Preadv(ctx context.Context, dst usermem.IOSequence, offset int64) (int64, error) { - var start time.Time - if RecordWaitTime { - start = time.Now() - } + start := fsmetric.StartReadWait() + defer fsmetric.FinishReadWait(fsmetric.ReadWait, start) + if !f.mu.Lock(ctx) { - IncrementWait(readWait, start) return 0, syserror.ErrInterrupted } - reads.Increment() + fsmetric.Reads.Increment() n, err := f.FileOperations.Read(ctx, f, dst, offset) f.mu.Unlock() - IncrementWait(readWait, start) return n, err } diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go index 8049538f2..ec3d3f96c 100644 --- a/pkg/sentry/fs/filetest/filetest.go +++ b/pkg/sentry/fs/filetest/filetest.go @@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File { // Read just fails the request. func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Readv not implemented") + return 0, fmt.Errorf("TestFileOperations.Read not implemented") } // Write just fails the request. func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Writev not implemented") + return 0, fmt.Errorf("TestFileOperations.Write not implemented") } diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index d2dbff268..a020da53b 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -65,7 +65,7 @@ var ( // runs with the lock held for reading. AsyncBarrier will take the lock // for writing, thus ensuring that all Async work completes before // AsyncBarrier returns. - workMu sync.RWMutex + workMu sync.CrossGoroutineRWMutex // asyncError is used to store up to one asynchronous execution error. asyncError = make(chan error, 1) diff --git a/pkg/sentry/fs/gofer/BUILD b/pkg/sentry/fs/gofer/BUILD index fea135eea..4c30098cd 100644 --- a/pkg/sentry/fs/gofer/BUILD +++ b/pkg/sentry/fs/gofer/BUILD @@ -28,7 +28,6 @@ go_library( "//pkg/context", "//pkg/fd", "//pkg/log", - "//pkg/metric", "//pkg/p9", "//pkg/refs", "//pkg/safemem", @@ -38,6 +37,7 @@ go_library( "//pkg/sentry/fs/fdpipe", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/host", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go index d481baf77..e5579095b 100644 --- a/pkg/sentry/fs/gofer/attr.go +++ b/pkg/sentry/fs/gofer/attr.go @@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType { return fs.BlockDevice case pattr.Mode.IsSocket(): return fs.Socket - case pattr.Mode.IsRegular(): - fallthrough default: return fs.RegularFile } diff --git a/pkg/sentry/fs/gofer/file.go b/pkg/sentry/fs/gofer/file.go index c0bc63a32..bb63448cb 100644 --- a/pkg/sentry/fs/gofer/file.go +++ b/pkg/sentry/fs/gofer/file.go @@ -21,27 +21,17 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) -var ( - opensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a writable+executable file was opened from a gofer.") - opens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a 9P file was opened from a gofer.") - opensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a host file was opened from a gofer.") - reads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") - readWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.") - readsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.") - readWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.") -) - // fileOperations implements fs.FileOperations for a remote file system. // // +stateify savable @@ -101,14 +91,14 @@ func NewFile(ctx context.Context, dirent *fs.Dirent, name string, flags fs.FileF } if flags.Write { if err := dirent.Inode.CheckPermission(ctx, fs.PermMask{Execute: true}); err == nil { - opensWX.Increment() + fsmetric.GoferOpensWX.Increment() log.Warningf("Opened a writable executable: %q", name) } } if handles.Host != nil { - opensHost.Increment() + fsmetric.GoferOpensHost.Increment() } else { - opens9P.Increment() + fsmetric.GoferOpens9P.Increment() } return fs.NewFile(ctx, dirent, flags, f) } @@ -278,20 +268,17 @@ func (f *fileOperations) Write(ctx context.Context, file *fs.File, src usermem.I // use this function rather than using a defer in Read() to avoid the performance hit of defer. func (f *fileOperations) incrementReadCounters(start time.Time) { if f.handles.Host != nil { - readsHost.Increment() - fs.IncrementWait(readWaitHost, start) + fsmetric.GoferReadsHost.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start) } else { - reads9P.Increment() - fs.IncrementWait(readWait9P, start) + fsmetric.GoferReads9P.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start) } } // Read implements fs.FileOperations.Read. func (f *fileOperations) Read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - var start time.Time - if fs.RecordWaitTime { - start = time.Now() - } + start := fsmetric.StartReadWait() if fs.IsDir(file.Dirent.Inode.StableAttr) { // Not all remote file systems enforce this so this client does. f.incrementReadCounters(start) diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 3a225fd39..9d6fdd08f 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -117,7 +117,7 @@ type inodeFileState struct { // loading is acquired when the inodeFileState begins an asynchronous // load. It releases when the load is complete. Callers that require all // state to be available should call waitForLoad() to ensure that. - loading sync.Mutex `state:".(struct{})"` + loading sync.CrossGoroutineMutex `state:".(struct{})"` // savedUAttr is only allocated during S/R. It points to the save-time // unstable attributes and is used to validate restore-time ones. diff --git a/pkg/sentry/fs/inode.go b/pkg/sentry/fs/inode.go index 004910453..9b3d8166a 100644 --- a/pkg/sentry/fs/inode.go +++ b/pkg/sentry/fs/inode.go @@ -18,9 +18,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" @@ -28,8 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -var opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.") - // Inode is a file system object that can be simultaneously referenced by different // components of the VFS (Dirent, fs.File, etc). // @@ -247,7 +245,7 @@ func (i *Inode) GetFile(ctx context.Context, d *Dirent, flags FileFlags) (*File, if i.overlay != nil { return overlayGetFile(ctx, i.overlay, d, flags) } - opens.Increment() + fsmetric.Opens.Increment() return i.InodeOperations.GetFile(ctx, d, flags) } diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index f8aad2dbd..b998fb75d 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -84,6 +84,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode children := map[string]*fs.Inode{ "hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil), + "sem": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), "shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))), "shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))), "shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))), diff --git a/pkg/sentry/fs/tmpfs/BUILD b/pkg/sentry/fs/tmpfs/BUILD index aa7199014..b521a86a2 100644 --- a/pkg/sentry/fs/tmpfs/BUILD +++ b/pkg/sentry/fs/tmpfs/BUILD @@ -15,12 +15,12 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", - "//pkg/metric", "//pkg/safemem", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/ramfs", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index d6c65301c..e04cd608d 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -18,14 +18,13 @@ import ( "fmt" "io" "math" - "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -35,13 +34,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -var ( - opensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.") - opensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.") - reads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.") - readWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.") -) - // fileInodeOperations implements fs.InodeOperations for a regular tmpfs file. // These files are backed by pages allocated from a platform.Memory, and may be // directly mapped. @@ -157,9 +149,9 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare // GetFile implements fs.InodeOperations.GetFile. func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { if flags.Write { - opensW.Increment() + fsmetric.TmpfsOpensW.Increment() } else if flags.Read { - opensRO.Increment() + fsmetric.TmpfsOpensRO.Increment() } flags.Pread = true flags.Pwrite = true @@ -319,14 +311,12 @@ func (*fileInodeOperations) StatFS(context.Context) (fs.Info, error) { } func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { - var start time.Time - if fs.RecordWaitTime { - start = time.Now() - } - reads.Increment() + start := fsmetric.StartReadWait() + defer fsmetric.FinishReadWait(fsmetric.TmpfsReadWait, start) + fsmetric.TmpfsReads.Increment() + // Zero length reads for tmpfs are no-ops. if dst.NumBytes() == 0 { - fs.IncrementWait(readWait, start) return 0, nil } @@ -343,7 +333,6 @@ func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst userm size := f.attr.Size f.dataMu.RUnlock() if offset >= size { - fs.IncrementWait(readWait, start) return 0, io.EOF } @@ -354,7 +343,6 @@ func (f *fileInodeOperations) read(ctx context.Context, file *fs.File, dst userm f.attr.AccessTime = ktime.NowFromContext(ctx) f.attrMu.Unlock() } - fs.IncrementWait(readWait, start) return n, err } diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go index 1b3459c1d..4ab894965 100644 --- a/pkg/sentry/fsimpl/fuse/connection_control.go +++ b/pkg/sentry/fsimpl/fuse/connection_control.go @@ -84,11 +84,7 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { Flags: fuseDefaultInitFlags, } - req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in) - if err != nil { - return err - } - + req := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in) // Since there is no task to block on and FUSE_INIT is the request // to unblock other requests, use nil. return conn.CallAsync(nil, req) diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go index 91d16c1cf..d8b0d7657 100644 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -76,10 +76,7 @@ func TestConnectionAbort(t *testing.T) { var futNormal []*futureResponse for i := 0; i < int(numRequests); i++ { - req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) fut, err := conn.callFutureLocked(task, req) if err != nil { t.Fatalf("callFutureLocked failed: %v", err) @@ -105,10 +102,7 @@ func TestConnectionAbort(t *testing.T) { } // After abort, Call() should return directly with ENOTCONN. - req, err := conn.NewRequest(creds, 0, 0, 0, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, 0, 0, 0, testObj) _, err = conn.Call(task, req) if err != syserror.ENOTCONN { t.Fatalf("Incorrect error code received for Call() after connection aborted") diff --git a/pkg/sentry/fsimpl/fuse/dev.go b/pkg/sentry/fsimpl/fuse/dev.go index 89c3ef079..1bbe6fdb7 100644 --- a/pkg/sentry/fsimpl/fuse/dev.go +++ b/pkg/sentry/fsimpl/fuse/dev.go @@ -363,7 +363,7 @@ func (fd *DeviceFD) Readiness(mask waiter.EventMask) waiter.EventMask { func (fd *DeviceFD) readinessLocked(mask waiter.EventMask) waiter.EventMask { var ready waiter.EventMask - if fd.fs.umounted { + if fd.fs == nil || fd.fs.umounted { ready |= waiter.EventErr return ready & mask } diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 95c475a65..bb2d0d31a 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -219,10 +219,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con data: rand.Uint32(), } - req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) // Queue up a request. // Analogous to Call except it doesn't block on the task. diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go index 8f220a04b..fcc5d9a2a 100644 --- a/pkg/sentry/fsimpl/fuse/directory.go +++ b/pkg/sentry/fsimpl/fuse/directory.go @@ -68,11 +68,7 @@ func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirent } // TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS. - req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) - if err != nil { - return err - } - + req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) res, err := fusefs.conn.Call(task, req) if err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go index 83f2816b7..e138b11f8 100644 --- a/pkg/sentry/fsimpl/fuse/file.go +++ b/pkg/sentry/fsimpl/fuse/file.go @@ -83,12 +83,8 @@ func (fd *fileDescription) Release(ctx context.Context) { opcode = linux.FUSE_RELEASE } kernelTask := kernel.TaskFromContext(ctx) - // ignoring errors and FUSE server reply is analogous to Linux's behavior. - req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) - if err != nil { - // No way to invoke Call() with an errored request. - return - } + // Ignoring errors and FUSE server reply is analogous to Linux's behavior. + req := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) // The reply will be ignored since no callback is defined in asyncCallBack(). conn.CallAsync(kernelTask, req) } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 23e827f90..3af807a21 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -119,7 +119,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) if err != nil { - return nil, nil, err + log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err) + return nil, nil, syserror.EINVAL } kernelTask := kernel.TaskFromContext(ctx) @@ -360,12 +361,8 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr in.Flags &= ^uint32(linux.O_TRUNC) } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) - if err != nil { - return nil, err - } - // Send the request and receive the reply. + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return nil, err @@ -485,10 +482,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err return syserror.EINVAL } in := linux.FUSEUnlinkIn{Name: name} - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) - if err != nil { - return err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return err @@ -515,11 +509,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) in := linux.FUSERmDirIn{Name: name} - req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) - if err != nil { - return err - } - + req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { return err @@ -535,10 +525,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) return nil, syserror.EINVAL } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) - if err != nil { - return nil, err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return nil, err @@ -574,10 +561,7 @@ func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context") return "", syserror.EINVAL } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) - if err != nil { - return "", err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return "", err @@ -680,11 +664,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp GetAttrFlags: flags, Fh: fh, } - req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) - if err != nil { - return linux.FUSEAttr{}, err - } - + req := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { return linux.FUSEAttr{}, err @@ -803,11 +783,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre UID: opts.Stat.UID, GID: opts.Stat.GID, } - req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) - if err != nil { - return err - } - + req := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) res, err := conn.Call(task, req) if err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 2d396e84c..23ce91849 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -79,13 +79,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui in.Offset = off + (uint64(pagesRead) << usermem.PageShift) in.Size = pagesCanRead << usermem.PageShift - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) - if err != nil { - return nil, 0, err - } - // TODO(gvisor.dev/issue/3247): support async read. + req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) res, err := fs.conn.Call(t, req) if err != nil { return nil, 0, err @@ -204,11 +200,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, in.Offset = off + uint64(written) in.Size = toWrite - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) - if err != nil { - return 0, err - } - + req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) req.payload = data[written : written+toWrite] // TODO(gvisor.dev/issue/3247): support async write. diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index 7fa00569b..41d679358 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) src = src[2:] } + _ = src // Remove unused warning. } // SizeBytes is the size of the payload of the FUSE_INIT response. @@ -104,7 +105,7 @@ type Request struct { } // NewRequest creates a new request that can be sent to the FUSE server. -func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { +func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) *Request { conn.fd.mu.Lock() defer conn.fd.mu.Unlock() conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) @@ -130,7 +131,7 @@ func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint id: hdr.Unique, hdr: &hdr, data: buf, - }, nil + } } // futureResponse represents an in-flight request, that may or may not have diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index 4c3e9acf8..807b6ed1f 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -59,6 +59,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/fsimpl/host", + "//pkg/sentry/fsmetric", "//pkg/sentry/hostfd", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 2294c490e..df27554d3 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/sentry/fsimpl/host" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" @@ -985,14 +986,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open switch d.fileType() { case linux.S_IFREG: if !d.fs.opts.regularFilesUseSpecialFileFD { - if err := d.ensureSharedHandle(ctx, ats&vfs.MayRead != 0, ats&vfs.MayWrite != 0, trunc); err != nil { + if err := d.ensureSharedHandle(ctx, ats.MayRead(), ats.MayWrite(), trunc); err != nil { return nil, err } - fd := ®ularFileFD{} - fd.LockFD.Init(&d.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ - AllowDirectIO: true, - }); err != nil { + fd, err := newRegularFileFD(mnt, d, opts.Flags) + if err != nil { return nil, err } vfd = &fd.vfsfd @@ -1019,6 +1017,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{}); err != nil { return nil, err } + if atomic.LoadInt32(&d.readFD) >= 0 { + fsmetric.GoferOpensHost.Increment() + } else { + fsmetric.GoferOpens9P.Increment() + } return &fd.vfsfd, nil case linux.S_IFLNK: // Can't open symlinks without O_PATH (which is unimplemented). @@ -1110,7 +1113,7 @@ retry: return nil, err } } - fd, err := newSpecialFileFD(h, mnt, d, &d.locks, opts.Flags) + fd, err := newSpecialFileFD(h, mnt, d, opts.Flags) if err != nil { h.close(ctx) return nil, err @@ -1205,11 +1208,8 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving // Finally, construct a file description representing the created file. var childVFSFD *vfs.FileDescription if useRegularFileFD { - fd := ®ularFileFD{} - fd.LockFD.Init(&child.locks) - if err := fd.vfsfd.Init(fd, opts.Flags, mnt, &child.vfsd, &vfs.FileDescriptionOptions{ - AllowDirectIO: true, - }); err != nil { + fd, err := newRegularFileFD(mnt, child, opts.Flags) + if err != nil { return nil, err } childVFSFD = &fd.vfsfd @@ -1221,7 +1221,7 @@ func (d *dentry) createAndOpenChildLocked(ctx context.Context, rp *vfs.Resolving if fdobj != nil { h.fd = int32(fdobj.Release()) } - fd, err := newSpecialFileFD(h, mnt, child, &d.locks, opts.Flags) + fd, err := newSpecialFileFD(h, mnt, child, opts.Flags) if err != nil { h.close(ctx) return nil, err diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index 75a836899..3cdb1e659 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -743,7 +743,9 @@ type dentry struct { // for memory mappings. If mmapFD is -1, no such FD is available, and the // internal page cache implementation is used for memory mappings instead. // - // These fields are protected by handleMu. + // These fields are protected by handleMu. readFD, writeFD, and mmapFD are + // additionally written using atomic memory operations, allowing them to be + // read (albeit racily) with atomic.LoadInt32() without locking handleMu. // // readFile and writeFile may or may not represent the same p9.File. Once // either p9.File transitions from closed (isNil() == true) to open @@ -1351,16 +1353,11 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { return } if refs > 0 { - if d.cached { - // This isn't strictly necessary (fs.cachedDentries is permitted to - // contain dentries with non-zero refs, which are skipped by - // fs.evictCachedDentryLocked() upon reaching the end of the LRU), - // but since we are already holding fs.renameMu for writing we may - // as well. - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentriesLen-- - d.cached = false - } + // This isn't strictly necessary (fs.cachedDentries is permitted to + // contain dentries with non-zero refs, which are skipped by + // fs.evictCachedDentryLocked() upon reaching the end of the LRU), but + // since we are already holding fs.renameMu for writing we may as well. + d.removeFromCacheLocked() return } // Deleted and invalidated dentries with zero references are no longer @@ -1369,20 +1366,18 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { if d.isDeleted() { d.watches.HandleDeletion(ctx) } - if d.cached { - d.fs.cachedDentries.Remove(d) - d.fs.cachedDentriesLen-- - d.cached = false - } + d.removeFromCacheLocked() d.destroyLocked(ctx) return } - // If d still has inotify watches and it is not deleted or invalidated, we - // cannot cache it and allow it to be evicted. Otherwise, we will lose its - // watches, even if a new dentry is created for the same file in the future. - // Note that the size of d.watches cannot concurrently transition from zero - // to non-zero, because adding a watch requires holding a reference on d. + // If d still has inotify watches and it is not deleted or invalidated, it + // can't be evicted. Otherwise, we will lose its watches, even if a new + // dentry is created for the same file in the future. Note that the size of + // d.watches cannot concurrently transition from zero to non-zero, because + // adding a watch requires holding a reference on d. if d.watches.Size() > 0 { + // As in the refs > 0 case, this is not strictly necessary. + d.removeFromCacheLocked() return } @@ -1413,6 +1408,15 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { } } +// Preconditions: d.fs.renameMu must be locked for writing. +func (d *dentry) removeFromCacheLocked() { + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } +} + // Precondition: fs.renameMu must be locked for writing; it may be temporarily // unlocked. func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { @@ -1426,12 +1430,10 @@ func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { // * fs.cachedDentriesLen != 0. func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { victim := fs.cachedDentries.Back() - fs.cachedDentries.Remove(victim) - fs.cachedDentriesLen-- - victim.cached = false - // victim.refs may have become non-zero from an earlier path resolution - // since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 { + victim.removeFromCacheLocked() + // victim.refs or victim.watches.Size() may have become non-zero from an + // earlier path resolution since it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 && victim.watches.Size() == 0 { if victim.parent != nil { victim.parent.dirMu.Lock() if !victim.vfsd.IsDead() { @@ -1668,7 +1670,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool } fdsToClose = append(fdsToClose, d.readFD) invalidateTranslations = true - d.readFD = h.fd + atomic.StoreInt32(&d.readFD, h.fd) } else { // Otherwise, we want to avoid invalidating existing // memmap.Translations (which is expensive); instead, use @@ -1689,15 +1691,15 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool h.fd = d.readFD } } else { - d.readFD = h.fd + atomic.StoreInt32(&d.readFD, h.fd) } if d.writeFD != h.fd && d.writeFD >= 0 { fdsToClose = append(fdsToClose, d.writeFD) } - d.writeFD = h.fd - d.mmapFD = h.fd + atomic.StoreInt32(&d.writeFD, h.fd) + atomic.StoreInt32(&d.mmapFD, h.fd) } else if openReadable && d.readFD < 0 { - d.readFD = h.fd + atomic.StoreInt32(&d.readFD, h.fd) // If the file has not been opened for writing, the new FD may // be used for read-only memory mappings. If the file was // previously opened for reading (without an FD), then existing @@ -1705,10 +1707,10 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // invalidate those mappings. if d.writeFile.isNil() { invalidateTranslations = !d.readFile.isNil() - d.mmapFD = h.fd + atomic.StoreInt32(&d.mmapFD, h.fd) } } else if openWritable && d.writeFD < 0 { - d.writeFD = h.fd + atomic.StoreInt32(&d.writeFD, h.fd) if d.readFD >= 0 { // We have an existing read-only FD, but the file has just // been opened for writing, so we need to start supporting @@ -1717,7 +1719,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // writable memory mappings. Switch to using the internal // page cache. invalidateTranslations = true - d.mmapFD = -1 + atomic.StoreInt32(&d.mmapFD, -1) } } else { // The new FD is not useful. @@ -1729,7 +1731,7 @@ func (d *dentry) ensureSharedHandle(ctx context.Context, read, write, trunc bool // memory mappings. However, we have no writable host FD. Switch to // using the internal page cache. invalidateTranslations = true - d.mmapFD = -1 + atomic.StoreInt32(&d.mmapFD, -1) } // Switch to new fids. diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index 652142ecc..283b220bb 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -26,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" @@ -48,6 +49,25 @@ type regularFileFD struct { off int64 } +func newRegularFileFD(mnt *vfs.Mount, d *dentry, flags uint32) (*regularFileFD, error) { + fd := ®ularFileFD{} + fd.LockFD.Init(&d.locks) + if err := fd.vfsfd.Init(fd, flags, mnt, &d.vfsd, &vfs.FileDescriptionOptions{ + AllowDirectIO: true, + }); err != nil { + return nil, err + } + if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { + fsmetric.GoferOpensWX.Increment() + } + if atomic.LoadInt32(&d.mmapFD) >= 0 { + fsmetric.GoferOpensHost.Increment() + } else { + fsmetric.GoferOpens9P.Increment() + } + return fd, nil +} + // Release implements vfs.FileDescriptionImpl.Release. func (fd *regularFileFD) Release(context.Context) { } @@ -89,6 +109,18 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + start := fsmetric.StartReadWait() + d := fd.dentry() + defer func() { + if atomic.LoadInt32(&d.readFD) >= 0 { + fsmetric.GoferReadsHost.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start) + } else { + fsmetric.GoferReads9P.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start) + } + }() + if offset < 0 { return 0, syserror.EINVAL } @@ -102,7 +134,6 @@ func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs // Check for reading at EOF before calling into MM (but not under // InteropModeShared, which makes d.size unreliable). - d := fd.dentry() if d.cachedMetadataAuthoritative() && uint64(offset) >= atomic.LoadUint64(&d.size) { return 0, io.EOF } @@ -647,10 +678,7 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt // Whether or not we have a host FD, we're not allowed to use it. return syserror.ENODEV } - d.handleMu.RLock() - haveFD := d.mmapFD >= 0 - d.handleMu.RUnlock() - if !haveFD { + if atomic.LoadInt32(&d.mmapFD) < 0 { return syserror.ENODEV } default: @@ -668,10 +696,7 @@ func (d *dentry) mayCachePages() bool { if d.fs.opts.forcePageCache { return true } - d.handleMu.RLock() - haveFD := d.mmapFD >= 0 - d.handleMu.RUnlock() - return haveFD + return atomic.LoadInt32(&d.mmapFD) >= 0 } // AddMapping implements memmap.Mappable.AddMapping. diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 625400c0b..089955a96 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -23,6 +23,7 @@ import ( "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" @@ -70,7 +71,7 @@ type specialFileFD struct { buf []byte } -func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { +func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, flags uint32) (*specialFileFD, error) { ftype := d.fileType() seekable := ftype == linux.S_IFREG || ftype == linux.S_IFCHR || ftype == linux.S_IFBLK haveQueue := (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && h.fd >= 0 @@ -80,7 +81,7 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, seekable: seekable, haveQueue: haveQueue, } - fd.LockFD.Init(locks) + fd.LockFD.Init(&d.locks) if haveQueue { if err := fdnotifier.AddFD(h.fd, &fd.queue); err != nil { return nil, err @@ -98,6 +99,14 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, d.fs.syncMu.Lock() d.fs.specialFileFDs[fd] = struct{}{} d.fs.syncMu.Unlock() + if fd.vfsfd.IsWritable() && (atomic.LoadUint32(&d.mode)&0111 != 0) { + fsmetric.GoferOpensWX.Increment() + } + if h.fd >= 0 { + fsmetric.GoferOpensHost.Increment() + } else { + fsmetric.GoferOpens9P.Increment() + } return fd, nil } @@ -161,6 +170,17 @@ func (fd *specialFileFD) Allocate(ctx context.Context, mode, offset, length uint // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + start := fsmetric.StartReadWait() + defer func() { + if fd.handle.fd >= 0 { + fsmetric.GoferReadsHost.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWaitHost, start) + } else { + fsmetric.GoferReads9P.Increment() + fsmetric.FinishReadWait(fsmetric.GoferReadWait9P, start) + } + }() + if fd.seekable && offset < 0 { return 0, syserror.EINVAL } diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index c14abcff4..565d723f0 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -286,7 +286,7 @@ func (d *Dentry) cacheLocked(ctx context.Context) { refs := atomic.LoadInt64(&d.refs) if refs == -1 { // Dentry has already been destroyed. - panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d)) + return } if refs > 0 { if d.cached { diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 469f3a33d..27b00cf6f 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -16,7 +16,6 @@ package overlay import ( "fmt" - "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -129,25 +128,9 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { return err } defer newFD.DecRef(ctx) - bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size - for { - readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{}) - if readErr != nil && readErr != io.EOF { - cleanupUndoCopyUp() - return readErr - } - total := int64(0) - for total < readN { - writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{}) - total += writeN - if writeErr != nil { - cleanupUndoCopyUp() - return writeErr - } - } - if readErr == io.EOF { - break - } + if _, err := vfs.CopyRegularFileData(ctx, newFD, oldFD); err != nil { + cleanupUndoCopyUp() + return err } d.mapsMu.Lock() defer d.mapsMu.Unlock() diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 2b89a7a6d..25c785fd4 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -103,8 +103,8 @@ func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescript for e, mask := range fd.lowerWaiters { fd.cachedFD.EventUnregister(e) upperFD.EventRegister(e, mask) - if ready&mask != 0 { - e.Callback.Callback(e) + if m := ready & mask; m != 0 { + e.Callback.Callback(e, m) } } } diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 5cf8a071a..d4f6a5a9b 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -208,7 +208,7 @@ func (n *netUnixData) Generate(ctx context.Context, buf *bytes.Buffer) error { for _, se := range n.kernel.ListSockets() { s := se.SockVFS2 if !s.TryIncRef() { - log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) + // Racing with socket destruction, this is ok. continue } if family, _, _ := s.Impl().(socket.SocketVFS2).Type(); family != linux.AF_UNIX { @@ -351,7 +351,7 @@ func commonGenerateTCP(ctx context.Context, buf *bytes.Buffer, k *kernel.Kernel, for _, se := range k.ListSockets() { s := se.SockVFS2 if !s.TryIncRef() { - log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) + // Racing with socket destruction, this is ok. continue } sops, ok := s.Impl().(socket.SocketVFS2) @@ -516,7 +516,7 @@ func (d *netUDPData) Generate(ctx context.Context, buf *bytes.Buffer) error { for _, se := range d.kernel.ListSockets() { s := se.SockVFS2 if !s.TryIncRef() { - log.Debugf("Couldn't get reference on %v in socket table, racing with destruction?", s) + // Racing with socket destruction, this is ok. continue } sops, ok := s.Impl().(socket.SocketVFS2) diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 7c7afdcfa..25c407d98 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -44,6 +44,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k * return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), + "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 10f1452ef..246bd87bc 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package signalfd provides basic signalfd file implementations. package signalfd import ( @@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index fe520b6fd..09957c2b7 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -67,6 +67,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/time", diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index e39cd305b..9296db2fb 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -381,6 +382,8 @@ afterTrailingSymlink: creds := rp.Credentials() child := fs.newDentry(fs.newRegularFile(creds.EffectiveKUID, creds.EffectiveKGID, opts.Mode)) parentDir.insertChildLocked(child, name) + child.IncRef() + defer child.DecRef(ctx) unlock() fd, err := child.open(ctx, rp, &opts, true) if err != nil { @@ -437,6 +440,11 @@ func (d *dentry) open(ctx context.Context, rp *vfs.ResolvingPath, opts *vfs.Open return nil, err } } + if fd.vfsfd.IsWritable() { + fsmetric.TmpfsOpensW.Increment() + } else if fd.vfsfd.IsReadable() { + fsmetric.TmpfsOpensRO.Increment() + } return &fd.vfsfd, nil case *directory: // Can't open directories writably. diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index f8e0cffb0..6255a7c84 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -359,6 +360,10 @@ func (fd *regularFileFD) Allocate(ctx context.Context, mode, offset, length uint // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *regularFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + start := fsmetric.StartReadWait() + defer fsmetric.FinishReadWait(fsmetric.TmpfsReadWait, start) + fsmetric.TmpfsReads.Increment() + if offset < 0 { return 0, syserror.EINVAL } diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index add5dd48e..a4ad625bb 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -107,8 +107,10 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de // Dentries which may have a reference count of zero, and which therefore // should be dropped once traversal is complete, are appended to ds. // -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. -// !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// * !rp.Done(). func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { if !d.isDir() { return nil, syserror.ENOTDIR @@ -158,15 +160,19 @@ afterSymlink: return child, nil } -// verifyChild verifies the hash of child against the already verified hash of -// the parent to ensure the child is expected. verifyChild triggers a sentry -// panic if unexpected modifications to the file system are detected. In -// noCrashOnVerificationFailure mode it returns a syserror instead. -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// verifyChildLocked verifies the hash of child against the already verified +// hash of the parent to ensure the child is expected. verifyChild triggers a +// sentry panic if unexpected modifications to the file system are detected. In +// ErrorOnViolation mode it returns a syserror instead. +// +// Preconditions: +// * fs.renameMu must be locked. +// * d.dirMu must be locked. +// // TODO(b/166474175): Investigate all possible errors returned in this // function, and make sure we differentiate all errors that indicate unexpected // modifications to the file system from the ones that are not harmful. -func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { +func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, child *dentry) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() // Get the path to the child dentry. This is only used to provide path @@ -248,7 +254,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: parentMerkleFD, Ctx: ctx, } @@ -268,7 +274,8 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contain the hash of the children in the parent Merkle tree when // Verify returns with success. var buf bytes.Buffer - if _, err := merkletree.Verify(&merkletree.VerifyParams{ + parent.hashMu.RLock() + _, err = merkletree.Verify(&merkletree.VerifyParams{ Out: &buf, File: &fdReader, Tree: &fdReader, @@ -284,21 +291,27 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), Expected: parent.hash, DataAndTreeInSameFile: true, - }); err != nil && err != io.EOF { + }) + parent.hashMu.RUnlock() + if err != nil && err != io.EOF { return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. + child.hashMu.Lock() if len(child.hash) == 0 { child.hash = buf.Bytes() } + child.hashMu.Unlock() return child, nil } -// verifyStatAndChildren verifies the stat and children names against the +// verifyStatAndChildrenLocked verifies the stat and children names against the // verified hash. The mode/uid/gid and childrenNames of the file is cached // after verified. -func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat linux.Statx) error { +// +// Preconditions: d.dirMu must be locked. +func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry, stat linux.Statx) error { vfsObj := fs.vfsfs.VirtualFilesystem() // Get the path to the child dentry. This is only used to provide path @@ -384,12 +397,13 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat } } - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: fd, Ctx: ctx, } var buf bytes.Buffer + d.hashMu.RLock() params := &merkletree.VerifyParams{ Out: &buf, Tree: &fdReader, @@ -407,6 +421,7 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat Expected: d.hash, DataAndTreeInSameFile: false, } + d.hashMu.RUnlock() if atomic.LoadUint32(&d.mode)&linux.S_IFMT == linux.S_IFDIR { params.DataAndTreeInSameFile = true } @@ -421,7 +436,9 @@ func (fs *filesystem) verifyStatAndChildren(ctx context.Context, d *dentry, stat return nil } -// Preconditions: fs.renameMu must be locked. d.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { // If verity is enabled on child, we should check again whether @@ -470,7 +487,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s // be cached before enabled. if fs.allowRuntimeEnable { if parent.verityEnabled() { - if _, err := fs.verifyChild(ctx, parent, child); err != nil { + if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil { return nil, err } } @@ -486,7 +503,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s if err != nil { return nil, err } - if err := fs.verifyStatAndChildren(ctx, child, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil { return nil, err } } @@ -506,7 +523,9 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s return child, nil } -// Preconditions: fs.renameMu must be locked. parent.dirMu must be locked. +// Preconditions: +// * fs.renameMu must be locked. +// * parent.dirMu must be locked. func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() @@ -597,13 +616,13 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // allowRuntimeEnable mode and the parent directory hasn't been enabled // yet. if parent.verityEnabled() { - if _, err := fs.verifyChild(ctx, parent, child); err != nil { + if _, err := fs.verifyChildLocked(ctx, parent, child); err != nil { child.destroyLocked(ctx) return nil, err } } if child.verityEnabled() { - if err := fs.verifyStatAndChildren(ctx, child, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, child, stat); err != nil { child.destroyLocked(ctx) return nil, err } @@ -617,7 +636,9 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // rp.Start().Impl().(*dentry)). It does not check that the returned directory // is searchable by the provider of rp. // -// Preconditions: fs.renameMu must be locked. !rp.Done(). +// Preconditions: +// * fs.renameMu must be locked. +// * !rp.Done(). func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() @@ -958,11 +979,13 @@ func (fs *filesystem) StatAt(ctx context.Context, rp *vfs.ResolvingPath, opts vf if err != nil { return linux.Statx{}, err } + d.dirMu.Lock() if d.verityEnabled() { - if err := fs.verifyStatAndChildren(ctx, d, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { return linux.Statx{}, err } } + d.dirMu.Unlock() return stat, nil } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 87dabe038..66029c64d 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -19,6 +19,18 @@ // The verity file system is read-only, except for one case: when // allowRuntimeEnable is true, additional Merkle files can be generated using // the FS_IOC_ENABLE_VERITY ioctl. +// +// Lock order: +// +// filesystem.renameMu +// dentry.dirMu +// fileDescription.mu +// filesystem.verityMu +// dentry.hashMu +// +// Locking dentry.dirMu in multiple dentries requires that parent dentries are +// locked before child dentries, and that filesystem.renameMu is locked to +// stabilize this relationship. package verity import ( @@ -52,6 +64,10 @@ const ( // tree file for "/foo" is "/.merkle.verity.foo". merklePrefix = ".merkle.verity." + // merkleRootPrefix is the prefix of the Merkle tree root file. This + // needs to be different from merklePrefix to avoid name collision. + merkleRootPrefix = ".merkleroot.verity." + // merkleOffsetInParentXattr is the extended attribute name specifying the // offset of the child hash in its parent's Merkle tree. merkleOffsetInParentXattr = "user.merkle.offset" @@ -76,13 +92,8 @@ const ( ) var ( - // noCrashOnVerificationFailure indicates whether the sandbox should panic - // whenever verification fails. If true, an error is returned instead of - // panicking. This should only be set for tests. - // - // TODO(b/165661693): Decide whether to panic or return error based on this - // flag. - noCrashOnVerificationFailure bool + // action specifies the action towards detected violation. + action ViolationAction // verityMu synchronizes concurrent operations that enable verity and perform // verification checks. @@ -93,6 +104,18 @@ var ( // content. type HashAlgorithm int +// ViolationAction is a type specifying the action when an integrity violation +// is detected. +type ViolationAction int + +const ( + // PanicOnViolation terminates the sentry on detected violation. + PanicOnViolation ViolationAction = 0 + // ErrorOnViolation returns an error from the violating system call on + // detected violation. + ErrorOnViolation = 1 +) + // Currently supported hashing algorithms include SHA256 and SHA512. const ( SHA256 HashAlgorithm = iota @@ -187,10 +210,8 @@ type InternalFilesystemOptions struct { // system wrapped by verity file system. LowerGetFSOptions vfs.GetFilesystemOptions - // NoCrashOnVerificationFailure indicates whether the sandbox should - // panic whenever verification fails. If true, an error is returned - // instead of panicking. This should only be set for tests. - NoCrashOnVerificationFailure bool + // Action specifies the action on an integrity violation. + Action ViolationAction } // Name implements vfs.FilesystemType.Name. @@ -202,10 +223,10 @@ func (FilesystemType) Name() string { func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means -// unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +// unexpected modification to the file system is detected. In ErrorOnViolation +// mode, it returns EIO, otherwise it panic. func alertIntegrityViolation(msg string) error { - if noCrashOnVerificationFailure { + if action == ErrorOnViolation { return syserror.EIO } panic(msg) @@ -218,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } - noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure + action = iopts.Action // Mount the lower file system. The lower file system is wrapped inside // verity, and should not be exposed or connected. @@ -246,7 +267,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merklePrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -372,12 +393,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, alertIntegrityViolation(fmt.Sprintf("Failed to deserialize childrenNames: %v", err)) } - if err := fs.verifyStatAndChildren(ctx, d, stat); err != nil { + if err := fs.verifyStatAndChildrenLocked(ctx, d, stat); err != nil { return nil, nil, err } } + d.hashMu.Lock() copy(d.hash, iopts.RootHash) + d.hashMu.Unlock() d.vfsd.Init(d) fs.rootDentry = d @@ -402,7 +425,8 @@ type dentry struct { fs *filesystem // mode, uid, gid and size are the file mode, owner, group, and size of - // the file in the underlying file system. + // the file in the underlying file system. They are set when a dentry + // is initialized, and never modified. mode uint32 uid uint32 gid uint32 @@ -425,18 +449,22 @@ type dentry struct { // childrenNames stores the name of all children of the dentry. This is // used by verity to check whether a child is expected. This is only - // populated by enableVerity. + // populated by enableVerity. childrenNames is also protected by dirMu. childrenNames map[string]struct{} - // lowerVD is the VirtualDentry in the underlying file system. + // lowerVD is the VirtualDentry in the underlying file system. It is + // never modified after initialized. lowerVD vfs.VirtualDentry // lowerMerkleVD is the VirtualDentry of the corresponding Merkle tree - // in the underlying file system. + // in the underlying file system. It is never modified after + // initialized. lowerMerkleVD vfs.VirtualDentry - // hash is the calculated hash for the current file or directory. - hash []byte + // hash is the calculated hash for the current file or directory. hash + // is protected by hashMu. + hashMu sync.RWMutex `state:"nosave"` + hash []byte } // newDentry creates a new dentry representing the given verity file. The @@ -519,7 +547,9 @@ func (d *dentry) checkDropLocked(ctx context.Context) { // destroyLocked destroys the dentry. // -// Preconditions: d.fs.renameMu must be locked for writing. d.refs == 0. +// Preconditions: +// * d.fs.renameMu must be locked for writing. +// * d.refs == 0. func (d *dentry) destroyLocked(ctx context.Context) { switch atomic.LoadInt64(&d.refs) { case 0: @@ -599,6 +629,8 @@ func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) // mode, it returns true if the target has been enabled with // ioctl(FS_IOC_ENABLE_VERITY). func (d *dentry) verityEnabled() bool { + d.hashMu.RLock() + defer d.hashMu.RUnlock() return !d.fs.allowRuntimeEnable || len(d.hash) != 0 } @@ -678,11 +710,13 @@ func (fd *fileDescription) Stat(ctx context.Context, opts vfs.StatOptions) (linu if err != nil { return linux.Statx{}, err } + fd.d.dirMu.Lock() if fd.d.verityEnabled() { - if err := fd.d.fs.verifyStatAndChildren(ctx, fd.d, stat); err != nil { + if err := fd.d.fs.verifyStatAndChildrenLocked(ctx, fd.d, stat); err != nil { return linux.Statx{}, err } } + fd.d.dirMu.Unlock() return stat, nil } @@ -718,22 +752,24 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) return offset, nil } -// generateMerkle generates a Merkle tree file for fd. If fd points to a file -// /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash -// of the generated Merkle tree and the data size is returned. If fd points to -// a regular file, the data is the content of the file. If fd points to a -// directory, the data is all hahes of its children, written to the Merkle tree -// file. -func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, error) { - fdReader := vfs.FileReadWriteSeeker{ +// generateMerkleLocked generates a Merkle tree file for fd. If fd points to a +// file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The +// hash of the generated Merkle tree and the data size is returned. If fd +// points to a regular file, the data is the content of the file. If fd points +// to a directory, the data is all hashes of its children, written to the Merkle +// tree file. +// +// Preconditions: fd.d.fs.verityMu must be locked. +func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) { + fdReader := FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, } - merkleReader := vfs.FileReadWriteSeeker{ + merkleReader := FileReadWriteSeeker{ FD: fd.merkleReader, Ctx: ctx, } - merkleWriter := vfs.FileReadWriteSeeker{ + merkleWriter := FileReadWriteSeeker{ FD: fd.merkleWriter, Ctx: ctx, } @@ -793,11 +829,14 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, return hash, uint64(params.Size), err } -// recordChildren writes the names of fd's children into the corresponding -// Merkle tree file, and saves the offset/size of the map into xattrs. +// recordChildrenLocked writes the names of fd's children into the +// corresponding Merkle tree file, and saves the offset/size of the map into +// xattrs. // -// Preconditions: fd.d.isDir() == true -func (fd *fileDescription) recordChildren(ctx context.Context) error { +// Preconditions: +// * fd.d.fs.verityMu must be locked. +// * fd.d.isDir() == true. +func (fd *fileDescription) recordChildrenLocked(ctx context.Context) error { // Record the children names in the Merkle tree file. childrenNames, err := json.Marshal(fd.d.childrenNames) if err != nil { @@ -847,7 +886,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } - hash, dataSize, err := fd.generateMerkle(ctx) + hash, dataSize, err := fd.generateMerkleLocked(ctx) if err != nil { return 0, err } @@ -888,11 +927,13 @@ func (fd *fileDescription) enableVerity(ctx context.Context) (uintptr, error) { } if fd.d.isDir() { - if err := fd.recordChildren(ctx); err != nil { + if err := fd.recordChildrenLocked(ctx); err != nil { return 0, err } } - fd.d.hash = append(fd.d.hash, hash...) + fd.d.hashMu.Lock() + fd.d.hash = hash + fd.d.hashMu.Unlock() return 0, nil } @@ -904,6 +945,9 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm } var metadata linux.DigestMetadata + fd.d.hashMu.RLock() + defer fd.d.hashMu.RUnlock() + // If allowRuntimeEnable is true, an empty fd.d.hash indicates that // verity is not enabled for the file. If allowRuntimeEnable is false, // this is an integrity violation because all files should have verity @@ -940,11 +984,13 @@ func (fd *fileDescription) measureVerity(ctx context.Context, verityDigest userm func (fd *fileDescription) verityFlags(ctx context.Context, flags usermem.Addr) (uintptr, error) { f := int32(0) + fd.d.hashMu.RLock() // All enabled files should store a hash. This flag is not settable via // FS_IOC_SETFLAGS. if len(fd.d.hash) != 0 { f |= linux.FS_VERITY_FL } + fd.d.hashMu.RUnlock() t := kernel.TaskFromContext(ctx) if t == nil { @@ -1013,16 +1059,17 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } - dataReader := vfs.FileReadWriteSeeker{ + dataReader := FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, } - merkleReader := vfs.FileReadWriteSeeker{ + merkleReader := FileReadWriteSeeker{ FD: fd.merkleReader, Ctx: ctx, } + fd.d.hashMu.RLock() n, err := merkletree.Verify(&merkletree.VerifyParams{ Out: dst.Writer(ctx), File: &dataReader, @@ -1040,6 +1087,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of Expected: fd.d.hash, DataAndTreeInSameFile: false, }) + fd.d.hashMu.RUnlock() if err != nil { return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } @@ -1065,3 +1113,45 @@ func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence) } + +// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as +// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. +type FileReadWriteSeeker struct { + FD *vfs.FileDescription + Ctx context.Context + ROpts vfs.ReadOptions + WOpts vfs.WriteOptions +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts) + return int(n), err +} + +// Read implements io.ReadWriteSeeker.Read. +func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.Read(f.Ctx, dst, f.ROpts) + return int(n), err +} + +// Seek implements io.ReadWriteSeeker.Seek. +func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { + return f.FD.Seek(f.Ctx, offset, int32(whence)) +} + +// WriteAt implements io.WriterAt.WriteAt. +func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts) + return int(n), err +} + +// Write implements io.ReadWriteSeeker.Write. +func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { + buf := usermem.BytesIOSequence(p) + n, err := f.FD.Write(f.Ctx, buf, f.WOpts) + return int(n), err +} diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 7196e74eb..30d8b4355 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -35,16 +35,39 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// rootMerkleFilename is the name of the root Merkle tree file. -const rootMerkleFilename = "root.verity" +const ( + // rootMerkleFilename is the name of the root Merkle tree file. + rootMerkleFilename = "root.verity" + // maxDataSize is the maximum data size of a test file. + maxDataSize = 100000 +) + +var hashAlgs = []HashAlgorithm{SHA256, SHA512} -// maxDataSize is the maximum data size written to the file for test. -const maxDataSize = 100000 +func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry { + t.Helper() + d, ok := vd.Dentry().Impl().(*dentry) + if !ok { + t.Fatalf("can't assert %T as a *dentry", vd) + } + return d +} + +// dentryFromFD returns the dentry corresponding to fd. +func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry { + t.Helper() + f, ok := fd.Impl().(*fileDescription) + if !ok { + t.Fatalf("can't assert %T as a *fileDescription", fd) + } + return f.d +} // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { + t.Helper() k, err := testutil.Boot() if err != nil { t.Fatalf("testutil.Boot: %v", err) @@ -69,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, - LowerName: "tmpfs", - Alg: hashAlg, - AllowRuntimeEnable: true, - NoCrashOnVerificationFailure: true, + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + Alg: hashAlg, + AllowRuntimeEnable: true, + Action: ErrorOnViolation, }, }, }) @@ -92,7 +115,6 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, t.Fatalf("testutil.CreateTask: %v", err) } - t.Helper() t.Cleanup(func() { root.DecRef(ctx) mntns.DecRef(ctx) @@ -100,21 +122,97 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, return vfsObj, root, task, nil } -// newFileFD creates a new file in the verity mount, and returns the FD. The FD -// points to a file that has random data generated. -func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { - creds := auth.CredentialsFromContext(ctx) - lowerRoot := root.Dentry().Impl().(*dentry).lowerVD +// openVerityAt opens a verity file. +// +// TODO(chongc): release reference from opening the file when done. +func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { + return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: vd, + Start: vd, + Path: fspath.Parse(path), + }, &vfs.OpenOptions{ + Flags: flags, + Mode: mode, + }) +} - // Create the file in the underlying file system. - lowerFD, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(filePath), +// openLowerAt opens the file in the underlying file system. +// +// TODO(chongc): release reference from opening the file when done. +func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { + return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(path), + }, &vfs.OpenOptions{ + Flags: flags, + Mode: mode, + }) +} + +// openLowerMerkleAt opens the Merkle file in the underlying file system. +// +// TODO(chongc): release reference from opening the file when done. +func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { + return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerMerkleVD, + Start: d.lowerMerkleVD, }, &vfs.OpenOptions{ - Flags: linux.O_RDWR | linux.O_CREAT | linux.O_EXCL, - Mode: linux.ModeRegular | mode, + Flags: flags, + Mode: mode, + }) +} + +// unlinkLowerAt deletes the file in the underlying file system. +func (d *dentry) unlinkLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error { + return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(path), }) +} + +// unlinkLowerMerkleAt deletes the Merkle file in the underlying file system. +func (d *dentry) unlinkLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string) error { + return vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(merklePrefix + path), + }) +} + +// renameLowerAt renames file name to newName in the underlying file system. +func (d *dentry) renameLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error { + return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(name), + }, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(newName), + }, &vfs.RenameOptions{}) +} + +// renameLowerMerkleAt renames Merkle file name to newName in the underlying +// file system. +func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, name string, newName string) error { + return vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(merklePrefix + name), + }, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(merklePrefix + newName), + }, &vfs.RenameOptions{}) +} + +// newFileFD creates a new file in the verity mount, and returns the FD. The FD +// points to a file that has random data generated. +func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { + // Create the file in the underlying file system. + lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) if err != nil { return nil, 0, err } @@ -137,20 +235,24 @@ func newFileFD(ctx context.Context, vfsObj *vfs.VirtualFilesystem, root vfs.Virt lowerFD.DecRef(ctx) // Now open the verity file descriptor. - fd, err := vfsObj.OpenAt(ctx, creds, &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filePath), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular | mode, - }) + fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) return fd, dataSize, err } -// corruptRandomBit randomly flips a bit in the file represented by fd. -func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { - // Flip a random bit in the underlying file. +// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD. +func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) { + // Create the file in the underlying file system. + _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) + if err != nil { + return nil, err + } + // Now open the verity file descriptor. + fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) + return fd, err +} + +// flipRandomBit randomly flips a bit in the file represented by fd. +func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { randomPos := int64(rand.Intn(size)) byteToModify := make([]byte, 1) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { @@ -163,7 +265,14 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er return nil } -var hashAlgs = []HashAlgorithm{SHA256, SHA512} +func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { + t.Helper() + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("enable verity: %v", err) + } +} // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. @@ -175,30 +284,18 @@ func TestOpen(t *testing.T) { } filename := "verity-test-file" - if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { t.Fatalf("newFileFD: %v", err) } // Ensure that the corresponding Merkle tree file is created. - lowerRoot := root.Dentry().Impl().(*dentry).lowerVD - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { + if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) } // Ensure the root merkle tree file is created. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + rootMerkleFilename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { + if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) } } @@ -214,17 +311,13 @@ func TestPReadUnmodifiedFileSucceeds(t *testing.T) { } filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) buf := make([]byte, size) n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) @@ -248,17 +341,13 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) buf := make([]byte, size) n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) @@ -272,6 +361,36 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } } +// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity +// file succeeds after enabling verity for it. +func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-empty-file" + fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newEmptyFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + enableVerity(ctx, t, fd) + + var buf []byte + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != 0 { + t.Errorf("fd.Read got read length %d, expected 0", n) + } + } +} + // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { @@ -282,27 +401,16 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) { } filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Ensure reopening the verity enabled file succeeds. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != nil { + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil { t.Errorf("reopen enabled file failed: %v", err) } } @@ -317,43 +425,24 @@ func TestOpenNonexistentFile(t *testing.T) { } filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Enable verity on the parent directory. - parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) + parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, parentFD) // Ensure open an unexpected file in the parent directory fails with // ENOENT rather than verification failure. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename + "abc"), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != syserror.ENOENT { + if _, err = openVerityAt(ctx, vfsObj, root, filename+"abc", linux.O_RDONLY, linux.ModeRegular); err != syserror.ENOENT { t.Errorf("OpenAt unexpected error: %v", err) } } @@ -368,33 +457,22 @@ func TestPReadModifiedFileFails(t *testing.T) { } filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerFD that's read/writable. - lowerVD := fd.Impl().(*fileDescription).d.lowerVD - - lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerVD, - Start: lowerVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) + lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from the modified file fails. @@ -415,33 +493,22 @@ func TestReadModifiedFileFails(t *testing.T) { } filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerFD that's read/writable. - lowerVD := fd.Impl().(*fileDescription).d.lowerVD - - lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerVD, - Start: lowerVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) + lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from the modified file fails. @@ -462,27 +529,16 @@ func TestModifiedMerkleFails(t *testing.T) { } filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, size, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerMerkleFD that's read/writable. - lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD - - lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerMerkleVD, - Start: lowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) + lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } @@ -493,14 +549,13 @@ func TestModifiedMerkleFails(t *testing.T) { t.Errorf("lowerMerkleFD.Stat: %v", err) } - if err := corruptRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from a file with modified Merkle tree fails. buf := make([]byte, size) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - fmt.Println(buf) t.Fatalf("fd.PRead succeeded with modified Merkle file") } } @@ -517,42 +572,23 @@ func TestModifiedParentMerkleFails(t *testing.T) { } filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Enable verity on the parent directory. - parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) + parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, parentFD) // Open a new lowerMerkleFD that's read/writable. - parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD - - parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: parentLowerMerkleVD, - Start: parentLowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) + parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } @@ -572,21 +608,14 @@ func TestModifiedParentMerkleFails(t *testing.T) { if err != nil { t.Fatalf("Failed convert size to int: %v", err) } - if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("flipRandomBit: %v", err) } parentLowerMerkleFD.DecRef(ctx) // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err == nil { + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err == nil { t.Errorf("OpenAt file with modified parent Merkle succeeded") } } @@ -602,18 +631,13 @@ func TestUnmodifiedStatSucceeds(t *testing.T) { } filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } - // Enable verity on the file and confirms stat succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } - + // Enable verity on the file and confirm that stat succeeds. + enableVerity(ctx, t, fd) if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { t.Errorf("fd.Stat: %v", err) } @@ -630,17 +654,13 @@ func TestModifiedStatFails(t *testing.T) { } filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } + enableVerity(ctx, t, fd) lowerFD := fd.Impl().(*fileDescription).lowerFD // Change the stat of the underlying file, and check that stat fails. @@ -663,73 +683,57 @@ func TestModifiedStatFails(t *testing.T) { // and/or the corresponding Merkle tree file fails with the verity error. func TestOpenDeletedFileFails(t *testing.T) { testCases := []struct { + name string // The original file is removed if changeFile is true. changeFile bool // The Merkle tree file is removed if changeMerkleFile is true. changeMerkleFile bool }{ { + name: "FileOnly", changeFile: true, changeMerkleFile: false, }, { + name: "MerkleOnly", changeFile: false, changeMerkleFile: true, }, { + name: "FileAndMerkle", changeFile: true, changeMerkleFile: true, }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("changeFile:%t, changeMerkleFile:%t", tc.changeFile, tc.changeMerkleFile), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { vfsObj, root, ctx, err := newVerityRoot(t, SHA256) if err != nil { t.Fatalf("newVerityRoot: %v", err) } filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) - rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD if tc.changeFile { - if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(filename), - }); err != nil { + if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } if tc.changeMerkleFile { - if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + filename), - }); err != nil { + if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != syserror.EIO { + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { t.Errorf("got OpenAt error: %v, expected EIO", err) } }) @@ -740,82 +744,58 @@ func TestOpenDeletedFileFails(t *testing.T) { // and/or the corresponding Merkle tree file fails with the verity error. func TestOpenRenamedFileFails(t *testing.T) { testCases := []struct { + name string // The original file is renamed if changeFile is true. changeFile bool // The Merkle tree file is renamed if changeMerkleFile is true. changeMerkleFile bool }{ { + name: "FileOnly", changeFile: true, changeMerkleFile: false, }, { + name: "MerkleOnly", changeFile: false, changeMerkleFile: true, }, { + name: "FileAndMerkle", changeFile: true, changeMerkleFile: true, }, } for _, tc := range testCases { - t.Run(fmt.Sprintf("changeFile:%t, changeMerkleFile:%t", tc.changeFile, tc.changeMerkleFile), func(t *testing.T) { + t.Run(tc.name, func(t *testing.T) { vfsObj, root, ctx, err := newVerityRoot(t, SHA256) if err != nil { t.Fatalf("newVerityRoot: %v", err) } filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + fd, _, err := newFileFD(ctx, t, vfsObj, root, filename, 0644) if err != nil { t.Fatalf("newFileFD: %v", err) } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) - rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD newFilename := "renamed-test-file" if tc.changeFile { - if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(filename), - }, &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(newFilename), - }, &vfs.RenameOptions{}); err != nil { + if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil { t.Fatalf("RenameAt: %v", err) } } if tc.changeMerkleFile { - if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + filename), - }, &vfs.PathOperation{ - Root: rootLowerVD, - Start: rootLowerVD, - Path: fspath.Parse(merklePrefix + newFilename), - }, &vfs.RenameOptions{}); err != nil { + if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != syserror.EIO { + if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != syserror.EIO { t.Errorf("got OpenAt error: %v, expected EIO", err) } }) diff --git a/pkg/sentry/fsmetric/BUILD b/pkg/sentry/fsmetric/BUILD new file mode 100644 index 000000000..4e86fbdd8 --- /dev/null +++ b/pkg/sentry/fsmetric/BUILD @@ -0,0 +1,10 @@ +load("//tools:defs.bzl", "go_library") + +licenses(["notice"]) + +go_library( + name = "fsmetric", + srcs = ["fsmetric.go"], + visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/metric"], +) diff --git a/pkg/sentry/fsmetric/fsmetric.go b/pkg/sentry/fsmetric/fsmetric.go new file mode 100644 index 000000000..7e535b527 --- /dev/null +++ b/pkg/sentry/fsmetric/fsmetric.go @@ -0,0 +1,83 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fsmetric defines filesystem metrics that are used by both VFS1 and +// VFS2. +// +// TODO(gvisor.dev/issue/1624): Once VFS1 is deleted, inline these metrics into +// VFS2. +package fsmetric + +import ( + "time" + + "gvisor.dev/gvisor/pkg/metric" +) + +// RecordWaitTime enables the ReadWait, GoferReadWait9P, GoferReadWaitHost, and +// TmpfsReadWait metrics. Enabling this comes at a CPU cost due to performing +// three clock reads per read call. +// +// Note that this is only performed in the direct read path, and may not be +// consistently applied for other forms of reads, such as splice. +var RecordWaitTime = false + +// Metrics that apply to all filesystems. +var ( + Opens = metric.MustCreateNewUint64Metric("/fs/opens", false /* sync */, "Number of file opens.") + Reads = metric.MustCreateNewUint64Metric("/fs/reads", false /* sync */, "Number of file reads.") + ReadWait = metric.MustCreateNewUint64NanosecondsMetric("/fs/read_wait", false /* sync */, "Time waiting on file reads, in nanoseconds.") +) + +// Metrics that only apply to fs/gofer and fsimpl/gofer. +var ( + GoferOpensWX = metric.MustCreateNewUint64Metric("/gofer/opened_write_execute_file", true /* sync */, "Number of times a executable file was opened writably from a gofer.") + GoferOpens9P = metric.MustCreateNewUint64Metric("/gofer/opens_9p", false /* sync */, "Number of times a file was opened from a gofer and did not have a host file descriptor.") + GoferOpensHost = metric.MustCreateNewUint64Metric("/gofer/opens_host", false /* sync */, "Number of times a file was opened from a gofer and did have a host file descriptor.") + GoferReads9P = metric.MustCreateNewUint64Metric("/gofer/reads_9p", false /* sync */, "Number of 9P file reads from a gofer.") + GoferReadWait9P = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_9p", false /* sync */, "Time waiting on 9P file reads from a gofer, in nanoseconds.") + GoferReadsHost = metric.MustCreateNewUint64Metric("/gofer/reads_host", false /* sync */, "Number of host file reads from a gofer.") + GoferReadWaitHost = metric.MustCreateNewUint64NanosecondsMetric("/gofer/read_wait_host", false /* sync */, "Time waiting on host file reads from a gofer, in nanoseconds.") +) + +// Metrics that only apply to fs/tmpfs and fsimpl/tmpfs. +var ( + TmpfsOpensRO = metric.MustCreateNewUint64Metric("/in_memory_file/opens_ro", false /* sync */, "Number of times an in-memory file was opened in read-only mode.") + TmpfsOpensW = metric.MustCreateNewUint64Metric("/in_memory_file/opens_w", false /* sync */, "Number of times an in-memory file was opened in write mode.") + TmpfsReads = metric.MustCreateNewUint64Metric("/in_memory_file/reads", false /* sync */, "Number of in-memory file reads.") + TmpfsReadWait = metric.MustCreateNewUint64NanosecondsMetric("/in_memory_file/read_wait", false /* sync */, "Time waiting on in-memory file reads, in nanoseconds.") +) + +// StartReadWait indicates the beginning of a file read. +func StartReadWait() time.Time { + if !RecordWaitTime { + return time.Time{} + } + return time.Now() +} + +// FinishReadWait indicates the end of a file read whose time is accounted by +// m. start must be the value returned by the corresponding call to +// StartReadWait. +// +// FinishReadWait is marked nosplit for performance since it's often called +// from defer statements, which prevents it from being inlined +// (https://github.com/golang/go/issues/38471). +//go:nosplit +func FinishReadWait(m *metric.Uint64Metric, start time.Time) { + if !RecordWaitTime { + return + } + m.IncrementBy(uint64(time.Since(start).Nanoseconds())) +} diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 15519f0df..61aeca044 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -273,7 +273,7 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent { // // Callback is called when one of the files we're polling becomes ready. It // moves said file to the readyList if it's currently in the waiting list. -func (p *pollEntry) Callback(*waiter.Entry) { +func (p *pollEntry) Callback(*waiter.Entry, waiter.EventMask) { e := p.epoll e.listsMu.Lock() @@ -306,9 +306,8 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) { f.EventRegister(&entry.waiter, entry.mask) // Check if the file happens to already be in a ready state. - ready := f.Readiness(entry.mask) & entry.mask - if ready != 0 { - entry.Callback(&entry.waiter) + if ready := f.Readiness(entry.mask) & entry.mask; ready != 0 { + entry.Callback(&entry.waiter, ready) } } diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 2b3955598..f855f038b 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -8,11 +8,13 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/sentry/arch", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/syserror", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index 153d2cd9b..b66d61c6f 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -17,22 +17,45 @@ package fasync import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) -// New creates a new fs.FileAsync. -func New() fs.FileAsync { - return &FileAsync{} +// Table to convert waiter event masks into si_band siginfo codes. +// Taken from fs/fcntl.c:band_table. +var bandTable = map[waiter.EventMask]int64{ + // POLL_IN + waiter.EventIn: linux.EPOLLIN | linux.EPOLLRDNORM, + // POLL_OUT + waiter.EventOut: linux.EPOLLOUT | linux.EPOLLWRNORM | linux.EPOLLWRBAND, + // POLL_ERR + waiter.EventErr: linux.EPOLLERR, + // POLL_PRI + waiter.EventPri: linux.EPOLLPRI | linux.EPOLLRDBAND, + // POLL_HUP + waiter.EventHUp: linux.EPOLLHUP | linux.EPOLLERR, } -// NewVFS2 creates a new vfs.FileAsync. -func NewVFS2() vfs.FileAsync { - return &FileAsync{} +// New returns a function that creates a new fs.FileAsync with the given file +// descriptor. +func New(fd int) func() fs.FileAsync { + return func() fs.FileAsync { + return &FileAsync{fd: fd} + } +} + +// NewVFS2 returns a function that creates a new vfs.FileAsync with the given +// file descriptor. +func NewVFS2(fd int) func() vfs.FileAsync { + return func() vfs.FileAsync { + return &FileAsync{fd: fd} + } } // FileAsync sends signals when the registered file is ready for IO. @@ -42,6 +65,12 @@ type FileAsync struct { // e is immutable after first use (which is protected by mu below). e waiter.Entry + // fd is the file descriptor to notify about. + // It is immutable, set at allocation time. This matches Linux semantics in + // fs/fcntl.c:fasync_helper. + // The fd value is passed to the signal recipient in siginfo.si_fd. + fd int + // regMu protects registeration and unregistration actions on e. // // regMu must be held while registration decisions are being made @@ -56,6 +85,10 @@ type FileAsync struct { mu sync.Mutex `state:"nosave"` requester *auth.Credentials registered bool + // signal is the signal to deliver upon I/O being available. + // The default value ("zero signal") means the default SIGIO signal will be + // delivered. + signal linux.Signal // Only one of the following is allowed to be non-nil. recipientPG *kernel.ProcessGroup @@ -64,10 +97,10 @@ type FileAsync struct { } // Callback sends a signal. -func (a *FileAsync) Callback(e *waiter.Entry) { +func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) { a.mu.Lock() + defer a.mu.Unlock() if !a.registered { - a.mu.Unlock() return } t := a.recipientT @@ -80,19 +113,34 @@ func (a *FileAsync) Callback(e *waiter.Entry) { } if t == nil { // No recipient has been registered. - a.mu.Unlock() return } c := t.Credentials() // Logic from sigio_perm in fs/fcntl.c. - if a.requester.EffectiveKUID == 0 || + permCheck := (a.requester.EffectiveKUID == 0 || a.requester.EffectiveKUID == c.SavedKUID || a.requester.EffectiveKUID == c.RealKUID || a.requester.RealKUID == c.SavedKUID || - a.requester.RealKUID == c.RealKUID { - t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO)) + a.requester.RealKUID == c.RealKUID) + if !permCheck { + return } - a.mu.Unlock() + signalInfo := &arch.SignalInfo{ + Signo: int32(linux.SIGIO), + Code: arch.SignalInfoKernel, + } + if a.signal != 0 { + signalInfo.Signo = int32(a.signal) + signalInfo.SetFD(uint32(a.fd)) + var band int64 + for m, bandCode := range bandTable { + if m&mask != 0 { + band |= bandCode + } + } + signalInfo.SetBand(band) + } + t.SendSignal(signalInfo) } // Register sets the file which will be monitored for IO events. @@ -186,3 +234,25 @@ func (a *FileAsync) ClearOwner() { a.recipientTG = nil a.recipientPG = nil } + +// Signal returns which signal will be sent to the signal recipient. +// A value of zero means the signal to deliver wasn't customized, which means +// the default signal (SIGIO) will be delivered. +func (a *FileAsync) Signal() linux.Signal { + a.mu.Lock() + defer a.mu.Unlock() + return a.signal +} + +// SetSignal overrides which signal to send when I/O is available. +// The default behavior can be reset by specifying signal zero, which means +// to send SIGIO. +func (a *FileAsync) SetSignal(signal linux.Signal) error { + if signal != 0 && !signal.IsValid() { + return syserror.EINVAL + } + a.mu.Lock() + defer a.mu.Unlock() + a.signal = signal + return nil +} diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index 470d8bf83..f17f9c59c 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -121,18 +121,21 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 panic("VFS1 and VFS2 files set") } - slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) + slicePtr := (*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) // Grow the table as required. - if last := int32(len(slice)); fd >= last { + if last := int32(len(*slicePtr)); fd >= last { end := fd + 1 if end < 2*last { end = 2 * last } - slice = append(slice, make([]unsafe.Pointer, end-last)...) - atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) + newSlice := append(*slicePtr, make([]unsafe.Pointer, end-last)...) + slicePtr = &newSlice + atomic.StorePointer(&f.slice, unsafe.Pointer(slicePtr)) } + slice := *slicePtr + var desc *descriptor if file != nil || fileVFS2 != nil { desc = &descriptor{ diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 2cdcdfc1f..b8627a54f 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -214,9 +214,11 @@ type Kernel struct { // netlinkPorts manages allocation of netlink socket port IDs. netlinkPorts *port.Manager - // saveErr is the error causing the sandbox to exit during save, if - // any. It is protected by extMu. - saveErr error `state:"nosave"` + // saveStatus is nil if the sandbox has not been saved, errSaved or + // errAutoSaved if it has been saved successfully, or the error causing the + // sandbox to exit during save. + // It is protected by extMu. + saveStatus error `state:"nosave"` // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` @@ -1481,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager { return k.netlinkPorts } -// SaveError returns the sandbox error that caused the kernel to exit during -// save. -func (k *Kernel) SaveError() error { +var ( + errSaved = errors.New("sandbox has been successfully saved") + errAutoSaved = errors.New("sandbox has been successfully auto-saved") +) + +// SaveStatus returns the sandbox save status. If it was saved successfully, +// autosaved indicates whether save was triggered by autosave. If it was not +// saved successfully, err indicates the sandbox error that caused the kernel to +// exit during save. +func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) { + k.extMu.Lock() + defer k.extMu.Unlock() + switch k.saveStatus { + case nil: + return false, false, nil + case errSaved: + return true, false, nil + case errAutoSaved: + return true, true, nil + default: + return false, false, k.saveStatus + } +} + +// SetSaveSuccess sets the flag indicating that save completed successfully, if +// no status was already set. +func (k *Kernel) SetSaveSuccess(autosave bool) { k.extMu.Lock() defer k.extMu.Unlock() - return k.saveErr + if k.saveStatus == nil { + if autosave { + k.saveStatus = errAutoSaved + } else { + k.saveStatus = errSaved + } + } } // SetSaveError sets the sandbox error that caused the kernel to exit during @@ -1494,8 +1526,8 @@ func (k *Kernel) SaveError() error { func (k *Kernel) SetSaveError(err error) { k.extMu.Lock() defer k.extMu.Unlock() - if k.saveErr == nil { - k.saveErr = err + if k.saveStatus == nil { + k.saveStatus = err } } diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 1abfe2201..cef58a590 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) { Signo: int32(linux.SIGTRAP), Code: code, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) if t.beginPtraceStopLocked() { tracer := t.Tracer() tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP)) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index b99c0bffa..db01e4a97 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -29,17 +29,17 @@ import ( ) const ( - valueMax = 32767 // SEMVMX + // Maximum semaphore value. + valueMax = linux.SEMVMX - // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL). - semaphoresMax = 32000 + // Maximum number of semaphore sets. + setsMax = linux.SEMMNI - // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI). - setsMax = 32000 + // Maximum number of semaphroes in a semaphore set. + semsMax = linux.SEMMSL - // semaphoresTotalMax is "system-wide limit on the number of semaphores" - // (SEMMNS = SEMMNI*SEMMSL). - semaphoresTotalMax = 1024000000 + // Maximum number of semaphores in all semaphroe sets. + semsTotalMax = linux.SEMMNS ) // Registry maintains a set of semaphores that can be found by key or ID. @@ -52,6 +52,9 @@ type Registry struct { mu sync.Mutex `state:"nosave"` semaphores map[int32]*Set lastIDUsed int32 + // indexes maintains a mapping between a set's index in virtual array and + // its identifier. + indexes map[int32]int32 } // Set represents a set of semaphores that can be operated atomically. @@ -113,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { return &Registry{ userNS: userNS, semaphores: make(map[int32]*Set), + indexes: make(map[int32]int32), } } @@ -122,7 +126,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { // be found. If exclusive is true, it fails if a set with the same key already // exists. func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) { - if nsems < 0 || nsems > semaphoresMax { + if nsems < 0 || nsems > semsMax { return nil, syserror.EINVAL } @@ -163,10 +167,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu } // Apply system limits. + // + // Map semaphores and map indexes in a registry are of the same size, + // check map semaphores only here for the system limit. if len(r.semaphores) >= setsMax { return nil, syserror.EINVAL } - if r.totalSems() > int(semaphoresTotalMax-nsems) { + if r.totalSems() > int(semsTotalMax-nsems) { return nil, syserror.EINVAL } @@ -176,6 +183,53 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu return r.newSet(ctx, key, owner, owner, perms, nsems) } +// IPCInfo returns information about system-wide semaphore limits and parameters. +func (r *Registry) IPCInfo() *linux.SemInfo { + return &linux.SemInfo{ + SemMap: linux.SEMMAP, + SemMni: linux.SEMMNI, + SemMns: linux.SEMMNS, + SemMnu: linux.SEMMNU, + SemMsl: linux.SEMMSL, + SemOpm: linux.SEMOPM, + SemUme: linux.SEMUME, + SemUsz: linux.SEMUSZ, + SemVmx: linux.SEMVMX, + SemAem: linux.SEMAEM, + } +} + +// SemInfo returns a seminfo structure containing the same information as +// for IPC_INFO, except that SemUsz field returns the number of existing +// semaphore sets, and SemAem field returns the number of existing semaphores. +func (r *Registry) SemInfo() *linux.SemInfo { + r.mu.Lock() + defer r.mu.Unlock() + + info := r.IPCInfo() + info.SemUsz = uint32(len(r.semaphores)) + info.SemAem = uint32(r.totalSems()) + + return info +} + +// HighestIndex returns the index of the highest used entry in +// the kernel's array. +func (r *Registry) HighestIndex() int32 { + r.mu.Lock() + defer r.mu.Unlock() + + // By default, highest used index is 0 even though + // there is no semaphroe set. + var highestIndex int32 + for index := range r.indexes { + if index > highestIndex { + highestIndex = index + } + } + return highestIndex +} + // RemoveID removes set with give 'id' from the registry and marks the set as // dead. All waiters will be awakened and fail. func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { @@ -186,6 +240,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { if set == nil { return syserror.EINVAL } + index, found := r.findIndexByID(id) + if !found { + // Inconsistent state. + panic(fmt.Sprintf("unable to find an index for ID: %d", id)) + } set.mu.Lock() defer set.mu.Unlock() @@ -197,6 +256,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { } delete(r.semaphores, set.ID) + delete(r.indexes, index) set.destroy() return nil } @@ -220,6 +280,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File continue } if r.semaphores[id] == nil { + index, found := r.findFirstAvailableIndex() + if !found { + panic("unable to find an available index") + } + r.indexes[index] = id r.lastIDUsed = id r.semaphores[id] = set set.ID = id @@ -238,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set { return r.semaphores[id] } +// FindByIndex looks up a set given an index. +func (r *Registry) FindByIndex(index int32) *Set { + r.mu.Lock() + defer r.mu.Unlock() + + id, present := r.indexes[index] + if !present { + return nil + } + return r.semaphores[id] +} + func (r *Registry) findByKey(key int32) *Set { for _, v := range r.semaphores { if v.key == key { @@ -247,6 +324,24 @@ func (r *Registry) findByKey(key int32) *Set { return nil } +func (r *Registry) findIndexByID(id int32) (int32, bool) { + for k, v := range r.indexes { + if v == id { + return k, true + } + } + return 0, false +} + +func (r *Registry) findFirstAvailableIndex() (int32, bool) { + for index := int32(0); index < setsMax; index++ { + if _, present := r.indexes[index]; !present { + return index, true + } + } + return 0, false +} + func (r *Registry) totalSems() int { totalSems := 0 for _, v := range r.semaphores { diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 80a592c8f..073e14507 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -6,6 +6,9 @@ package(licenses = ["notice"]) go_template_instance( name = "shm_refs", out = "shm_refs.go", + consts = { + "enableLogging": "true", + }, package = "shm", prefix = "Shm", template = "//pkg/refsvfs2:refs_template", diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go index e8cce37d0..2488ae7d5 100644 --- a/pkg/sentry/kernel/signal.go +++ b/pkg/sentry/kernel/signal.go @@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 78f718cfe..884966120 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/kernel/syslog.go b/pkg/sentry/kernel/syslog.go index a83ce219c..3fee7aa68 100644 --- a/pkg/sentry/kernel/syslog.go +++ b/pkg/sentry/kernel/syslog.go @@ -75,6 +75,12 @@ func (s *syslog) Log() []byte { "Checking naughty and nice process list...", // Check it up to twice. "Granting licence to kill(2)...", // British spelling for British movie. "Letting the watchdogs out...", + "Conjuring /dev/null black hole...", + "Adversarially training Redcode AI...", + "Singleplexing /dev/ptmx...", + "Recruiting cron-ies...", + "Verifying that no non-zero bytes made their way into /dev/zero...", + "Accelerating teletypewriter to 9600 baud...", } selectMessage := func() string { diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c5137c282..16986244c 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -368,8 +368,8 @@ func (t *Task) exitChildren() { Signo: int32(sig), Code: arch.SignalInfoUser, } - siginfo.SetPid(int32(c.tg.pidns.tids[t])) - siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) + siginfo.SetPID(int32(c.tg.pidns.tids[t])) + siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) c.tg.signalHandlers.mu.Lock() c.sendSignalLocked(siginfo, true /* group */) c.tg.signalHandlers.mu.Unlock() @@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si info := &arch.SignalInfo{ Signo: int32(sig), } - info.SetPid(int32(receiver.tg.pidns.tids[t])) - info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.tids[t])) + info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) if t.exitStatus.Signaled() { info.Code = arch.CLD_KILLED info.SetStatus(int32(t.exitStatus.Signo)) diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 42dd3e278..75af3af79 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -914,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { Signo: int32(linux.SIGCHLD), Code: code, } - sigchld.SetPid(int32(t.tg.pidns.tids[target])) - sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + sigchld.SetPID(int32(t.tg.pidns.tids[target])) + sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) sigchld.SetStatus(status) // TODO(b/72102453): Set utime, stime. t.sendSignalLocked(sigchld, true /* group */) @@ -1022,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState { Signo: int32(sig), Code: t.ptraceCode, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } else { t.ptraceCode = int32(sig) t.ptraceSiginfo = nil @@ -1114,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { if parent == nil { // Tracer has detached and t was created by Kernel.CreateProcess(). // Pretend the parent is in an ancestor PID + user namespace. - info.SetPid(0) - info.SetUid(int32(auth.OverflowUID)) + info.SetPID(0) + info.SetUID(int32(auth.OverflowUID)) } else { - info.SetPid(int32(t.tg.pidns.tids[parent])) - info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + info.SetPID(int32(t.tg.pidns.tids[parent])) + info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } } t.tg.signalHandlers.mu.Lock() diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 7fd77925f..49e21026e 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp // Translations must be contiguous and in increasing order of // Translation.Source. if i > 0 && ts[i-1].Source.End != t.Source.Start { - return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t) + return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t) } // At least part of each Translation must be required. if t.Source.Intersect(required).Length() == 0 { diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 4c8cd38ed..5ab2ef79f 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -36,12 +36,12 @@ type aioManager struct { contexts map[uint64]*AIOContext } -func (a *aioManager) destroy() { - a.mu.Lock() - defer a.mu.Unlock() +func (mm *MemoryManager) destroyAIOManager(ctx context.Context) { + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() - for _, ctx := range a.contexts { - ctx.destroy() + for id := range mm.aioManager.contexts { + mm.destroyAIOContextLocked(ctx, id) } } @@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool { // be drained. // // Nil is returned if the context does not exist. -func (a *aioManager) destroyAIOContext(id uint64) *AIOContext { - a.mu.Lock() - defer a.mu.Unlock() - ctx, ok := a.contexts[id] +// +// Precondition: mm.aioManager.mu is locked. +func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext { + aioCtx, ok := mm.aioManager.contexts[id] if !ok { return nil } - delete(a.contexts, id) - ctx.destroy() - return ctx + + // Only unmaps after it assured that the address is a valid aio context to + // prevent random memory from been unmapped. + // + // Note: It's possible to unmap this address and map something else into + // the same address. Then it would be unmapping memory that it doesn't own. + // This is, however, the way Linux implements AIO. Keeps the same [weird] + // semantics in case anyone relies on it. + mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) + + delete(mm.aioManager.contexts, id) + aioCtx.destroy() + return aioCtx } // lookupAIOContext looks up the given context. @@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() { } } -// Prepare reserves space for a new request, returning true if available. -// Returns false if the context is busy. -func (ctx *AIOContext) Prepare() bool { +// Prepare reserves space for a new request, returning nil if available. +// Returns EAGAIN if the context is busy and EINVAL if the context is dead. +func (ctx *AIOContext) Prepare() error { ctx.mu.Lock() defer ctx.mu.Unlock() + if ctx.dead { + // Context died after the caller looked it up. + return syserror.EINVAL + } if ctx.outstanding >= ctx.maxOutstanding { - return false + // Context is busy. + return syserror.EAGAIN } ctx.outstanding++ - return true + return nil } // PopRequest pops a completed request if available, this function does not do @@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint // DestroyAIOContext destroys an asynchronous I/O context. It returns the // destroyed context. nil if the context does not exist. func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext { - if _, ok := mm.LookupAIOContext(ctx, id); !ok { + if !mm.isValidAddr(ctx, id) { return nil } - // Only unmaps after it assured that the address is a valid aio context to - // prevent random memory from been unmapped. - // - // Note: It's possible to unmap this address and map something else into - // the same address. Then it would be unmapping memory that it doesn't own. - // This is, however, the way Linux implements AIO. Keeps the same [weird] - // semantics in case anyone relies on it. - mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) - - return mm.aioManager.destroyAIOContext(id) + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() + return mm.destroyAIOContextLocked(ctx, id) } // LookupAIOContext looks up the given context. It returns false if the context @@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC return nil, false } - // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes - // from id). - var buf [4]byte - _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) - if err != nil { + // Protect against 'id' that is inaccessible. + if !mm.isValidAddr(ctx, id) { return nil, false } return aioCtx, true } + +// isValidAddr determines if the address `id` is valid. (Linux also reads 4 +// bytes from id). +func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool { + var buf [4]byte + _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) + return err == nil +} diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go index 3dabac1af..e8931922f 100644 --- a/pkg/sentry/mm/aio_context_state.go +++ b/pkg/sentry/mm/aio_context_state.go @@ -15,6 +15,6 @@ package mm // afterLoad is invoked by stateify. -func (a *AIOContext) afterLoad() { - a.requestReady = make(chan struct{}, 1) +func (ctx *AIOContext) afterLoad() { + ctx.requestReady = make(chan struct{}, 1) } diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 09dbc06a4..120707429 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) { panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users)) } - mm.aioManager.destroy() + mm.destroyAIOManager(ctx) mm.metadataMu.Lock() exe := mm.executable diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index acac3d357..bc53bd41e 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) { t.Errorf("CopyOut got %d want 1", n) } } + +// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be +// prepared after destruction. +func TestAIOPrepareAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + defer mm.DecUsers(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + aioCtx, ok := mm.LookupAIOContext(ctx, id) + if !ok { + t.Fatalf("AIOContext not found") + } + mm.DestroyAIOContext(ctx, id) + + // Prepare should fail because aioCtx should be destroyed. + if err := aioCtx.Prepare(); err != syserror.EINVAL { + t.Errorf("aioCtx.Prepare got err %v want nil", err) + } else if err == nil { + aioCtx.CancelPendingRequest() + } +} + +// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be +// looked up after memory manager is destroyed. +func TestAIOLookupAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + mm.DecUsers(ctx) + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + mm.DecUsers(ctx) // This destroys the AIOContext manager. + + if _, ok := mm.LookupAIOContext(ctx, id); ok { + t.Errorf("AIOContext found even after AIOContext manager is destroyed") + } +} diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 7c297fb9e..d99be7f46 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -423,11 +423,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File } if f.opts.ManualZeroing { - if err := f.forEachMappingSlice(fr, func(bs []byte) { - for i := range bs { - bs[i] = 0 - } - }); err != nil { + if err := f.manuallyZero(fr); err != nil { return memmap.FileRange{}, err } } @@ -560,19 +556,39 @@ func (f *MemoryFile) Decommit(fr memmap.FileRange) error { panic(fmt.Sprintf("invalid range: %v", fr)) } + if f.opts.ManualZeroing { + // FALLOC_FL_PUNCH_HOLE may not zero pages if ManualZeroing is in + // effect. + if err := f.manuallyZero(fr); err != nil { + return err + } + } else { + if err := f.decommitFile(fr); err != nil { + return err + } + } + + f.markDecommitted(fr) + return nil +} + +func (f *MemoryFile) manuallyZero(fr memmap.FileRange) error { + return f.forEachMappingSlice(fr, func(bs []byte) { + for i := range bs { + bs[i] = 0 + } + }) +} + +func (f *MemoryFile) decommitFile(fr memmap.FileRange) error { // "After a successful call, subsequent reads from this range will // return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with // FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2) - err := syscall.Fallocate( + return syscall.Fallocate( int(f.file.Fd()), _FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE, int64(fr.Start), int64(fr.Length())) - if err != nil { - return err - } - f.markDecommitted(fr) - return nil } func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { @@ -1044,20 +1060,20 @@ func (f *MemoryFile) runReclaim() { break } - if err := f.Decommit(fr); err != nil { - log.Warningf("Reclaim failed to decommit %v: %v", fr, err) - // Zero the pages manually. This won't reduce memory usage, but at - // least ensures that the pages will be zero when reallocated. - f.forEachMappingSlice(fr, func(bs []byte) { - for i := range bs { - bs[i] = 0 + // If ManualZeroing is in effect, pages will be zeroed on allocation + // and may not be freed by decommitFile, so calling decommitFile is + // unnecessary. + if !f.opts.ManualZeroing { + if err := f.decommitFile(fr); err != nil { + log.Warningf("Reclaim failed to decommit %v: %v", fr, err) + // Zero the pages manually. This won't reduce memory usage, but at + // least ensures that the pages will be zero when reallocated. + if err := f.manuallyZero(fr); err != nil { + panic(fmt.Sprintf("Reclaim failed to decommit or zero %v: %v", fr, err)) } - }) - // Pretend the pages were decommitted even though they weren't, - // since the memory accounting implementation has no idea how to - // deal with this. - f.markDecommitted(fr) + } } + f.markDecommitted(fr) f.markReclaimed(fr) } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index acad4c793..f8ccb7430 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -91,6 +91,13 @@ func bluepillSigBus(c *vCPU) { } } +// bluepillHandleEnosys is reponsible for handling enosys error. +// +//go:nosplit +func bluepillHandleEnosys(c *vCPU) { + throw("run failed: ENOSYS") +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection. // //go:nosplit @@ -126,3 +133,10 @@ func bluepillReadyStopGuest(c *vCPU) bool { } return true } + +// bluepillArchHandleExit checks architecture specific exitcode. +// +//go:nosplit +func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) { + c.die(bluepillArchContext(context), "unknown") +} diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 965ad66b5..1f09813ba 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -42,6 +42,13 @@ var ( sErrEsr: _ESR_ELx_SERR_NMI, }, } + + // vcpuExtDabt is the event of ext_dabt. + vcpuExtDabt = kvmVcpuEvents{ + exception: exception{ + extDabtPending: 1, + }, + } ) // getTLS returns the value of TPIDR_EL0 register. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 9433d4da5..4d912769a 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -85,7 +85,7 @@ func bluepillStopGuest(c *vCPU) { uintptr(c.fd), _KVM_SET_VCPU_EVENTS, uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 { - throw("sErr injection failed") + throw("bounce sErr injection failed") } } @@ -93,18 +93,54 @@ func bluepillStopGuest(c *vCPU) { // //go:nosplit func bluepillSigBus(c *vCPU) { + // Host must support ARM64_HAS_RAS_EXTN. if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 { - throw("sErr injection failed") + if errno == syscall.EINVAL { + throw("No ARM64_HAS_RAS_EXTN feature in host.") + } + throw("nmi sErr injection failed") } } +// bluepillExtDabt is reponsible for injecting external data abort. +// +//go:nosplit +func bluepillExtDabt(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 { + throw("ext_dabt injection failed") + } +} + +// bluepillHandleEnosys is reponsible for handling enosys error. +// +//go:nosplit +func bluepillHandleEnosys(c *vCPU) { + bluepillExtDabt(c) +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection. // //go:nosplit func bluepillReadyStopGuest(c *vCPU) bool { return true } + +// bluepillArchHandleExit checks architecture specific exitcode. +// +//go:nosplit +func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) { + switch c.runData.exitReason { + case _KVM_EXIT_ARM_NISV: + bluepillExtDabt(c) + default: + c.die(bluepillArchContext(context), "unknown") + } +} diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 75085ac6a..8c5369377 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -148,6 +148,9 @@ func bluepillHandler(context unsafe.Pointer) { // mode and have interrupts disabled. bluepillSigBus(c) continue // Rerun vCPU. + case syscall.ENOSYS: + bluepillHandleEnosys(c) + continue default: throw("run failed") } @@ -220,7 +223,7 @@ func bluepillHandler(context unsafe.Pointer) { c.die(bluepillArchContext(context), "entry failed") return default: - c.die(bluepillArchContext(context), "unknown") + bluepillArchHandleExit(c, context) return } } diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 0b06a923a..9db1db4e9 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -47,10 +47,11 @@ type userRegs struct { } type exception struct { - sErrPending uint8 - sErrHasEsr uint8 - pad [6]uint8 - sErrEsr uint64 + sErrPending uint8 + sErrHasEsr uint8 + extDabtPending uint8 + pad [5]uint8 + sErrEsr uint64 } type kvmVcpuEvents struct { diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 6abaa21c4..2492d57be 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -56,6 +56,7 @@ const ( _KVM_EXIT_FAIL_ENTRY = 0x9 _KVM_EXIT_INTERNAL_ERROR = 0x11 _KVM_EXIT_SYSTEM_EVENT = 0x18 + _KVM_EXIT_ARM_NISV = 0x1c ) // KVM capability options. diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index fd92c3873..3f5be276b 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -263,13 +263,6 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) return usermem.NoAccess, platform.ErrContextInterrupt case ring0.El0SyncUndef: return c.fault(int32(syscall.SIGILL), info) - case ring0.El1SyncUndef: - *info = arch.SignalInfo{ - Signo: int32(syscall.SIGILL), - Code: 1, // ILL_ILLOPC (illegal opcode). - } - info.SetAddr(switchOpts.Registers.Pc) // Include address. - return usermem.AccessType{}, platform.ErrContextSignal default: panic(fmt.Sprintf("unexpected vector: 0x%x", vector)) } diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index f56aa3b79..571bfcc2e 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -18,8 +18,8 @@ // // In a nutshell, it works as follows: // -// The creation of a new address space creates a new child processes with a -// single thread which is traced by a single goroutine. +// The creation of a new address space creates a new child process with a single +// thread which is traced by a single goroutine. // // A context is just a collection of temporary variables. Calling Switch on a // context does the following: diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 812ab80ef..aacd7ce70 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // facilitate vsyscall emulation. See patchSignalInfo. patchSignalInfo(regs, &c.signalInfo) return false - } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) { + } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) { // The signal was generated by this process. That means // that it was an interrupt or something else that we // should bail for. Note that we ignore signals diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 679b287c3..2852b7387 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "arch_genrule", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -39,19 +39,19 @@ go_template_instance( template = ":defs_arm64", ) -genrule( +arch_genrule( name = "entry_impl_amd64", srcs = ["entry_amd64.s"], outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) -genrule( +arch_genrule( name = "entry_impl_arm64", srcs = ["entry_arm64.s"], outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) @@ -72,7 +72,6 @@ go_library( "lib_amd64.s", "lib_arm64.go", "lib_arm64.s", - "lib_arm64_unsafe.go", "ring0.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index 327d48465..c51df2811 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -90,6 +90,7 @@ const ( El0SyncIa El0SyncFpsimdAcc El0SyncSveAcc + El0SyncFpsimdExc El0SyncSys El0SyncSpPc El0SyncUndef diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index f489ad352..cf0bf3528 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -132,40 +132,6 @@ MOVD offset+PTRACE_R29(reg), R29; \ MOVD offset+PTRACE_R30(reg), R30; -// NOP-s -#define nop31Instructions() \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; - #define ESR_ELx_EC_UNKNOWN (0x00) #define ESR_ELx_EC_WFx (0x01) /* Unallocated EC: 0x02 */ @@ -324,6 +290,18 @@ MOVD CPU_TTBR0_KVM(from), RSV_REG; \ MSR RSV_REG, TTBR0_EL1; +TEXT ·EnableVFP(SB),NOSPLIT,$0 + MOVD $FPEN_ENABLE, R0 + WORD $0xd5181040 //MSR R0, CPACR_EL1 + ISB $15 + RET + +TEXT ·DisableVFP(SB),NOSPLIT,$0 + MOVD $0, R0 + WORD $0xd5181040 //MSR R0, CPACR_EL1 + ISB $15 + RET + #define VFP_ENABLE \ MOVD $FPEN_ENABLE, R0; \ WORD $0xd5181040; \ //MSR R0, CPACR_EL1 @@ -370,12 +348,12 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. -// EXCEPTION_WITH_ERROR is a common exception handler function. -#define EXCEPTION_WITH_ERROR(user, vector) \ +// EXCEPTION_EL0 is a common el0 exception handler function. +#define EXCEPTION_EL0(vector) \ WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 WORD $0xd538601a; \ //MRS FAR_EL1, R26 MOVD R26, CPU_FAULT_ADDR(RSV_REG); \ - MOVD $user, R3; \ + MOVD $1, R3; \ MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user. MOVD $vector, R3; \ MOVD R3, CPU_VECTOR_CODE(RSV_REG); \ @@ -383,6 +361,12 @@ MOVD R3, CPU_ERROR_CODE(RSV_REG); \ B ·kernelExitToEl1(SB); +// EXCEPTION_EL1 is a common el1 exception handler function. +#define EXCEPTION_EL1(vector) \ + MOVD $vector, R3; \ + MOVD R3, 8(RSP); \ + B ·HaltEl1ExceptionAndResume(SB); + // storeAppASID writes the application's asid value. TEXT ·storeAppASID(SB),NOSPLIT,$0-8 MOVD asid+0(FP), R1 @@ -430,6 +414,16 @@ TEXT ·HaltEl1SvcAndResume(SB),NOSPLIT,$0 CALL ·kernelSyscall(SB) // Call the trampoline. B ·kernelExitToEl1(SB) // Resume. +// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume. +TEXT ·HaltEl1ExceptionAndResume(SB),NOSPLIT,$0-8 + WORD $0xd538d092 // MRS TPIDR_EL1, R18 + MOVD CPU_SELF(RSV_REG), R3 // Load vCPU. + MOVD R3, 8(RSP) // First argument (vCPU). + MOVD vector+0(FP), R3 + MOVD R3, 16(RSP) // Second argument (vector). + CALL ·kernelException(SB) // Call the trampoline. + B ·kernelExitToEl1(SB) // Resume. + // Shutdown stops the guest. TEXT ·Shutdown(SB),NOSPLIT,$0 // PSCI EVENT. @@ -592,39 +586,22 @@ TEXT ·El1_sync(SB),NOSPLIT,$0 B el1_invalid el1_da: + EXCEPTION_EL1(El1SyncDa) el1_ia: - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - WORD $0xd538601a //MRS FAR_EL1, R26 - - MOVD R26, CPU_FAULT_ADDR(RSV_REG) - - MOVD $0, CPU_ERROR_TYPE(RSV_REG) - - MOVD $PageFault, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - B ·HaltAndResume(SB) - + EXCEPTION_EL1(El1SyncIa) el1_sp_pc: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncSpPc) el1_undef: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncUndef) el1_svc: - MOVD $0, CPU_ERROR_CODE(RSV_REG) - MOVD $0, CPU_ERROR_TYPE(RSV_REG) B ·HaltEl1SvcAndResume(SB) - el1_dbg: - B ·Shutdown(SB) - + EXCEPTION_EL1(El1SyncDbg) el1_fpsimd_acc: VFP_ENABLE B ·kernelExitToEl1(SB) // Resume. - el1_invalid: - B ·Shutdown(SB) + EXCEPTION_EL1(El1SyncInv) // El1_irq is the handler for El1_irq. TEXT ·El1_irq(SB),NOSPLIT,$0 @@ -680,28 +657,21 @@ el0_svc: el0_da: el0_ia: - EXCEPTION_WITH_ERROR(1, PageFault) - + EXCEPTION_EL0(PageFault) el0_fpsimd_acc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncFpsimdAcc) el0_sve_acc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncSveAcc) el0_fpsimd_exc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncFpsimdExc) el0_sp_pc: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncSpPc) el0_undef: - EXCEPTION_WITH_ERROR(1, El0SyncUndef) - + EXCEPTION_EL0(El0SyncUndef) el0_dbg: - B ·Shutdown(SB) - + EXCEPTION_EL0(El0SyncDbg) el0_invalid: - B ·Shutdown(SB) + EXCEPTION_EL0(El0SyncInv) TEXT ·El0_irq(SB),NOSPLIT,$0 B ·Shutdown(SB) @@ -760,79 +730,43 @@ TEXT ·El0_error_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) // Vectors implements exception vector table. +// The start address of exception vector table should be 11-bits aligned. +// For detail, please refer to arm developer document: +// https://developer.arm.com/documentation/100933/0100/AArch64-exception-vector-table +// Also can refer to the code in linux kernel: arch/arm64/kernel/entry.S TEXT ·Vectors(SB),NOSPLIT,$0 + PCALIGN $2048 B ·El1_sync_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_irq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_fiq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_error_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_sync(SB) - nop31Instructions() + PCALIGN $128 B ·El1_irq(SB) - nop31Instructions() + PCALIGN $128 B ·El1_fiq(SB) - nop31Instructions() + PCALIGN $128 B ·El1_error(SB) - nop31Instructions() + PCALIGN $128 B ·El0_sync(SB) - nop31Instructions() + PCALIGN $128 B ·El0_irq(SB) - nop31Instructions() + PCALIGN $128 B ·El0_fiq(SB) - nop31Instructions() + PCALIGN $128 B ·El0_error(SB) - nop31Instructions() + PCALIGN $128 B ·El0_sync_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_irq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_fiq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_error_invalid(SB) - nop31Instructions() - - // The exception-vector-table is required to be 11-bits aligned. - // Please see Linux source code as reference: arch/arm64/kernel/entry.s. - // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs. - // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. - WORD $0xd503201f //nop - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 9742308d8..a9703baf6 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,6 +24,9 @@ go_binary( "defs_impl_arm64.go", "main.go", ], + # Use the libc malloc to avoid any extra dependencies. This is required to + # pass the sentry deps test. + system_malloc = True, visibility = [ "//pkg/sentry/platform/kvm:__pkg__", "//pkg/sentry/platform/ring0:__pkg__", diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index c1c808b96..c05284641 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -24,6 +24,10 @@ func HaltAndResume() //go:nosplit func HaltEl1SvcAndResume() +// HaltEl1ExceptionAndResume calls Hooks.KernelException and resume. +//go:nosplit +func HaltEl1ExceptionAndResume() + // init initializes architecture-specific state. func (k *Kernel) init(maxCPUs int) { } @@ -49,6 +53,12 @@ func IsCanonical(addr uint64) bool { return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000 } +// SwitchToUser performs an eret. +// +// The return value is the exception vector. +// +// +checkescape:all +// //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeAppASID(uintptr(switchOpts.UserASID)) @@ -61,11 +71,13 @@ func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { regs.Pstate &= ^uint64(PsrFlagsClear) regs.Pstate |= UserFlagsSet + EnableVFP() LoadFloatingPoint(switchOpts.FloatingPointState) kernelExitToEl0() SaveFloatingPoint(switchOpts.FloatingPointState) + DisableVFP() vector = c.vecCode diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index bf1c655f4..a490bf3af 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -34,13 +34,13 @@ func FlushTlbAll() // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) -// FPCR returns the value of FPCR register. +// GetFPCR returns the value of FPCR register. func GetFPCR() (value uintptr) // SetFPCR writes the FPCR value. func SetFPCR(value uintptr) -// FPSR returns the value of FPSR register. +// GetFPSR returns the value of FPSR register. func GetFPSR() (value uintptr) // SetFPSR writes the FPSR value. @@ -59,9 +59,13 @@ func LoadFloatingPoint(*byte) // SaveFloatingPoint saves floating point state. func SaveFloatingPoint(*byte) +// EnableVFP enables fpsimd. +func EnableVFP() + +// DisableVFP disables fpsimd. +func DisableVFP() + // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. -func Init() { - rewriteVectors() -} +func Init() {} diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 675a8bdb7..e39b32841 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -52,62 +52,47 @@ TEXT ·CPACREL1(SB),NOSPLIT,$0-8 RET TEXT ·GetFPCR(SB),NOSPLIT,$0-8 - WORD $0xd53b4201 // MRS NZCV, R1 + MOVD FPCR, R1 MOVD R1, ret+0(FP) RET TEXT ·GetFPSR(SB),NOSPLIT,$0-8 - WORD $0xd53b4421 // MRS FPSR, R1 + MOVD FPSR, R1 MOVD R1, ret+0(FP) RET TEXT ·SetFPCR(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R1 - WORD $0xd51b4201 // MSR R1, NZCV + MOVD R1, FPCR RET TEXT ·SetFPSR(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R1 - WORD $0xd51b4421 // MSR R1, FPSR + MOVD R1, FPSR RET TEXT ·SaveVRegs(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R0 // Skip aarch64_ctx, fpsr, fpcr. - FMOVD F0, 16*1(R0) - FMOVD F1, 16*2(R0) - FMOVD F2, 16*3(R0) - FMOVD F3, 16*4(R0) - FMOVD F4, 16*5(R0) - FMOVD F5, 16*6(R0) - FMOVD F6, 16*7(R0) - FMOVD F7, 16*8(R0) - FMOVD F8, 16*9(R0) - FMOVD F9, 16*10(R0) - FMOVD F10, 16*11(R0) - FMOVD F11, 16*12(R0) - FMOVD F12, 16*13(R0) - FMOVD F13, 16*14(R0) - FMOVD F14, 16*15(R0) - FMOVD F15, 16*16(R0) - FMOVD F16, 16*17(R0) - FMOVD F17, 16*18(R0) - FMOVD F18, 16*19(R0) - FMOVD F19, 16*20(R0) - FMOVD F20, 16*21(R0) - FMOVD F21, 16*22(R0) - FMOVD F22, 16*23(R0) - FMOVD F23, 16*24(R0) - FMOVD F24, 16*25(R0) - FMOVD F25, 16*26(R0) - FMOVD F26, 16*27(R0) - FMOVD F27, 16*28(R0) - FMOVD F28, 16*29(R0) - FMOVD F29, 16*30(R0) - FMOVD F30, 16*31(R0) - FMOVD F31, 16*32(R0) - ISB $15 + ADD $16, R0, R0 + + WORD $0xad000400 // stp q0, q1, [x0] + WORD $0xad010c02 // stp q2, q3, [x0, #32] + WORD $0xad021404 // stp q4, q5, [x0, #64] + WORD $0xad031c06 // stp q6, q7, [x0, #96] + WORD $0xad042408 // stp q8, q9, [x0, #128] + WORD $0xad052c0a // stp q10, q11, [x0, #160] + WORD $0xad06340c // stp q12, q13, [x0, #192] + WORD $0xad073c0e // stp q14, q15, [x0, #224] + WORD $0xad084410 // stp q16, q17, [x0, #256] + WORD $0xad094c12 // stp q18, q19, [x0, #288] + WORD $0xad0a5414 // stp q20, q21, [x0, #320] + WORD $0xad0b5c16 // stp q22, q23, [x0, #352] + WORD $0xad0c6418 // stp q24, q25, [x0, #384] + WORD $0xad0d6c1a // stp q26, q27, [x0, #416] + WORD $0xad0e741c // stp q28, q29, [x0, #448] + WORD $0xad0f7c1e // stp q30, q31, [x0, #480] RET @@ -115,39 +100,24 @@ TEXT ·LoadVRegs(SB),NOSPLIT,$0-8 MOVD addr+0(FP), R0 // Skip aarch64_ctx, fpsr, fpcr. - FMOVD 16*1(R0), F0 - FMOVD 16*2(R0), F1 - FMOVD 16*3(R0), F2 - FMOVD 16*4(R0), F3 - FMOVD 16*5(R0), F4 - FMOVD 16*6(R0), F5 - FMOVD 16*7(R0), F6 - FMOVD 16*8(R0), F7 - FMOVD 16*9(R0), F8 - FMOVD 16*10(R0), F9 - FMOVD 16*11(R0), F10 - FMOVD 16*12(R0), F11 - FMOVD 16*13(R0), F12 - FMOVD 16*14(R0), F13 - FMOVD 16*15(R0), F14 - FMOVD 16*16(R0), F15 - FMOVD 16*17(R0), F16 - FMOVD 16*18(R0), F17 - FMOVD 16*19(R0), F18 - FMOVD 16*20(R0), F19 - FMOVD 16*21(R0), F20 - FMOVD 16*22(R0), F21 - FMOVD 16*23(R0), F22 - FMOVD 16*24(R0), F23 - FMOVD 16*25(R0), F24 - FMOVD 16*26(R0), F25 - FMOVD 16*27(R0), F26 - FMOVD 16*28(R0), F27 - FMOVD 16*29(R0), F28 - FMOVD 16*30(R0), F29 - FMOVD 16*31(R0), F30 - FMOVD 16*32(R0), F31 - ISB $15 + ADD $16, R0, R0 + + WORD $0xad400400 // ldp q0, q1, [x0] + WORD $0xad410c02 // ldp q2, q3, [x0, #32] + WORD $0xad421404 // ldp q4, q5, [x0, #64] + WORD $0xad431c06 // ldp q6, q7, [x0, #96] + WORD $0xad442408 // ldp q8, q9, [x0, #128] + WORD $0xad452c0a // ldp q10, q11, [x0, #160] + WORD $0xad46340c // ldp q12, q13, [x0, #192] + WORD $0xad473c0e // ldp q14, q15, [x0, #224] + WORD $0xad484410 // ldp q16, q17, [x0, #256] + WORD $0xad494c12 // ldp q18, q19, [x0, #288] + WORD $0xad4a5414 // ldp q20, q21, [x0, #320] + WORD $0xad4b5c16 // ldp q22, q23, [x0, #352] + WORD $0xad4c6418 // ldp q24, q25, [x0, #384] + WORD $0xad4d6c1a // ldp q26, q27, [x0, #416] + WORD $0xad4e741c // ldp q28, q29, [x0, #448] + WORD $0xad4f7c1e // ldp q30, q31, [x0, #480] RET diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go deleted file mode 100644 index c05166fea..000000000 --- a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package ring0 - -import ( - "reflect" - "syscall" - "unsafe" - - "gvisor.dev/gvisor/pkg/safecopy" - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - nopInstruction = 0xd503201f - instSize = unsafe.Sizeof(uint32(0)) - vectorsRawLen = 0x800 -) - -func unsafeSlice(addr uintptr, length int) (slice []uint32) { - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) - hdr.Data = addr - hdr.Len = length / int(instSize) - hdr.Cap = length / int(instSize) - return slice -} - -// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. -// -// According to the design documentation of Arm64, -// the start address of exception vector table should be 11-bits aligned. -// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S -// But, we can't align a function's start address to a specific address by using golang. -// We have raised this question in golang community: -// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I -// This function will be removed when golang supports this feature. -// -// There are 2 jobs were implemented in this function: -// 1, move the start address of exception vector table into the specific address. -// 2, modify the offset of each instruction. -func rewriteVectors() { - vectorsBegin := reflect.ValueOf(Vectors).Pointer() - - // The exception-vector-table is required to be 11-bits aligned. - // And the size is 0x800. - // Please see the documentation as reference: - // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table - // - // But, golang does not allow to set a function's address to a specific value. - // So, for gvisor, I defined the size of exception-vector-table as 4K, - // filled the 2nd 2K part with NOP-s. - // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. - // - // So, the prerequisite for this function to work correctly is: - // vectorsSafeLen >= 0x1000 - // vectorsRawLen = 0x800 - vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin) - if vectorsSafeLen < 2*vectorsRawLen { - panic("Can't update vectors") - } - - vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32 - vectorsRawLen32 := vectorsRawLen / int(instSize) - - offset := vectorsBegin & (1<<11 - 1) - if offset != 0 { - offset = 1<<11 - offset - } - - pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1) - - _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)) - if errno != 0 { - panic(errno.Error()) - } - - offset = offset / instSize // By index, not bytes. - // Move exception-vector-table into the specific address, should uses memmove here. - for i := 1; i <= vectorsRawLen32; i++ { - vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i] - } - - // Adjust branch since instruction was moved forward. - for i := 0; i < vectorsRawLen32; i++ { - if vectorsSafeTable[int(offset)+i] != nopInstruction { - vectorsSafeTable[int(offset)+i] -= uint32(offset) - } - } - - _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC)) - if errno != 0 { - panic(errno.Error()) - } -} diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index 53bc3353c..b5652deb9 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -70,6 +70,7 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa) fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc) fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc) + fmt.Fprintf(w, "#define El0SyncFpsimdExc 0x%02x\n", El0SyncFpsimdExc) fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys) fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc) fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index bc16a1622..7605d0cb2 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -58,6 +58,15 @@ type PageTables struct { readOnlyShared bool } +// Init initializes a set of PageTables. +// +//go:nosplit +func (p *PageTables) Init(allocator Allocator) { + p.Allocator = allocator + p.root = p.Allocator.NewPTEs() + p.rootPhysical = p.Allocator.PhysicalFor(p.root) +} + // NewWithUpper returns new PageTables. // // upperSharedPageTables are used for mapping the upper of addresses, @@ -73,14 +82,17 @@ type PageTables struct { func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables { p := new(PageTables) p.Init(a) + if upperSharedPageTables != nil { if !upperSharedPageTables.readOnlyShared { panic("Only read-only shared pagetables can be used as upper") } p.upperSharedPageTables = upperSharedPageTables p.upperStart = upperStart - p.cloneUpperShared() } + + p.InitArch(a) + return p } diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index a4e416af7..520161755 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -24,6 +24,14 @@ import ( // archPageTables is architecture-specific data. type archPageTables struct { + // root is the pagetable root for kernel space. + root *PTEs + + // rootPhysical is the cached physical address of the root. + // + // This is saved only to prevent constant translation. + rootPhysical uintptr + asid uint16 } @@ -38,7 +46,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 { // //go:nosplit func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 { - return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset + return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset } // Bits in page table entries. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go index e7ab887e5..4bdde8448 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -41,13 +41,13 @@ const ( entriesPerPage = 512 ) -// Init initializes a set of PageTables. +// InitArch does some additional initialization related to the architecture. // //go:nosplit -func (p *PageTables) Init(allocator Allocator) { - p.Allocator = allocator - p.root = p.Allocator.NewPTEs() - p.rootPhysical = p.Allocator.PhysicalFor(p.root) +func (p *PageTables) InitArch(allocator Allocator) { + if p.upperSharedPageTables != nil { + p.cloneUpperShared() + } } func pgdIndex(upperStart uintptr) uintptr { diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go index 5392bf27a..ad0e30c88 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -42,13 +42,16 @@ const ( entriesPerPage = 512 ) -// Init initializes a set of PageTables. +// InitArch does some additional initialization related to the architecture. // //go:nosplit -func (p *PageTables) Init(allocator Allocator) { - p.Allocator = allocator - p.root = p.Allocator.NewPTEs() - p.rootPhysical = p.Allocator.PhysicalFor(p.root) +func (p *PageTables) InitArch(allocator Allocator) { + if p.upperSharedPageTables != nil { + p.cloneUpperShared() + } else { + p.archPageTables.root = p.Allocator.NewPTEs() + p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) + } } // cloneUpperShared clone the upper from the upper shared page tables. @@ -59,7 +62,8 @@ func (p *PageTables) cloneUpperShared() { panic("upperStart should be the same as upperBottom") } - // nothing to do for arm. + p.archPageTables.root = p.upperSharedPageTables.archPageTables.root + p.archPageTables.rootPhysical = p.upperSharedPageTables.archPageTables.rootPhysical } // PTEs is a collection of entries. diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go index 157c9a7cc..c261d393a 100644 --- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go @@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr { func (w *Walker) iterateRangeCanonical(start, end uintptr) { pgdEntryIndex := w.pageTables.root if start >= upperBottom { - pgdEntryIndex = w.pageTables.upperSharedPageTables.root + pgdEntryIndex = w.pageTables.archPageTables.root } for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index a3f775d15..cc1f6bfcc 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index ca16d0381..fb7c5dc61 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -23,7 +23,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserror", - "//pkg/tcpip", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..b88cdca48 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -344,21 +343,34 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { } // PackIPPacketInfo packs an IP_PKTINFO socket control message. -func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { - var p linux.ControlMessageIPPacketInfo - p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) - +func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, linux.IP_PKTINFO, t.Arch().Width(), - p, + packetInfo, ) } +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { + var level uint32 + var optType uint32 + switch originalDstAddress.(type) { + case *linux.SockAddrInet: + level = linux.SOL_IP + optType = linux.IP_RECVORIGDSTADDR + case *linux.SockAddrInet6: + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + default: + panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg") + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), originalDstAddress) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. @@ -384,7 +396,11 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt } if cmsgs.IP.HasIPPacketInfo { - buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) + } + + if cmsgs.IP.OriginalDstAddress != nil { + buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) } return buf @@ -416,17 +432,15 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageTClass) } - return space -} + if cmsgs.IP.HasIPPacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) + } -// NewIPPacketInfo returns the IPPacketInfo struct. -func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { - var p tcpip.IPPacketInfo - p.NIC = tcpip.NICID(packetInfo.NIC) - copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) - copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + if cmsgs.IP.OriginalDstAddress != nil { + space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) + } - return p + return space } // Parse parses a raw socket control message into portable objects. @@ -489,6 +503,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.Unix.Credentials = scmCreds i += binary.AlignUp(length, width) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + return socket.ControlMessages{}, syserror.EINVAL + } + cmsgs.IP.HasTimestamp = true + binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp) + i += binary.AlignUp(length, width) + default: // Unknown message type. return socket.ControlMessages{}, syserror.EINVAL @@ -512,7 +534,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + cmsgs.IP.PacketInfo = packetInfo + i += binary.AlignUp(length, width) + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) default: @@ -528,6 +559,15 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) i += binary.AlignUp(length, width) + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..1f220c343 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 case linux.SO_LINGER: optlen = syscall.SizeofLinger @@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR: optlen = sizeofInt32 case linux.IP_PKTINFO: optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 } case linux.SOL_TCP: switch name { - case linux.TCP_NODELAY: + case linux.TCP_NODELAY, linux.TCP_INQ: optlen = sizeofInt32 } } @@ -416,6 +416,37 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] return nil } +func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) { + // We always do a non-blocking recv*(). + sysflags := flags | syscall.MSG_DONTWAIT + + msg := syscall.Msghdr{} + if len(iovs) > 0 { + msg.Iov = &iovs[0] + msg.Iovlen = uint64(len(iovs)) + } + var senderAddrBuf []byte + if senderRequested { + senderAddrBuf = make([]byte, sizeofSockaddr) + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(sizeofSockaddr) + } + var controlBuf []byte + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } + n, err := recvmsg(s.fd, &msg, sysflags) + if err != nil { + return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err + } + return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err +} + // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Only allow known and safe flags. @@ -427,56 +458,36 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } - var senderAddr linux.SockAddr var senderAddrBuf []byte - if senderRequested { - senderAddrBuf = make([]byte, sizeofSockaddr) - } - var controlBuf []byte var msgFlags int - - recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { - // Refuse to do anything if any part of dst.Addrs was unusable. - if uint64(dst.NumBytes()) != dsts.NumBytes() { - return 0, nil - } - if dsts.IsEmpty() { - return 0, nil - } - - // We always do a non-blocking recv*(). - sysflags := flags | syscall.MSG_DONTWAIT - - iovs := safemem.IovecsFromBlockSeq(dsts) - msg := syscall.Msghdr{ - Iov: &iovs[0], - Iovlen: uint64(len(iovs)), - } - if len(senderAddrBuf) != 0 { - msg.Name = &senderAddrBuf[0] - msg.Namelen = uint32(len(senderAddrBuf)) - } - if controlLen > 0 { - if controlLen > maxControlLen { - controlLen = maxControlLen + copyToDst := func() (int64, error) { + var n uint64 + var err error + if dst.NumBytes() == 0 { + // We want to make the recvmsg(2) call to the host even if dst is empty + // to fetch control messages, sender address or errors if any occur. + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen) + return int64(n), err + } + + recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { + // Refuse to do anything if any part of dst.Addrs was unusable. + if uint64(dst.NumBytes()) != dsts.NumBytes() { + return 0, nil + } + if dsts.IsEmpty() { + return 0, nil } - controlBuf = make([]byte, controlLen) - msg.Control = &controlBuf[0] - msg.Controllen = controlLen - } - n, err := recvmsg(s.fd, &msg, sysflags) - if err != nil { - return 0, err - } - senderAddrBuf = senderAddrBuf[:msg.Namelen] - msgFlags = int(msg.Flags) - controlLen = uint64(msg.Controllen) - return n, nil - }) + + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen) + return n, err + }) + return dst.CopyOutFrom(t, recvmsgToBlocks) + } var ch chan struct{} - n, err := dst.CopyOutFrom(t, recvmsgToBlocks) + n, err := copyToDst() if flags&syscall.MSG_DONTWAIT == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which @@ -494,48 +505,75 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags s.EventRegister(&e, waiter.EventIn) defer s.EventUnregister(&e) } - n, err = dst.CopyOutFrom(t, recvmsgToBlocks) + n, err = copyToDst() } } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + var senderAddr linux.SockAddr if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil +} +func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages { controlMessages := socket.ControlMessages{} for _, unixCmsg := range unixControlMessages { switch unixCmsg.Header.Level { - case syscall.SOL_IP: + case linux.SOL_SOCKET: + switch unixCmsg.Header.Type { + case linux.SO_TIMESTAMP: + controlMessages.IP.HasTimestamp = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp) + } + + case linux.SOL_IP: switch unixCmsg.Header.Type { - case syscall.IP_TOS: + case linux.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) - case syscall.IP_PKTINFO: + case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) + controlMessages.IP.PacketInfo = packetInfo + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr } - case syscall.SOL_IPV6: + case linux.SOL_IPV6: switch unixCmsg.Header.Type { - case syscall.IPV6_TCLASS: + case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + } + + case linux.SOL_TCP: + switch unixCmsg.Header.Type { + case linux.TCP_INQ: + controlMessages.IP.HasInq = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq) } } } - - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil + return controlMessages } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 3baad098b..057f4d294 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -120,9 +120,6 @@ type socketOpsCommon struct { // fixed buffer but only consume this many bytes. sendBufferSize uint32 - // passcred indicates if this socket wants SCM credentials. - passcred bool - // filter indicates that this socket has a BPF filter "installed". // // TODO(gvisor.dev/issue/1119): We don't actually support filtering, @@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { // Passcred implements transport.Credentialer.Passcred. func (s *socketOpsCommon) Passcred() bool { - s.mu.Lock() - passcred := s.passcred - s.mu.Unlock() - return passcred + return s.ep.SocketOptions().GetPassCred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. @@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] } passcred := usermem.ByteOrder.Uint32(opt) - s.mu.Lock() - s.passcred = passcred != 0 - s.mu.Unlock() + s.ep.SocketOptions().SetPassCred(passcred != 0) return nil case linux.SO_ATTACH_FILTER: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 7d0ae15ca..23d5cab9c 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -84,69 +84,95 @@ var Metrics = tcpip.Stats{ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), ICMP: tcpip.ICMPStats{ - V4PacketsSent: tcpip.ICMPv4SentPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + V4: tcpip.ICMPv4Stats{ + PacketsSent: tcpip.ICMPv4SentPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), + }, + PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), }, - V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + V6: tcpip.ICMPv6Stats{ + PacketsSent: tcpip.ICMPv6SentPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + }, + PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - V6PacketsSent: tcpip.ICMPv6SentPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + IGMP: tcpip.IGMPStats{ + PacketsSent: tcpip.IGMPSentPacketStats{ + IGMPPacketStats: tcpip.IGMPPacketStats{ + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."), }, - V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + PacketsReceived: tcpip.IGMPReceivedPacketStats{ + IGMPPacketStats: tcpip.IGMPPacketStats{ + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."), }, - Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."), + ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."), }, }, IP: tcpip.IPStats{ @@ -209,18 +235,6 @@ const sizeOfInt32 int = 4 var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) -// ntohs converts a 16-bit number from network byte order to host byte order. It -// assumes that the host is little endian. -func ntohs(v uint16) uint16 { - return v<<8 | v>>8 -} - -// htons converts a 16-bit number from host byte order to network byte order. It -// assumes that the host is little endian. -func htons(v uint16) uint16 { - return ntohs(v) -} - // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. type commonEndpoint interface { @@ -240,10 +254,6 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and - // transport.Endpoint.SetSockOptBool. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -252,18 +262,20 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and - // transport.Endpoint.GetSockOpt. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) - // LastError implements tcpip.Endpoint.LastError. + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 + + // LastError implements tcpip.Endpoint.LastError and + // transport.Endpoint.LastError. LastError() *tcpip.Error - // SocketOptions implements tcpip.Endpoint.SocketOptions. + // SocketOptions implements tcpip.Endpoint.SocketOptions and + // transport.Endpoint.SocketOptions. SocketOptions() *tcpip.SocketOptions } @@ -308,7 +320,7 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages + readCM socket.IPControlMessages sender tcpip.FullAddress linkPacketInfo tcpip.LinkPacketInfo @@ -332,9 +344,7 @@ type socketOpsCommon struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } dirent := socket.NewDirent(t, netstackDevice) @@ -363,88 +373,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, -// AF_INET6, and AF_PACKET addresses. -// -// AddressAndFamily returns an address and its family. -func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { - // Make sure we have at least 2 bytes for the address family. - if len(addr) < 2 { - return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument - } - - // Get the rest of the fields based on the address family. - switch family := usermem.ByteOrder.Uint16(addr); family { - case linux.AF_UNIX: - path := addr[2:] - if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - // Drop the terminating NUL (if one exists) and everything after - // it for filesystem (non-abstract) addresses. - if len(path) > 0 && path[0] != 0 { - if n := bytes.IndexByte(path[1:], 0); n >= 0 { - path = path[:n+1] - } - } - return tcpip.FullAddress{ - Addr: tcpip.Address(path), - }, family, nil - - case linux.AF_INET: - var a linux.SockAddrInet - if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - return out, family, nil - - case linux.AF_INET6: - var a linux.SockAddrInet6 - if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - if isLinkLocal(out.Addr) { - out.NIC = tcpip.NICID(a.Scope_id) - } - return out, family, nil - - case linux.AF_PACKET: - var a linux.SockAddrLink - if len(addr) < sockAddrLinkSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) - if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/173): Return protocol too. - return tcpip.FullAddress{ - NIC: tcpip.NICID(a.InterfaceIndex), - Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), - }, family, nil - - case linux.AF_UNSPEC: - return tcpip.FullAddress{}, family, nil - - default: - return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported - } -} - func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } @@ -480,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = v - s.readCM = cms + s.readCM = socket.NewIPControlMessages(s.family, cms) atomic.StoreUint32(&s.readViewHasData, 1) return nil @@ -500,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { return } - var v tcpip.LingerOption - if err := s.Endpoint.GetSockOpt(&v); err != nil { - return - } - + v := s.Endpoint.SocketOptions().GetLinger() // The case for zero timeout is handled in tcp endpoint close function. // Close is blocked until either: // 1. The endpoint state is not in any of the states: FIN-WAIT1, @@ -721,11 +645,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { return nil } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { - v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return syserr.TranslateNetstackError(err) - } - if !v { + if !s.Endpoint.SocketOptions().GetV6Only() { return nil } } @@ -749,7 +669,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -830,7 +750,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } } else { var err *syserr.Error - addr, family, err = AddressAndFamily(sockaddr) + addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -921,7 +841,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -1005,7 +925,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptSocket(t, s, ep, family, skType, name, outLen) case linux.SOL_TCP: - return getSockOptTCP(t, ep, name, outLen) + return getSockOptTCP(t, s, ep, name, outLen) case linux.SOL_IPV6: return getSockOptIPv6(t, s, ep, name, outPtr, outLen) @@ -1041,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.LastError() + err := ep.SocketOptions().GetLastError() if err == nil { optP := primitive.Int32(0) return &optP, nil @@ -1068,13 +988,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.PasscredOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred())) + return &v, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1115,25 +1030,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress())) + return &v, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReusePortOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort())) + return &v, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1174,24 +1080,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetKeepAlive())) + return &v, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - var v tcpip.LingerOption var linger linux.Linger - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetLinger() if v.Enabled { linger.OnOff = 1 @@ -1222,34 +1120,26 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.OutOfBandInlineOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(v) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline())) + return &v, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum())) + return &v, nil case linux.SO_ACCEPTCONN: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.AcceptConnOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + // This option is only viable for TCP endpoints. + var v bool + if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) { + v = tcp.EndpointState(ep.State()) == tcp.StateListen } vP := primitive.Int32(boolToInt32(v)) return &vP, nil @@ -1261,46 +1151,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(!v)) - return &vP, nil + v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption())) + return &v, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.CorkOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption())) + return &v, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.QuickAckOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck())) + return &v, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1474,19 +1354,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + + family, skType, _ := s.Type() + if family != linux.AF_INET6 { + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only())) + return &v, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1518,13 +1403,16 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) + return &v, nil + + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1536,7 +1424,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: @@ -1545,7 +1433,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1565,7 +1453,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1585,7 +1473,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1607,6 +1495,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name // getSockOptIP implements GetSockOpt when level is SOL_IP. func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1649,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) return &a.(*linux.SockAddrInet).Addr, nil @@ -1658,13 +1551,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop())) + return &v, nil case linux.IP_TOS: // Length handling for parity with Linux. @@ -1688,26 +1576,32 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) + return &v, nil + + case linux.IP_PKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo())) + return &v, nil - case linux.IP_PKTINFO: + case linux.IP_HDRINCL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) + return &v, nil + + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { @@ -1719,7 +1613,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet), nil case linux.IPT_SO_GET_INFO: @@ -1826,7 +1720,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptSocket(t, s, ep, name, optVal) case linux.SOL_TCP: - return setSockOptTCP(t, ep, name, optVal) + return setSockOptTCP(t, s, ep, name, optVal) case linux.SOL_IPV6: return setSockOptIPv6(t, s, ep, name, optVal) @@ -1876,7 +1770,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) + ep.SocketOptions().SetReuseAddress(v != 0) + return nil case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1884,7 +1779,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) + ep.SocketOptions().SetReusePort(v != 0) + return nil case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) @@ -1923,7 +1819,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) + ep.SocketOptions().SetPassCred(v != 0) + return nil case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1931,7 +1828,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) + ep.SocketOptions().SetKeepAlive(v != 0) + return nil case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1970,8 +1868,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - opt := tcpip.OutOfBandInlineOption(v) - return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) + ep.SocketOptions().SetOutOfBandInline(v != 0) + return nil case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1979,7 +1877,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + ep.SocketOptions().SetNoChecksum(v != 0) + return nil case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { @@ -1993,10 +1892,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError( - ep.SetSockOpt(&tcpip.LingerOption{ - Enabled: v.OnOff != 0, - Timeout: time.Second * time.Duration(v.Linger)})) + ep.SocketOptions().SetLinger(tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger), + }) + return nil case linux.SO_DETACH_FILTER: // optval is ignored. @@ -2011,7 +1911,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. -func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if len(optVal) < sizeOfInt32 { @@ -2019,7 +1924,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) + ep.SocketOptions().SetDelayOption(v == 0) + return nil case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -2027,7 +1933,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) + ep.SocketOptions().SetCorkOption(v != 0) + return nil case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -2035,7 +1942,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) + ep.SocketOptions().SetQuickAck(v != 0) + return nil case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -2147,14 +2055,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + + family, skType, skProto := s.Type() + if family != linux.AF_INET6 { + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } + if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { + return syserr.ErrInvalidEndpointState + } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + return syserr.ErrInvalidEndpointState + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) + ep.SocketOptions().SetV6Only(v != 0) + return nil case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, @@ -2174,6 +2099,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2193,7 +2127,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + ep.SocketOptions().SetReceiveTClass(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2201,7 +2136,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return syserr.ErrProtocolNotAvailable } @@ -2276,6 +2211,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2328,7 +2268,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), - InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), + InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]), })) case linux.IP_MULTICAST_LOOP: @@ -2337,7 +2277,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) + ep.SocketOptions().SetMulticastLoop(v != 0) + return nil case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2373,7 +2314,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + ep.SocketOptions().SetReceiveTOS(v != 0) + return nil case linux.IP_PKTINFO: if len(optVal) == 0 { @@ -2383,7 +2325,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + ep.SocketOptions().SetReceivePacketInfo(v != 0) + return nil case linux.IP_HDRINCL: if len(optVal) == 0 { @@ -2393,7 +2336,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + ep.SocketOptions().SetHeaderIncluded(v != 0) + return nil + + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { @@ -2433,7 +2389,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2511,7 +2466,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2535,7 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { switch name { case linux.IP_TOS, linux.IP_TTL, - linux.IP_HDRINCL, linux.IP_OPTIONS, linux.IP_ROUTER_ALERT, linux.IP_RECVOPTS, @@ -2582,72 +2535,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { } } -// isLinkLocal determines if the given IPv6 address is link-local. This is the -// case when it has the fe80::/10 prefix. This check is used to determine when -// the NICID is relevant for a given IPv6 address. -func isLinkLocal(addr tcpip.Address) bool { - return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 -} - -// ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { - switch family { - case linux.AF_UNIX: - var out linux.SockAddrUnix - out.Family = linux.AF_UNIX - l := len([]byte(addr.Addr)) - for i := 0; i < l; i++ { - out.Path[i] = int8(addr.Addr[i]) - } - - // Linux returns the used length of the address struct (including the - // null terminator) for filesystem paths. The Family field is 2 bytes. - // It is sometimes allowed to exclude the null terminator if the - // address length is the max. Abstract and empty paths always return - // the full exact length. - if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return &out, uint32(2 + l) - } - return &out, uint32(3 + l) - - case linux.AF_INET: - var out linux.SockAddrInet - copy(out.Addr[:], addr.Addr) - out.Family = linux.AF_INET - out.Port = htons(addr.Port) - return &out, uint32(sockAddrInetSize) - - case linux.AF_INET6: - var out linux.SockAddrInet6 - if len(addr.Addr) == header.IPv4AddressSize { - // Copy address in v4-mapped format. - copy(out.Addr[12:], addr.Addr) - out.Addr[10] = 0xff - out.Addr[11] = 0xff - } else { - copy(out.Addr[:], addr.Addr) - } - out.Family = linux.AF_INET6 - out.Port = htons(addr.Port) - if isLinkLocal(addr.Addr) { - out.Scope_id = uint32(addr.NIC) - } - return &out, uint32(sockAddrInet6Size) - - case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. - var out linux.SockAddrLink - out.Family = linux.AF_PACKET - out.InterfaceIndex = int32(addr.NIC) - out.HardwareAddrLen = header.EthernetAddressSize - copy(out.HardwareAddr[:], addr.Addr) - return &out, uint32(sockAddrLinkSize) - - default: - return nil, 0 - } -} - // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { @@ -2656,7 +2543,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2668,7 +2555,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2686,7 +2573,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ // Always do at least one fetchReadView, even if the number of bytes to // read is 0. err = s.fetchReadView() - if err != nil { + if err != nil || len(s.readView) == 0 { break } if dst.NumBytes() == 0 { @@ -2709,15 +2596,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ } copied += n s.readView.TrimFront(n) - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) - } dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) break } + // If we are done reading requested data then stop. + if dst.NumBytes() == 0 { + break + } + } + + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) } // If we managed to copy something, we must deliver it. @@ -2812,10 +2704,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { - addr, addrLen = ConvertAddress(s.family, s.sender) + addr, addrLen = socket.ConvertAddress(s.family, s.sender) switch v := addr.(type) { case *linux.SockAddrLink: - v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } @@ -2833,7 +2725,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // We need to peek beyond the first message. dst = dst.DropFirst(n) num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) { - n, _, err := s.Endpoint.Peek(dsts) + n, err := s.Endpoint.Peek(dsts) // TODO(b/78348848): Handle peek timestamp. if err != nil { return int64(n), syserr.TranslateNetstackError(err).ToError() @@ -2877,15 +2769,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ - IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + IP: socket.IPControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + OriginalDstAddress: s.readCM.OriginalDstAddress, }, } } @@ -2980,7 +2873,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, family, err := AddressAndFamily(to) + addrBuf, family, err := socket.AddressAndFamily(to) if err != nil { return 0, err } @@ -3399,6 +3292,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { return rv } +func isTCPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP) +} + +func isUDPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP) +} + +func isICMPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6) +} + // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. func (s *socketOpsCommon) State() uint32 { @@ -3408,7 +3313,7 @@ func (s *socketOpsCommon) State() uint32 { } switch { - case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: + case isTCPSocket(s.skType, s.protocol): // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -3437,7 +3342,7 @@ func (s *socketOpsCommon) State() uint32 { // Internal or unknown state. return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + case isUDPSocket(s.skType, s.protocol): // UDP socket. switch udp.EndpointState(s.Endpoint.State()) { case udp.StateInitial, udp.StateBound, udp.StateClosed: @@ -3447,7 +3352,7 @@ func (s *socketOpsCommon) State() uint32 { default: return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + case isICMPSocket(s.skType, s.protocol): // TODO(b/112063468): Export states for ICMP sockets. case s.skType == linux.SOCK_RAW: // TODO(b/112063468): Export states for raw sockets. diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index b0d9e4d9e..b756bfca0 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // NewVFS2 creates a new endpoint socket. func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } mnt := t.Kernel().SocketMount() @@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addrLen uint32 if peerAddr != nil { // Get address of the peer and write it to peer slice. - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index ead3b2b79..c847ff1c7 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 2a01143f6..0af805246 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fa9ac9059..cc0fadeb5 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: - in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats - out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ 0, // Icmp/InMsgs. - Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors. 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. @@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. 0, // Icmp/OutMsgs. - Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. - out.DstUnreachable.Value(), // OutDestUnreachs. - out.TimeExceeded.Value(), // OutTimeExcds. - out.ParamProblem.Value(), // OutParmProbs. - out.SrcQuench.Value(), // OutSrcQuenchs. - out.Redirect.Value(), // OutRedirects. - out.Echo.Value(), // OutEchos. - out.EchoReply.Value(), // OutEchoReps. - out.Timestamp.Value(), // OutTimestamps. - out.TimestampReply.Value(), // OutTimestampReps. - out.InfoRequest.Value(), // OutAddrMasks. - out.InfoReply.Value(), // OutAddrMaskReps. + Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. } case *inet.StatSNMPTCP: tcp := Metrics.TCP diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fd31479e5..bcc426e33 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -18,6 +18,7 @@ package socket import ( + "bytes" "fmt" "sync/atomic" "syscall" @@ -35,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/usermem" ) @@ -42,7 +44,79 @@ import ( // control messages. type ControlMessages struct { Unix transport.ControlMessages - IP tcpip.ControlMessages + IP IPControlMessages +} + +// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format. +func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + return p +} + +// NewIPControlMessages converts the tcpip ControlMessgaes (which does not +// have Linux specific format) to Linux format. +func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages { + var orgDstAddr linux.SockAddr + if cmgs.HasOriginalDstAddress { + orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) + } + return IPControlMessages{ + HasTimestamp: cmgs.HasTimestamp, + Timestamp: cmgs.Timestamp, + HasInq: cmgs.HasInq, + Inq: cmgs.Inq, + HasTOS: cmgs.HasTOS, + TOS: cmgs.TOS, + HasTClass: cmgs.HasTClass, + TClass: cmgs.TClass, + HasIPPacketInfo: cmgs.HasIPPacketInfo, + PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + OriginalDstAddress: orgDstAddr, + } +} + +// IPControlMessages contains socket control messages for IP sockets. +// This can contain Linux specific structures unlike tcpip.ControlMessages. +// +// +stateify savable +type IPControlMessages struct { + // HasTimestamp indicates whether Timestamp is valid/set. + HasTimestamp bool + + // Timestamp is the time (in ns) that the last packet used to create + // the read data was received. + Timestamp int64 + + // HasInq indicates whether Inq is valid/set. + HasInq bool + + // Inq is the number of bytes ready to be received. + Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS uint8 + + // HasTClass indicates whether TClass is valid/set. + HasTClass bool + + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo linux.ControlMessageIPPacketInfo + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress linux.SockAddr } // Release releases Unix domain socket credentials and rights. @@ -460,3 +534,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { panic(fmt.Sprintf("Unsupported socket family %v", family)) } } + +var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes() +var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes() +var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes() + +// Ntohs converts a 16-bit number from network byte order to host byte order. It +// assumes that the host is little endian. +func Ntohs(v uint16) uint16 { + return v<<8 | v>>8 +} + +// Htons converts a 16-bit number from host byte order to network byte order. It +// assumes that the host is little endian. +func Htons(v uint16) uint16 { + return Ntohs(v) +} + +// isLinkLocal determines if the given IPv6 address is link-local. This is the +// case when it has the fe80::/10 prefix. This check is used to determine when +// the NICID is relevant for a given IPv6 address. +func isLinkLocal(addr tcpip.Address) bool { + return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 +} + +// ConvertAddress converts the given address to a native format. +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { + switch family { + case linux.AF_UNIX: + var out linux.SockAddrUnix + out.Family = linux.AF_UNIX + l := len([]byte(addr.Addr)) + for i := 0; i < l; i++ { + out.Path[i] = int8(addr.Addr[i]) + } + + // Linux returns the used length of the address struct (including the + // null terminator) for filesystem paths. The Family field is 2 bytes. + // It is sometimes allowed to exclude the null terminator if the + // address length is the max. Abstract and empty paths always return + // the full exact length. + if l == 0 || out.Path[0] == 0 || l == len(out.Path) { + return &out, uint32(2 + l) + } + return &out, uint32(3 + l) + + case linux.AF_INET: + var out linux.SockAddrInet + copy(out.Addr[:], addr.Addr) + out.Family = linux.AF_INET + out.Port = Htons(addr.Port) + return &out, uint32(sockAddrInetSize) + + case linux.AF_INET6: + var out linux.SockAddrInet6 + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. + copy(out.Addr[12:], addr.Addr) + out.Addr[10] = 0xff + out.Addr[11] = 0xff + } else { + copy(out.Addr[:], addr.Addr) + } + out.Family = linux.AF_INET6 + out.Port = Htons(addr.Port) + if isLinkLocal(addr.Addr) { + out.Scope_id = uint32(addr.NIC) + } + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(gvisor.dev/issue/173): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + + default: + return nil, 0 + } +} + +// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the +// netstack representation taking any addresses into account. +func BytesToIPAddress(addr []byte) tcpip.Address { + if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { + return "" + } + return tcpip.Address(addr) +} + +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. +// +// AddressAndFamily returns an address and its family. +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { + // Make sure we have at least 2 bytes for the address family. + if len(addr) < 2 { + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument + } + + // Get the rest of the fields based on the address family. + switch family := usermem.ByteOrder.Uint16(addr); family { + case linux.AF_UNIX: + path := addr[2:] + if len(path) > linux.UnixPathMax { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + // Drop the terminating NUL (if one exists) and everything after + // it for filesystem (non-abstract) addresses. + if len(path) > 0 && path[0] != 0 { + if n := bytes.IndexByte(path[1:], 0); n >= 0 { + path = path[:n+1] + } + } + return tcpip.FullAddress{ + Addr: tcpip.Address(path), + }, family, nil + + case linux.AF_INET: + var a linux.SockAddrInet + if len(addr) < sockAddrInetSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + return out, family, nil + + case linux.AF_INET6: + var a linux.SockAddrInet6 + if len(addr) < sockAddrInet6Size { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + if isLinkLocal(out.Addr) { + out.NIC = tcpip.NICID(a.Scope_id) + } + return out, family, nil + + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/173): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, family, nil + + default: + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported + } +} diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 6d9e502bd..9f7aca305 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -118,28 +118,24 @@ var ( // NewConnectioned creates a new unbound connectionedEndpoint. func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { - return &connectionedEndpoint{ + return newConnectioned(ctx, stype, uid) +} + +func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint { + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { - a := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } - b := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } + a := newConnectioned(ctx, stype, uid) + b := newConnectioned(ctx, stype, uid) q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q1.InitRefs() @@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { - return &connectionedEndpoint{ + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // ID implements ConnectingEndpoint.ID. @@ -298,6 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } + ne.ops.InitHandler(ne) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.InitRefs() diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 1406971bc..0813ad87d 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,6 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} + ep.ops.InitHandler(ep) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 18a50e9f8..099a56281 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,8 +16,6 @@ package transport import ( - "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -180,10 +178,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool sets a socket option for simple cases when a value has - // the int type. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt sets a socket option for simple cases when a value has // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -191,10 +185,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool gets a socket option for simple cases when a return - // value has the int type. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -203,10 +193,11 @@ type Endpoint interface { // procfs. State() uint32 - // LastError implements tcpip.Endpoint.LastError. + // LastError clears and returns the last error reported by the endpoint. LastError() *tcpip.Error - // SocketOptions implements tcpip.Endpoint.SocketOptions. + // SocketOptions returns the structure which contains all the socket + // level options. SocketOptions() *tcpip.SocketOptions } @@ -739,10 +730,7 @@ func (e *connectedEndpoint) CloseUnread() { // +stateify savable type baseEndpoint struct { *waiter.Queue - - // passcred specifies whether SCM_CREDENTIALS socket control messages are - // enabled on this endpoint. Must be accessed atomically. - passcred int32 + tcpip.DefaultSocketOptionsHandler // Mutex protects the below fields. sync.Mutex `state:"nosave"` @@ -758,9 +746,7 @@ type baseEndpoint struct { // or may be used if the endpoint is connected. path string - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - + // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -786,7 +772,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { // Passcred implements Credentialer.Passcred. func (e *baseEndpoint) Passcred() bool { - return atomic.LoadInt32(&e.passcred) != 0 + return e.SocketOptions().GetPassCred() } // ConnectedPasscred implements Credentialer.ConnectedPasscred. @@ -796,14 +782,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool { return e.connected != nil && e.connected.Passcred() } -func (e *baseEndpoint) setPasscred(pc bool) { - if pc { - atomic.StoreInt32(&e.passcred, 1) - } else { - atomic.StoreInt32(&e.passcred, 0) - } -} - // Connected implements ConnectingEndpoint.Connected. func (e *baseEndpoint) Connected() bool { return e.receiver != nil && e.connected != nil @@ -859,23 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - e.linger = *v - e.Unlock() - } - return nil -} - -func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.PasscredOption: - e.setPasscred(v) - case tcpip.ReuseAddressOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } return nil } @@ -889,20 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - case tcpip.PasscredOption: - return e.Passcred(), nil - - default: - log.Warningf("Unsupported socket option: %d", opt) - return false, tcpip.ErrUnknownProtocolOption - } -} - func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -966,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - *o = e.linger - e.Unlock() - return nil - - default: - log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption - } + log.Warningf("Unsupported socket option: %T", opt) + return tcpip.ErrUnknownProtocolOption } // LastError implements Endpoint.LastError. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 3e520d2ee..c59297c80 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -115,9 +115,6 @@ type socketOpsCommon struct { // bound, they cannot be modified. abstractName string abstractNamespace *kernel.AbstractSocketNamespace - - // ops is used to get socket level options. - ops tcpip.SocketOptions } func (s *socketOpsCommon) isPacket() bool { @@ -139,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, family, err := netstack.AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { if err == syserr.ErrAddressFamilyNotSupported { err = syserr.ErrInvalidArgument @@ -172,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -184,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -258,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -650,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -685,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index eaf0b0d26..27f705bb2 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index a920180d3..d36a64ffc 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -32,8 +32,8 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/kernel", + "//pkg/sentry/socket", "//pkg/sentry/socket/netlink", - "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", "//pkg/usermem", ], diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index cc5f70cd4..d943a7cb1 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" - "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/usermem" ) @@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := netstack.AddressAndFamily(b) + fa, _, err := socket.AddressAndFamily(b) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index bb1f715e2..b815e498f 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index 0bf313a13..c2285f796 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := ctx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := ctx.Prepare(); err != nil { + return err } if eventFile != nil { diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 519066a47..8db587401 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -646,7 +646,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } - fSetOwn(t, file, set) + fSetOwn(t, int(fd), file, set) return 0, nil, nil case linux.FIOGETOWN, linux.SIOCGPGRP: @@ -901,8 +901,8 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 { // // If who is positive, it represents a PID. If negative, it represents a PGID. // If the PID or PGID is invalid, the owner is silently unset. -func fSetOwn(t *kernel.Task, file *fs.File, who int32) error { - a := file.Async(fasync.New).(*fasync.FileAsync) +func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error { + a := file.Async(fasync.New(fd)).(*fasync.FileAsync) if who < 0 { // Check for overflow before flipping the sign. if who-1 > who { @@ -1049,7 +1049,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN: return uintptr(fGetOwn(t, file)), nil, nil case linux.F_SETOWN: - return 0, nil, fSetOwn(t, file, args[2].Int()) + return 0, nil, fSetOwn(t, int(fd), file, args[2].Int()) case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) @@ -1062,7 +1062,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - a := file.Async(fasync.New).(*fasync.FileAsync) + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) switch owner.Type { case linux.F_OWNER_TID: task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID)) @@ -1111,6 +1111,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } n, err := sz.SetFifoSize(int64(args[2].Int())) return uintptr(n), nil, err + case linux.F_GETSIG: + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) + return uintptr(a.Signal()), nil, nil + case linux.F_SETSIG: + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) + return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index e383a0a87..1166cd7bb 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -146,11 +146,37 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal v, err := getNCnt(t, id, num) return uintptr(v), nil, err - case linux.IPC_INFO, - linux.SEM_INFO, - linux.SEM_STAT, - linux.SEM_STAT_ANY: + case linux.IPC_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.IPCInfo() + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil + + case linux.SEM_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.SemInfo() + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil + + case linux.SEM_STAT: + arg := args[3].Pointer() + // id is an index in SEM_STAT. + semid, ds, err := semStat(t, id) + if err != nil { + return 0, nil, err + } + if _, err := ds.CopyOut(t, arg); err != nil { + return 0, nil, err + } + return uintptr(semid), nil, err + case linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -195,6 +221,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { return set.GetStat(creds) } +func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByIndex(index) + if set == nil { + return 0, nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + ds, err := set.GetStat(creds) + return set.ID, ds, err +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index e748d33d8..d639c9bf7 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(target.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) + info.SetPID(int32(target.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) if err := target.SendGroupSignal(info); err != syserror.ESRCH { return 0, nil, err } @@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) err := tg.SendSignal(info) if err == syserror.ESRCH { // ESRCH is ignored because it means the task @@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) // See note above regarding ESRCH race above. if err := tg.SendSignal(info); err != syserror.ESRCH { lastErr = err @@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI Signo: int32(sig), Code: arch.SignalInfoTkill, } - info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 983f8d396..8e7ac0ffe 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal si := arch.SignalInfo{ Signo: int32(linux.SIGCHLD), } - si.SetPid(int32(wr.TID)) - si.SetUid(int32(wr.UID)) + si.SetPID(int32(wr.TID)) + si.SetUID(int32(wr.UID)) // TODO(b/73541790): convert kernel.ExitStatus to functions and make // WaitResult.Status a linux.WaitStatus. s := syscall.WaitStatus(wr.Status) diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 6d0a38330..1365a5a62 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := aioCtx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := aioCtx.Prepare(); err != nil { + return err } if eventFD != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 36e89700e..7dd9ef857 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -165,7 +165,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ownerType = linux.F_OWNER_PGRP who = -who } - return 0, nil, setAsyncOwner(t, file, ownerType, who) + return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who) case linux.F_GETOWN_EX: owner, hasOwner := getAsyncOwner(t, file) if !hasOwner { @@ -179,7 +179,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID) + return 0, nil, setAsyncOwner(t, int(fd), file, owner.Type, owner.PID) case linux.F_SETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { @@ -207,6 +207,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err case linux.F_SETLK, linux.F_SETLKW: return 0, nil, posixLock(t, args, file, cmd) + case linux.F_GETSIG: + a := file.AsyncHandler() + if a == nil { + // Default behavior aka SIGIO. + return 0, nil, nil + } + return uintptr(a.(*fasync.FileAsync).Signal()), nil, nil + case linux.F_SETSIG: + a := file.SetAsyncHandler(fasync.NewVFS2(int(fd))).(*fasync.FileAsync) + return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. return 0, nil, syserror.EINVAL @@ -241,7 +251,7 @@ func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwne } } -func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error { +func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, pid int32) error { switch ownerType { case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP: // Acceptable type. @@ -249,7 +259,7 @@ func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32 return syserror.EINVAL } - a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync) + a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync) if pid == 0 { a.ClearOwner() return nil diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 2806c3f6f..20c264fef 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -100,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ownerType = linux.F_OWNER_PGRP who = -who } - return 0, nil, setAsyncOwner(t, file, ownerType, who) + return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who) } ret, err := file.Ioctl(t, t.MemoryManager(), args) diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index 440c9307c..a3868bf16 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -105,6 +105,7 @@ go_library( "//pkg/sentry/arch", "//pkg/sentry/fs", "//pkg/sentry/fs/lock", + "//pkg/sentry/fsmetric", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/limits", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index a98aac52b..072655fe8 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -204,8 +204,8 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin file.EventRegister(&epi.waiter, wmask) // Check if the file is already ready. - if file.Readiness(wmask)&wmask != 0 { - epi.Callback(nil) + if m := file.Readiness(wmask) & wmask; m != 0 { + epi.Callback(nil, m) } // Add epi to file.epolls so that it is removed when the last @@ -274,8 +274,8 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event file.EventRegister(&epi.waiter, wmask) // Check if the file is already ready with the new mask. - if file.Readiness(wmask)&wmask != 0 { - epi.Callback(nil) + if m := file.Readiness(wmask) & wmask; m != 0 { + epi.Callback(nil, m) } return nil @@ -311,7 +311,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error } // Callback implements waiter.EntryCallback.Callback. -func (epi *epollInterest) Callback(*waiter.Entry) { +func (epi *epollInterest) Callback(*waiter.Entry, waiter.EventMask) { newReady := false epi.epoll.mu.Lock() if !epi.ready { diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 936f9fc71..5321ac80a 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -15,12 +15,14 @@ package vfs import ( + "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sync" @@ -42,7 +44,7 @@ import ( type FileDescription struct { FileDescriptionRefs - // flagsMu protects statusFlags and asyncHandler below. + // flagsMu protects `statusFlags`, `saved`, and `asyncHandler` below. flagsMu sync.Mutex `state:"nosave"` // statusFlags contains status flags, "initialized by open(2) and possibly @@ -51,6 +53,11 @@ type FileDescription struct { // access to asyncHandler. statusFlags uint32 + // saved is true after beforeSave is called. This is used to prevent + // double-unregistration of asyncHandler. This does not work properly for + // save-resume, which is not currently supported in gVisor (see b/26588733). + saved bool `state:"nosave"` + // asyncHandler handles O_ASYNC signal generation. It is set with the // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must // also be set by fcntl(2). @@ -183,7 +190,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.vd.DecRef(ctx) fd.flagsMu.Lock() - if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + if !fd.saved && fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { fd.asyncHandler.Unregister(fd) } fd.asyncHandler = nil @@ -583,7 +590,11 @@ func (fd *FileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of if !fd.readable { return 0, syserror.EBADF } - return fd.impl.PRead(ctx, dst, offset, opts) + start := fsmetric.StartReadWait() + n, err := fd.impl.PRead(ctx, dst, offset, opts) + fsmetric.Reads.Increment() + fsmetric.FinishReadWait(fsmetric.ReadWait, start) + return n, err } // Read is similar to PRead, but does not specify an offset. @@ -591,7 +602,11 @@ func (fd *FileDescription) Read(ctx context.Context, dst usermem.IOSequence, opt if !fd.readable { return 0, syserror.EBADF } - return fd.impl.Read(ctx, dst, opts) + start := fsmetric.StartReadWait() + n, err := fd.impl.Read(ctx, dst, opts) + fsmetric.Reads.Increment() + fsmetric.FinishReadWait(fsmetric.ReadWait, start) + return n, err } // PWrite writes src to the file represented by fd, starting at the given @@ -825,44 +840,27 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn return fd.asyncHandler } -// FileReadWriteSeeker is a helper struct to pass a FileDescription as -// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. -type FileReadWriteSeeker struct { - FD *FileDescription - Ctx context.Context - ROpts ReadOptions - WOpts WriteOptions -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts) - return int(n), err -} - -// Read implements io.ReadWriteSeeker.Read. -func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.Read(f.Ctx, dst, f.ROpts) - return int(n), err -} - -// Seek implements io.ReadWriteSeeker.Seek. -func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { - return f.FD.Seek(f.Ctx, offset, int32(whence)) -} - -// WriteAt implements io.WriterAt.WriteAt. -func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts) - return int(n), err -} - -// Write implements io.ReadWriteSeeker.Write. -func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { - buf := usermem.BytesIOSequence(p) - n, err := f.FD.Write(f.Ctx, buf, f.WOpts) - return int(n), err +// CopyRegularFileData copies data from srcFD to dstFD until reading from srcFD +// returns EOF or an error. It returns the number of bytes copied. +func CopyRegularFileData(ctx context.Context, dstFD, srcFD *FileDescription) (int64, error) { + done := int64(0) + buf := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size + for { + readN, readErr := srcFD.Read(ctx, buf, ReadOptions{}) + if readErr != nil && readErr != io.EOF { + return done, readErr + } + src := buf.TakeFirst64(readN) + for src.NumBytes() != 0 { + writeN, writeErr := dstFD.Write(ctx, src, WriteOptions{}) + done += writeN + src = src.DropFirst64(writeN) + if writeErr != nil { + return done, writeErr + } + } + if readErr == io.EOF { + return done, nil + } + } } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index cb48c37a1..0df023713 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.12 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - package vfs import ( @@ -41,6 +36,15 @@ type mountKey struct { point unsafe.Pointer // *Dentry } +var ( + mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil)) + mountKeySeed = sync.RandUintptr() +) + +func (k *mountKey) hash() uintptr { + return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed) +} + func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry { } } -func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } - // Invariant: mnt.key.parent == nil. vd.Ok(). func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } -func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } - // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). // // mountTable.Init() must be called on new mountTables before use. -// -// +stateify savable type mountTable struct { // mountTable is implemented as a seqcount-protected hash table that // resolves collisions with linear probing, featuring Robin Hood insertion @@ -84,8 +82,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq sync.SeqCount `state:"nosave"` - seed uint32 // for hashing keys + seq sync.SeqCount `state:"nosave"` // size holds both length (number of elements) and capacity (number of // slots): capacity is stored as its base-2 log (referred to as order) in @@ -150,7 +147,6 @@ func init() { // Init must be called exactly once on each mountTable before use. func (mt *mountTable) Init() { - mt.seed = rand32() mt.size = mtInitOrder mt.slots = newMountTableSlots(mtInitCap) } @@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := key.hash() loop: for { @@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() // We're under the maximum load factor if: // @@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 slots := mt.slots @@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) { off = (off + mountSlotBytes) & offmask } } - -//go:linkname memhash runtime.memhash -func memhash(p unsafe.Pointer, seed, s uintptr) uintptr - -//go:linkname rand32 runtime.fastrand -func rand32() uint32 diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go index 7723ed643..8998a82dd 100644 --- a/pkg/sentry/vfs/save_restore.go +++ b/pkg/sentry/vfs/save_restore.go @@ -18,8 +18,10 @@ import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/waiter" ) // FilesystemImplSaveRestoreExtension is an optional extension to @@ -99,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount { return mounts } +// saveKey is called by stateify. +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + // loadMounts is called by stateify. func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { if mounts == nil { @@ -110,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { } } +// loadKey is called by stateify. +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + func (mnt *Mount) afterLoad() { if atomic.LoadInt64(&mnt.refs) != 0 { refsvfs2.Register(mnt) @@ -120,5 +128,20 @@ func (mnt *Mount) afterLoad() { func (epi *epollInterest) afterLoad() { // Mark all epollInterests as ready after restore so that the next call to // EpollInstance.ReadEvents() rechecks their readiness. - epi.Callback(nil) + epi.Callback(nil, waiter.EventMaskFromLinux(epi.mask)) +} + +// beforeSave is called by stateify. +func (fd *FileDescription) beforeSave() { + fd.saved = true + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Unregister(fd) + } +} + +// afterLoad is called by stateify. +func (fd *FileDescription) afterLoad() { + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Register(fd) + } } diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 48d6252f7..6fd1bb0b2 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -41,6 +41,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/fsmetric" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" @@ -381,6 +382,8 @@ func (vfs *VirtualFilesystem) MknodAt(ctx context.Context, creds *auth.Credentia // OpenAt returns a FileDescription providing access to the file at the given // path. A reference is taken on the returned FileDescription. func (vfs *VirtualFilesystem) OpenAt(ctx context.Context, creds *auth.Credentials, pop *PathOperation, opts *OpenOptions) (*FileDescription, error) { + fsmetric.Opens.Increment() + // Remove: // // - O_CLOEXEC, which affects file descriptors and therefore must be diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 1d1062aeb..8e3146d8d 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -338,6 +338,7 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo tid := w.k.TaskSet().Root.IDOfTask(t) buf.WriteString(fmt.Sprintf("\tTask tid: %v (goroutine %d), entered RunSys state %v ago.\n", tid, t.GoroutineID(), now.Sub(o.lastUpdateTime))) } + buf.WriteString("Search for 'goroutine <id>' in the stack dump to find the offending goroutine(s)") // Force stack dump only if a new task is detected. w.doAction(w.TaskTimeoutAction, newTaskFound, &buf) |