summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry')
-rw-r--r--pkg/sentry/arch/arch.go15
-rw-r--r--pkg/sentry/arch/arch_state_x86.go17
-rw-r--r--pkg/sentry/arch/signal.go39
-rw-r--r--pkg/sentry/control/pprof.go287
-rw-r--r--pkg/sentry/control/state.go1
-rw-r--r--pkg/sentry/fdimport/fdimport.go1
-rw-r--r--pkg/sentry/fs/copy_up.go13
-rw-r--r--pkg/sentry/fs/copy_up_test.go2
-rw-r--r--pkg/sentry/fs/filetest/filetest.go4
-rw-r--r--pkg/sentry/fs/fs.go2
-rw-r--r--pkg/sentry/fs/gofer/attr.go2
-rw-r--r--pkg/sentry/fs/gofer/inode.go3
-rw-r--r--pkg/sentry/fs/host/inode.go4
-rw-r--r--pkg/sentry/fs/proc/sys.go1
-rw-r--r--pkg/sentry/fs/ramfs/socket.go3
-rw-r--r--pkg/sentry/fs/tmpfs/inode_file.go4
-rw-r--r--pkg/sentry/fsimpl/ext/inode.go4
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_control.go6
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_test.go10
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go5
-rw-r--r--pkg/sentry/fsimpl/fuse/directory.go6
-rw-r--r--pkg/sentry/fsimpl/fuse/file.go8
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go45
-rw-r--r--pkg/sentry/fsimpl/fuse/read_write.go12
-rw-r--r--pkg/sentry/fsimpl/fuse/request_response.go5
-rw-r--r--pkg/sentry/fsimpl/host/host.go5
-rw-r--r--pkg/sentry/fsimpl/overlay/copy_up.go23
-rw-r--r--pkg/sentry/fsimpl/overlay/regular_file.go4
-rw-r--r--pkg/sentry/fsimpl/pipefs/pipefs.go4
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go17
-rw-r--r--pkg/sentry/fsimpl/proc/tasks_sys.go1
-rw-r--r--pkg/sentry/fsimpl/signalfd/signalfd.go5
-rw-r--r--pkg/sentry/fsimpl/verity/filesystem.go6
-rw-r--r--pkg/sentry/fsimpl/verity/verity.go94
-rw-r--r--pkg/sentry/fsimpl/verity/verity_test.go232
-rw-r--r--pkg/sentry/kernel/epoll/epoll.go7
-rw-r--r--pkg/sentry/kernel/fasync/BUILD2
-rw-r--r--pkg/sentry/kernel/fasync/fasync.go96
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go11
-rw-r--r--pkg/sentry/kernel/kernel.go50
-rw-r--r--pkg/sentry/kernel/pipe/vfs.go28
-rw-r--r--pkg/sentry/kernel/ptrace.go4
-rw-r--r--pkg/sentry/kernel/semaphore/semaphore.go115
-rw-r--r--pkg/sentry/kernel/shm/BUILD3
-rw-r--r--pkg/sentry/kernel/signal.go4
-rw-r--r--pkg/sentry/kernel/signalfd/signalfd.go4
-rw-r--r--pkg/sentry/kernel/task_exit.go8
-rw-r--r--pkg/sentry/kernel/task_signals.go16
-rw-r--r--pkg/sentry/memmap/memmap.go2
-rw-r--r--pkg/sentry/mm/aio_context.go79
-rw-r--r--pkg/sentry/mm/aio_context_state.go4
-rw-r--r--pkg/sentry/mm/lifecycle.go2
-rw-r--r--pkg/sentry/mm/mm_test.go43
-rw-r--r--pkg/sentry/pgalloc/pgalloc.go62
-rw-r--r--pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go14
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64.go7
-rw-r--r--pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go40
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go5
-rw-r--r--pkg/sentry/platform/kvm/kvm_arm64.go9
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go1
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go2
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64_unsafe.go4
-rw-r--r--pkg/sentry/platform/ptrace/ptrace.go4
-rw-r--r--pkg/sentry/platform/ptrace/subprocess.go2
-rw-r--r--pkg/sentry/platform/ring0/BUILD11
-rw-r--r--pkg/sentry/platform/ring0/entry_arm64.s165
-rw-r--r--pkg/sentry/platform/ring0/gen_offsets/BUILD3
-rw-r--r--pkg/sentry/platform/ring0/kernel_arm64.go8
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.go16
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64.s17
-rw-r--r--pkg/sentry/platform/ring0/lib_arm64_unsafe.go108
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD17
-rw-r--r--pkg/sentry/socket/control/control.go125
-rw-r--r--pkg/sentry/socket/control/control_test.go59
-rw-r--r--pkg/sentry/socket/hostinet/socket.go179
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go1036
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go6
-rw-r--r--pkg/sentry/socket/netstack/provider.go2
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go2
-rw-r--r--pkg/sentry/socket/netstack/stack.go30
-rw-r--r--pkg/sentry/socket/socket.go304
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go40
-rw-r--r--pkg/sentry/socket/unix/unix.go12
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go2
-rw-r--r--pkg/sentry/strace/BUILD2
-rw-r--r--pkg/sentry/strace/socket.go4
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_aio.go5
-rw-r--r--pkg/sentry/syscalls/linux/sys_file.go28
-rw-r--r--pkg/sentry/syscalls/linux/sys_sem.go54
-rw-r--r--pkg/sentry/syscalls/linux/sys_signal.go16
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go7
-rw-r--r--pkg/sentry/syscalls/linux/sys_thread.go4
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/aio.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/fd.go18
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/ioctl.go2
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/pipe.go5
-rw-r--r--pkg/sentry/syscalls/linux/vfs2/socket.go7
-rw-r--r--pkg/sentry/vfs/epoll.go10
-rw-r--r--pkg/sentry/vfs/file_description.go73
-rw-r--r--pkg/sentry/vfs/mount_unsafe.go36
-rw-r--r--pkg/sentry/vfs/save_restore.go25
104 files changed, 2374 insertions, 1597 deletions
diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go
index d75d665ae..dd2effdf9 100644
--- a/pkg/sentry/arch/arch.go
+++ b/pkg/sentry/arch/arch.go
@@ -365,3 +365,18 @@ func (a SyscallArgument) SizeT() uint {
func (a SyscallArgument) ModeT() uint {
return uint(uint16(a.Value))
}
+
+// ErrFloatingPoint indicates a failed restore due to unusable floating point
+// state.
+type ErrFloatingPoint struct {
+ // supported is the supported floating point state.
+ supported uint64
+
+ // saved is the saved floating point state.
+ saved uint64
+}
+
+// Error returns a sensible description of the restore error.
+func (e ErrFloatingPoint) Error() string {
+ return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
+}
diff --git a/pkg/sentry/arch/arch_state_x86.go b/pkg/sentry/arch/arch_state_x86.go
index 19ce99d25..840e53d33 100644
--- a/pkg/sentry/arch/arch_state_x86.go
+++ b/pkg/sentry/arch/arch_state_x86.go
@@ -17,27 +17,10 @@
package arch
import (
- "fmt"
-
"gvisor.dev/gvisor/pkg/cpuid"
"gvisor.dev/gvisor/pkg/usermem"
)
-// ErrFloatingPoint indicates a failed restore due to unusable floating point
-// state.
-type ErrFloatingPoint struct {
- // supported is the supported floating point state.
- supported uint64
-
- // saved is the saved floating point state.
- saved uint64
-}
-
-// Error returns a sensible description of the restore error.
-func (e ErrFloatingPoint) Error() string {
- return fmt.Sprintf("floating point state contains unsupported features; supported: %#x saved: %#x", e.supported, e.saved)
-}
-
// XSTATE_BV does not exist if FXSAVE is used, but FXSAVE implicitly saves x87
// and SSE state, so this is the equivalent XSTATE_BV value.
const fxsaveBV uint64 = cpuid.XSAVEFeatureX87 | cpuid.XSAVEFeatureSSE
diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go
index c9fb55d00..35d2e07c3 100644
--- a/pkg/sentry/arch/signal.go
+++ b/pkg/sentry/arch/signal.go
@@ -152,23 +152,23 @@ func (s *SignalInfo) FixSignalCodeForUser() {
}
}
-// Pid returns the si_pid field.
-func (s *SignalInfo) Pid() int32 {
+// PID returns the si_pid field.
+func (s *SignalInfo) PID() int32 {
return int32(usermem.ByteOrder.Uint32(s.Fields[0:4]))
}
-// SetPid mutates the si_pid field.
-func (s *SignalInfo) SetPid(val int32) {
+// SetPID mutates the si_pid field.
+func (s *SignalInfo) SetPID(val int32) {
usermem.ByteOrder.PutUint32(s.Fields[0:4], uint32(val))
}
-// Uid returns the si_uid field.
-func (s *SignalInfo) Uid() int32 {
+// UID returns the si_uid field.
+func (s *SignalInfo) UID() int32 {
return int32(usermem.ByteOrder.Uint32(s.Fields[4:8]))
}
-// SetUid mutates the si_uid field.
-func (s *SignalInfo) SetUid(val int32) {
+// SetUID mutates the si_uid field.
+func (s *SignalInfo) SetUID(val int32) {
usermem.ByteOrder.PutUint32(s.Fields[4:8], uint32(val))
}
@@ -251,3 +251,26 @@ func (s *SignalInfo) Arch() uint32 {
func (s *SignalInfo) SetArch(val uint32) {
usermem.ByteOrder.PutUint32(s.Fields[12:16], val)
}
+
+// Band returns the si_band field.
+func (s *SignalInfo) Band() int64 {
+ return int64(usermem.ByteOrder.Uint64(s.Fields[0:8]))
+}
+
+// SetBand mutates the si_band field.
+func (s *SignalInfo) SetBand(val int64) {
+ // Note: this assumes the platform uses `long` as `__ARCH_SI_BAND_T`.
+ // On some platforms, which gVisor doesn't support, `__ARCH_SI_BAND_T` is
+ // `int`. See siginfo.h.
+ usermem.ByteOrder.PutUint64(s.Fields[0:8], uint64(val))
+}
+
+// FD returns the si_fd field.
+func (s *SignalInfo) FD() uint32 {
+ return usermem.ByteOrder.Uint32(s.Fields[8:12])
+}
+
+// SetFD mutates the si_fd field.
+func (s *SignalInfo) SetFD(val uint32) {
+ usermem.ByteOrder.PutUint32(s.Fields[8:12], val)
+}
diff --git a/pkg/sentry/control/pprof.go b/pkg/sentry/control/pprof.go
index 2bf3c45e1..2f3664c57 100644
--- a/pkg/sentry/control/pprof.go
+++ b/pkg/sentry/control/pprof.go
@@ -15,10 +15,10 @@
package control
import (
- "errors"
"runtime"
"runtime/pprof"
"runtime/trace"
+ "time"
"gvisor.dev/gvisor/pkg/fd"
"gvisor.dev/gvisor/pkg/sentry/kernel"
@@ -26,184 +26,263 @@ import (
"gvisor.dev/gvisor/pkg/urpc"
)
-var errNoOutput = errors.New("no output writer provided")
+// Profile includes profile-related RPC stubs. It provides a way to
+// control the built-in runtime profiling facilities.
+//
+// The profile object must be instantied via NewProfile.
+type Profile struct {
+ // kernel is the kernel under profile. It's immutable.
+ kernel *kernel.Kernel
-// ProfileOpts contains options for the StartCPUProfile/Goroutine RPC call.
-type ProfileOpts struct {
- // File is the filesystem path for the profile.
- File string `json:"path"`
+ // cpuMu protects CPU profiling.
+ cpuMu sync.Mutex
- // FilePayload is the destination for the profiling output.
- urpc.FilePayload
+ // blockMu protects block profiling.
+ blockMu sync.Mutex
+
+ // mutexMu protects mutex profiling.
+ mutexMu sync.Mutex
+
+ // traceMu protects trace profiling.
+ traceMu sync.Mutex
+
+ // done is closed when profiling is done.
+ done chan struct{}
}
-// Profile includes profile-related RPC stubs. It provides a way to
-// control the built-in pprof facility in sentry via sentryctl.
-//
-// The following options to sentryctl are added:
-//
-// - collect CPU profile on-demand.
-// sentryctl -pid <pid> pprof-cpu-start
-// sentryctl -pid <pid> pprof-cpu-stop
-//
-// - dump out the stack trace of current go routines.
-// sentryctl -pid <pid> pprof-goroutine
-type Profile struct {
- // Kernel is the kernel under profile. It's immutable.
- Kernel *kernel.Kernel
+// NewProfile returns a new Profile object.
+func NewProfile(k *kernel.Kernel) *Profile {
+ return &Profile{
+ kernel: k,
+ done: make(chan struct{}),
+ }
+}
- // mu protects the fields below.
- mu sync.Mutex
+// Stop implements urpc.Stopper.Stop.
+func (p *Profile) Stop() {
+ close(p.done)
+}
- // cpuFile is the current CPU profile output file.
- cpuFile *fd.FD
+// CPUProfileOpts contains options specifically for CPU profiles.
+type CPUProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
- // traceFile is the current execution trace output file.
- traceFile *fd.FD
+ // Duration is the duration of the profile.
+ Duration time.Duration `json:"duration"`
}
-// StartCPUProfile is an RPC stub which starts recording the CPU profile in a
-// file.
-func (p *Profile) StartCPUProfile(o *ProfileOpts, _ *struct{}) error {
+// CPU is an RPC stub which collects a CPU profile.
+func (p *Profile) CPU(o *CPUProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
- output, err := fd.NewFromFile(o.FilePayload.Files[0])
- if err != nil {
- return err
- }
+ output := o.FilePayload.Files[0]
+ defer output.Close()
- p.mu.Lock()
- defer p.mu.Unlock()
+ p.cpuMu.Lock()
+ defer p.cpuMu.Unlock()
// Returns an error if profiling is already started.
if err := pprof.StartCPUProfile(output); err != nil {
- output.Close()
return err
}
+ defer pprof.StopCPUProfile()
+
+ // Collect the profile.
+ select {
+ case <-time.After(o.Duration):
+ case <-p.done:
+ }
- p.cpuFile = output
return nil
}
-// StopCPUProfile is an RPC stub which stops the CPU profiling and flush out the
-// profile data. It takes no argument.
-func (p *Profile) StopCPUProfile(_, _ *struct{}) error {
- p.mu.Lock()
- defer p.mu.Unlock()
-
- if p.cpuFile == nil {
- return errors.New("CPU profiling not started")
- }
+// HeapProfileOpts contains options specifically for heap profiles.
+type HeapProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
- pprof.StopCPUProfile()
- p.cpuFile.Close()
- p.cpuFile = nil
- return nil
+ // Delay is the sleep time, similar to Duration. This may
+ // not affect the data collected however, as the heap will
+ // continue only the memory associated with the last alloc.
+ Delay time.Duration `json:"delay"`
}
-// HeapProfile generates a heap profile for the sentry.
-func (p *Profile) HeapProfile(o *ProfileOpts, _ *struct{}) error {
+// Heap generates a heap profile.
+func (p *Profile) Heap(o *HeapProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
+
output := o.FilePayload.Files[0]
defer output.Close()
- runtime.GC() // Get up-to-date statistics.
- if err := pprof.WriteHeapProfile(output); err != nil {
- return err
+
+ // Wait for the given delay.
+ select {
+ case <-time.After(o.Delay):
+ case <-p.done:
}
- return nil
+
+ // Get up-to-date statistics.
+ runtime.GC()
+
+ // Write the given profile.
+ return pprof.WriteHeapProfile(output)
+}
+
+// GoroutineProfileOpts contains options specifically for goroutine profiles.
+type GoroutineProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
}
-// GoroutineProfile is an RPC stub which dumps out the stack trace for all
-// running goroutines.
-func (p *Profile) GoroutineProfile(o *ProfileOpts, _ *struct{}) error {
+// Goroutine dumps out the stack trace for all running goroutines.
+func (p *Profile) Goroutine(o *GoroutineProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
+
output := o.FilePayload.Files[0]
defer output.Close()
- if err := pprof.Lookup("goroutine").WriteTo(output, 2); err != nil {
- return err
- }
- return nil
+
+ return pprof.Lookup("goroutine").WriteTo(output, 2)
+}
+
+// BlockProfileOpts contains options specifically for block profiles.
+type BlockProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
+
+ // Duration is the duration of the profile.
+ Duration time.Duration `json:"duration"`
+
+ // Rate is the block profile rate.
+ Rate int `json:"rate"`
}
-// BlockProfile is an RPC stub which dumps out the stack trace that led to
-// blocking on synchronization primitives.
-func (p *Profile) BlockProfile(o *ProfileOpts, _ *struct{}) error {
+// Block dumps a blocking profile.
+func (p *Profile) Block(o *BlockProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
+
output := o.FilePayload.Files[0]
defer output.Close()
- if err := pprof.Lookup("block").WriteTo(output, 0); err != nil {
- return err
+
+ p.blockMu.Lock()
+ defer p.blockMu.Unlock()
+
+ // Always set the rate. We then wait to collect a profile at this rate,
+ // and disable when we're done. Note that the default here is 10%, which
+ // will record a stacktrace 10% of the time when blocking occurs. Since
+ // these events should not be super frequent, we expect this to achieve
+ // a reasonable balance between collecting the data we need and imposing
+ // a high performance cost (e.g. skewing even the CPU profile).
+ rate := 10
+ if o.Rate != 0 {
+ rate = o.Rate
}
- return nil
+ runtime.SetBlockProfileRate(rate)
+ defer runtime.SetBlockProfileRate(0)
+
+ // Collect the profile.
+ select {
+ case <-time.After(o.Duration):
+ case <-p.done:
+ }
+
+ return pprof.Lookup("block").WriteTo(output, 0)
+}
+
+// MutexProfileOpts contains options specifically for mutex profiles.
+type MutexProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
+
+ // Duration is the duration of the profile.
+ Duration time.Duration `json:"duration"`
+
+ // Fraction is the mutex profile fraction.
+ Fraction int `json:"fraction"`
}
-// MutexProfile is an RPC stub which dumps out the stack trace of holders of
-// contended mutexes.
-func (p *Profile) MutexProfile(o *ProfileOpts, _ *struct{}) error {
+// Mutex dumps a mutex profile.
+func (p *Profile) Mutex(o *MutexProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
+
output := o.FilePayload.Files[0]
defer output.Close()
- if err := pprof.Lookup("mutex").WriteTo(output, 0); err != nil {
- return err
+
+ p.mutexMu.Lock()
+ defer p.mutexMu.Unlock()
+
+ // Always set the fraction. Like the block rate above, we use
+ // a default rate of 10% for the same reasons.
+ fraction := 10
+ if o.Fraction != 0 {
+ fraction = o.Fraction
}
- return nil
+ runtime.SetMutexProfileFraction(fraction)
+ defer runtime.SetMutexProfileFraction(0)
+
+ // Collect the profile.
+ select {
+ case <-time.After(o.Duration):
+ case <-p.done:
+ }
+
+ return pprof.Lookup("mutex").WriteTo(output, 0)
}
-// StartTrace is an RPC stub which starts collection of an execution trace.
-func (p *Profile) StartTrace(o *ProfileOpts, _ *struct{}) error {
+// TraceProfileOpts contains options specifically for traces.
+type TraceProfileOpts struct {
+ // FilePayload is the destination for the profiling output.
+ urpc.FilePayload
+
+ // Duration is the duration of the profile.
+ Duration time.Duration `json:"duration"`
+}
+
+// Trace is an RPC stub which starts collection of an execution trace.
+func (p *Profile) Trace(o *TraceProfileOpts, _ *struct{}) error {
if len(o.FilePayload.Files) < 1 {
- return errNoOutput
+ return nil // Allowed.
}
output, err := fd.NewFromFile(o.FilePayload.Files[0])
if err != nil {
return err
}
+ defer output.Close()
- p.mu.Lock()
- defer p.mu.Unlock()
+ p.traceMu.Lock()
+ defer p.traceMu.Unlock()
// Returns an error if profiling is already started.
if err := trace.Start(output); err != nil {
output.Close()
return err
}
+ defer trace.Stop()
// Ensure all trace contexts are registered.
- p.Kernel.RebuildTraceContexts()
-
- p.traceFile = output
- return nil
-}
-
-// StopTrace is an RPC stub which stops collection of an ongoing execution
-// trace and flushes the trace data. It takes no argument.
-func (p *Profile) StopTrace(_, _ *struct{}) error {
- p.mu.Lock()
- defer p.mu.Unlock()
+ p.kernel.RebuildTraceContexts()
- if p.traceFile == nil {
- return errors.New("Execution tracing not started")
+ // Wait for the trace.
+ select {
+ case <-time.After(o.Duration):
+ case <-p.done:
}
// Similarly to the case above, if tasks have not ended traces, we will
// lose information. Thus we need to rebuild the tasks in order to have
// complete information. This will not lose information if multiple
// traces are overlapping.
- p.Kernel.RebuildTraceContexts()
+ p.kernel.RebuildTraceContexts()
- trace.Stop()
- p.traceFile.Close()
- p.traceFile = nil
return nil
}
diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go
index d800f2c85..62eaca965 100644
--- a/pkg/sentry/control/state.go
+++ b/pkg/sentry/control/state.go
@@ -62,6 +62,7 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error {
Callback: func(err error) {
if err == nil {
log.Infof("Save succeeded: exiting...")
+ s.Kernel.SetSaveSuccess(false /* autosave */)
} else {
log.Warningf("Save failed: exiting...")
s.Kernel.SetSaveError(err)
diff --git a/pkg/sentry/fdimport/fdimport.go b/pkg/sentry/fdimport/fdimport.go
index 314661475..badd5b073 100644
--- a/pkg/sentry/fdimport/fdimport.go
+++ b/pkg/sentry/fdimport/fdimport.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package fdimport provides the Import function.
package fdimport
import (
diff --git a/pkg/sentry/fs/copy_up.go b/pkg/sentry/fs/copy_up.go
index ff2fe6712..8e0aa9019 100644
--- a/pkg/sentry/fs/copy_up.go
+++ b/pkg/sentry/fs/copy_up.go
@@ -336,7 +336,12 @@ func cleanupUpper(ctx context.Context, parent *Inode, name string, copyUpErr err
// copyUpBuffers is a buffer pool for copying file content. The buffer
// size is the same used by io.Copy.
-var copyUpBuffers = sync.Pool{New: func() interface{} { return make([]byte, 8*usermem.PageSize) }}
+var copyUpBuffers = sync.Pool{
+ New: func() interface{} {
+ b := make([]byte, 8*usermem.PageSize)
+ return &b
+ },
+}
// copyContentsLocked copies the contents of lower to upper. It panics if
// less than size bytes can be copied.
@@ -361,7 +366,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
defer lowerFile.DecRef(ctx)
// Use a buffer pool to minimize allocations.
- buf := copyUpBuffers.Get().([]byte)
+ buf := copyUpBuffers.Get().(*[]byte)
defer copyUpBuffers.Put(buf)
// Transfer the contents.
@@ -371,7 +376,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
// optimizations could be self-defeating. So we leave this as simple as possible.
var offset int64
for {
- nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(buf), offset)
+ nr, err := lowerFile.FileOperations.Read(ctx, lowerFile, usermem.BytesIOSequence(*buf), offset)
if err != nil && err != io.EOF {
return err
}
@@ -383,7 +388,7 @@ func copyContentsLocked(ctx context.Context, upper *Inode, lower *Inode, size in
}
return nil
}
- nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence(buf[:nr]), offset)
+ nw, err := upperFile.FileOperations.Write(ctx, upperFile, usermem.BytesIOSequence((*buf)[:nr]), offset)
if err != nil {
return err
}
diff --git a/pkg/sentry/fs/copy_up_test.go b/pkg/sentry/fs/copy_up_test.go
index c7a11eec1..e04784db2 100644
--- a/pkg/sentry/fs/copy_up_test.go
+++ b/pkg/sentry/fs/copy_up_test.go
@@ -64,7 +64,7 @@ func TestConcurrentCopyUp(t *testing.T) {
wg.Add(1)
go func(o *overlayTestFile) {
if err := o.File.Dirent.Inode.Truncate(ctx, o.File.Dirent, truncateFileSize); err != nil {
- t.Fatalf("failed to copy up: %v", err)
+ t.Errorf("failed to copy up: %v", err)
}
wg.Done()
}(file)
diff --git a/pkg/sentry/fs/filetest/filetest.go b/pkg/sentry/fs/filetest/filetest.go
index 8049538f2..ec3d3f96c 100644
--- a/pkg/sentry/fs/filetest/filetest.go
+++ b/pkg/sentry/fs/filetest/filetest.go
@@ -52,10 +52,10 @@ func NewTestFile(tb testing.TB) *fs.File {
// Read just fails the request.
func (*TestFileOperations) Read(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, fmt.Errorf("Readv not implemented")
+ return 0, fmt.Errorf("TestFileOperations.Read not implemented")
}
// Write just fails the request.
func (*TestFileOperations) Write(context.Context, *fs.File, usermem.IOSequence, int64) (int64, error) {
- return 0, fmt.Errorf("Writev not implemented")
+ return 0, fmt.Errorf("TestFileOperations.Write not implemented")
}
diff --git a/pkg/sentry/fs/fs.go b/pkg/sentry/fs/fs.go
index d2dbff268..a020da53b 100644
--- a/pkg/sentry/fs/fs.go
+++ b/pkg/sentry/fs/fs.go
@@ -65,7 +65,7 @@ var (
// runs with the lock held for reading. AsyncBarrier will take the lock
// for writing, thus ensuring that all Async work completes before
// AsyncBarrier returns.
- workMu sync.RWMutex
+ workMu sync.CrossGoroutineRWMutex
// asyncError is used to store up to one asynchronous execution error.
asyncError = make(chan error, 1)
diff --git a/pkg/sentry/fs/gofer/attr.go b/pkg/sentry/fs/gofer/attr.go
index d481baf77..e5579095b 100644
--- a/pkg/sentry/fs/gofer/attr.go
+++ b/pkg/sentry/fs/gofer/attr.go
@@ -117,8 +117,6 @@ func ntype(pattr p9.Attr) fs.InodeType {
return fs.BlockDevice
case pattr.Mode.IsSocket():
return fs.Socket
- case pattr.Mode.IsRegular():
- fallthrough
default:
return fs.RegularFile
}
diff --git a/pkg/sentry/fs/gofer/inode.go b/pkg/sentry/fs/gofer/inode.go
index 9d6fdd08f..e840b6f5e 100644
--- a/pkg/sentry/fs/gofer/inode.go
+++ b/pkg/sentry/fs/gofer/inode.go
@@ -475,6 +475,9 @@ func (i *inodeOperations) Check(ctx context.Context, inode *fs.Inode, p fs.PermM
func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
switch d.Inode.StableAttr.Type {
case fs.Socket:
+ if i.session().overrides != nil {
+ return nil, syserror.ENXIO
+ }
return i.getFileSocket(ctx, d, flags)
case fs.Pipe:
return i.getFilePipe(ctx, d, flags)
diff --git a/pkg/sentry/fs/host/inode.go b/pkg/sentry/fs/host/inode.go
index fbfba1b58..2c14aa6d9 100644
--- a/pkg/sentry/fs/host/inode.go
+++ b/pkg/sentry/fs/host/inode.go
@@ -276,6 +276,10 @@ func (i *inodeOperations) BoundEndpoint(inode *fs.Inode, path string) transport.
// GetFile implements fs.InodeOperations.GetFile.
func (i *inodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ if fs.IsSocket(d.Inode.StableAttr) {
+ return nil, syserror.ENXIO
+ }
+
return newFile(ctx, d, flags, i), nil
}
diff --git a/pkg/sentry/fs/proc/sys.go b/pkg/sentry/fs/proc/sys.go
index f8aad2dbd..b998fb75d 100644
--- a/pkg/sentry/fs/proc/sys.go
+++ b/pkg/sentry/fs/proc/sys.go
@@ -84,6 +84,7 @@ func (p *proc) newKernelDir(ctx context.Context, msrc *fs.MountSource) *fs.Inode
children := map[string]*fs.Inode{
"hostname": newProcInode(ctx, &h, msrc, fs.SpecialFile, nil),
+ "sem": newStaticProcInode(ctx, msrc, []byte(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
"shmall": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMALL, 10))),
"shmmax": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMAX, 10))),
"shmmni": newStaticProcInode(ctx, msrc, []byte(strconv.FormatUint(linux.SHMMNI, 10))),
diff --git a/pkg/sentry/fs/ramfs/socket.go b/pkg/sentry/fs/ramfs/socket.go
index 29ff004f2..d0c565879 100644
--- a/pkg/sentry/fs/ramfs/socket.go
+++ b/pkg/sentry/fs/ramfs/socket.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -63,7 +64,7 @@ func (s *Socket) BoundEndpoint(*fs.Inode, string) transport.BoundEndpoint {
// GetFile implements fs.FileOperations.GetFile.
func (s *Socket) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
- return fs.NewFile(ctx, dirent, flags, &socketFileOperations{}), nil
+ return nil, syserror.ENXIO
}
// +stateify savable
diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go
index e04cd608d..ad4aea282 100644
--- a/pkg/sentry/fs/tmpfs/inode_file.go
+++ b/pkg/sentry/fs/tmpfs/inode_file.go
@@ -148,6 +148,10 @@ func (*fileInodeOperations) Rename(ctx context.Context, inode *fs.Inode, oldPare
// GetFile implements fs.InodeOperations.GetFile.
func (f *fileInodeOperations) GetFile(ctx context.Context, d *fs.Dirent, flags fs.FileFlags) (*fs.File, error) {
+ if fs.IsSocket(d.Inode.StableAttr) {
+ return nil, syserror.ENXIO
+ }
+
if flags.Write {
fsmetric.TmpfsOpensW.Increment()
} else if flags.Read {
diff --git a/pkg/sentry/fsimpl/ext/inode.go b/pkg/sentry/fsimpl/ext/inode.go
index 9009ba3c7..4a555bf72 100644
--- a/pkg/sentry/fsimpl/ext/inode.go
+++ b/pkg/sentry/fsimpl/ext/inode.go
@@ -200,7 +200,9 @@ func (in *inode) open(rp *vfs.ResolvingPath, vfsd *vfs.Dentry, opts *vfs.OpenOpt
}
var fd symlinkFD
fd.LockFD.Init(&in.locks)
- fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{})
+ if err := fd.vfsfd.Init(&fd, opts.Flags, mnt, vfsd, &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
return &fd.vfsfd, nil
default:
panic(fmt.Sprintf("unknown inode type: %T", in.impl))
diff --git a/pkg/sentry/fsimpl/fuse/connection_control.go b/pkg/sentry/fsimpl/fuse/connection_control.go
index 1b3459c1d..4ab894965 100644
--- a/pkg/sentry/fsimpl/fuse/connection_control.go
+++ b/pkg/sentry/fsimpl/fuse/connection_control.go
@@ -84,11 +84,7 @@ func (conn *connection) InitSend(creds *auth.Credentials, pid uint32) error {
Flags: fuseDefaultInitFlags,
}
- req, err := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
- if err != nil {
- return err
- }
-
+ req := conn.NewRequest(creds, pid, 0, linux.FUSE_INIT, &in)
// Since there is no task to block on and FUSE_INIT is the request
// to unblock other requests, use nil.
return conn.CallAsync(nil, req)
diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go
index 91d16c1cf..d8b0d7657 100644
--- a/pkg/sentry/fsimpl/fuse/connection_test.go
+++ b/pkg/sentry/fsimpl/fuse/connection_test.go
@@ -76,10 +76,7 @@ func TestConnectionAbort(t *testing.T) {
var futNormal []*futureResponse
for i := 0; i < int(numRequests); i++ {
- req, err := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
fut, err := conn.callFutureLocked(task, req)
if err != nil {
t.Fatalf("callFutureLocked failed: %v", err)
@@ -105,10 +102,7 @@ func TestConnectionAbort(t *testing.T) {
}
// After abort, Call() should return directly with ENOTCONN.
- req, err := conn.NewRequest(creds, 0, 0, 0, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, 0, 0, 0, testObj)
_, err = conn.Call(task, req)
if err != syserror.ENOTCONN {
t.Fatalf("Incorrect error code received for Call() after connection aborted")
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 95c475a65..bb2d0d31a 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -219,10 +219,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con
data: rand.Uint32(),
}
- req, err := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
- if err != nil {
- t.Fatalf("NewRequest creation failed: %v", err)
- }
+ req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
// Queue up a request.
// Analogous to Call except it doesn't block on the task.
diff --git a/pkg/sentry/fsimpl/fuse/directory.go b/pkg/sentry/fsimpl/fuse/directory.go
index 8f220a04b..fcc5d9a2a 100644
--- a/pkg/sentry/fsimpl/fuse/directory.go
+++ b/pkg/sentry/fsimpl/fuse/directory.go
@@ -68,11 +68,7 @@ func (dir *directoryFD) IterDirents(ctx context.Context, callback vfs.IterDirent
}
// TODO(gVisor.dev/issue/3404): Support FUSE_READDIRPLUS.
- req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
- if err != nil {
- return err
- }
-
+ req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), dir.inode().nodeID, linux.FUSE_READDIR, &in)
res, err := fusefs.conn.Call(task, req)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/file.go b/pkg/sentry/fsimpl/fuse/file.go
index 83f2816b7..e138b11f8 100644
--- a/pkg/sentry/fsimpl/fuse/file.go
+++ b/pkg/sentry/fsimpl/fuse/file.go
@@ -83,12 +83,8 @@ func (fd *fileDescription) Release(ctx context.Context) {
opcode = linux.FUSE_RELEASE
}
kernelTask := kernel.TaskFromContext(ctx)
- // ignoring errors and FUSE server reply is analogous to Linux's behavior.
- req, err := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
- if err != nil {
- // No way to invoke Call() with an errored request.
- return
- }
+ // Ignoring errors and FUSE server reply is analogous to Linux's behavior.
+ req := conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), fd.inode().nodeID, opcode, &in)
// The reply will be ignored since no callback is defined in asyncCallBack().
conn.CallAsync(kernelTask, req)
}
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index 23e827f90..204d8d143 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -119,7 +119,8 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
deviceDescriptor, err := strconv.ParseInt(deviceDescriptorStr, 10 /* base */, 32 /* bitSize */)
if err != nil {
- return nil, nil, err
+ log.Debugf("%s.GetFilesystem: device FD '%v' not parsable: %v", fsType.Name(), deviceDescriptorStr, err)
+ return nil, nil, syserror.EINVAL
}
kernelTask := kernel.TaskFromContext(ctx)
@@ -128,6 +129,9 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
return nil, nil, syserror.EINVAL
}
fuseFDGeneric := kernelTask.GetFileVFS2(int32(deviceDescriptor))
+ if fuseFDGeneric == nil {
+ return nil, nil, syserror.EINVAL
+ }
defer fuseFDGeneric.DecRef(ctx)
fuseFD, ok := fuseFDGeneric.Impl().(*DeviceFD)
if !ok {
@@ -360,12 +364,8 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
in.Flags &= ^uint32(linux.O_TRUNC)
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
- if err != nil {
- return nil, err
- }
-
// Send the request and receive the reply.
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
@@ -485,10 +485,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err
return syserror.EINVAL
}
in := linux.FUSEUnlinkIn{Name: name}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
- if err != nil {
- return err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return err
@@ -515,11 +512,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro
task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
in := linux.FUSERmDirIn{Name: name}
- req, err := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
- if err != nil {
- return err
- }
-
+ req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
return err
@@ -535,10 +528,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
return nil, syserror.EINVAL
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
- if err != nil {
- return nil, err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, opcode, payload)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return nil, err
@@ -574,10 +564,7 @@ func (i *inode) Readlink(ctx context.Context, mnt *vfs.Mount) (string, error) {
log.Warningf("fusefs.Inode.Readlink: couldn't get kernel task from context")
return "", syserror.EINVAL
}
- req, err := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
- if err != nil {
- return "", err
- }
+ req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_READLINK, &linux.FUSEEmptyIn{})
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
return "", err
@@ -680,11 +667,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp
GetAttrFlags: flags,
Fh: fh,
}
- req, err := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
- if err != nil {
- return linux.FUSEAttr{}, err
- }
-
+ req := i.fs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_GETATTR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
return linux.FUSEAttr{}, err
@@ -803,11 +786,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre
UID: opts.Stat.UID,
GID: opts.Stat.GID,
}
- req, err := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
- if err != nil {
- return err
- }
-
+ req := conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_SETATTR, &in)
res, err := conn.Call(task, req)
if err != nil {
return err
diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go
index 2d396e84c..23ce91849 100644
--- a/pkg/sentry/fsimpl/fuse/read_write.go
+++ b/pkg/sentry/fsimpl/fuse/read_write.go
@@ -79,13 +79,9 @@ func (fs *filesystem) ReadInPages(ctx context.Context, fd *regularFileFD, off ui
in.Offset = off + (uint64(pagesRead) << usermem.PageShift)
in.Size = pagesCanRead << usermem.PageShift
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
- if err != nil {
- return nil, 0, err
- }
-
// TODO(gvisor.dev/issue/3247): support async read.
+ req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_READ, &in)
res, err := fs.conn.Call(t, req)
if err != nil {
return nil, 0, err
@@ -204,11 +200,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64,
in.Offset = off + uint64(written)
in.Size = toWrite
- req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
- if err != nil {
- return 0, err
- }
-
+ req := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in)
req.payload = data[written : written+toWrite]
// TODO(gvisor.dev/issue/3247): support async write.
diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go
index 7fa00569b..41d679358 100644
--- a/pkg/sentry/fsimpl/fuse/request_response.go
+++ b/pkg/sentry/fsimpl/fuse/request_response.go
@@ -70,6 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) {
out.MaxPages = uint16(usermem.ByteOrder.Uint16(src[:2]))
src = src[2:]
}
+ _ = src // Remove unused warning.
}
// SizeBytes is the size of the payload of the FUSE_INIT response.
@@ -104,7 +105,7 @@ type Request struct {
}
// NewRequest creates a new request that can be sent to the FUSE server.
-func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) (*Request, error) {
+func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint64, opcode linux.FUSEOpcode, payload marshal.Marshallable) *Request {
conn.fd.mu.Lock()
defer conn.fd.mu.Unlock()
conn.fd.nextOpID += linux.FUSEOpID(reqIDStep)
@@ -130,7 +131,7 @@ func (conn *connection) NewRequest(creds *auth.Credentials, pid uint32, ino uint
id: hdr.Unique,
hdr: &hdr,
data: buf,
- }, nil
+ }
}
// futureResponse represents an in-flight request, that may or may not have
diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go
index 435a21d77..36a3f6810 100644
--- a/pkg/sentry/fsimpl/host/host.go
+++ b/pkg/sentry/fsimpl/host/host.go
@@ -31,6 +31,7 @@ import (
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/kernfs"
"gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/memmap"
unixsocket "gvisor.dev/gvisor/pkg/sentry/socket/unix"
@@ -499,6 +500,10 @@ func (i *inode) open(ctx context.Context, d *kernfs.Dentry, mnt *vfs.Mount, flag
fileDescription: fileDescription{inode: i},
termios: linux.DefaultReplicaTermios,
}
+ if task := kernel.TaskFromContext(ctx); task != nil {
+ fd.fgProcessGroup = task.ThreadGroup().ProcessGroup()
+ fd.session = fd.fgProcessGroup.Session()
+ }
fd.LockFD.Init(&i.locks)
vfsfd := &fd.vfsfd
if err := vfsfd.Init(fd, flags, mnt, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go
index 469f3a33d..27b00cf6f 100644
--- a/pkg/sentry/fsimpl/overlay/copy_up.go
+++ b/pkg/sentry/fsimpl/overlay/copy_up.go
@@ -16,7 +16,6 @@ package overlay
import (
"fmt"
- "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -129,25 +128,9 @@ func (d *dentry) copyUpLocked(ctx context.Context) error {
return err
}
defer newFD.DecRef(ctx)
- bufIOSeq := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
- for {
- readN, readErr := oldFD.Read(ctx, bufIOSeq, vfs.ReadOptions{})
- if readErr != nil && readErr != io.EOF {
- cleanupUndoCopyUp()
- return readErr
- }
- total := int64(0)
- for total < readN {
- writeN, writeErr := newFD.Write(ctx, bufIOSeq.DropFirst64(total), vfs.WriteOptions{})
- total += writeN
- if writeErr != nil {
- cleanupUndoCopyUp()
- return writeErr
- }
- }
- if readErr == io.EOF {
- break
- }
+ if _, err := vfs.CopyRegularFileData(ctx, newFD, oldFD); err != nil {
+ cleanupUndoCopyUp()
+ return err
}
d.mapsMu.Lock()
defer d.mapsMu.Unlock()
diff --git a/pkg/sentry/fsimpl/overlay/regular_file.go b/pkg/sentry/fsimpl/overlay/regular_file.go
index 2b89a7a6d..25c785fd4 100644
--- a/pkg/sentry/fsimpl/overlay/regular_file.go
+++ b/pkg/sentry/fsimpl/overlay/regular_file.go
@@ -103,8 +103,8 @@ func (fd *regularFileFD) currentFDLocked(ctx context.Context) (*vfs.FileDescript
for e, mask := range fd.lowerWaiters {
fd.cachedFD.EventUnregister(e)
upperFD.EventRegister(e, mask)
- if ready&mask != 0 {
- e.Callback.Callback(e)
+ if m := ready & mask; m != 0 {
+ e.Callback.Callback(e, m)
}
}
}
diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go
index 0ecb592cf..429733c10 100644
--- a/pkg/sentry/fsimpl/pipefs/pipefs.go
+++ b/pkg/sentry/fsimpl/pipefs/pipefs.go
@@ -164,11 +164,11 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e
// and write ends of a newly-created pipe, as for pipe(2) and pipe2(2).
//
// Preconditions: mnt.Filesystem() must have been returned by NewFilesystem().
-func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+func NewConnectedPipeFDs(ctx context.Context, mnt *vfs.Mount, flags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) {
fs := mnt.Filesystem().Impl().(*filesystem)
inode := newInode(ctx, fs)
var d kernfs.Dentry
d.Init(&fs.Filesystem, inode)
defer d.DecRef(ctx)
- return inode.pipe.ReaderWriterPair(mnt, d.VFSDentry(), flags)
+ return inode.pipe.ReaderWriterPair(ctx, mnt, d.VFSDentry(), flags)
}
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index a3780b222..75be6129f 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -57,9 +57,6 @@ func getMM(task *kernel.Task) *mm.MemoryManager {
// MemoryManager's users count is incremented, and must be decremented by the
// caller when it is no longer in use.
func getMMIncRef(task *kernel.Task) (*mm.MemoryManager, error) {
- if task.ExitState() == kernel.TaskExitDead {
- return nil, syserror.ESRCH
- }
var m *mm.MemoryManager
task.WithMuLocked(func(t *kernel.Task) {
m = t.MemoryManager()
@@ -111,9 +108,13 @@ var _ dynamicInode = (*auxvData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (d *auxvData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if d.task.ExitState() == kernel.TaskExitDead {
+ return syserror.ESRCH
+ }
m, err := getMMIncRef(d.task)
if err != nil {
- return err
+ // Return empty file.
+ return nil
}
defer m.DecUsers(ctx)
@@ -157,9 +158,13 @@ var _ dynamicInode = (*cmdlineData)(nil)
// Generate implements vfs.DynamicBytesSource.Generate.
func (d *cmdlineData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+ if d.task.ExitState() == kernel.TaskExitDead {
+ return syserror.ESRCH
+ }
m, err := getMMIncRef(d.task)
if err != nil {
- return err
+ // Return empty file.
+ return nil
}
defer m.DecUsers(ctx)
@@ -472,7 +477,7 @@ func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64
}
m, err := getMMIncRef(fd.inode.task)
if err != nil {
- return 0, nil
+ return 0, err
}
defer m.DecUsers(ctx)
// Buffer the read data because of MM locks
diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go
index 7c7afdcfa..25c407d98 100644
--- a/pkg/sentry/fsimpl/proc/tasks_sys.go
+++ b/pkg/sentry/fsimpl/proc/tasks_sys.go
@@ -44,6 +44,7 @@ func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *
return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{
"hostname": fs.newInode(ctx, root, 0444, &hostnameData{}),
+ "sem": fs.newInode(ctx, root, 0444, newStaticFile(fmt.Sprintf("%d\t%d\t%d\t%d\n", linux.SEMMSL, linux.SEMMNS, linux.SEMOPM, linux.SEMMNI))),
"shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)),
"shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)),
"shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)),
diff --git a/pkg/sentry/fsimpl/signalfd/signalfd.go b/pkg/sentry/fsimpl/signalfd/signalfd.go
index 10f1452ef..246bd87bc 100644
--- a/pkg/sentry/fsimpl/signalfd/signalfd.go
+++ b/pkg/sentry/fsimpl/signalfd/signalfd.go
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+// Package signalfd provides basic signalfd file implementations.
package signalfd
import (
@@ -98,8 +99,8 @@ func (sfd *SignalFileDescription) Read(ctx context.Context, dst usermem.IOSequen
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
- PID: uint32(info.Pid()),
- UID: uint32(info.Uid()),
+ PID: uint32(info.PID()),
+ UID: uint32(info.UID()),
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go
index 59fcff498..a4ad625bb 100644
--- a/pkg/sentry/fsimpl/verity/filesystem.go
+++ b/pkg/sentry/fsimpl/verity/filesystem.go
@@ -163,7 +163,7 @@ afterSymlink:
// verifyChildLocked verifies the hash of child against the already verified
// hash of the parent to ensure the child is expected. verifyChild triggers a
// sentry panic if unexpected modifications to the file system are detected. In
-// noCrashOnVerificationFailure mode it returns a syserror instead.
+// ErrorOnViolation mode it returns a syserror instead.
//
// Preconditions:
// * fs.renameMu must be locked.
@@ -254,7 +254,7 @@ func (fs *filesystem) verifyChildLocked(ctx context.Context, parent *dentry, chi
return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err))
}
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: parentMerkleFD,
Ctx: ctx,
}
@@ -397,7 +397,7 @@ func (fs *filesystem) verifyStatAndChildrenLocked(ctx context.Context, d *dentry
}
}
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: fd,
Ctx: ctx,
}
diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go
index add65bee6..a5171b5ad 100644
--- a/pkg/sentry/fsimpl/verity/verity.go
+++ b/pkg/sentry/fsimpl/verity/verity.go
@@ -64,6 +64,10 @@ const (
// tree file for "/foo" is "/.merkle.verity.foo".
merklePrefix = ".merkle.verity."
+ // merkleRootPrefix is the prefix of the Merkle tree root file. This
+ // needs to be different from merklePrefix to avoid name collision.
+ merkleRootPrefix = ".merkleroot.verity."
+
// merkleOffsetInParentXattr is the extended attribute name specifying the
// offset of the child hash in its parent's Merkle tree.
merkleOffsetInParentXattr = "user.merkle.offset"
@@ -88,10 +92,8 @@ const (
)
var (
- // noCrashOnVerificationFailure indicates whether the sandbox should panic
- // whenever verification fails. If true, an error is returned instead of
- // panicking. This should only be set for tests.
- noCrashOnVerificationFailure bool
+ // action specifies the action towards detected violation.
+ action ViolationAction
// verityMu synchronizes concurrent operations that enable verity and perform
// verification checks.
@@ -102,6 +104,18 @@ var (
// content.
type HashAlgorithm int
+// ViolationAction is a type specifying the action when an integrity violation
+// is detected.
+type ViolationAction int
+
+const (
+ // PanicOnViolation terminates the sentry on detected violation.
+ PanicOnViolation ViolationAction = 0
+ // ErrorOnViolation returns an error from the violating system call on
+ // detected violation.
+ ErrorOnViolation = 1
+)
+
// Currently supported hashing algorithms include SHA256 and SHA512.
const (
SHA256 HashAlgorithm = iota
@@ -166,7 +180,7 @@ type filesystem struct {
// its children. So they shouldn't be enabled the same time. This lock
// is for the whole file system to ensure that no more than one file is
// enabled the same time.
- verityMu sync.RWMutex
+ verityMu sync.RWMutex `state:"nosave"`
}
// InternalFilesystemOptions may be passed as
@@ -196,10 +210,8 @@ type InternalFilesystemOptions struct {
// system wrapped by verity file system.
LowerGetFSOptions vfs.GetFilesystemOptions
- // NoCrashOnVerificationFailure indicates whether the sandbox should
- // panic whenever verification fails. If true, an error is returned
- // instead of panicking. This should only be set for tests.
- NoCrashOnVerificationFailure bool
+ // Action specifies the action on an integrity violation.
+ Action ViolationAction
}
// Name implements vfs.FilesystemType.Name.
@@ -211,10 +223,10 @@ func (FilesystemType) Name() string {
func (FilesystemType) Release(ctx context.Context) {}
// alertIntegrityViolation alerts a violation of integrity, which usually means
-// unexpected modification to the file system is detected. In
-// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic.
+// unexpected modification to the file system is detected. In ErrorOnViolation
+// mode, it returns EIO, otherwise it panic.
func alertIntegrityViolation(msg string) error {
- if noCrashOnVerificationFailure {
+ if action == ErrorOnViolation {
return syserror.EIO
}
panic(msg)
@@ -227,7 +239,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
ctx.Warningf("verity.FilesystemType.GetFilesystem: missing verity configs")
return nil, nil, syserror.EINVAL
}
- noCrashOnVerificationFailure = iopts.NoCrashOnVerificationFailure
+ action = iopts.Action
// Mount the lower file system. The lower file system is wrapped inside
// verity, and should not be exposed or connected.
@@ -255,7 +267,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt
lowerVD.IncRef()
d.lowerVD = lowerVD
- rootMerkleName := merklePrefix + iopts.RootMerkleFileName
+ rootMerkleName := merkleRootPrefix + iopts.RootMerkleFileName
lowerMerkleVD, err := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{
Root: lowerVD,
@@ -744,20 +756,20 @@ func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32)
// file /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The
// hash of the generated Merkle tree and the data size is returned. If fd
// points to a regular file, the data is the content of the file. If fd points
-// to a directory, the data is all hahes of its children, written to the Merkle
+// to a directory, the data is all hashes of its children, written to the Merkle
// tree file.
//
// Preconditions: fd.d.fs.verityMu must be locked.
func (fd *fileDescription) generateMerkleLocked(ctx context.Context) ([]byte, uint64, error) {
- fdReader := vfs.FileReadWriteSeeker{
+ fdReader := FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
}
- merkleReader := vfs.FileReadWriteSeeker{
+ merkleReader := FileReadWriteSeeker{
FD: fd.merkleReader,
Ctx: ctx,
}
- merkleWriter := vfs.FileReadWriteSeeker{
+ merkleWriter := FileReadWriteSeeker{
FD: fd.merkleWriter,
Ctx: ctx,
}
@@ -1047,12 +1059,12 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of
return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err))
}
- dataReader := vfs.FileReadWriteSeeker{
+ dataReader := FileReadWriteSeeker{
FD: fd.lowerFD,
Ctx: ctx,
}
- merkleReader := vfs.FileReadWriteSeeker{
+ merkleReader := FileReadWriteSeeker{
FD: fd.merkleReader,
Ctx: ctx,
}
@@ -1101,3 +1113,45 @@ func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t
func (fd *fileDescription) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
return fd.lowerFD.UnlockPOSIX(ctx, uid, start, length, whence)
}
+
+// FileReadWriteSeeker is a helper struct to pass a vfs.FileDescription as
+// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
+type FileReadWriteSeeker struct {
+ FD *vfs.FileDescription
+ Ctx context.Context
+ ROpts vfs.ReadOptions
+ WOpts vfs.WriteOptions
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
+ return int(n), err
+}
+
+// Read implements io.ReadWriteSeeker.Read.
+func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
+ return int(n), err
+}
+
+// Seek implements io.ReadWriteSeeker.Seek.
+func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
+ return f.FD.Seek(f.Ctx, offset, int32(whence))
+}
+
+// WriteAt implements io.WriterAt.WriteAt.
+func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
+ dst := usermem.BytesIOSequence(p)
+ n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
+ return int(n), err
+}
+
+// Write implements io.ReadWriteSeeker.Write.
+func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
+ buf := usermem.BytesIOSequence(p)
+ n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
+ return int(n), err
+}
diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go
index 5d1f5de08..30d8b4355 100644
--- a/pkg/sentry/fsimpl/verity/verity_test.go
+++ b/pkg/sentry/fsimpl/verity/verity_test.go
@@ -35,14 +35,16 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// rootMerkleFilename is the name of the root Merkle tree file.
-const rootMerkleFilename = "root.verity"
+const (
+ // rootMerkleFilename is the name of the root Merkle tree file.
+ rootMerkleFilename = "root.verity"
+ // maxDataSize is the maximum data size of a test file.
+ maxDataSize = 100000
+)
-// maxDataSize is the maximum data size written to the file for test.
-const maxDataSize = 100000
+var hashAlgs = []HashAlgorithm{SHA256, SHA512}
-// getD returns a *dentry corresponding to VD.
-func getD(t *testing.T, vd vfs.VirtualDentry) *dentry {
+func dentryFromVD(t *testing.T, vd vfs.VirtualDentry) *dentry {
t.Helper()
d, ok := vd.Dentry().Impl().(*dentry)
if !ok {
@@ -51,10 +53,21 @@ func getD(t *testing.T, vd vfs.VirtualDentry) *dentry {
return d
}
+// dentryFromFD returns the dentry corresponding to fd.
+func dentryFromFD(t *testing.T, fd *vfs.FileDescription) *dentry {
+ t.Helper()
+ f, ok := fd.Impl().(*fileDescription)
+ if !ok {
+ t.Fatalf("can't assert %T as a *fileDescription", fd)
+ }
+ return f.d
+}
+
// newVerityRoot creates a new verity mount, and returns the root. The
// underlying file system is tmpfs. If the error is not nil, then cleanup
// should be called when the root is no longer needed.
func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) {
+ t.Helper()
k, err := testutil.Boot()
if err != nil {
t.Fatalf("testutil.Boot: %v", err)
@@ -79,11 +92,11 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
mntns, err := vfsObj.NewMountNamespace(ctx, auth.CredentialsFromContext(ctx), "", "verity", &vfs.MountOptions{
GetFilesystemOptions: vfs.GetFilesystemOptions{
InternalData: InternalFilesystemOptions{
- RootMerkleFileName: rootMerkleFilename,
- LowerName: "tmpfs",
- Alg: hashAlg,
- AllowRuntimeEnable: true,
- NoCrashOnVerificationFailure: true,
+ RootMerkleFileName: rootMerkleFilename,
+ LowerName: "tmpfs",
+ Alg: hashAlg,
+ AllowRuntimeEnable: true,
+ Action: ErrorOnViolation,
},
},
})
@@ -102,7 +115,6 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
t.Fatalf("testutil.CreateTask: %v", err)
}
- t.Helper()
t.Cleanup(func() {
root.DecRef(ctx)
mntns.DecRef(ctx)
@@ -111,6 +123,8 @@ func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem,
}
// openVerityAt opens a verity file.
+//
+// TODO(chongc): release reference from opening the file when done.
func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.VirtualDentry, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
Root: vd,
@@ -123,6 +137,8 @@ func openVerityAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, vd vfs.Vir
}
// openLowerAt opens the file in the underlying file system.
+//
+// TODO(chongc): release reference from opening the file when done.
func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, path string, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
Root: d.lowerVD,
@@ -135,6 +151,8 @@ func (d *dentry) openLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem,
}
// openLowerMerkleAt opens the Merkle file in the underlying file system.
+//
+// TODO(chongc): release reference from opening the file when done.
func (d *dentry) openLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, flags uint32, mode linux.FileMode) (*vfs.FileDescription, error) {
return vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{
Root: d.lowerMerkleVD,
@@ -190,21 +208,11 @@ func (d *dentry) renameLowerMerkleAt(ctx context.Context, vfsObj *vfs.VirtualFil
}, &vfs.RenameOptions{})
}
-// getDentry returns a *dentry corresponds to fd.
-func getDentry(t *testing.T, fd *vfs.FileDescription) *dentry {
- t.Helper()
- f, ok := fd.Impl().(*fileDescription)
- if !ok {
- t.Fatalf("can't assert %T as a *fileDescription", fd)
- }
- return f.d
-}
-
// newFileFD creates a new file in the verity mount, and returns the FD. The FD
// points to a file that has random data generated.
func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, int, error) {
// Create the file in the underlying file system.
- lowerFD, err := getD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
+ lowerFD, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
if err != nil {
return nil, 0, err
}
@@ -231,9 +239,20 @@ func newFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem,
return fd, dataSize, err
}
-// corruptRandomBit randomly flips a bit in the file represented by fd.
-func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
- // Flip a random bit in the underlying file.
+// newEmptyFileFD creates a new empty file in the verity mount, and returns the FD.
+func newEmptyFileFD(ctx context.Context, t *testing.T, vfsObj *vfs.VirtualFilesystem, root vfs.VirtualDentry, filePath string, mode linux.FileMode) (*vfs.FileDescription, error) {
+ // Create the file in the underlying file system.
+ _, err := dentryFromVD(t, root).openLowerAt(ctx, vfsObj, filePath, linux.O_RDWR|linux.O_CREAT|linux.O_EXCL, linux.ModeRegular|mode)
+ if err != nil {
+ return nil, err
+ }
+ // Now open the verity file descriptor.
+ fd, err := openVerityAt(ctx, vfsObj, root, filePath, linux.O_RDONLY, mode)
+ return fd, err
+}
+
+// flipRandomBit randomly flips a bit in the file represented by fd.
+func flipRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) error {
randomPos := int64(rand.Intn(size))
byteToModify := make([]byte, 1)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(byteToModify), randomPos, vfs.ReadOptions{}); err != nil {
@@ -246,7 +265,14 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er
return nil
}
-var hashAlgs = []HashAlgorithm{SHA256, SHA512}
+func enableVerity(ctx context.Context, t *testing.T, fd *vfs.FileDescription) {
+ t.Helper()
+ var args arch.SyscallArguments
+ args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
+ if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
+ t.Fatalf("enable verity: %v", err)
+ }
+}
// TestOpen ensures that when a file is created, the corresponding Merkle tree
// file and the root Merkle tree file exist.
@@ -264,12 +290,12 @@ func TestOpen(t *testing.T) {
}
// Ensure that the corresponding Merkle tree file is created.
- if _, err = getDentry(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
+ if _, err = dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err)
}
// Ensure the root merkle tree file is created.
- if _, err = getD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
+ if _, err = dentryFromVD(t, root).openLowerMerkleAt(ctx, vfsObj, linux.O_RDONLY, linux.ModeRegular); err != nil {
t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err)
}
}
@@ -291,11 +317,7 @@ func TestPReadUnmodifiedFileSucceeds(t *testing.T) {
}
// Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
buf := make([]byte, size)
n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{})
@@ -325,11 +347,7 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
// Enable verity on the file and confirm a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
buf := make([]byte, size)
n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
@@ -343,6 +361,36 @@ func TestReadUnmodifiedFileSucceeds(t *testing.T) {
}
}
+// TestReadUnmodifiedEmptyFileSucceeds ensures that read from an untouched empty verity
+// file succeeds after enabling verity for it.
+func TestReadUnmodifiedEmptyFileSucceeds(t *testing.T) {
+ for _, alg := range hashAlgs {
+ vfsObj, root, ctx, err := newVerityRoot(t, alg)
+ if err != nil {
+ t.Fatalf("newVerityRoot: %v", err)
+ }
+
+ filename := "verity-test-empty-file"
+ fd, err := newEmptyFileFD(ctx, t, vfsObj, root, filename, 0644)
+ if err != nil {
+ t.Fatalf("newEmptyFileFD: %v", err)
+ }
+
+ // Enable verity on the file and confirm a normal read succeeds.
+ enableVerity(ctx, t, fd)
+
+ var buf []byte
+ n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{})
+ if err != nil && err != io.EOF {
+ t.Fatalf("fd.Read: %v", err)
+ }
+
+ if n != 0 {
+ t.Errorf("fd.Read got read length %d, expected 0", n)
+ }
+ }
+}
+
// TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file
// succeeds after enabling verity for it.
func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
@@ -359,11 +407,7 @@ func TestReopenUnmodifiedFileSucceeds(t *testing.T) {
}
// Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Ensure reopening the verity enabled file succeeds.
if _, err = openVerityAt(ctx, vfsObj, root, filename, linux.O_RDONLY, linux.ModeRegular); err != nil {
@@ -387,21 +431,14 @@ func TestOpenNonexistentFile(t *testing.T) {
}
// Enable verity on the file and confirms a normal read succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Enable verity on the parent directory.
parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, parentFD)
// Ensure open an unexpected file in the parent directory fails with
// ENOENT rather than verification failure.
@@ -426,20 +463,16 @@ func TestPReadModifiedFileFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerFD that's read/writable.
- lowerFD, err := getDentry(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
+ lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from the modified file fails.
@@ -466,20 +499,16 @@ func TestReadModifiedFileFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerFD that's read/writable.
- lowerFD, err := getDentry(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
+ lowerFD, err := dentryFromFD(t, fd).openLowerAt(ctx, vfsObj, "", linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
- if err := corruptRandomBit(ctx, lowerFD, size); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerFD, size); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from the modified file fails.
@@ -506,14 +535,10 @@ func TestModifiedMerkleFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Open a new lowerMerkleFD that's read/writable.
- lowerMerkleFD, err := getDentry(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
+ lowerMerkleFD, err := dentryFromFD(t, fd).openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
@@ -524,14 +549,13 @@ func TestModifiedMerkleFails(t *testing.T) {
t.Errorf("lowerMerkleFD.Stat: %v", err)
}
- if err := corruptRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, lowerMerkleFD, int(stat.Size)); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
// Confirm that read from a file with modified Merkle tree fails.
buf := make([]byte, size)
if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil {
- fmt.Println(buf)
t.Fatalf("fd.PRead succeeded with modified Merkle file")
}
}
@@ -554,24 +578,17 @@ func TestModifiedParentMerkleFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
// Enable verity on the parent directory.
parentFD, err := openVerityAt(ctx, vfsObj, root, "", linux.O_RDONLY, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
-
- if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, parentFD)
// Open a new lowerMerkleFD that's read/writable.
- parentLowerMerkleFD, err := getDentry(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
+ parentLowerMerkleFD, err := dentryFromFD(t, fd).parent.openLowerMerkleAt(ctx, vfsObj, linux.O_RDWR, linux.ModeRegular)
if err != nil {
t.Fatalf("OpenAt: %v", err)
}
@@ -591,8 +608,8 @@ func TestModifiedParentMerkleFails(t *testing.T) {
if err != nil {
t.Fatalf("Failed convert size to int: %v", err)
}
- if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
- t.Fatalf("corruptRandomBit: %v", err)
+ if err := flipRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil {
+ t.Fatalf("flipRandomBit: %v", err)
}
parentLowerMerkleFD.DecRef(ctx)
@@ -619,13 +636,8 @@ func TestUnmodifiedStatSucceeds(t *testing.T) {
t.Fatalf("newFileFD: %v", err)
}
- // Enable verity on the file and confirms stat succeeds.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
-
+ // Enable verity on the file and confirm that stat succeeds.
+ enableVerity(ctx, t, fd)
if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil {
t.Errorf("fd.Stat: %v", err)
}
@@ -648,11 +660,7 @@ func TestModifiedStatFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("fd.Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
lowerFD := fd.Impl().(*fileDescription).lowerFD
// Change the stat of the underlying file, and check that stat fails.
@@ -711,19 +719,15 @@ func TestOpenDeletedFileFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
if tc.changeFile {
- if err := getD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil {
+ if err := dentryFromVD(t, root).unlinkLowerAt(ctx, vfsObj, filename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
if tc.changeMerkleFile {
- if err := getD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil {
+ if err := dentryFromVD(t, root).unlinkLowerMerkleAt(ctx, vfsObj, filename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
@@ -776,20 +780,16 @@ func TestOpenRenamedFileFails(t *testing.T) {
}
// Enable verity on the file.
- var args arch.SyscallArguments
- args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY}
- if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil {
- t.Fatalf("Ioctl: %v", err)
- }
+ enableVerity(ctx, t, fd)
newFilename := "renamed-test-file"
if tc.changeFile {
- if err := getD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil {
+ if err := dentryFromVD(t, root).renameLowerAt(ctx, vfsObj, filename, newFilename); err != nil {
t.Fatalf("RenameAt: %v", err)
}
}
if tc.changeMerkleFile {
- if err := getD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil {
+ if err := dentryFromVD(t, root).renameLowerMerkleAt(ctx, vfsObj, filename, newFilename); err != nil {
t.Fatalf("UnlinkAt: %v", err)
}
}
diff --git a/pkg/sentry/kernel/epoll/epoll.go b/pkg/sentry/kernel/epoll/epoll.go
index 15519f0df..61aeca044 100644
--- a/pkg/sentry/kernel/epoll/epoll.go
+++ b/pkg/sentry/kernel/epoll/epoll.go
@@ -273,7 +273,7 @@ func (e *EventPoll) ReadEvents(max int) []linux.EpollEvent {
//
// Callback is called when one of the files we're polling becomes ready. It
// moves said file to the readyList if it's currently in the waiting list.
-func (p *pollEntry) Callback(*waiter.Entry) {
+func (p *pollEntry) Callback(*waiter.Entry, waiter.EventMask) {
e := p.epoll
e.listsMu.Lock()
@@ -306,9 +306,8 @@ func (e *EventPoll) initEntryReadiness(entry *pollEntry) {
f.EventRegister(&entry.waiter, entry.mask)
// Check if the file happens to already be in a ready state.
- ready := f.Readiness(entry.mask) & entry.mask
- if ready != 0 {
- entry.Callback(&entry.waiter)
+ if ready := f.Readiness(entry.mask) & entry.mask; ready != 0 {
+ entry.Callback(&entry.waiter, ready)
}
}
diff --git a/pkg/sentry/kernel/fasync/BUILD b/pkg/sentry/kernel/fasync/BUILD
index 2b3955598..f855f038b 100644
--- a/pkg/sentry/kernel/fasync/BUILD
+++ b/pkg/sentry/kernel/fasync/BUILD
@@ -8,11 +8,13 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
+ "//pkg/sentry/arch",
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
"//pkg/sentry/vfs",
"//pkg/sync",
+ "//pkg/syserror",
"//pkg/waiter",
],
)
diff --git a/pkg/sentry/kernel/fasync/fasync.go b/pkg/sentry/kernel/fasync/fasync.go
index 153d2cd9b..b66d61c6f 100644
--- a/pkg/sentry/kernel/fasync/fasync.go
+++ b/pkg/sentry/kernel/fasync/fasync.go
@@ -17,22 +17,45 @@ package fasync
import (
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/waiter"
)
-// New creates a new fs.FileAsync.
-func New() fs.FileAsync {
- return &FileAsync{}
+// Table to convert waiter event masks into si_band siginfo codes.
+// Taken from fs/fcntl.c:band_table.
+var bandTable = map[waiter.EventMask]int64{
+ // POLL_IN
+ waiter.EventIn: linux.EPOLLIN | linux.EPOLLRDNORM,
+ // POLL_OUT
+ waiter.EventOut: linux.EPOLLOUT | linux.EPOLLWRNORM | linux.EPOLLWRBAND,
+ // POLL_ERR
+ waiter.EventErr: linux.EPOLLERR,
+ // POLL_PRI
+ waiter.EventPri: linux.EPOLLPRI | linux.EPOLLRDBAND,
+ // POLL_HUP
+ waiter.EventHUp: linux.EPOLLHUP | linux.EPOLLERR,
}
-// NewVFS2 creates a new vfs.FileAsync.
-func NewVFS2() vfs.FileAsync {
- return &FileAsync{}
+// New returns a function that creates a new fs.FileAsync with the given file
+// descriptor.
+func New(fd int) func() fs.FileAsync {
+ return func() fs.FileAsync {
+ return &FileAsync{fd: fd}
+ }
+}
+
+// NewVFS2 returns a function that creates a new vfs.FileAsync with the given
+// file descriptor.
+func NewVFS2(fd int) func() vfs.FileAsync {
+ return func() vfs.FileAsync {
+ return &FileAsync{fd: fd}
+ }
}
// FileAsync sends signals when the registered file is ready for IO.
@@ -42,6 +65,12 @@ type FileAsync struct {
// e is immutable after first use (which is protected by mu below).
e waiter.Entry
+ // fd is the file descriptor to notify about.
+ // It is immutable, set at allocation time. This matches Linux semantics in
+ // fs/fcntl.c:fasync_helper.
+ // The fd value is passed to the signal recipient in siginfo.si_fd.
+ fd int
+
// regMu protects registeration and unregistration actions on e.
//
// regMu must be held while registration decisions are being made
@@ -56,6 +85,10 @@ type FileAsync struct {
mu sync.Mutex `state:"nosave"`
requester *auth.Credentials
registered bool
+ // signal is the signal to deliver upon I/O being available.
+ // The default value ("zero signal") means the default SIGIO signal will be
+ // delivered.
+ signal linux.Signal
// Only one of the following is allowed to be non-nil.
recipientPG *kernel.ProcessGroup
@@ -64,10 +97,10 @@ type FileAsync struct {
}
// Callback sends a signal.
-func (a *FileAsync) Callback(e *waiter.Entry) {
+func (a *FileAsync) Callback(e *waiter.Entry, mask waiter.EventMask) {
a.mu.Lock()
+ defer a.mu.Unlock()
if !a.registered {
- a.mu.Unlock()
return
}
t := a.recipientT
@@ -80,19 +113,34 @@ func (a *FileAsync) Callback(e *waiter.Entry) {
}
if t == nil {
// No recipient has been registered.
- a.mu.Unlock()
return
}
c := t.Credentials()
// Logic from sigio_perm in fs/fcntl.c.
- if a.requester.EffectiveKUID == 0 ||
+ permCheck := (a.requester.EffectiveKUID == 0 ||
a.requester.EffectiveKUID == c.SavedKUID ||
a.requester.EffectiveKUID == c.RealKUID ||
a.requester.RealKUID == c.SavedKUID ||
- a.requester.RealKUID == c.RealKUID {
- t.SendSignal(kernel.SignalInfoPriv(linux.SIGIO))
+ a.requester.RealKUID == c.RealKUID)
+ if !permCheck {
+ return
}
- a.mu.Unlock()
+ signalInfo := &arch.SignalInfo{
+ Signo: int32(linux.SIGIO),
+ Code: arch.SignalInfoKernel,
+ }
+ if a.signal != 0 {
+ signalInfo.Signo = int32(a.signal)
+ signalInfo.SetFD(uint32(a.fd))
+ var band int64
+ for m, bandCode := range bandTable {
+ if m&mask != 0 {
+ band |= bandCode
+ }
+ }
+ signalInfo.SetBand(band)
+ }
+ t.SendSignal(signalInfo)
}
// Register sets the file which will be monitored for IO events.
@@ -186,3 +234,25 @@ func (a *FileAsync) ClearOwner() {
a.recipientTG = nil
a.recipientPG = nil
}
+
+// Signal returns which signal will be sent to the signal recipient.
+// A value of zero means the signal to deliver wasn't customized, which means
+// the default signal (SIGIO) will be delivered.
+func (a *FileAsync) Signal() linux.Signal {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.signal
+}
+
+// SetSignal overrides which signal to send when I/O is available.
+// The default behavior can be reset by specifying signal zero, which means
+// to send SIGIO.
+func (a *FileAsync) SetSignal(signal linux.Signal) error {
+ if signal != 0 && !signal.IsValid() {
+ return syserror.EINVAL
+ }
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ a.signal = signal
+ return nil
+}
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index 470d8bf83..f17f9c59c 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -121,18 +121,21 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2
panic("VFS1 and VFS2 files set")
}
- slice := *(*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
+ slicePtr := (*[]unsafe.Pointer)(atomic.LoadPointer(&f.slice))
// Grow the table as required.
- if last := int32(len(slice)); fd >= last {
+ if last := int32(len(*slicePtr)); fd >= last {
end := fd + 1
if end < 2*last {
end = 2 * last
}
- slice = append(slice, make([]unsafe.Pointer, end-last)...)
- atomic.StorePointer(&f.slice, unsafe.Pointer(&slice))
+ newSlice := append(*slicePtr, make([]unsafe.Pointer, end-last)...)
+ slicePtr = &newSlice
+ atomic.StorePointer(&f.slice, unsafe.Pointer(slicePtr))
}
+ slice := *slicePtr
+
var desc *descriptor
if file != nil || fileVFS2 != nil {
desc = &descriptor{
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 2cdcdfc1f..b8627a54f 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -214,9 +214,11 @@ type Kernel struct {
// netlinkPorts manages allocation of netlink socket port IDs.
netlinkPorts *port.Manager
- // saveErr is the error causing the sandbox to exit during save, if
- // any. It is protected by extMu.
- saveErr error `state:"nosave"`
+ // saveStatus is nil if the sandbox has not been saved, errSaved or
+ // errAutoSaved if it has been saved successfully, or the error causing the
+ // sandbox to exit during save.
+ // It is protected by extMu.
+ saveStatus error `state:"nosave"`
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
@@ -1481,12 +1483,42 @@ func (k *Kernel) NetlinkPorts() *port.Manager {
return k.netlinkPorts
}
-// SaveError returns the sandbox error that caused the kernel to exit during
-// save.
-func (k *Kernel) SaveError() error {
+var (
+ errSaved = errors.New("sandbox has been successfully saved")
+ errAutoSaved = errors.New("sandbox has been successfully auto-saved")
+)
+
+// SaveStatus returns the sandbox save status. If it was saved successfully,
+// autosaved indicates whether save was triggered by autosave. If it was not
+// saved successfully, err indicates the sandbox error that caused the kernel to
+// exit during save.
+func (k *Kernel) SaveStatus() (saved, autosaved bool, err error) {
+ k.extMu.Lock()
+ defer k.extMu.Unlock()
+ switch k.saveStatus {
+ case nil:
+ return false, false, nil
+ case errSaved:
+ return true, false, nil
+ case errAutoSaved:
+ return true, true, nil
+ default:
+ return false, false, k.saveStatus
+ }
+}
+
+// SetSaveSuccess sets the flag indicating that save completed successfully, if
+// no status was already set.
+func (k *Kernel) SetSaveSuccess(autosave bool) {
k.extMu.Lock()
defer k.extMu.Unlock()
- return k.saveErr
+ if k.saveStatus == nil {
+ if autosave {
+ k.saveStatus = errAutoSaved
+ } else {
+ k.saveStatus = errSaved
+ }
+ }
}
// SetSaveError sets the sandbox error that caused the kernel to exit during
@@ -1494,8 +1526,8 @@ func (k *Kernel) SaveError() error {
func (k *Kernel) SetSaveError(err error) {
k.extMu.Lock()
defer k.extMu.Unlock()
- if k.saveErr == nil {
- k.saveErr = err
+ if k.saveStatus == nil {
+ k.saveStatus = err
}
}
diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go
index 7b23cbe86..2d47d2e82 100644
--- a/pkg/sentry/kernel/pipe/vfs.go
+++ b/pkg/sentry/kernel/pipe/vfs.go
@@ -63,10 +63,19 @@ func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe {
// ReaderWriterPair returns read-only and write-only FDs for vp.
//
// Preconditions: statusFlags should not contain an open access mode.
-func (vp *VFSPipe) ReaderWriterPair(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription) {
+func (vp *VFSPipe) ReaderWriterPair(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32) (*vfs.FileDescription, *vfs.FileDescription, error) {
// Connected pipes share the same locks.
locks := &vfs.FileLocks{}
- return vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks), vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
+ r, err := vp.newFD(mnt, vfsd, linux.O_RDONLY|statusFlags, locks)
+ if err != nil {
+ return nil, nil, err
+ }
+ w, err := vp.newFD(mnt, vfsd, linux.O_WRONLY|statusFlags, locks)
+ if err != nil {
+ r.DecRef(ctx)
+ return nil, nil, err
+ }
+ return r, w, nil
}
// Allocate implements vfs.FileDescriptionImpl.Allocate.
@@ -85,7 +94,10 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
return nil, syserror.EINVAL
}
- fd := vp.newFD(mnt, vfsd, statusFlags, locks)
+ fd, err := vp.newFD(mnt, vfsd, statusFlags, locks)
+ if err != nil {
+ return nil, err
+ }
// Named pipes have special blocking semantics during open:
//
@@ -137,16 +149,18 @@ func (vp *VFSPipe) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, s
}
// Preconditions: vp.mu must be held.
-func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) *vfs.FileDescription {
+func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
fd := &VFSPipeFD{
pipe: &vp.pipe,
}
fd.LockFD.Init(locks)
- fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
+ if err := fd.vfsfd.Init(fd, statusFlags, mnt, vfsd, &vfs.FileDescriptionOptions{
DenyPRead: true,
DenyPWrite: true,
UseDentryMetadata: true,
- })
+ }); err != nil {
+ return nil, err
+ }
switch {
case fd.vfsfd.IsReadable() && fd.vfsfd.IsWritable():
@@ -160,7 +174,7 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l
panic("invalid pipe flags: must be readable, writable, or both")
}
- return &fd.vfsfd
+ return &fd.vfsfd, nil
}
// VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements
diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go
index 1abfe2201..cef58a590 100644
--- a/pkg/sentry/kernel/ptrace.go
+++ b/pkg/sentry/kernel/ptrace.go
@@ -259,8 +259,8 @@ func (t *Task) ptraceTrapLocked(code int32) {
Signo: int32(linux.SIGTRAP),
Code: code,
}
- t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
- t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
if t.beginPtraceStopLocked() {
tracer := t.Tracer()
tracer.signalStop(t, arch.CLD_TRAPPED, int32(linux.SIGTRAP))
diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go
index b99c0bffa..db01e4a97 100644
--- a/pkg/sentry/kernel/semaphore/semaphore.go
+++ b/pkg/sentry/kernel/semaphore/semaphore.go
@@ -29,17 +29,17 @@ import (
)
const (
- valueMax = 32767 // SEMVMX
+ // Maximum semaphore value.
+ valueMax = linux.SEMVMX
- // semaphoresMax is "maximum number of semaphores per semaphore ID" (SEMMSL).
- semaphoresMax = 32000
+ // Maximum number of semaphore sets.
+ setsMax = linux.SEMMNI
- // setMax is "system-wide limit on the number of semaphore sets" (SEMMNI).
- setsMax = 32000
+ // Maximum number of semaphroes in a semaphore set.
+ semsMax = linux.SEMMSL
- // semaphoresTotalMax is "system-wide limit on the number of semaphores"
- // (SEMMNS = SEMMNI*SEMMSL).
- semaphoresTotalMax = 1024000000
+ // Maximum number of semaphores in all semaphroe sets.
+ semsTotalMax = linux.SEMMNS
)
// Registry maintains a set of semaphores that can be found by key or ID.
@@ -52,6 +52,9 @@ type Registry struct {
mu sync.Mutex `state:"nosave"`
semaphores map[int32]*Set
lastIDUsed int32
+ // indexes maintains a mapping between a set's index in virtual array and
+ // its identifier.
+ indexes map[int32]int32
}
// Set represents a set of semaphores that can be operated atomically.
@@ -113,6 +116,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
return &Registry{
userNS: userNS,
semaphores: make(map[int32]*Set),
+ indexes: make(map[int32]int32),
}
}
@@ -122,7 +126,7 @@ func NewRegistry(userNS *auth.UserNamespace) *Registry {
// be found. If exclusive is true, it fails if a set with the same key already
// exists.
func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linux.FileMode, private, create, exclusive bool) (*Set, error) {
- if nsems < 0 || nsems > semaphoresMax {
+ if nsems < 0 || nsems > semsMax {
return nil, syserror.EINVAL
}
@@ -163,10 +167,13 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
}
// Apply system limits.
+ //
+ // Map semaphores and map indexes in a registry are of the same size,
+ // check map semaphores only here for the system limit.
if len(r.semaphores) >= setsMax {
return nil, syserror.EINVAL
}
- if r.totalSems() > int(semaphoresTotalMax-nsems) {
+ if r.totalSems() > int(semsTotalMax-nsems) {
return nil, syserror.EINVAL
}
@@ -176,6 +183,53 @@ func (r *Registry) FindOrCreate(ctx context.Context, key, nsems int32, mode linu
return r.newSet(ctx, key, owner, owner, perms, nsems)
}
+// IPCInfo returns information about system-wide semaphore limits and parameters.
+func (r *Registry) IPCInfo() *linux.SemInfo {
+ return &linux.SemInfo{
+ SemMap: linux.SEMMAP,
+ SemMni: linux.SEMMNI,
+ SemMns: linux.SEMMNS,
+ SemMnu: linux.SEMMNU,
+ SemMsl: linux.SEMMSL,
+ SemOpm: linux.SEMOPM,
+ SemUme: linux.SEMUME,
+ SemUsz: linux.SEMUSZ,
+ SemVmx: linux.SEMVMX,
+ SemAem: linux.SEMAEM,
+ }
+}
+
+// SemInfo returns a seminfo structure containing the same information as
+// for IPC_INFO, except that SemUsz field returns the number of existing
+// semaphore sets, and SemAem field returns the number of existing semaphores.
+func (r *Registry) SemInfo() *linux.SemInfo {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ info := r.IPCInfo()
+ info.SemUsz = uint32(len(r.semaphores))
+ info.SemAem = uint32(r.totalSems())
+
+ return info
+}
+
+// HighestIndex returns the index of the highest used entry in
+// the kernel's array.
+func (r *Registry) HighestIndex() int32 {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ // By default, highest used index is 0 even though
+ // there is no semaphroe set.
+ var highestIndex int32
+ for index := range r.indexes {
+ if index > highestIndex {
+ highestIndex = index
+ }
+ }
+ return highestIndex
+}
+
// RemoveID removes set with give 'id' from the registry and marks the set as
// dead. All waiters will be awakened and fail.
func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
@@ -186,6 +240,11 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
if set == nil {
return syserror.EINVAL
}
+ index, found := r.findIndexByID(id)
+ if !found {
+ // Inconsistent state.
+ panic(fmt.Sprintf("unable to find an index for ID: %d", id))
+ }
set.mu.Lock()
defer set.mu.Unlock()
@@ -197,6 +256,7 @@ func (r *Registry) RemoveID(id int32, creds *auth.Credentials) error {
}
delete(r.semaphores, set.ID)
+ delete(r.indexes, index)
set.destroy()
return nil
}
@@ -220,6 +280,11 @@ func (r *Registry) newSet(ctx context.Context, key int32, owner, creator fs.File
continue
}
if r.semaphores[id] == nil {
+ index, found := r.findFirstAvailableIndex()
+ if !found {
+ panic("unable to find an available index")
+ }
+ r.indexes[index] = id
r.lastIDUsed = id
r.semaphores[id] = set
set.ID = id
@@ -238,6 +303,18 @@ func (r *Registry) FindByID(id int32) *Set {
return r.semaphores[id]
}
+// FindByIndex looks up a set given an index.
+func (r *Registry) FindByIndex(index int32) *Set {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ id, present := r.indexes[index]
+ if !present {
+ return nil
+ }
+ return r.semaphores[id]
+}
+
func (r *Registry) findByKey(key int32) *Set {
for _, v := range r.semaphores {
if v.key == key {
@@ -247,6 +324,24 @@ func (r *Registry) findByKey(key int32) *Set {
return nil
}
+func (r *Registry) findIndexByID(id int32) (int32, bool) {
+ for k, v := range r.indexes {
+ if v == id {
+ return k, true
+ }
+ }
+ return 0, false
+}
+
+func (r *Registry) findFirstAvailableIndex() (int32, bool) {
+ for index := int32(0); index < setsMax; index++ {
+ if _, present := r.indexes[index]; !present {
+ return index, true
+ }
+ }
+ return 0, false
+}
+
func (r *Registry) totalSems() int {
totalSems := 0
for _, v := range r.semaphores {
diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD
index 80a592c8f..073e14507 100644
--- a/pkg/sentry/kernel/shm/BUILD
+++ b/pkg/sentry/kernel/shm/BUILD
@@ -6,6 +6,9 @@ package(licenses = ["notice"])
go_template_instance(
name = "shm_refs",
out = "shm_refs.go",
+ consts = {
+ "enableLogging": "true",
+ },
package = "shm",
prefix = "Shm",
template = "//pkg/refsvfs2:refs_template",
diff --git a/pkg/sentry/kernel/signal.go b/pkg/sentry/kernel/signal.go
index e8cce37d0..2488ae7d5 100644
--- a/pkg/sentry/kernel/signal.go
+++ b/pkg/sentry/kernel/signal.go
@@ -73,7 +73,7 @@ func SignalInfoNoInfo(sig linux.Signal, sender, receiver *Task) *arch.SignalInfo
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
- info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.tg.pidns.IDOfThreadGroup(sender.tg)))
+ info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}
diff --git a/pkg/sentry/kernel/signalfd/signalfd.go b/pkg/sentry/kernel/signalfd/signalfd.go
index 78f718cfe..884966120 100644
--- a/pkg/sentry/kernel/signalfd/signalfd.go
+++ b/pkg/sentry/kernel/signalfd/signalfd.go
@@ -106,8 +106,8 @@ func (s *SignalOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
Signo: uint32(info.Signo),
Errno: info.Errno,
Code: info.Code,
- PID: uint32(info.Pid()),
- UID: uint32(info.Uid()),
+ PID: uint32(info.PID()),
+ UID: uint32(info.UID()),
Status: info.Status(),
Overrun: uint32(info.Overrun()),
Addr: info.Addr(),
diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go
index c5137c282..16986244c 100644
--- a/pkg/sentry/kernel/task_exit.go
+++ b/pkg/sentry/kernel/task_exit.go
@@ -368,8 +368,8 @@ func (t *Task) exitChildren() {
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- siginfo.SetPid(int32(c.tg.pidns.tids[t]))
- siginfo.SetUid(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
+ siginfo.SetPID(int32(c.tg.pidns.tids[t]))
+ siginfo.SetUID(int32(t.Credentials().RealKUID.In(c.UserNamespace()).OrOverflow()))
c.tg.signalHandlers.mu.Lock()
c.sendSignalLocked(siginfo, true /* group */)
c.tg.signalHandlers.mu.Unlock()
@@ -698,8 +698,8 @@ func (t *Task) exitNotificationSignal(sig linux.Signal, receiver *Task) *arch.Si
info := &arch.SignalInfo{
Signo: int32(sig),
}
- info.SetPid(int32(receiver.tg.pidns.tids[t]))
- info.SetUid(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.tg.pidns.tids[t]))
+ info.SetUID(int32(t.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
if t.exitStatus.Signaled() {
info.Code = arch.CLD_KILLED
info.SetStatus(int32(t.exitStatus.Signo))
diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go
index 42dd3e278..75af3af79 100644
--- a/pkg/sentry/kernel/task_signals.go
+++ b/pkg/sentry/kernel/task_signals.go
@@ -914,8 +914,8 @@ func (t *Task) signalStop(target *Task, code int32, status int32) {
Signo: int32(linux.SIGCHLD),
Code: code,
}
- sigchld.SetPid(int32(t.tg.pidns.tids[target]))
- sigchld.SetUid(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ sigchld.SetPID(int32(t.tg.pidns.tids[target]))
+ sigchld.SetUID(int32(target.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
sigchld.SetStatus(status)
// TODO(b/72102453): Set utime, stime.
t.sendSignalLocked(sigchld, true /* group */)
@@ -1022,8 +1022,8 @@ func (*runInterrupt) execute(t *Task) taskRunState {
Signo: int32(sig),
Code: t.ptraceCode,
}
- t.ptraceSiginfo.SetPid(int32(t.tg.pidns.tids[t]))
- t.ptraceSiginfo.SetUid(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ t.ptraceSiginfo.SetPID(int32(t.tg.pidns.tids[t]))
+ t.ptraceSiginfo.SetUID(int32(t.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
} else {
t.ptraceCode = int32(sig)
t.ptraceSiginfo = nil
@@ -1114,11 +1114,11 @@ func (*runInterruptAfterSignalDeliveryStop) execute(t *Task) taskRunState {
if parent == nil {
// Tracer has detached and t was created by Kernel.CreateProcess().
// Pretend the parent is in an ancestor PID + user namespace.
- info.SetPid(0)
- info.SetUid(int32(auth.OverflowUID))
+ info.SetPID(0)
+ info.SetUID(int32(auth.OverflowUID))
} else {
- info.SetPid(int32(t.tg.pidns.tids[parent]))
- info.SetUid(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(t.tg.pidns.tids[parent]))
+ info.SetUID(int32(parent.Credentials().RealKUID.In(t.UserNamespace()).OrOverflow()))
}
}
t.tg.signalHandlers.mu.Lock()
diff --git a/pkg/sentry/memmap/memmap.go b/pkg/sentry/memmap/memmap.go
index 7fd77925f..49e21026e 100644
--- a/pkg/sentry/memmap/memmap.go
+++ b/pkg/sentry/memmap/memmap.go
@@ -160,7 +160,7 @@ func CheckTranslateResult(required, optional MappableRange, at usermem.AccessTyp
// Translations must be contiguous and in increasing order of
// Translation.Source.
if i > 0 && ts[i-1].Source.End != t.Source.Start {
- return fmt.Errorf("Translations %+v and %+v are not contiguous", ts[i-1], t)
+ return fmt.Errorf("Translation %+v and Translation %+v are not contiguous", ts[i-1], t)
}
// At least part of each Translation must be required.
if t.Source.Intersect(required).Length() == 0 {
diff --git a/pkg/sentry/mm/aio_context.go b/pkg/sentry/mm/aio_context.go
index 4c8cd38ed..5ab2ef79f 100644
--- a/pkg/sentry/mm/aio_context.go
+++ b/pkg/sentry/mm/aio_context.go
@@ -36,12 +36,12 @@ type aioManager struct {
contexts map[uint64]*AIOContext
}
-func (a *aioManager) destroy() {
- a.mu.Lock()
- defer a.mu.Unlock()
+func (mm *MemoryManager) destroyAIOManager(ctx context.Context) {
+ mm.aioManager.mu.Lock()
+ defer mm.aioManager.mu.Unlock()
- for _, ctx := range a.contexts {
- ctx.destroy()
+ for id := range mm.aioManager.contexts {
+ mm.destroyAIOContextLocked(ctx, id)
}
}
@@ -68,16 +68,26 @@ func (a *aioManager) newAIOContext(events uint32, id uint64) bool {
// be drained.
//
// Nil is returned if the context does not exist.
-func (a *aioManager) destroyAIOContext(id uint64) *AIOContext {
- a.mu.Lock()
- defer a.mu.Unlock()
- ctx, ok := a.contexts[id]
+//
+// Precondition: mm.aioManager.mu is locked.
+func (mm *MemoryManager) destroyAIOContextLocked(ctx context.Context, id uint64) *AIOContext {
+ aioCtx, ok := mm.aioManager.contexts[id]
if !ok {
return nil
}
- delete(a.contexts, id)
- ctx.destroy()
- return ctx
+
+ // Only unmaps after it assured that the address is a valid aio context to
+ // prevent random memory from been unmapped.
+ //
+ // Note: It's possible to unmap this address and map something else into
+ // the same address. Then it would be unmapping memory that it doesn't own.
+ // This is, however, the way Linux implements AIO. Keeps the same [weird]
+ // semantics in case anyone relies on it.
+ mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
+
+ delete(mm.aioManager.contexts, id)
+ aioCtx.destroy()
+ return aioCtx
}
// lookupAIOContext looks up the given context.
@@ -140,16 +150,21 @@ func (ctx *AIOContext) checkForDone() {
}
}
-// Prepare reserves space for a new request, returning true if available.
-// Returns false if the context is busy.
-func (ctx *AIOContext) Prepare() bool {
+// Prepare reserves space for a new request, returning nil if available.
+// Returns EAGAIN if the context is busy and EINVAL if the context is dead.
+func (ctx *AIOContext) Prepare() error {
ctx.mu.Lock()
defer ctx.mu.Unlock()
+ if ctx.dead {
+ // Context died after the caller looked it up.
+ return syserror.EINVAL
+ }
if ctx.outstanding >= ctx.maxOutstanding {
- return false
+ // Context is busy.
+ return syserror.EAGAIN
}
ctx.outstanding++
- return true
+ return nil
}
// PopRequest pops a completed request if available, this function does not do
@@ -391,20 +406,13 @@ func (mm *MemoryManager) NewAIOContext(ctx context.Context, events uint32) (uint
// DestroyAIOContext destroys an asynchronous I/O context. It returns the
// destroyed context. nil if the context does not exist.
func (mm *MemoryManager) DestroyAIOContext(ctx context.Context, id uint64) *AIOContext {
- if _, ok := mm.LookupAIOContext(ctx, id); !ok {
+ if !mm.isValidAddr(ctx, id) {
return nil
}
- // Only unmaps after it assured that the address is a valid aio context to
- // prevent random memory from been unmapped.
- //
- // Note: It's possible to unmap this address and map something else into
- // the same address. Then it would be unmapping memory that it doesn't own.
- // This is, however, the way Linux implements AIO. Keeps the same [weird]
- // semantics in case anyone relies on it.
- mm.MUnmap(ctx, usermem.Addr(id), aioRingBufferSize)
-
- return mm.aioManager.destroyAIOContext(id)
+ mm.aioManager.mu.Lock()
+ defer mm.aioManager.mu.Unlock()
+ return mm.destroyAIOContextLocked(ctx, id)
}
// LookupAIOContext looks up the given context. It returns false if the context
@@ -415,13 +423,18 @@ func (mm *MemoryManager) LookupAIOContext(ctx context.Context, id uint64) (*AIOC
return nil, false
}
- // Protect against 'ids' that are inaccessible (Linux also reads 4 bytes
- // from id).
- var buf [4]byte
- _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
- if err != nil {
+ // Protect against 'id' that is inaccessible.
+ if !mm.isValidAddr(ctx, id) {
return nil, false
}
return aioCtx, true
}
+
+// isValidAddr determines if the address `id` is valid. (Linux also reads 4
+// bytes from id).
+func (mm *MemoryManager) isValidAddr(ctx context.Context, id uint64) bool {
+ var buf [4]byte
+ _, err := mm.CopyIn(ctx, usermem.Addr(id), buf[:], usermem.IOOpts{})
+ return err == nil
+}
diff --git a/pkg/sentry/mm/aio_context_state.go b/pkg/sentry/mm/aio_context_state.go
index 3dabac1af..e8931922f 100644
--- a/pkg/sentry/mm/aio_context_state.go
+++ b/pkg/sentry/mm/aio_context_state.go
@@ -15,6 +15,6 @@
package mm
// afterLoad is invoked by stateify.
-func (a *AIOContext) afterLoad() {
- a.requestReady = make(chan struct{}, 1)
+func (ctx *AIOContext) afterLoad() {
+ ctx.requestReady = make(chan struct{}, 1)
}
diff --git a/pkg/sentry/mm/lifecycle.go b/pkg/sentry/mm/lifecycle.go
index 09dbc06a4..120707429 100644
--- a/pkg/sentry/mm/lifecycle.go
+++ b/pkg/sentry/mm/lifecycle.go
@@ -253,7 +253,7 @@ func (mm *MemoryManager) DecUsers(ctx context.Context) {
panic(fmt.Sprintf("Invalid MemoryManager.users: %d", users))
}
- mm.aioManager.destroy()
+ mm.destroyAIOManager(ctx)
mm.metadataMu.Lock()
exe := mm.executable
diff --git a/pkg/sentry/mm/mm_test.go b/pkg/sentry/mm/mm_test.go
index acac3d357..bc53bd41e 100644
--- a/pkg/sentry/mm/mm_test.go
+++ b/pkg/sentry/mm/mm_test.go
@@ -229,3 +229,46 @@ func TestIOAfterMProtect(t *testing.T) {
t.Errorf("CopyOut got %d want 1", n)
}
}
+
+// TestAIOPrepareAfterDestroy tests that AIOContext should not be able to be
+// prepared after destruction.
+func TestAIOPrepareAfterDestroy(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+ defer mm.DecUsers(ctx)
+
+ id, err := mm.NewAIOContext(ctx, 1)
+ if err != nil {
+ t.Fatalf("mm.NewAIOContext got err %v want nil", err)
+ }
+ aioCtx, ok := mm.LookupAIOContext(ctx, id)
+ if !ok {
+ t.Fatalf("AIOContext not found")
+ }
+ mm.DestroyAIOContext(ctx, id)
+
+ // Prepare should fail because aioCtx should be destroyed.
+ if err := aioCtx.Prepare(); err != syserror.EINVAL {
+ t.Errorf("aioCtx.Prepare got err %v want nil", err)
+ } else if err == nil {
+ aioCtx.CancelPendingRequest()
+ }
+}
+
+// TestAIOLookupAfterDestroy tests that AIOContext should not be able to be
+// looked up after memory manager is destroyed.
+func TestAIOLookupAfterDestroy(t *testing.T) {
+ ctx := contexttest.Context(t)
+ mm := testMemoryManager(ctx)
+
+ id, err := mm.NewAIOContext(ctx, 1)
+ if err != nil {
+ mm.DecUsers(ctx)
+ t.Fatalf("mm.NewAIOContext got err %v want nil", err)
+ }
+ mm.DecUsers(ctx) // This destroys the AIOContext manager.
+
+ if _, ok := mm.LookupAIOContext(ctx, id); ok {
+ t.Errorf("AIOContext found even after AIOContext manager is destroyed")
+ }
+}
diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go
index 7c297fb9e..d99be7f46 100644
--- a/pkg/sentry/pgalloc/pgalloc.go
+++ b/pkg/sentry/pgalloc/pgalloc.go
@@ -423,11 +423,7 @@ func (f *MemoryFile) Allocate(length uint64, kind usage.MemoryKind) (memmap.File
}
if f.opts.ManualZeroing {
- if err := f.forEachMappingSlice(fr, func(bs []byte) {
- for i := range bs {
- bs[i] = 0
- }
- }); err != nil {
+ if err := f.manuallyZero(fr); err != nil {
return memmap.FileRange{}, err
}
}
@@ -560,19 +556,39 @@ func (f *MemoryFile) Decommit(fr memmap.FileRange) error {
panic(fmt.Sprintf("invalid range: %v", fr))
}
+ if f.opts.ManualZeroing {
+ // FALLOC_FL_PUNCH_HOLE may not zero pages if ManualZeroing is in
+ // effect.
+ if err := f.manuallyZero(fr); err != nil {
+ return err
+ }
+ } else {
+ if err := f.decommitFile(fr); err != nil {
+ return err
+ }
+ }
+
+ f.markDecommitted(fr)
+ return nil
+}
+
+func (f *MemoryFile) manuallyZero(fr memmap.FileRange) error {
+ return f.forEachMappingSlice(fr, func(bs []byte) {
+ for i := range bs {
+ bs[i] = 0
+ }
+ })
+}
+
+func (f *MemoryFile) decommitFile(fr memmap.FileRange) error {
// "After a successful call, subsequent reads from this range will
// return zeroes. The FALLOC_FL_PUNCH_HOLE flag must be ORed with
// FALLOC_FL_KEEP_SIZE in mode ..." - fallocate(2)
- err := syscall.Fallocate(
+ return syscall.Fallocate(
int(f.file.Fd()),
_FALLOC_FL_PUNCH_HOLE|_FALLOC_FL_KEEP_SIZE,
int64(fr.Start),
int64(fr.Length()))
- if err != nil {
- return err
- }
- f.markDecommitted(fr)
- return nil
}
func (f *MemoryFile) markDecommitted(fr memmap.FileRange) {
@@ -1044,20 +1060,20 @@ func (f *MemoryFile) runReclaim() {
break
}
- if err := f.Decommit(fr); err != nil {
- log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
- // Zero the pages manually. This won't reduce memory usage, but at
- // least ensures that the pages will be zero when reallocated.
- f.forEachMappingSlice(fr, func(bs []byte) {
- for i := range bs {
- bs[i] = 0
+ // If ManualZeroing is in effect, pages will be zeroed on allocation
+ // and may not be freed by decommitFile, so calling decommitFile is
+ // unnecessary.
+ if !f.opts.ManualZeroing {
+ if err := f.decommitFile(fr); err != nil {
+ log.Warningf("Reclaim failed to decommit %v: %v", fr, err)
+ // Zero the pages manually. This won't reduce memory usage, but at
+ // least ensures that the pages will be zero when reallocated.
+ if err := f.manuallyZero(fr); err != nil {
+ panic(fmt.Sprintf("Reclaim failed to decommit or zero %v: %v", fr, err))
}
- })
- // Pretend the pages were decommitted even though they weren't,
- // since the memory accounting implementation has no idea how to
- // deal with this.
- f.markDecommitted(fr)
+ }
}
+ f.markDecommitted(fr)
f.markReclaimed(fr)
}
diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
index acad4c793..f8ccb7430 100644
--- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go
@@ -91,6 +91,13 @@ func bluepillSigBus(c *vCPU) {
}
}
+// bluepillHandleEnosys is reponsible for handling enosys error.
+//
+//go:nosplit
+func bluepillHandleEnosys(c *vCPU) {
+ throw("run failed: ENOSYS")
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection.
//
//go:nosplit
@@ -126,3 +133,10 @@ func bluepillReadyStopGuest(c *vCPU) bool {
}
return true
}
+
+// bluepillArchHandleExit checks architecture specific exitcode.
+//
+//go:nosplit
+func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) {
+ c.die(bluepillArchContext(context), "unknown")
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go
index 965ad66b5..1f09813ba 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64.go
@@ -42,6 +42,13 @@ var (
sErrEsr: _ESR_ELx_SERR_NMI,
},
}
+
+ // vcpuExtDabt is the event of ext_dabt.
+ vcpuExtDabt = kvmVcpuEvents{
+ exception: exception{
+ extDabtPending: 1,
+ },
+ }
)
// getTLS returns the value of TPIDR_EL0 register.
diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
index 9433d4da5..4d912769a 100644
--- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go
@@ -85,7 +85,7 @@ func bluepillStopGuest(c *vCPU) {
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 {
- throw("sErr injection failed")
+ throw("bounce sErr injection failed")
}
}
@@ -93,18 +93,54 @@ func bluepillStopGuest(c *vCPU) {
//
//go:nosplit
func bluepillSigBus(c *vCPU) {
+ // Host must support ARM64_HAS_RAS_EXTN.
if _, _, errno := syscall.RawSyscall( // escapes: no.
syscall.SYS_IOCTL,
uintptr(c.fd),
_KVM_SET_VCPU_EVENTS,
uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 {
- throw("sErr injection failed")
+ if errno == syscall.EINVAL {
+ throw("No ARM64_HAS_RAS_EXTN feature in host.")
+ }
+ throw("nmi sErr injection failed")
}
}
+// bluepillExtDabt is reponsible for injecting external data abort.
+//
+//go:nosplit
+func bluepillExtDabt(c *vCPU) {
+ if _, _, errno := syscall.RawSyscall( // escapes: no.
+ syscall.SYS_IOCTL,
+ uintptr(c.fd),
+ _KVM_SET_VCPU_EVENTS,
+ uintptr(unsafe.Pointer(&vcpuExtDabt))); errno != 0 {
+ throw("ext_dabt injection failed")
+ }
+}
+
+// bluepillHandleEnosys is reponsible for handling enosys error.
+//
+//go:nosplit
+func bluepillHandleEnosys(c *vCPU) {
+ bluepillExtDabt(c)
+}
+
// bluepillReadyStopGuest checks whether the current vCPU is ready for sError injection.
//
//go:nosplit
func bluepillReadyStopGuest(c *vCPU) bool {
return true
}
+
+// bluepillArchHandleExit checks architecture specific exitcode.
+//
+//go:nosplit
+func bluepillArchHandleExit(c *vCPU, context unsafe.Pointer) {
+ switch c.runData.exitReason {
+ case _KVM_EXIT_ARM_NISV:
+ bluepillExtDabt(c)
+ default:
+ c.die(bluepillArchContext(context), "unknown")
+ }
+}
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index 75085ac6a..8c5369377 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -148,6 +148,9 @@ func bluepillHandler(context unsafe.Pointer) {
// mode and have interrupts disabled.
bluepillSigBus(c)
continue // Rerun vCPU.
+ case syscall.ENOSYS:
+ bluepillHandleEnosys(c)
+ continue
default:
throw("run failed")
}
@@ -220,7 +223,7 @@ func bluepillHandler(context unsafe.Pointer) {
c.die(bluepillArchContext(context), "entry failed")
return
default:
- c.die(bluepillArchContext(context), "unknown")
+ bluepillArchHandleExit(c, context)
return
}
}
diff --git a/pkg/sentry/platform/kvm/kvm_arm64.go b/pkg/sentry/platform/kvm/kvm_arm64.go
index 0b06a923a..9db1db4e9 100644
--- a/pkg/sentry/platform/kvm/kvm_arm64.go
+++ b/pkg/sentry/platform/kvm/kvm_arm64.go
@@ -47,10 +47,11 @@ type userRegs struct {
}
type exception struct {
- sErrPending uint8
- sErrHasEsr uint8
- pad [6]uint8
- sErrEsr uint64
+ sErrPending uint8
+ sErrHasEsr uint8
+ extDabtPending uint8
+ pad [5]uint8
+ sErrEsr uint64
}
type kvmVcpuEvents struct {
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index 6abaa21c4..2492d57be 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -56,6 +56,7 @@ const (
_KVM_EXIT_FAIL_ENTRY = 0x9
_KVM_EXIT_INTERNAL_ERROR = 0x11
_KVM_EXIT_SYSTEM_EVENT = 0x18
+ _KVM_EXIT_ARM_NISV = 0x1c
)
// KVM capability options.
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
index 54837f20c..aa2d21748 100644
--- a/pkg/sentry/platform/kvm/machine_arm64.go
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -54,7 +54,7 @@ func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) {
pageTable.Map(
usermem.Addr(ring0.KernelStartAddress|pr.virtual),
pr.length,
- pagetables.MapOpts{AccessType: usermem.AnyAccess},
+ pagetables.MapOpts{AccessType: usermem.AnyAccess, Global: true},
pr.physical)
return true // Keep iterating.
diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
index f2459755b..a466acf4d 100644
--- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go
@@ -79,7 +79,7 @@ func (c *vCPU) initArchState() error {
}
// tcr_el1
- data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS | _TCR_A1
+ data = _TCR_TXSZ_VA48 | _TCR_CACHE_FLAGS | _TCR_SHARED | _TCR_TG_FLAGS | _TCR_ASID16 | _TCR_IPS_40BITS
reg.id = _KVM_ARM64_REGS_TCR_EL1
if err := c.setOneRegister(&reg); err != nil {
return err
@@ -103,7 +103,7 @@ func (c *vCPU) initArchState() error {
c.SetTtbr0Kvm(uintptr(data))
// ttbr1_el1
- data = c.machine.kernel.PageTables.TTBR1_EL1(false, 1)
+ data = c.machine.kernel.PageTables.TTBR1_EL1(false, 0)
reg.id = _KVM_ARM64_REGS_TTBR1_EL1
if err := c.setOneRegister(&reg); err != nil {
diff --git a/pkg/sentry/platform/ptrace/ptrace.go b/pkg/sentry/platform/ptrace/ptrace.go
index f56aa3b79..571bfcc2e 100644
--- a/pkg/sentry/platform/ptrace/ptrace.go
+++ b/pkg/sentry/platform/ptrace/ptrace.go
@@ -18,8 +18,8 @@
//
// In a nutshell, it works as follows:
//
-// The creation of a new address space creates a new child processes with a
-// single thread which is traced by a single goroutine.
+// The creation of a new address space creates a new child process with a single
+// thread which is traced by a single goroutine.
//
// A context is just a collection of temporary variables. Calling Switch on a
// context does the following:
diff --git a/pkg/sentry/platform/ptrace/subprocess.go b/pkg/sentry/platform/ptrace/subprocess.go
index 812ab80ef..aacd7ce70 100644
--- a/pkg/sentry/platform/ptrace/subprocess.go
+++ b/pkg/sentry/platform/ptrace/subprocess.go
@@ -590,7 +590,7 @@ func (s *subprocess) switchToApp(c *context, ac arch.Context) bool {
// facilitate vsyscall emulation. See patchSignalInfo.
patchSignalInfo(regs, &c.signalInfo)
return false
- } else if c.signalInfo.Code <= 0 && c.signalInfo.Pid() == int32(os.Getpid()) {
+ } else if c.signalInfo.Code <= 0 && c.signalInfo.PID() == int32(os.Getpid()) {
// The signal was generated by this process. That means
// that it was an interrupt or something else that we
// should bail for. Note that we ignore signals
diff --git a/pkg/sentry/platform/ring0/BUILD b/pkg/sentry/platform/ring0/BUILD
index 679b287c3..2852b7387 100644
--- a/pkg/sentry/platform/ring0/BUILD
+++ b/pkg/sentry/platform/ring0/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "arch_genrule", "go_library")
load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance")
package(licenses = ["notice"])
@@ -39,19 +39,19 @@ go_template_instance(
template = ":defs_arm64",
)
-genrule(
+arch_genrule(
name = "entry_impl_amd64",
srcs = ["entry_amd64.s"],
outs = ["entry_impl_amd64.s"],
- cmd = "(echo -e '// build +amd64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ cmd = "(echo -e '// build +amd64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_amd64.s)) > $@",
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
-genrule(
+arch_genrule(
name = "entry_impl_arm64",
srcs = ["entry_arm64.s"],
outs = ["entry_impl_arm64.s"],
- cmd = "(echo -e '// build +arm64\\n' && $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(SRCS)) > $@",
+ cmd = "(echo -e '// build +arm64\\n' && QEMU $(location //pkg/sentry/platform/ring0/gen_offsets) && cat $(location entry_arm64.s)) > $@",
tools = ["//pkg/sentry/platform/ring0/gen_offsets"],
)
@@ -72,7 +72,6 @@ go_library(
"lib_amd64.s",
"lib_arm64.go",
"lib_arm64.s",
- "lib_arm64_unsafe.go",
"ring0.go",
],
visibility = ["//pkg/sentry:internal"],
diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s
index 155f45ad8..b2bb18257 100644
--- a/pkg/sentry/platform/ring0/entry_arm64.s
+++ b/pkg/sentry/platform/ring0/entry_arm64.s
@@ -132,40 +132,6 @@
MOVD offset+PTRACE_R29(reg), R29; \
MOVD offset+PTRACE_R30(reg), R30;
-// NOP-s
-#define nop31Instructions() \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f; \
- WORD $0xd503201f;
-
#define ESR_ELx_EC_UNKNOWN (0x00)
#define ESR_ELx_EC_WFx (0x01)
/* Unallocated EC: 0x02 */
@@ -305,24 +271,20 @@
WORD $0xd538d092; //MRS TPIDR_EL1, R18
// SWITCH_TO_APP_PAGETABLE sets a new pagetable for a container application.
-#define SWITCH_TO_APP_PAGETABLE(from) \
- MRS TTBR1_EL1, R0; \
- MOVD CPU_APP_ASID(from), R1; \
- BFI $48, R1, $16, R0; \
- MSR R0, TTBR1_EL1; \ // set the ASID in TTBR1_EL1 (since TCR.A1 is set)
- ISB $15; \
- MOVD CPU_TTBR0_APP(from), RSV_REG; \
- MSR RSV_REG, TTBR0_EL1;
+#define SWITCH_TO_APP_PAGETABLE() \
+ MOVD CPU_APP_ASID(RSV_REG), RSV_REG_APP; \
+ MOVD CPU_TTBR0_APP(RSV_REG), RSV_REG; \
+ BFI $48, RSV_REG_APP, $16, RSV_REG; \
+ MSR RSV_REG, TTBR0_EL1; \
+ ISB $15;
// SWITCH_TO_KVM_PAGETABLE sets the kvm pagetable.
-#define SWITCH_TO_KVM_PAGETABLE(from) \
- MRS TTBR1_EL1, R0; \
- MOVD $1, R1; \
- BFI $48, R1, $16, R0; \
- MSR R0, TTBR1_EL1; \
- ISB $15; \
- MOVD CPU_TTBR0_KVM(from), RSV_REG; \
- MSR RSV_REG, TTBR0_EL1;
+#define SWITCH_TO_KVM_PAGETABLE() \
+ MOVD CPU_TTBR0_KVM(RSV_REG), RSV_REG; \
+ MOVD $1, RSV_REG_APP; \
+ BFI $48, RSV_REG_APP, $16, RSV_REG; \
+ MSR RSV_REG, TTBR0_EL1; \
+ ISB $15;
TEXT ·EnableVFP(SB),NOSPLIT,$0
MOVD $FPEN_ENABLE, R0
@@ -530,7 +492,7 @@ do_exit_to_el0:
WORD $0xd538d092 //MRS TPIDR_EL1, R18
- SWITCH_TO_APP_PAGETABLE(RSV_REG)
+ SWITCH_TO_APP_PAGETABLE()
LDP 16*1(RSP), (R0, R1)
LDP 16*0(RSP), (RSV_REG, RSV_REG_APP)
@@ -555,10 +517,10 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1
MOVD R1, RSP
- SWITCH_TO_KVM_PAGETABLE(RSV_REG)
+ REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
+ SWITCH_TO_KVM_PAGETABLE()
MRS TPIDR_EL1, RSV_REG
- REGISTERS_LOAD(RSV_REG, CPU_REGISTERS)
MOVD CPU_REGISTERS+PTRACE_R9(RSV_REG), RSV_REG_APP
ERET()
@@ -566,8 +528,16 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0
// Start is the CPU entrypoint.
TEXT ·Start(SB),NOSPLIT,$0
// Init.
- MOVD $SCTLR_EL1_DEFAULT, R1
- MSR R1, SCTLR_EL1
+ WORD $0xd508871f // __tlbi(vmalle1)
+ DSB $7 // dsb(nsh)
+
+ MOVD $1<<12, R1 // Reset mdscr_el1 and disable
+ MSR R1, MDSCR_EL1 // access to the DCC from EL0
+ ISB $15
+
+ MRS TTBR1_EL1, R1
+ MSR R1, TTBR0_EL1
+ ISB $15
MOVD $CNTKCTL_EL1_DEFAULT, R1
MSR R1, CNTKCTL_EL1
@@ -576,6 +546,15 @@ TEXT ·Start(SB),NOSPLIT,$0
ORR $0xffff000000000000, RSV_REG, RSV_REG
WORD $0xd518d092 //MSR R18, TPIDR_EL1
+ // Init.
+ MOVD $SCTLR_EL1_DEFAULT, R1 // re-enable the mmu.
+ MSR R1, SCTLR_EL1
+ ISB $15
+ WORD $0xd508751f // ic iallu
+
+ DSB $7 // dsb(nsh)
+ ISB $15
+
B ·kernelExitToEl1(SB)
// El1_sync_invalid is the handler for an invalid EL1_sync.
@@ -748,79 +727,43 @@ TEXT ·El0_error_invalid(SB),NOSPLIT,$0
B ·Shutdown(SB)
// Vectors implements exception vector table.
+// The start address of exception vector table should be 11-bits aligned.
+// For detail, please refer to arm developer document:
+// https://developer.arm.com/documentation/100933/0100/AArch64-exception-vector-table
+// Also can refer to the code in linux kernel: arch/arm64/kernel/entry.S
TEXT ·Vectors(SB),NOSPLIT,$0
+ PCALIGN $2048
B ·El1_sync_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_irq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_fiq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_error_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_sync(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_irq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_fiq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El1_error(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_sync(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_irq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_fiq(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_error(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_sync_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_irq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_fiq_invalid(SB)
- nop31Instructions()
+ PCALIGN $128
B ·El0_error_invalid(SB)
- nop31Instructions()
-
- // The exception-vector-table is required to be 11-bits aligned.
- // Please see Linux source code as reference: arch/arm64/kernel/entry.s.
- // For gvisor, I defined it as 4K in length, filled the 2nd 2K part with NOPs.
- // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
- WORD $0xd503201f //nop
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
-
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
- WORD $0xd503201f
- nop31Instructions()
diff --git a/pkg/sentry/platform/ring0/gen_offsets/BUILD b/pkg/sentry/platform/ring0/gen_offsets/BUILD
index 9742308d8..a9703baf6 100644
--- a/pkg/sentry/platform/ring0/gen_offsets/BUILD
+++ b/pkg/sentry/platform/ring0/gen_offsets/BUILD
@@ -24,6 +24,9 @@ go_binary(
"defs_impl_arm64.go",
"main.go",
],
+ # Use the libc malloc to avoid any extra dependencies. This is required to
+ # pass the sentry deps test.
+ system_malloc = True,
visibility = [
"//pkg/sentry/platform/kvm:__pkg__",
"//pkg/sentry/platform/ring0:__pkg__",
diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go
index 90a7b8392..c05284641 100644
--- a/pkg/sentry/platform/ring0/kernel_arm64.go
+++ b/pkg/sentry/platform/ring0/kernel_arm64.go
@@ -53,11 +53,17 @@ func IsCanonical(addr uint64) bool {
return addr <= 0x0000ffffffffffff || addr > 0xffff000000000000
}
+// SwitchToUser performs an eret.
+//
+// The return value is the exception vector.
+//
+// +checkescape:all
+//
//go:nosplit
func (c *CPU) SwitchToUser(switchOpts SwitchOpts) (vector Vector) {
storeAppASID(uintptr(switchOpts.UserASID))
if switchOpts.Flush {
- FlushTlbAll()
+ FlushTlbByASID(uintptr(switchOpts.UserASID))
}
regs := switchOpts.Registers
diff --git a/pkg/sentry/platform/ring0/lib_arm64.go b/pkg/sentry/platform/ring0/lib_arm64.go
index 0dffd33a3..a490bf3af 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.go
+++ b/pkg/sentry/platform/ring0/lib_arm64.go
@@ -22,19 +22,25 @@ func storeAppASID(asid uintptr)
// LocalFlushTlbAll same as FlushTlbAll, but only applies to the calling CPU.
func LocalFlushTlbAll()
-// FlushTlbAll flush all tlb.
+// FlushTlbByVA invalidates tlb by VA/Last-level/Inner-Shareable.
+func FlushTlbByVA(addr uintptr)
+
+// FlushTlbByASID invalidates tlb by ASID/Inner-Shareable.
+func FlushTlbByASID(asid uintptr)
+
+// FlushTlbAll invalidates all tlb.
func FlushTlbAll()
// CPACREL1 returns the value of the CPACR_EL1 register.
func CPACREL1() (value uintptr)
-// FPCR returns the value of FPCR register.
+// GetFPCR returns the value of FPCR register.
func GetFPCR() (value uintptr)
// SetFPCR writes the FPCR value.
func SetFPCR(value uintptr)
-// FPSR returns the value of FPSR register.
+// GetFPSR returns the value of FPSR register.
func GetFPSR() (value uintptr)
// SetFPSR writes the FPSR value.
@@ -62,6 +68,4 @@ func DisableVFP()
// Init sets function pointers based on architectural features.
//
// This must be called prior to using ring0.
-func Init() {
- rewriteVectors()
-}
+func Init() {}
diff --git a/pkg/sentry/platform/ring0/lib_arm64.s b/pkg/sentry/platform/ring0/lib_arm64.s
index 6f4923539..e39b32841 100644
--- a/pkg/sentry/platform/ring0/lib_arm64.s
+++ b/pkg/sentry/platform/ring0/lib_arm64.s
@@ -15,6 +15,23 @@
#include "funcdata.h"
#include "textflag.h"
+#define TLBI_ASID_SHIFT 48
+
+TEXT ·FlushTlbByVA(SB),NOSPLIT,$0-8
+ MOVD addr+0(FP), R1
+ DSB $10 // dsb(ishst)
+ WORD $0xd50883a1 // tlbi vale1is, x1
+ DSB $11 // dsb(ish)
+ RET
+
+TEXT ·FlushTlbByASID(SB),NOSPLIT,$0-8
+ MOVD asid+0(FP), R1
+ LSL $TLBI_ASID_SHIFT, R1, R1
+ DSB $10 // dsb(ishst)
+ WORD $0xd5088341 // tlbi aside1is, x1
+ DSB $11 // dsb(ish)
+ RET
+
TEXT ·LocalFlushTlbAll(SB),NOSPLIT,$0
DSB $6 // dsb(nshst)
WORD $0xd508871f // __tlbi(vmalle1)
diff --git a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go b/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
deleted file mode 100644
index c05166fea..000000000
--- a/pkg/sentry/platform/ring0/lib_arm64_unsafe.go
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2019 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// +build arm64
-
-package ring0
-
-import (
- "reflect"
- "syscall"
- "unsafe"
-
- "gvisor.dev/gvisor/pkg/safecopy"
- "gvisor.dev/gvisor/pkg/usermem"
-)
-
-const (
- nopInstruction = 0xd503201f
- instSize = unsafe.Sizeof(uint32(0))
- vectorsRawLen = 0x800
-)
-
-func unsafeSlice(addr uintptr, length int) (slice []uint32) {
- hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slice))
- hdr.Data = addr
- hdr.Len = length / int(instSize)
- hdr.Cap = length / int(instSize)
- return slice
-}
-
-// Work around: move ring0.Vectors() into a specific address with 11-bits alignment.
-//
-// According to the design documentation of Arm64,
-// the start address of exception vector table should be 11-bits aligned.
-// Please see the code in linux kernel as reference: arch/arm64/kernel/entry.S
-// But, we can't align a function's start address to a specific address by using golang.
-// We have raised this question in golang community:
-// https://groups.google.com/forum/m/#!topic/golang-dev/RPj90l5x86I
-// This function will be removed when golang supports this feature.
-//
-// There are 2 jobs were implemented in this function:
-// 1, move the start address of exception vector table into the specific address.
-// 2, modify the offset of each instruction.
-func rewriteVectors() {
- vectorsBegin := reflect.ValueOf(Vectors).Pointer()
-
- // The exception-vector-table is required to be 11-bits aligned.
- // And the size is 0x800.
- // Please see the documentation as reference:
- // https://developer.arm.com/docs/100933/0100/aarch64-exception-vector-table
- //
- // But, golang does not allow to set a function's address to a specific value.
- // So, for gvisor, I defined the size of exception-vector-table as 4K,
- // filled the 2nd 2K part with NOP-s.
- // So that, I can safely move the 1st 2K part into the address with 11-bits alignment.
- //
- // So, the prerequisite for this function to work correctly is:
- // vectorsSafeLen >= 0x1000
- // vectorsRawLen = 0x800
- vectorsSafeLen := int(safecopy.FindEndAddress(vectorsBegin) - vectorsBegin)
- if vectorsSafeLen < 2*vectorsRawLen {
- panic("Can't update vectors")
- }
-
- vectorsSafeTable := unsafeSlice(vectorsBegin, vectorsSafeLen) // Now a []uint32
- vectorsRawLen32 := vectorsRawLen / int(instSize)
-
- offset := vectorsBegin & (1<<11 - 1)
- if offset != 0 {
- offset = 1<<11 - offset
- }
-
- pageBegin := (vectorsBegin + offset) & ^uintptr(usermem.PageSize-1)
-
- _, _, errno := syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_WRITE|syscall.PROT_EXEC))
- if errno != 0 {
- panic(errno.Error())
- }
-
- offset = offset / instSize // By index, not bytes.
- // Move exception-vector-table into the specific address, should uses memmove here.
- for i := 1; i <= vectorsRawLen32; i++ {
- vectorsSafeTable[int(offset)+vectorsRawLen32-i] = vectorsSafeTable[vectorsRawLen32-i]
- }
-
- // Adjust branch since instruction was moved forward.
- for i := 0; i < vectorsRawLen32; i++ {
- if vectorsSafeTable[int(offset)+i] != nopInstruction {
- vectorsSafeTable[int(offset)+i] -= uint32(offset)
- }
- }
-
- _, _, errno = syscall.Syscall(syscall.SYS_MPROTECT, uintptr(pageBegin), uintptr(usermem.PageSize), uintptr(syscall.PROT_READ|syscall.PROT_EXEC))
- if errno != 0 {
- panic(errno.Error())
- }
-}
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index a3f775d15..cc1f6bfcc 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
+ "//pkg/tcpip/header",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index ca16d0381..ebcc891b3 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -1,4 +1,4 @@
-load("//tools:defs.bzl", "go_library")
+load("//tools:defs.bzl", "go_library", "go_test")
package(licenses = ["notice"])
@@ -23,7 +23,20 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//pkg/tcpip",
"//pkg/usermem",
],
)
+
+go_test(
+ name = "control_test",
+ size = "small",
+ srcs = ["control_test.go"],
+ library = ":control",
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/socket",
+ "//pkg/usermem",
+ "@com_github_google_go_cmp//cmp:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 70ccf77a7..65b556489 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -344,18 +343,42 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
}
// PackIPPacketInfo packs an IP_PKTINFO socket control message.
-func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
- var p linux.ControlMessageIPPacketInfo
- p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
-
+func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
linux.IP_PKTINFO,
t.Arch().Width(),
- p,
+ packetInfo,
+ )
+}
+
+// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
+func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
+ var level uint32
+ var optType uint32
+ switch originalDstAddress.(type) {
+ case *linux.SockAddrInet:
+ level = linux.SOL_IP
+ optType = linux.IP_RECVORIGDSTADDR
+ case *linux.SockAddrInet6:
+ level = linux.SOL_IPV6
+ optType = linux.IPV6_RECVORIGDSTADDR
+ default:
+ panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
+ }
+ return putCmsgStruct(
+ buf, level, optType, t.Arch().Width(), originalDstAddress)
+}
+
+// PackSockExtendedErr packs an IP*_RECVERR socket control message.
+func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ sockErr.CMsgLevel(),
+ sockErr.CMsgType(),
+ t.Arch().Width(),
+ sockErr,
)
}
@@ -384,7 +407,15 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
}
if cmsgs.IP.HasIPPacketInfo {
- buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
+ }
+
+ if cmsgs.IP.OriginalDstAddress != nil {
+ buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
+ }
+
+ if cmsgs.IP.SockErr != nil {
+ buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf)
}
return buf
@@ -416,21 +447,23 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
}
- return space
-}
+ if cmsgs.IP.HasIPPacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
+ }
-// NewIPPacketInfo returns the IPPacketInfo struct.
-func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
- var p tcpip.IPPacketInfo
- p.NIC = tcpip.NICID(packetInfo.NIC)
- copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
- copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+ if cmsgs.IP.OriginalDstAddress != nil {
+ space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
+ }
- return p
+ if cmsgs.IP.SockErr != nil {
+ space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes())
+ }
+
+ return space
}
// Parse parses a raw socket control message into portable objects.
-func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) {
var (
cmsgs socket.ControlMessages
fds linux.ControlMessageRights
@@ -454,10 +487,6 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
i += linux.SizeOfControlMessageHeader
length := int(h.Length) - linux.SizeOfControlMessageHeader
- // The use of t.Arch().Width() is analogous to Linux's use of
- // sizeof(long) in CMSG_ALIGN.
- width := t.Arch().Width()
-
switch h.Level {
case linux.SOL_SOCKET:
switch h.Type {
@@ -489,6 +518,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
cmsgs.Unix.Credentials = scmCreds
i += binary.AlignUp(length, width)
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ var ts linux.Timeval
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &ts)
+ cmsgs.IP.Timestamp = ts.ToNsecCapped()
+ cmsgs.IP.HasTimestamp = true
+ i += binary.AlignUp(length, width)
+
default:
// Unknown message type.
return socket.ControlMessages{}, syserror.EINVAL
@@ -512,7 +551,26 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ cmsgs.IP.PacketInfo = packetInfo
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv4
+ if length < errCmsg.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ cmsgs.IP.SockErr = &errCmsg
i += binary.AlignUp(length, width)
default:
@@ -528,6 +586,25 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += binary.AlignUp(length, width)
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
+ case linux.IPV6_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv6
+ if length < errCmsg.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ cmsgs.IP.SockErr = &errCmsg
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
diff --git a/pkg/sentry/socket/control/control_test.go b/pkg/sentry/socket/control/control_test.go
new file mode 100644
index 000000000..d40a4cc85
--- /dev/null
+++ b/pkg/sentry/socket/control/control_test.go
@@ -0,0 +1,59 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package control provides internal representations of socket control
+// messages.
+package control
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func TestParse(t *testing.T) {
+ // Craft the control message to parse.
+ length := linux.SizeOfControlMessageHeader + linux.SizeOfTimeval
+ hdr := linux.ControlMessageHeader{
+ Length: uint64(length),
+ Level: linux.SOL_SOCKET,
+ Type: linux.SO_TIMESTAMP,
+ }
+ buf := make([]byte, 0, length)
+ buf = binary.Marshal(buf, usermem.ByteOrder, &hdr)
+ ts := linux.Timeval{
+ Sec: 2401,
+ Usec: 343,
+ }
+ buf = binary.Marshal(buf, usermem.ByteOrder, &ts)
+
+ cmsg, err := Parse(nil, nil, buf, 8 /* width */)
+ if err != nil {
+ t.Fatalf("Parse(_, _, %+v, _): %v", cmsg, err)
+ }
+
+ want := socket.ControlMessages{
+ IP: socket.IPControlMessages{
+ HasTimestamp: true,
+ Timestamp: ts.ToNsecCapped(),
+ },
+ }
+ if diff := cmp.Diff(want, cmsg); diff != "" {
+ t.Errorf("unexpected message parsed, (-want, +got):\n%s", diff)
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7d3c4a01c..5b868216d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
case linux.SO_LINGER:
optlen = syscall.SizeofLinger
@@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
case linux.IP_PKTINFO:
optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
}
case linux.SOL_TCP:
switch name {
- case linux.TCP_NODELAY:
+ case linux.TCP_NODELAY, linux.TCP_INQ:
optlen = sizeofInt32
}
}
@@ -416,68 +416,76 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
return nil
}
-// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- // Only allow known and safe flags.
- //
- // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
- // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
- // Socket interface's dependence on netstack.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
- }
+func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) {
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
- var senderAddr linux.SockAddr
+ msg := syscall.Msghdr{}
+ if len(iovs) > 0 {
+ msg.Iov = &iovs[0]
+ msg.Iovlen = uint64(len(iovs))
+ }
var senderAddrBuf []byte
if senderRequested {
senderAddrBuf = make([]byte, sizeofSockaddr)
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(sizeofSockaddr)
}
-
var controlBuf []byte
- var msgFlags int
-
- recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
- // Refuse to do anything if any part of dst.Addrs was unusable.
- if uint64(dst.NumBytes()) != dsts.NumBytes() {
- return 0, nil
- }
- if dsts.IsEmpty() {
- return 0, nil
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
}
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err
+ }
+ return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err
+}
- // We always do a non-blocking recv*().
- sysflags := flags | syscall.MSG_DONTWAIT
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ // Only allow known and safe flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
+ }
- iovs := safemem.IovecsFromBlockSeq(dsts)
- msg := syscall.Msghdr{
- Iov: &iovs[0],
- Iovlen: uint64(len(iovs)),
- }
- if len(senderAddrBuf) != 0 {
- msg.Name = &senderAddrBuf[0]
- msg.Namelen = uint32(len(senderAddrBuf))
- }
- if controlLen > 0 {
- if controlLen > maxControlLen {
- controlLen = maxControlLen
+ var senderAddrBuf []byte
+ var controlBuf []byte
+ var msgFlags int
+ copyToDst := func() (int64, error) {
+ var n uint64
+ var err error
+ if dst.NumBytes() == 0 {
+ // We want to make the recvmsg(2) call to the host even if dst is empty
+ // to fetch control messages, sender address or errors if any occur.
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen)
+ return int64(n), err
+ }
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
}
- controlBuf = make([]byte, controlLen)
- msg.Control = &controlBuf[0]
- msg.Controllen = controlLen
- }
- n, err := recvmsg(s.fd, &msg, sysflags)
- if err != nil {
- return 0, err
- }
- senderAddrBuf = senderAddrBuf[:msg.Namelen]
- msgFlags = int(msg.Flags)
- controlLen = uint64(msg.Controllen)
- return n, nil
- })
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen)
+ return n, err
+ })
+ return dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
var ch chan struct{}
- n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
- if flags&syscall.MSG_DONTWAIT == 0 {
+ n, err := copyToDst()
+ // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
+ if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 {
for err == syserror.ErrWouldBlock {
// We only expect blocking to come from the actual syscall, in which
// case it can't have returned any data.
@@ -494,48 +502,85 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
s.EventRegister(&e, waiter.EventIn)
defer s.EventUnregister(&e)
}
- n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ n, err = copyToDst()
}
}
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ var senderAddr linux.SockAddr
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil
+}
+func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages {
controlMessages := socket.ControlMessages{}
for _, unixCmsg := range unixControlMessages {
switch unixCmsg.Header.Level {
- case syscall.SOL_IP:
+ case linux.SOL_SOCKET:
switch unixCmsg.Header.Type {
- case syscall.IP_TOS:
+ case linux.SO_TIMESTAMP:
+ controlMessages.IP.HasTimestamp = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp)
+ }
+
+ case linux.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case linux.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
- case syscall.IP_PKTINFO:
+ case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ controlMessages.IP.PacketInfo = packetInfo
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+
+ case syscall.IP_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv4
+ errCmsg.UnmarshalBytes(unixCmsg.Data)
+ controlMessages.IP.SockErr = &errCmsg
}
- case syscall.SOL_IPV6:
+ case linux.SOL_IPV6:
switch unixCmsg.Header.Type {
- case syscall.IPV6_TCLASS:
+ case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+
+ case syscall.IPV6_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv6
+ errCmsg.UnmarshalBytes(unixCmsg.Data)
+ controlMessages.IP.SockErr = &errCmsg
+ }
+
+ case linux.SOL_TCP:
+ switch unixCmsg.Header.Type {
+ case linux.TCP_INQ:
+ controlMessages.IP.HasInq = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq)
}
}
}
-
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
+ return controlMessages
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index fae3b6783..b2206900b 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -25,7 +25,6 @@ go_library(
"//pkg/marshal",
"//pkg/marshal/primitive",
"//pkg/metric",
- "//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e8a0103bf..dcf898c0a 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -28,9 +28,9 @@ import (
"bytes"
"fmt"
"io"
+ "io/ioutil"
"math"
"reflect"
- "sync/atomic"
"syscall"
"time"
@@ -43,7 +43,6 @@ import (
"gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/metric"
- "gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -84,69 +83,95 @@ var Metrics = tcpip.Stats{
MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
ICMP: tcpip.ICMPStats{
- V4PacketsSent: tcpip.ICMPv4SentPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ V4: tcpip.ICMPv4Stats{
+ PacketsSent: tcpip.ICMPv4SentPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
},
- V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
- ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
- Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
- SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
- Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
- Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
- TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
- InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
- InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ V6: tcpip.ICMPv6Stats{
+ PacketsSent: tcpip.ICMPv6SentPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ },
+ PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
},
- V6PacketsSent: tcpip.ICMPv6SentPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ IGMP: tcpip.IGMPStats{
+ PacketsSent: tcpip.IGMPSentPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."),
},
- Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."),
},
- V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
- ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
- EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
- EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
- DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
- PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
- TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
- ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
- RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
- RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
- NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
- NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
- RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ PacketsReceived: tcpip.IGMPReceivedPacketStats{
+ IGMPPacketStats: tcpip.IGMPPacketStats{
+ MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."),
+ V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."),
+ V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."),
+ LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."),
},
- Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."),
+ ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."),
+ Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."),
},
},
IP: tcpip.IPStats{
@@ -209,18 +234,6 @@ const sizeOfInt32 int = 4
var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
-// ntohs converts a 16-bit number from network byte order to host byte order. It
-// assumes that the host is little endian.
-func ntohs(v uint16) uint16 {
- return v<<8 | v>>8
-}
-
-// htons converts a 16-bit number from host byte order to network byte order. It
-// assumes that the host is little endian.
-func htons(v uint16) uint16 {
- return ntohs(v)
-}
-
// commonEndpoint represents the intersection of a tcpip.Endpoint and a
// transport.Endpoint.
type commonEndpoint interface {
@@ -240,10 +253,6 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
- // transport.Endpoint.SetSockOptBool.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
// transport.Endpoint.SetSockOptInt.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -252,14 +261,14 @@ type commonEndpoint interface {
// transport.Endpoint.GetSockOpt.
GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
- // transport.Endpoint.GetSockOpt.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+ // State returns a socket's lifecycle state. The returned value is
+ // protocol-specific and is primarily used for diagnostics.
+ State() uint32
+
// LastError implements tcpip.Endpoint.LastError and
// transport.Endpoint.LastError.
LastError() *tcpip.Error
@@ -298,19 +307,11 @@ type socketOpsCommon struct {
skType linux.SockType
protocol int
- // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
- // Must be accessed using atomic operations. It must only be written
- // with readMu held but can be read without holding readMu. The latter
- // is required to avoid deadlocks in epoll Readiness checks.
- readViewHasData uint32
-
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
- // readView contains the remaining payload from the last packet.
- readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
+ readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
@@ -326,17 +327,15 @@ type socketOpsCommon struct {
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
- // sockOptInq corresponds to TCP_INQ. It is implemented at this level
- // because it takes into account data from readView.
+ // TODO(b/153685824): Move this to SocketOptions.
+ // sockOptInq corresponds to TCP_INQ.
sockOptInq bool
}
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
dirent := socket.NewDirent(t, netstackDevice)
@@ -365,127 +364,27 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
return tcpip.Address(addr)
}
-// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
-// AF_INET6, and AF_PACKET addresses.
-//
-// AddressAndFamily returns an address and its family.
-func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
- // Make sure we have at least 2 bytes for the address family.
- if len(addr) < 2 {
- return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
- }
-
- // Get the rest of the fields based on the address family.
- switch family := usermem.ByteOrder.Uint16(addr); family {
- case linux.AF_UNIX:
- path := addr[2:]
- if len(path) > linux.UnixPathMax {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- // Drop the terminating NUL (if one exists) and everything after
- // it for filesystem (non-abstract) addresses.
- if len(path) > 0 && path[0] != 0 {
- if n := bytes.IndexByte(path[1:], 0); n >= 0 {
- path = path[:n+1]
- }
- }
- return tcpip.FullAddress{
- Addr: tcpip.Address(path),
- }, family, nil
-
- case linux.AF_INET:
- var a linux.SockAddrInet
- if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- return out, family, nil
-
- case linux.AF_INET6:
- var a linux.SockAddrInet6
- if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
-
- out := tcpip.FullAddress{
- Addr: bytesToIPAddress(a.Addr[:]),
- Port: ntohs(a.Port),
- }
- if isLinkLocal(out.Addr) {
- out.NIC = tcpip.NICID(a.Scope_id)
- }
- return out, family, nil
-
- case linux.AF_PACKET:
- var a linux.SockAddrLink
- if len(addr) < sockAddrLinkSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
- binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
- if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
- return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
- }
-
- // TODO(gvisor.dev/issue/173): Return protocol too.
- return tcpip.FullAddress{
- NIC: tcpip.NICID(a.InterfaceIndex),
- Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
- }, family, nil
-
- case linux.AF_UNSPEC:
- return tcpip.FullAddress{}, family, nil
-
- default:
- return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
- }
-}
-
func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
-// fetchReadView updates the readView field of the socket if it's currently
-// empty. It assumes that the socket is locked.
-//
// Precondition: s.readMu must be held.
-func (s *socketOpsCommon) fetchReadView() *syserr.Error {
- if len(s.readView) > 0 {
- return nil
- }
- s.readView = nil
- s.sender = tcpip.FullAddress{}
- s.linkPacketInfo = tcpip.LinkPacketInfo{}
+func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
+ res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
+ Peek: peek,
+ NeedRemoteAddr: true,
+ NeedLinkPacketInfo: true,
+ })
- var v buffer.View
- var cms tcpip.ControlMessages
- var err *tcpip.Error
+ // Assign these anyways.
+ s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
+ s.sender = res.RemoteAddr
+ s.linkPacketInfo = res.LinkPacketInfo
- switch e := s.Endpoint.(type) {
- // The ordering of these interfaces matters. The most specific
- // interfaces must be specified before the more generic Endpoint
- // interface.
- case tcpip.PacketEndpoint:
- v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo)
- case tcpip.Endpoint:
- v, cms, err = e.Read(&s.sender)
- }
if err != nil {
- atomic.StoreUint32(&s.readViewHasData, 0)
- return syserr.TranslateNetstackError(err)
+ return 0, 0, syserr.TranslateNetstackError(err)
}
-
- s.readView = v
- s.readCM = cms
- atomic.StoreUint32(&s.readViewHasData, 1)
-
- return nil
+ return res.Count, res.Total, nil
}
// Release implements fs.FileOperations.Release.
@@ -502,11 +401,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
return
}
- var v tcpip.LingerOption
- if err := s.Endpoint.GetSockOpt(&v); err != nil {
- return
- }
-
+ v := s.Endpoint.SocketOptions().GetLinger()
// The case for zero timeout is handled in tcp endpoint close function.
// Close is blocked until either:
// 1. The endpoint state is not in any of the states: FIN-WAIT1,
@@ -538,38 +433,14 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// WriteTo implements fs.FileOperations.WriteTo.
func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
s.readMu.Lock()
+ defer s.readMu.Unlock()
- // Copy as much data as possible.
- done := int64(0)
- for count > 0 {
- // This may return a blocking error.
- if err := s.fetchReadView(); err != nil {
- s.readMu.Unlock()
- return done, err.ToError()
- }
-
- // Write to the underlying file.
- n, err := dst.Write(s.readView)
- done += int64(n)
- count -= int64(n)
- if dup {
- // That's all we support for dup. This is generally
- // supported by any Linux system calls, but the
- // expectation is that now a caller will call read to
- // actually remove these bytes from the socket.
- break
- }
-
- // Drop that part of the view.
- s.readView.TrimFront(n)
- if err != nil {
- s.readMu.Unlock()
- return done, err
- }
+ // This may return a blocking error.
+ n, _, err := s.readLocked(dst, int(count), dup /* peek */)
+ if err != nil {
+ return 0, err.ToError()
}
-
- s.readMu.Unlock()
- return done, nil
+ return int64(n), nil
}
// ioSequencePayload implements tcpip.Payload.
@@ -705,17 +576,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
// Readiness returns a mask of ready events for socket s.
func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
- r := s.Endpoint.Readiness(mask)
-
- // Check our cached value iff the caller asked for readability and the
- // endpoint itself is currently not readable.
- if (mask & ^r & waiter.EventIn) != 0 {
- if atomic.LoadUint32(&s.readViewHasData) == 1 {
- r |= waiter.EventIn
- }
- }
-
- return r
+ return s.Endpoint.Readiness(mask)
}
func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
@@ -723,11 +584,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
return nil
}
if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
- v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return syserr.TranslateNetstackError(err)
- }
- if !v {
+ if !s.Endpoint.SocketOptions().GetV6Only() {
return nil
}
}
@@ -751,7 +608,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -832,7 +689,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
} else {
var err *syserr.Error
- addr, family, err = AddressAndFamily(sockaddr)
+ addr, family, err = socket.AddressAndFamily(sockaddr)
if err != nil {
return err
}
@@ -923,7 +780,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -1007,7 +864,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
case linux.SOL_TCP:
- return getSockOptTCP(t, ep, name, outLen)
+ return getSockOptTCP(t, s, ep, name, outLen)
case linux.SOL_IPV6:
return getSockOptIPv6(t, s, ep, name, outPtr, outLen)
@@ -1043,7 +900,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.LastError()
+ err := ep.SocketOptions().GetLastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1124,10 +981,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return &v, nil
case linux.SO_BINDTODEVICE:
- var v tcpip.BindToDeviceOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetBindToDevice()
if v == 0 {
var b primitive.ByteSlice
return &b, nil
@@ -1170,11 +1024,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.LingerOption
var linger linux.Linger
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetLinger()
if v.Enabled {
linger.OnOff = 1
@@ -1205,13 +1056,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.OutOfBandInlineOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(v)
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline()))
+ return &v, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1226,8 +1072,13 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- v := primitive.Int32(boolToInt32(ep.SocketOptions().GetAcceptConn()))
- return &v, nil
+ // This option is only viable for TCP endpoints.
+ var v bool
+ if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) {
+ v = tcp.EndpointState(ep.State()) == tcp.StateListen
+ }
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1236,46 +1087,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.DelayOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(!v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption()))
+ return &v, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.CorkOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption()))
+ return &v, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck()))
+ return &v, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1449,19 +1290,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, _ := s.Type()
+ if family != linux.AF_INET6 {
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only()))
+ return &v, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1493,13 +1339,23 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
+ return &v, nil
+ case linux.IPV6_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
@@ -1511,7 +1367,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet6), nil
case linux.IP6T_SO_GET_INFO:
@@ -1520,7 +1376,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1540,7 +1396,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return nil, syserr.ErrInvalidArgument
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1560,7 +1416,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return nil, syserr.ErrProtocolNotAvailable
}
@@ -1582,6 +1438,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
// getSockOptIP implements GetSockOpt when level is SOL_IP.
func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return nil, syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1624,7 +1485,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
return &a.(*linux.SockAddrInet).Addr, nil
@@ -1633,13 +1494,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop()))
+ return &v, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
@@ -1663,26 +1519,40 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
+ return &v, nil
+
+ case linux.IP_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
- if err != nil {
- return nil, syserr.TranslateNetstackError(err)
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo()))
+ return &v, nil
+
+ case linux.IP_HDRINCL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
}
- vP := primitive.Int32(boolToInt32(v))
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
+ return &v, nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
@@ -1694,7 +1564,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return nil, syserr.TranslateNetstackError(err)
}
- a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
+ a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet), nil
case linux.IPT_SO_GET_INFO:
@@ -1801,7 +1671,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int
return setSockOptSocket(t, s, ep, name, optVal)
case linux.SOL_TCP:
- return setSockOptTCP(t, ep, name, optVal)
+ return setSockOptTCP(t, s, ep, name, optVal)
case linux.SOL_IPV6:
return setSockOptIPv6(t, s, ep, name, optVal)
@@ -1870,8 +1740,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
name := string(optVal[:n])
if name == "" {
- v := tcpip.BindToDeviceOption(0)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0))
}
s := t.NetworkContext()
if s == nil {
@@ -1879,8 +1748,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
for nicID, nic := range s.Interfaces() {
if nic.Name == name {
- v := tcpip.BindToDeviceOption(nicID)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID))
}
}
return syserr.ErrUnknownDevice
@@ -1949,8 +1817,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- opt := tcpip.OutOfBandInlineOption(v)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
+ ep.SocketOptions().SetOutOfBandInline(v != 0)
+ return nil
case linux.SO_NO_CHECK:
if len(optVal) < sizeOfInt32 {
@@ -1973,10 +1841,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return syserr.TranslateNetstackError(
- ep.SetSockOpt(&tcpip.LingerOption{
- Enabled: v.OnOff != 0,
- Timeout: time.Second * time.Duration(v.Linger)}))
+ ep.SocketOptions().SetLinger(tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger),
+ })
+ return nil
case linux.SO_DETACH_FILTER:
// optval is ignored.
@@ -1991,7 +1860,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
-func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) {
+ log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.TCP_NODELAY:
if len(optVal) < sizeOfInt32 {
@@ -1999,7 +1873,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
+ ep.SocketOptions().SetDelayOption(v == 0)
+ return nil
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -2007,7 +1882,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
+ ep.SocketOptions().SetCorkOption(v != 0)
+ return nil
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -2015,7 +1891,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
+ ep.SocketOptions().SetQuickAck(v != 0)
+ return nil
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -2127,18 +2004,55 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
+ family, skType, skProto := s.Type()
+ if family != linux.AF_INET6 {
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IPV6_V6ONLY:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
}
+ if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial {
+ return syserr.ErrInvalidEndpointState
+ }
+
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
+ ep.SocketOptions().SetV6Only(v != 0)
+ return nil
+
+ case linux.IPV6_ADD_MEMBERSHIP:
+ req, err := copyInMulticastV6Request(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.AddMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
- case linux.IPV6_ADD_MEMBERSHIP,
- linux.IPV6_DROP_MEMBERSHIP,
- linux.IPV6_IPSEC_POLICY,
+ case linux.IPV6_DROP_MEMBERSHIP:
+ req, err := copyInMulticastV6Request(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.RemoveMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
// TODO(b/148887420): Add support for IPV6_PKTINFO.
@@ -2154,6 +2068,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2173,7 +2096,18 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+ ep.SocketOptions().SetReceiveTClass(v != 0)
+ return nil
+ case linux.IPV6_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2181,7 +2115,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
}
// Only valid for raw IPv6 sockets.
- if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ if skType != linux.SOCK_RAW {
return syserr.ErrProtocolNotAvailable
}
@@ -2206,6 +2140,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
var (
inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{}))
inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{}))
+ inet6MulticastRequestSize = int(binary.Size(linux.Inet6MulticastRequest{}))
)
// copyInMulticastRequest copies in a variable-size multicast request. The
@@ -2239,6 +2174,16 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR
return req, nil
}
+func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syserr.Error) {
+ if len(optVal) < inet6MulticastRequestSize {
+ return linux.Inet6MulticastRequest{}, syserr.ErrInvalidArgument
+ }
+
+ var req linux.Inet6MulticastRequest
+ binary.Unmarshal(optVal[:inet6MulticastRequestSize], usermem.ByteOrder, &req)
+ return req, nil
+}
+
// parseIntOrChar copies either a 32-bit int or an 8-bit uint out of buf.
//
// net/ipv4/ip_sockglue.c:do_ip_setsockopt does this for its socket options.
@@ -2256,6 +2201,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
// setSockOptIP implements SetSockOpt when level is SOL_IP.
func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ if _, ok := ep.(tcpip.Endpoint); !ok {
+ log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name)
+ return syserr.ErrUnknownProtocolOption
+ }
+
switch name {
case linux.IP_MULTICAST_TTL:
v, err := parseIntOrChar(optVal)
@@ -2308,7 +2258,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{
NIC: tcpip.NICID(req.InterfaceIndex),
- InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
+ InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]),
}))
case linux.IP_MULTICAST_LOOP:
@@ -2317,7 +2267,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
+ ep.SocketOptions().SetMulticastLoop(v != 0)
+ return nil
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -2353,7 +2304,19 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ ep.SocketOptions().SetReceiveTOS(v != 0)
+ return nil
+
+ case linux.IP_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
case linux.IP_PKTINFO:
if len(optVal) == 0 {
@@ -2363,7 +2326,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+ ep.SocketOptions().SetReceivePacketInfo(v != 0)
+ return nil
case linux.IP_HDRINCL:
if len(optVal) == 0 {
@@ -2373,7 +2337,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+ ep.SocketOptions().SetHeaderIncluded(v != 0)
+ return nil
+
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
@@ -2410,10 +2387,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2487,11 +2462,9 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_MULTICAST_IF,
linux.IPV6_MULTICAST_LOOP,
linux.IPV6_RECVDSTOPTS,
- linux.IPV6_RECVERR,
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2515,7 +2488,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
switch name {
case linux.IP_TOS,
linux.IP_TTL,
- linux.IP_HDRINCL,
linux.IP_OPTIONS,
linux.IP_ROUTER_ALERT,
linux.IP_RECVOPTS,
@@ -2523,7 +2495,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
linux.IP_PKTINFO,
linux.IP_PKTOPTIONS,
linux.IP_MTU_DISCOVER,
- linux.IP_RECVERR,
linux.IP_RECVTTL,
linux.IP_RECVTOS,
linux.IP_MTU,
@@ -2562,72 +2533,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
}
}
-// isLinkLocal determines if the given IPv6 address is link-local. This is the
-// case when it has the fe80::/10 prefix. This check is used to determine when
-// the NICID is relevant for a given IPv6 address.
-func isLinkLocal(addr tcpip.Address) bool {
- return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
-}
-
-// ConvertAddress converts the given address to a native format.
-func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
- switch family {
- case linux.AF_UNIX:
- var out linux.SockAddrUnix
- out.Family = linux.AF_UNIX
- l := len([]byte(addr.Addr))
- for i := 0; i < l; i++ {
- out.Path[i] = int8(addr.Addr[i])
- }
-
- // Linux returns the used length of the address struct (including the
- // null terminator) for filesystem paths. The Family field is 2 bytes.
- // It is sometimes allowed to exclude the null terminator if the
- // address length is the max. Abstract and empty paths always return
- // the full exact length.
- if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
- return &out, uint32(2 + l)
- }
- return &out, uint32(3 + l)
-
- case linux.AF_INET:
- var out linux.SockAddrInet
- copy(out.Addr[:], addr.Addr)
- out.Family = linux.AF_INET
- out.Port = htons(addr.Port)
- return &out, uint32(sockAddrInetSize)
-
- case linux.AF_INET6:
- var out linux.SockAddrInet6
- if len(addr.Addr) == header.IPv4AddressSize {
- // Copy address in v4-mapped format.
- copy(out.Addr[12:], addr.Addr)
- out.Addr[10] = 0xff
- out.Addr[11] = 0xff
- } else {
- copy(out.Addr[:], addr.Addr)
- }
- out.Family = linux.AF_INET6
- out.Port = htons(addr.Port)
- if isLinkLocal(addr.Addr) {
- out.Scope_id = uint32(addr.NIC)
- }
- return &out, uint32(sockAddrInet6Size)
-
- case linux.AF_PACKET:
- // TODO(gvisor.dev/issue/173): Return protocol too.
- var out linux.SockAddrLink
- out.Family = linux.AF_PACKET
- out.InterfaceIndex = int32(addr.NIC)
- out.HardwareAddrLen = header.EthernetAddressSize
- copy(out.HardwareAddr[:], addr.Addr)
- return &out, uint32(sockAddrLinkSize)
-
- default:
- return nil, 0
- }
-}
-
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
@@ -2636,7 +2541,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
@@ -2648,70 +2553,24 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := ConvertAddress(s.family, addr)
+ a, l := socket.ConvertAddress(s.family, addr)
return a, l, nil
}
-// coalescingRead is the fast path for non-blocking, non-peek, stream-based
-// case. It coalesces as many packets as possible before returning to the
-// caller.
+// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
//
// Precondition: s.readMu must be locked.
-func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
- var err *syserr.Error
- var copied int
-
- // Copy as many views as possible into the user-provided buffer.
- for {
- // Always do at least one fetchReadView, even if the number of bytes to
- // read is 0.
- err = s.fetchReadView()
- if err != nil || len(s.readView) == 0 {
- break
- }
- if dst.NumBytes() == 0 {
- break
- }
-
- var n int
- var e error
- if discard {
- n = len(s.readView)
- if int64(n) > dst.NumBytes() {
- n = int(dst.NumBytes())
- }
- } else {
- n, e = dst.CopyOut(ctx, s.readView)
- // Set the control message, even if 0 bytes were read.
- if e == nil {
- s.updateTimestamp()
- }
- }
- copied += n
- s.readView.TrimFront(n)
-
- dst = dst.DropFirst(n)
- if e != nil {
- err = syserr.FromError(e)
- break
- }
- // If we are done reading requested data then stop.
- if dst.NumBytes() == 0 {
- break
- }
- }
-
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
+func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
+ // Always do at least one read, even if the number of bytes to read is 0.
+ var n int
+ n, _, err := s.readLocked(dst, count, false /* peek */)
+ if err != nil {
+ return 0, err
}
-
- // If we managed to copy something, we must deliver it.
- if copied > 0 {
- s.Endpoint.ModerateRecvBuf(copied)
- return copied, nil
+ if n > 0 {
+ s.Endpoint.ModerateRecvBuf(n)
}
-
- return 0, err
+ return n, nil
}
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
@@ -2723,7 +2582,7 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
return
}
cmsg.IP.HasInq = true
- cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+ cmsg.IP.Inq = int32(rcvBufUsed)
}
func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
@@ -2760,7 +2619,21 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// bytes of data to be discarded, rather than passed back in a
// caller-supplied buffer.
s.readMu.Lock()
- n, err := s.coalescingRead(ctx, dst, trunc)
+
+ var w io.Writer
+ if trunc {
+ w = ioutil.Discard
+ } else {
+ w = dst.Writer(ctx)
+ }
+
+ n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
+
+ if err == nil && !trunc {
+ // Set the control message, even if 0 bytes were read.
+ s.updateTimestamp()
+ }
+
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
s.readMu.Unlock()
@@ -2770,18 +2643,32 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
s.readMu.Lock()
defer s.readMu.Unlock()
- if err := s.fetchReadView(); err != nil {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read, and does not write to buffer.
+ isTCPPeekTrunc := !isPacket && peek && trunc
+
+ var w io.Writer
+ if isTCPPeekTrunc {
+ w = ioutil.Discard
+ } else {
+ w = dst.Writer(ctx)
+ }
+
+ var numRead, numTotal int
+ var err *syserr.Error
+ numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
+ if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, err
}
- if !isPacket && peek && trunc {
- // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
- // amount that could be read.
+ if isTCPPeekTrunc {
+ // TCP endpoint does not return the total bytes in buffer as numTotal.
+ // We need to query it from socket option.
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
- available := len(s.readView) + int(rql)
+ available := int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
return available, 0, nil, 0, socket.ControlMessages{}, nil
@@ -2789,88 +2676,65 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
}
- n, err := dst.CopyOut(ctx, s.readView)
// Set the control message, even if 0 bytes were read.
- if err == nil {
- s.updateTimestamp()
- }
+ s.updateTimestamp()
+
var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
- addr, addrLen = ConvertAddress(s.family, s.sender)
+ addr, addrLen = socket.ConvertAddress(s.family, s.sender)
switch v := addr.(type) {
case *linux.SockAddrLink:
- v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
if peek {
- if l := len(s.readView); trunc && l > n {
+ if trunc && numTotal > numRead {
// isPacket must be true.
- return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil
}
-
- if isPacket || err != nil {
- return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
- }
-
- // We need to peek beyond the first message.
- dst = dst.DropFirst(n)
- num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, _, err := s.Endpoint.Peek(dsts)
- // TODO(b/78348848): Handle peek timestamp.
- if err != nil {
- return int64(n), syserr.TranslateNetstackError(err).ToError()
- }
- return int64(n), nil
- }})
- n += int(num)
- if err == syserror.ErrWouldBlock && n > 0 {
- // We got some data, so no need to return an error.
- err = nil
- }
- return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
+ return numRead, 0, nil, 0, s.controlMessages(), nil
}
var msgLen int
if isPacket {
- msgLen = len(s.readView)
- s.readView = nil
+ msgLen = numTotal
} else {
- msgLen = int(n)
- s.readView.TrimFront(int(n))
- }
-
- if len(s.readView) == 0 {
- atomic.StoreUint32(&s.readViewHasData, 0)
+ msgLen = numRead
}
var flags int
- if msgLen > int(n) {
+ if msgLen > numRead {
flags |= linux.MSG_TRUNC
}
+ n := numRead
if trunc {
n = msgLen
}
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
- return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
+ return n, flags, addr, addrLen, cmsg, nil
}
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
- IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ IP: socket.IPControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasInq: s.readCM.HasInq,
+ Inq: s.readCM.Inq,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
+ SockErr: s.readCM.SockErr,
},
}
}
@@ -2887,9 +2751,66 @@ func (s *socketOpsCommon) updateTimestamp() {
}
}
+// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb().
+func (s *socketOpsCommon) dequeueErr() *tcpip.SockError {
+ so := s.Endpoint.SocketOptions()
+ err := so.DequeueErr()
+ if err == nil {
+ return nil
+ }
+
+ // Update socket error to reflect ICMP errors in queue.
+ if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nextErr.Err)
+ } else if err.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nil)
+ }
+ return err
+}
+
+// addrFamilyFromNetProto returns the address family identifier for the given
+// network protocol.
+func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int {
+ switch net {
+ case header.IPv4ProtocolNumber:
+ return linux.AF_INET
+ case header.IPv6ProtocolNumber:
+ return linux.AF_INET6
+ default:
+ panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net))
+ }
+}
+
+// recvErr handles MSG_ERRQUEUE for recvmsg(2).
+// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error().
+func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ sockErr := s.dequeueErr()
+ if sockErr == nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+
+ // The payload of the original packet that caused the error is passed as
+ // normal data via msg_iovec. -- recvmsg(2)
+ msgFlags := linux.MSG_ERRQUEUE
+ if int(dst.NumBytes()) < len(sockErr.Payload) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ n, err := dst.CopyOut(t, sockErr.Payload)
+
+ // The original destination address of the datagram that caused the error is
+ // supplied via msg_name. -- recvmsg(2)
+ dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst)
+ cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})}
+ return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err)
+}
+
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return s.recvErr(t, dst)
+ }
+
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -2965,7 +2886,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, family, err := AddressAndFamily(to)
+ addrBuf, family, err := socket.AddressAndFamily(to)
if err != nil {
return 0, err
}
@@ -3063,11 +2984,6 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy
return 0, syserr.TranslateNetstackError(terr).ToError()
}
- // Add bytes removed from the endpoint but not yet sent to the caller.
- s.readMu.Lock()
- v += len(s.readView)
- s.readMu.Unlock()
-
if v > math.MaxInt32 {
v = math.MaxInt32
}
@@ -3384,6 +3300,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
return rv
}
+func isTCPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP)
+}
+
+func isUDPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP)
+}
+
+func isICMPSocket(skType linux.SockType, skProto int) bool {
+ return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6)
+}
+
// State implements socket.Socket.State. State translates the internal state
// returned by netstack to values defined by Linux.
func (s *socketOpsCommon) State() uint32 {
@@ -3393,7 +3321,7 @@ func (s *socketOpsCommon) State() uint32 {
}
switch {
- case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
+ case isTCPSocket(s.skType, s.protocol):
// TCP socket.
switch tcp.EndpointState(s.Endpoint.State()) {
case tcp.StateEstablished:
@@ -3422,7 +3350,7 @@ func (s *socketOpsCommon) State() uint32 {
// Internal or unknown state.
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ case isUDPSocket(s.skType, s.protocol):
// UDP socket.
switch udp.EndpointState(s.Endpoint.State()) {
case udp.StateInitial, udp.StateBound, udp.StateClosed:
@@ -3432,7 +3360,7 @@ func (s *socketOpsCommon) State() uint32 {
default:
return 0
}
- case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ case isICMPSocket(s.skType, s.protocol):
// TODO(b/112063468): Export states for ICMP sockets.
case s.skType == linux.SOCK_RAW:
// TODO(b/112063468): Export states for raw sockets.
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index b0d9e4d9e..b756bfca0 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// NewVFS2 creates a new endpoint socket.
func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ endpoint.SocketOptions().SetDelayOption(true)
}
mnt := t.Kernel().SocketMount()
@@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addrLen uint32
if peerAddr != nil {
// Get address of the peer and write it to peer slice.
- addr, addrLen = ConvertAddress(s.family, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index ead3b2b79..c847ff1c7 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
index 2a01143f6..0af805246 100644
--- a/pkg/sentry/socket/netstack/provider_vfs2.go
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot
// protocol is passed in network byte order, but netstack wants it in
// host order.
- netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+ netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol)))
wq := &waiter.Queue{}
ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index fa9ac9059..cc0fadeb5 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
- in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
- out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
0, // Icmp/InMsgs.
- Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors.
0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
@@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
0, // Icmp/OutMsgs.
- Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
- out.DstUnreachable.Value(), // OutDestUnreachs.
- out.TimeExceeded.Value(), // OutTimeExcds.
- out.ParamProblem.Value(), // OutParmProbs.
- out.SrcQuench.Value(), // OutSrcQuenchs.
- out.Redirect.Value(), // OutRedirects.
- out.Echo.Value(), // OutEchos.
- out.EchoReply.Value(), // OutEchoReps.
- out.Timestamp.Value(), // OutTimestamps.
- out.TimestampReply.Value(), // OutTimestampReps.
- out.InfoRequest.Value(), // OutAddrMasks.
- out.InfoReply.Value(), // OutAddrMaskReps.
+ Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
}
case *inet.StatSNMPTCP:
tcp := Metrics.TCP
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fd31479e5..97729dacc 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -18,6 +18,7 @@
package socket
import (
+ "bytes"
"fmt"
"sync/atomic"
"syscall"
@@ -35,6 +36,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -42,7 +44,134 @@ import (
// control messages.
type ControlMessages struct {
Unix transport.ControlMessages
- IP tcpip.ControlMessages
+ IP IPControlMessages
+}
+
+// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format.
+func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ return p
+}
+
+// errOriginToLinux maps tcpip socket origin to Linux socket origin constants.
+func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 {
+ switch origin {
+ case tcpip.SockExtErrorOriginNone:
+ return linux.SO_EE_ORIGIN_NONE
+ case tcpip.SockExtErrorOriginLocal:
+ return linux.SO_EE_ORIGIN_LOCAL
+ case tcpip.SockExtErrorOriginICMP:
+ return linux.SO_EE_ORIGIN_ICMP
+ case tcpip.SockExtErrorOriginICMP6:
+ return linux.SO_EE_ORIGIN_ICMP6
+ default:
+ panic(fmt.Sprintf("unknown socket origin: %d", origin))
+ }
+}
+
+// sockErrCmsgToLinux converts SockError control message from tcpip format to
+// Linux format.
+func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg {
+ if sockErr == nil {
+ return nil
+ }
+
+ ee := linux.SockExtendedErr{
+ Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()),
+ Origin: errOriginToLinux(sockErr.ErrOrigin),
+ Type: sockErr.ErrType,
+ Code: sockErr.ErrCode,
+ Info: sockErr.ErrInfo,
+ }
+
+ switch sockErr.NetProto {
+ case header.IPv4ProtocolNumber:
+ errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee}
+ if len(sockErr.Offender.Addr) > 0 {
+ addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender)
+ errMsg.Offender = *addr.(*linux.SockAddrInet)
+ }
+ return errMsg
+ case header.IPv6ProtocolNumber:
+ errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee}
+ if len(sockErr.Offender.Addr) > 0 {
+ addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender)
+ errMsg.Offender = *addr.(*linux.SockAddrInet6)
+ }
+ return errMsg
+ default:
+ panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto))
+ }
+}
+
+// NewIPControlMessages converts the tcpip ControlMessgaes (which does not
+// have Linux specific format) to Linux format.
+func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages {
+ var orgDstAddr linux.SockAddr
+ if cmgs.HasOriginalDstAddress {
+ orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
+ }
+ return IPControlMessages{
+ HasTimestamp: cmgs.HasTimestamp,
+ Timestamp: cmgs.Timestamp,
+ HasInq: cmgs.HasInq,
+ Inq: cmgs.Inq,
+ HasTOS: cmgs.HasTOS,
+ TOS: cmgs.TOS,
+ HasTClass: cmgs.HasTClass,
+ TClass: cmgs.TClass,
+ HasIPPacketInfo: cmgs.HasIPPacketInfo,
+ PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ OriginalDstAddress: orgDstAddr,
+ SockErr: sockErrCmsgToLinux(cmgs.SockErr),
+ }
+}
+
+// IPControlMessages contains socket control messages for IP sockets.
+// This can contain Linux specific structures unlike tcpip.ControlMessages.
+//
+// +stateify savable
+type IPControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packet used to create
+ // the read data was received.
+ Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo linux.ControlMessageIPPacketInfo
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress linux.SockAddr
+
+ // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE).
+ SockErr linux.SockErrCMsg
}
// Release releases Unix domain socket credentials and rights.
@@ -460,3 +589,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
panic(fmt.Sprintf("Unsupported socket family %v", family))
}
}
+
+var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes()
+var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes()
+var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes()
+
+// Ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func Ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// Htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func Htons(v uint16) uint16 {
+ return Ntohs(v)
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ l := len([]byte(addr.Addr))
+ for i := 0; i < l; i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract and empty paths always return
+ // the full exact length.
+ if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
+ return &out, uint32(2 + l)
+ }
+ return &out, uint32(3 + l)
+
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = Htons(addr.Port)
+ return &out, uint32(sockAddrInetSize)
+
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = Htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
+ default:
+ return nil, 0
+ }
+}
+
+// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the
+// netstack representation taking any addresses into account.
+func BytesToIPAddress(addr []byte) tcpip.Address {
+ if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
+ return ""
+ }
+ return tcpip.Address(addr)
+}
+
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
+//
+// AddressAndFamily returns an address and its family.
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family := usermem.ByteOrder.Uint16(addr); family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ if len(path) > linux.UnixPathMax {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ // Drop the terminating NUL (if one exists) and everything after
+ // it for filesystem (non-abstract) addresses.
+ if len(path) > 0 && path[0] != 0 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, family, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ return out, family, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: BytesToIPAddress(a.Addr[:]),
+ Port: Ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ return out, family, nil
+
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/173): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, family, nil
+
+ default:
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 4abea90cc..099a56281 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -178,10 +178,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool sets a socket option for simple cases when a value has
- // the int type.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -189,10 +185,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool gets a socket option for simple cases when a return
- // value has the int type.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
@@ -754,9 +746,6 @@ type baseEndpoint struct {
// or may be used if the endpoint is connected.
path string
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -848,17 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option.
func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- e.linger = *v
- e.Unlock()
- }
- return nil
-}
-
-func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- log.Warningf("Unsupported socket option: %d", opt)
return nil
}
@@ -872,11 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- log.Warningf("Unsupported socket option: %d", opt)
- return false, tcpip.ErrUnknownProtocolOption
-}
-
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -940,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- *o = e.linger
- e.Unlock()
- return nil
-
- default:
- log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
- }
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
}
// LastError implements Endpoint.LastError.
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index b32bb7ba8..c59297c80 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -136,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, family, err := netstack.AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
if err == syserr.ErrAddressFamilyNotSupported {
err = syserr.ErrInvalidArgument
@@ -169,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -181,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -255,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -647,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -682,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index eaf0b0d26..27f705bb2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index a920180d3..d36a64ffc 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -32,8 +32,8 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/netlink",
- "//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
"//pkg/usermem",
],
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index cc5f70cd4..d943a7cb1 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -23,8 +23,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
- "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -341,7 +341,7 @@ func sockAddr(t *kernel.Task, addr usermem.Addr, length uint32) string {
switch family {
case linux.AF_INET, linux.AF_INET6, linux.AF_UNIX:
- fa, _, err := netstack.AddressAndFamily(b)
+ fa, _, err := socket.AddressAndFamily(b)
if err != nil {
return fmt.Sprintf("%#x {Family: %s, error extracting address: %v}", addr, familyStr, err)
}
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index bb1f715e2..a72df62f6 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{
63: syscalls.Supported("uname", Uname),
64: syscalls.Supported("semget", Semget),
65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
- 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
+ 66: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
@@ -272,7 +272,7 @@ var AMD64 = &kernel.SyscallTable{
217: syscalls.Supported("getdents64", Getdents64),
218: syscalls.Supported("set_tid_address", SetTidAddress),
219: syscalls.Supported("restart_syscall", RestartSyscall),
- 220: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
+ 220: syscalls.PartiallySupported("semtimedop", Semtimedop, "A non-zero timeout argument isn't supported.", []string{"gvisor.dev/issue/137"}),
221: syscalls.PartiallySupported("fadvise64", Fadvise64, "Not all options are supported.", nil),
222: syscalls.Supported("timer_create", TimerCreate),
223: syscalls.Supported("timer_settime", TimerSettime),
@@ -619,8 +619,8 @@ var ARM64 = &kernel.SyscallTable{
188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
190: syscalls.Supported("semget", Semget),
- 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil),
- 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}),
+ 191: syscalls.PartiallySupported("semctl", Semctl, "Options SEM_STAT_ANY not supported.", nil),
+ 192: syscalls.PartiallySupported("semtimedop", Semtimedop, "A non-zero timeout argument isn't supported.", []string{"gvisor.dev/issue/137"}),
193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil),
194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil),
195: syscalls.PartiallySupported("shmctl", Shmctl, "Options SHM_LOCK, SHM_UNLOCK are not supported.", nil),
diff --git a/pkg/sentry/syscalls/linux/sys_aio.go b/pkg/sentry/syscalls/linux/sys_aio.go
index 0bf313a13..c2285f796 100644
--- a/pkg/sentry/syscalls/linux/sys_aio.go
+++ b/pkg/sentry/syscalls/linux/sys_aio.go
@@ -307,9 +307,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if !ok {
return syserror.EINVAL
}
- if ready := ctx.Prepare(); !ready {
- // Context is busy.
- return syserror.EAGAIN
+ if err := ctx.Prepare(); err != nil {
+ return err
}
if eventFile != nil {
diff --git a/pkg/sentry/syscalls/linux/sys_file.go b/pkg/sentry/syscalls/linux/sys_file.go
index 519066a47..c33571f43 100644
--- a/pkg/sentry/syscalls/linux/sys_file.go
+++ b/pkg/sentry/syscalls/linux/sys_file.go
@@ -175,6 +175,12 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
}
}
+ file, err := d.Inode.GetFile(t, d, fileFlags)
+ if err != nil {
+ return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
+ }
+ defer file.DecRef(t)
+
// Truncate is called when O_TRUNC is specified for any kind of
// existing Dirent. Behavior is delegated to the entry's Truncate
// implementation.
@@ -184,12 +190,6 @@ func openAt(t *kernel.Task, dirFD int32, addr usermem.Addr, flags uint) (fd uint
}
}
- file, err := d.Inode.GetFile(t, d, fileFlags)
- if err != nil {
- return syserror.ConvertIntr(err, syserror.ERESTARTSYS)
- }
- defer file.DecRef(t)
-
// Success.
newFD, err := t.NewFDFrom(0, file, kernel.FDFlags{
CloseOnExec: flags&linux.O_CLOEXEC != 0,
@@ -646,7 +646,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if _, err := primitive.CopyInt32In(t, args[2].Pointer(), &set); err != nil {
return 0, nil, err
}
- fSetOwn(t, file, set)
+ fSetOwn(t, int(fd), file, set)
return 0, nil, nil
case linux.FIOGETOWN, linux.SIOCGPGRP:
@@ -901,8 +901,8 @@ func fGetOwn(t *kernel.Task, file *fs.File) int32 {
//
// If who is positive, it represents a PID. If negative, it represents a PGID.
// If the PID or PGID is invalid, the owner is silently unset.
-func fSetOwn(t *kernel.Task, file *fs.File, who int32) error {
- a := file.Async(fasync.New).(*fasync.FileAsync)
+func fSetOwn(t *kernel.Task, fd int, file *fs.File, who int32) error {
+ a := file.Async(fasync.New(fd)).(*fasync.FileAsync)
if who < 0 {
// Check for overflow before flipping the sign.
if who-1 > who {
@@ -1049,7 +1049,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
case linux.F_GETOWN:
return uintptr(fGetOwn(t, file)), nil, nil
case linux.F_SETOWN:
- return 0, nil, fSetOwn(t, file, args[2].Int())
+ return 0, nil, fSetOwn(t, int(fd), file, args[2].Int())
case linux.F_GETOWN_EX:
addr := args[2].Pointer()
owner := fGetOwnEx(t, file)
@@ -1062,7 +1062,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- a := file.Async(fasync.New).(*fasync.FileAsync)
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
switch owner.Type {
case linux.F_OWNER_TID:
task := t.PIDNamespace().TaskWithID(kernel.ThreadID(owner.PID))
@@ -1111,6 +1111,12 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
n, err := sz.SetFifoSize(int64(args[2].Int()))
return uintptr(n), nil, err
+ case linux.F_GETSIG:
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
+ return uintptr(a.Signal()), nil, nil
+ case linux.F_SETSIG:
+ a := file.Async(fasync.New(int(fd))).(*fasync.FileAsync)
+ return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go
index e383a0a87..d324461a3 100644
--- a/pkg/sentry/syscalls/linux/sys_sem.go
+++ b/pkg/sentry/syscalls/linux/sys_sem.go
@@ -48,6 +48,15 @@ func Semget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return uintptr(set.ID), nil, nil
}
+// Semtimedop handles: semop(int semid, struct sembuf *sops, size_t nsops, const struct timespec *timeout)
+func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ // TODO(gvisor.dev/issue/137): A non-zero timeout isn't supported.
+ if args[3].Pointer() != 0 {
+ return 0, nil, syserror.ENOSYS
+ }
+ return Semop(t, args)
+}
+
// Semop handles: semop(int semid, struct sembuf *sops, size_t nsops)
func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := args[0].Int()
@@ -146,11 +155,37 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
v, err := getNCnt(t, id, num)
return uintptr(v), nil, err
- case linux.IPC_INFO,
- linux.SEM_INFO,
- linux.SEM_STAT,
- linux.SEM_STAT_ANY:
+ case linux.IPC_INFO:
+ buf := args[3].Pointer()
+ r := t.IPCNamespace().SemaphoreRegistry()
+ info := r.IPCInfo()
+ if _, err := info.CopyOut(t, buf); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r.HighestIndex()), nil, nil
+ case linux.SEM_INFO:
+ buf := args[3].Pointer()
+ r := t.IPCNamespace().SemaphoreRegistry()
+ info := r.SemInfo()
+ if _, err := info.CopyOut(t, buf); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(r.HighestIndex()), nil, nil
+
+ case linux.SEM_STAT:
+ arg := args[3].Pointer()
+ // id is an index in SEM_STAT.
+ semid, ds, err := semStat(t, id)
+ if err != nil {
+ return 0, nil, err
+ }
+ if _, err := ds.CopyOut(t, arg); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(semid), nil, err
+
+ case linux.SEM_STAT_ANY:
t.Kernel().EmitUnimplementedEvent(t)
fallthrough
@@ -195,6 +230,17 @@ func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
return set.GetStat(creds)
}
+func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
+ r := t.IPCNamespace().SemaphoreRegistry()
+ set := r.FindByIndex(index)
+ if set == nil {
+ return 0, nil, syserror.EINVAL
+ }
+ creds := auth.CredentialsFromContext(t)
+ ds, err := set.GetStat(creds)
+ return set.ID, ds, err
+}
+
func setVal(t *kernel.Task, id int32, num int32, val int16) error {
r := t.IPCNamespace().SemaphoreRegistry()
set := r.FindByID(id)
diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go
index e748d33d8..d639c9bf7 100644
--- a/pkg/sentry/syscalls/linux/sys_signal.go
+++ b/pkg/sentry/syscalls/linux/sys_signal.go
@@ -88,8 +88,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(target.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(target.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(target.UserNamespace()).OrOverflow()))
if err := target.SendGroupSignal(info); err != syserror.ESRCH {
return 0, nil, err
}
@@ -127,8 +127,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
err := tg.SendSignal(info)
if err == syserror.ESRCH {
// ESRCH is ignored because it means the task
@@ -171,8 +171,8 @@ func Kill(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallC
Signo: int32(sig),
Code: arch.SignalInfoUser,
}
- info.SetPid(int32(tg.PIDNamespace().IDOfTask(t)))
- info.SetUid(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
+ info.SetPID(int32(tg.PIDNamespace().IDOfTask(t)))
+ info.SetUID(int32(t.Credentials().RealKUID.In(tg.Leader().UserNamespace()).OrOverflow()))
// See note above regarding ESRCH race above.
if err := tg.SendSignal(info); err != syserror.ESRCH {
lastErr = err
@@ -189,8 +189,8 @@ func tkillSigInfo(sender, receiver *kernel.Task, sig linux.Signal) *arch.SignalI
Signo: int32(sig),
Code: arch.SignalInfoTkill,
}
- info.SetPid(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
- info.SetUid(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
+ info.SetPID(int32(receiver.PIDNamespace().IDOfThreadGroup(sender.ThreadGroup())))
+ info.SetUID(int32(sender.Credentials().RealKUID.In(receiver.UserNamespace()).OrOverflow()))
return info
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 9cd052c3d..fe45225c1 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -749,11 +749,6 @@ func recvSingleMsg(t *kernel.Task, s socket.Socket, msgPtr usermem.Addr, flags i
return 0, err
}
- // FIXME(b/63594852): Pretend we have an empty error queue.
- if flags&linux.MSG_ERRQUEUE != 0 {
- return 0, syserror.EAGAIN
- }
-
// Fast path when no control message nor name buffers are provided.
if msg.ControlLen == 0 && msg.NameLen == 0 {
n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
@@ -1035,7 +1030,7 @@ func sendSingleMsg(t *kernel.Task, s socket.Socket, file *fs.File, msgPtr userme
return 0, err
}
- controlMessages, err := control.Parse(t, s, controlData)
+ controlMessages, err := control.Parse(t, s, controlData, t.Arch().Width())
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/syscalls/linux/sys_thread.go b/pkg/sentry/syscalls/linux/sys_thread.go
index 983f8d396..8e7ac0ffe 100644
--- a/pkg/sentry/syscalls/linux/sys_thread.go
+++ b/pkg/sentry/syscalls/linux/sys_thread.go
@@ -413,8 +413,8 @@ func Waitid(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
si := arch.SignalInfo{
Signo: int32(linux.SIGCHLD),
}
- si.SetPid(int32(wr.TID))
- si.SetUid(int32(wr.UID))
+ si.SetPID(int32(wr.TID))
+ si.SetUID(int32(wr.UID))
// TODO(b/73541790): convert kernel.ExitStatus to functions and make
// WaitResult.Status a linux.WaitStatus.
s := syscall.WaitStatus(wr.Status)
diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go
index 6d0a38330..1365a5a62 100644
--- a/pkg/sentry/syscalls/linux/vfs2/aio.go
+++ b/pkg/sentry/syscalls/linux/vfs2/aio.go
@@ -130,9 +130,8 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user
if !ok {
return syserror.EINVAL
}
- if ready := aioCtx.Prepare(); !ready {
- // Context is busy.
- return syserror.EAGAIN
+ if err := aioCtx.Prepare(); err != nil {
+ return err
}
if eventFD != nil {
diff --git a/pkg/sentry/syscalls/linux/vfs2/fd.go b/pkg/sentry/syscalls/linux/vfs2/fd.go
index 36e89700e..7dd9ef857 100644
--- a/pkg/sentry/syscalls/linux/vfs2/fd.go
+++ b/pkg/sentry/syscalls/linux/vfs2/fd.go
@@ -165,7 +165,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ownerType = linux.F_OWNER_PGRP
who = -who
}
- return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who)
case linux.F_GETOWN_EX:
owner, hasOwner := getAsyncOwner(t, file)
if !hasOwner {
@@ -179,7 +179,7 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
if err != nil {
return 0, nil, err
}
- return 0, nil, setAsyncOwner(t, file, owner.Type, owner.PID)
+ return 0, nil, setAsyncOwner(t, int(fd), file, owner.Type, owner.PID)
case linux.F_SETPIPE_SZ:
pipefile, ok := file.Impl().(*pipe.VFSPipeFD)
if !ok {
@@ -207,6 +207,16 @@ func Fcntl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
return 0, nil, err
case linux.F_SETLK, linux.F_SETLKW:
return 0, nil, posixLock(t, args, file, cmd)
+ case linux.F_GETSIG:
+ a := file.AsyncHandler()
+ if a == nil {
+ // Default behavior aka SIGIO.
+ return 0, nil, nil
+ }
+ return uintptr(a.(*fasync.FileAsync).Signal()), nil, nil
+ case linux.F_SETSIG:
+ a := file.SetAsyncHandler(fasync.NewVFS2(int(fd))).(*fasync.FileAsync)
+ return 0, nil, a.SetSignal(linux.Signal(args[2].Int()))
default:
// Everything else is not yet supported.
return 0, nil, syserror.EINVAL
@@ -241,7 +251,7 @@ func getAsyncOwner(t *kernel.Task, fd *vfs.FileDescription) (ownerEx linux.FOwne
}
}
-func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32) error {
+func setAsyncOwner(t *kernel.Task, fd int, file *vfs.FileDescription, ownerType, pid int32) error {
switch ownerType {
case linux.F_OWNER_TID, linux.F_OWNER_PID, linux.F_OWNER_PGRP:
// Acceptable type.
@@ -249,7 +259,7 @@ func setAsyncOwner(t *kernel.Task, fd *vfs.FileDescription, ownerType, pid int32
return syserror.EINVAL
}
- a := fd.SetAsyncHandler(fasync.NewVFS2).(*fasync.FileAsync)
+ a := file.SetAsyncHandler(fasync.NewVFS2(fd)).(*fasync.FileAsync)
if pid == 0 {
a.ClearOwner()
return nil
diff --git a/pkg/sentry/syscalls/linux/vfs2/ioctl.go b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
index 2806c3f6f..20c264fef 100644
--- a/pkg/sentry/syscalls/linux/vfs2/ioctl.go
+++ b/pkg/sentry/syscalls/linux/vfs2/ioctl.go
@@ -100,7 +100,7 @@ func Ioctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
ownerType = linux.F_OWNER_PGRP
who = -who
}
- return 0, nil, setAsyncOwner(t, file, ownerType, who)
+ return 0, nil, setAsyncOwner(t, int(fd), file, ownerType, who)
}
ret, err := file.Ioctl(t, t.MemoryManager(), args)
diff --git a/pkg/sentry/syscalls/linux/vfs2/pipe.go b/pkg/sentry/syscalls/linux/vfs2/pipe.go
index ee38fdca0..6986e39fe 100644
--- a/pkg/sentry/syscalls/linux/vfs2/pipe.go
+++ b/pkg/sentry/syscalls/linux/vfs2/pipe.go
@@ -42,7 +42,10 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags int32) error {
if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 {
return syserror.EINVAL
}
- r, w := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
+ r, w, err := pipefs.NewConnectedPipeFDs(t, t.Kernel().PipeMount(), uint32(flags&linux.O_NONBLOCK))
+ if err != nil {
+ return err
+ }
defer r.DecRef(t)
defer w.DecRef(t)
diff --git a/pkg/sentry/syscalls/linux/vfs2/socket.go b/pkg/sentry/syscalls/linux/vfs2/socket.go
index 7b33b3f59..f5795b4a8 100644
--- a/pkg/sentry/syscalls/linux/vfs2/socket.go
+++ b/pkg/sentry/syscalls/linux/vfs2/socket.go
@@ -752,11 +752,6 @@ func recvSingleMsg(t *kernel.Task, s socket.SocketVFS2, msgPtr usermem.Addr, fla
return 0, err
}
- // FIXME(b/63594852): Pretend we have an empty error queue.
- if flags&linux.MSG_ERRQUEUE != 0 {
- return 0, syserror.EAGAIN
- }
-
// Fast path when no control message nor name buffers are provided.
if msg.ControlLen == 0 && msg.NameLen == 0 {
n, mflags, _, _, cms, err := s.RecvMsg(t, dst, int(flags), haveDeadline, deadline, false, 0)
@@ -1038,7 +1033,7 @@ func sendSingleMsg(t *kernel.Task, s socket.SocketVFS2, file *vfs.FileDescriptio
return 0, err
}
- controlMessages, err := control.Parse(t, s, controlData)
+ controlMessages, err := control.Parse(t, s, controlData, t.Arch().Width())
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go
index a98aac52b..072655fe8 100644
--- a/pkg/sentry/vfs/epoll.go
+++ b/pkg/sentry/vfs/epoll.go
@@ -204,8 +204,8 @@ func (ep *EpollInstance) AddInterest(file *FileDescription, num int32, event lin
file.EventRegister(&epi.waiter, wmask)
// Check if the file is already ready.
- if file.Readiness(wmask)&wmask != 0 {
- epi.Callback(nil)
+ if m := file.Readiness(wmask) & wmask; m != 0 {
+ epi.Callback(nil, m)
}
// Add epi to file.epolls so that it is removed when the last
@@ -274,8 +274,8 @@ func (ep *EpollInstance) ModifyInterest(file *FileDescription, num int32, event
file.EventRegister(&epi.waiter, wmask)
// Check if the file is already ready with the new mask.
- if file.Readiness(wmask)&wmask != 0 {
- epi.Callback(nil)
+ if m := file.Readiness(wmask) & wmask; m != 0 {
+ epi.Callback(nil, m)
}
return nil
@@ -311,7 +311,7 @@ func (ep *EpollInstance) DeleteInterest(file *FileDescription, num int32) error
}
// Callback implements waiter.EntryCallback.Callback.
-func (epi *epollInterest) Callback(*waiter.Entry) {
+func (epi *epollInterest) Callback(*waiter.Entry, waiter.EventMask) {
newReady := false
epi.epoll.mu.Lock()
if !epi.ready {
diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go
index f9e39a94c..5321ac80a 100644
--- a/pkg/sentry/vfs/file_description.go
+++ b/pkg/sentry/vfs/file_description.go
@@ -15,6 +15,7 @@
package vfs
import (
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -43,7 +44,7 @@ import (
type FileDescription struct {
FileDescriptionRefs
- // flagsMu protects statusFlags and asyncHandler below.
+ // flagsMu protects `statusFlags`, `saved`, and `asyncHandler` below.
flagsMu sync.Mutex `state:"nosave"`
// statusFlags contains status flags, "initialized by open(2) and possibly
@@ -52,6 +53,11 @@ type FileDescription struct {
// access to asyncHandler.
statusFlags uint32
+ // saved is true after beforeSave is called. This is used to prevent
+ // double-unregistration of asyncHandler. This does not work properly for
+ // save-resume, which is not currently supported in gVisor (see b/26588733).
+ saved bool `state:"nosave"`
+
// asyncHandler handles O_ASYNC signal generation. It is set with the
// F_SETOWN or F_SETOWN_EX fcntls. For asyncHandler to be used, O_ASYNC must
// also be set by fcntl(2).
@@ -184,7 +190,7 @@ func (fd *FileDescription) DecRef(ctx context.Context) {
}
fd.vd.DecRef(ctx)
fd.flagsMu.Lock()
- if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ if !fd.saved && fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
fd.asyncHandler.Unregister(fd)
}
fd.asyncHandler = nil
@@ -834,44 +840,27 @@ func (fd *FileDescription) SetAsyncHandler(newHandler func() FileAsync) FileAsyn
return fd.asyncHandler
}
-// FileReadWriteSeeker is a helper struct to pass a FileDescription as
-// io.Reader/io.Writer/io.ReadSeeker/io.ReaderAt/io.WriterAt/etc.
-type FileReadWriteSeeker struct {
- FD *FileDescription
- Ctx context.Context
- ROpts ReadOptions
- WOpts WriteOptions
-}
-
-// ReadAt implements io.ReaderAt.ReadAt.
-func (f *FileReadWriteSeeker) ReadAt(p []byte, off int64) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.PRead(f.Ctx, dst, off, f.ROpts)
- return int(n), err
-}
-
-// Read implements io.ReadWriteSeeker.Read.
-func (f *FileReadWriteSeeker) Read(p []byte) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.Read(f.Ctx, dst, f.ROpts)
- return int(n), err
-}
-
-// Seek implements io.ReadWriteSeeker.Seek.
-func (f *FileReadWriteSeeker) Seek(offset int64, whence int) (int64, error) {
- return f.FD.Seek(f.Ctx, offset, int32(whence))
-}
-
-// WriteAt implements io.WriterAt.WriteAt.
-func (f *FileReadWriteSeeker) WriteAt(p []byte, off int64) (int, error) {
- dst := usermem.BytesIOSequence(p)
- n, err := f.FD.PWrite(f.Ctx, dst, off, f.WOpts)
- return int(n), err
-}
-
-// Write implements io.ReadWriteSeeker.Write.
-func (f *FileReadWriteSeeker) Write(p []byte) (int, error) {
- buf := usermem.BytesIOSequence(p)
- n, err := f.FD.Write(f.Ctx, buf, f.WOpts)
- return int(n), err
+// CopyRegularFileData copies data from srcFD to dstFD until reading from srcFD
+// returns EOF or an error. It returns the number of bytes copied.
+func CopyRegularFileData(ctx context.Context, dstFD, srcFD *FileDescription) (int64, error) {
+ done := int64(0)
+ buf := usermem.BytesIOSequence(make([]byte, 32*1024)) // arbitrary buffer size
+ for {
+ readN, readErr := srcFD.Read(ctx, buf, ReadOptions{})
+ if readErr != nil && readErr != io.EOF {
+ return done, readErr
+ }
+ src := buf.TakeFirst64(readN)
+ for src.NumBytes() != 0 {
+ writeN, writeErr := dstFD.Write(ctx, src, WriteOptions{})
+ done += writeN
+ src = src.DropFirst64(writeN)
+ if writeErr != nil {
+ return done, writeErr
+ }
+ }
+ if readErr == io.EOF {
+ return done, nil
+ }
+ }
}
diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go
index cb48c37a1..0df023713 100644
--- a/pkg/sentry/vfs/mount_unsafe.go
+++ b/pkg/sentry/vfs/mount_unsafe.go
@@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// +build go1.12
-// +build !go1.17
-
-// Check go:linkname function signatures when updating Go version.
-
package vfs
import (
@@ -41,6 +36,15 @@ type mountKey struct {
point unsafe.Pointer // *Dentry
}
+var (
+ mountKeyHasher = sync.MapKeyHasher(map[mountKey]struct{}(nil))
+ mountKeySeed = sync.RandUintptr()
+)
+
+func (k *mountKey) hash() uintptr {
+ return mountKeyHasher(gohacks.Noescape(unsafe.Pointer(k)), mountKeySeed)
+}
+
func (mnt *Mount) parent() *Mount {
return (*Mount)(atomic.LoadPointer(&mnt.key.parent))
}
@@ -56,23 +60,17 @@ func (mnt *Mount) getKey() VirtualDentry {
}
}
-func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
-
// Invariant: mnt.key.parent == nil. vd.Ok().
func (mnt *Mount) setKey(vd VirtualDentry) {
atomic.StorePointer(&mnt.key.parent, unsafe.Pointer(vd.mount))
atomic.StorePointer(&mnt.key.point, unsafe.Pointer(vd.dentry))
}
-func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
-
// mountTable maps (mount parent, mount point) pairs to mounts. It supports
// efficient concurrent lookup, even in the presence of concurrent mutators
// (provided mutation is sufficiently uncommon).
//
// mountTable.Init() must be called on new mountTables before use.
-//
-// +stateify savable
type mountTable struct {
// mountTable is implemented as a seqcount-protected hash table that
// resolves collisions with linear probing, featuring Robin Hood insertion
@@ -84,8 +82,7 @@ type mountTable struct {
// intrinsics and inline assembly, limiting the performance of this
// approach.)
- seq sync.SeqCount `state:"nosave"`
- seed uint32 // for hashing keys
+ seq sync.SeqCount `state:"nosave"`
// size holds both length (number of elements) and capacity (number of
// slots): capacity is stored as its base-2 log (referred to as order) in
@@ -150,7 +147,6 @@ func init() {
// Init must be called exactly once on each mountTable before use.
func (mt *mountTable) Init() {
- mt.seed = rand32()
mt.size = mtInitOrder
mt.slots = newMountTableSlots(mtInitCap)
}
@@ -167,7 +163,7 @@ func newMountTableSlots(cap uintptr) unsafe.Pointer {
// Lookup may be called even if there are concurrent mutators of mt.
func (mt *mountTable) Lookup(parent *Mount, point *Dentry) *Mount {
key := mountKey{parent: unsafe.Pointer(parent), point: unsafe.Pointer(point)}
- hash := memhash(gohacks.Noescape(unsafe.Pointer(&key)), uintptr(mt.seed), mountKeyBytes)
+ hash := key.hash()
loop:
for {
@@ -247,7 +243,7 @@ func (mt *mountTable) Insert(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must not already contain a Mount with the same mount point and parent.
func (mt *mountTable) insertSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
// We're under the maximum load factor if:
//
@@ -346,7 +342,7 @@ func (mt *mountTable) Remove(mount *Mount) {
// * mt.seq must be in a writer critical section.
// * mt must contain mount.
func (mt *mountTable) removeSeqed(mount *Mount) {
- hash := memhash(unsafe.Pointer(&mount.key), uintptr(mt.seed), mountKeyBytes)
+ hash := mount.key.hash()
tcap := uintptr(1) << (mt.size & mtSizeOrderMask)
mask := tcap - 1
slots := mt.slots
@@ -386,9 +382,3 @@ func (mt *mountTable) removeSeqed(mount *Mount) {
off = (off + mountSlotBytes) & offmask
}
}
-
-//go:linkname memhash runtime.memhash
-func memhash(p unsafe.Pointer, seed, s uintptr) uintptr
-
-//go:linkname rand32 runtime.fastrand
-func rand32() uint32
diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go
index 7723ed643..8998a82dd 100644
--- a/pkg/sentry/vfs/save_restore.go
+++ b/pkg/sentry/vfs/save_restore.go
@@ -18,8 +18,10 @@ import (
"fmt"
"sync/atomic"
+ "gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/refsvfs2"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// FilesystemImplSaveRestoreExtension is an optional extension to
@@ -99,6 +101,9 @@ func (vfs *VirtualFilesystem) saveMounts() []*Mount {
return mounts
}
+// saveKey is called by stateify.
+func (mnt *Mount) saveKey() VirtualDentry { return mnt.getKey() }
+
// loadMounts is called by stateify.
func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
if mounts == nil {
@@ -110,6 +115,9 @@ func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) {
}
}
+// loadKey is called by stateify.
+func (mnt *Mount) loadKey(vd VirtualDentry) { mnt.setKey(vd) }
+
func (mnt *Mount) afterLoad() {
if atomic.LoadInt64(&mnt.refs) != 0 {
refsvfs2.Register(mnt)
@@ -120,5 +128,20 @@ func (mnt *Mount) afterLoad() {
func (epi *epollInterest) afterLoad() {
// Mark all epollInterests as ready after restore so that the next call to
// EpollInstance.ReadEvents() rechecks their readiness.
- epi.Callback(nil)
+ epi.Callback(nil, waiter.EventMaskFromLinux(epi.mask))
+}
+
+// beforeSave is called by stateify.
+func (fd *FileDescription) beforeSave() {
+ fd.saved = true
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Unregister(fd)
+ }
+}
+
+// afterLoad is called by stateify.
+func (fd *FileDescription) afterLoad() {
+ if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil {
+ fd.asyncHandler.Register(fd)
+ }
}