summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--kokoro/runtime_tests.cfg1
-rw-r--r--pkg/flipcall/ctrl_futex.go2
-rw-r--r--pkg/flipcall/flipcall.go10
-rw-r--r--pkg/flipcall/futex_linux.go6
-rw-r--r--pkg/p9/server.go6
-rw-r--r--pkg/sentry/context/context.go63
-rw-r--r--pkg/sentry/kernel/context.go32
-rw-r--r--pkg/sentry/kernel/kernel.go13
-rw-r--r--pkg/sentry/pgalloc/save_restore.go13
-rw-r--r--pkg/sentry/platform/kvm/BUILD1
-rw-r--r--pkg/sentry/platform/kvm/address_space.go2
-rw-r--r--pkg/sentry/platform/kvm/allocator.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go10
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go12
-rw-r--r--pkg/sentry/platform/kvm/kvm_const.go7
-rw-r--r--pkg/sentry/platform/kvm/machine.go26
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64.go10
-rw-r--r--pkg/sentry/platform/kvm/machine_amd64_unsafe.go24
-rw-r--r--pkg/sentry/platform/kvm/machine_arm64.go61
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go24
-rw-r--r--pkg/sentry/socket/netlink/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/socket.go76
-rw-r--r--pkg/sentry/watchdog/watchdog.go152
-rw-r--r--pkg/state/decode.go4
-rw-r--r--pkg/state/encode.go4
-rw-r--r--pkg/state/map.go11
-rw-r--r--pkg/state/state.go7
-rw-r--r--pkg/state/state_test.go11
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go10
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go4
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/stack/registration.go19
-rw-r--r--pkg/tcpip/stack/stack.go29
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go10
-rw-r--r--pkg/tcpip/stack/transport_test.go11
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go12
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go8
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go12
-rw-r--r--runsc/boot/controller.go8
-rw-r--r--runsc/boot/loader.go9
-rw-r--r--runsc/container/BUILD1
-rw-r--r--runsc/container/container.go344
-rw-r--r--runsc/container/state_file.go185
-rw-r--r--runsc/main.go8
-rwxr-xr-xscripts/runtime_tests.sh25
-rw-r--r--test/runtimes/BUILD10
-rw-r--r--test/runtimes/build_defs.bzl41
-rw-r--r--test/syscalls/linux/semaphore.cc5
-rw-r--r--test/syscalls/linux/socket_inet_loopback.cc107
-rw-r--r--test/syscalls/linux/socket_ip_tcp_generic.cc6
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc110
-rwxr-xr-xtools/go_branch.sh26
54 files changed, 1183 insertions, 423 deletions
diff --git a/kokoro/runtime_tests.cfg b/kokoro/runtime_tests.cfg
new file mode 100644
index 000000000..7d56d5aca
--- /dev/null
+++ b/kokoro/runtime_tests.cfg
@@ -0,0 +1 @@
+build_file: "repo/scripts/runtime_tests.sh"
diff --git a/pkg/flipcall/ctrl_futex.go b/pkg/flipcall/ctrl_futex.go
index 8390915a2..e7c3a3a0b 100644
--- a/pkg/flipcall/ctrl_futex.go
+++ b/pkg/flipcall/ctrl_futex.go
@@ -113,7 +113,7 @@ func (ep *Endpoint) enterFutexWait() error {
return nil
case epsBlocked | epsShutdown:
atomic.AddInt32(&ep.ctrl.state, -epsBlocked)
- return shutdownError{}
+ return ShutdownError{}
default:
// Most likely due to ep.enterFutexWait() being called concurrently
// from multiple goroutines.
diff --git a/pkg/flipcall/flipcall.go b/pkg/flipcall/flipcall.go
index 386cee42c..3cdb576e1 100644
--- a/pkg/flipcall/flipcall.go
+++ b/pkg/flipcall/flipcall.go
@@ -136,8 +136,8 @@ func (ep *Endpoint) unmapPacket() {
// Shutdown causes concurrent and future calls to ep.Connect(), ep.SendRecv(),
// ep.RecvFirst(), and ep.SendLast(), as well as the same calls in the peer
-// Endpoint, to unblock and return errors. It does not wait for concurrent
-// calls to return. Successive calls to Shutdown have no effect.
+// Endpoint, to unblock and return ShutdownErrors. It does not wait for
+// concurrent calls to return. Successive calls to Shutdown have no effect.
//
// Shutdown is the only Endpoint method that may be called concurrently with
// other methods on the same Endpoint.
@@ -154,10 +154,12 @@ func (ep *Endpoint) isShutdownLocally() bool {
return atomic.LoadUint32(&ep.shutdown) != 0
}
-type shutdownError struct{}
+// ShutdownError is returned by most Endpoint methods after Endpoint.Shutdown()
+// has been called.
+type ShutdownError struct{}
// Error implements error.Error.
-func (shutdownError) Error() string {
+func (ShutdownError) Error() string {
return "flipcall connection shutdown"
}
diff --git a/pkg/flipcall/futex_linux.go b/pkg/flipcall/futex_linux.go
index b127a2bbb..168c1ccff 100644
--- a/pkg/flipcall/futex_linux.go
+++ b/pkg/flipcall/futex_linux.go
@@ -61,7 +61,7 @@ func (ep *Endpoint) futexSwitchToPeer() error {
if !atomic.CompareAndSwapUint32(ep.connState(), ep.activeState, ep.inactiveState) {
switch cs := atomic.LoadUint32(ep.connState()); cs {
case csShutdown:
- return shutdownError{}
+ return ShutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAKE: %v", cs)
}
@@ -81,14 +81,14 @@ func (ep *Endpoint) futexSwitchFromPeer() error {
return nil
case ep.inactiveState:
if ep.isShutdownLocally() {
- return shutdownError{}
+ return ShutdownError{}
}
if err := ep.futexWaitConnState(ep.inactiveState); err != nil {
return fmt.Errorf("failed to FUTEX_WAIT for peer Endpoint: %v", err)
}
continue
case csShutdown:
- return shutdownError{}
+ return ShutdownError{}
default:
return fmt.Errorf("unexpected connection state before FUTEX_WAIT: %v", cs)
}
diff --git a/pkg/p9/server.go b/pkg/p9/server.go
index e717e6161..40b8fa023 100644
--- a/pkg/p9/server.go
+++ b/pkg/p9/server.go
@@ -453,7 +453,11 @@ func (cs *connState) initializeChannels() (err error) {
go func() { // S/R-SAFE: Server side.
defer cs.channelWg.Done()
if err := res.service(cs); err != nil {
- log.Warningf("p9.channel.service: %v", err)
+ // Don't log flipcall.ShutdownErrors, which we expect to be
+ // returned during server shutdown.
+ if _, ok := err.(flipcall.ShutdownError); !ok {
+ log.Warningf("p9.channel.service: %v", err)
+ }
}
}()
}
diff --git a/pkg/sentry/context/context.go b/pkg/sentry/context/context.go
index dfd62cbdb..23e009ef3 100644
--- a/pkg/sentry/context/context.go
+++ b/pkg/sentry/context/context.go
@@ -12,10 +12,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package context defines the sentry's Context type.
+// Package context defines an internal context type.
+//
+// The given Context conforms to the standard Go context, but mandates
+// additional methods that are specific to the kernel internals. Note however,
+// that the Context described by this package carries additional constraints
+// regarding concurrent access and retaining beyond the scope of a call.
+//
+// See the Context type for complete details.
package context
import (
+ "context"
+ "time"
+
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/log"
)
@@ -59,6 +69,7 @@ func ThreadGroupIDFromContext(ctx Context) (tgid int32, ok bool) {
type Context interface {
log.Logger
amutex.Sleeper
+ context.Context
// UninterruptibleSleepStart indicates the beginning of an uninterruptible
// sleep state (equivalent to Linux's TASK_UNINTERRUPTIBLE). If deactivate
@@ -72,19 +83,36 @@ type Context interface {
// AddressSpace is activated. Normally activate is the same value as the
// deactivate parameter passed to UninterruptibleSleepStart.
UninterruptibleSleepFinish(activate bool)
+}
+
+// NoopSleeper is a noop implementation of amutex.Sleeper and UninterruptibleSleep
+// methods for anonymous embedding in other types that do not implement sleeps.
+type NoopSleeper struct {
+ amutex.NoopSleeper
+}
+
+// UninterruptibleSleepStart does nothing.
+func (NoopSleeper) UninterruptibleSleepStart(bool) {}
+
+// UninterruptibleSleepFinish does nothing.
+func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
+
+// Deadline returns zero values, meaning no deadline.
+func (NoopSleeper) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
+
+// Done returns nil.
+func (NoopSleeper) Done() <-chan struct{} {
+ return nil
+}
- // Value returns the value associated with this Context for key, or nil if
- // no value is associated with key. Successive calls to Value with the same
- // key returns the same result.
- //
- // A key identifies a specific value in a Context. Functions that wish to
- // retrieve values from Context typically allocate a key in a global
- // variable then use that key as the argument to Context.Value. A key can
- // be any type that supports equality; packages should define keys as an
- // unexported type to avoid collisions.
- Value(key interface{}) interface{}
+// Err returns nil.
+func (NoopSleeper) Err() error {
+ return nil
}
+// logContext implements basic logging.
type logContext struct {
log.Logger
NoopSleeper
@@ -95,19 +123,6 @@ func (logContext) Value(key interface{}) interface{} {
return nil
}
-// NoopSleeper is a noop implementation of amutex.Sleeper and
-// Context.UninterruptibleSleep* methods for anonymous embedding in other types
-// that do not want to notify kernel.Task about sleeps.
-type NoopSleeper struct {
- amutex.NoopSleeper
-}
-
-// UninterruptibleSleepStart does nothing.
-func (NoopSleeper) UninterruptibleSleepStart(bool) {}
-
-// UninterruptibleSleepFinish does nothing.
-func (NoopSleeper) UninterruptibleSleepFinish(bool) {}
-
// bgContext is the context returned by context.Background.
var bgContext = &logContext{Logger: log.Log()}
diff --git a/pkg/sentry/kernel/context.go b/pkg/sentry/kernel/context.go
index e3f5b0d83..3c9dceaba 100644
--- a/pkg/sentry/kernel/context.go
+++ b/pkg/sentry/kernel/context.go
@@ -15,6 +15,8 @@
package kernel
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/context"
)
@@ -97,6 +99,21 @@ func TaskFromContext(ctx context.Context) *Task {
return nil
}
+// Deadline implements context.Context.Deadline.
+func (*Task) Deadline() (time.Time, bool) {
+ return time.Time{}, false
+}
+
+// Done implements context.Context.Done.
+func (*Task) Done() <-chan struct{} {
+ return nil
+}
+
+// Err implements context.Context.Err.
+func (*Task) Err() error {
+ return nil
+}
+
// AsyncContext returns a context.Context that may be used by goroutines that
// do work on behalf of t and therefore share its contextual values, but are
// not t's task goroutine (e.g. asynchronous I/O).
@@ -129,6 +146,21 @@ func (ctx taskAsyncContext) IsLogging(level log.Level) bool {
return ctx.t.IsLogging(level)
}
+// Deadline implements context.Context.Deadline.
+func (ctx taskAsyncContext) Deadline() (time.Time, bool) {
+ return ctx.t.Deadline()
+}
+
+// Done implements context.Context.Done.
+func (ctx taskAsyncContext) Done() <-chan struct{} {
+ return ctx.t.Done()
+}
+
+// Err implements context.Context.Err.
+func (ctx taskAsyncContext) Err() error {
+ return ctx.t.Err()
+}
+
// Value implements context.Context.Value.
func (ctx taskAsyncContext) Value(key interface{}) interface{} {
return ctx.t.Value(key)
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index e64d648e2..28ba950bd 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -391,7 +391,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
//
// N.B. This will also be saved along with the full kernel save below.
cpuidStart := time.Now()
- if err := state.Save(w, k.FeatureSet(), nil); err != nil {
+ if err := state.Save(k.SupervisorContext(), w, k.FeatureSet(), nil); err != nil {
return err
}
log.Infof("CPUID save took [%s].", time.Since(cpuidStart))
@@ -399,7 +399,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// Save the kernel state.
kernelStart := time.Now()
var stats state.Stats
- if err := state.Save(w, k, &stats); err != nil {
+ if err := state.Save(k.SupervisorContext(), w, k, &stats); err != nil {
return err
}
log.Infof("Kernel save stats: %s", &stats)
@@ -407,7 +407,7 @@ func (k *Kernel) SaveTo(w io.Writer) error {
// Save the memory file's state.
memoryStart := time.Now()
- if err := k.mf.SaveTo(w); err != nil {
+ if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil {
return err
}
log.Infof("Memory save took [%s].", time.Since(memoryStart))
@@ -542,7 +542,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// don't need to explicitly install it in the Kernel.
cpuidStart := time.Now()
var features cpuid.FeatureSet
- if err := state.Load(r, &features, nil); err != nil {
+ if err := state.Load(k.SupervisorContext(), r, &features, nil); err != nil {
return err
}
log.Infof("CPUID load took [%s].", time.Since(cpuidStart))
@@ -558,7 +558,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the kernel state.
kernelStart := time.Now()
var stats state.Stats
- if err := state.Load(r, k, &stats); err != nil {
+ if err := state.Load(k.SupervisorContext(), r, k, &stats); err != nil {
return err
}
log.Infof("Kernel load stats: %s", &stats)
@@ -566,7 +566,7 @@ func (k *Kernel) LoadFrom(r io.Reader, net inet.Stack, clocks sentrytime.Clocks)
// Load the memory file's state.
memoryStart := time.Now()
- if err := k.mf.LoadFrom(r); err != nil {
+ if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil {
return err
}
log.Infof("Memory load took [%s].", time.Since(memoryStart))
@@ -1322,6 +1322,7 @@ func (k *Kernel) ListSockets() []*SocketEntry {
return socks
}
+// supervisorContext is a privileged context.
type supervisorContext struct {
context.NoopSleeper
log.Logger
diff --git a/pkg/sentry/pgalloc/save_restore.go b/pkg/sentry/pgalloc/save_restore.go
index 1effc7735..aafce1d00 100644
--- a/pkg/sentry/pgalloc/save_restore.go
+++ b/pkg/sentry/pgalloc/save_restore.go
@@ -16,6 +16,7 @@ package pgalloc
import (
"bytes"
+ "context"
"fmt"
"io"
"runtime"
@@ -29,7 +30,7 @@ import (
)
// SaveTo writes f's state to the given stream.
-func (f *MemoryFile) SaveTo(w io.Writer) error {
+func (f *MemoryFile) SaveTo(ctx context.Context, w io.Writer) error {
// Wait for reclaim.
f.mu.Lock()
defer f.mu.Unlock()
@@ -78,10 +79,10 @@ func (f *MemoryFile) SaveTo(w io.Writer) error {
}
// Save metadata.
- if err := state.Save(w, &f.fileSize, nil); err != nil {
+ if err := state.Save(ctx, w, &f.fileSize, nil); err != nil {
return err
}
- if err := state.Save(w, &f.usage, nil); err != nil {
+ if err := state.Save(ctx, w, &f.usage, nil); err != nil {
return err
}
@@ -114,9 +115,9 @@ func (f *MemoryFile) SaveTo(w io.Writer) error {
}
// LoadFrom loads MemoryFile state from the given stream.
-func (f *MemoryFile) LoadFrom(r io.Reader) error {
+func (f *MemoryFile) LoadFrom(ctx context.Context, r io.Reader) error {
// Load metadata.
- if err := state.Load(r, &f.fileSize, nil); err != nil {
+ if err := state.Load(ctx, r, &f.fileSize, nil); err != nil {
return err
}
if err := f.file.Truncate(f.fileSize); err != nil {
@@ -124,7 +125,7 @@ func (f *MemoryFile) LoadFrom(r io.Reader) error {
}
newMappings := make([]uintptr, f.fileSize>>chunkShift)
f.mappings.Store(newMappings)
- if err := state.Load(r, &f.usage, nil); err != nil {
+ if err := state.Load(ctx, r, &f.usage, nil); err != nil {
return err
}
diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD
index 31fa48ec5..6803d488c 100644
--- a/pkg/sentry/platform/kvm/BUILD
+++ b/pkg/sentry/platform/kvm/BUILD
@@ -23,6 +23,7 @@ go_library(
"machine.go",
"machine_amd64.go",
"machine_amd64_unsafe.go",
+ "machine_arm64.go",
"machine_unsafe.go",
"physical_map.go",
"virtual_map.go",
diff --git a/pkg/sentry/platform/kvm/address_space.go b/pkg/sentry/platform/kvm/address_space.go
index acd41f73d..ea8b9632e 100644
--- a/pkg/sentry/platform/kvm/address_space.go
+++ b/pkg/sentry/platform/kvm/address_space.go
@@ -127,7 +127,7 @@ func (as *addressSpace) mapHost(addr usermem.Addr, m hostMapEntry, at usermem.Ac
// not have physical mappings, the KVM module may inject
// spurious exceptions when emulation fails (i.e. it tries to
// emulate because the RIP is pointed at those pages).
- as.machine.mapPhysical(physical, length)
+ as.machine.mapPhysical(physical, length, physicalRegions, _KVM_MEM_FLAGS_NONE)
// Install the page table mappings. Note that the ordering is
// important; if the pagetable mappings were installed before
diff --git a/pkg/sentry/platform/kvm/allocator.go b/pkg/sentry/platform/kvm/allocator.go
index 80942e9c9..3f35414bb 100644
--- a/pkg/sentry/platform/kvm/allocator.go
+++ b/pkg/sentry/platform/kvm/allocator.go
@@ -54,7 +54,7 @@ func (a allocator) PhysicalFor(ptes *pagetables.PTEs) uintptr {
//
//go:nosplit
func (a allocator) LookupPTEs(physical uintptr) *pagetables.PTEs {
- virtualStart, physicalStart, _, ok := calculateBluepillFault(physical)
+ virtualStart, physicalStart, _, ok := calculateBluepillFault(physical, physicalRegions)
if !ok {
panic(fmt.Sprintf("LookupPTEs failed for 0x%x", physical))
}
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
index b97476053..f6459cda9 100644
--- a/pkg/sentry/platform/kvm/bluepill_fault.go
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -46,9 +46,9 @@ func yield() {
// calculateBluepillFault calculates the fault address range.
//
//go:nosplit
-func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, length uintptr, ok bool) {
+func calculateBluepillFault(physical uintptr, phyRegions []physicalRegion) (virtualStart, physicalStart, length uintptr, ok bool) {
alignedPhysical := physical &^ uintptr(usermem.PageSize-1)
- for _, pr := range physicalRegions {
+ for _, pr := range phyRegions {
end := pr.physical + pr.length
if physical < pr.physical || physical >= end {
continue
@@ -77,12 +77,12 @@ func calculateBluepillFault(physical uintptr) (virtualStart, physicalStart, leng
// The corresponding virtual address is returned. This may throw on error.
//
//go:nosplit
-func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) {
+func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegion, flags uint32) (uintptr, bool) {
// Paging fault: we need to map the underlying physical pages for this
// fault. This all has to be done in this function because we're in a
// signal handler context. (We can't call any functions that might
// split the stack.)
- virtualStart, physicalStart, length, ok := calculateBluepillFault(physical)
+ virtualStart, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
if !ok {
return 0, false
}
@@ -96,7 +96,7 @@ func handleBluepillFault(m *machine, physical uintptr) (uintptr, bool) {
yield() // Race with another call.
slot = atomic.SwapUint32(&m.nextSlot, ^uint32(0))
}
- errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart)
+ errno := m.setMemoryRegion(int(slot), physicalStart, length, virtualStart, flags)
if errno == 0 {
// Successfully added region; we can increment nextSlot and
// allow another set to proceed here.
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index ee730ad70..ca011ef78 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -80,13 +80,17 @@ func bluepillHandler(context unsafe.Pointer) {
// interrupted KVM. Since we're in a signal handler
// currently, all signals are masked and the signal
// must have been delivered directly to this thread.
+ timeout := syscall.Timespec{}
sig, _, errno := syscall.RawSyscall6(
syscall.SYS_RT_SIGTIMEDWAIT,
uintptr(unsafe.Pointer(&bounceSignalMask)),
- 0, // siginfo.
- 0, // timeout.
- 8, // sigset size.
+ 0, // siginfo.
+ uintptr(unsafe.Pointer(&timeout)), // timeout.
+ 8, // sigset size.
0, 0)
+ if errno == syscall.EAGAIN {
+ continue
+ }
if errno != 0 {
throw("error waiting for pending signal")
}
@@ -162,7 +166,7 @@ func bluepillHandler(context unsafe.Pointer) {
// For MMIO, the physical address is the first data item.
physical := uintptr(c.runData.data[0])
- virtual, ok := handleBluepillFault(c.machine, physical)
+ virtual, ok := handleBluepillFault(c.machine, physical, physicalRegions, _KVM_MEM_FLAGS_NONE)
if !ok {
c.die(bluepillArchContext(context), "invalid physical address")
return
diff --git a/pkg/sentry/platform/kvm/kvm_const.go b/pkg/sentry/platform/kvm/kvm_const.go
index d05f05c29..766131d60 100644
--- a/pkg/sentry/platform/kvm/kvm_const.go
+++ b/pkg/sentry/platform/kvm/kvm_const.go
@@ -62,3 +62,10 @@ const (
_KVM_NR_INTERRUPTS = 0x100
_KVM_NR_CPUID_ENTRIES = 0x100
)
+
+// KVM kvm_memory_region::flags.
+const (
+ _KVM_MEM_LOG_DIRTY_PAGES = uint32(1) << 0
+ _KVM_MEM_READONLY = uint32(1) << 1
+ _KVM_MEM_FLAGS_NONE = 0
+)
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index cc6c138b2..7d02ebf19 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -215,6 +215,17 @@ func newMachine(vm int) (*machine, error) {
return true // Keep iterating.
})
+ var physicalRegionsReadOnly []physicalRegion
+ var physicalRegionsAvailable []physicalRegion
+
+ physicalRegionsReadOnly = rdonlyRegionsForSetMem()
+ physicalRegionsAvailable = availableRegionsForSetMem()
+
+ // Map all read-only regions.
+ for _, r := range physicalRegionsReadOnly {
+ m.mapPhysical(r.physical, r.length, physicalRegionsReadOnly, _KVM_MEM_READONLY)
+ }
+
// Ensure that the currently mapped virtual regions are actually
// available in the VM. Note that this doesn't guarantee no future
// faults, however it should guarantee that everything is available to
@@ -223,6 +234,13 @@ func newMachine(vm int) (*machine, error) {
if excludeVirtualRegion(vr) {
return // skip region.
}
+
+ for _, r := range physicalRegionsReadOnly {
+ if vr.virtual == r.virtual {
+ return
+ }
+ }
+
for virtual := vr.virtual; virtual < vr.virtual+vr.length; {
physical, length, ok := translateToPhysical(virtual)
if !ok {
@@ -236,7 +254,7 @@ func newMachine(vm int) (*machine, error) {
}
// Ensure the physical range is mapped.
- m.mapPhysical(physical, length)
+ m.mapPhysical(physical, length, physicalRegionsAvailable, _KVM_MEM_FLAGS_NONE)
virtual += length
}
})
@@ -256,9 +274,9 @@ func newMachine(vm int) (*machine, error) {
// not available. This attempts to be efficient for calls in the hot path.
//
// This panics on error.
-func (m *machine) mapPhysical(physical, length uintptr) {
+func (m *machine) mapPhysical(physical, length uintptr, phyRegions []physicalRegion, flags uint32) {
for end := physical + length; physical < end; {
- _, physicalStart, length, ok := calculateBluepillFault(physical)
+ _, physicalStart, length, ok := calculateBluepillFault(physical, phyRegions)
if !ok {
// Should never happen.
panic("mapPhysical on unknown physical address")
@@ -266,7 +284,7 @@ func (m *machine) mapPhysical(physical, length uintptr) {
if _, ok := m.mappingCache.LoadOrStore(physicalStart, true); !ok {
// Not present in the cache; requires setting the slot.
- if _, ok := handleBluepillFault(m, physical); !ok {
+ if _, ok := handleBluepillFault(m, physical, phyRegions, flags); !ok {
panic("handleBluepillFault failed")
}
}
diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go
index c1cbe33be..b99fe425e 100644
--- a/pkg/sentry/platform/kvm/machine_amd64.go
+++ b/pkg/sentry/platform/kvm/machine_amd64.go
@@ -355,3 +355,13 @@ func (m *machine) retryInGuest(fn func()) {
}
}
}
+
+// On x86 platform, the flags for "setMemoryRegion" can always be set as 0.
+// There is no need to return read-only physicalRegions.
+func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
+ return nil
+}
+
+func availableRegionsForSetMem() (phyRegions []physicalRegion) {
+ return physicalRegions
+}
diff --git a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
index 506ec9af1..61227cafb 100644
--- a/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_amd64_unsafe.go
@@ -26,30 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/time"
)
-// setMemoryRegion initializes a region.
-//
-// This may be called from bluepillHandler, and therefore returns an errno
-// directly (instead of wrapping in an error) to avoid allocations.
-//
-//go:nosplit
-func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr) syscall.Errno {
- userRegion := userMemoryRegion{
- slot: uint32(slot),
- flags: 0,
- guestPhysAddr: uint64(physical),
- memorySize: uint64(length),
- userspaceAddr: uint64(virtual),
- }
-
- // Set the region.
- _, _, errno := syscall.RawSyscall(
- syscall.SYS_IOCTL,
- uintptr(m.fd),
- _KVM_SET_USER_MEMORY_REGION,
- uintptr(unsafe.Pointer(&userRegion)))
- return errno
-}
-
// loadSegments copies the current segments.
//
// This may be called from within the signal context and throws on error.
diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go
new file mode 100644
index 000000000..b7e2cfb9d
--- /dev/null
+++ b/pkg/sentry/platform/kvm/machine_arm64.go
@@ -0,0 +1,61 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package kvm
+
+// Get all read-only physicalRegions.
+func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) {
+ var rdonlyRegions []region
+
+ applyVirtualRegions(func(vr virtualRegion) {
+ if excludeVirtualRegion(vr) {
+ return
+ }
+
+ if !vr.accessType.Write && vr.accessType.Read {
+ rdonlyRegions = append(rdonlyRegions, vr.region)
+ }
+ })
+
+ for _, r := range rdonlyRegions {
+ physical, _, ok := translateToPhysical(r.virtual)
+ if !ok {
+ continue
+ }
+
+ phyRegions = append(phyRegions, physicalRegion{
+ region: region{
+ virtual: r.virtual,
+ length: r.length,
+ },
+ physical: physical,
+ })
+ }
+
+ return phyRegions
+}
+
+// Get all available physicalRegions.
+func availableRegionsForSetMem() (phyRegions []physicalRegion) {
+ var excludeRegions []region
+ applyVirtualRegions(func(vr virtualRegion) {
+ if !vr.accessType.Write {
+ excludeRegions = append(excludeRegions, vr.region)
+ }
+ })
+
+ phyRegions = computePhysicalRegions(excludeRegions)
+
+ return phyRegions
+}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index e00c7ae40..ed9433311 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -35,6 +35,30 @@ func entersyscall()
//go:linkname exitsyscall runtime.exitsyscall
func exitsyscall()
+// setMemoryRegion initializes a region.
+//
+// This may be called from bluepillHandler, and therefore returns an errno
+// directly (instead of wrapping in an error) to avoid allocations.
+//
+//go:nosplit
+func (m *machine) setMemoryRegion(slot int, physical, length, virtual uintptr, flags uint32) syscall.Errno {
+ userRegion := userMemoryRegion{
+ slot: uint32(slot),
+ flags: uint32(flags),
+ guestPhysAddr: uint64(physical),
+ memorySize: uint64(length),
+ userspaceAddr: uint64(virtual),
+ }
+
+ // Set the region.
+ _, _, errno := syscall.RawSyscall(
+ syscall.SYS_IOCTL,
+ uintptr(m.fd),
+ _KVM_SET_USER_MEMORY_REGION,
+ uintptr(unsafe.Pointer(&userRegion)))
+ return errno
+}
+
// mapRunData maps the vCPU run data.
func mapRunData(fd int) (*runData, error) {
r, _, errno := syscall.RawSyscall6(
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index f95803f91..79589e3c8 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -20,6 +20,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index b2732ca29..05dac4f0a 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -27,6 +27,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -61,7 +62,7 @@ var netlinkSocketDevice = device.NewAnonDevice()
// This implementation only supports userspace sending and receiving messages
// to/from the kernel.
//
-// Socket implements socket.Socket.
+// Socket implements socket.Socket and transport.Credentialer.
//
// +stateify savable
type Socket struct {
@@ -104,9 +105,13 @@ type Socket struct {
// sendBufferSize is the send buffer "size". We don't actually have a
// fixed buffer but only consume this many bytes.
sendBufferSize uint32
+
+ // passcred indicates if this socket wants SCM credentials.
+ passcred bool
}
var _ socket.Socket = (*Socket)(nil)
+var _ transport.Credentialer = (*Socket)(nil)
// NewSocket creates a new Socket.
func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
@@ -172,6 +177,22 @@ func (s *Socket) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
+// Passcred implements transport.Credentialer.Passcred.
+func (s *Socket) Passcred() bool {
+ s.mu.Lock()
+ passcred := s.passcred
+ s.mu.Unlock()
+ return passcred
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *Socket) ConnectedPasscred() bool {
+ // This socket is connected to the kernel, which doesn't need creds.
+ //
+ // This is arbitrary, as ConnectedPasscred on this type has no callers.
+ return false
+}
+
// Ioctl implements fs.FileOperations.Ioctl.
func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) {
// TODO(b/68878065): no ioctls supported.
@@ -309,9 +330,20 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.
// We don't have limit on receiving size.
return int32(math.MaxInt32), nil
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ var passcred int32
+ if s.Passcred() {
+ passcred = 1
+ }
+ return passcred, nil
+
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
}
+
case linux.SOL_NETLINK:
switch name {
case linux.NETLINK_BROADCAST_ERROR,
@@ -348,6 +380,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
s.sendBufferSize = size
s.mu.Unlock()
return nil
+
case linux.SO_RCVBUF:
if len(opt) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -355,6 +388,18 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
// We don't have limit on receiving size. So just accept anything as
// valid for compatibility.
return nil
+
+ case linux.SO_PASSCRED:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ passcred := usermem.ByteOrder.Uint32(opt)
+
+ s.mu.Lock()
+ s.passcred = passcred != 0
+ s.mu.Unlock()
+ return nil
+
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
@@ -483,6 +528,26 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _
})
}
+// kernelSCM implements control.SCMCredentials with credentials that represent
+// the kernel itself rather than a Task.
+//
+// +stateify savable
+type kernelSCM struct{}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
+ _, ok := oc.(kernelSCM)
+ return ok
+}
+
+// Credentials implements control.SCMCredentials.Credentials.
+func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ return 0, auth.RootUID, auth.RootGID
+}
+
+// kernelCreds is the concrete version of kernelSCM used in all creds.
+var kernelCreds = &kernelSCM{}
+
// sendResponse sends the response messages in ms back to userspace.
func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
// Linux combines multiple netlink messages into a single datagram.
@@ -491,10 +556,15 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
bufs = append(bufs, m.Finalize())
}
+ // All messages are from the kernel.
+ cms := transport.ControlMessages{
+ Credentials: kernelCreds,
+ }
+
if len(bufs) > 0 {
// RecvMsg never receives the address, so we don't need to send
// one.
- _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{})
// If the buffer is full, we simply drop messages, just like
// Linux.
if err != nil && err != syserr.ErrWouldBlock {
@@ -521,7 +591,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
// Add the dump_done_errno payload.
m.Put(int64(0))
- _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
return err
}
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 145102c0d..ecce6c69f 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -42,8 +42,35 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
)
-// DefaultTimeout is a resonable timeout value for most applications.
-const DefaultTimeout = 3 * time.Minute
+// Opts configures the watchdog.
+type Opts struct {
+ // TaskTimeout is the amount of time to allow a task to execute the
+ // same syscall without blocking before it's declared stuck.
+ TaskTimeout time.Duration
+
+ // TaskTimeoutAction indicates what action to take when a stuck tasks
+ // is detected.
+ TaskTimeoutAction Action
+
+ // StartupTimeout is the amount of time to allow between watchdog
+ // creation and calling watchdog.Start.
+ StartupTimeout time.Duration
+
+ // StartupTimeoutAction indicates what action to take when
+ // watchdog.Start is not called within the timeout.
+ StartupTimeoutAction Action
+}
+
+// DefaultOpts is a default set of options for the watchdog.
+var DefaultOpts = Opts{
+ // Task timeout.
+ TaskTimeout: 3 * time.Minute,
+ TaskTimeoutAction: LogWarning,
+
+ // Startup timeout.
+ StartupTimeout: 30 * time.Second,
+ StartupTimeoutAction: LogWarning,
+}
// descheduleThreshold is the amount of time scheduling needs to be off before the entire wait period
// is discounted from task's last update time. It's set high enough that small scheduling delays won't
@@ -61,6 +88,7 @@ type Action int
const (
// LogWarning logs warning message followed by stack trace.
LogWarning Action = iota
+
// Panic will do the same logging as LogWarning and panic().
Panic
)
@@ -80,17 +108,13 @@ func (a Action) String() string {
// Watchdog is the main watchdog class. It controls a goroutine that periodically
// analyses all tasks and reports if any of them appear to be stuck.
type Watchdog struct {
+ // Configuration options are embedded.
+ Opts
+
// period indicates how often to check all tasks. It's calculated based on
- // 'taskTimeout'.
+ // opts.TaskTimeout.
period time.Duration
- // taskTimeout is the amount of time to allow a task to execute the same syscall
- // without blocking before it's declared stuck.
- taskTimeout time.Duration
-
- // timeoutAction indicates what action to take when a stuck tasks is detected.
- timeoutAction Action
-
// k is where the tasks come from.
k *kernel.Kernel
@@ -113,8 +137,12 @@ type Watchdog struct {
// mu protects the fields below.
mu sync.Mutex
- // started is true if the watchdog has been started before.
- started bool
+ // running is true if the watchdog is running.
+ running bool
+
+ // startCalled is true if Start has ever been called. It remains true
+ // even if Stop is called.
+ startCalled bool
}
type offender struct {
@@ -122,58 +150,81 @@ type offender struct {
}
// New creates a new watchdog.
-func New(k *kernel.Kernel, taskTimeout time.Duration, a Action) *Watchdog {
- // 4 is arbitrary, just don't want to prolong 'taskTimeout' too much.
- period := taskTimeout / 4
- return &Watchdog{
- k: k,
- period: period,
- taskTimeout: taskTimeout,
- timeoutAction: a,
- offenders: make(map[*kernel.Task]*offender),
- stop: make(chan struct{}),
- done: make(chan struct{}),
+func New(k *kernel.Kernel, opts Opts) *Watchdog {
+ // 4 is arbitrary, just don't want to prolong 'TaskTimeout' too much.
+ period := opts.TaskTimeout / 4
+ w := &Watchdog{
+ Opts: opts,
+ k: k,
+ period: period,
+ offenders: make(map[*kernel.Task]*offender),
+ stop: make(chan struct{}),
+ done: make(chan struct{}),
+ }
+
+ // Handle StartupTimeout if it exists.
+ if w.StartupTimeout > 0 {
+ log.Infof("Watchdog waiting %v for startup", w.StartupTimeout)
+ go w.waitForStart() // S/R-SAFE: watchdog is stopped buring save and restarted after restore.
}
+
+ return w
}
// Start starts the watchdog.
func (w *Watchdog) Start() {
- if w.taskTimeout == 0 {
- log.Infof("Watchdog disabled")
- return
- }
-
w.mu.Lock()
defer w.mu.Unlock()
- if w.started {
+ w.startCalled = true
+
+ if w.running {
return
}
+ if w.TaskTimeout == 0 {
+ log.Infof("Watchdog task timeout disabled")
+ return
+ }
w.lastRun = w.k.MonotonicClock().Now()
- log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.taskTimeout, w.timeoutAction)
+ log.Infof("Starting watchdog, period: %v, timeout: %v, action: %v", w.period, w.TaskTimeout, w.TaskTimeoutAction)
go w.loop() // S/R-SAFE: watchdog is stopped during save and restarted after restore.
- w.started = true
+ w.running = true
}
// Stop requests the watchdog to stop and wait for it.
func (w *Watchdog) Stop() {
- if w.taskTimeout == 0 {
+ if w.TaskTimeout == 0 {
return
}
w.mu.Lock()
defer w.mu.Unlock()
- if !w.started {
+ if !w.running {
return
}
log.Infof("Stopping watchdog")
w.stop <- struct{}{}
<-w.done
- w.started = false
+ w.running = false
log.Infof("Watchdog stopped")
}
+// waitForStart waits for Start to be called and takes action if it does not
+// happen within the startup timeout.
+func (w *Watchdog) waitForStart() {
+ <-time.After(w.StartupTimeout)
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if w.startCalled {
+ // We are fine.
+ return
+ }
+ var buf bytes.Buffer
+ buf.WriteString("Watchdog.Start() not called within %s:\n")
+ w.doAction(w.StartupTimeoutAction, false, &buf)
+}
+
// loop is the main watchdog routine. It only returns when 'Stop()' is called.
func (w *Watchdog) loop() {
// Loop until someone stops it.
@@ -202,7 +253,7 @@ func (w *Watchdog) runTurn() {
select {
case <-done:
- case <-time.After(w.taskTimeout):
+ case <-time.After(w.TaskTimeout):
// Report if the watchdog is not making progress.
// No one is wathching the watchdog watcher though.
w.reportStuckWatchdog()
@@ -231,7 +282,7 @@ func (w *Watchdog) runTurn() {
if tsched.State == kernel.TaskGoroutineRunningSys {
lastUpdateTime := ktime.FromNanoseconds(int64(tsched.Timestamp * uint64(linux.ClockTick)))
elapsed := now.Sub(lastUpdateTime) - discount
- if elapsed > w.taskTimeout {
+ if elapsed > w.TaskTimeout {
tc, ok := w.offenders[t]
if !ok {
// New stuck task detected.
@@ -261,28 +312,34 @@ func (w *Watchdog) report(offenders map[*kernel.Task]*offender, newTaskFound boo
tid := w.k.TaskSet().Root.IDOfTask(t)
buf.WriteString(fmt.Sprintf("\tTask tid: %v (%#x), entered RunSys state %v ago.\n", tid, uint64(tid), now.Sub(o.lastUpdateTime)))
}
+
buf.WriteString("Search for '(*Task).run(0x..., 0x<tid>)' in the stack dump to find the offending goroutine")
- w.onStuckTask(newTaskFound, &buf)
+
+ // Dump stack only if a new task is detected or if it sometime has
+ // passed since the last time a stack dump was generated.
+ skipStack := newTaskFound || time.Since(w.lastStackDump) >= stackDumpSameTaskPeriod
+ w.doAction(w.TaskTimeoutAction, skipStack, &buf)
}
func (w *Watchdog) reportStuckWatchdog() {
var buf bytes.Buffer
buf.WriteString("Watchdog goroutine is stuck:\n")
- w.onStuckTask(true, &buf)
+ w.doAction(w.TaskTimeoutAction, false, &buf)
}
-func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) {
- switch w.timeoutAction {
+// doAction will take the given action. If the action is LogWarnind and
+// skipStack is true, then the stack printing will be skipped.
+func (w *Watchdog) doAction(action Action, skipStack bool, msg *bytes.Buffer) {
+ switch action {
case LogWarning:
- // Dump stack only if a new task is detected or if it sometime has passed since
- // the last time a stack dump was generated.
- if !newTaskFound && time.Since(w.lastStackDump) < stackDumpSameTaskPeriod {
+ if skipStack {
msg.WriteString("\n...[stack dump skipped]...")
log.Warningf(msg.String())
- } else {
- log.TracebackAll(msg.String())
- w.lastStackDump = time.Now()
+ return
+
}
+ log.TracebackAll(msg.String())
+ w.lastStackDump = time.Now()
case Panic:
// Panic will skip over running tasks, which is likely the culprit here. So manually
@@ -301,5 +358,8 @@ func (w *Watchdog) onStuckTask(newTaskFound bool, msg *bytes.Buffer) {
case <-time.After(1 * time.Second):
}
panic(fmt.Sprintf("Stack for running G's are skipped while panicking.\n%s", msg.String()))
+ default:
+ panic(fmt.Sprintf("Unknown watchdog action %v", action))
+
}
}
diff --git a/pkg/state/decode.go b/pkg/state/decode.go
index 47e6b878a..590c241a3 100644
--- a/pkg/state/decode.go
+++ b/pkg/state/decode.go
@@ -16,6 +16,7 @@ package state
import (
"bytes"
+ "context"
"encoding/binary"
"errors"
"fmt"
@@ -133,6 +134,9 @@ func (os *objectState) findCycle() []*objectState {
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
+ // ctx is the decode context.
+ ctx context.Context
+
// objectByID is the set of objects in progress.
objectsByID map[uint64]*objectState
diff --git a/pkg/state/encode.go b/pkg/state/encode.go
index 5d9409a45..c5118d3a9 100644
--- a/pkg/state/encode.go
+++ b/pkg/state/encode.go
@@ -16,6 +16,7 @@ package state
import (
"container/list"
+ "context"
"encoding/binary"
"fmt"
"io"
@@ -38,6 +39,9 @@ type queuedObject struct {
// The encoding process is a breadth-first traversal of the object graph. The
// inherent races and dependencies are much simpler than the decode case.
type encodeState struct {
+ // ctx is the encode context.
+ ctx context.Context
+
// lastID is the last object ID.
//
// See idsByObject for context. Because of the special zero encoding
diff --git a/pkg/state/map.go b/pkg/state/map.go
index 7e6fefed4..4f3ebb0da 100644
--- a/pkg/state/map.go
+++ b/pkg/state/map.go
@@ -15,6 +15,7 @@
package state
import (
+ "context"
"fmt"
"reflect"
"sort"
@@ -219,3 +220,13 @@ func (m Map) AfterLoad(fn func()) {
// data dependencies have been cleared.
m.os.callbacks = append(m.os.callbacks, fn)
}
+
+// Context returns the current context object.
+func (m Map) Context() context.Context {
+ if m.es != nil {
+ return m.es.ctx
+ } else if m.ds != nil {
+ return m.ds.ctx
+ }
+ return context.Background() // No context.
+}
diff --git a/pkg/state/state.go b/pkg/state/state.go
index d408ff84a..dbe507ab4 100644
--- a/pkg/state/state.go
+++ b/pkg/state/state.go
@@ -50,6 +50,7 @@
package state
import (
+ "context"
"fmt"
"io"
"reflect"
@@ -86,9 +87,10 @@ func UnwrapErrState(err error) error {
}
// Save saves the given object state.
-func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
+func Save(ctx context.Context, w io.Writer, rootPtr interface{}, stats *Stats) error {
// Create the encoding state.
es := &encodeState{
+ ctx: ctx,
idsByObject: make(map[uintptr]uint64),
w: w,
stats: stats,
@@ -101,9 +103,10 @@ func Save(w io.Writer, rootPtr interface{}, stats *Stats) error {
}
// Load loads a checkpoint.
-func Load(r io.Reader, rootPtr interface{}, stats *Stats) error {
+func Load(ctx context.Context, r io.Reader, rootPtr interface{}, stats *Stats) error {
// Create the decoding state.
ds := &decodeState{
+ ctx: ctx,
objectsByID: make(map[uint64]*objectState),
deferred: make(map[uint64]*pb.Object),
r: r,
diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go
index 7c24bbcda..d7221e9e8 100644
--- a/pkg/state/state_test.go
+++ b/pkg/state/state_test.go
@@ -16,6 +16,7 @@ package state
import (
"bytes"
+ "context"
"io/ioutil"
"math"
"reflect"
@@ -46,7 +47,7 @@ func runTest(t *testing.T, tests []TestCase) {
saveBuffer := &bytes.Buffer{}
saveObjectPtr := reflect.New(reflect.TypeOf(root))
saveObjectPtr.Elem().Set(reflect.ValueOf(root))
- if err := Save(saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
+ if err := Save(context.Background(), saveBuffer, saveObjectPtr.Interface(), nil); err != nil && !test.Fail {
t.Errorf(" FAIL: Save failed unexpectedly: %v", err)
continue
} else if err != nil {
@@ -56,7 +57,7 @@ func runTest(t *testing.T, tests []TestCase) {
// Load a new copy of the object.
loadObjectPtr := reflect.New(reflect.TypeOf(root))
- if err := Load(bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
+ if err := Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface(), nil); err != nil && !test.Fail {
t.Errorf(" FAIL: Load failed unexpectedly: %v", err)
continue
} else if err != nil {
@@ -624,7 +625,7 @@ func BenchmarkEncoding(b *testing.B) {
bs := buildObject(b.N)
var stats Stats
b.StartTimer()
- if err := Save(ioutil.Discard, bs, &stats); err != nil {
+ if err := Save(context.Background(), ioutil.Discard, bs, &stats); err != nil {
b.Errorf("save failed: %v", err)
}
b.StopTimer()
@@ -638,12 +639,12 @@ func BenchmarkDecoding(b *testing.B) {
bs := buildObject(b.N)
var newBS benchStruct
buf := &bytes.Buffer{}
- if err := Save(buf, bs, nil); err != nil {
+ if err := Save(context.Background(), buf, bs, nil); err != nil {
b.Errorf("save failed: %v", err)
}
var stats Stats
b.StartTimer()
- if err := Load(buf, &newBS, &stats); err != nil {
+ if err := Load(context.Background(), buf, &newBS, &stats); err != nil {
b.Errorf("load failed: %v", err)
}
b.StopTimer()
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
index 59378b96c..e7c05ca4f 100644
--- a/pkg/tcpip/link/fdbased/endpoint_test.go
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -45,7 +45,7 @@ const (
type packetInfo struct {
raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
- contents buffer.View
+ contents buffer.VectorisedView
linkHeader buffer.View
}
@@ -94,7 +94,7 @@ func (c *context) cleanup() {
}
func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remote tcpip.LinkAddress, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View) {
- c.ch <- packetInfo{remote, protocol, vv.ToView(), linkHeader}
+ c.ch <- packetInfo{remote, protocol, vv, linkHeader}
}
func TestNoEthernetProperties(t *testing.T) {
@@ -319,13 +319,17 @@ func TestDeliverPacket(t *testing.T) {
want := packetInfo{
raddr: raddr,
proto: proto,
- contents: b,
+ contents: buffer.View(b).ToVectorisedView(),
linkHeader: buffer.View(hdr),
}
if !eth {
want.proto = header.IPv4ProtocolNumber
want.raddr = ""
}
+ // want.contents will be a single view,
+ // so make pi do the same for the
+ // DeepEqual check.
+ pi.contents = pi.contents.ToView().ToVectorisedView()
if !reflect.DeepEqual(want, pi) {
t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
}
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index 12168a1dc..3331b6453 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -139,7 +139,7 @@ func (d *readVDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(n, BufConfig)
- vv := buffer.NewVectorisedView(n, d.views[:used])
+ vv := buffer.NewVectorisedView(n, append([]buffer.View(nil), d.views[:used]...))
vv.TrimFront(d.e.hdrSize)
d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth))
@@ -293,7 +293,7 @@ func (d *recvMMsgDispatcher) dispatch() (bool, *tcpip.Error) {
}
used := d.capViews(k, int(n), BufConfig)
- vv := buffer.NewVectorisedView(int(n), d.views[k][:used])
+ vv := buffer.NewVectorisedView(int(n), append([]buffer.View(nil), d.views[k][:used]...))
vv.TrimFront(d.e.hdrSize)
d.e.dispatcher.DeliverNetworkPacket(d.e, remote, local, p, vv, buffer.View(eth))
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index a01a208b8..fe8f83d58 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -762,7 +762,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.Link
}
n.mu.RUnlock()
for _, ep := range packetEPs {
- ep.HandlePacket(n.id, local, protocol, vv, linkHeader)
+ ep.HandlePacket(n.id, local, protocol, vv.Clone(nil), linkHeader)
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
index 0360187b8..d7c124e81 100644
--- a/pkg/tcpip/stack/registration.go
+++ b/pkg/tcpip/stack/registration.go
@@ -60,12 +60,19 @@ const (
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
+ // UniqueID returns an unique ID for this transport endpoint.
+ UniqueID() uint64
+
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
+ //
+ // HandlePacket takes ownership of vv.
HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
+ //
+ // HandleControlPacket takes ownership of vv.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
// Close puts the endpoint in a closed state and frees all resources
@@ -91,6 +98,8 @@ type RawTransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint. The packet contains all data from the link
// layer up.
+ //
+ // HandlePacket takes ownership of packet and netHeader.
HandlePacket(r *Route, netHeader buffer.View, packet buffer.VectorisedView)
}
@@ -107,6 +116,8 @@ type PacketEndpoint interface {
//
// linkHeader may have a length of 0, in which case the PacketEndpoint
// should construct its own ethernet header for applications.
+ //
+ // HandlePacket takes ownership of packet and linkHeader.
HandlePacket(nicid tcpip.NICID, addr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, packet buffer.VectorisedView, linkHeader buffer.View)
}
@@ -157,10 +168,14 @@ type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
// transport protocol endpoint. It also returns the network layer
// header for the enpoint to inspect or pass up the stack.
+ //
+ // DeliverTransportPacket takes ownership of vv and netHeader.
DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
+ //
+ // DeliverTransportControlPacket takes ownership of vv.
DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
@@ -234,6 +249,8 @@ type NetworkEndpoint interface {
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint.
+ //
+ // HandlePacket takes ownership of vv.
HandlePacket(r *Route, vv buffer.VectorisedView)
// Close is called when the endpoint is reomved from a stack.
@@ -279,6 +296,8 @@ type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol endpoint
// and hands the packet over for further processing. linkHeader may have
// length 0 when the caller does not have ethernet data.
+ //
+ // DeliverNetworkPacket takes ownership of vv and linkHeader.
DeliverNetworkPacket(linkEP LinkEndpoint, remote, local tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, linkHeader buffer.View)
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 6d6ddc0ff..115a6fcb8 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -22,6 +22,7 @@ package stack
import (
"encoding/binary"
"sync"
+ "sync/atomic"
"time"
"golang.org/x/time/rate"
@@ -344,6 +345,13 @@ type ResumableEndpoint interface {
Resume(*Stack)
}
+// uniqueIDGenerator is a default unique ID generator.
+type uniqueIDGenerator uint64
+
+func (u *uniqueIDGenerator) UniqueID() uint64 {
+ return atomic.AddUint64((*uint64)(u), 1)
+}
+
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
@@ -411,6 +419,14 @@ type Stack struct {
// ndpDisp is the NDP event dispatcher that is used to send the netstack
// integrator NDP related events.
ndpDisp NDPDispatcher
+
+ // uniqueIDGenerator is a generator of unique identifiers.
+ uniqueIDGenerator UniqueID
+}
+
+// UniqueID is an abstract generator of unique identifiers.
+type UniqueID interface {
+ UniqueID() uint64
}
// Options contains optional Stack configuration.
@@ -434,6 +450,9 @@ type Options struct {
// stack (false).
HandleLocal bool
+ // UniqueID is an optional generator of unique identifiers.
+ UniqueID UniqueID
+
// NDPConfigs is the default NDP configurations used by interfaces.
//
// By default, NDPConfigs will have a zero value for its
@@ -506,6 +525,10 @@ func New(opts Options) *Stack {
clock = &tcpip.StdClock{}
}
+ if opts.UniqueID == nil {
+ opts.UniqueID = new(uniqueIDGenerator)
+ }
+
// Make sure opts.NDPConfigs contains valid values only.
opts.NDPConfigs.validate()
@@ -524,6 +547,7 @@ func New(opts Options) *Stack {
portSeed: generateRandUint32(),
ndpConfigs: opts.NDPConfigs,
autoGenIPv6LinkLocal: opts.AutoGenIPv6LinkLocal,
+ uniqueIDGenerator: opts.UniqueID,
ndpDisp: opts.NDPDisp,
}
@@ -551,6 +575,11 @@ func New(opts Options) *Stack {
return s
}
+// UniqueID returns a unique identifier.
+func (s *Stack) UniqueID() uint64 {
+ return s.uniqueIDGenerator.UniqueID()
+}
+
// SetNetworkProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index f633632f0..ccd3d030e 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -17,6 +17,7 @@ package stack
import (
"fmt"
"math/rand"
+ "sort"
"sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -310,6 +311,15 @@ func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePo
// endpointsMap. This will allow us to remove endpoint from the array fast.
ep.endpointsMap[t] = len(ep.endpointsArr)
ep.endpointsArr = append(ep.endpointsArr, t)
+
+ // ep.endpointsArr is sorted by endpoint unique IDs, so that endpoints
+ // can be restored in the same order.
+ sort.Slice(ep.endpointsArr, func(i, j int) bool {
+ return ep.endpointsArr[i].UniqueID() < ep.endpointsArr[j].UniqueID()
+ })
+ for i, e := range ep.endpointsArr {
+ ep.endpointsMap[e] = i
+ }
return nil
}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index ae6fda3a9..203e79f56 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -43,6 +43,7 @@ type fakeTransportEndpoint struct {
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
+ uniqueID uint64
// acceptQueue is non-nil iff bound.
acceptQueue []fakeTransportEndpoint
@@ -56,8 +57,8 @@ func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
-func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
- return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto}
+func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
}
func (f *fakeTransportEndpoint) Close() {
@@ -144,6 +145,10 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return nil
}
+func (f *fakeTransportEndpoint) UniqueID() uint64 {
+ return f.uniqueID
+}
+
func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -251,7 +256,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
}
func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(stack, f, netProto), nil
+ return newFakeTransportEndpoint(stack, f, netProto, stack.UniqueID()), nil
}
func (f *fakeTransportProtocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d0dd383fd..33405eb7d 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -31,9 +31,6 @@ type icmpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- // views is used as buffer for data when its length is large
- // enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
}
type endpointState int
@@ -58,6 +55,7 @@ type endpoint struct {
// immutable.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -90,9 +88,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: stateInitial,
+ uniqueID: s.UniqueID(),
}, nil
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -760,7 +764,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
},
}
- pkt.data = vv.Clone(pkt.views[:])
+ pkt.data = vv
e.rcvList.PushBack(pkt)
e.rcvBufSize += pkt.data.Size()
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 73cdaa265..ead83b83d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -41,10 +41,6 @@ type packet struct {
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // views is pre-allocated space to back data. As long as the packet is
- // made up of fewer than 8 buffer.Views, no extra allocation is
- // necessary to store packet data.
- views [8]buffer.View `state:"nosave"`
// timestampNS is the unix time at which the packet was received.
timestampNS int64
// senderAddr is the network address of the sender.
@@ -310,7 +306,7 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress,
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = vv.Clone(packet.views[:])
+ packet.data = vv
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
@@ -328,7 +324,7 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress,
}
combinedVV := buffer.View(ethHeader).ToVectorisedView()
combinedVV.Append(vv)
- packet.data = combinedVV.Clone(packet.views[:])
+ packet.data = combinedVV
}
packet.timestampNS = ep.stack.NowNanoseconds()
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 951d317ed..23922a30e 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -42,10 +42,6 @@ type rawPacket struct {
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // views is pre-allocated space to back data. As long as the packet is
- // made up of fewer than 8 buffer.Views, no extra allocation is
- // necessary to store packet data.
- views [8]buffer.View `state:"nosave"`
// timestampNS is the unix time at which the packet was received.
timestampNS int64
// senderAddr is the network address of the sender.
@@ -609,7 +605,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv bu
combinedVV := netHeader.ToVectorisedView()
combinedVV.Append(vv)
- pkt.data = combinedVV.Clone(pkt.views[:])
+ pkt.data = combinedVV
pkt.timestampNS = e.stack.NowNanoseconds()
e.rcvList.PushBack(pkt)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 8a3ca0f1b..a1efd8d55 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -287,6 +287,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue `state:"wait"`
+ uniqueID uint64
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
@@ -504,6 +505,11 @@ type endpoint struct {
stats Stats `state:"nosave"`
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// calculateAdvertisedMSS calculates the MSS to advertise.
//
// If userMSS is non-zero and is not greater than the maximum possible MSS for
@@ -565,6 +571,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
interval: 75 * time.Second,
count: 9,
},
+ uniqueID: s.UniqueID(),
}
var ss SendBufferSizeOption
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index cda302bb7..03bd5c8fd 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -31,9 +31,6 @@ type udpPacket struct {
senderAddress tcpip.FullAddress
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
- // views is used as buffer for data when its length is large
- // enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
}
// EndpointState represents the state of a UDP endpoint.
@@ -80,6 +77,7 @@ type endpoint struct {
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
waiterQueue *waiter.Queue
+ uniqueID uint64
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
@@ -160,9 +158,15 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
state: StateInitial,
+ uniqueID: s.UniqueID(),
}
}
+// UniqueID implements stack.TransportEndpoint.UniqueID.
+func (e *endpoint) UniqueID() uint64 {
+ return e.uniqueID
+}
+
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
func (e *endpoint) Close() {
@@ -1195,7 +1199,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
Port: hdr.SourcePort(),
},
}
- pkt.data = vv.Clone(pkt.views[:])
+ pkt.data = vv
e.rcvList.PushBack(pkt)
e.rcvBufSize += vv.Size()
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index 5f644b57e..f62be4c59 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -51,7 +51,7 @@ const (
ContainerEvent = "containerManager.Event"
// ContainerExecuteAsync is the URPC endpoint for executing a command in a
- // container..
+ // container.
ContainerExecuteAsync = "containerManager.ExecuteAsync"
// ContainerPause pauses the container.
@@ -161,7 +161,7 @@ func newController(fd int, l *Loader) (*controller, error) {
}, nil
}
-// containerManager manages sandboes containers.
+// containerManager manages sandbox containers.
type containerManager struct {
// startChan is used to signal when the root container process should
// be started.
@@ -380,7 +380,9 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
}
// Since we have a new kernel we also must make a new watchdog.
- dog := watchdog.New(k, watchdog.DefaultTimeout, cm.l.conf.WatchdogAction)
+ dogOpts := watchdog.DefaultOpts
+ dogOpts.TaskTimeoutAction = cm.l.conf.WatchdogAction
+ dog := watchdog.New(k, dogOpts)
// Change the loader fields to reflect the changes made when restoring.
cm.l.k = k
diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go
index 0c0eba99e..4d1bd2d08 100644
--- a/runsc/boot/loader.go
+++ b/runsc/boot/loader.go
@@ -232,7 +232,7 @@ func New(args Args) (*Loader, error) {
// this point. Netns is configured before Run() is called. Netstack is
// configured using a control uRPC message. Host network is configured inside
// Run().
- networkStack, err := newEmptyNetworkStack(args.Conf, k)
+ networkStack, err := newEmptyNetworkStack(args.Conf, k, k)
if err != nil {
return nil, fmt.Errorf("creating network: %v", err)
}
@@ -300,7 +300,9 @@ func New(args Args) (*Loader, error) {
}
// Create a watchdog.
- dog := watchdog.New(k, watchdog.DefaultTimeout, args.Conf.WatchdogAction)
+ dogOpts := watchdog.DefaultOpts
+ dogOpts.TaskTimeoutAction = args.Conf.WatchdogAction
+ dog := watchdog.New(k, dogOpts)
procArgs, err := newProcess(args.ID, args.Spec, creds, k, k.RootPIDNamespace())
if err != nil {
@@ -905,7 +907,7 @@ func (l *Loader) WaitExit() kernel.ExitStatus {
return l.k.GlobalInit().ExitStatus()
}
-func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
+func newEmptyNetworkStack(conf *Config, clock tcpip.Clock, uniqueID stack.UniqueID) (inet.Stack, error) {
switch conf.Network {
case NetworkHost:
return hostinet.NewStack(), nil
@@ -923,6 +925,7 @@ func newEmptyNetworkStack(conf *Config, clock tcpip.Clock) (inet.Stack, error) {
// Enable raw sockets for users with sufficient
// privileges.
RawFactory: raw.EndpointFactory{},
+ UniqueID: uniqueID,
})}
// Enable SACK Recovery.
diff --git a/runsc/container/BUILD b/runsc/container/BUILD
index 26d1cd5ab..2bd12120d 100644
--- a/runsc/container/BUILD
+++ b/runsc/container/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"container.go",
"hook.go",
+ "state_file.go",
"status.go",
],
importpath = "gvisor.dev/gvisor/runsc/container",
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 32510d427..68782c4be 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -17,13 +17,11 @@ package container
import (
"context"
- "encoding/json"
"fmt"
"io/ioutil"
"os"
"os/exec"
"os/signal"
- "path/filepath"
"regexp"
"strconv"
"strings"
@@ -31,7 +29,6 @@ import (
"time"
"github.com/cenkalti/backoff"
- "github.com/gofrs/flock"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/control"
@@ -41,17 +38,6 @@ import (
"gvisor.dev/gvisor/runsc/specutils"
)
-const (
- // metadataFilename is the name of the metadata file relative to the
- // container root directory that holds sandbox metadata.
- metadataFilename = "meta.json"
-
- // metadataLockFilename is the name of a lock file in the container
- // root directory that is used to prevent concurrent modifications to
- // the container state and metadata.
- metadataLockFilename = "meta.lock"
-)
-
// validateID validates the container id.
func validateID(id string) error {
// See libcontainer/factory_linux.go.
@@ -99,11 +85,6 @@ type Container struct {
// BundleDir is the directory containing the container bundle.
BundleDir string `json:"bundleDir"`
- // Root is the directory containing the container metadata file. If this
- // container is the root container, Root and RootContainerDir will be the
- // same.
- Root string `json:"root"`
-
// CreatedAt is the time the container was created.
CreatedAt time.Time `json:"createdAt"`
@@ -121,21 +102,24 @@ type Container struct {
// be 0 if the gofer has been killed.
GoferPid int `json:"goferPid"`
+ // Sandbox is the sandbox this container is running in. It's set when the
+ // container is created and reset when the sandbox is destroyed.
+ Sandbox *sandbox.Sandbox `json:"sandbox"`
+
+ // Saver handles load from/save to the state file safely from multiple
+ // processes.
+ Saver StateFile `json:"saver"`
+
+ //
+ // Fields below this line are not saved in the state file and will not
+ // be preserved across commands.
+ //
+
// goferIsChild is set if a gofer process is a child of the current process.
//
// This field isn't saved to json, because only a creator of a gofer
// process will have it as a child process.
goferIsChild bool
-
- // Sandbox is the sandbox this container is running in. It's set when the
- // container is created and reset when the sandbox is destroyed.
- Sandbox *sandbox.Sandbox `json:"sandbox"`
-
- // RootContainerDir is the root directory containing the metadata file of the
- // sandbox root container. It's used to lock in order to serialize creating
- // and deleting this Container's metadata directory. If this container is the
- // root container, this is the same as Root.
- RootContainerDir string
}
// loadSandbox loads all containers that belong to the sandbox with the given
@@ -166,43 +150,35 @@ func loadSandbox(rootDir, id string) ([]*Container, error) {
return containers, nil
}
-// Load loads a container with the given id from a metadata file. id may be an
-// abbreviation of the full container id, in which case Load loads the
-// container to which id unambiguously refers to.
-// Returns ErrNotExist if container doesn't exist.
-func Load(rootDir, id string) (*Container, error) {
- log.Debugf("Load container %q %q", rootDir, id)
- if err := validateID(id); err != nil {
+// Load loads a container with the given id from a metadata file. partialID may
+// be an abbreviation of the full container id, in which case Load loads the
+// container to which id unambiguously refers to. Returns ErrNotExist if
+// container doesn't exist.
+func Load(rootDir, partialID string) (*Container, error) {
+ log.Debugf("Load container %q %q", rootDir, partialID)
+ if err := validateID(partialID); err != nil {
return nil, fmt.Errorf("validating id: %v", err)
}
- cRoot, err := findContainerRoot(rootDir, id)
+ id, err := findContainerID(rootDir, partialID)
if err != nil {
// Preserve error so that callers can distinguish 'not found' errors.
return nil, err
}
- // Lock the container metadata to prevent other runsc instances from
- // writing to it while we are reading it.
- unlock, err := lockContainerMetadata(cRoot)
- if err != nil {
- return nil, err
+ state := StateFile{
+ RootDir: rootDir,
+ ID: id,
}
- defer unlock()
+ defer state.close()
- // Read the container metadata file and create a new Container from it.
- metaFile := filepath.Join(cRoot, metadataFilename)
- metaBytes, err := ioutil.ReadFile(metaFile)
- if err != nil {
+ c := &Container{}
+ if err := state.load(c); err != nil {
if os.IsNotExist(err) {
// Preserve error so that callers can distinguish 'not found' errors.
return nil, err
}
- return nil, fmt.Errorf("reading container metadata file %q: %v", metaFile, err)
- }
- var c Container
- if err := json.Unmarshal(metaBytes, &c); err != nil {
- return nil, fmt.Errorf("unmarshaling container metadata from %q: %v", metaFile, err)
+ return nil, fmt.Errorf("reading container metadata file %q: %v", state.statePath(), err)
}
// If the status is "Running" or "Created", check that the sandbox
@@ -223,57 +199,37 @@ func Load(rootDir, id string) (*Container, error) {
}
}
- return &c, nil
+ return c, nil
}
-func findContainerRoot(rootDir, partialID string) (string, error) {
+func findContainerID(rootDir, partialID string) (string, error) {
// Check whether the id fully specifies an existing container.
- cRoot := filepath.Join(rootDir, partialID)
- if _, err := os.Stat(cRoot); err == nil {
- return cRoot, nil
+ stateFile := buildStatePath(rootDir, partialID)
+ if _, err := os.Stat(stateFile); err == nil {
+ return partialID, nil
}
// Now see whether id could be an abbreviation of exactly 1 of the
// container ids. If id is ambiguous (it could match more than 1
// container), it is an error.
- cRoot = ""
ids, err := List(rootDir)
if err != nil {
return "", err
}
+ rv := ""
for _, id := range ids {
if strings.HasPrefix(id, partialID) {
- if cRoot != "" {
- return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, cRoot, id)
+ if rv != "" {
+ return "", fmt.Errorf("id %q is ambiguous and could refer to multiple containers: %q, %q", partialID, rv, id)
}
- cRoot = id
+ rv = id
}
}
- if cRoot == "" {
+ if rv == "" {
return "", os.ErrNotExist
}
- log.Debugf("abbreviated id %q resolves to full id %q", partialID, cRoot)
- return filepath.Join(rootDir, cRoot), nil
-}
-
-// List returns all container ids in the given root directory.
-func List(rootDir string) ([]string, error) {
- log.Debugf("List containers %q", rootDir)
- fs, err := ioutil.ReadDir(rootDir)
- if err != nil {
- return nil, fmt.Errorf("reading dir %q: %v", rootDir, err)
- }
- var out []string
- for _, f := range fs {
- // Filter out directories that do no belong to a container.
- cid := f.Name()
- if validateID(cid) == nil {
- if _, err := os.Stat(filepath.Join(rootDir, cid, metadataFilename)); err == nil {
- out = append(out, f.Name())
- }
- }
- }
- return out, nil
+ log.Debugf("abbreviated id %q resolves to full id %q", partialID, rv)
+ return rv, nil
}
// Args is used to configure a new container.
@@ -316,44 +272,34 @@ func New(conf *boot.Config, args Args) (*Container, error) {
return nil, err
}
- unlockRoot, err := maybeLockRootContainer(args.Spec, conf.RootDir)
- if err != nil {
- return nil, err
+ if err := os.MkdirAll(conf.RootDir, 0711); err != nil {
+ return nil, fmt.Errorf("creating container root directory: %v", err)
}
- defer unlockRoot()
+
+ c := &Container{
+ ID: args.ID,
+ Spec: args.Spec,
+ ConsoleSocket: args.ConsoleSocket,
+ BundleDir: args.BundleDir,
+ Status: Creating,
+ CreatedAt: time.Now(),
+ Owner: os.Getenv("USER"),
+ Saver: StateFile{
+ RootDir: conf.RootDir,
+ ID: args.ID,
+ },
+ }
+ // The Cleanup object cleans up partially created containers when an error
+ // occurs. Any errors occurring during cleanup itself are ignored.
+ cu := specutils.MakeCleanup(func() { _ = c.Destroy() })
+ defer cu.Clean()
// Lock the container metadata file to prevent concurrent creations of
// containers with the same id.
- containerRoot := filepath.Join(conf.RootDir, args.ID)
- unlock, err := lockContainerMetadata(containerRoot)
- if err != nil {
+ if err := c.Saver.lockForNew(); err != nil {
return nil, err
}
- defer unlock()
-
- // Check if the container already exists by looking for the metadata
- // file.
- if _, err := os.Stat(filepath.Join(containerRoot, metadataFilename)); err == nil {
- return nil, fmt.Errorf("container with id %q already exists", args.ID)
- } else if !os.IsNotExist(err) {
- return nil, fmt.Errorf("looking for existing container in %q: %v", containerRoot, err)
- }
-
- c := &Container{
- ID: args.ID,
- Spec: args.Spec,
- ConsoleSocket: args.ConsoleSocket,
- BundleDir: args.BundleDir,
- Root: containerRoot,
- Status: Creating,
- CreatedAt: time.Now(),
- Owner: os.Getenv("USER"),
- RootContainerDir: conf.RootDir,
- }
- // The Cleanup object cleans up partially created containers when an error occurs.
- // Any errors occuring during cleanup itself are ignored.
- cu := specutils.MakeCleanup(func() { _ = c.Destroy() })
- defer cu.Clean()
+ defer c.Saver.unlock()
// If the metadata annotations indicate that this container should be
// started in an existing sandbox, we must do so. The metadata will
@@ -431,7 +377,7 @@ func New(conf *boot.Config, args Args) (*Container, error) {
c.changeStatus(Created)
// Save the metadata file.
- if err := c.save(); err != nil {
+ if err := c.saveLocked(); err != nil {
return nil, err
}
@@ -451,17 +397,12 @@ func New(conf *boot.Config, args Args) (*Container, error) {
func (c *Container) Start(conf *boot.Config) error {
log.Debugf("Start container %q", c.ID)
- unlockRoot, err := maybeLockRootContainer(c.Spec, c.RootContainerDir)
- if err != nil {
+ if err := c.Saver.lock(); err != nil {
return err
}
- defer unlockRoot()
+ unlock := specutils.MakeCleanup(func() { c.Saver.unlock() })
+ defer unlock.Clean()
- unlock, err := c.lock()
- if err != nil {
- return err
- }
- defer unlock()
if err := c.requireStatus("start", Created); err != nil {
return err
}
@@ -509,14 +450,15 @@ func (c *Container) Start(conf *boot.Config) error {
}
c.changeStatus(Running)
- if err := c.save(); err != nil {
+ if err := c.saveLocked(); err != nil {
return err
}
- // Adjust the oom_score_adj for sandbox. This must be done after
- // save().
- err = adjustSandboxOOMScoreAdj(c.Sandbox, c.RootContainerDir, false)
- if err != nil {
+ // Release lock before adjusting OOM score because the lock is acquired there.
+ unlock.Clean()
+
+ // Adjust the oom_score_adj for sandbox. This must be done after saveLocked().
+ if err := adjustSandboxOOMScoreAdj(c.Sandbox, c.Saver.RootDir, false); err != nil {
return err
}
@@ -529,11 +471,10 @@ func (c *Container) Start(conf *boot.Config) error {
// to restore a container from its state file.
func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile string) error {
log.Debugf("Restore container %q", c.ID)
- unlock, err := c.lock()
- if err != nil {
+ if err := c.Saver.lock(); err != nil {
return err
}
- defer unlock()
+ defer c.Saver.unlock()
if err := c.requireStatus("restore", Created); err != nil {
return err
@@ -551,7 +492,7 @@ func (c *Container) Restore(spec *specs.Spec, conf *boot.Config, restoreFile str
return err
}
c.changeStatus(Running)
- return c.save()
+ return c.saveLocked()
}
// Run is a helper that calls Create + Start + Wait.
@@ -711,11 +652,10 @@ func (c *Container) Checkpoint(f *os.File) error {
// The call only succeeds if the container's status is created or running.
func (c *Container) Pause() error {
log.Debugf("Pausing container %q", c.ID)
- unlock, err := c.lock()
- if err != nil {
+ if err := c.Saver.lock(); err != nil {
return err
}
- defer unlock()
+ defer c.Saver.unlock()
if c.Status != Created && c.Status != Running {
return fmt.Errorf("cannot pause container %q in state %v", c.ID, c.Status)
@@ -725,18 +665,17 @@ func (c *Container) Pause() error {
return fmt.Errorf("pausing container: %v", err)
}
c.changeStatus(Paused)
- return c.save()
+ return c.saveLocked()
}
// Resume unpauses the container and its kernel.
// The call only succeeds if the container's status is paused.
func (c *Container) Resume() error {
log.Debugf("Resuming container %q", c.ID)
- unlock, err := c.lock()
- if err != nil {
+ if err := c.Saver.lock(); err != nil {
return err
}
- defer unlock()
+ defer c.Saver.unlock()
if c.Status != Paused {
return fmt.Errorf("cannot resume container %q in state %v", c.ID, c.Status)
@@ -745,7 +684,7 @@ func (c *Container) Resume() error {
return fmt.Errorf("resuming container: %v", err)
}
c.changeStatus(Running)
- return c.save()
+ return c.saveLocked()
}
// State returns the metadata of the container.
@@ -773,6 +712,17 @@ func (c *Container) Processes() ([]*control.Process, error) {
func (c *Container) Destroy() error {
log.Debugf("Destroy container %q", c.ID)
+ if err := c.Saver.lock(); err != nil {
+ return err
+ }
+ defer func() {
+ c.Saver.unlock()
+ c.Saver.close()
+ }()
+
+ // Stored for later use as stop() sets c.Sandbox to nil.
+ sb := c.Sandbox
+
// We must perform the following cleanup steps:
// * stop the container and gofer processes,
// * remove the container filesystem on the host, and
@@ -782,48 +732,43 @@ func (c *Container) Destroy() error {
// do our best to perform all of the cleanups. Hence, we keep a slice
// of errors return their concatenation.
var errs []string
-
- unlock, err := maybeLockRootContainer(c.Spec, c.RootContainerDir)
- if err != nil {
- return err
- }
- defer unlock()
-
- // Stored for later use as stop() sets c.Sandbox to nil.
- sb := c.Sandbox
-
if err := c.stop(); err != nil {
err = fmt.Errorf("stopping container: %v", err)
log.Warningf("%v", err)
errs = append(errs, err.Error())
}
- if err := os.RemoveAll(c.Root); err != nil && !os.IsNotExist(err) {
- err = fmt.Errorf("deleting container root directory %q: %v", c.Root, err)
+ if err := c.Saver.destroy(); err != nil {
+ err = fmt.Errorf("deleting container state files: %v", err)
log.Warningf("%v", err)
errs = append(errs, err.Error())
}
c.changeStatus(Stopped)
- // Adjust oom_score_adj for the sandbox. This must be done after the
- // container is stopped and the directory at c.Root is removed.
- // We must test if the sandbox is nil because Destroy should be
- // idempotent.
- if sb != nil {
- if err := adjustSandboxOOMScoreAdj(sb, c.RootContainerDir, true); err != nil {
+ // Adjust oom_score_adj for the sandbox. This must be done after the container
+ // is stopped and the directory at c.Root is removed. Adjustment can be
+ // skipped if the root container is exiting, because it brings down the entire
+ // sandbox.
+ //
+ // Use 'sb' to tell whether it has been executed before because Destroy must
+ // be idempotent.
+ if sb != nil && !isRoot(c.Spec) {
+ if err := adjustSandboxOOMScoreAdj(sb, c.Saver.RootDir, true); err != nil {
errs = append(errs, err.Error())
}
}
// "If any poststop hook fails, the runtime MUST log a warning, but the
- // remaining hooks and lifecycle continue as if the hook had succeeded" -OCI spec.
- // Based on the OCI, "The post-stop hooks MUST be called after the container is
- // deleted but before the delete operation returns"
+ // remaining hooks and lifecycle continue as if the hook had
+ // succeeded" - OCI spec.
+ //
+ // Based on the OCI, "The post-stop hooks MUST be called after the container
+ // is deleted but before the delete operation returns"
// Run it here to:
// 1) Conform to the OCI.
- // 2) Make sure it only runs once, because the root has been deleted, the container
- // can't be loaded again.
+ // 2) Make sure it only runs once, because the root has been deleted, the
+ // container can't be loaded again.
if c.Spec.Hooks != nil {
executeHooksBestEffort(c.Spec.Hooks.Poststop, c.State())
}
@@ -834,18 +779,13 @@ func (c *Container) Destroy() error {
return fmt.Errorf(strings.Join(errs, "\n"))
}
-// save saves the container metadata to a file.
+// saveLocked saves the container metadata to a file.
//
// Precondition: container must be locked with container.lock().
-func (c *Container) save() error {
+func (c *Container) saveLocked() error {
log.Debugf("Save container %q", c.ID)
- metaFile := filepath.Join(c.Root, metadataFilename)
- meta, err := json.Marshal(c)
- if err != nil {
- return fmt.Errorf("invalid container metadata: %v", err)
- }
- if err := ioutil.WriteFile(metaFile, meta, 0640); err != nil {
- return fmt.Errorf("writing container metadata: %v", err)
+ if err := c.Saver.saveLocked(c); err != nil {
+ return fmt.Errorf("saving container metadata: %v", err)
}
return nil
}
@@ -1106,48 +1046,6 @@ func (c *Container) requireStatus(action string, statuses ...Status) error {
return fmt.Errorf("cannot %s container %q in state %s", action, c.ID, c.Status)
}
-// lock takes a file lock on the container metadata lock file.
-func (c *Container) lock() (func() error, error) {
- return lockContainerMetadata(filepath.Join(c.Root, c.ID))
-}
-
-// lockContainerMetadata takes a file lock on the metadata lock file in the
-// given container root directory.
-func lockContainerMetadata(containerRootDir string) (func() error, error) {
- if err := os.MkdirAll(containerRootDir, 0711); err != nil {
- return nil, fmt.Errorf("creating container root directory %q: %v", containerRootDir, err)
- }
- f := filepath.Join(containerRootDir, metadataLockFilename)
- l := flock.NewFlock(f)
- if err := l.Lock(); err != nil {
- return nil, fmt.Errorf("acquiring lock on container lock file %q: %v", f, err)
- }
- return l.Unlock, nil
-}
-
-// maybeLockRootContainer locks the sandbox root container. It is used to
-// prevent races to create and delete child container sandboxes.
-func maybeLockRootContainer(spec *specs.Spec, rootDir string) (func() error, error) {
- if isRoot(spec) {
- return func() error { return nil }, nil
- }
-
- sbid, ok := specutils.SandboxID(spec)
- if !ok {
- return nil, fmt.Errorf("no sandbox ID found when locking root container")
- }
- sb, err := Load(rootDir, sbid)
- if err != nil {
- return nil, err
- }
-
- unlock, err := sb.lock()
- if err != nil {
- return nil, err
- }
- return unlock, nil
-}
-
func isRoot(spec *specs.Spec) bool {
return specutils.SpecContainerType(spec) != specutils.ContainerTypeContainer
}
@@ -1170,7 +1068,12 @@ func runInCgroup(cg *cgroup.Cgroup, fn func() error) error {
func (c *Container) adjustGoferOOMScoreAdj() error {
if c.GoferPid != 0 && c.Spec.Process.OOMScoreAdj != nil {
if err := setOOMScoreAdj(c.GoferPid, *c.Spec.Process.OOMScoreAdj); err != nil {
- return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
+ // Ignore NotExist error because it can be returned when the sandbox
+ // exited while OOM score was being adjusted.
+ if !os.IsNotExist(err) {
+ return fmt.Errorf("setting gofer oom_score_adj for container %q: %v", c.ID, err)
+ }
+ log.Warningf("Gofer process (%d) not found setting oom_score_adj", c.GoferPid)
}
}
@@ -1252,7 +1155,12 @@ func adjustSandboxOOMScoreAdj(s *sandbox.Sandbox, rootDir string, destroy bool)
// Set the lowest of all containers oom_score_adj to the sandbox.
if err := setOOMScoreAdj(s.Pid, lowScore); err != nil {
- return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err)
+ // Ignore NotExist error because it can be returned when the sandbox
+ // exited while OOM score was being adjusted.
+ if !os.IsNotExist(err) {
+ return fmt.Errorf("setting oom_score_adj for sandbox %q: %v", s.ID, err)
+ }
+ log.Warningf("Sandbox process (%d) not found setting oom_score_adj", s.Pid)
}
return nil
diff --git a/runsc/container/state_file.go b/runsc/container/state_file.go
new file mode 100644
index 000000000..d95151ea5
--- /dev/null
+++ b/runsc/container/state_file.go
@@ -0,0 +1,185 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package container
+
+import (
+ "encoding/json"
+ "fmt"
+ "io/ioutil"
+ "os"
+ "path/filepath"
+ "sync"
+
+ "github.com/gofrs/flock"
+ "gvisor.dev/gvisor/pkg/log"
+)
+
+const stateFileExtension = ".state"
+
+// StateFile handles load from/save to container state safely from multiple
+// processes. It uses a lock file to provide synchronization between operations.
+//
+// The lock file is located at: "${s.RootDir}/${s.ID}.lock".
+// The state file is located at: "${s.RootDir}/${s.ID}.state".
+type StateFile struct {
+ // RootDir is the directory containing the container metadata file.
+ RootDir string `json:"rootDir"`
+
+ // ID is the container ID.
+ ID string `json:"id"`
+
+ //
+ // Fields below this line are not saved in the state file and will not
+ // be preserved across commands.
+ //
+
+ once sync.Once
+ flock *flock.Flock
+}
+
+// List returns all container ids in the given root directory.
+func List(rootDir string) ([]string, error) {
+ log.Debugf("List containers %q", rootDir)
+ list, err := filepath.Glob(filepath.Join(rootDir, "*"+stateFileExtension))
+ if err != nil {
+ return nil, err
+ }
+ var out []string
+ for _, path := range list {
+ // Filter out files that do no belong to a container.
+ fileName := filepath.Base(path)
+ if len(fileName) < len(stateFileExtension) {
+ panic(fmt.Sprintf("invalid file match %q", path))
+ }
+ // Remove the extension.
+ cid := fileName[:len(fileName)-len(stateFileExtension)]
+ if validateID(cid) == nil {
+ out = append(out, cid)
+ }
+ }
+ return out, nil
+}
+
+// lock globally locks all locking operations for the container.
+func (s *StateFile) lock() error {
+ s.once.Do(func() {
+ s.flock = flock.NewFlock(s.lockPath())
+ })
+
+ if err := s.flock.Lock(); err != nil {
+ return fmt.Errorf("acquiring lock on %q: %v", s.flock, err)
+ }
+ return nil
+}
+
+// lockForNew acquires the lock and checks if the state file doesn't exist. This
+// is done to ensure that more than one creation didn't race to create
+// containers with the same ID.
+func (s *StateFile) lockForNew() error {
+ if err := s.lock(); err != nil {
+ return err
+ }
+
+ // Checks if the container already exists by looking for the metadata file.
+ if _, err := os.Stat(s.statePath()); err == nil {
+ s.unlock()
+ return fmt.Errorf("container already exists")
+ } else if !os.IsNotExist(err) {
+ s.unlock()
+ return fmt.Errorf("looking for existing container: %v", err)
+ }
+ return nil
+}
+
+// unlock globally unlocks all locking operations for the container.
+func (s *StateFile) unlock() error {
+ if !s.flock.Locked() {
+ panic("unlock called without lock held")
+ }
+
+ if err := s.flock.Unlock(); err != nil {
+ log.Warningf("Error to release lock on %q: %v", s.flock, err)
+ return fmt.Errorf("releasing lock on %q: %v", s.flock, err)
+ }
+ return nil
+}
+
+// saveLocked saves 'v' to the state file.
+//
+// Preconditions: lock() must been called before.
+func (s *StateFile) saveLocked(v interface{}) error {
+ if !s.flock.Locked() {
+ panic("saveLocked called without lock held")
+ }
+
+ meta, err := json.Marshal(v)
+ if err != nil {
+ return err
+ }
+ if err := ioutil.WriteFile(s.statePath(), meta, 0640); err != nil {
+ return fmt.Errorf("writing json file: %v", err)
+ }
+ return nil
+}
+
+func (s *StateFile) load(v interface{}) error {
+ if err := s.lock(); err != nil {
+ return err
+ }
+ defer s.unlock()
+
+ metaBytes, err := ioutil.ReadFile(s.statePath())
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal(metaBytes, &v)
+}
+
+func (s *StateFile) close() error {
+ if s.flock == nil {
+ return nil
+ }
+ if s.flock.Locked() {
+ panic("Closing locked file")
+ }
+ return s.flock.Close()
+}
+
+func buildStatePath(rootDir, id string) string {
+ return filepath.Join(rootDir, id+stateFileExtension)
+}
+
+// statePath is the full path to the state file.
+func (s *StateFile) statePath() string {
+ return buildStatePath(s.RootDir, s.ID)
+}
+
+// lockPath is the full path to the lock file.
+func (s *StateFile) lockPath() string {
+ return filepath.Join(s.RootDir, s.ID+".lock")
+}
+
+// destroy deletes all state created by the stateFile. It may be called with the
+// lock file held. In that case, the lock file must still be unlocked and
+// properly closed after destroy returns.
+func (s *StateFile) destroy() error {
+ if err := os.Remove(s.statePath()); err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ if err := os.Remove(s.lockPath()); err != nil && !os.IsNotExist(err) {
+ return err
+ }
+ return nil
+}
diff --git a/runsc/main.go b/runsc/main.go
index ae906c661..711f60d4f 100644
--- a/runsc/main.go
+++ b/runsc/main.go
@@ -46,6 +46,8 @@ var (
logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.")
debug = flag.Bool("debug", false, "enable debug logging.")
showVersion = flag.Bool("version", false, "show version and exit.")
+ // TODO(gvisor.dev/issue/193): support systemd cgroups
+ systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.")
// These flags are unique to runsc, and are used to configure parts of the
// system that are not covered by the runtime spec.
@@ -136,6 +138,12 @@ func main() {
os.Exit(0)
}
+ // TODO(gvisor.dev/issue/193): support systemd cgroups
+ if *systemdCgroup {
+ fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193")
+ os.Exit(1)
+ }
+
var errorLogger io.Writer
if *logFD > -1 {
errorLogger = os.NewFile(uintptr(*logFD), "error log file")
diff --git a/scripts/runtime_tests.sh b/scripts/runtime_tests.sh
new file mode 100755
index 000000000..9ee991e42
--- /dev/null
+++ b/scripts/runtime_tests.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+# Copyright 2019 The gVisor Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+source $(dirname $0)/common.sh
+
+if [ ! -v RUNTIME_TEST_NAME ]; then
+ echo 'Must set $RUNTIME_TEST_NAME' >&2
+ exit 1
+fi
+
+install_runsc_for_test runtimes
+test_runsc "//test/runtimes:${RUNTIME_TEST_NAME}_test"
diff --git a/test/runtimes/BUILD b/test/runtimes/BUILD
index 2e125525b..367295206 100644
--- a/test/runtimes/BUILD
+++ b/test/runtimes/BUILD
@@ -16,32 +16,32 @@ go_binary(
)
runtime_test(
+ name = "go1.12",
blacklist_file = "blacklist_go1.12.csv",
- image = "gcr.io/gvisor-presubmit/go1.12",
lang = "go",
)
runtime_test(
+ name = "java11",
blacklist_file = "blacklist_java11.csv",
- image = "gcr.io/gvisor-presubmit/java11",
lang = "java",
)
runtime_test(
+ name = "nodejs12.4.0",
blacklist_file = "blacklist_nodejs12.4.0.csv",
- image = "gcr.io/gvisor-presubmit/nodejs12.4.0",
lang = "nodejs",
)
runtime_test(
+ name = "php7.3.6",
blacklist_file = "blacklist_php7.3.6.csv",
- image = "gcr.io/gvisor-presubmit/php7.3.6",
lang = "php",
)
runtime_test(
+ name = "python3.7.3",
blacklist_file = "blacklist_python3.7.3.csv",
- image = "gcr.io/gvisor-presubmit/python3.7.3",
lang = "python",
)
diff --git a/test/runtimes/build_defs.bzl b/test/runtimes/build_defs.bzl
index 7c11624b4..6f84ca852 100644
--- a/test/runtimes/build_defs.bzl
+++ b/test/runtimes/build_defs.bzl
@@ -2,32 +2,48 @@
load("@io_bazel_rules_go//go:def.bzl", "go_test")
-# runtime_test is a macro that will create targets to run the given test target
-# with different runtime options.
def runtime_test(
+ name,
lang,
- image,
+ image_repo = "gcr.io/gvisor-presubmit",
+ image_name = None,
+ blacklist_file = None,
shard_count = 50,
- size = "enormous",
- blacklist_file = ""):
+ size = "enormous"):
+ """Generates sh_test and blacklist test targets for a given runtime.
+
+ Args:
+ name: The name of the runtime being tested. Typically, the lang + version.
+ This is used in the names of the generated test targets.
+ lang: The language being tested.
+ image_repo: The docker repository containing the proctor image to run.
+ i.e., the prefix to the fully qualified docker image id.
+ image_name: The name of the image in the image_repo.
+ Defaults to the test name.
+ blacklist_file: A test blacklist to pass to the runtime test's runner.
+ shard_count: See Bazel common test attributes.
+ size: See Bazel common test attributes.
+ """
+ if image_name == None:
+ image_name = name
args = [
"--lang",
lang,
"--image",
- image,
+ "/".join([image_repo, image_name]),
]
data = [
":runner",
]
- if blacklist_file != "":
+ if blacklist_file:
args += ["--blacklist_file", "test/runtimes/" + blacklist_file]
data += [blacklist_file]
# Add a test that the blacklist parses correctly.
- blacklist_test(lang, blacklist_file)
+ blacklist_test(name, blacklist_file)
sh_test(
- name = lang + "_test",
+ name = name + "_test",
srcs = ["runner.sh"],
args = args,
data = data,
@@ -35,15 +51,16 @@ def runtime_test(
shard_count = shard_count,
tags = [
# Requires docker and runsc to be configured before the test runs.
- "manual",
"local",
+ # Don't include test target in wildcard target patterns.
+ "manual",
],
)
-def blacklist_test(lang, blacklist_file):
+def blacklist_test(name, blacklist_file):
"""Test that a blacklist parses correctly."""
go_test(
- name = lang + "_blacklist_test",
+ name = name + "_blacklist_test",
embed = [":runner"],
srcs = ["blacklist_test.go"],
args = ["--blacklist_file", "test/runtimes/" + blacklist_file],
diff --git a/test/syscalls/linux/semaphore.cc b/test/syscalls/linux/semaphore.cc
index 40c57f543..e9b131ca9 100644
--- a/test/syscalls/linux/semaphore.cc
+++ b/test/syscalls/linux/semaphore.cc
@@ -447,9 +447,8 @@ TEST(SemaphoreTest, SemCtlGetPidFork) {
const pid_t child_pid = fork();
if (child_pid == 0) {
- ASSERT_THAT(semctl(sem.get(), 0, SETVAL, 1), SyscallSucceeds());
- ASSERT_THAT(semctl(sem.get(), 0, GETPID),
- SyscallSucceedsWithValue(getpid()));
+ TEST_PCHECK(semctl(sem.get(), 0, SETVAL, 1) == 0);
+ TEST_PCHECK(semctl(sem.get(), 0, GETPID) == getpid());
_exit(0);
}
diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc
index 322ee07ad..ab375aaaf 100644
--- a/test/syscalls/linux/socket_inet_loopback.cc
+++ b/test/syscalls/linux/socket_inet_loopback.cc
@@ -16,6 +16,7 @@
#include <netinet/in.h>
#include <poll.h>
#include <string.h>
+#include <sys/epoll.h>
#include <sys/socket.h>
#include <atomic>
@@ -516,6 +517,112 @@ TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThread) {
EquivalentWithin((kConnectAttempts / kThreadCount), 0.10));
}
+TEST_P(SocketInetReusePortTest, UdpPortReuseMultiThreadShort) {
+ auto const& param = GetParam();
+
+ TestAddress const& listener = param.listener;
+ TestAddress const& connector = param.connector;
+ sockaddr_storage listen_addr = listener.addr;
+ sockaddr_storage conn_addr = connector.addr;
+ constexpr int kThreadCount = 3;
+
+ // TODO(b/141211329): endpointsByNic.seed has to be saved/restored.
+ const DisableSave ds141211329;
+
+ // Create listening sockets.
+ FileDescriptor listener_fds[kThreadCount];
+ for (int i = 0; i < kThreadCount; i++) {
+ listener_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0));
+ int fd = listener_fds[i].get();
+
+ ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+ ASSERT_THAT(
+ bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len),
+ SyscallSucceeds());
+
+ // On the first bind we need to determine which port was bound.
+ if (i != 0) {
+ continue;
+ }
+
+ // Get the port bound by the listening socket.
+ socklen_t addrlen = listener.addr_len;
+ ASSERT_THAT(
+ getsockname(listener_fds[0].get(),
+ reinterpret_cast<sockaddr*>(&listen_addr), &addrlen),
+ SyscallSucceeds());
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr));
+ ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port));
+ ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port));
+ }
+
+ constexpr int kConnectAttempts = 10;
+ FileDescriptor client_fds[kConnectAttempts];
+
+ // Do the first run without save/restore.
+ DisableSave ds;
+ for (int i = 0; i < kConnectAttempts; i++) {
+ client_fds[i] =
+ ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0));
+ EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+ ds.reset();
+
+ // Check that a mapping of client and server sockets has
+ // not been change after save/restore.
+ for (int i = 0; i < kConnectAttempts; i++) {
+ EXPECT_THAT(RetryEINTR(sendto)(client_fds[i].get(), &i, sizeof(i), 0,
+ reinterpret_cast<sockaddr*>(&conn_addr),
+ connector.addr_len),
+ SyscallSucceedsWithValue(sizeof(i)));
+ }
+
+ int epollfd;
+ ASSERT_THAT(epollfd = epoll_create1(0), SyscallSucceeds());
+
+ for (int i = 0; i < kThreadCount; i++) {
+ int fd = listener_fds[i].get();
+ struct epoll_event ev;
+ ev.data.fd = fd;
+ ev.events = EPOLLIN;
+ ASSERT_THAT(epoll_ctl(epollfd, EPOLL_CTL_ADD, fd, &ev), SyscallSucceeds());
+ }
+
+ std::map<uint16_t, int> portToFD;
+
+ for (int i = 0; i < kConnectAttempts * 2; i++) {
+ struct sockaddr_storage addr = {};
+ socklen_t addrlen = sizeof(addr);
+ struct epoll_event ev;
+ int data, fd;
+
+ ASSERT_THAT(epoll_wait(epollfd, &ev, 1, -1), SyscallSucceedsWithValue(1));
+
+ fd = ev.data.fd;
+ EXPECT_THAT(RetryEINTR(recvfrom)(fd, &data, sizeof(data), 0,
+ reinterpret_cast<struct sockaddr*>(&addr),
+ &addrlen),
+ SyscallSucceedsWithValue(sizeof(data)));
+ uint16_t const port =
+ ASSERT_NO_ERRNO_AND_VALUE(AddrPort(connector.family(), addr));
+ auto prev_port = portToFD.find(port);
+ // Check that all packets from one client have been delivered to the same
+ // server socket.
+ if (prev_port == portToFD.end()) {
+ portToFD[port] = ev.data.fd;
+ } else {
+ EXPECT_EQ(portToFD[port], ev.data.fd);
+ }
+ }
+}
+
INSTANTIATE_TEST_SUITE_P(
All, SocketInetReusePortTest,
::testing::Values(
diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc
index 7e0deda05..592448289 100644
--- a/test/syscalls/linux/socket_ip_tcp_generic.cc
+++ b/test/syscalls/linux/socket_ip_tcp_generic.cc
@@ -30,7 +30,7 @@
namespace gvisor {
namespace testing {
-TEST_P(TCPSocketPairTest, TcpInfoSucceedes) {
+TEST_P(TCPSocketPairTest, TcpInfoSucceeds) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct tcp_info opt = {};
@@ -39,7 +39,7 @@ TEST_P(TCPSocketPairTest, TcpInfoSucceedes) {
SyscallSucceeds());
}
-TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) {
+TEST_P(TCPSocketPairTest, ShortTcpInfoSucceeds) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct tcp_info opt = {};
@@ -48,7 +48,7 @@ TEST_P(TCPSocketPairTest, ShortTcpInfoSucceedes) {
SyscallSucceeds());
}
-TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceedes) {
+TEST_P(TCPSocketPairTest, ZeroTcpInfoSucceeds) {
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
struct tcp_info opt = {};
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index dd4a11655..be0dadcd6 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -195,7 +195,8 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(SO_DOMAIN, IsEqual(AF_NETLINK),
absl::StrFormat("AF_NETLINK (%d)", AF_NETLINK)),
std::make_tuple(SO_PROTOCOL, IsEqual(NETLINK_ROUTE),
- absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE))));
+ absl::StrFormat("NETLINK_ROUTE (%d)", NETLINK_ROUTE)),
+ std::make_tuple(SO_PASSCRED, IsEqual(0), "0")));
// Validates the reponses to RTM_GETLINK + NLM_F_DUMP.
void CheckGetLinkResponse(const struct nlmsghdr* hdr, int seq, int port) {
@@ -692,6 +693,113 @@ TEST(NetlinkRouteTest, RecvmsgTruncPeek) {
} while (type != NLMSG_DONE && type != NLMSG_ERROR);
}
+// No SCM_CREDENTIALS are received without SO_PASSCRED set.
+TEST(NetlinkRouteTest, NoPasscredNoCreds) {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+
+ ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceeds());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ char control[CMSG_SPACE(sizeof(struct ucred))] = {};
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ // Note: This test assumes at least one message is returned by the
+ // RTM_GETADDR request.
+ ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ // No control messages.
+ EXPECT_EQ(CMSG_FIRSTHDR(&msg), nullptr);
+}
+
+// SCM_CREDENTIALS are received with SO_PASSCRED set.
+TEST(NetlinkRouteTest, PasscredCreds) {
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket());
+
+ ASSERT_THAT(setsockopt(fd.get(), SOL_SOCKET, SO_PASSCRED, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct rtgenmsg rgm;
+ };
+
+ constexpr uint32_t kSeq = 12345;
+
+ struct request req;
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETADDR;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
+ req.hdr.nlmsg_seq = kSeq;
+ req.rgm.rtgen_family = AF_UNSPEC;
+
+ struct iovec iov = {};
+ iov.iov_base = &req;
+ iov.iov_len = sizeof(req);
+
+ struct msghdr msg = {};
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ iov.iov_base = NULL;
+ iov.iov_len = 0;
+
+ char control[CMSG_SPACE(sizeof(struct ucred))] = {};
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ // Note: This test assumes at least one message is returned by the
+ // RTM_GETADDR request.
+ ASSERT_THAT(RetryEINTR(recvmsg)(fd.get(), &msg, 0), SyscallSucceeds());
+
+ struct ucred creds;
+ struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
+ ASSERT_NE(cmsg, nullptr);
+ ASSERT_EQ(cmsg->cmsg_len, CMSG_LEN(sizeof(creds)));
+ ASSERT_EQ(cmsg->cmsg_level, SOL_SOCKET);
+ ASSERT_EQ(cmsg->cmsg_type, SCM_CREDENTIALS);
+
+ memcpy(&creds, CMSG_DATA(cmsg), sizeof(creds));
+
+ // The peer is the kernel, which is "PID" 0.
+ EXPECT_EQ(creds.pid, 0);
+ // The kernel identifies as root. Also allow nobody in case this test is
+ // running in a userns without root mapped.
+ EXPECT_THAT(creds.uid, AnyOf(Eq(0), Eq(65534)));
+ EXPECT_THAT(creds.gid, AnyOf(Eq(0), Eq(65534)));
+}
+
} // namespace
} // namespace testing
diff --git a/tools/go_branch.sh b/tools/go_branch.sh
index ddb9b6e7b..0ac16e266 100755
--- a/tools/go_branch.sh
+++ b/tools/go_branch.sh
@@ -17,9 +17,9 @@
set -eo pipefail
# Discovery the package name from the go.mod file.
-declare -r gomod="$(pwd)/go.mod"
-declare -r module=$(cat "${gomod}" | grep -E "^module" | cut -d' ' -f2)
-declare -r gosum="$(pwd)/go.sum"
+declare -r module=$(cat go.mod | grep -E "^module" | cut -d' ' -f2)
+declare -r origpwd=$(pwd)
+declare -r othersrc=("go.mod" "go.sum" "AUTHORS" "LICENSE")
# Check that gopath has been built.
declare -r gopath_dir="$(pwd)/bazel-bin/gopath/src/${module}"
@@ -65,10 +65,22 @@ git checkout -b go "${go_branch}"
git merge --no-commit --strategy ours ${head} || \
git merge --allow-unrelated-histories --no-commit --strategy ours ${head}
-# Sync the entire gopath_dir and go.mod.
-rsync --recursive --verbose --delete --exclude .git --exclude README.md -L "${gopath_dir}/" .
-cp "${gomod}" .
-cp "${gosum}" .
+# Sync the entire gopath_dir.
+rsync --recursive --verbose --delete --exclude .git -L "${gopath_dir}/" .
+
+# Add additional files.
+for file in "${othersrc[@]}"; do
+ cp "${origpwd}"/"${file}" .
+done
+
+# Construct a new README.md.
+cat > README.md <<EOF
+# gVisor
+
+This branch is a synthetic branch, containing only Go sources, that is
+compatible with standard Go tools. See the `master` branch for authoritative
+sources and tests.
+EOF
# There are a few solitary files that can get left behind due to the way bazel
# constructs the gopath target. Note that we don't find all Go files here