diff options
Diffstat (limited to 'pkg/sentry')
104 files changed, 2374 insertions, 1597 deletions
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index d75d665ae..dd2effdf9 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint { func (a SyscallArgument) ModeT() uint { return uint(uint16(a.Value)) } + +// ErrFloatingPoint indicates a failed restore due to unusable floating point +// state. +type ErrFloatingPoint struct { + // supported is the supported floating point state. + supported uint64 + + // saved is the saved floating point state. + saved uint64 +} + +// Error returns a sensible description of the restore error. +func (e ErrFloatingPoint) Error() string { + return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) +} diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go index 19ce99d25..840e53d33 100644 --- a/pkg/sentry/arch/arch_state_x86.go +++ b/pkg/sentry/arch/arch_state_x86.go @@ -17,27 +17,10 @@ package arch import ( - "fmt" - "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/usermem" ) -// ErrFloatingPoint indicates a failed restore due to unusable floating point -// state. -type ErrFloatingPoint struct { - // supported is the supported floating point state. - supported uint64 - - // saved is the saved floating point state. - saved uint64 -} - -// Error returns a sensible description of the restore error. -func (e ErrFloatingPoint) Error() string { - return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved) -} - // XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87 // and SSE state, so this is the equivalent XSTATE_BV value. const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index c9fb55d00..35d2e07c3 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() { } } -// Pid returns the si_pid field. -func (s *SignalInfo) Pid() int32 { +// PID returns the si_pid field. +func (s *SignalInfo) PID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[0:4])) } -// SetPid mutates the si_pid field. -func (s *SignalInfo) SetPid(val int32) { +// SetPID mutates the si_pid field. +func (s *SignalInfo) SetPID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val)) } -// Uid returns the si_uid field. -func (s *SignalInfo) Uid() int32 { +// UID returns the si_uid field. +func (s *SignalInfo) UID() int32 { return int32(usermem.ByteOrder.Uint32(s.Fields[4:8])) } -// SetUid mutates the si_uid field. -func (s *SignalInfo) SetUid(val int32) { +// SetUID mutates the si_uid field. +func (s *SignalInfo) SetUID(val int32) { usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val)) } @@ -251,3 +251,26 @@ func (s *SignalInfo) Arch() uint32 { func (s *SignalInfo) SetArch(val uint32) { usermem.ByteOrder.PutUint32(s.Fields[12:16], val) } + +// Band returns the si_band field. +func (s *SignalInfo) Band() int64 { + return int64(usermem.ByteOrder.Uint64(s.Fields[0:8])) +} + +// SetBand mutates the si_band field. +func (s *SignalInfo) SetBand(val int64) { + // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`. + // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is + // `int`. See siginfo.h. + usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val)) +} + +// FD returns the si_fd field. +func (s *SignalInfo) FD() uint32 { + return usermem.ByteOrder.Uint32(s.Fields[8:12]) +} + +// SetFD mutates the si_fd field. +func (s *SignalInfo) SetFD(val uint32) { + usermem.ByteOrder.PutUint32(s.Fields[8:12], val) +} diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go index 2bf3c45e1..2f3664c57 100644 --- a/pkg/sentry/control/pprof.go +++ b/pkg/sentry/control/pprof.go @@ -15,10 +15,10 @@ package control import ( - "errors" "runtime" "runtime/pprof" "runtime/trace" + "time" "gvisor.dev/gvisor/pkg/fd" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -26,184 +26,263 @@ import ( "gvisor.dev/gvisor/pkg/urpc" ) -var errNoOutput = errors.New("no output writer provided") +// Profile includes profile-related RPC stubs. It provides a way to +// control the built-in runtime profiling facilities. +// +// The profile object must be instantied via NewProfile. +type Profile struct { + // kernel is the kernel under profile. It's immutable. + kernel *kernel.Kernel -// ProfileOpts contains options for the StartCPUProfile/Goroutine RPC call. -type ProfileOpts struct { - // File is the filesystem path for the profile. - File string `json:"path"` + // cpuMu protects CPU profiling. + cpuMu sync.Mutex - // FilePayload is the destination for the profiling output. - urpc.FilePayload + // blockMu protects block profiling. + blockMu sync.Mutex + + // mutexMu protects mutex profiling. + mutexMu sync.Mutex + + // traceMu protects trace profiling. + traceMu sync.Mutex + + // done is closed when profiling is done. + done chan struct{} } -// Profile includes profile-related RPC stubs. It provides a way to -// control the built-in pprof facility in sentry via sentryctl. -// -// The following options to sentryctl are added: -// -// - collect CPU profile on-demand. -// sentryctl -pid <pid> pprof-cpu-start -// sentryctl -pid <pid> pprof-cpu-stop -// -// - dump out the stack trace of current go routines. -// sentryctl -pid <pid> pprof-goroutine -type Profile struct { - // Kernel is the kernel under profile. It's immutable. - Kernel *kernel.Kernel +// NewProfile returns a new Profile object. +func NewProfile(k *kernel.Kernel) *Profile { + return &Profile{ + kernel: k, + done: make(chan struct{}), + } +} - // mu protects the fields below. - mu sync.Mutex +// Stop implements urpc.Stopper.Stop. +func (p *Profile) Stop() { + close(p.done) +} - // cpuFile is the current CPU profile output file. - cpuFile *fd.FD +// CPUProfileOpts contains options specifically for CPU profiles. +type CPUProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload - // traceFile is the current execution trace output file. - traceFile *fd.FD + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` } -// StartCPUProfile is an RPC stub which starts recording the CPU profile in a -// file. -func (p *Profile) StartCPUProfile(o *ProfileOpts, _ *struct{}) error { +// CPU is an RPC stub which collects a CPU profile. +func (p *Profile) CPU(o *CPUProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } - output, err := fd.NewFromFile(o.FilePayload.Files[0]) - if err != nil { - return err - } + output := o.FilePayload.Files[0] + defer output.Close() - p.mu.Lock() - defer p.mu.Unlock() + p.cpuMu.Lock() + defer p.cpuMu.Unlock() // Returns an error if profiling is already started. if err := pprof.StartCPUProfile(output); err != nil { - output.Close() return err } + defer pprof.StopCPUProfile() + + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: + } - p.cpuFile = output return nil } -// StopCPUProfile is an RPC stub which stops the CPU profiling and flush out the -// profile data. It takes no argument. -func (p *Profile) StopCPUProfile(_, _ *struct{}) error { - p.mu.Lock() - defer p.mu.Unlock() - - if p.cpuFile == nil { - return errors.New("CPU profiling not started") - } +// HeapProfileOpts contains options specifically for heap profiles. +type HeapProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload - pprof.StopCPUProfile() - p.cpuFile.Close() - p.cpuFile = nil - return nil + // Delay is the sleep time, similar to Duration. This may + // not affect the data collected however, as the heap will + // continue only the memory associated with the last alloc. + Delay time.Duration `json:"delay"` } -// HeapProfile generates a heap profile for the sentry. -func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error { +// Heap generates a heap profile. +func (p *Profile) Heap(o *HeapProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - runtime.GC() // Get up-to-date statistics. - if err := pprof.WriteHeapProfile(output); err != nil { - return err + + // Wait for the given delay. + select { + case <-time.After(o.Delay): + case <-p.done: } - return nil + + // Get up-to-date statistics. + runtime.GC() + + // Write the given profile. + return pprof.WriteHeapProfile(output) +} + +// GoroutineProfileOpts contains options specifically for goroutine profiles. +type GoroutineProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload } -// GoroutineProfile is an RPC stub which dumps out the stack trace for all -// running goroutines. -func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error { +// Goroutine dumps out the stack trace for all running goroutines. +func (p *Profile) Goroutine(o *GoroutineProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("goroutine").WriteTo(output, 2); err != nil { - return err - } - return nil + + return pprof.Lookup("goroutine").WriteTo(output, 2) +} + +// BlockProfileOpts contains options specifically for block profiles. +type BlockProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` + + // Rate is the block profile rate. + Rate int `json:"rate"` } -// BlockProfile is an RPC stub which dumps out the stack trace that led to -// blocking on synchronization primitives. -func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error { +// Block dumps a blocking profile. +func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("block").WriteTo(output, 0); err != nil { - return err + + p.blockMu.Lock() + defer p.blockMu.Unlock() + + // Always set the rate. We then wait to collect a profile at this rate, + // and disable when we're done. Note that the default here is 10%, which + // will record a stacktrace 10% of the time when blocking occurs. Since + // these events should not be super frequent, we expect this to achieve + // a reasonable balance between collecting the data we need and imposing + // a high performance cost (e.g. skewing even the CPU profile). + rate := 10 + if o.Rate != 0 { + rate = o.Rate } - return nil + runtime.SetBlockProfileRate(rate) + defer runtime.SetBlockProfileRate(0) + + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: + } + + return pprof.Lookup("block").WriteTo(output, 0) +} + +// MutexProfileOpts contains options specifically for mutex profiles. +type MutexProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` + + // Fraction is the mutex profile fraction. + Fraction int `json:"fraction"` } -// MutexProfile is an RPC stub which dumps out the stack trace of holders of -// contended mutexes. -func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error { +// Mutex dumps a mutex profile. +func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } + output := o.FilePayload.Files[0] defer output.Close() - if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil { - return err + + p.mutexMu.Lock() + defer p.mutexMu.Unlock() + + // Always set the fraction. Like the block rate above, we use + // a default rate of 10% for the same reasons. + fraction := 10 + if o.Fraction != 0 { + fraction = o.Fraction } - return nil + runtime.SetMutexProfileFraction(fraction) + defer runtime.SetMutexProfileFraction(0) + + // Collect the profile. + select { + case <-time.After(o.Duration): + case <-p.done: + } + + return pprof.Lookup("mutex").WriteTo(output, 0) } -// StartTrace is an RPC stub which starts collection of an execution trace. -func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error { +// TraceProfileOpts contains options specifically for traces. +type TraceProfileOpts struct { + // FilePayload is the destination for the profiling output. + urpc.FilePayload + + // Duration is the duration of the profile. + Duration time.Duration `json:"duration"` +} + +// Trace is an RPC stub which starts collection of an execution trace. +func (p *Profile) Trace(o *TraceProfileOpts, _ *struct{}) error { if len(o.FilePayload.Files) < 1 { - return errNoOutput + return nil // Allowed. } output, err := fd.NewFromFile(o.FilePayload.Files[0]) if err != nil { return err } + defer output.Close() - p.mu.Lock() - defer p.mu.Unlock() + p.traceMu.Lock() + defer p.traceMu.Unlock() // Returns an error if profiling is already started. if err := trace.Start(output); err != nil { output.Close() return err } + defer trace.Stop() // Ensure all trace contexts are registered. - p.Kernel.RebuildTraceContexts() - - p.traceFile = output - return nil -} - -// StopTrace is an RPC stub which stops collection of an ongoing execution -// trace and flushes the trace data. It takes no argument. -func (p *Profile) StopTrace(_, _ *struct{}) error { - p.mu.Lock() - defer p.mu.Unlock() + p.kernel.RebuildTraceContexts() - if p.traceFile == nil { - return errors.New("Execution tracing not started") + // Wait for the trace. + select { + case <-time.After(o.Duration): + case <-p.done: } // Similarly to the case above, if tasks have not ended traces, we will // lose information. Thus we need to rebuild the tasks in order to have // complete information. This will not lose information if multiple // traces are overlapping. - p.Kernel.RebuildTraceContexts() + p.kernel.RebuildTraceContexts() - trace.Stop() - p.traceFile.Close() - p.traceFile = nil return nil } diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index d800f2c85..62eaca965 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { Callback: func(err error) { if err == nil { log.Infof("Save succeeded: exiting...") + s.Kernel.SetSaveSuccess(false /* autosave */) } else { log.Warningf("Save failed: exiting...") s.Kernel.SetSaveError(err) diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go index 314661475..badd5b073 100644 --- a/pkg/sentry/fdimport/fdimport.go +++ b/pkg/sentry/fdimport/fdimport.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package fdimport provides the Import function. package fdimport import ( diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go index ff2fe6712..8e0aa9019 100644 --- a/pkg/sentry/fs/copy_up.go +++ b/pkg/sentry/fs/copy_up.go @@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err // copyUpBuffers is a buffer pool for copying file content. The buffer // size is the same used by io.Copy. -var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }} +var copyUpBuffers = sync.Pool{ + New: func() interface{} { + b := make([]byte, 8*usermem.PageSize) + return &b + }, +} // copyContentsLocked copies the contents of lower to upper. It panics if // less than size bytes can be copied. @@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in defer lowerFile.DecRef(ctx) // Use a buffer pool to minimize allocations. - buf := copyUpBuffers.Get().([]byte) + buf := copyUpBuffers.Get().(*[]byte) defer copyUpBuffers.Put(buf) // Transfer the contents. @@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in // optimizations could be self-defeating. So we leave this as simple as possible. var offset int64 for { - nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset) + nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset) if err != nil && err != io.EOF { return err } @@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in } return nil } - nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset) + nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset) if err != nil { return err } diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go index c7a11eec1..e04784db2 100644 --- a/pkg/sentry/fs/copy_up_test.go +++ b/pkg/sentry/fs/copy_up_test.go @@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) { wg.Add(1) go func(o *overlayTestFile) { if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil { - t.Fatalf("failed to copy up: %v", err) + t.Errorf("failed to copy up: %v", err) } wg.Done() }(file) diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go index 8049538f2..ec3d3f96c 100644 --- a/pkg/sentry/fs/filetest/filetest.go +++ b/pkg/sentry/fs/filetest/filetest.go @@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File { // Read just fails the request. func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Readv not implemented") + return 0, fmt.Errorf("TestFileOperations.Read not implemented") } // Write just fails the request. func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) { - return 0, fmt.Errorf("Writev not implemented") + return 0, fmt.Errorf("TestFileOperations.Write not implemented") } diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go index d2dbff268..a020da53b 100644 --- a/pkg/sentry/fs/fs.go +++ b/pkg/sentry/fs/fs.go @@ -65,7 +65,7 @@ var ( // runs with the lock held for reading. AsyncBarrier will take the lock // for writing, thus ensuring that all Async work completes before // AsyncBarrier returns. - workMu sync.RWMutex + workMu sync.CrossGoroutineRWMutex // asyncError is used to store up to one asynchronous execution error. asyncError = make(chan error, 1) diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go index d481baf77..e5579095b 100644 --- a/pkg/sentry/fs/gofer/attr.go +++ b/pkg/sentry/fs/gofer/attr.go @@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType { return fs.BlockDevice case pattr.Mode.IsSocket(): return fs.Socket - case pattr.Mode.IsRegular(): - fallthrough default: return fs.RegularFile } diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go index 9d6fdd08f..e840b6f5e 100644 --- a/pkg/sentry/fs/gofer/inode.go +++ b/pkg/sentry/fs/gofer/inode.go @@ -475,6 +475,9 @@ func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermM func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { switch d.Inode.StableAttr.Type { case fs.Socket: + if i.session().overrides != nil { + return nil, syserror.ENXIO + } return i.getFileSocket(ctx, d, flags) case fs.Pipe: return i.getFilePipe(ctx, d, flags) diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go index fbfba1b58..2c14aa6d9 100644 --- a/pkg/sentry/fs/host/inode.go +++ b/pkg/sentry/fs/host/inode.go @@ -276,6 +276,10 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport. // GetFile implements fs.InodeOperations.GetFile. func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + if fs.IsSocket(d.Inode.StableAttr) { + return nil, syserror.ENXIO + } + return newFile(ctx, d, flags, i), nil } diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go index f8aad2dbd..b998fb75d 100644 --- a/pkg/sentry/fs/proc/sys.go +++ b/pkg/sentry/fs/proc/sys.go @@ -84,6 +84,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode children := map[string]*fs.Inode{ "hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil), + "sem": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), "shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))), "shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))), "shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))), diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go index 29ff004f2..d0c565879 100644 --- a/pkg/sentry/fs/ramfs/socket.go +++ b/pkg/sentry/fs/ramfs/socket.go @@ -20,6 +20,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) @@ -63,7 +64,7 @@ func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint { // GetFile implements fs.FileOperations.GetFile. func (s *Socket) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, dirent, flags, &socketFileOperations{}), nil + return nil, syserror.ENXIO } // +stateify savable diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index e04cd608d..ad4aea282 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -148,6 +148,10 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare // GetFile implements fs.InodeOperations.GetFile. func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + if fs.IsSocket(d.Inode.StableAttr) { + return nil, syserror.ENXIO + } + if flags.Write { fsmetric.TmpfsOpensW.Increment() } else if flags.Read { diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go index 9009ba3c7..4a555bf72 100644 --- a/pkg/sentry/fsimpl/ext/inode.go +++ b/pkg/sentry/fsimpl/ext/inode.go @@ -200,7 +200,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt } var fd symlinkFD fd.LockFD.Init(&in.locks) - fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}) + if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } return &fd.vfsfd, nil default: panic(fmt.Sprintf("unknown inode type: %T", in.impl)) diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go index 1b3459c1d..4ab894965 100644 --- a/pkg/sentry/fsimpl/fuse/connection_control.go +++ b/pkg/sentry/fsimpl/fuse/connection_control.go @@ -84,11 +84,7 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error { Flags: fuseDefaultInitFlags, } - req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in) - if err != nil { - return err - } - + req := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in) // Since there is no task to block on and FUSE_INIT is the request // to unblock other requests, use nil. return conn.CallAsync(nil, req) diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go index 91d16c1cf..d8b0d7657 100644 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -76,10 +76,7 @@ func TestConnectionAbort(t *testing.T) { var futNormal []*futureResponse for i := 0; i < int(numRequests); i++ { - req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) fut, err := conn.callFutureLocked(task, req) if err != nil { t.Fatalf("callFutureLocked failed: %v", err) @@ -105,10 +102,7 @@ func TestConnectionAbort(t *testing.T) { } // After abort, Call() should return directly with ENOTCONN. - req, err := conn.NewRequest(creds, 0, 0, 0, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, 0, 0, 0, testObj) _, err = conn.Call(task, req) if err != syserror.ENOTCONN { t.Fatalf("Incorrect error code received for Call() after connection aborted") diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 95c475a65..bb2d0d31a 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -219,10 +219,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con data: rand.Uint32(), } - req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) - if err != nil { - t.Fatalf("NewRequest creation failed: %v", err) - } + req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) // Queue up a request. // Analogous to Call except it doesn't block on the task. diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go index 8f220a04b..fcc5d9a2a 100644 --- a/pkg/sentry/fsimpl/fuse/directory.go +++ b/pkg/sentry/fsimpl/fuse/directory.go @@ -68,11 +68,7 @@ func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirent } // TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS. - req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) - if err != nil { - return err - } - + req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in) res, err := fusefs.conn.Call(task, req) if err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go index 83f2816b7..e138b11f8 100644 --- a/pkg/sentry/fsimpl/fuse/file.go +++ b/pkg/sentry/fsimpl/fuse/file.go @@ -83,12 +83,8 @@ func (fd *fileDescription) Release(ctx context.Context) { opcode = linux.FUSE_RELEASE } kernelTask := kernel.TaskFromContext(ctx) - // ignoring errors and FUSE server reply is analogous to Linux's behavior. - req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) - if err != nil { - // No way to invoke Call() with an errored request. - return - } + // Ignoring errors and FUSE server reply is analogous to Linux's behavior. + req := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in) // The reply will be ignored since no callback is defined in asyncCallBack(). conn.CallAsync(kernelTask, req) } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index 23e827f90..204d8d143 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -119,7 +119,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */) if err != nil { - return nil, nil, err + log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err) + return nil, nil, syserror.EINVAL } kernelTask := kernel.TaskFromContext(ctx) @@ -128,6 +129,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, syserror.EINVAL } fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor)) + if fuseFDGeneric == nil { + return nil, nil, syserror.EINVAL + } defer fuseFDGeneric.DecRef(ctx) fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD) if !ok { @@ -360,12 +364,8 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr in.Flags &= ^uint32(linux.O_TRUNC) } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) - if err != nil { - return nil, err - } - // Send the request and receive the reply. + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return nil, err @@ -485,10 +485,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err return syserror.EINVAL } in := linux.FUSEUnlinkIn{Name: name} - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) - if err != nil { - return err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return err @@ -515,11 +512,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) in := linux.FUSERmDirIn{Name: name} - req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) - if err != nil { - return err - } - + req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { return err @@ -535,10 +528,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) return nil, syserror.EINVAL } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) - if err != nil { - return nil, err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return nil, err @@ -574,10 +564,7 @@ func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) { log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context") return "", syserror.EINVAL } - req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) - if err != nil { - return "", err - } + req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{}) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { return "", err @@ -680,11 +667,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp GetAttrFlags: flags, Fh: fh, } - req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) - if err != nil { - return linux.FUSEAttr{}, err - } - + req := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { return linux.FUSEAttr{}, err @@ -803,11 +786,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre UID: opts.Stat.UID, GID: opts.Stat.GID, } - req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) - if err != nil { - return err - } - + req := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in) res, err := conn.Call(task, req) if err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 2d396e84c..23ce91849 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -79,13 +79,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui in.Offset = off + (uint64(pagesRead) << usermem.PageShift) in.Size = pagesCanRead << usermem.PageShift - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) - if err != nil { - return nil, 0, err - } - // TODO(gvisor.dev/issue/3247): support async read. + req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in) res, err := fs.conn.Call(t, req) if err != nil { return nil, 0, err @@ -204,11 +200,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, in.Offset = off + uint64(written) in.Size = toWrite - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) - if err != nil { - return 0, err - } - + req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) req.payload = data[written : written+toWrite] // TODO(gvisor.dev/issue/3247): support async write. diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index 7fa00569b..41d679358 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2])) src = src[2:] } + _ = src // Remove unused warning. } // SizeBytes is the size of the payload of the FUSE_INIT response. @@ -104,7 +105,7 @@ type Request struct { } // NewRequest creates a new request that can be sent to the FUSE server. -func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) { +func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) *Request { conn.fd.mu.Lock() defer conn.fd.mu.Unlock() conn.fd.nextOpID += linux.FUSEOpID(reqIDStep) @@ -130,7 +131,7 @@ func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint id: hdr.Unique, hdr: &hdr, data: buf, - }, nil + } } // futureResponse represents an in-flight request, that may or may not have diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 435a21d77..36a3f6810 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -31,6 +31,7 @@ import ( fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs" "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -499,6 +500,10 @@ func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flag fileDescription: fileDescription{inode: i}, termios: linux.DefaultReplicaTermios, } + if task := kernel.TaskFromContext(ctx); task != nil { + fd.fgProcessGroup = task.ThreadGroup().ProcessGroup() + fd.session = fd.fgProcessGroup.Session() + } fd.LockFD.Init(&i.locks) vfsfd := &fd.vfsfd if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 469f3a33d..27b00cf6f 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -16,7 +16,6 @@ package overlay import ( "fmt" - "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -129,25 +128,9 @@ func (d *dentry) copyUpLocked(ctx context.Context) error { return err } defer newFD.DecRef(ctx) - bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size - for { - readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{}) - if readErr != nil && readErr != io.EOF { - cleanupUndoCopyUp() - return readErr - } - total := int64(0) - for total < readN { - writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{}) - total += writeN - if writeErr != nil { - cleanupUndoCopyUp() - return writeErr - } - } - if readErr == io.EOF { - break - } + if _, err := vfs.CopyRegularFileData(ctx, newFD, oldFD); err != nil { + cleanupUndoCopyUp() + return err } d.mapsMu.Lock() defer d.mapsMu.Unlock() diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go index 2b89a7a6d..25c785fd4 100644 --- a/pkg/sentry/fsimpl/overlay/regular_file.go +++ b/pkg/sentry/fsimpl/overlay/regular_file.go @@ -103,8 +103,8 @@ func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescript for e, mask := range fd.lowerWaiters { fd.cachedFD.EventUnregister(e) upperFD.EventRegister(e, mask) - if ready&mask != 0 { - e.Callback.Callback(e) + if m := ready & mask; m != 0 { + e.Callback.Callback(e, m) } } } diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index 0ecb592cf..429733c10 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -164,11 +164,11 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // and write ends of a newly-created pipe, as for pipe(2) and pipe2(2). // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). -func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) { +func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) { fs := mnt.Filesystem().Impl().(*filesystem) inode := newInode(ctx, fs) var d kernfs.Dentry d.Init(&fs.Filesystem, inode) defer d.DecRef(ctx) - return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags) + return inode.pipe.ReaderWriterPair(ctx, mnt, d.VFSDentry(), flags) } diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index a3780b222..75be6129f 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -57,9 +57,6 @@ func getMM(task *kernel.Task) *mm.MemoryManager { // MemoryManager's users count is incremented, and must be decremented by the // caller when it is no longer in use. func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) { - if task.ExitState() == kernel.TaskExitDead { - return nil, syserror.ESRCH - } var m *mm.MemoryManager task.WithMuLocked(func(t *kernel.Task) { m = t.MemoryManager() @@ -111,9 +108,13 @@ var _ dynamicInode = (*auxvData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if d.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } m, err := getMMIncRef(d.task) if err != nil { - return err + // Return empty file. + return nil } defer m.DecUsers(ctx) @@ -157,9 +158,13 @@ var _ dynamicInode = (*cmdlineData)(nil) // Generate implements vfs.DynamicBytesSource.Generate. func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error { + if d.task.ExitState() == kernel.TaskExitDead { + return syserror.ESRCH + } m, err := getMMIncRef(d.task) if err != nil { - return err + // Return empty file. + return nil } defer m.DecUsers(ctx) @@ -472,7 +477,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64 } m, err := getMMIncRef(fd.inode.task) if err != nil { - return 0, nil + return 0, err } defer m.DecUsers(ctx) // Buffer the read data because of MM locks diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 7c7afdcfa..25c407d98 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -44,6 +44,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k * return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), + "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))), "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go index 10f1452ef..246bd87bc 100644 --- a/pkg/sentry/fsimpl/signalfd/signalfd.go +++ b/pkg/sentry/fsimpl/signalfd/signalfd.go @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package signalfd provides basic signalfd file implementations. package signalfd import ( @@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 59fcff498..a4ad625bb 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -163,7 +163,7 @@ afterSymlink: // verifyChildLocked verifies the hash of child against the already verified // hash of the parent to ensure the child is expected. verifyChild triggers a // sentry panic if unexpected modifications to the file system are detected. In -// noCrashOnVerificationFailure mode it returns a syserror instead. +// ErrorOnViolation mode it returns a syserror instead. // // Preconditions: // * fs.renameMu must be locked. @@ -254,7 +254,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: parentMerkleFD, Ctx: ctx, } @@ -397,7 +397,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry } } - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: fd, Ctx: ctx, } diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index add65bee6..a5171b5ad 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -64,6 +64,10 @@ const ( // tree file for "/foo" is "/.merkle.verity.foo". merklePrefix = ".merkle.verity." + // merkleRootPrefix is the prefix of the Merkle tree root file. This + // needs to be different from merklePrefix to avoid name collision. + merkleRootPrefix = ".merkleroot.verity." + // merkleOffsetInParentXattr is the extended attribute name specifying the // offset of the child hash in its parent's Merkle tree. merkleOffsetInParentXattr = "user.merkle.offset" @@ -88,10 +92,8 @@ const ( ) var ( - // noCrashOnVerificationFailure indicates whether the sandbox should panic - // whenever verification fails. If true, an error is returned instead of - // panicking. This should only be set for tests. - noCrashOnVerificationFailure bool + // action specifies the action towards detected violation. + action ViolationAction // verityMu synchronizes concurrent operations that enable verity and perform // verification checks. @@ -102,6 +104,18 @@ var ( // content. type HashAlgorithm int +// ViolationAction is a type specifying the action when an integrity violation +// is detected. +type ViolationAction int + +const ( + // PanicOnViolation terminates the sentry on detected violation. + PanicOnViolation ViolationAction = 0 + // ErrorOnViolation returns an error from the violating system call on + // detected violation. + ErrorOnViolation = 1 +) + // Currently supported hashing algorithms include SHA256 and SHA512. const ( SHA256 HashAlgorithm = iota @@ -166,7 +180,7 @@ type filesystem struct { // its children. So they shouldn't be enabled the same time. This lock // is for the whole file system to ensure that no more than one file is // enabled the same time. - verityMu sync.RWMutex + verityMu sync.RWMutex `state:"nosave"` } // InternalFilesystemOptions may be passed as @@ -196,10 +210,8 @@ type InternalFilesystemOptions struct { // system wrapped by verity file system. LowerGetFSOptions vfs.GetFilesystemOptions - // NoCrashOnVerificationFailure indicates whether the sandbox should - // panic whenever verification fails. If true, an error is returned - // instead of panicking. This should only be set for tests. - NoCrashOnVerificationFailure bool + // Action specifies the action on an integrity violation. + Action ViolationAction } // Name implements vfs.FilesystemType.Name. @@ -211,10 +223,10 @@ func (FilesystemType) Name() string { func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means -// unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +// unexpected modification to the file system is detected. In ErrorOnViolation +// mode, it returns EIO, otherwise it panic. func alertIntegrityViolation(msg string) error { - if noCrashOnVerificationFailure { + if action == ErrorOnViolation { return syserror.EIO } panic(msg) @@ -227,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs") return nil, nil, syserror.EINVAL } - noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure + action = iopts.Action // Mount the lower file system. The lower file system is wrapped inside // verity, and should not be exposed or connected. @@ -255,7 +267,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt lowerVD.IncRef() d.lowerVD = lowerVD - rootMerkleName := merklePrefix + iopts.RootMerkleFileName + rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ Root: lowerVD, @@ -744,20 +756,20 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) // file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The // hash of the generated Merkle tree and the data size is returned. If fd // points to a regular file, the data is the content of the file. If fd points -// to a directory, the data is all hahes of its children, written to the Merkle +// to a directory, the data is all hashes of its children, written to the Merkle // tree file. // // Preconditions: fd.d.fs.verityMu must be locked. func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) { - fdReader := vfs.FileReadWriteSeeker{ + fdReader := FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, } - merkleReader := vfs.FileReadWriteSeeker{ + merkleReader := FileReadWriteSeeker{ FD: fd.merkleReader, Ctx: ctx, } - merkleWriter := vfs.FileReadWriteSeeker{ + merkleWriter := FileReadWriteSeeker{ FD: fd.merkleWriter, Ctx: ctx, } @@ -1047,12 +1059,12 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } - dataReader := vfs.FileReadWriteSeeker{ + dataReader := FileReadWriteSeeker{ FD: fd.lowerFD, Ctx: ctx, } - merkleReader := vfs.FileReadWriteSeeker{ + merkleReader := FileReadWriteSeeker{ FD: fd.merkleReader, Ctx: ctx, } @@ -1101,3 +1113,45 @@ func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence) } + +// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as +// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. +type FileReadWriteSeeker struct { + FD *vfs.FileDescription + Ctx context.Context + ROpts vfs.ReadOptions + WOpts vfs.WriteOptions +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts) + return int(n), err +} + +// Read implements io.ReadWriteSeeker.Read. +func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.Read(f.Ctx, dst, f.ROpts) + return int(n), err +} + +// Seek implements io.ReadWriteSeeker.Seek. +func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { + return f.FD.Seek(f.Ctx, offset, int32(whence)) +} + +// WriteAt implements io.WriterAt.WriteAt. +func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) { + dst := usermem.BytesIOSequence(p) + n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts) + return int(n), err +} + +// Write implements io.ReadWriteSeeker.Write. +func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { + buf := usermem.BytesIOSequence(p) + n, err := f.FD.Write(f.Ctx, buf, f.WOpts) + return int(n), err +} diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index 5d1f5de08..30d8b4355 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -35,14 +35,16 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// rootMerkleFilename is the name of the root Merkle tree file. -const rootMerkleFilename = "root.verity" +const ( + // rootMerkleFilename is the name of the root Merkle tree file. + rootMerkleFilename = "root.verity" + // maxDataSize is the maximum data size of a test file. + maxDataSize = 100000 +) -// maxDataSize is the maximum data size written to the file for test. -const maxDataSize = 100000 +var hashAlgs = []HashAlgorithm{SHA256, SHA512} -// getD returns a *dentry corresponding to VD. -func getD(t *testing.T, vd vfs.VirtualDentry) *dentry { +func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry { t.Helper() d, ok := vd.Dentry().Impl().(*dentry) if !ok { @@ -51,10 +53,21 @@ func getD(t *testing.T, vd vfs.VirtualDentry) *dentry { return d } +// dentryFromFD returns the dentry corresponding to fd. +func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry { + t.Helper() + f, ok := fd.Impl().(*fileDescription) + if !ok { + t.Fatalf("can't assert %T as a *fileDescription", fd) + } + return f.d +} + // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { + t.Helper() k, err := testutil.Boot() if err != nil { t.Fatalf("testutil.Boot: %v", err) @@ -79,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{ GetFilesystemOptions: vfs.GetFilesystemOptions{ InternalData: InternalFilesystemOptions{ - RootMerkleFileName: rootMerkleFilename, - LowerName: "tmpfs", - Alg: hashAlg, - AllowRuntimeEnable: true, - NoCrashOnVerificationFailure: true, + RootMerkleFileName: rootMerkleFilename, + LowerName: "tmpfs", + Alg: hashAlg, + AllowRuntimeEnable: true, + Action: ErrorOnViolation, }, }, }) @@ -102,7 +115,6 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, t.Fatalf("testutil.CreateTask: %v", err) } - t.Helper() t.Cleanup(func() { root.DecRef(ctx) mntns.DecRef(ctx) @@ -111,6 +123,8 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, } // openVerityAt opens a verity file. +// +// TODO(chongc): release reference from opening the file when done. func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ Root: vd, @@ -123,6 +137,8 @@ func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.Vir } // openLowerAt opens the file in the underlying file system. +// +// TODO(chongc): release reference from opening the file when done. func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ Root: d.lowerVD, @@ -135,6 +151,8 @@ func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, } // openLowerMerkleAt opens the Merkle file in the underlying file system. +// +// TODO(chongc): release reference from opening the file when done. func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) { return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ Root: d.lowerMerkleVD, @@ -190,21 +208,11 @@ func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFil }, &vfs.RenameOptions{}) } -// getDentry returns a *dentry corresponds to fd. -func getDentry(t *testing.T, fd *vfs.FileDescription) *dentry { - t.Helper() - f, ok := fd.Impl().(*fileDescription) - if !ok { - t.Fatalf("can't assert %T as a *fileDescription", fd) - } - return f.d -} - // newFileFD creates a new file in the verity mount, and returns the FD. The FD // points to a file that has random data generated. func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) { // Create the file in the underlying file system. - lowerFD, err := getD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) + lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) if err != nil { return nil, 0, err } @@ -231,9 +239,20 @@ func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, return fd, dataSize, err } -// corruptRandomBit randomly flips a bit in the file represented by fd. -func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { - // Flip a random bit in the underlying file. +// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD. +func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) { + // Create the file in the underlying file system. + _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode) + if err != nil { + return nil, err + } + // Now open the verity file descriptor. + fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode) + return fd, err +} + +// flipRandomBit randomly flips a bit in the file represented by fd. +func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error { randomPos := int64(rand.Intn(size)) byteToModify := make([]byte, 1) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil { @@ -246,7 +265,14 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er return nil } -var hashAlgs = []HashAlgorithm{SHA256, SHA512} +func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) { + t.Helper() + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("enable verity: %v", err) + } +} // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. @@ -264,12 +290,12 @@ func TestOpen(t *testing.T) { } // Ensure that the corresponding Merkle tree file is created. - if _, err = getDentry(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { + if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) } // Ensure the root merkle tree file is created. - if _, err = getD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { + if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil { t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) } } @@ -291,11 +317,7 @@ func TestPReadUnmodifiedFileSucceeds(t *testing.T) { } // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) buf := make([]byte, size) n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) @@ -325,11 +347,7 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) buf := make([]byte, size) n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) @@ -343,6 +361,36 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) { } } +// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity +// file succeeds after enabling verity for it. +func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-empty-file" + fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newEmptyFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + enableVerity(ctx, t, fd) + + var buf []byte + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != 0 { + t.Errorf("fd.Read got read length %d, expected 0", n) + } + } +} + // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { @@ -359,11 +407,7 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) { } // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Ensure reopening the verity enabled file succeeds. if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil { @@ -387,21 +431,14 @@ func TestOpenNonexistentFile(t *testing.T) { } // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Enable verity on the parent directory. parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, parentFD) // Ensure open an unexpected file in the parent directory fails with // ENOENT rather than verification failure. @@ -426,20 +463,16 @@ func TestPReadModifiedFileFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerFD that's read/writable. - lowerFD, err := getDentry(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) + lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from the modified file fails. @@ -466,20 +499,16 @@ func TestReadModifiedFileFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerFD that's read/writable. - lowerFD, err := getDentry(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) + lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from the modified file fails. @@ -506,14 +535,10 @@ func TestModifiedMerkleFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Open a new lowerMerkleFD that's read/writable. - lowerMerkleFD, err := getDentry(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) + lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } @@ -524,14 +549,13 @@ func TestModifiedMerkleFails(t *testing.T) { t.Errorf("lowerMerkleFD.Stat: %v", err) } - if err := corruptRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil { + t.Fatalf("flipRandomBit: %v", err) } // Confirm that read from a file with modified Merkle tree fails. buf := make([]byte, size) if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - fmt.Println(buf) t.Fatalf("fd.PRead succeeded with modified Merkle file") } } @@ -554,24 +578,17 @@ func TestModifiedParentMerkleFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) // Enable verity on the parent directory. parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, parentFD) // Open a new lowerMerkleFD that's read/writable. - parentLowerMerkleFD, err := getDentry(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) + parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular) if err != nil { t.Fatalf("OpenAt: %v", err) } @@ -591,8 +608,8 @@ func TestModifiedParentMerkleFails(t *testing.T) { if err != nil { t.Fatalf("Failed convert size to int: %v", err) } - if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) + if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("flipRandomBit: %v", err) } parentLowerMerkleFD.DecRef(ctx) @@ -619,13 +636,8 @@ func TestUnmodifiedStatSucceeds(t *testing.T) { t.Fatalf("newFileFD: %v", err) } - // Enable verity on the file and confirms stat succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } - + // Enable verity on the file and confirm that stat succeeds. + enableVerity(ctx, t, fd) if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { t.Errorf("fd.Stat: %v", err) } @@ -648,11 +660,7 @@ func TestModifiedStatFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } + enableVerity(ctx, t, fd) lowerFD := fd.Impl().(*fileDescription).lowerFD // Change the stat of the underlying file, and check that stat fails. @@ -711,19 +719,15 @@ func TestOpenDeletedFileFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) if tc.changeFile { - if err := getD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil { + if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } if tc.changeMerkleFile { - if err := getD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil { + if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } @@ -776,20 +780,16 @@ func TestOpenRenamedFileFails(t *testing.T) { } // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } + enableVerity(ctx, t, fd) newFilename := "renamed-test-file" if tc.changeFile { - if err := getD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil { + if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil { t.Fatalf("RenameAt: %v", err) } } if tc.changeMerkleFile { - if err := getD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil { + if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil { t.Fatalf("UnlinkAt: %v", err) } } diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go index 15519f0df..61aeca044 100644 --- a/pkg/sentry/kernel/epoll/epoll.go +++ b/pkg/sentry/kernel/epoll/epoll.go @@ -273,7 +273,7 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent { // // Callback is called when one of the files we're polling becomes ready. It // moves said file to the readyList if it's currently in the waiting list. -func (p *pollEntry) Callback(*waiter.Entry) { +func (p *pollEntry) Callback(*waiter.Entry, waiter.EventMask) { e := p.epoll e.listsMu.Lock() @@ -306,9 +306,8 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) { f.EventRegister(&entry.waiter, entry.mask) // Check if the file happens to already be in a ready state. - ready := f.Readiness(entry.mask) & entry.mask - if ready != 0 { - entry.Callback(&entry.waiter) + if ready := f.Readiness(entry.mask) & entry.mask; ready != 0 { + entry.Callback(&entry.waiter, ready) } } diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD index 2b3955598..f855f038b 100644 --- a/pkg/sentry/kernel/fasync/BUILD +++ b/pkg/sentry/kernel/fasync/BUILD @@ -8,11 +8,13 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/sentry/arch", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/vfs", "//pkg/sync", + "//pkg/syserror", "//pkg/waiter", ], ) diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go index 153d2cd9b..b66d61c6f 100644 --- a/pkg/sentry/kernel/fasync/fasync.go +++ b/pkg/sentry/kernel/fasync/fasync.go @@ -17,22 +17,45 @@ package fasync import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/waiter" ) -// New creates a new fs.FileAsync. -func New() fs.FileAsync { - return &FileAsync{} +// Table to convert waiter event masks into si_band siginfo codes. +// Taken from fs/fcntl.c:band_table. +var bandTable = map[waiter.EventMask]int64{ + // POLL_IN + waiter.EventIn: linux.EPOLLIN | linux.EPOLLRDNORM, + // POLL_OUT + waiter.EventOut: linux.EPOLLOUT | linux.EPOLLWRNORM | linux.EPOLLWRBAND, + // POLL_ERR + waiter.EventErr: linux.EPOLLERR, + // POLL_PRI + waiter.EventPri: linux.EPOLLPRI | linux.EPOLLRDBAND, + // POLL_HUP + waiter.EventHUp: linux.EPOLLHUP | linux.EPOLLERR, } -// NewVFS2 creates a new vfs.FileAsync. -func NewVFS2() vfs.FileAsync { - return &FileAsync{} +// New returns a function that creates a new fs.FileAsync with the given file +// descriptor. +func New(fd int) func() fs.FileAsync { + return func() fs.FileAsync { + return &FileAsync{fd: fd} + } +} + +// NewVFS2 returns a function that creates a new vfs.FileAsync with the given +// file descriptor. +func NewVFS2(fd int) func() vfs.FileAsync { + return func() vfs.FileAsync { + return &FileAsync{fd: fd} + } } // FileAsync sends signals when the registered file is ready for IO. @@ -42,6 +65,12 @@ type FileAsync struct { // e is immutable after first use (which is protected by mu below). e waiter.Entry + // fd is the file descriptor to notify about. + // It is immutable, set at allocation time. This matches Linux semantics in + // fs/fcntl.c:fasync_helper. + // The fd value is passed to the signal recipient in siginfo.si_fd. + fd int + // regMu protects registeration and unregistration actions on e. // // regMu must be held while registration decisions are being made @@ -56,6 +85,10 @@ type FileAsync struct { mu sync.Mutex `state:"nosave"` requester *auth.Credentials registered bool + // signal is the signal to deliver upon I/O being available. + // The default value ("zero signal") means the default SIGIO signal will be + // delivered. + signal linux.Signal // Only one of the following is allowed to be non-nil. recipientPG *kernel.ProcessGroup @@ -64,10 +97,10 @@ type FileAsync struct { } // Callback sends a signal. -func (a *FileAsync) Callback(e *waiter.Entry) { +func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) { a.mu.Lock() + defer a.mu.Unlock() if !a.registered { - a.mu.Unlock() return } t := a.recipientT @@ -80,19 +113,34 @@ func (a *FileAsync) Callback(e *waiter.Entry) { } if t == nil { // No recipient has been registered. - a.mu.Unlock() return } c := t.Credentials() // Logic from sigio_perm in fs/fcntl.c. - if a.requester.EffectiveKUID == 0 || + permCheck := (a.requester.EffectiveKUID == 0 || a.requester.EffectiveKUID == c.SavedKUID || a.requester.EffectiveKUID == c.RealKUID || a.requester.RealKUID == c.SavedKUID || - a.requester.RealKUID == c.RealKUID { - t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO)) + a.requester.RealKUID == c.RealKUID) + if !permCheck { + return } - a.mu.Unlock() + signalInfo := &arch.SignalInfo{ + Signo: int32(linux.SIGIO), + Code: arch.SignalInfoKernel, + } + if a.signal != 0 { + signalInfo.Signo = int32(a.signal) + signalInfo.SetFD(uint32(a.fd)) + var band int64 + for m, bandCode := range bandTable { + if m&mask != 0 { + band |= bandCode + } + } + signalInfo.SetBand(band) + } + t.SendSignal(signalInfo) } // Register sets the file which will be monitored for IO events. @@ -186,3 +234,25 @@ func (a *FileAsync) ClearOwner() { a.recipientTG = nil a.recipientPG = nil } + +// Signal returns which signal will be sent to the signal recipient. +// A value of zero means the signal to deliver wasn't customized, which means +// the default signal (SIGIO) will be delivered. +func (a *FileAsync) Signal() linux.Signal { + a.mu.Lock() + defer a.mu.Unlock() + return a.signal +} + +// SetSignal overrides which signal to send when I/O is available. +// The default behavior can be reset by specifying signal zero, which means +// to send SIGIO. +func (a *FileAsync) SetSignal(signal linux.Signal) error { + if signal != 0 && !signal.IsValid() { + return syserror.EINVAL + } + a.mu.Lock() + defer a.mu.Unlock() + a.signal = signal + return nil +} diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index 470d8bf83..f17f9c59c 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -121,18 +121,21 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 panic("VFS1 and VFS2 files set") } - slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) + slicePtr := (*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice)) // Grow the table as required. - if last := int32(len(slice)); fd >= last { + if last := int32(len(*slicePtr)); fd >= last { end := fd + 1 if end < 2*last { end = 2 * last } - slice = append(slice, make([]unsafe.Pointer, end-last)...) - atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) + newSlice := append(*slicePtr, make([]unsafe.Pointer, end-last)...) + slicePtr = &newSlice + atomic.StorePointer(&f.slice, unsafe.Pointer(slicePtr)) } + slice := *slicePtr + var desc *descriptor if file != nil || fileVFS2 != nil { desc = &descriptor{ diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 2cdcdfc1f..b8627a54f 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -214,9 +214,11 @@ type Kernel struct { // netlinkPorts manages allocation of netlink socket port IDs. netlinkPorts *port.Manager - // saveErr is the error causing the sandbox to exit during save, if - // any. It is protected by extMu. - saveErr error `state:"nosave"` + // saveStatus is nil if the sandbox has not been saved, errSaved or + // errAutoSaved if it has been saved successfully, or the error causing the + // sandbox to exit during save. + // It is protected by extMu. + saveStatus error `state:"nosave"` // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` @@ -1481,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager { return k.netlinkPorts } -// SaveError returns the sandbox error that caused the kernel to exit during -// save. -func (k *Kernel) SaveError() error { +var ( + errSaved = errors.New("sandbox has been successfully saved") + errAutoSaved = errors.New("sandbox has been successfully auto-saved") +) + +// SaveStatus returns the sandbox save status. If it was saved successfully, +// autosaved indicates whether save was triggered by autosave. If it was not +// saved successfully, err indicates the sandbox error that caused the kernel to +// exit during save. +func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) { + k.extMu.Lock() + defer k.extMu.Unlock() + switch k.saveStatus { + case nil: + return false, false, nil + case errSaved: + return true, false, nil + case errAutoSaved: + return true, true, nil + default: + return false, false, k.saveStatus + } +} + +// SetSaveSuccess sets the flag indicating that save completed successfully, if +// no status was already set. +func (k *Kernel) SetSaveSuccess(autosave bool) { k.extMu.Lock() defer k.extMu.Unlock() - return k.saveErr + if k.saveStatus == nil { + if autosave { + k.saveStatus = errAutoSaved + } else { + k.saveStatus = errSaved + } + } } // SetSaveError sets the sandbox error that caused the kernel to exit during @@ -1494,8 +1526,8 @@ func (k *Kernel) SaveError() error { func (k *Kernel) SetSaveError(err error) { k.extMu.Lock() defer k.extMu.Unlock() - if k.saveErr == nil { - k.saveErr = err + if k.saveStatus == nil { + k.saveStatus = err } } diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 7b23cbe86..2d47d2e82 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -63,10 +63,19 @@ func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe { // ReaderWriterPair returns read-only and write-only FDs for vp. // // Preconditions: statusFlags should not contain an open access mode. -func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) { +func (vp *VFSPipe) ReaderWriterPair(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) { // Connected pipes share the same locks. locks := &vfs.FileLocks{} - return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) + r, err := vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks) + if err != nil { + return nil, nil, err + } + w, err := vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks) + if err != nil { + r.DecRef(ctx) + return nil, nil, err + } + return r, w, nil } // Allocate implements vfs.FileDescriptionImpl.Allocate. @@ -85,7 +94,10 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s return nil, syserror.EINVAL } - fd := vp.newFD(mnt, vfsd, statusFlags, locks) + fd, err := vp.newFD(mnt, vfsd, statusFlags, locks) + if err != nil { + return nil, err + } // Named pipes have special blocking semantics during open: // @@ -137,16 +149,18 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s } // Preconditions: vp.mu must be held. -func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription { +func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) { fd := &VFSPipeFD{ pipe: &vp.pipe, } fd.LockFD.Init(locks) - fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{ + if err := fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{ DenyPRead: true, DenyPWrite: true, UseDentryMetadata: true, - }) + }); err != nil { + return nil, err + } switch { case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable(): @@ -160,7 +174,7 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l panic("invalid pipe flags: must be readable, writable, or both") } - return &fd.vfsfd + return &fd.vfsfd, nil } // VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 1abfe2201..cef58a590 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) { Signo: int32(linux.SIGTRAP), Code: code, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) if t.beginPtraceStopLocked() { tracer := t.Tracer() tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP)) diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index b99c0bffa..db01e4a97 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -29,17 +29,17 @@ import ( ) const ( - valueMax = 32767 // SEMVMX + // Maximum semaphore value. + valueMax = linux.SEMVMX - // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL). - semaphoresMax = 32000 + // Maximum number of semaphore sets. + setsMax = linux.SEMMNI - // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI). - setsMax = 32000 + // Maximum number of semaphroes in a semaphore set. + semsMax = linux.SEMMSL - // semaphoresTotalMax is "system-wide limit on the number of semaphores" - // (SEMMNS = SEMMNI*SEMMSL). - semaphoresTotalMax = 1024000000 + // Maximum number of semaphores in all semaphroe sets. + semsTotalMax = linux.SEMMNS ) // Registry maintains a set of semaphores that can be found by key or ID. @@ -52,6 +52,9 @@ type Registry struct { mu sync.Mutex `state:"nosave"` semaphores map[int32]*Set lastIDUsed int32 + // indexes maintains a mapping between a set's index in virtual array and + // its identifier. + indexes map[int32]int32 } // Set represents a set of semaphores that can be operated atomically. @@ -113,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { return &Registry{ userNS: userNS, semaphores: make(map[int32]*Set), + indexes: make(map[int32]int32), } } @@ -122,7 +126,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry { // be found. If exclusive is true, it fails if a set with the same key already // exists. func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) { - if nsems < 0 || nsems > semaphoresMax { + if nsems < 0 || nsems > semsMax { return nil, syserror.EINVAL } @@ -163,10 +167,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu } // Apply system limits. + // + // Map semaphores and map indexes in a registry are of the same size, + // check map semaphores only here for the system limit. if len(r.semaphores) >= setsMax { return nil, syserror.EINVAL } - if r.totalSems() > int(semaphoresTotalMax-nsems) { + if r.totalSems() > int(semsTotalMax-nsems) { return nil, syserror.EINVAL } @@ -176,6 +183,53 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu return r.newSet(ctx, key, owner, owner, perms, nsems) } +// IPCInfo returns information about system-wide semaphore limits and parameters. +func (r *Registry) IPCInfo() *linux.SemInfo { + return &linux.SemInfo{ + SemMap: linux.SEMMAP, + SemMni: linux.SEMMNI, + SemMns: linux.SEMMNS, + SemMnu: linux.SEMMNU, + SemMsl: linux.SEMMSL, + SemOpm: linux.SEMOPM, + SemUme: linux.SEMUME, + SemUsz: linux.SEMUSZ, + SemVmx: linux.SEMVMX, + SemAem: linux.SEMAEM, + } +} + +// SemInfo returns a seminfo structure containing the same information as +// for IPC_INFO, except that SemUsz field returns the number of existing +// semaphore sets, and SemAem field returns the number of existing semaphores. +func (r *Registry) SemInfo() *linux.SemInfo { + r.mu.Lock() + defer r.mu.Unlock() + + info := r.IPCInfo() + info.SemUsz = uint32(len(r.semaphores)) + info.SemAem = uint32(r.totalSems()) + + return info +} + +// HighestIndex returns the index of the highest used entry in +// the kernel's array. +func (r *Registry) HighestIndex() int32 { + r.mu.Lock() + defer r.mu.Unlock() + + // By default, highest used index is 0 even though + // there is no semaphroe set. + var highestIndex int32 + for index := range r.indexes { + if index > highestIndex { + highestIndex = index + } + } + return highestIndex +} + // RemoveID removes set with give 'id' from the registry and marks the set as // dead. All waiters will be awakened and fail. func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { @@ -186,6 +240,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { if set == nil { return syserror.EINVAL } + index, found := r.findIndexByID(id) + if !found { + // Inconsistent state. + panic(fmt.Sprintf("unable to find an index for ID: %d", id)) + } set.mu.Lock() defer set.mu.Unlock() @@ -197,6 +256,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error { } delete(r.semaphores, set.ID) + delete(r.indexes, index) set.destroy() return nil } @@ -220,6 +280,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File continue } if r.semaphores[id] == nil { + index, found := r.findFirstAvailableIndex() + if !found { + panic("unable to find an available index") + } + r.indexes[index] = id r.lastIDUsed = id r.semaphores[id] = set set.ID = id @@ -238,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set { return r.semaphores[id] } +// FindByIndex looks up a set given an index. +func (r *Registry) FindByIndex(index int32) *Set { + r.mu.Lock() + defer r.mu.Unlock() + + id, present := r.indexes[index] + if !present { + return nil + } + return r.semaphores[id] +} + func (r *Registry) findByKey(key int32) *Set { for _, v := range r.semaphores { if v.key == key { @@ -247,6 +324,24 @@ func (r *Registry) findByKey(key int32) *Set { return nil } +func (r *Registry) findIndexByID(id int32) (int32, bool) { + for k, v := range r.indexes { + if v == id { + return k, true + } + } + return 0, false +} + +func (r *Registry) findFirstAvailableIndex() (int32, bool) { + for index := int32(0); index < setsMax; index++ { + if _, present := r.indexes[index]; !present { + return index, true + } + } + return 0, false +} + func (r *Registry) totalSems() int { totalSems := 0 for _, v := range r.semaphores { diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index 80a592c8f..073e14507 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -6,6 +6,9 @@ package(licenses = ["notice"]) go_template_instance( name = "shm_refs", out = "shm_refs.go", + consts = { + "enableLogging": "true", + }, package = "shm", prefix = "Shm", template = "//pkg/refsvfs2:refs_template", diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go index e8cce37d0..2488ae7d5 100644 --- a/pkg/sentry/kernel/signal.go +++ b/pkg/sentry/kernel/signal.go @@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go index 78f718cfe..884966120 100644 --- a/pkg/sentry/kernel/signalfd/signalfd.go +++ b/pkg/sentry/kernel/signalfd/signalfd.go @@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS Signo: uint32(info.Signo), Errno: info.Errno, Code: info.Code, - PID: uint32(info.Pid()), - UID: uint32(info.Uid()), + PID: uint32(info.PID()), + UID: uint32(info.UID()), Status: info.Status(), Overrun: uint32(info.Overrun()), Addr: info.Addr(), diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index c5137c282..16986244c 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -368,8 +368,8 @@ func (t *Task) exitChildren() { Signo: int32(sig), Code: arch.SignalInfoUser, } - siginfo.SetPid(int32(c.tg.pidns.tids[t])) - siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) + siginfo.SetPID(int32(c.tg.pidns.tids[t])) + siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow())) c.tg.signalHandlers.mu.Lock() c.sendSignalLocked(siginfo, true /* group */) c.tg.signalHandlers.mu.Unlock() @@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si info := &arch.SignalInfo{ Signo: int32(sig), } - info.SetPid(int32(receiver.tg.pidns.tids[t])) - info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.tg.pidns.tids[t])) + info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) if t.exitStatus.Signaled() { info.Code = arch.CLD_KILLED info.SetStatus(int32(t.exitStatus.Signo)) diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 42dd3e278..75af3af79 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -914,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { Signo: int32(linux.SIGCHLD), Code: code, } - sigchld.SetPid(int32(t.tg.pidns.tids[target])) - sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + sigchld.SetPID(int32(t.tg.pidns.tids[target])) + sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) sigchld.SetStatus(status) // TODO(b/72102453): Set utime, stime. t.sendSignalLocked(sigchld, true /* group */) @@ -1022,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState { Signo: int32(sig), Code: t.ptraceCode, } - t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t])) - t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t])) + t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } else { t.ptraceCode = int32(sig) t.ptraceSiginfo = nil @@ -1114,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState { if parent == nil { // Tracer has detached and t was created by Kernel.CreateProcess(). // Pretend the parent is in an ancestor PID + user namespace. - info.SetPid(0) - info.SetUid(int32(auth.OverflowUID)) + info.SetPID(0) + info.SetUID(int32(auth.OverflowUID)) } else { - info.SetPid(int32(t.tg.pidns.tids[parent])) - info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) + info.SetPID(int32(t.tg.pidns.tids[parent])) + info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow())) } } t.tg.signalHandlers.mu.Lock() diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go index 7fd77925f..49e21026e 100644 --- a/pkg/sentry/memmap/memmap.go +++ b/pkg/sentry/memmap/memmap.go @@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp // Translations must be contiguous and in increasing order of // Translation.Source. if i > 0 && ts[i-1].Source.End != t.Source.Start { - return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t) + return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t) } // At least part of each Translation must be required. if t.Source.Intersect(required).Length() == 0 { diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go index 4c8cd38ed..5ab2ef79f 100644 --- a/pkg/sentry/mm/aio_context.go +++ b/pkg/sentry/mm/aio_context.go @@ -36,12 +36,12 @@ type aioManager struct { contexts map[uint64]*AIOContext } -func (a *aioManager) destroy() { - a.mu.Lock() - defer a.mu.Unlock() +func (mm *MemoryManager) destroyAIOManager(ctx context.Context) { + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() - for _, ctx := range a.contexts { - ctx.destroy() + for id := range mm.aioManager.contexts { + mm.destroyAIOContextLocked(ctx, id) } } @@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool { // be drained. // // Nil is returned if the context does not exist. -func (a *aioManager) destroyAIOContext(id uint64) *AIOContext { - a.mu.Lock() - defer a.mu.Unlock() - ctx, ok := a.contexts[id] +// +// Precondition: mm.aioManager.mu is locked. +func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext { + aioCtx, ok := mm.aioManager.contexts[id] if !ok { return nil } - delete(a.contexts, id) - ctx.destroy() - return ctx + + // Only unmaps after it assured that the address is a valid aio context to + // prevent random memory from been unmapped. + // + // Note: It's possible to unmap this address and map something else into + // the same address. Then it would be unmapping memory that it doesn't own. + // This is, however, the way Linux implements AIO. Keeps the same [weird] + // semantics in case anyone relies on it. + mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) + + delete(mm.aioManager.contexts, id) + aioCtx.destroy() + return aioCtx } // lookupAIOContext looks up the given context. @@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() { } } -// Prepare reserves space for a new request, returning true if available. -// Returns false if the context is busy. -func (ctx *AIOContext) Prepare() bool { +// Prepare reserves space for a new request, returning nil if available. +// Returns EAGAIN if the context is busy and EINVAL if the context is dead. +func (ctx *AIOContext) Prepare() error { ctx.mu.Lock() defer ctx.mu.Unlock() + if ctx.dead { + // Context died after the caller looked it up. + return syserror.EINVAL + } if ctx.outstanding >= ctx.maxOutstanding { - return false + // Context is busy. + return syserror.EAGAIN } ctx.outstanding++ - return true + return nil } // PopRequest pops a completed request if available, this function does not do @@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint // DestroyAIOContext destroys an asynchronous I/O context. It returns the // destroyed context. nil if the context does not exist. func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext { - if _, ok := mm.LookupAIOContext(ctx, id); !ok { + if !mm.isValidAddr(ctx, id) { return nil } - // Only unmaps after it assured that the address is a valid aio context to - // prevent random memory from been unmapped. - // - // Note: It's possible to unmap this address and map something else into - // the same address. Then it would be unmapping memory that it doesn't own. - // This is, however, the way Linux implements AIO. Keeps the same [weird] - // semantics in case anyone relies on it. - mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize) - - return mm.aioManager.destroyAIOContext(id) + mm.aioManager.mu.Lock() + defer mm.aioManager.mu.Unlock() + return mm.destroyAIOContextLocked(ctx, id) } // LookupAIOContext looks up the given context. It returns false if the context @@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC return nil, false } - // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes - // from id). - var buf [4]byte - _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) - if err != nil { + // Protect against 'id' that is inaccessible. + if !mm.isValidAddr(ctx, id) { return nil, false } return aioCtx, true } + +// isValidAddr determines if the address `id` is valid. (Linux also reads 4 +// bytes from id). +func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool { + var buf [4]byte + _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{}) + return err == nil +} diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go index 3dabac1af..e8931922f 100644 --- a/pkg/sentry/mm/aio_context_state.go +++ b/pkg/sentry/mm/aio_context_state.go @@ -15,6 +15,6 @@ package mm // afterLoad is invoked by stateify. -func (a *AIOContext) afterLoad() { - a.requestReady = make(chan struct{}, 1) +func (ctx *AIOContext) afterLoad() { + ctx.requestReady = make(chan struct{}, 1) } diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go index 09dbc06a4..120707429 100644 --- a/pkg/sentry/mm/lifecycle.go +++ b/pkg/sentry/mm/lifecycle.go @@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) { panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users)) } - mm.aioManager.destroy() + mm.destroyAIOManager(ctx) mm.metadataMu.Lock() exe := mm.executable diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go index acac3d357..bc53bd41e 100644 --- a/pkg/sentry/mm/mm_test.go +++ b/pkg/sentry/mm/mm_test.go @@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) { t.Errorf("CopyOut got %d want 1", n) } } + +// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be +// prepared after destruction. +func TestAIOPrepareAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + defer mm.DecUsers(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + aioCtx, ok := mm.LookupAIOContext(ctx, id) + if !ok { + t.Fatalf("AIOContext not found") + } + mm.DestroyAIOContext(ctx, id) + + // Prepare should fail because aioCtx should be destroyed. + if err := aioCtx.Prepare(); err != syserror.EINVAL { + t.Errorf("aioCtx.Prepare got err %v want nil", err) + } else if err == nil { + aioCtx.CancelPendingRequest() + } +} + +// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be +// looked up after memory manager is destroyed. +func TestAIOLookupAfterDestroy(t *testing.T) { + ctx := contexttest.Context(t) + mm := testMemoryManager(ctx) + + id, err := mm.NewAIOContext(ctx, 1) + if err != nil { + mm.DecUsers(ctx) + t.Fatalf("mm.NewAIOContext got err %v want nil", err) + } + mm.DecUsers(ctx) // This destroys the AIOContext manager. + + if _, ok := mm.LookupAIOContext(ctx, id); ok { + t.Errorf("AIOContext found even after AIOContext manager is destroyed") + } +} diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index 7c297fb9e..d99be7f46 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -423,11 +423,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File } if f.opts.ManualZeroing { - if err := f.forEachMappingSlice(fr, func(bs []byte) { - for i := range bs { - bs[i] = 0 - } - }); err != nil { + if err := f.manuallyZero(fr); err != nil { return memmap.FileRange{}, err } } @@ -560,19 +556,39 @@ func (f *MemoryFile) Decommit(fr memmap.FileRange) error { panic(fmt.Sprintf("invalid range: %v", fr)) } + if f.opts.ManualZeroing { + // FALLOC_FL_PUNCH_HOLE may not zero pages if ManualZeroing is in + // effect. + if err := f.manuallyZero(fr); err != nil { + return err + } + } else { + if err := f.decommitFile(fr); err != nil { + return err + } + } + + f.markDecommitted(fr) + return nil +} + +func (f *MemoryFile) manuallyZero(fr memmap.FileRange) error { + return f.forEachMappingSlice(fr, func(bs []byte) { + for i := range bs { + bs[i] = 0 + } + }) +} + +func (f *MemoryFile) decommitFile(fr memmap.FileRange) error { // "After a successful call, subsequent reads from this range will // return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with // FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2) - err := syscall.Fallocate( + return syscall.Fallocate( int(f.file.Fd()), _FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE, int64(fr.Start), int64(fr.Length())) - if err != nil { - return err - } - f.markDecommitted(fr) - return nil } func (f *MemoryFile) markDecommitted(fr memmap.FileRange) { @@ -1044,20 +1060,20 @@ func (f *MemoryFile) runReclaim() { break } - if err := f.Decommit(fr); err != nil { - log.Warningf("Reclaim failed to decommit %v: %v", fr, err) - // Zero the pages manually. This won't reduce memory usage, but at - // least ensures that the pages will be zero when reallocated. - f.forEachMappingSlice(fr, func(bs []byte) { - for i := range bs { - bs[i] = 0 + // If ManualZeroing is in effect, pages will be zeroed on allocation + // and may not be freed by decommitFile, so calling decommitFile is + // unnecessary. + if !f.opts.ManualZeroing { + if err := f.decommitFile(fr); err != nil { + log.Warningf("Reclaim failed to decommit %v: %v", fr, err) + // Zero the pages manually. This won't reduce memory usage, but at + // least ensures that the pages will be zero when reallocated. + if err := f.manuallyZero(fr); err != nil { + panic(fmt.Sprintf("Reclaim failed to decommit or zero %v: %v", fr, err)) } - }) - // Pretend the pages were decommitted even though they weren't, - // since the memory accounting implementation has no idea how to - // deal with this. - f.markDecommitted(fr) + } } + f.markDecommitted(fr) f.markReclaimed(fr) } diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index acad4c793..f8ccb7430 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -91,6 +91,13 @@ func bluepillSigBus(c *vCPU) { } } +// bluepillHandleEnosys is reponsible for handling enosys error. +// +//go:nosplit +func bluepillHandleEnosys(c *vCPU) { + throw("run failed: ENOSYS") +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection. // //go:nosplit @@ -126,3 +133,10 @@ func bluepillReadyStopGuest(c *vCPU) bool { } return true } + +// bluepillArchHandleExit checks architecture specific exitcode. +// +//go:nosplit +func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) { + c.die(bluepillArchContext(context), "unknown") +} diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 965ad66b5..1f09813ba 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -42,6 +42,13 @@ var ( sErrEsr: _ESR_ELx_SERR_NMI, }, } + + // vcpuExtDabt is the event of ext_dabt. + vcpuExtDabt = kvmVcpuEvents{ + exception: exception{ + extDabtPending: 1, + }, + } ) // getTLS returns the value of TPIDR_EL0 register. diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index 9433d4da5..4d912769a 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -85,7 +85,7 @@ func bluepillStopGuest(c *vCPU) { uintptr(c.fd), _KVM_SET_VCPU_EVENTS, uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 { - throw("sErr injection failed") + throw("bounce sErr injection failed") } } @@ -93,18 +93,54 @@ func bluepillStopGuest(c *vCPU) { // //go:nosplit func bluepillSigBus(c *vCPU) { + // Host must support ARM64_HAS_RAS_EXTN. if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 { - throw("sErr injection failed") + if errno == syscall.EINVAL { + throw("No ARM64_HAS_RAS_EXTN feature in host.") + } + throw("nmi sErr injection failed") } } +// bluepillExtDabt is reponsible for injecting external data abort. +// +//go:nosplit +func bluepillExtDabt(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 { + throw("ext_dabt injection failed") + } +} + +// bluepillHandleEnosys is reponsible for handling enosys error. +// +//go:nosplit +func bluepillHandleEnosys(c *vCPU) { + bluepillExtDabt(c) +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection. // //go:nosplit func bluepillReadyStopGuest(c *vCPU) bool { return true } + +// bluepillArchHandleExit checks architecture specific exitcode. +// +//go:nosplit +func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) { + switch c.runData.exitReason { + case _KVM_EXIT_ARM_NISV: + bluepillExtDabt(c) + default: + c.die(bluepillArchContext(context), "unknown") + } +} diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index 75085ac6a..8c5369377 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -148,6 +148,9 @@ func bluepillHandler(context unsafe.Pointer) { // mode and have interrupts disabled. bluepillSigBus(c) continue // Rerun vCPU. + case syscall.ENOSYS: + bluepillHandleEnosys(c) + continue default: throw("run failed") } @@ -220,7 +223,7 @@ func bluepillHandler(context unsafe.Pointer) { c.die(bluepillArchContext(context), "entry failed") return default: - c.die(bluepillArchContext(context), "unknown") + bluepillArchHandleExit(c, context) return } } diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go index 0b06a923a..9db1db4e9 100644 --- a/pkg/sentry/platform/kvm/kvm_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_arm64.go @@ -47,10 +47,11 @@ type userRegs struct { } type exception struct { - sErrPending uint8 - sErrHasEsr uint8 - pad [6]uint8 - sErrEsr uint64 + sErrPending uint8 + sErrHasEsr uint8 + extDabtPending uint8 + pad [5]uint8 + sErrEsr uint64 } type kvmVcpuEvents struct { diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index 6abaa21c4..2492d57be 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -56,6 +56,7 @@ const ( _KVM_EXIT_FAIL_ENTRY = 0x9 _KVM_EXIT_INTERNAL_ERROR = 0x11 _KVM_EXIT_SYSTEM_EVENT = 0x18 + _KVM_EXIT_ARM_NISV = 0x1c ) // KVM capability options. diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 54837f20c..aa2d21748 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -54,7 +54,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { pageTable.Map( usermem.Addr(ring0.KernelStartAddress|pr.virtual), pr.length, - pagetables.MapOpts{AccessType: usermem.AnyAccess}, + pagetables.MapOpts{AccessType: usermem.AnyAccess, Global: true}, pr.physical) return true // Keep iterating. diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index f2459755b..a466acf4d 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -79,7 +79,7 @@ func (c *vCPU) initArchState() error { } // tcr_el1 - data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1 + data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS reg.id = _KVM_ARM64_REGS_TCR_EL1 if err := c.setOneRegister(®); err != nil { return err @@ -103,7 +103,7 @@ func (c *vCPU) initArchState() error { c.SetTtbr0Kvm(uintptr(data)) // ttbr1_el1 - data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1) + data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0) reg.id = _KVM_ARM64_REGS_TTBR1_EL1 if err := c.setOneRegister(®); err != nil { diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go index f56aa3b79..571bfcc2e 100644 --- a/pkg/sentry/platform/ptrace/ptrace.go +++ b/pkg/sentry/platform/ptrace/ptrace.go @@ -18,8 +18,8 @@ // // In a nutshell, it works as follows: // -// The creation of a new address space creates a new child processes with a -// single thread which is traced by a single goroutine. +// The creation of a new address space creates a new child process with a single +// thread which is traced by a single goroutine. // // A context is just a collection of temporary variables. Calling Switch on a // context does the following: diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go index 812ab80ef..aacd7ce70 100644 --- a/pkg/sentry/platform/ptrace/subprocess.go +++ b/pkg/sentry/platform/ptrace/subprocess.go @@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool { // facilitate vsyscall emulation. See patchSignalInfo. patchSignalInfo(regs, &c.signalInfo) return false - } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) { + } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) { // The signal was generated by this process. That means // that it was an interrupt or something else that we // should bail for. Note that we ignore signals diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD index 679b287c3..2852b7387 100644 --- a/pkg/sentry/platform/ring0/BUILD +++ b/pkg/sentry/platform/ring0/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "arch_genrule", "go_library") load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) @@ -39,19 +39,19 @@ go_template_instance( template = ":defs_arm64", ) -genrule( +arch_genrule( name = "entry_impl_amd64", srcs = ["entry_amd64.s"], outs = ["entry_impl_amd64.s"], - cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) -genrule( +arch_genrule( name = "entry_impl_arm64", srcs = ["entry_arm64.s"], outs = ["entry_impl_arm64.s"], - cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@", + cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@", tools = ["//pkg/sentry/platform/ring0/gen_offsets"], ) @@ -72,7 +72,6 @@ go_library( "lib_amd64.s", "lib_arm64.go", "lib_arm64.s", - "lib_arm64_unsafe.go", "ring0.go", ], visibility = ["//pkg/sentry:internal"], diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 155f45ad8..b2bb18257 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -132,40 +132,6 @@ MOVD offset+PTRACE_R29(reg), R29; \ MOVD offset+PTRACE_R30(reg), R30; -// NOP-s -#define nop31Instructions() \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; \ - WORD $0xd503201f; - #define ESR_ELx_EC_UNKNOWN (0x00) #define ESR_ELx_EC_WFx (0x01) /* Unallocated EC: 0x02 */ @@ -305,24 +271,20 @@ WORD $0xd538d092; //MRS TPIDR_EL1, R18 // SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application. -#define SWITCH_TO_APP_PAGETABLE(from) \ - MRS TTBR1_EL1, R0; \ - MOVD CPU_APP_ASID(from), R1; \ - BFI $48, R1, $16, R0; \ - MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set) - ISB $15; \ - MOVD CPU_TTBR0_APP(from), RSV_REG; \ - MSR RSV_REG, TTBR0_EL1; +#define SWITCH_TO_APP_PAGETABLE() \ + MOVD CPU_APP_ASID(RSV_REG), RSV_REG_APP; \ + MOVD CPU_TTBR0_APP(RSV_REG), RSV_REG; \ + BFI $48, RSV_REG_APP, $16, RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; \ + ISB $15; // SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable. -#define SWITCH_TO_KVM_PAGETABLE(from) \ - MRS TTBR1_EL1, R0; \ - MOVD $1, R1; \ - BFI $48, R1, $16, R0; \ - MSR R0, TTBR1_EL1; \ - ISB $15; \ - MOVD CPU_TTBR0_KVM(from), RSV_REG; \ - MSR RSV_REG, TTBR0_EL1; +#define SWITCH_TO_KVM_PAGETABLE() \ + MOVD CPU_TTBR0_KVM(RSV_REG), RSV_REG; \ + MOVD $1, RSV_REG_APP; \ + BFI $48, RSV_REG_APP, $16, RSV_REG; \ + MSR RSV_REG, TTBR0_EL1; \ + ISB $15; TEXT ·EnableVFP(SB),NOSPLIT,$0 MOVD $FPEN_ENABLE, R0 @@ -530,7 +492,7 @@ do_exit_to_el0: WORD $0xd538d092 //MRS TPIDR_EL1, R18 - SWITCH_TO_APP_PAGETABLE(RSV_REG) + SWITCH_TO_APP_PAGETABLE() LDP 16*1(RSP), (R0, R1) LDP 16*0(RSP), (RSV_REG, RSV_REG_APP) @@ -555,10 +517,10 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 MOVD R1, RSP - SWITCH_TO_KVM_PAGETABLE(RSV_REG) + REGISTERS_LOAD(RSV_REG, CPU_REGISTERS) + SWITCH_TO_KVM_PAGETABLE() MRS TPIDR_EL1, RSV_REG - REGISTERS_LOAD(RSV_REG, CPU_REGISTERS) MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP ERET() @@ -566,8 +528,16 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 // Start is the CPU entrypoint. TEXT ·Start(SB),NOSPLIT,$0 // Init. - MOVD $SCTLR_EL1_DEFAULT, R1 - MSR R1, SCTLR_EL1 + WORD $0xd508871f // __tlbi(vmalle1) + DSB $7 // dsb(nsh) + + MOVD $1<<12, R1 // Reset mdscr_el1 and disable + MSR R1, MDSCR_EL1 // access to the DCC from EL0 + ISB $15 + + MRS TTBR1_EL1, R1 + MSR R1, TTBR0_EL1 + ISB $15 MOVD $CNTKCTL_EL1_DEFAULT, R1 MSR R1, CNTKCTL_EL1 @@ -576,6 +546,15 @@ TEXT ·Start(SB),NOSPLIT,$0 ORR $0xffff000000000000, RSV_REG, RSV_REG WORD $0xd518d092 //MSR R18, TPIDR_EL1 + // Init. + MOVD $SCTLR_EL1_DEFAULT, R1 // re-enable the mmu. + MSR R1, SCTLR_EL1 + ISB $15 + WORD $0xd508751f // ic iallu + + DSB $7 // dsb(nsh) + ISB $15 + B ·kernelExitToEl1(SB) // El1_sync_invalid is the handler for an invalid EL1_sync. @@ -748,79 +727,43 @@ TEXT ·El0_error_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) // Vectors implements exception vector table. +// The start address of exception vector table should be 11-bits aligned. +// For detail, please refer to arm developer document: +// https://developer.arm.com/documentation/100933/0100/AArch64-exception-vector-table +// Also can refer to the code in linux kernel: arch/arm64/kernel/entry.S TEXT ·Vectors(SB),NOSPLIT,$0 + PCALIGN $2048 B ·El1_sync_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_irq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_fiq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_error_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El1_sync(SB) - nop31Instructions() + PCALIGN $128 B ·El1_irq(SB) - nop31Instructions() + PCALIGN $128 B ·El1_fiq(SB) - nop31Instructions() + PCALIGN $128 B ·El1_error(SB) - nop31Instructions() + PCALIGN $128 B ·El0_sync(SB) - nop31Instructions() + PCALIGN $128 B ·El0_irq(SB) - nop31Instructions() + PCALIGN $128 B ·El0_fiq(SB) - nop31Instructions() + PCALIGN $128 B ·El0_error(SB) - nop31Instructions() + PCALIGN $128 B ·El0_sync_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_irq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_fiq_invalid(SB) - nop31Instructions() + PCALIGN $128 B ·El0_error_invalid(SB) - nop31Instructions() - - // The exception-vector-table is required to be 11-bits aligned. - // Please see Linux source code as reference: arch/arm64/kernel/entry.s. - // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs. - // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. - WORD $0xd503201f //nop - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() - WORD $0xd503201f - nop31Instructions() diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD index 9742308d8..a9703baf6 100644 --- a/pkg/sentry/platform/ring0/gen_offsets/BUILD +++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD @@ -24,6 +24,9 @@ go_binary( "defs_impl_arm64.go", "main.go", ], + # Use the libc malloc to avoid any extra dependencies. This is required to + # pass the sentry deps test. + system_malloc = True, visibility = [ "//pkg/sentry/platform/kvm:__pkg__", "//pkg/sentry/platform/ring0:__pkg__", diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index 90a7b8392..c05284641 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -53,11 +53,17 @@ func IsCanonical(addr uint64) bool { return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000 } +// SwitchToUser performs an eret. +// +// The return value is the exception vector. +// +// +checkescape:all +// //go:nosplit func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) { storeAppASID(uintptr(switchOpts.UserASID)) if switchOpts.Flush { - FlushTlbAll() + FlushTlbByASID(uintptr(switchOpts.UserASID)) } regs := switchOpts.Registers diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go index 0dffd33a3..a490bf3af 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.go +++ b/pkg/sentry/platform/ring0/lib_arm64.go @@ -22,19 +22,25 @@ func storeAppASID(asid uintptr) // LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU. func LocalFlushTlbAll() -// FlushTlbAll flush all tlb. +// FlushTlbByVA invalidates tlb by VA/Last-level/Inner-Shareable. +func FlushTlbByVA(addr uintptr) + +// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable. +func FlushTlbByASID(asid uintptr) + +// FlushTlbAll invalidates all tlb. func FlushTlbAll() // CPACREL1 returns the value of the CPACR_EL1 register. func CPACREL1() (value uintptr) -// FPCR returns the value of FPCR register. +// GetFPCR returns the value of FPCR register. func GetFPCR() (value uintptr) // SetFPCR writes the FPCR value. func SetFPCR(value uintptr) -// FPSR returns the value of FPSR register. +// GetFPSR returns the value of FPSR register. func GetFPSR() (value uintptr) // SetFPSR writes the FPSR value. @@ -62,6 +68,4 @@ func DisableVFP() // Init sets function pointers based on architectural features. // // This must be called prior to using ring0. -func Init() { - rewriteVectors() -} +func Init() {} diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s index 6f4923539..e39b32841 100644 --- a/pkg/sentry/platform/ring0/lib_arm64.s +++ b/pkg/sentry/platform/ring0/lib_arm64.s @@ -15,6 +15,23 @@ #include "funcdata.h" #include "textflag.h" +#define TLBI_ASID_SHIFT 48 + +TEXT ·FlushTlbByVA(SB),NOSPLIT,$0-8 + MOVD addr+0(FP), R1 + DSB $10 // dsb(ishst) + WORD $0xd50883a1 // tlbi vale1is, x1 + DSB $11 // dsb(ish) + RET + +TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8 + MOVD asid+0(FP), R1 + LSL $TLBI_ASID_SHIFT, R1, R1 + DSB $10 // dsb(ishst) + WORD $0xd5088341 // tlbi aside1is, x1 + DSB $11 // dsb(ish) + RET + TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0 DSB $6 // dsb(nshst) WORD $0xd508871f // __tlbi(vmalle1) diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go deleted file mode 100644 index c05166fea..000000000 --- a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2019 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build arm64 - -package ring0 - -import ( - "reflect" - "syscall" - "unsafe" - - "gvisor.dev/gvisor/pkg/safecopy" - "gvisor.dev/gvisor/pkg/usermem" -) - -const ( - nopInstruction = 0xd503201f - instSize = unsafe.Sizeof(uint32(0)) - vectorsRawLen = 0x800 -) - -func unsafeSlice(addr uintptr, length int) (slice []uint32) { - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) - hdr.Data = addr - hdr.Len = length / int(instSize) - hdr.Cap = length / int(instSize) - return slice -} - -// Work around: move ring0.Vectors() into a specific address with 11-bits alignment. -// -// According to the design documentation of Arm64, -// the start address of exception vector table should be 11-bits aligned. -// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S -// But, we can't align a function's start address to a specific address by using golang. -// We have raised this question in golang community: -// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I -// This function will be removed when golang supports this feature. -// -// There are 2 jobs were implemented in this function: -// 1, move the start address of exception vector table into the specific address. -// 2, modify the offset of each instruction. -func rewriteVectors() { - vectorsBegin := reflect.ValueOf(Vectors).Pointer() - - // The exception-vector-table is required to be 11-bits aligned. - // And the size is 0x800. - // Please see the documentation as reference: - // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table - // - // But, golang does not allow to set a function's address to a specific value. - // So, for gvisor, I defined the size of exception-vector-table as 4K, - // filled the 2nd 2K part with NOP-s. - // So that, I can safely move the 1st 2K part into the address with 11-bits alignment. - // - // So, the prerequisite for this function to work correctly is: - // vectorsSafeLen >= 0x1000 - // vectorsRawLen = 0x800 - vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin) - if vectorsSafeLen < 2*vectorsRawLen { - panic("Can't update vectors") - } - - vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32 - vectorsRawLen32 := vectorsRawLen / int(instSize) - - offset := vectorsBegin & (1<<11 - 1) - if offset != 0 { - offset = 1<<11 - offset - } - - pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1) - - _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC)) - if errno != 0 { - panic(errno.Error()) - } - - offset = offset / instSize // By index, not bytes. - // Move exception-vector-table into the specific address, should uses memmove here. - for i := 1; i <= vectorsRawLen32; i++ { - vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i] - } - - // Adjust branch since instruction was moved forward. - for i := 0; i < vectorsRawLen32; i++ { - if vectorsSafeTable[int(offset)+i] != nopInstruction { - vectorsSafeTable[int(offset)+i] -= uint32(offset) - } - } - - _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC)) - if errno != 0 { - panic(errno.Error()) - } -} diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index a3f775d15..cc1f6bfcc 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index ca16d0381..ebcc891b3 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -23,7 +23,20 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserror", - "//pkg/tcpip", "//pkg/usermem", ], ) + +go_test( + name = "control_test", + size = "small", + srcs = ["control_test.go"], + library = ":control", + deps = [ + "//pkg/abi/linux", + "//pkg/binary", + "//pkg/sentry/socket", + "//pkg/usermem", + "@com_github_google_go_cmp//cmp:go_default_library", + ], +) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..65b556489 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -344,18 +343,42 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { } // PackIPPacketInfo packs an IP_PKTINFO socket control message. -func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { - var p linux.ControlMessageIPPacketInfo - p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) - +func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, linux.IP_PKTINFO, t.Arch().Width(), - p, + packetInfo, + ) +} + +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { + var level uint32 + var optType uint32 + switch originalDstAddress.(type) { + case *linux.SockAddrInet: + level = linux.SOL_IP + optType = linux.IP_RECVORIGDSTADDR + case *linux.SockAddrInet6: + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + default: + panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg") + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), originalDstAddress) +} + +// PackSockExtendedErr packs an IP*_RECVERR socket control message. +func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte { + return putCmsgStruct( + buf, + sockErr.CMsgLevel(), + sockErr.CMsgType(), + t.Arch().Width(), + sockErr, ) } @@ -384,7 +407,15 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt } if cmsgs.IP.HasIPPacketInfo { - buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) + } + + if cmsgs.IP.OriginalDstAddress != nil { + buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) + } + + if cmsgs.IP.SockErr != nil { + buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf) } return buf @@ -416,21 +447,23 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageTClass) } - return space -} + if cmsgs.IP.HasIPPacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) + } -// NewIPPacketInfo returns the IPPacketInfo struct. -func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { - var p tcpip.IPPacketInfo - p.NIC = tcpip.NICID(packetInfo.NIC) - copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) - copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + if cmsgs.IP.OriginalDstAddress != nil { + space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) + } - return p + if cmsgs.IP.SockErr != nil { + space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes()) + } + + return space } // Parse parses a raw socket control message into portable objects. -func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { +func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) { var ( cmsgs socket.ControlMessages fds linux.ControlMessageRights @@ -454,10 +487,6 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con i += linux.SizeOfControlMessageHeader length := int(h.Length) - linux.SizeOfControlMessageHeader - // The use of t.Arch().Width() is analogous to Linux's use of - // sizeof(long) in CMSG_ALIGN. - width := t.Arch().Width() - switch h.Level { case linux.SOL_SOCKET: switch h.Type { @@ -489,6 +518,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.Unix.Credentials = scmCreds i += binary.AlignUp(length, width) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + return socket.ControlMessages{}, syserror.EINVAL + } + var ts linux.Timeval + binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &ts) + cmsgs.IP.Timestamp = ts.ToNsecCapped() + cmsgs.IP.HasTimestamp = true + i += binary.AlignUp(length, width) + default: // Unknown message type. return socket.ControlMessages{}, syserror.EINVAL @@ -512,7 +551,26 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + cmsgs.IP.PacketInfo = packetInfo + i += binary.AlignUp(length, width) + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + + case linux.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg i += binary.AlignUp(length, width) default: @@ -528,6 +586,25 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) i += binary.AlignUp(length, width) + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + + case linux.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + if length < errCmsg.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + + errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + cmsgs.IP.SockErr = &errCmsg + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go new file mode 100644 index 000000000..d40a4cc85 --- /dev/null +++ b/pkg/sentry/socket/control/control_test.go @@ -0,0 +1,59 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package control provides internal representations of socket control +// messages. +package control + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/usermem" +) + +func TestParse(t *testing.T) { + // Craft the control message to parse. + length := linux.SizeOfControlMessageHeader + linux.SizeOfTimeval + hdr := linux.ControlMessageHeader{ + Length: uint64(length), + Level: linux.SOL_SOCKET, + Type: linux.SO_TIMESTAMP, + } + buf := make([]byte, 0, length) + buf = binary.Marshal(buf, usermem.ByteOrder, &hdr) + ts := linux.Timeval{ + Sec: 2401, + Usec: 343, + } + buf = binary.Marshal(buf, usermem.ByteOrder, &ts) + + cmsg, err := Parse(nil, nil, buf, 8 /* width */) + if err != nil { + t.Fatalf("Parse(_, _, %+v, _): %v", cmsg, err) + } + + want := socket.ControlMessages{ + IP: socket.IPControlMessages{ + HasTimestamp: true, + Timestamp: ts.ToNsecCapped(), + }, + } + if diff := cmp.Diff(want, cmsg); diff != "" { + t.Errorf("unexpected message parsed, (-want, +got):\n%s", diff) + } +} diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..5b868216d 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 case linux.SO_LINGER: optlen = syscall.SizeofLinger @@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR: optlen = sizeofInt32 case linux.IP_PKTINFO: optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 } case linux.SOL_TCP: switch name { - case linux.TCP_NODELAY: + case linux.TCP_NODELAY, linux.TCP_INQ: optlen = sizeofInt32 } } @@ -416,68 +416,76 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] return nil } -// RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - // Only allow known and safe flags. - // - // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the - // Socket interface's dependence on netstack. - if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { - return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument - } +func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) { + // We always do a non-blocking recv*(). + sysflags := flags | syscall.MSG_DONTWAIT - var senderAddr linux.SockAddr + msg := syscall.Msghdr{} + if len(iovs) > 0 { + msg.Iov = &iovs[0] + msg.Iovlen = uint64(len(iovs)) + } var senderAddrBuf []byte if senderRequested { senderAddrBuf = make([]byte, sizeofSockaddr) + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(sizeofSockaddr) } - var controlBuf []byte - var msgFlags int - - recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { - // Refuse to do anything if any part of dst.Addrs was unusable. - if uint64(dst.NumBytes()) != dsts.NumBytes() { - return 0, nil - } - if dsts.IsEmpty() { - return 0, nil + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } + n, err := recvmsg(s.fd, &msg, sysflags) + if err != nil { + return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err + } + return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err +} - // We always do a non-blocking recv*(). - sysflags := flags | syscall.MSG_DONTWAIT +// RecvMsg implements socket.Socket.RecvMsg. +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + // Only allow known and safe flags. + if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument + } - iovs := safemem.IovecsFromBlockSeq(dsts) - msg := syscall.Msghdr{ - Iov: &iovs[0], - Iovlen: uint64(len(iovs)), - } - if len(senderAddrBuf) != 0 { - msg.Name = &senderAddrBuf[0] - msg.Namelen = uint32(len(senderAddrBuf)) - } - if controlLen > 0 { - if controlLen > maxControlLen { - controlLen = maxControlLen + var senderAddrBuf []byte + var controlBuf []byte + var msgFlags int + copyToDst := func() (int64, error) { + var n uint64 + var err error + if dst.NumBytes() == 0 { + // We want to make the recvmsg(2) call to the host even if dst is empty + // to fetch control messages, sender address or errors if any occur. + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen) + return int64(n), err + } + + recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { + // Refuse to do anything if any part of dst.Addrs was unusable. + if uint64(dst.NumBytes()) != dsts.NumBytes() { + return 0, nil } - controlBuf = make([]byte, controlLen) - msg.Control = &controlBuf[0] - msg.Controllen = controlLen - } - n, err := recvmsg(s.fd, &msg, sysflags) - if err != nil { - return 0, err - } - senderAddrBuf = senderAddrBuf[:msg.Namelen] - msgFlags = int(msg.Flags) - controlLen = uint64(msg.Controllen) - return n, nil - }) + if dsts.IsEmpty() { + return 0, nil + } + + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen) + return n, err + }) + return dst.CopyOutFrom(t, recvmsgToBlocks) + } var ch chan struct{} - n, err := dst.CopyOutFrom(t, recvmsgToBlocks) - if flags&syscall.MSG_DONTWAIT == 0 { + n, err := copyToDst() + // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT. + if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which // case it can't have returned any data. @@ -494,48 +502,85 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags s.EventRegister(&e, waiter.EventIn) defer s.EventUnregister(&e) } - n, err = dst.CopyOutFrom(t, recvmsgToBlocks) + n, err = copyToDst() } } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + var senderAddr linux.SockAddr if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil +} +func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages { controlMessages := socket.ControlMessages{} for _, unixCmsg := range unixControlMessages { switch unixCmsg.Header.Level { - case syscall.SOL_IP: + case linux.SOL_SOCKET: switch unixCmsg.Header.Type { - case syscall.IP_TOS: + case linux.SO_TIMESTAMP: + controlMessages.IP.HasTimestamp = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp) + } + + case linux.SOL_IP: + switch unixCmsg.Header.Type { + case linux.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) - case syscall.IP_PKTINFO: + case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) + controlMessages.IP.PacketInfo = packetInfo + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IP_RECVERR: + var errCmsg linux.SockErrCMsgIPv4 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg } - case syscall.SOL_IPV6: + case linux.SOL_IPV6: switch unixCmsg.Header.Type { - case syscall.IPV6_TCLASS: + case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + + case syscall.IPV6_RECVERR: + var errCmsg linux.SockErrCMsgIPv6 + errCmsg.UnmarshalBytes(unixCmsg.Data) + controlMessages.IP.SockErr = &errCmsg + } + + case linux.SOL_TCP: + switch unixCmsg.Header.Type { + case linux.TCP_INQ: + controlMessages.IP.HasInq = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq) } } } - - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil + return controlMessages } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index fae3b6783..b2206900b 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -25,7 +25,6 @@ go_library( "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/metric", - "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e8a0103bf..dcf898c0a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -28,9 +28,9 @@ import ( "bytes" "fmt" "io" + "io/ioutil" "math" "reflect" - "sync/atomic" "syscall" "time" @@ -43,7 +43,6 @@ import ( "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/metric" - "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -84,69 +83,95 @@ var Metrics = tcpip.Stats{ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), ICMP: tcpip.ICMPStats{ - V4PacketsSent: tcpip.ICMPv4SentPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + V4: tcpip.ICMPv4Stats{ + PacketsSent: tcpip.ICMPv4SentPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), + }, + PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), }, - V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + V6: tcpip.ICMPv6Stats{ + PacketsSent: tcpip.ICMPv6SentPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + }, + PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - V6PacketsSent: tcpip.ICMPv6SentPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + IGMP: tcpip.IGMPStats{ + PacketsSent: tcpip.IGMPSentPacketStats{ + IGMPPacketStats: tcpip.IGMPPacketStats{ + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."), }, - V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + PacketsReceived: tcpip.IGMPReceivedPacketStats{ + IGMPPacketStats: tcpip.IGMPPacketStats{ + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."), }, - Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."), + ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."), }, }, IP: tcpip.IPStats{ @@ -209,18 +234,6 @@ const sizeOfInt32 int = 4 var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) -// ntohs converts a 16-bit number from network byte order to host byte order. It -// assumes that the host is little endian. -func ntohs(v uint16) uint16 { - return v<<8 | v>>8 -} - -// htons converts a 16-bit number from host byte order to network byte order. It -// assumes that the host is little endian. -func htons(v uint16) uint16 { - return ntohs(v) -} - // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. type commonEndpoint interface { @@ -240,10 +253,6 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and - // transport.Endpoint.SetSockOptBool. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -252,14 +261,14 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and - // transport.Endpoint.GetSockOpt. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 + // LastError implements tcpip.Endpoint.LastError and // transport.Endpoint.LastError. LastError() *tcpip.Error @@ -298,19 +307,11 @@ type socketOpsCommon struct { skType linux.SockType protocol int - // readViewHasData is 1 iff readView has data to be read, 0 otherwise. - // Must be accessed using atomic operations. It must only be written - // with readMu held but can be read without holding readMu. The latter - // is required to avoid deadlocks in epoll Readiness checks. - readViewHasData uint32 - // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` - // readView contains the remaining payload from the last packet. - readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages + readCM socket.IPControlMessages sender tcpip.FullAddress linkPacketInfo tcpip.LinkPacketInfo @@ -326,17 +327,15 @@ type socketOpsCommon struct { // valid when timestampValid is true. It is protected by readMu. timestampNS int64 - // sockOptInq corresponds to TCP_INQ. It is implemented at this level - // because it takes into account data from readView. + // TODO(b/153685824): Move this to SocketOptions. + // sockOptInq corresponds to TCP_INQ. sockOptInq bool } // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } dirent := socket.NewDirent(t, netstackDevice) @@ -365,127 +364,27 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, -// AF_INET6, and AF_PACKET addresses. -// -// AddressAndFamily returns an address and its family. -func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { - // Make sure we have at least 2 bytes for the address family. - if len(addr) < 2 { - return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument - } - - // Get the rest of the fields based on the address family. - switch family := usermem.ByteOrder.Uint16(addr); family { - case linux.AF_UNIX: - path := addr[2:] - if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - // Drop the terminating NUL (if one exists) and everything after - // it for filesystem (non-abstract) addresses. - if len(path) > 0 && path[0] != 0 { - if n := bytes.IndexByte(path[1:], 0); n >= 0 { - path = path[:n+1] - } - } - return tcpip.FullAddress{ - Addr: tcpip.Address(path), - }, family, nil - - case linux.AF_INET: - var a linux.SockAddrInet - if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - return out, family, nil - - case linux.AF_INET6: - var a linux.SockAddrInet6 - if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - if isLinkLocal(out.Addr) { - out.NIC = tcpip.NICID(a.Scope_id) - } - return out, family, nil - - case linux.AF_PACKET: - var a linux.SockAddrLink - if len(addr) < sockAddrLinkSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) - if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/173): Return protocol too. - return tcpip.FullAddress{ - NIC: tcpip.NICID(a.InterfaceIndex), - Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), - }, family, nil - - case linux.AF_UNSPEC: - return tcpip.FullAddress{}, family, nil - - default: - return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported - } -} - func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } -// fetchReadView updates the readView field of the socket if it's currently -// empty. It assumes that the socket is locked. -// // Precondition: s.readMu must be held. -func (s *socketOpsCommon) fetchReadView() *syserr.Error { - if len(s.readView) > 0 { - return nil - } - s.readView = nil - s.sender = tcpip.FullAddress{} - s.linkPacketInfo = tcpip.LinkPacketInfo{} +func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) { + res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{ + Peek: peek, + NeedRemoteAddr: true, + NeedLinkPacketInfo: true, + }) - var v buffer.View - var cms tcpip.ControlMessages - var err *tcpip.Error + // Assign these anyways. + s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages) + s.sender = res.RemoteAddr + s.linkPacketInfo = res.LinkPacketInfo - switch e := s.Endpoint.(type) { - // The ordering of these interfaces matters. The most specific - // interfaces must be specified before the more generic Endpoint - // interface. - case tcpip.PacketEndpoint: - v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) - case tcpip.Endpoint: - v, cms, err = e.Read(&s.sender) - } if err != nil { - atomic.StoreUint32(&s.readViewHasData, 0) - return syserr.TranslateNetstackError(err) + return 0, 0, syserr.TranslateNetstackError(err) } - - s.readView = v - s.readCM = cms - atomic.StoreUint32(&s.readViewHasData, 1) - - return nil + return res.Count, res.Total, nil } // Release implements fs.FileOperations.Release. @@ -502,11 +401,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { return } - var v tcpip.LingerOption - if err := s.Endpoint.GetSockOpt(&v); err != nil { - return - } - + v := s.Endpoint.SocketOptions().GetLinger() // The case for zero timeout is handled in tcp endpoint close function. // Close is blocked until either: // 1. The endpoint state is not in any of the states: FIN-WAIT1, @@ -538,38 +433,14 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // WriteTo implements fs.FileOperations.WriteTo. func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { s.readMu.Lock() + defer s.readMu.Unlock() - // Copy as much data as possible. - done := int64(0) - for count > 0 { - // This may return a blocking error. - if err := s.fetchReadView(); err != nil { - s.readMu.Unlock() - return done, err.ToError() - } - - // Write to the underlying file. - n, err := dst.Write(s.readView) - done += int64(n) - count -= int64(n) - if dup { - // That's all we support for dup. This is generally - // supported by any Linux system calls, but the - // expectation is that now a caller will call read to - // actually remove these bytes from the socket. - break - } - - // Drop that part of the view. - s.readView.TrimFront(n) - if err != nil { - s.readMu.Unlock() - return done, err - } + // This may return a blocking error. + n, _, err := s.readLocked(dst, int(count), dup /* peek */) + if err != nil { + return 0, err.ToError() } - - s.readMu.Unlock() - return done, nil + return int64(n), nil } // ioSequencePayload implements tcpip.Payload. @@ -705,17 +576,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader // Readiness returns a mask of ready events for socket s. func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { - r := s.Endpoint.Readiness(mask) - - // Check our cached value iff the caller asked for readability and the - // endpoint itself is currently not readable. - if (mask & ^r & waiter.EventIn) != 0 { - if atomic.LoadUint32(&s.readViewHasData) == 1 { - r |= waiter.EventIn - } - } - - return r + return s.Endpoint.Readiness(mask) } func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { @@ -723,11 +584,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { return nil } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { - v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return syserr.TranslateNetstackError(err) - } - if !v { + if !s.Endpoint.SocketOptions().GetV6Only() { return nil } } @@ -751,7 +608,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -832,7 +689,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } } else { var err *syserr.Error - addr, family, err = AddressAndFamily(sockaddr) + addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -923,7 +780,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -1007,7 +864,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptSocket(t, s, ep, family, skType, name, outLen) case linux.SOL_TCP: - return getSockOptTCP(t, ep, name, outLen) + return getSockOptTCP(t, s, ep, name, outLen) case linux.SOL_IPV6: return getSockOptIPv6(t, s, ep, name, outPtr, outLen) @@ -1043,7 +900,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.LastError() + err := ep.SocketOptions().GetLastError() if err == nil { optP := primitive.Int32(0) return &optP, nil @@ -1124,10 +981,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return &v, nil case linux.SO_BINDTODEVICE: - var v tcpip.BindToDeviceOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetBindToDevice() if v == 0 { var b primitive.ByteSlice return &b, nil @@ -1170,11 +1024,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.LingerOption var linger linux.Linger - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetLinger() if v.Enabled { linger.OnOff = 1 @@ -1205,13 +1056,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.OutOfBandInlineOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(v) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline())) + return &v, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { @@ -1226,8 +1072,13 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v := primitive.Int32(boolToInt32(ep.SocketOptions().GetAcceptConn())) - return &v, nil + // This option is only viable for TCP endpoints. + var v bool + if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) { + v = tcp.EndpointState(ep.State()) == tcp.StateListen + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1236,46 +1087,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(!v)) - return &vP, nil + v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption())) + return &v, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.CorkOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption())) + return &v, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.QuickAckOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck())) + return &v, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1449,19 +1290,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + + family, skType, _ := s.Type() + if family != linux.AF_INET6 { + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only())) + return &v, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1493,13 +1339,23 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) + return &v, nil + case linux.IPV6_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil + + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1511,7 +1367,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: @@ -1520,7 +1376,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1540,7 +1396,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1560,7 +1416,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1582,6 +1438,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name // getSockOptIP implements GetSockOpt when level is SOL_IP. func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1624,7 +1485,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) return &a.(*linux.SockAddrInet).Addr, nil @@ -1633,13 +1494,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop())) + return &v, nil case linux.IP_TOS: // Length handling for parity with Linux. @@ -1663,26 +1519,40 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) + return &v, nil + + case linux.IP_RECVERR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError())) + return &v, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo())) + return &v, nil + + case linux.IP_HDRINCL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) + return &v, nil + + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { @@ -1694,7 +1564,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet), nil case linux.IPT_SO_GET_INFO: @@ -1801,7 +1671,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptSocket(t, s, ep, name, optVal) case linux.SOL_TCP: - return setSockOptTCP(t, ep, name, optVal) + return setSockOptTCP(t, s, ep, name, optVal) case linux.SOL_IPV6: return setSockOptIPv6(t, s, ep, name, optVal) @@ -1870,8 +1740,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } name := string(optVal[:n]) if name == "" { - v := tcpip.BindToDeviceOption(0) - return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) + return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0)) } s := t.NetworkContext() if s == nil { @@ -1879,8 +1748,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } for nicID, nic := range s.Interfaces() { if nic.Name == name { - v := tcpip.BindToDeviceOption(nicID) - return syserr.TranslateNetstackError(ep.SetSockOpt(&v)) + return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID)) } } return syserr.ErrUnknownDevice @@ -1949,8 +1817,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - opt := tcpip.OutOfBandInlineOption(v) - return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) + ep.SocketOptions().SetOutOfBandInline(v != 0) + return nil case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1973,10 +1841,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError( - ep.SetSockOpt(&tcpip.LingerOption{ - Enabled: v.OnOff != 0, - Timeout: time.Second * time.Duration(v.Linger)})) + ep.SocketOptions().SetLinger(tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger), + }) + return nil case linux.SO_DETACH_FILTER: // optval is ignored. @@ -1991,7 +1860,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. -func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if len(optVal) < sizeOfInt32 { @@ -1999,7 +1873,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) + ep.SocketOptions().SetDelayOption(v == 0) + return nil case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -2007,7 +1882,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) + ep.SocketOptions().SetCorkOption(v != 0) + return nil case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -2015,7 +1891,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) + ep.SocketOptions().SetQuickAck(v != 0) + return nil case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -2127,18 +2004,55 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + + family, skType, skProto := s.Type() + if family != linux.AF_INET6 { + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } + if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { + return syserr.ErrInvalidEndpointState + } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + return syserr.ErrInvalidEndpointState + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) + ep.SocketOptions().SetV6Only(v != 0) + return nil + + case linux.IPV6_ADD_MEMBERSHIP: + req, err := copyInMulticastV6Request(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.AddMembershipOption{ + NIC: tcpip.NICID(req.InterfaceIndex), + MulticastAddr: tcpip.Address(req.MulticastAddr[:]), + })) - case linux.IPV6_ADD_MEMBERSHIP, - linux.IPV6_DROP_MEMBERSHIP, - linux.IPV6_IPSEC_POLICY, + case linux.IPV6_DROP_MEMBERSHIP: + req, err := copyInMulticastV6Request(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.RemoveMembershipOption{ + NIC: tcpip.NICID(req.InterfaceIndex), + MulticastAddr: tcpip.Address(req.MulticastAddr[:]), + })) + + case linux.IPV6_IPSEC_POLICY, linux.IPV6_JOIN_ANYCAST, linux.IPV6_LEAVE_ANYCAST, // TODO(b/148887420): Add support for IPV6_PKTINFO. @@ -2154,6 +2068,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2173,7 +2096,18 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + ep.SocketOptions().SetReceiveTClass(v != 0) + return nil + case linux.IPV6_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2181,7 +2115,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return syserr.ErrProtocolNotAvailable } @@ -2206,6 +2140,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name var ( inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{})) inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{})) + inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{})) ) // copyInMulticastRequest copies in a variable-size multicast request. The @@ -2239,6 +2174,16 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR return req, nil } +func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syserr.Error) { + if len(optVal) < inet6MulticastRequestSize { + return linux.Inet6MulticastRequest{}, syserr.ErrInvalidArgument + } + + var req linux.Inet6MulticastRequest + binary.Unmarshal(optVal[:inet6MulticastRequestSize], usermem.ByteOrder, &req) + return req, nil +} + // parseIntOrChar copies either a 32-bit int or an 8-bit uint out of buf. // // net/ipv4/ip_sockglue.c:do_ip_setsockopt does this for its socket options. @@ -2256,6 +2201,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2308,7 +2258,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), - InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), + InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]), })) case linux.IP_MULTICAST_LOOP: @@ -2317,7 +2267,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) + ep.SocketOptions().SetMulticastLoop(v != 0) + return nil case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2353,7 +2304,19 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + ep.SocketOptions().SetReceiveTOS(v != 0) + return nil + + case linux.IP_RECVERR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + ep.SocketOptions().SetRecvError(v != 0) + return nil case linux.IP_PKTINFO: if len(optVal) == 0 { @@ -2363,7 +2326,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + ep.SocketOptions().SetReceivePacketInfo(v != 0) + return nil case linux.IP_HDRINCL: if len(optVal) == 0 { @@ -2373,7 +2337,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + ep.SocketOptions().SetHeaderIncluded(v != 0) + return nil + + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { @@ -2410,10 +2387,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2487,11 +2462,9 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_MULTICAST_IF, linux.IPV6_MULTICAST_LOOP, linux.IPV6_RECVDSTOPTS, - linux.IPV6_RECVERR, linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2515,7 +2488,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { switch name { case linux.IP_TOS, linux.IP_TTL, - linux.IP_HDRINCL, linux.IP_OPTIONS, linux.IP_ROUTER_ALERT, linux.IP_RECVOPTS, @@ -2523,7 +2495,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { linux.IP_PKTINFO, linux.IP_PKTOPTIONS, linux.IP_MTU_DISCOVER, - linux.IP_RECVERR, linux.IP_RECVTTL, linux.IP_RECVTOS, linux.IP_MTU, @@ -2562,72 +2533,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { } } -// isLinkLocal determines if the given IPv6 address is link-local. This is the -// case when it has the fe80::/10 prefix. This check is used to determine when -// the NICID is relevant for a given IPv6 address. -func isLinkLocal(addr tcpip.Address) bool { - return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 -} - -// ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { - switch family { - case linux.AF_UNIX: - var out linux.SockAddrUnix - out.Family = linux.AF_UNIX - l := len([]byte(addr.Addr)) - for i := 0; i < l; i++ { - out.Path[i] = int8(addr.Addr[i]) - } - - // Linux returns the used length of the address struct (including the - // null terminator) for filesystem paths. The Family field is 2 bytes. - // It is sometimes allowed to exclude the null terminator if the - // address length is the max. Abstract and empty paths always return - // the full exact length. - if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return &out, uint32(2 + l) - } - return &out, uint32(3 + l) - - case linux.AF_INET: - var out linux.SockAddrInet - copy(out.Addr[:], addr.Addr) - out.Family = linux.AF_INET - out.Port = htons(addr.Port) - return &out, uint32(sockAddrInetSize) - - case linux.AF_INET6: - var out linux.SockAddrInet6 - if len(addr.Addr) == header.IPv4AddressSize { - // Copy address in v4-mapped format. - copy(out.Addr[12:], addr.Addr) - out.Addr[10] = 0xff - out.Addr[11] = 0xff - } else { - copy(out.Addr[:], addr.Addr) - } - out.Family = linux.AF_INET6 - out.Port = htons(addr.Port) - if isLinkLocal(addr.Addr) { - out.Scope_id = uint32(addr.NIC) - } - return &out, uint32(sockAddrInet6Size) - - case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. - var out linux.SockAddrLink - out.Family = linux.AF_PACKET - out.InterfaceIndex = int32(addr.NIC) - out.HardwareAddrLen = header.EthernetAddressSize - copy(out.HardwareAddr[:], addr.Addr) - return &out, uint32(sockAddrLinkSize) - - default: - return nil, 0 - } -} - // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { @@ -2636,7 +2541,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2648,70 +2553,24 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } -// coalescingRead is the fast path for non-blocking, non-peek, stream-based -// case. It coalesces as many packets as possible before returning to the -// caller. +// streamRead is the fast path for non-blocking, non-peek, stream-based socket. // // Precondition: s.readMu must be locked. -func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { - var err *syserr.Error - var copied int - - // Copy as many views as possible into the user-provided buffer. - for { - // Always do at least one fetchReadView, even if the number of bytes to - // read is 0. - err = s.fetchReadView() - if err != nil || len(s.readView) == 0 { - break - } - if dst.NumBytes() == 0 { - break - } - - var n int - var e error - if discard { - n = len(s.readView) - if int64(n) > dst.NumBytes() { - n = int(dst.NumBytes()) - } - } else { - n, e = dst.CopyOut(ctx, s.readView) - // Set the control message, even if 0 bytes were read. - if e == nil { - s.updateTimestamp() - } - } - copied += n - s.readView.TrimFront(n) - - dst = dst.DropFirst(n) - if e != nil { - err = syserr.FromError(e) - break - } - // If we are done reading requested data then stop. - if dst.NumBytes() == 0 { - break - } - } - - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) +func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) { + // Always do at least one read, even if the number of bytes to read is 0. + var n int + n, _, err := s.readLocked(dst, count, false /* peek */) + if err != nil { + return 0, err } - - // If we managed to copy something, we must deliver it. - if copied > 0 { - s.Endpoint.ModerateRecvBuf(copied) - return copied, nil + if n > 0 { + s.Endpoint.ModerateRecvBuf(n) } - - return 0, err + return n, nil } func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { @@ -2723,7 +2582,7 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { return } cmsg.IP.HasInq = true - cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) + cmsg.IP.Inq = int32(rcvBufUsed) } func toLinuxPacketType(pktType tcpip.PacketType) uint8 { @@ -2760,7 +2619,21 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // bytes of data to be discarded, rather than passed back in a // caller-supplied buffer. s.readMu.Lock() - n, err := s.coalescingRead(ctx, dst, trunc) + + var w io.Writer + if trunc { + w = ioutil.Discard + } else { + w = dst.Writer(ctx) + } + + n, err := s.streamRead(ctx, w, int(dst.NumBytes())) + + if err == nil && !trunc { + // Set the control message, even if 0 bytes were read. + s.updateTimestamp() + } + cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) s.readMu.Unlock() @@ -2770,18 +2643,32 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq s.readMu.Lock() defer s.readMu.Unlock() - if err := s.fetchReadView(); err != nil { + // MSG_TRUNC with MSG_PEEK on a TCP socket returns the + // amount that could be read, and does not write to buffer. + isTCPPeekTrunc := !isPacket && peek && trunc + + var w io.Writer + if isTCPPeekTrunc { + w = ioutil.Discard + } else { + w = dst.Writer(ctx) + } + + var numRead, numTotal int + var err *syserr.Error + numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek) + if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, err } - if !isPacket && peek && trunc { - // MSG_TRUNC with MSG_PEEK on a TCP socket returns the - // amount that could be read. + if isTCPPeekTrunc { + // TCP endpoint does not return the total bytes in buffer as numTotal. + // We need to query it from socket option. rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } - available := len(s.readView) + int(rql) + available := int(rql) bufLen := int(dst.NumBytes()) if available < bufLen { return available, 0, nil, 0, socket.ControlMessages{}, nil @@ -2789,88 +2676,65 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq return bufLen, 0, nil, 0, socket.ControlMessages{}, nil } - n, err := dst.CopyOut(ctx, s.readView) // Set the control message, even if 0 bytes were read. - if err == nil { - s.updateTimestamp() - } + s.updateTimestamp() + var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { - addr, addrLen = ConvertAddress(s.family, s.sender) + addr, addrLen = socket.ConvertAddress(s.family, s.sender) switch v := addr.(type) { case *linux.SockAddrLink: - v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } if peek { - if l := len(s.readView); trunc && l > n { + if trunc && numTotal > numRead { // isPacket must be true. - return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err) + return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil } - - if isPacket || err != nil { - return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err) - } - - // We need to peek beyond the first message. - dst = dst.DropFirst(n) - num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) { - n, _, err := s.Endpoint.Peek(dsts) - // TODO(b/78348848): Handle peek timestamp. - if err != nil { - return int64(n), syserr.TranslateNetstackError(err).ToError() - } - return int64(n), nil - }}) - n += int(num) - if err == syserror.ErrWouldBlock && n > 0 { - // We got some data, so no need to return an error. - err = nil - } - return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err) + return numRead, 0, nil, 0, s.controlMessages(), nil } var msgLen int if isPacket { - msgLen = len(s.readView) - s.readView = nil + msgLen = numTotal } else { - msgLen = int(n) - s.readView.TrimFront(int(n)) - } - - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) + msgLen = numRead } var flags int - if msgLen > int(n) { + if msgLen > numRead { flags |= linux.MSG_TRUNC } + n := numRead if trunc { n = msgLen } cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) - return n, flags, addr, addrLen, cmsg, syserr.FromError(err) + return n, flags, addr, addrLen, cmsg, nil } func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ - IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + IP: socket.IPControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasInq: s.readCM.HasInq, + Inq: s.readCM.Inq, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + OriginalDstAddress: s.readCM.OriginalDstAddress, + SockErr: s.readCM.SockErr, }, } } @@ -2887,9 +2751,66 @@ func (s *socketOpsCommon) updateTimestamp() { } } +// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb(). +func (s *socketOpsCommon) dequeueErr() *tcpip.SockError { + so := s.Endpoint.SocketOptions() + err := so.DequeueErr() + if err == nil { + return nil + } + + // Update socket error to reflect ICMP errors in queue. + if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() { + so.SetLastError(nextErr.Err) + } else if err.ErrOrigin.IsICMPErr() { + so.SetLastError(nil) + } + return err +} + +// addrFamilyFromNetProto returns the address family identifier for the given +// network protocol. +func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int { + switch net { + case header.IPv4ProtocolNumber: + return linux.AF_INET + case header.IPv6ProtocolNumber: + return linux.AF_INET6 + default: + panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net)) + } +} + +// recvErr handles MSG_ERRQUEUE for recvmsg(2). +// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error(). +func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + sockErr := s.dequeueErr() + if sockErr == nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + } + + // The payload of the original packet that caused the error is passed as + // normal data via msg_iovec. -- recvmsg(2) + msgFlags := linux.MSG_ERRQUEUE + if int(dst.NumBytes()) < len(sockErr.Payload) { + msgFlags |= linux.MSG_TRUNC + } + n, err := dst.CopyOut(t, sockErr.Payload) + + // The original destination address of the datagram that caused the error is + // supplied via msg_name. -- recvmsg(2) + dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst) + cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})} + return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err) +} + // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { + if flags&linux.MSG_ERRQUEUE != 0 { + return s.recvErr(t, dst) + } + trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -2965,7 +2886,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, family, err := AddressAndFamily(to) + addrBuf, family, err := socket.AddressAndFamily(to) if err != nil { return 0, err } @@ -3063,11 +2984,6 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy return 0, syserr.TranslateNetstackError(terr).ToError() } - // Add bytes removed from the endpoint but not yet sent to the caller. - s.readMu.Lock() - v += len(s.readView) - s.readMu.Unlock() - if v > math.MaxInt32 { v = math.MaxInt32 } @@ -3384,6 +3300,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { return rv } +func isTCPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP) +} + +func isUDPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP) +} + +func isICMPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6) +} + // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. func (s *socketOpsCommon) State() uint32 { @@ -3393,7 +3321,7 @@ func (s *socketOpsCommon) State() uint32 { } switch { - case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: + case isTCPSocket(s.skType, s.protocol): // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -3422,7 +3350,7 @@ func (s *socketOpsCommon) State() uint32 { // Internal or unknown state. return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + case isUDPSocket(s.skType, s.protocol): // UDP socket. switch udp.EndpointState(s.Endpoint.State()) { case udp.StateInitial, udp.StateBound, udp.StateClosed: @@ -3432,7 +3360,7 @@ func (s *socketOpsCommon) State() uint32 { default: return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + case isICMPSocket(s.skType, s.protocol): // TODO(b/112063468): Export states for ICMP sockets. case s.skType == linux.SOCK_RAW: // TODO(b/112063468): Export states for raw sockets. diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index b0d9e4d9e..b756bfca0 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // NewVFS2 creates a new endpoint socket. func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } mnt := t.Kernel().SocketMount() @@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addrLen uint32 if peerAddr != nil { // Get address of the peer and write it to peer slice. - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index ead3b2b79..c847ff1c7 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 2a01143f6..0af805246 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fa9ac9059..cc0fadeb5 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: - in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats - out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ 0, // Icmp/InMsgs. - Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors. 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. @@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. 0, // Icmp/OutMsgs. - Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. - out.DstUnreachable.Value(), // OutDestUnreachs. - out.TimeExceeded.Value(), // OutTimeExcds. - out.ParamProblem.Value(), // OutParmProbs. - out.SrcQuench.Value(), // OutSrcQuenchs. - out.Redirect.Value(), // OutRedirects. - out.Echo.Value(), // OutEchos. - out.EchoReply.Value(), // OutEchoReps. - out.Timestamp.Value(), // OutTimestamps. - out.TimestampReply.Value(), // OutTimestampReps. - out.InfoRequest.Value(), // OutAddrMasks. - out.InfoReply.Value(), // OutAddrMaskReps. + Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. } case *inet.StatSNMPTCP: tcp := Metrics.TCP diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fd31479e5..97729dacc 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -18,6 +18,7 @@ package socket import ( + "bytes" "fmt" "sync/atomic" "syscall" @@ -35,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/usermem" ) @@ -42,7 +44,134 @@ import ( // control messages. type ControlMessages struct { Unix transport.ControlMessages - IP tcpip.ControlMessages + IP IPControlMessages +} + +// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format. +func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + return p +} + +// errOriginToLinux maps tcpip socket origin to Linux socket origin constants. +func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 { + switch origin { + case tcpip.SockExtErrorOriginNone: + return linux.SO_EE_ORIGIN_NONE + case tcpip.SockExtErrorOriginLocal: + return linux.SO_EE_ORIGIN_LOCAL + case tcpip.SockExtErrorOriginICMP: + return linux.SO_EE_ORIGIN_ICMP + case tcpip.SockExtErrorOriginICMP6: + return linux.SO_EE_ORIGIN_ICMP6 + default: + panic(fmt.Sprintf("unknown socket origin: %d", origin)) + } +} + +// sockErrCmsgToLinux converts SockError control message from tcpip format to +// Linux format. +func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg { + if sockErr == nil { + return nil + } + + ee := linux.SockExtendedErr{ + Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()), + Origin: errOriginToLinux(sockErr.ErrOrigin), + Type: sockErr.ErrType, + Code: sockErr.ErrCode, + Info: sockErr.ErrInfo, + } + + switch sockErr.NetProto { + case header.IPv4ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet) + } + return errMsg + case header.IPv6ProtocolNumber: + errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee} + if len(sockErr.Offender.Addr) > 0 { + addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender) + errMsg.Offender = *addr.(*linux.SockAddrInet6) + } + return errMsg + default: + panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto)) + } +} + +// NewIPControlMessages converts the tcpip ControlMessgaes (which does not +// have Linux specific format) to Linux format. +func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages { + var orgDstAddr linux.SockAddr + if cmgs.HasOriginalDstAddress { + orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) + } + return IPControlMessages{ + HasTimestamp: cmgs.HasTimestamp, + Timestamp: cmgs.Timestamp, + HasInq: cmgs.HasInq, + Inq: cmgs.Inq, + HasTOS: cmgs.HasTOS, + TOS: cmgs.TOS, + HasTClass: cmgs.HasTClass, + TClass: cmgs.TClass, + HasIPPacketInfo: cmgs.HasIPPacketInfo, + PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + OriginalDstAddress: orgDstAddr, + SockErr: sockErrCmsgToLinux(cmgs.SockErr), + } +} + +// IPControlMessages contains socket control messages for IP sockets. +// This can contain Linux specific structures unlike tcpip.ControlMessages. +// +// +stateify savable +type IPControlMessages struct { + // HasTimestamp indicates whether Timestamp is valid/set. + HasTimestamp bool + + // Timestamp is the time (in ns) that the last packet used to create + // the read data was received. + Timestamp int64 + + // HasInq indicates whether Inq is valid/set. + HasInq bool + + // Inq is the number of bytes ready to be received. + Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS uint8 + + // HasTClass indicates whether TClass is valid/set. + HasTClass bool + + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo linux.ControlMessageIPPacketInfo + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress linux.SockAddr + + // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE). + SockErr linux.SockErrCMsg } // Release releases Unix domain socket credentials and rights. @@ -460,3 +589,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { panic(fmt.Sprintf("Unsupported socket family %v", family)) } } + +var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes() +var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes() +var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes() + +// Ntohs converts a 16-bit number from network byte order to host byte order. It +// assumes that the host is little endian. +func Ntohs(v uint16) uint16 { + return v<<8 | v>>8 +} + +// Htons converts a 16-bit number from host byte order to network byte order. It +// assumes that the host is little endian. +func Htons(v uint16) uint16 { + return Ntohs(v) +} + +// isLinkLocal determines if the given IPv6 address is link-local. This is the +// case when it has the fe80::/10 prefix. This check is used to determine when +// the NICID is relevant for a given IPv6 address. +func isLinkLocal(addr tcpip.Address) bool { + return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 +} + +// ConvertAddress converts the given address to a native format. +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { + switch family { + case linux.AF_UNIX: + var out linux.SockAddrUnix + out.Family = linux.AF_UNIX + l := len([]byte(addr.Addr)) + for i := 0; i < l; i++ { + out.Path[i] = int8(addr.Addr[i]) + } + + // Linux returns the used length of the address struct (including the + // null terminator) for filesystem paths. The Family field is 2 bytes. + // It is sometimes allowed to exclude the null terminator if the + // address length is the max. Abstract and empty paths always return + // the full exact length. + if l == 0 || out.Path[0] == 0 || l == len(out.Path) { + return &out, uint32(2 + l) + } + return &out, uint32(3 + l) + + case linux.AF_INET: + var out linux.SockAddrInet + copy(out.Addr[:], addr.Addr) + out.Family = linux.AF_INET + out.Port = Htons(addr.Port) + return &out, uint32(sockAddrInetSize) + + case linux.AF_INET6: + var out linux.SockAddrInet6 + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. + copy(out.Addr[12:], addr.Addr) + out.Addr[10] = 0xff + out.Addr[11] = 0xff + } else { + copy(out.Addr[:], addr.Addr) + } + out.Family = linux.AF_INET6 + out.Port = Htons(addr.Port) + if isLinkLocal(addr.Addr) { + out.Scope_id = uint32(addr.NIC) + } + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(gvisor.dev/issue/173): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + + default: + return nil, 0 + } +} + +// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the +// netstack representation taking any addresses into account. +func BytesToIPAddress(addr []byte) tcpip.Address { + if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { + return "" + } + return tcpip.Address(addr) +} + +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. +// +// AddressAndFamily returns an address and its family. +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { + // Make sure we have at least 2 bytes for the address family. + if len(addr) < 2 { + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument + } + + // Get the rest of the fields based on the address family. + switch family := usermem.ByteOrder.Uint16(addr); family { + case linux.AF_UNIX: + path := addr[2:] + if len(path) > linux.UnixPathMax { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + // Drop the terminating NUL (if one exists) and everything after + // it for filesystem (non-abstract) addresses. + if len(path) > 0 && path[0] != 0 { + if n := bytes.IndexByte(path[1:], 0); n >= 0 { + path = path[:n+1] + } + } + return tcpip.FullAddress{ + Addr: tcpip.Address(path), + }, family, nil + + case linux.AF_INET: + var a linux.SockAddrInet + if len(addr) < sockAddrInetSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + return out, family, nil + + case linux.AF_INET6: + var a linux.SockAddrInet6 + if len(addr) < sockAddrInet6Size { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + if isLinkLocal(out.Addr) { + out.NIC = tcpip.NICID(a.Scope_id) + } + return out, family, nil + + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/173): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, family, nil + + default: + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported + } +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 4abea90cc..099a56281 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -178,10 +178,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool sets a socket option for simple cases when a value has - // the int type. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt sets a socket option for simple cases when a value has // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -189,10 +185,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool gets a socket option for simple cases when a return - // value has the int type. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -754,9 +746,6 @@ type baseEndpoint struct { // or may be used if the endpoint is connected. path string - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -848,17 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - e.linger = *v - e.Unlock() - } - return nil -} - -func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - log.Warningf("Unsupported socket option: %d", opt) return nil } @@ -872,11 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - log.Warningf("Unsupported socket option: %d", opt) - return false, tcpip.ErrUnknownProtocolOption -} - func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -940,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - *o = e.linger - e.Unlock() - return nil - - default: - log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption - } + log.Warningf("Unsupported socket option: %T", opt) + return tcpip.ErrUnknownProtocolOption } // LastError implements Endpoint.LastError. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index b32bb7ba8..c59297c80 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -136,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, family, err := netstack.AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { if err == syserr.ErrAddressFamilyNotSupported { err = syserr.ErrInvalidArgument @@ -169,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -181,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -255,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -647,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -682,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index eaf0b0d26..27f705bb2 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD index a920180d3..d36a64ffc 100644 --- a/pkg/sentry/strace/BUILD +++ b/pkg/sentry/strace/BUILD @@ -32,8 +32,8 @@ go_library( "//pkg/seccomp", "//pkg/sentry/arch", "//pkg/sentry/kernel", + "//pkg/sentry/socket", "//pkg/sentry/socket/netlink", - "//pkg/sentry/socket/netstack", "//pkg/sentry/syscalls/linux", "//pkg/usermem", ], diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index cc5f70cd4..d943a7cb1 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -23,8 +23,8 @@ import ( "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" - "gvisor.dev/gvisor/pkg/sentry/socket/netstack" slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux" "gvisor.dev/gvisor/pkg/usermem" ) @@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string { switch family { case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX: - fa, _, err := netstack.AddressAndFamily(b) + fa, _, err := socket.AddressAndFamily(b) if err != nil { return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err) } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index bb1f715e2..a72df62f6 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -272,7 +272,7 @@ var AMD64 = &kernel.SyscallTable{ 217: syscalls.Supported("getdents64", Getdents64), 218: syscalls.Supported("set_tid_address", SetTidAddress), 219: syscalls.Supported("restart_syscall", RestartSyscall), - 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), + 220: syscalls.PartiallySupported("semtimedop", Semtimedop, "A non-zero timeout argument isn't supported.", []string{"gvisor.dev/issue/137"}), 221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil), 222: syscalls.Supported("timer_create", TimerCreate), 223: syscalls.Supported("timer_settime", TimerSettime), @@ -619,8 +619,8 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), - 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil), + 192: syscalls.PartiallySupported("semtimedop", Semtimedop, "A non-zero timeout argument isn't supported.", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), 195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go index 0bf313a13..c2285f796 100644 --- a/pkg/sentry/syscalls/linux/sys_aio.go +++ b/pkg/sentry/syscalls/linux/sys_aio.go @@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := ctx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := ctx.Prepare(); err != nil { + return err } if eventFile != nil { diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go index 519066a47..c33571f43 100644 --- a/pkg/sentry/syscalls/linux/sys_file.go +++ b/pkg/sentry/syscalls/linux/sys_file.go @@ -175,6 +175,12 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint } } + file, err := d.Inode.GetFile(t, d, fileFlags) + if err != nil { + return syserror.ConvertIntr(err, syserror.ERESTARTSYS) + } + defer file.DecRef(t) + // Truncate is called when O_TRUNC is specified for any kind of // existing Dirent. Behavior is delegated to the entry's Truncate // implementation. @@ -184,12 +190,6 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint } } - file, err := d.Inode.GetFile(t, d, fileFlags) - if err != nil { - return syserror.ConvertIntr(err, syserror.ERESTARTSYS) - } - defer file.DecRef(t) - // Success. newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{ CloseOnExec: flags&linux.O_CLOEXEC != 0, @@ -646,7 +646,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil { return 0, nil, err } - fSetOwn(t, file, set) + fSetOwn(t, int(fd), file, set) return 0, nil, nil case linux.FIOGETOWN, linux.SIOCGPGRP: @@ -901,8 +901,8 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 { // // If who is positive, it represents a PID. If negative, it represents a PGID. // If the PID or PGID is invalid, the owner is silently unset. -func fSetOwn(t *kernel.Task, file *fs.File, who int32) error { - a := file.Async(fasync.New).(*fasync.FileAsync) +func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error { + a := file.Async(fasync.New(fd)).(*fasync.FileAsync) if who < 0 { // Check for overflow before flipping the sign. if who-1 > who { @@ -1049,7 +1049,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall case linux.F_GETOWN: return uintptr(fGetOwn(t, file)), nil, nil case linux.F_SETOWN: - return 0, nil, fSetOwn(t, file, args[2].Int()) + return 0, nil, fSetOwn(t, int(fd), file, args[2].Int()) case linux.F_GETOWN_EX: addr := args[2].Pointer() owner := fGetOwnEx(t, file) @@ -1062,7 +1062,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - a := file.Async(fasync.New).(*fasync.FileAsync) + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) switch owner.Type { case linux.F_OWNER_TID: task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID)) @@ -1111,6 +1111,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } n, err := sz.SetFifoSize(int64(args[2].Int())) return uintptr(n), nil, err + case linux.F_GETSIG: + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) + return uintptr(a.Signal()), nil, nil + case linux.F_SETSIG: + a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync) + return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. return 0, nil, syserror.EINVAL diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index e383a0a87..d324461a3 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -48,6 +48,15 @@ func Semget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return uintptr(set.ID), nil, nil } +// Semtimedop handles: semop(int semid, struct sembuf *sops, size_t nsops, const struct timespec *timeout) +func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + // TODO(gvisor.dev/issue/137): A non-zero timeout isn't supported. + if args[3].Pointer() != 0 { + return 0, nil, syserror.ENOSYS + } + return Semop(t, args) +} + // Semop handles: semop(int semid, struct sembuf *sops, size_t nsops) func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { id := args[0].Int() @@ -146,11 +155,37 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal v, err := getNCnt(t, id, num) return uintptr(v), nil, err - case linux.IPC_INFO, - linux.SEM_INFO, - linux.SEM_STAT, - linux.SEM_STAT_ANY: + case linux.IPC_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.IPCInfo() + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil + case linux.SEM_INFO: + buf := args[3].Pointer() + r := t.IPCNamespace().SemaphoreRegistry() + info := r.SemInfo() + if _, err := info.CopyOut(t, buf); err != nil { + return 0, nil, err + } + return uintptr(r.HighestIndex()), nil, nil + + case linux.SEM_STAT: + arg := args[3].Pointer() + // id is an index in SEM_STAT. + semid, ds, err := semStat(t, id) + if err != nil { + return 0, nil, err + } + if _, err := ds.CopyOut(t, arg); err != nil { + return 0, nil, err + } + return uintptr(semid), nil, err + + case linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -195,6 +230,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { return set.GetStat(creds) } +func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByIndex(index) + if set == nil { + return 0, nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + ds, err := set.GetStat(creds) + return set.ID, ds, err +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index e748d33d8..d639c9bf7 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(target.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) + info.SetPID(int32(target.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow())) if err := target.SendGroupSignal(info); err != syserror.ESRCH { return 0, nil, err } @@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) err := tg.SendSignal(info) if err == syserror.ESRCH { // ESRCH is ignored because it means the task @@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC Signo: int32(sig), Code: arch.SignalInfoUser, } - info.SetPid(int32(tg.PIDNamespace().IDOfTask(t))) - info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) + info.SetPID(int32(tg.PIDNamespace().IDOfTask(t))) + info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow())) // See note above regarding ESRCH race above. if err := tg.SendSignal(info); err != syserror.ESRCH { lastErr = err @@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI Signo: int32(sig), Code: arch.SignalInfoTkill, } - info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) - info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) + info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup()))) + info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow())) return info } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 9cd052c3d..fe45225c1 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -749,11 +749,6 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i return 0, err } - // FIXME(b/63594852): Pretend we have an empty error queue. - if flags&linux.MSG_ERRQUEUE != 0 { - return 0, syserror.EAGAIN - } - // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) @@ -1035,7 +1030,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme return 0, err } - controlMessages, err := control.Parse(t, s, controlData) + controlMessages, err := control.Parse(t, s, controlData, t.Arch().Width()) if err != nil { return 0, err } diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go index 983f8d396..8e7ac0ffe 100644 --- a/pkg/sentry/syscalls/linux/sys_thread.go +++ b/pkg/sentry/syscalls/linux/sys_thread.go @@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal si := arch.SignalInfo{ Signo: int32(linux.SIGCHLD), } - si.SetPid(int32(wr.TID)) - si.SetUid(int32(wr.UID)) + si.SetPID(int32(wr.TID)) + si.SetUID(int32(wr.UID)) // TODO(b/73541790): convert kernel.ExitStatus to functions and make // WaitResult.Status a linux.WaitStatus. s := syscall.WaitStatus(wr.Status) diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 6d0a38330..1365a5a62 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user if !ok { return syserror.EINVAL } - if ready := aioCtx.Prepare(); !ready { - // Context is busy. - return syserror.EAGAIN + if err := aioCtx.Prepare(); err != nil { + return err } if eventFD != nil { diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go index 36e89700e..7dd9ef857 100644 --- a/pkg/sentry/syscalls/linux/vfs2/fd.go +++ b/pkg/sentry/syscalls/linux/vfs2/fd.go @@ -165,7 +165,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ownerType = linux.F_OWNER_PGRP who = -who } - return 0, nil, setAsyncOwner(t, file, ownerType, who) + return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who) case linux.F_GETOWN_EX: owner, hasOwner := getAsyncOwner(t, file) if !hasOwner { @@ -179,7 +179,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall if err != nil { return 0, nil, err } - return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID) + return 0, nil, setAsyncOwner(t, int(fd), file, owner.Type, owner.PID) case linux.F_SETPIPE_SZ: pipefile, ok := file.Impl().(*pipe.VFSPipeFD) if !ok { @@ -207,6 +207,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall return 0, nil, err case linux.F_SETLK, linux.F_SETLKW: return 0, nil, posixLock(t, args, file, cmd) + case linux.F_GETSIG: + a := file.AsyncHandler() + if a == nil { + // Default behavior aka SIGIO. + return 0, nil, nil + } + return uintptr(a.(*fasync.FileAsync).Signal()), nil, nil + case linux.F_SETSIG: + a := file.SetAsyncHandler(fasync.NewVFS2(int(fd))).(*fasync.FileAsync) + return 0, nil, a.SetSignal(linux.Signal(args[2].Int())) default: // Everything else is not yet supported. return 0, nil, syserror.EINVAL @@ -241,7 +251,7 @@ func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwne } } -func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error { +func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, pid int32) error { switch ownerType { case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP: // Acceptable type. @@ -249,7 +259,7 @@ func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32 return syserror.EINVAL } - a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync) + a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync) if pid == 0 { a.ClearOwner() return nil diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go index 2806c3f6f..20c264fef 100644 --- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go +++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go @@ -100,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall ownerType = linux.F_OWNER_PGRP who = -who } - return 0, nil, setAsyncOwner(t, file, ownerType, who) + return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who) } ret, err := file.Ioctl(t, t.MemoryManager(), args) diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go index ee38fdca0..6986e39fe 100644 --- a/pkg/sentry/syscalls/linux/vfs2/pipe.go +++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go @@ -42,7 +42,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return syserror.EINVAL } - r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) + r, w, err := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK)) + if err != nil { + return err + } defer r.DecRef(t) defer w.DecRef(t) diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go index 7b33b3f59..f5795b4a8 100644 --- a/pkg/sentry/syscalls/linux/vfs2/socket.go +++ b/pkg/sentry/syscalls/linux/vfs2/socket.go @@ -752,11 +752,6 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla return 0, err } - // FIXME(b/63594852): Pretend we have an empty error queue. - if flags&linux.MSG_ERRQUEUE != 0 { - return 0, syserror.EAGAIN - } - // Fast path when no control message nor name buffers are provided. if msg.ControlLen == 0 && msg.NameLen == 0 { n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0) @@ -1038,7 +1033,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio return 0, err } - controlMessages, err := control.Parse(t, s, controlData) + controlMessages, err := control.Parse(t, s, controlData, t.Arch().Width()) if err != nil { return 0, err } diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index a98aac52b..072655fe8 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -204,8 +204,8 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin file.EventRegister(&epi.waiter, wmask) // Check if the file is already ready. - if file.Readiness(wmask)&wmask != 0 { - epi.Callback(nil) + if m := file.Readiness(wmask) & wmask; m != 0 { + epi.Callback(nil, m) } // Add epi to file.epolls so that it is removed when the last @@ -274,8 +274,8 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event file.EventRegister(&epi.waiter, wmask) // Check if the file is already ready with the new mask. - if file.Readiness(wmask)&wmask != 0 { - epi.Callback(nil) + if m := file.Readiness(wmask) & wmask; m != 0 { + epi.Callback(nil, m) } return nil @@ -311,7 +311,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error } // Callback implements waiter.EntryCallback.Callback. -func (epi *epollInterest) Callback(*waiter.Entry) { +func (epi *epollInterest) Callback(*waiter.Entry, waiter.EventMask) { newReady := false epi.epoll.mu.Lock() if !epi.ready { diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index f9e39a94c..5321ac80a 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -15,6 +15,7 @@ package vfs import ( + "io" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -43,7 +44,7 @@ import ( type FileDescription struct { FileDescriptionRefs - // flagsMu protects statusFlags and asyncHandler below. + // flagsMu protects `statusFlags`, `saved`, and `asyncHandler` below. flagsMu sync.Mutex `state:"nosave"` // statusFlags contains status flags, "initialized by open(2) and possibly @@ -52,6 +53,11 @@ type FileDescription struct { // access to asyncHandler. statusFlags uint32 + // saved is true after beforeSave is called. This is used to prevent + // double-unregistration of asyncHandler. This does not work properly for + // save-resume, which is not currently supported in gVisor (see b/26588733). + saved bool `state:"nosave"` + // asyncHandler handles O_ASYNC signal generation. It is set with the // F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must // also be set by fcntl(2). @@ -184,7 +190,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.vd.DecRef(ctx) fd.flagsMu.Lock() - if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + if !fd.saved && fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { fd.asyncHandler.Unregister(fd) } fd.asyncHandler = nil @@ -834,44 +840,27 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn return fd.asyncHandler } -// FileReadWriteSeeker is a helper struct to pass a FileDescription as -// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc. -type FileReadWriteSeeker struct { - FD *FileDescription - Ctx context.Context - ROpts ReadOptions - WOpts WriteOptions -} - -// ReadAt implements io.ReaderAt.ReadAt. -func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts) - return int(n), err -} - -// Read implements io.ReadWriteSeeker.Read. -func (f *FileReadWriteSeeker) Read(p []byte) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.Read(f.Ctx, dst, f.ROpts) - return int(n), err -} - -// Seek implements io.ReadWriteSeeker.Seek. -func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) { - return f.FD.Seek(f.Ctx, offset, int32(whence)) -} - -// WriteAt implements io.WriterAt.WriteAt. -func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) { - dst := usermem.BytesIOSequence(p) - n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts) - return int(n), err -} - -// Write implements io.ReadWriteSeeker.Write. -func (f *FileReadWriteSeeker) Write(p []byte) (int, error) { - buf := usermem.BytesIOSequence(p) - n, err := f.FD.Write(f.Ctx, buf, f.WOpts) - return int(n), err +// CopyRegularFileData copies data from srcFD to dstFD until reading from srcFD +// returns EOF or an error. It returns the number of bytes copied. +func CopyRegularFileData(ctx context.Context, dstFD, srcFD *FileDescription) (int64, error) { + done := int64(0) + buf := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size + for { + readN, readErr := srcFD.Read(ctx, buf, ReadOptions{}) + if readErr != nil && readErr != io.EOF { + return done, readErr + } + src := buf.TakeFirst64(readN) + for src.NumBytes() != 0 { + writeN, writeErr := dstFD.Write(ctx, src, WriteOptions{}) + done += writeN + src = src.DropFirst64(writeN) + if writeErr != nil { + return done, writeErr + } + } + if readErr == io.EOF { + return done, nil + } + } } diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index cb48c37a1..0df023713 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// +build go1.12 -// +build !go1.17 - -// Check go:linkname function signatures when updating Go version. - package vfs import ( @@ -41,6 +36,15 @@ type mountKey struct { point unsafe.Pointer // *Dentry } +var ( + mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil)) + mountKeySeed = sync.RandUintptr() +) + +func (k *mountKey) hash() uintptr { + return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed) +} + func (mnt *Mount) parent() *Mount { return (*Mount)(atomic.LoadPointer(&mnt.key.parent)) } @@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry { } } -func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } - // Invariant: mnt.key.parent == nil. vd.Ok(). func (mnt *Mount) setKey(vd VirtualDentry) { atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount)) atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry)) } -func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } - // mountTable maps (mount parent, mount point) pairs to mounts. It supports // efficient concurrent lookup, even in the presence of concurrent mutators // (provided mutation is sufficiently uncommon). // // mountTable.Init() must be called on new mountTables before use. -// -// +stateify savable type mountTable struct { // mountTable is implemented as a seqcount-protected hash table that // resolves collisions with linear probing, featuring Robin Hood insertion @@ -84,8 +82,7 @@ type mountTable struct { // intrinsics and inline assembly, limiting the performance of this // approach.) - seq sync.SeqCount `state:"nosave"` - seed uint32 // for hashing keys + seq sync.SeqCount `state:"nosave"` // size holds both length (number of elements) and capacity (number of // slots): capacity is stored as its base-2 log (referred to as order) in @@ -150,7 +147,6 @@ func init() { // Init must be called exactly once on each mountTable before use. func (mt *mountTable) Init() { - mt.seed = rand32() mt.size = mtInitOrder mt.slots = newMountTableSlots(mtInitCap) } @@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer { // Lookup may be called even if there are concurrent mutators of mt. func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount { key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)} - hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes) + hash := key.hash() loop: for { @@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must not already contain a Mount with the same mount point and parent. func (mt *mountTable) insertSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() // We're under the maximum load factor if: // @@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) { // * mt.seq must be in a writer critical section. // * mt must contain mount. func (mt *mountTable) removeSeqed(mount *Mount) { - hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes) + hash := mount.key.hash() tcap := uintptr(1) << (mt.size & mtSizeOrderMask) mask := tcap - 1 slots := mt.slots @@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) { off = (off + mountSlotBytes) & offmask } } - -//go:linkname memhash runtime.memhash -func memhash(p unsafe.Pointer, seed, s uintptr) uintptr - -//go:linkname rand32 runtime.fastrand -func rand32() uint32 diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go index 7723ed643..8998a82dd 100644 --- a/pkg/sentry/vfs/save_restore.go +++ b/pkg/sentry/vfs/save_restore.go @@ -18,8 +18,10 @@ import ( "fmt" "sync/atomic" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/waiter" ) // FilesystemImplSaveRestoreExtension is an optional extension to @@ -99,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount { return mounts } +// saveKey is called by stateify. +func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() } + // loadMounts is called by stateify. func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { if mounts == nil { @@ -110,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { } } +// loadKey is called by stateify. +func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) } + func (mnt *Mount) afterLoad() { if atomic.LoadInt64(&mnt.refs) != 0 { refsvfs2.Register(mnt) @@ -120,5 +128,20 @@ func (mnt *Mount) afterLoad() { func (epi *epollInterest) afterLoad() { // Mark all epollInterests as ready after restore so that the next call to // EpollInstance.ReadEvents() rechecks their readiness. - epi.Callback(nil) + epi.Callback(nil, waiter.EventMaskFromLinux(epi.mask)) +} + +// beforeSave is called by stateify. +func (fd *FileDescription) beforeSave() { + fd.saved = true + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Unregister(fd) + } +} + +// afterLoad is called by stateify. +func (fd *FileDescription) afterLoad() { + if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { + fd.asyncHandler.Register(fd) + } } |