diff options
Diffstat (limited to 'pkg')
39 files changed, 552 insertions, 168 deletions
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go index 8390915a2..e7c3a3a0b 100644 --- a/pkg/flipcall/ctrl_futex.go +++ b/pkg/flipcall/ctrl_futex.go @@ -113,7 +113,7 @@ func (ep *Endpoint) enterFutexWait() error { return nil case epsBlocked | epsShutdown: atomic.AddInt32(&ep.ctrl.state, -epsBlocked) - return shutdownError{} + return ShutdownError{} default: // Most likely due to ep.enterFutexWait() being called concurrently // from multiple goroutines. diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go index 386cee42c..3cdb576e1 100644 --- a/pkg/flipcall/flipcall.go +++ b/pkg/flipcall/flipcall.go @@ -136,8 +136,8 @@ func (ep *Endpoint) unmapPacket() { // Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(), // ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer -// Endpoint, to unblock and return errors. It does not wait for concurrent -// calls to return. Successive calls to Shutdown have no effect. +// Endpoint, to unblock and return ShutdownErrors. It does not wait for +// concurrent calls to return. Successive calls to Shutdown have no effect. // // Shutdown is the only Endpoint method that may be called concurrently with // other methods on the same Endpoint. @@ -154,10 +154,12 @@ func (ep *Endpoint) isShutdownLocally() bool { return atomic.LoadUint32(&ep.shutdown) != 0 } -type shutdownError struct{} +// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown() +// has been called. +type ShutdownError struct{} // Error implements error.Error. -func (shutdownError) Error() string { +func (ShutdownError) Error() string { return "flipcall connection shutdown" } diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go index b127a2bbb..168c1ccff 100644 --- a/pkg/flipcall/futex_linux.go +++ b/pkg/flipcall/futex_linux.go @@ -61,7 +61,7 @@ func (ep *Endpoint) futexSwitchToPeer() error { if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) { switch cs := atomic.LoadUint32(ep.connState()); cs { case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs) } @@ -81,14 +81,14 @@ func (ep *Endpoint) futexSwitchFromPeer() error { return nil case ep.inactiveState: if ep.isShutdownLocally() { - return shutdownError{} + return ShutdownError{} } if err := ep.futexWaitConnState(ep.inactiveState); err != nil { return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err) } continue case csShutdown: - return shutdownError{} + return ShutdownError{} default: return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs) } diff --git a/pkg/p9/server.go b/pkg/p9/server.go index e717e6161..40b8fa023 100644 --- a/pkg/p9/server.go +++ b/pkg/p9/server.go @@ -453,7 +453,11 @@ func (cs *connState) initializeChannels() (err error) { go func() { // S/R-SAFE: Server side. defer cs.channelWg.Done() if err := res.service(cs); err != nil { - log.Warningf("p9.channel.service: %v", err) + // Don't log flipcall.ShutdownErrors, which we expect to be + // returned during server shutdown. + if _, ok := err.(flipcall.ShutdownError); !ok { + log.Warningf("p9.channel.service: %v", err) + } } }() } diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go index dfd62cbdb..23e009ef3 100644 --- a/pkg/sentry/context/context.go +++ b/pkg/sentry/context/context.go @@ -12,10 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package context defines the sentry's Context type. +// Package context defines an internal context type. +// +// The given Context conforms to the standard Go context, but mandates +// additional methods that are specific to the kernel internals. Note however, +// that the Context described by this package carries additional constraints +// regarding concurrent access and retaining beyond the scope of a call. +// +// See the Context type for complete details. package context import ( + "context" + "time" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/log" ) @@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) { type Context interface { log.Logger amutex.Sleeper + context.Context // UninterruptibleSleepStart indicates the beginning of an uninterruptible // sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate @@ -72,19 +83,36 @@ type Context interface { // AddressSpace is activated. Normally activate is the same value as the // deactivate parameter passed to UninterruptibleSleepStart. UninterruptibleSleepFinish(activate bool) +} + +// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep +// methods for anonymous embedding in other types that do not implement sleeps. +type NoopSleeper struct { + amutex.NoopSleeper +} + +// UninterruptibleSleepStart does nothing. +func (NoopSleeper) UninterruptibleSleepStart(bool) {} + +// UninterruptibleSleepFinish does nothing. +func (NoopSleeper) UninterruptibleSleepFinish(bool) {} + +// Deadline returns zero values, meaning no deadline. +func (NoopSleeper) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done returns nil. +func (NoopSleeper) Done() <-chan struct{} { + return nil +} - // Value returns the value associated with this Context for key, or nil if - // no value is associated with key. Successive calls to Value with the same - // key returns the same result. - // - // A key identifies a specific value in a Context. Functions that wish to - // retrieve values from Context typically allocate a key in a global - // variable then use that key as the argument to Context.Value. A key can - // be any type that supports equality; packages should define keys as an - // unexported type to avoid collisions. - Value(key interface{}) interface{} +// Err returns nil. +func (NoopSleeper) Err() error { + return nil } +// logContext implements basic logging. type logContext struct { log.Logger NoopSleeper @@ -95,19 +123,6 @@ func (logContext) Value(key interface{}) interface{} { return nil } -// NoopSleeper is a noop implementation of amutex.Sleeper and -// Context.UninterruptibleSleep* methods for anonymous embedding in other types -// that do not want to notify kernel.Task about sleeps. -type NoopSleeper struct { - amutex.NoopSleeper -} - -// UninterruptibleSleepStart does nothing. -func (NoopSleeper) UninterruptibleSleepStart(bool) {} - -// UninterruptibleSleepFinish does nothing. -func (NoopSleeper) UninterruptibleSleepFinish(bool) {} - // bgContext is the context returned by context.Background. var bgContext = &logContext{Logger: log.Log()} diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go index e3f5b0d83..3c9dceaba 100644 --- a/pkg/sentry/kernel/context.go +++ b/pkg/sentry/kernel/context.go @@ -15,6 +15,8 @@ package kernel import ( + "time" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/context" ) @@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task { return nil } +// Deadline implements context.Context.Deadline. +func (*Task) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +// Done implements context.Context.Done. +func (*Task) Done() <-chan struct{} { + return nil +} + +// Err implements context.Context.Err. +func (*Task) Err() error { + return nil +} + // AsyncContext returns a context.Context that may be used by goroutines that // do work on behalf of t and therefore share its contextual values, but are // not t's task goroutine (e.g. asynchronous I/O). @@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool { return ctx.t.IsLogging(level) } +// Deadline implements context.Context.Deadline. +func (ctx taskAsyncContext) Deadline() (time.Time, bool) { + return ctx.t.Deadline() +} + +// Done implements context.Context.Done. +func (ctx taskAsyncContext) Done() <-chan struct{} { + return ctx.t.Done() +} + +// Err implements context.Context.Err. +func (ctx taskAsyncContext) Err() error { + return ctx.t.Err() +} + // Value implements context.Context.Value. func (ctx taskAsyncContext) Value(key interface{}) interface{} { return ctx.t.Value(key) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index e64d648e2..28ba950bd 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -391,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if err := state.Save(w, k.FeatureSet(), nil); err != nil { + if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) @@ -399,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Save(w, k, &stats); err != nil { + if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil { return err } log.Infof("Kernel save stats: %s", &stats) @@ -407,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(w); err != nil { + if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -542,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if err := state.Load(r, &features, nil); err != nil { + if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -558,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the kernel state. kernelStart := time.Now() var stats state.Stats - if err := state.Load(r, k, &stats); err != nil { + if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil { return err } log.Infof("Kernel load stats: %s", &stats) @@ -566,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks) // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(r); err != nil { + if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -1322,6 +1322,7 @@ func (k *Kernel) ListSockets() []*SocketEntry { return socks } +// supervisorContext is a privileged context. type supervisorContext struct { context.NoopSleeper log.Logger diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go index 1effc7735..aafce1d00 100644 --- a/pkg/sentry/pgalloc/save_restore.go +++ b/pkg/sentry/pgalloc/save_restore.go @@ -16,6 +16,7 @@ package pgalloc import ( "bytes" + "context" "fmt" "io" "runtime" @@ -29,7 +30,7 @@ import ( ) // SaveTo writes f's state to the given stream. -func (f *MemoryFile) SaveTo(w io.Writer) error { +func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error { // Wait for reclaim. f.mu.Lock() defer f.mu.Unlock() @@ -78,10 +79,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // Save metadata. - if err := state.Save(w, &f.fileSize, nil); err != nil { + if err := state.Save(ctx, w, &f.fileSize, nil); err != nil { return err } - if err := state.Save(w, &f.usage, nil); err != nil { + if err := state.Save(ctx, w, &f.usage, nil); err != nil { return err } @@ -114,9 +115,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error { } // LoadFrom loads MemoryFile state from the given stream. -func (f *MemoryFile) LoadFrom(r io.Reader) error { +func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error { // Load metadata. - if err := state.Load(r, &f.fileSize, nil); err != nil { + if err := state.Load(ctx, r, &f.fileSize, nil); err != nil { return err } if err := f.file.Truncate(f.fileSize); err != nil { @@ -124,7 +125,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error { } newMappings := make([]uintptr, f.fileSize>>chunkShift) f.mappings.Store(newMappings) - if err := state.Load(r, &f.usage, nil); err != nil { + if err := state.Load(ctx, r, &f.usage, nil); err != nil { return err } diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 31fa48ec5..6803d488c 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -23,6 +23,7 @@ go_library( "machine.go", "machine_amd64.go", "machine_amd64_unsafe.go", + "machine_arm64.go", "machine_unsafe.go", "physical_map.go", "virtual_map.go", diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go index acd41f73d..ea8b9632e 100644 --- a/pkg/sentry/platform/kvm/address_space.go +++ b/pkg/sentry/platform/kvm/address_space.go @@ -127,7 +127,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac // not have physical mappings, the KVM module may inject // spurious exceptions when emulation fails (i.e. it tries to // emulate because the RIP is pointed at those pages). - as.machine.mapPhysical(physical, length) + as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE) // Install the page table mappings. Note that the ordering is // important; if the pagetable mappings were installed before diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go index 80942e9c9..3f35414bb 100644 --- a/pkg/sentry/platform/kvm/allocator.go +++ b/pkg/sentry/platform/kvm/allocator.go @@ -54,7 +54,7 @@ func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr { // //go:nosplit func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs { - virtualStart, physicalStart, _, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions) if !ok { panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical)) } diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index b97476053..f6459cda9 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -46,9 +46,9 @@ func yield() { // calculateBluepillFault calculates the fault address range. // //go:nosplit -func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) { +func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) { alignedPhysical := physical &^ uintptr(usermem.PageSize-1) - for _, pr := range physicalRegions { + for _, pr := range phyRegions { end := pr.physical + pr.length if physical < pr.physical || physical >= end { continue @@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng // The corresponding virtual address is returned. This may throw on error. // //go:nosplit -func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { +func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) { // Paging fault: we need to map the underlying physical pages for this // fault. This all has to be done in this function because we're in a // signal handler context. (We can't call any functions that might // split the stack.) - virtualStart, physicalStart, length, ok := calculateBluepillFault(physical) + virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { return 0, false } @@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) { yield() // Race with another call. slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0)) } - errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart) + errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags) if errno == 0 { // Successfully added region; we can increment nextSlot and // allow another set to proceed here. diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index ee730ad70..ca011ef78 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -80,13 +80,17 @@ func bluepillHandler(context unsafe.Pointer) { // interrupted KVM. Since we're in a signal handler // currently, all signals are masked and the signal // must have been delivered directly to this thread. + timeout := syscall.Timespec{} sig, _, errno := syscall.RawSyscall6( syscall.SYS_RT_SIGTIMEDWAIT, uintptr(unsafe.Pointer(&bounceSignalMask)), - 0, // siginfo. - 0, // timeout. - 8, // sigset size. + 0, // siginfo. + uintptr(unsafe.Pointer(&timeout)), // timeout. + 8, // sigset size. 0, 0) + if errno == syscall.EAGAIN { + continue + } if errno != 0 { throw("error waiting for pending signal") } @@ -162,7 +166,7 @@ func bluepillHandler(context unsafe.Pointer) { // For MMIO, the physical address is the first data item. physical := uintptr(c.runData.data[0]) - virtual, ok := handleBluepillFault(c.machine, physical) + virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE) if !ok { c.die(bluepillArchContext(context), "invalid physical address") return diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go index d05f05c29..766131d60 100644 --- a/pkg/sentry/platform/kvm/kvm_const.go +++ b/pkg/sentry/platform/kvm/kvm_const.go @@ -62,3 +62,10 @@ const ( _KVM_NR_INTERRUPTS = 0x100 _KVM_NR_CPUID_ENTRIES = 0x100 ) + +// KVM kvm_memory_region::flags. +const ( + _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0 + _KVM_MEM_READONLY = uint32(1) << 1 + _KVM_MEM_FLAGS_NONE = 0 +) diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index cc6c138b2..7d02ebf19 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -215,6 +215,17 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) + var physicalRegionsReadOnly []physicalRegion + var physicalRegionsAvailable []physicalRegion + + physicalRegionsReadOnly = rdonlyRegionsForSetMem() + physicalRegionsAvailable = availableRegionsForSetMem() + + // Map all read-only regions. + for _, r := range physicalRegionsReadOnly { + m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY) + } + // Ensure that the currently mapped virtual regions are actually // available in the VM. Note that this doesn't guarantee no future // faults, however it should guarantee that everything is available to @@ -223,6 +234,13 @@ func newMachine(vm int) (*machine, error) { if excludeVirtualRegion(vr) { return // skip region. } + + for _, r := range physicalRegionsReadOnly { + if vr.virtual == r.virtual { + return + } + } + for virtual := vr.virtual; virtual < vr.virtual+vr.length; { physical, length, ok := translateToPhysical(virtual) if !ok { @@ -236,7 +254,7 @@ func newMachine(vm int) (*machine, error) { } // Ensure the physical range is mapped. - m.mapPhysical(physical, length) + m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE) virtual += length } }) @@ -256,9 +274,9 @@ func newMachine(vm int) (*machine, error) { // not available. This attempts to be efficient for calls in the hot path. // // This panics on error. -func (m *machine) mapPhysical(physical, length uintptr) { +func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) { for end := physical + length; physical < end; { - _, physicalStart, length, ok := calculateBluepillFault(physical) + _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions) if !ok { // Should never happen. panic("mapPhysical on unknown physical address") @@ -266,7 +284,7 @@ func (m *machine) mapPhysical(physical, length uintptr) { if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok { // Not present in the cache; requires setting the slot. - if _, ok := handleBluepillFault(m, physical); !ok { + if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok { panic("handleBluepillFault failed") } } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c1cbe33be..b99fe425e 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -355,3 +355,13 @@ func (m *machine) retryInGuest(fn func()) { } } } + +// On x86 platform, the flags for "setMemoryRegion" can always be set as 0. +// There is no need to return read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + return nil +} + +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + return physicalRegions +} diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go index 506ec9af1..61227cafb 100644 --- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go @@ -26,30 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/time" ) -// setMemoryRegion initializes a region. -// -// This may be called from bluepillHandler, and therefore returns an errno -// directly (instead of wrapping in an error) to avoid allocations. -// -//go:nosplit -func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno { - userRegion := userMemoryRegion{ - slot: uint32(slot), - flags: 0, - guestPhysAddr: uint64(physical), - memorySize: uint64(length), - userspaceAddr: uint64(virtual), - } - - // Set the region. - _, _, errno := syscall.RawSyscall( - syscall.SYS_IOCTL, - uintptr(m.fd), - _KVM_SET_USER_MEMORY_REGION, - uintptr(unsafe.Pointer(&userRegion))) - return errno -} - // loadSegments copies the current segments. // // This may be called from within the signal context and throws on error. diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go new file mode 100644 index 000000000..b7e2cfb9d --- /dev/null +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -0,0 +1,61 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kvm + +// Get all read-only physicalRegions. +func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { + var rdonlyRegions []region + + applyVirtualRegions(func(vr virtualRegion) { + if excludeVirtualRegion(vr) { + return + } + + if !vr.accessType.Write && vr.accessType.Read { + rdonlyRegions = append(rdonlyRegions, vr.region) + } + }) + + for _, r := range rdonlyRegions { + physical, _, ok := translateToPhysical(r.virtual) + if !ok { + continue + } + + phyRegions = append(phyRegions, physicalRegion{ + region: region{ + virtual: r.virtual, + length: r.length, + }, + physical: physical, + }) + } + + return phyRegions +} + +// Get all available physicalRegions. +func availableRegionsForSetMem() (phyRegions []physicalRegion) { + var excludeRegions []region + applyVirtualRegions(func(vr virtualRegion) { + if !vr.accessType.Write { + excludeRegions = append(excludeRegions, vr.region) + } + }) + + phyRegions = computePhysicalRegions(excludeRegions) + + return phyRegions +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index e00c7ae40..ed9433311 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -35,6 +35,30 @@ func entersyscall() //go:linkname exitsyscall runtime.exitsyscall func exitsyscall() +// setMemoryRegion initializes a region. +// +// This may be called from bluepillHandler, and therefore returns an errno +// directly (instead of wrapping in an error) to avoid allocations. +// +//go:nosplit +func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno { + userRegion := userMemoryRegion{ + slot: uint32(slot), + flags: uint32(flags), + guestPhysAddr: uint64(physical), + memorySize: uint64(length), + userspaceAddr: uint64(virtual), + } + + // Set the region. + _, _, errno := syscall.RawSyscall( + syscall.SYS_IOCTL, + uintptr(m.fd), + _KVM_SET_USER_MEMORY_REGION, + uintptr(unsafe.Pointer(&userRegion))) + return errno +} + // mapRunData maps the vCPU run data. func mapRunData(fd int) (*runData, error) { r, _, errno := syscall.RawSyscall6( diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index f95803f91..79589e3c8 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index b2732ca29..05dac4f0a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -27,6 +27,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -61,7 +62,7 @@ var netlinkSocketDevice = device.NewAnonDevice() // This implementation only supports userspace sending and receiving messages // to/from the kernel. // -// Socket implements socket.Socket. +// Socket implements socket.Socket and transport.Credentialer. // // +stateify savable type Socket struct { @@ -104,9 +105,13 @@ type Socket struct { // sendBufferSize is the send buffer "size". We don't actually have a // fixed buffer but only consume this many bytes. sendBufferSize uint32 + + // passcred indicates if this socket wants SCM credentials. + passcred bool } var _ socket.Socket = (*Socket)(nil) +var _ transport.Credentialer = (*Socket)(nil) // NewSocket creates a new Socket. func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { @@ -172,6 +177,22 @@ func (s *Socket) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } +// Passcred implements transport.Credentialer.Passcred. +func (s *Socket) Passcred() bool { + s.mu.Lock() + passcred := s.passcred + s.mu.Unlock() + return passcred +} + +// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. +func (s *Socket) ConnectedPasscred() bool { + // This socket is connected to the kernel, which doesn't need creds. + // + // This is arbitrary, as ConnectedPasscred on this type has no callers. + return false +} + // Ioctl implements fs.FileOperations.Ioctl. func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) { // TODO(b/68878065): no ioctls supported. @@ -309,9 +330,20 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem. // We don't have limit on receiving size. return int32(math.MaxInt32), nil + case linux.SO_PASSCRED: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + var passcred int32 + if s.Passcred() { + passcred = 1 + } + return passcred, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } + case linux.SOL_NETLINK: switch name { case linux.NETLINK_BROADCAST_ERROR, @@ -348,6 +380,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy s.sendBufferSize = size s.mu.Unlock() return nil + case linux.SO_RCVBUF: if len(opt) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -355,6 +388,18 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy // We don't have limit on receiving size. So just accept anything as // valid for compatibility. return nil + + case linux.SO_PASSCRED: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + passcred := usermem.ByteOrder.Uint32(opt) + + s.mu.Lock() + s.passcred = passcred != 0 + s.mu.Unlock() + return nil + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -483,6 +528,26 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ }) } +// kernelSCM implements control.SCMCredentials with credentials that represent +// the kernel itself rather than a Task. +// +// +stateify savable +type kernelSCM struct{} + +// Equals implements transport.CredentialsControlMessage.Equals. +func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool { + _, ok := oc.(kernelSCM) + return ok +} + +// Credentials implements control.SCMCredentials.Credentials. +func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) { + return 0, auth.RootUID, auth.RootGID +} + +// kernelCreds is the concrete version of kernelSCM used in all creds. +var kernelCreds = &kernelSCM{} + // sendResponse sends the response messages in ms back to userspace. func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { // Linux combines multiple netlink messages into a single datagram. @@ -491,10 +556,15 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error bufs = append(bufs, m.Finalize()) } + // All messages are from the kernel. + cms := transport.ControlMessages{ + Credentials: kernelCreds, + } + if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -521,7 +591,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error // Add the dump_done_errno payload. m.Put(int64(0)) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 145102c0d..ecce6c69f 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -42,8 +42,35 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" ) -// DefaultTimeout is a resonable timeout value for most applications. -const DefaultTimeout = 3 * time.Minute +// Opts configures the watchdog. +type Opts struct { + // TaskTimeout is the amount of time to allow a task to execute the + // same syscall without blocking before it's declared stuck. + TaskTimeout time.Duration + + // TaskTimeoutAction indicates what action to take when a stuck tasks + // is detected. + TaskTimeoutAction Action + + // StartupTimeout is the amount of time to allow between watchdog + // creation and calling watchdog.Start. + StartupTimeout time.Duration + + // StartupTimeoutAction indicates what action to take when + // watchdog.Start is not called within the timeout. + StartupTimeoutAction Action +} + +// DefaultOpts is a default set of options for the watchdog. +var DefaultOpts = Opts{ + // Task timeout. + TaskTimeout: 3 * time.Minute, + TaskTimeoutAction: LogWarning, + + // Startup timeout. + StartupTimeout: 30 * time.Second, + StartupTimeoutAction: LogWarning, +} // descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period // is discounted from task's last update time. It's set high enough that small scheduling delays won't @@ -61,6 +88,7 @@ type Action int const ( // LogWarning logs warning message followed by stack trace. LogWarning Action = iota + // Panic will do the same logging as LogWarning and panic(). Panic ) @@ -80,17 +108,13 @@ func (a Action) String() string { // Watchdog is the main watchdog class. It controls a goroutine that periodically // analyses all tasks and reports if any of them appear to be stuck. type Watchdog struct { + // Configuration options are embedded. + Opts + // period indicates how often to check all tasks. It's calculated based on - // 'taskTimeout'. + // opts.TaskTimeout. period time.Duration - // taskTimeout is the amount of time to allow a task to execute the same syscall - // without blocking before it's declared stuck. - taskTimeout time.Duration - - // timeoutAction indicates what action to take when a stuck tasks is detected. - timeoutAction Action - // k is where the tasks come from. k *kernel.Kernel @@ -113,8 +137,12 @@ type Watchdog struct { // mu protects the fields below. mu sync.Mutex - // started is true if the watchdog has been started before. - started bool + // running is true if the watchdog is running. + running bool + + // startCalled is true if Start has ever been called. It remains true + // even if Stop is called. + startCalled bool } type offender struct { @@ -122,58 +150,81 @@ type offender struct { } // New creates a new watchdog. -func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog { - // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much. - period := taskTimeout / 4 - return &Watchdog{ - k: k, - period: period, - taskTimeout: taskTimeout, - timeoutAction: a, - offenders: make(map[*kernel.Task]*offender), - stop: make(chan struct{}), - done: make(chan struct{}), +func New(k *kernel.Kernel, opts Opts) *Watchdog { + // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much. + period := opts.TaskTimeout / 4 + w := &Watchdog{ + Opts: opts, + k: k, + period: period, + offenders: make(map[*kernel.Task]*offender), + stop: make(chan struct{}), + done: make(chan struct{}), + } + + // Handle StartupTimeout if it exists. + if w.StartupTimeout > 0 { + log.Infof("Watchdog waiting %v for startup", w.StartupTimeout) + go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore. } + + return w } // Start starts the watchdog. func (w *Watchdog) Start() { - if w.taskTimeout == 0 { - log.Infof("Watchdog disabled") - return - } - w.mu.Lock() defer w.mu.Unlock() - if w.started { + w.startCalled = true + + if w.running { return } + if w.TaskTimeout == 0 { + log.Infof("Watchdog task timeout disabled") + return + } w.lastRun = w.k.MonotonicClock().Now() - log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction) + log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction) go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore. - w.started = true + w.running = true } // Stop requests the watchdog to stop and wait for it. func (w *Watchdog) Stop() { - if w.taskTimeout == 0 { + if w.TaskTimeout == 0 { return } w.mu.Lock() defer w.mu.Unlock() - if !w.started { + if !w.running { return } log.Infof("Stopping watchdog") w.stop <- struct{}{} <-w.done - w.started = false + w.running = false log.Infof("Watchdog stopped") } +// waitForStart waits for Start to be called and takes action if it does not +// happen within the startup timeout. +func (w *Watchdog) waitForStart() { + <-time.After(w.StartupTimeout) + w.mu.Lock() + defer w.mu.Unlock() + if w.startCalled { + // We are fine. + return + } + var buf bytes.Buffer + buf.WriteString("Watchdog.Start() not called within %s:\n") + w.doAction(w.StartupTimeoutAction, false, &buf) +} + // loop is the main watchdog routine. It only returns when 'Stop()' is called. func (w *Watchdog) loop() { // Loop until someone stops it. @@ -202,7 +253,7 @@ func (w *Watchdog) runTurn() { select { case <-done: - case <-time.After(w.taskTimeout): + case <-time.After(w.TaskTimeout): // Report if the watchdog is not making progress. // No one is wathching the watchdog watcher though. w.reportStuckWatchdog() @@ -231,7 +282,7 @@ func (w *Watchdog) runTurn() { if tsched.State == kernel.TaskGoroutineRunningSys { lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick))) elapsed := now.Sub(lastUpdateTime) - discount - if elapsed > w.taskTimeout { + if elapsed > w.TaskTimeout { tc, ok := w.offenders[t] if !ok { // New stuck task detected. @@ -261,28 +312,34 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo tid := w.k.TaskSet().Root.IDOfTask(t) buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime))) } + buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine") - w.onStuckTask(newTaskFound, &buf) + + // Dump stack only if a new task is detected or if it sometime has + // passed since the last time a stack dump was generated. + skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod + w.doAction(w.TaskTimeoutAction, skipStack, &buf) } func (w *Watchdog) reportStuckWatchdog() { var buf bytes.Buffer buf.WriteString("Watchdog goroutine is stuck:\n") - w.onStuckTask(true, &buf) + w.doAction(w.TaskTimeoutAction, false, &buf) } -func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { - switch w.timeoutAction { +// doAction will take the given action. If the action is LogWarnind and +// skipStack is true, then the stack printing will be skipped. +func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) { + switch action { case LogWarning: - // Dump stack only if a new task is detected or if it sometime has passed since - // the last time a stack dump was generated. - if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod { + if skipStack { msg.WriteString("\n...[stack dump skipped]...") log.Warningf(msg.String()) - } else { - log.TracebackAll(msg.String()) - w.lastStackDump = time.Now() + return + } + log.TracebackAll(msg.String()) + w.lastStackDump = time.Now() case Panic: // Panic will skip over running tasks, which is likely the culprit here. So manually @@ -301,5 +358,8 @@ func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) { case <-time.After(1 * time.Second): } panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String())) + default: + panic(fmt.Sprintf("Unknown watchdog action %v", action)) + } } diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 47e6b878a..590c241a3 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -133,6 +134,9 @@ func (os *objectState) findCycle() []*objectState { // to ensure that all callbacks are executed, otherwise the callback graph was // not acyclic. type decodeState struct { + // ctx is the decode context. + ctx context.Context + // objectByID is the set of objects in progress. objectsByID map[uint64]*objectState diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 5d9409a45..c5118d3a9 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -16,6 +16,7 @@ package state import ( "container/list" + "context" "encoding/binary" "fmt" "io" @@ -38,6 +39,9 @@ type queuedObject struct { // The encoding process is a breadth-first traversal of the object graph. The // inherent races and dependencies are much simpler than the decode case. type encodeState struct { + // ctx is the encode context. + ctx context.Context + // lastID is the last object ID. // // See idsByObject for context. Because of the special zero encoding diff --git a/pkg/state/map.go b/pkg/state/map.go index 7e6fefed4..4f3ebb0da 100644 --- a/pkg/state/map.go +++ b/pkg/state/map.go @@ -15,6 +15,7 @@ package state import ( + "context" "fmt" "reflect" "sort" @@ -219,3 +220,13 @@ func (m Map) AfterLoad(fn func()) { // data dependencies have been cleared. m.os.callbacks = append(m.os.callbacks, fn) } + +// Context returns the current context object. +func (m Map) Context() context.Context { + if m.es != nil { + return m.es.ctx + } else if m.ds != nil { + return m.ds.ctx + } + return context.Background() // No context. +} diff --git a/pkg/state/state.go b/pkg/state/state.go index d408ff84a..dbe507ab4 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -50,6 +50,7 @@ package state import ( + "context" "fmt" "io" "reflect" @@ -86,9 +87,10 @@ func UnwrapErrState(err error) error { } // Save saves the given object state. -func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { +func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error { // Create the encoding state. es := &encodeState{ + ctx: ctx, idsByObject: make(map[uintptr]uint64), w: w, stats: stats, @@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error { } // Load loads a checkpoint. -func Load(r io.Reader, rootPtr interface{}, stats *Stats) error { +func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error { // Create the decoding state. ds := &decodeState{ + ctx: ctx, objectsByID: make(map[uint64]*objectState), deferred: make(map[uint64]*pb.Object), r: r, diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go index 7c24bbcda..d7221e9e8 100644 --- a/pkg/state/state_test.go +++ b/pkg/state/state_test.go @@ -16,6 +16,7 @@ package state import ( "bytes" + "context" "io/ioutil" "math" "reflect" @@ -46,7 +47,7 @@ func runTest(t *testing.T, tests []TestCase) { saveBuffer := &bytes.Buffer{} saveObjectPtr := reflect.New(reflect.TypeOf(root)) saveObjectPtr.Elem().Set(reflect.ValueOf(root)) - if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Save failed unexpectedly: %v", err) continue } else if err != nil { @@ -56,7 +57,7 @@ func runTest(t *testing.T, tests []TestCase) { // Load a new copy of the object. loadObjectPtr := reflect.New(reflect.TypeOf(root)) - if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { + if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail { t.Errorf(" FAIL: Load failed unexpectedly: %v", err) continue } else if err != nil { @@ -624,7 +625,7 @@ func BenchmarkEncoding(b *testing.B) { bs := buildObject(b.N) var stats Stats b.StartTimer() - if err := Save(ioutil.Discard, bs, &stats); err != nil { + if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil { b.Errorf("save failed: %v", err) } b.StopTimer() @@ -638,12 +639,12 @@ func BenchmarkDecoding(b *testing.B) { bs := buildObject(b.N) var newBS benchStruct buf := &bytes.Buffer{} - if err := Save(buf, bs, nil); err != nil { + if err := Save(context.Background(), buf, bs, nil); err != nil { b.Errorf("save failed: %v", err) } var stats Stats b.StartTimer() - if err := Load(buf, &newBS, &stats); err != nil { + if err := Load(context.Background(), buf, &newBS, &stats); err != nil { b.Errorf("load failed: %v", err) } b.StopTimer() diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go index 59378b96c..e7c05ca4f 100644 --- a/pkg/tcpip/link/fdbased/endpoint_test.go +++ b/pkg/tcpip/link/fdbased/endpoint_test.go @@ -45,7 +45,7 @@ const ( type packetInfo struct { raddr tcpip.LinkAddress proto tcpip.NetworkProtocolNumber - contents buffer.View + contents buffer.VectorisedView linkHeader buffer.View } @@ -94,7 +94,7 @@ func (c *context) cleanup() { } func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) { - c.ch <- packetInfo{remote, protocol, vv.ToView(), linkHeader} + c.ch <- packetInfo{remote, protocol, vv, linkHeader} } func TestNoEthernetProperties(t *testing.T) { @@ -319,13 +319,17 @@ func TestDeliverPacket(t *testing.T) { want := packetInfo{ raddr: raddr, proto: proto, - contents: b, + contents: buffer.View(b).ToVectorisedView(), linkHeader: buffer.View(hdr), } if !eth { want.proto = header.IPv4ProtocolNumber want.raddr = "" } + // want.contents will be a single view, + // so make pi do the same for the + // DeepEqual check. + pi.contents = pi.contents.ToView().ToVectorisedView() if !reflect.DeepEqual(want, pi) { t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) } diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index 12168a1dc..3331b6453 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(n, BufConfig) - vv := buffer.NewVectorisedView(n, d.views[:used]) + vv := buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...)) vv.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) @@ -293,7 +293,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) { } used := d.capViews(k, int(n), BufConfig) - vv := buffer.NewVectorisedView(int(n), d.views[k][:used]) + vv := buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...)) vv.TrimFront(d.e.hdrSize) d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth)) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a01a208b8..fe8f83d58 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -762,7 +762,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link } n.mu.RUnlock() for _, ep := range packetEPs { - ep.HandlePacket(n.id, local, protocol, vv, linkHeader) + ep.HandlePacket(n.id, local, protocol, vv.Clone(nil), linkHeader) } if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0360187b8..d7c124e81 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -60,12 +60,19 @@ const ( // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { + // UniqueID returns an unique ID for this transport endpoint. + UniqueID() uint64 + // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. + // + // HandlePacket takes ownership of vv. HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) // HandleControlPacket is called by the stack when new control (e.g., // ICMP) packets arrive to this transport endpoint. + // + // HandleControlPacket takes ownership of vv. HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) // Close puts the endpoint in a closed state and frees all resources @@ -91,6 +98,8 @@ type RawTransportEndpoint interface { // HandlePacket is called by the stack when new packets arrive to // this transport endpoint. The packet contains all data from the link // layer up. + // + // HandlePacket takes ownership of packet and netHeader. HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView) } @@ -107,6 +116,8 @@ type PacketEndpoint interface { // // linkHeader may have a length of 0, in which case the PacketEndpoint // should construct its own ethernet header for applications. + // + // HandlePacket takes ownership of packet and linkHeader. HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View) } @@ -157,10 +168,14 @@ type TransportDispatcher interface { // DeliverTransportPacket delivers packets to the appropriate // transport protocol endpoint. It also returns the network layer // header for the enpoint to inspect or pass up the stack. + // + // DeliverTransportPacket takes ownership of vv and netHeader. DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. + // + // DeliverTransportControlPacket takes ownership of vv. DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) } @@ -234,6 +249,8 @@ type NetworkEndpoint interface { // HandlePacket is called by the link layer when new packets arrive to // this network endpoint. + // + // HandlePacket takes ownership of vv. HandlePacket(r *Route, vv buffer.VectorisedView) // Close is called when the endpoint is reomved from a stack. @@ -279,6 +296,8 @@ type NetworkDispatcher interface { // DeliverNetworkPacket finds the appropriate network protocol endpoint // and hands the packet over for further processing. linkHeader may have // length 0 when the caller does not have ethernet data. + // + // DeliverNetworkPacket takes ownership of vv and linkHeader. DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 6d6ddc0ff..115a6fcb8 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "encoding/binary" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" @@ -344,6 +345,13 @@ type ResumableEndpoint interface { Resume(*Stack) } +// uniqueIDGenerator is a default unique ID generator. +type uniqueIDGenerator uint64 + +func (u *uniqueIDGenerator) UniqueID() uint64 { + return atomic.AddUint64((*uint64)(u), 1) +} + // Stack is a networking stack, with all supported protocols, NICs, and route // table. type Stack struct { @@ -411,6 +419,14 @@ type Stack struct { // ndpDisp is the NDP event dispatcher that is used to send the netstack // integrator NDP related events. ndpDisp NDPDispatcher + + // uniqueIDGenerator is a generator of unique identifiers. + uniqueIDGenerator UniqueID +} + +// UniqueID is an abstract generator of unique identifiers. +type UniqueID interface { + UniqueID() uint64 } // Options contains optional Stack configuration. @@ -434,6 +450,9 @@ type Options struct { // stack (false). HandleLocal bool + // UniqueID is an optional generator of unique identifiers. + UniqueID UniqueID + // NDPConfigs is the default NDP configurations used by interfaces. // // By default, NDPConfigs will have a zero value for its @@ -506,6 +525,10 @@ func New(opts Options) *Stack { clock = &tcpip.StdClock{} } + if opts.UniqueID == nil { + opts.UniqueID = new(uniqueIDGenerator) + } + // Make sure opts.NDPConfigs contains valid values only. opts.NDPConfigs.validate() @@ -524,6 +547,7 @@ func New(opts Options) *Stack { portSeed: generateRandUint32(), ndpConfigs: opts.NDPConfigs, autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal, + uniqueIDGenerator: opts.UniqueID, ndpDisp: opts.NDPDisp, } @@ -551,6 +575,11 @@ func New(opts Options) *Stack { return s } +// UniqueID returns a unique identifier. +func (s *Stack) UniqueID() uint64 { + return s.uniqueIDGenerator.UniqueID() +} + // SetNetworkProtocolOption allows configuring individual protocol level // options. This method returns an error if the protocol is not supported or // option is not supported by the protocol implementation or the provided value diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index f633632f0..ccd3d030e 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "math/rand" + "sort" "sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -310,6 +311,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + + // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints + // can be restored in the same order. + sort.Slice(ep.endpointsArr, func(i, j int) bool { + return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID() + }) + for i, e := range ep.endpointsArr { + ep.endpointsMap[e] = i + } return nil } diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index ae6fda3a9..203e79f56 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -43,6 +43,7 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + uniqueID uint64 // acceptQueue is non-nil iff bound. acceptQueue []fakeTransportEndpoint @@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { - return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto} +func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} } func (f *fakeTransportEndpoint) Close() { @@ -144,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return nil } +func (f *fakeTransportEndpoint) UniqueID() uint64 { + return f.uniqueID +} + func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -251,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { } func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newFakeTransportEndpoint(stack, f, netProto), nil + return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil } func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d0dd383fd..33405eb7d 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -31,9 +31,6 @@ type icmpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } type endpointState int @@ -58,6 +55,7 @@ type endpoint struct { // immutable. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -90,9 +88,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: stateInitial, + uniqueID: s.UniqueID(), }, nil } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -760,7 +764,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv }, } - pkt.data = vv.Clone(pkt.views[:]) + pkt.data = vv e.rcvList.PushBack(pkt) e.rcvBufSize += pkt.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 73cdaa265..ead83b83d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -41,10 +41,6 @@ type packet struct { // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -310,7 +306,7 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, if ep.cooked { // Cooked packets can simply be queued. - packet.data = vv.Clone(packet.views[:]) + packet.data = vv } else { // Raw packets need their ethernet headers prepended before // queueing. @@ -328,7 +324,7 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, } combinedVV := buffer.View(ethHeader).ToVectorisedView() combinedVV.Append(vv) - packet.data = combinedVV.Clone(packet.views[:]) + packet.data = combinedVV } packet.timestampNS = ep.stack.NowNanoseconds() diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 951d317ed..23922a30e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -42,10 +42,6 @@ type rawPacket struct { // data holds the actual packet data, including any headers and // payload. data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - // views is pre-allocated space to back data. As long as the packet is - // made up of fewer than 8 buffer.Views, no extra allocation is - // necessary to store packet data. - views [8]buffer.View `state:"nosave"` // timestampNS is the unix time at which the packet was received. timestampNS int64 // senderAddr is the network address of the sender. @@ -609,7 +605,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu combinedVV := netHeader.ToVectorisedView() combinedVV.Append(vv) - pkt.data = combinedVV.Clone(pkt.views[:]) + pkt.data = combinedVV pkt.timestampNS = e.stack.NowNanoseconds() e.rcvList.PushBack(pkt) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8a3ca0f1b..a1efd8d55 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -287,6 +287,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 // lastError represents the last error that the endpoint reported; // access to it is protected by the following mutex. @@ -504,6 +505,11 @@ type endpoint struct { stats Stats `state:"nosave"` } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // calculateAdvertisedMSS calculates the MSS to advertise. // // If userMSS is non-zero and is not greater than the maximum possible MSS for @@ -565,6 +571,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue interval: 75 * time.Second, count: 9, }, + uniqueID: s.UniqueID(), } var ss SendBufferSizeOption diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index cda302bb7..03bd5c8fd 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -31,9 +31,6 @@ type udpPacket struct { senderAddress tcpip.FullAddress data buffer.VectorisedView `state:".(buffer.VectorisedView)"` timestamp int64 - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` } // EndpointState represents the state of a UDP endpoint. @@ -80,6 +77,7 @@ type endpoint struct { // change throughout the lifetime of the endpoint. stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue + uniqueID uint64 // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -160,9 +158,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, state: StateInitial, + uniqueID: s.UniqueID(), } } +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { @@ -1195,7 +1199,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv Port: hdr.SourcePort(), }, } - pkt.data = vv.Clone(pkt.views[:]) + pkt.data = vv e.rcvList.PushBack(pkt) e.rcvBufSize += vv.Size() |