summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/kernel/kernel.go83
-rw-r--r--pkg/sentry/kernel/kernel_state_autogen.go103
-rw-r--r--pkg/sentry/kernel/socket_list.go32
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go7
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go7
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go7
-rw-r--r--pkg/sentry/socket/unix/socket_refs.go24
-rw-r--r--pkg/sentry/socket/unix/socket_vfs2_refs.go118
-rw-r--r--pkg/sentry/socket/unix/unix.go36
-rw-r--r--pkg/sentry/socket/unix/unix_state_autogen.go77
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go20
11 files changed, 381 insertions, 133 deletions
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 08bb5bd12..d6c21adb7 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -220,13 +220,18 @@ type Kernel struct {
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
- // sockets is the list of all network sockets the system. Protected by
- // extMu.
+ // sockets is the list of all network sockets in the system.
+ // Protected by extMu.
+ // TODO(gvisor.dev/issue/1624): Only used by VFS1.
sockets socketList
- // nextSocketEntry is the next entry number to use in sockets. Protected
+ // socketsVFS2 records all network sockets in the system. Protected by
+ // extMu.
+ socketsVFS2 map[*vfs.FileDescription]*SocketRecord
+
+ // nextSocketRecord is the next entry number to use in sockets. Protected
// by extMu.
- nextSocketEntry uint64
+ nextSocketRecord uint64
// deviceRegistry is used to save/restore device.SimpleDevices.
deviceRegistry struct{} `state:".(*device.Registry)"`
@@ -414,6 +419,8 @@ func (k *Kernel) Init(args InitKernelArgs) error {
return fmt.Errorf("failed to create sockfs mount: %v", err)
}
k.socketMount = socketMount
+
+ k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord)
}
return nil
@@ -1512,20 +1519,27 @@ func (k *Kernel) SupervisorContext() context.Context {
}
}
-// SocketEntry represents a socket recorded in Kernel.sockets. It implements
+// SocketRecord represents a socket recorded in Kernel.socketsVFS2.
+//
+// +stateify savable
+type SocketRecord struct {
+ k *Kernel
+ Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1.
+ SockVFS2 *vfs.FileDescription // Only used by VFS2.
+ ID uint64 // Socket table entry number.
+}
+
+// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements
// refs.WeakRefUser for sockets stored in the socket table.
//
// +stateify savable
-type SocketEntry struct {
+type SocketRecordVFS1 struct {
socketEntry
- k *Kernel
- Sock *refs.WeakRef
- SockVFS2 *vfs.FileDescription
- ID uint64 // Socket table entry number.
+ SocketRecord
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (s *SocketEntry) WeakRefGone(context.Context) {
+func (s *SocketRecordVFS1) WeakRefGone(context.Context) {
s.k.extMu.Lock()
s.k.sockets.Remove(s)
s.k.extMu.Unlock()
@@ -1536,9 +1550,14 @@ func (s *SocketEntry) WeakRefGone(context.Context) {
// Precondition: Caller must hold a reference to sock.
func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Lock()
- id := k.nextSocketEntry
- k.nextSocketEntry++
- s := &SocketEntry{k: k, ID: id}
+ id := k.nextSocketRecord
+ k.nextSocketRecord++
+ s := &SocketRecordVFS1{
+ SocketRecord: SocketRecord{
+ k: k,
+ ID: id,
+ },
+ }
s.Sock = refs.NewWeakRef(sock, s)
k.sockets.PushBack(s)
k.extMu.Unlock()
@@ -1550,29 +1569,45 @@ func (k *Kernel) RecordSocket(sock *fs.File) {
// Precondition: Caller must hold a reference to sock.
//
// Note that the socket table will not hold a reference on the
-// vfs.FileDescription, because we do not support weak refs on VFS2 files.
+// vfs.FileDescription.
func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) {
k.extMu.Lock()
- id := k.nextSocketEntry
- k.nextSocketEntry++
- s := &SocketEntry{
+ if _, ok := k.socketsVFS2[sock]; ok {
+ panic(fmt.Sprintf("Socket %p added twice", sock))
+ }
+ id := k.nextSocketRecord
+ k.nextSocketRecord++
+ s := &SocketRecord{
k: k,
ID: id,
SockVFS2: sock,
}
- k.sockets.PushBack(s)
+ k.socketsVFS2[sock] = s
+ k.extMu.Unlock()
+}
+
+// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table.
+func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) {
+ k.extMu.Lock()
+ delete(k.socketsVFS2, sock)
k.extMu.Unlock()
}
// ListSockets returns a snapshot of all sockets.
//
-// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef()
+// Callers of ListSockets() in VFS2 should use SocketRecord.SockVFS2.TryIncRef()
// to get a reference on a socket in the table.
-func (k *Kernel) ListSockets() []*SocketEntry {
+func (k *Kernel) ListSockets() []*SocketRecord {
k.extMu.Lock()
- var socks []*SocketEntry
- for s := k.sockets.Front(); s != nil; s = s.Next() {
- socks = append(socks, s)
+ var socks []*SocketRecord
+ if VFS2Enabled {
+ for _, s := range k.socketsVFS2 {
+ socks = append(socks, s)
+ }
+ } else {
+ for s := k.sockets.Front(); s != nil; s = s.Next() {
+ socks = append(socks, &s.SocketRecord)
+ }
}
k.extMu.Unlock()
return socks
diff --git a/pkg/sentry/kernel/kernel_state_autogen.go b/pkg/sentry/kernel/kernel_state_autogen.go
index d0ff135d7..f20800960 100644
--- a/pkg/sentry/kernel/kernel_state_autogen.go
+++ b/pkg/sentry/kernel/kernel_state_autogen.go
@@ -297,7 +297,8 @@ func (x *Kernel) StateFields() []string {
"netlinkPorts",
"danglingEndpoints",
"sockets",
- "nextSocketEntry",
+ "socketsVFS2",
+ "nextSocketRecord",
"deviceRegistry",
"DirentCacheLimiter",
"SpecialOpts",
@@ -317,7 +318,7 @@ func (x *Kernel) StateSave(m state.Sink) {
var danglingEndpoints []tcpip.Endpoint = x.saveDanglingEndpoints()
m.SaveValue(24, danglingEndpoints)
var deviceRegistry *device.Registry = x.saveDeviceRegistry()
- m.SaveValue(27, deviceRegistry)
+ m.SaveValue(28, deviceRegistry)
m.Save(0, &x.featureSet)
m.Save(1, &x.timekeeper)
m.Save(2, &x.tasks)
@@ -343,15 +344,16 @@ func (x *Kernel) StateSave(m state.Sink) {
m.Save(22, &x.nextInotifyCookie)
m.Save(23, &x.netlinkPorts)
m.Save(25, &x.sockets)
- m.Save(26, &x.nextSocketEntry)
- m.Save(28, &x.DirentCacheLimiter)
- m.Save(29, &x.SpecialOpts)
- m.Save(30, &x.vfs)
- m.Save(31, &x.hostMount)
- m.Save(32, &x.pipeMount)
- m.Save(33, &x.shmMount)
- m.Save(34, &x.socketMount)
- m.Save(35, &x.SleepForAddressSpaceActivation)
+ m.Save(26, &x.socketsVFS2)
+ m.Save(27, &x.nextSocketRecord)
+ m.Save(29, &x.DirentCacheLimiter)
+ m.Save(30, &x.SpecialOpts)
+ m.Save(31, &x.vfs)
+ m.Save(32, &x.hostMount)
+ m.Save(33, &x.pipeMount)
+ m.Save(34, &x.shmMount)
+ m.Save(35, &x.socketMount)
+ m.Save(36, &x.SleepForAddressSpaceActivation)
}
func (x *Kernel) afterLoad() {}
@@ -382,26 +384,26 @@ func (x *Kernel) StateLoad(m state.Source) {
m.Load(22, &x.nextInotifyCookie)
m.Load(23, &x.netlinkPorts)
m.Load(25, &x.sockets)
- m.Load(26, &x.nextSocketEntry)
- m.Load(28, &x.DirentCacheLimiter)
- m.Load(29, &x.SpecialOpts)
- m.Load(30, &x.vfs)
- m.Load(31, &x.hostMount)
- m.Load(32, &x.pipeMount)
- m.Load(33, &x.shmMount)
- m.Load(34, &x.socketMount)
- m.Load(35, &x.SleepForAddressSpaceActivation)
+ m.Load(26, &x.socketsVFS2)
+ m.Load(27, &x.nextSocketRecord)
+ m.Load(29, &x.DirentCacheLimiter)
+ m.Load(30, &x.SpecialOpts)
+ m.Load(31, &x.vfs)
+ m.Load(32, &x.hostMount)
+ m.Load(33, &x.pipeMount)
+ m.Load(34, &x.shmMount)
+ m.Load(35, &x.socketMount)
+ m.Load(36, &x.SleepForAddressSpaceActivation)
m.LoadValue(24, new([]tcpip.Endpoint), func(y interface{}) { x.loadDanglingEndpoints(y.([]tcpip.Endpoint)) })
- m.LoadValue(27, new(*device.Registry), func(y interface{}) { x.loadDeviceRegistry(y.(*device.Registry)) })
+ m.LoadValue(28, new(*device.Registry), func(y interface{}) { x.loadDeviceRegistry(y.(*device.Registry)) })
}
-func (x *SocketEntry) StateTypeName() string {
- return "pkg/sentry/kernel.SocketEntry"
+func (x *SocketRecord) StateTypeName() string {
+ return "pkg/sentry/kernel.SocketRecord"
}
-func (x *SocketEntry) StateFields() []string {
+func (x *SocketRecord) StateFields() []string {
return []string{
- "socketEntry",
"k",
"Sock",
"SockVFS2",
@@ -409,25 +411,49 @@ func (x *SocketEntry) StateFields() []string {
}
}
-func (x *SocketEntry) beforeSave() {}
+func (x *SocketRecord) beforeSave() {}
-func (x *SocketEntry) StateSave(m state.Sink) {
+func (x *SocketRecord) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.k)
+ m.Save(1, &x.Sock)
+ m.Save(2, &x.SockVFS2)
+ m.Save(3, &x.ID)
+}
+
+func (x *SocketRecord) afterLoad() {}
+
+func (x *SocketRecord) StateLoad(m state.Source) {
+ m.Load(0, &x.k)
+ m.Load(1, &x.Sock)
+ m.Load(2, &x.SockVFS2)
+ m.Load(3, &x.ID)
+}
+
+func (x *SocketRecordVFS1) StateTypeName() string {
+ return "pkg/sentry/kernel.SocketRecordVFS1"
+}
+
+func (x *SocketRecordVFS1) StateFields() []string {
+ return []string{
+ "socketEntry",
+ "SocketRecord",
+ }
+}
+
+func (x *SocketRecordVFS1) beforeSave() {}
+
+func (x *SocketRecordVFS1) StateSave(m state.Sink) {
x.beforeSave()
m.Save(0, &x.socketEntry)
- m.Save(1, &x.k)
- m.Save(2, &x.Sock)
- m.Save(3, &x.SockVFS2)
- m.Save(4, &x.ID)
+ m.Save(1, &x.SocketRecord)
}
-func (x *SocketEntry) afterLoad() {}
+func (x *SocketRecordVFS1) afterLoad() {}
-func (x *SocketEntry) StateLoad(m state.Source) {
+func (x *SocketRecordVFS1) StateLoad(m state.Source) {
m.Load(0, &x.socketEntry)
- m.Load(1, &x.k)
- m.Load(2, &x.Sock)
- m.Load(3, &x.SockVFS2)
- m.Load(4, &x.ID)
+ m.Load(1, &x.SocketRecord)
}
func (x *pendingSignals) StateTypeName() string {
@@ -2264,7 +2290,8 @@ func init() {
state.Register((*FSContextRefs)(nil))
state.Register((*IPCNamespace)(nil))
state.Register((*Kernel)(nil))
- state.Register((*SocketEntry)(nil))
+ state.Register((*SocketRecord)(nil))
+ state.Register((*SocketRecordVFS1)(nil))
state.Register((*pendingSignals)(nil))
state.Register((*pendingSignalQueue)(nil))
state.Register((*pendingSignal)(nil))
diff --git a/pkg/sentry/kernel/socket_list.go b/pkg/sentry/kernel/socket_list.go
index d2d4307a1..246fba405 100644
--- a/pkg/sentry/kernel/socket_list.go
+++ b/pkg/sentry/kernel/socket_list.go
@@ -13,7 +13,7 @@ type socketElementMapper struct{}
// This default implementation should be inlined.
//
//go:nosplit
-func (socketElementMapper) linkerFor(elem *SocketEntry) *SocketEntry { return elem }
+func (socketElementMapper) linkerFor(elem *SocketRecordVFS1) *SocketRecordVFS1 { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
@@ -27,8 +27,8 @@ func (socketElementMapper) linkerFor(elem *SocketEntry) *SocketEntry { return el
//
// +stateify savable
type socketList struct {
- head *SocketEntry
- tail *SocketEntry
+ head *SocketRecordVFS1
+ tail *SocketRecordVFS1
}
// Reset resets list l to the empty state.
@@ -43,12 +43,12 @@ func (l *socketList) Empty() bool {
}
// Front returns the first element of list l or nil.
-func (l *socketList) Front() *SocketEntry {
+func (l *socketList) Front() *SocketRecordVFS1 {
return l.head
}
// Back returns the last element of list l or nil.
-func (l *socketList) Back() *SocketEntry {
+func (l *socketList) Back() *SocketRecordVFS1 {
return l.tail
}
@@ -63,7 +63,7 @@ func (l *socketList) Len() (count int) {
}
// PushFront inserts the element e at the front of list l.
-func (l *socketList) PushFront(e *SocketEntry) {
+func (l *socketList) PushFront(e *SocketRecordVFS1) {
linker := socketElementMapper{}.linkerFor(e)
linker.SetNext(l.head)
linker.SetPrev(nil)
@@ -77,7 +77,7 @@ func (l *socketList) PushFront(e *SocketEntry) {
}
// PushBack inserts the element e at the back of list l.
-func (l *socketList) PushBack(e *SocketEntry) {
+func (l *socketList) PushBack(e *SocketRecordVFS1) {
linker := socketElementMapper{}.linkerFor(e)
linker.SetNext(nil)
linker.SetPrev(l.tail)
@@ -106,7 +106,7 @@ func (l *socketList) PushBackList(m *socketList) {
}
// InsertAfter inserts e after b.
-func (l *socketList) InsertAfter(b, e *SocketEntry) {
+func (l *socketList) InsertAfter(b, e *SocketRecordVFS1) {
bLinker := socketElementMapper{}.linkerFor(b)
eLinker := socketElementMapper{}.linkerFor(e)
@@ -124,7 +124,7 @@ func (l *socketList) InsertAfter(b, e *SocketEntry) {
}
// InsertBefore inserts e before a.
-func (l *socketList) InsertBefore(a, e *SocketEntry) {
+func (l *socketList) InsertBefore(a, e *SocketRecordVFS1) {
aLinker := socketElementMapper{}.linkerFor(a)
eLinker := socketElementMapper{}.linkerFor(e)
@@ -141,7 +141,7 @@ func (l *socketList) InsertBefore(a, e *SocketEntry) {
}
// Remove removes e from l.
-func (l *socketList) Remove(e *SocketEntry) {
+func (l *socketList) Remove(e *SocketRecordVFS1) {
linker := socketElementMapper{}.linkerFor(e)
prev := linker.Prev()
next := linker.Next()
@@ -168,26 +168,26 @@ func (l *socketList) Remove(e *SocketEntry) {
//
// +stateify savable
type socketEntry struct {
- next *SocketEntry
- prev *SocketEntry
+ next *SocketRecordVFS1
+ prev *SocketRecordVFS1
}
// Next returns the entry that follows e in the list.
-func (e *socketEntry) Next() *SocketEntry {
+func (e *socketEntry) Next() *SocketRecordVFS1 {
return e.next
}
// Prev returns the entry that precedes e in the list.
-func (e *socketEntry) Prev() *SocketEntry {
+func (e *socketEntry) Prev() *SocketRecordVFS1 {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
-func (e *socketEntry) SetNext(elem *SocketEntry) {
+func (e *socketEntry) SetNext(elem *SocketRecordVFS1) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
-func (e *socketEntry) SetPrev(elem *SocketEntry) {
+func (e *socketEntry) SetPrev(elem *SocketRecordVFS1) {
e.prev = elem
}
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 87b077e68..163af329b 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -78,6 +78,13 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in
return vfsfd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *socketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
index a38d25da9..c83b23242 100644
--- a/pkg/sentry/socket/netlink/socket_vfs2.go
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -82,6 +82,13 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV
return fd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index c0212ad76..4c6791fff 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -79,6 +79,13 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
return vfsfd, nil
}
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.socketOpsCommon.Release(ctx)
+}
+
// Readiness implements waiter.Waitable.Readiness.
func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.socketOpsCommon.Readiness(mask)
diff --git a/pkg/sentry/socket/unix/socket_refs.go b/pkg/sentry/socket/unix/socket_refs.go
index dababb85f..ea63dc659 100644
--- a/pkg/sentry/socket/unix/socket_refs.go
+++ b/pkg/sentry/socket/unix/socket_refs.go
@@ -11,7 +11,7 @@ import (
// ownerType is used to customize logging. Note that we use a pointer to T so
// that we do not copy the entire object when passed as a format parameter.
-var socketOpsCommonownerType *socketOpsCommon
+var socketOperationsownerType *SocketOperations
// Refs implements refs.RefCounter. It keeps a reference count using atomic
// operations and calls the destructor when the count reaches zero.
@@ -25,7 +25,7 @@ var socketOpsCommonownerType *socketOpsCommon
// without growing the size of Refs.
//
// +stateify savable
-type socketOpsCommonRefs struct {
+type socketOperationsRefs struct {
// refCount is composed of two fields:
//
// [32-bit speculative references]:[32-bit real references]
@@ -36,7 +36,7 @@ type socketOpsCommonRefs struct {
refCount int64
}
-func (r *socketOpsCommonRefs) finalize() {
+func (r *socketOperationsRefs) finalize() {
var note string
switch refs_vfs1.GetLeakMode() {
case refs_vfs1.NoLeakChecking:
@@ -45,20 +45,20 @@ func (r *socketOpsCommonRefs) finalize() {
note = "(Leak checker uninitialized): "
}
if n := r.ReadRefs(); n != 0 {
- log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, socketOpsCommonownerType, n)
+ log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, socketOperationsownerType, n)
}
}
// EnableLeakCheck checks for reference leaks when Refs gets garbage collected.
-func (r *socketOpsCommonRefs) EnableLeakCheck() {
+func (r *socketOperationsRefs) EnableLeakCheck() {
if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
- runtime.SetFinalizer(r, (*socketOpsCommonRefs).finalize)
+ runtime.SetFinalizer(r, (*socketOperationsRefs).finalize)
}
}
// ReadRefs returns the current number of references. The returned count is
// inherently racy and is unsafe to use without external synchronization.
-func (r *socketOpsCommonRefs) ReadRefs() int64 {
+func (r *socketOperationsRefs) ReadRefs() int64 {
return atomic.LoadInt64(&r.refCount) + 1
}
@@ -66,9 +66,9 @@ func (r *socketOpsCommonRefs) ReadRefs() int64 {
// IncRef implements refs.RefCounter.IncRef.
//
//go:nosplit
-func (r *socketOpsCommonRefs) IncRef() {
+func (r *socketOperationsRefs) IncRef() {
if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
- panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, socketOpsCommonownerType))
+ panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, socketOperationsownerType))
}
}
@@ -79,7 +79,7 @@ func (r *socketOpsCommonRefs) IncRef() {
// other TryIncRef calls from genuine references held.
//
//go:nosplit
-func (r *socketOpsCommonRefs) TryIncRef() bool {
+func (r *socketOperationsRefs) TryIncRef() bool {
const speculativeRef = 1 << 32
v := atomic.AddInt64(&r.refCount, speculativeRef)
if int32(v) < 0 {
@@ -104,10 +104,10 @@ func (r *socketOpsCommonRefs) TryIncRef() bool {
// A: TryIncRef [transform speculative to real]
//
//go:nosplit
-func (r *socketOpsCommonRefs) DecRef(destroy func()) {
+func (r *socketOperationsRefs) DecRef(destroy func()) {
switch v := atomic.AddInt64(&r.refCount, -1); {
case v < -1:
- panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, socketOpsCommonownerType))
+ panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, socketOperationsownerType))
case v == -1:
diff --git a/pkg/sentry/socket/unix/socket_vfs2_refs.go b/pkg/sentry/socket/unix/socket_vfs2_refs.go
new file mode 100644
index 000000000..dc55f2947
--- /dev/null
+++ b/pkg/sentry/socket/unix/socket_vfs2_refs.go
@@ -0,0 +1,118 @@
+package unix
+
+import (
+ "fmt"
+ "runtime"
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/log"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+)
+
+// ownerType is used to customize logging. Note that we use a pointer to T so
+// that we do not copy the entire object when passed as a format parameter.
+var socketVFS2ownerType *SocketVFS2
+
+// Refs implements refs.RefCounter. It keeps a reference count using atomic
+// operations and calls the destructor when the count reaches zero.
+//
+// Note that the number of references is actually refCount + 1 so that a default
+// zero-value Refs object contains one reference.
+//
+// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in
+// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount.
+// This will allow us to add stack trace information to the leak messages
+// without growing the size of Refs.
+//
+// +stateify savable
+type socketVFS2Refs struct {
+ // refCount is composed of two fields:
+ //
+ // [32-bit speculative references]:[32-bit real references]
+ //
+ // Speculative references are used for TryIncRef, to avoid a CompareAndSwap
+ // loop. See IncRef, DecRef and TryIncRef for details of how these fields are
+ // used.
+ refCount int64
+}
+
+func (r *socketVFS2Refs) finalize() {
+ var note string
+ switch refs_vfs1.GetLeakMode() {
+ case refs_vfs1.NoLeakChecking:
+ return
+ case refs_vfs1.UninitializedLeakChecking:
+ note = "(Leak checker uninitialized): "
+ }
+ if n := r.ReadRefs(); n != 0 {
+ log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, socketVFS2ownerType, n)
+ }
+}
+
+// EnableLeakCheck checks for reference leaks when Refs gets garbage collected.
+func (r *socketVFS2Refs) EnableLeakCheck() {
+ if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
+ runtime.SetFinalizer(r, (*socketVFS2Refs).finalize)
+ }
+}
+
+// ReadRefs returns the current number of references. The returned count is
+// inherently racy and is unsafe to use without external synchronization.
+func (r *socketVFS2Refs) ReadRefs() int64 {
+
+ return atomic.LoadInt64(&r.refCount) + 1
+}
+
+// IncRef implements refs.RefCounter.IncRef.
+//
+//go:nosplit
+func (r *socketVFS2Refs) IncRef() {
+ if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
+ panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, socketVFS2ownerType))
+ }
+}
+
+// TryIncRef implements refs.RefCounter.TryIncRef.
+//
+// To do this safely without a loop, a speculative reference is first acquired
+// on the object. This allows multiple concurrent TryIncRef calls to distinguish
+// other TryIncRef calls from genuine references held.
+//
+//go:nosplit
+func (r *socketVFS2Refs) TryIncRef() bool {
+ const speculativeRef = 1 << 32
+ v := atomic.AddInt64(&r.refCount, speculativeRef)
+ if int32(v) < 0 {
+
+ atomic.AddInt64(&r.refCount, -speculativeRef)
+ return false
+ }
+
+ atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ return true
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+//
+// Note that speculative references are counted here. Since they were added
+// prior to real references reaching zero, they will successfully convert to
+// real references. In other words, we see speculative references only in the
+// following case:
+//
+// A: TryIncRef [speculative increase => sees non-negative references]
+// B: DecRef [real decrease]
+// A: TryIncRef [transform speculative to real]
+//
+//go:nosplit
+func (r *socketVFS2Refs) DecRef(destroy func()) {
+ switch v := atomic.AddInt64(&r.refCount, -1); {
+ case v < -1:
+ panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, socketVFS2ownerType))
+
+ case v == -1:
+
+ if destroy != nil {
+ destroy()
+ }
+ }
+}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 917055cea..f80011ce4 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -55,6 +55,7 @@ type SocketOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ socketOperationsRefs
socketOpsCommon
}
@@ -84,11 +85,27 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
return fs.NewFile(ctx, d, flags, &s)
}
+// DecRef implements RefCounter.DecRef.
+func (s *SocketOperations) DecRef(ctx context.Context) {
+ s.socketOperationsRefs.DecRef(func() {
+ s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
+ })
+}
+
+// Release implemements fs.FileOperations.Release.
+func (s *SocketOperations) Release(ctx context.Context) {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef(ctx)
+}
+
// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
//
// +stateify savable
type socketOpsCommon struct {
- socketOpsCommonRefs
socket.SendReceiveTimeout
ep transport.Endpoint
@@ -101,23 +118,6 @@ type socketOpsCommon struct {
abstractNamespace *kernel.AbstractSocketNamespace
}
-// DecRef implements RefCounter.DecRef.
-func (s *socketOpsCommon) DecRef(ctx context.Context) {
- s.socketOpsCommonRefs.DecRef(func() {
- s.ep.Close(ctx)
- if s.abstractNamespace != nil {
- s.abstractNamespace.Remove(s.abstractName, s)
- }
- })
-}
-
-// Release implemements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release(ctx context.Context) {
- // Release only decrements a reference on s because s may be referenced in
- // the abstract socket namespace.
- s.DecRef(ctx)
-}
-
func (s *socketOpsCommon) isPacket() bool {
switch s.stype {
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go
index 89d78a9ad..51fd66b78 100644
--- a/pkg/sentry/socket/unix/unix_state_autogen.go
+++ b/pkg/sentry/socket/unix/unix_state_autogen.go
@@ -6,26 +6,49 @@ import (
"gvisor.dev/gvisor/pkg/state"
)
-func (x *socketOpsCommonRefs) StateTypeName() string {
- return "pkg/sentry/socket/unix.socketOpsCommonRefs"
+func (x *socketOperationsRefs) StateTypeName() string {
+ return "pkg/sentry/socket/unix.socketOperationsRefs"
}
-func (x *socketOpsCommonRefs) StateFields() []string {
+func (x *socketOperationsRefs) StateFields() []string {
return []string{
"refCount",
}
}
-func (x *socketOpsCommonRefs) beforeSave() {}
+func (x *socketOperationsRefs) beforeSave() {}
-func (x *socketOpsCommonRefs) StateSave(m state.Sink) {
+func (x *socketOperationsRefs) StateSave(m state.Sink) {
x.beforeSave()
m.Save(0, &x.refCount)
}
-func (x *socketOpsCommonRefs) afterLoad() {}
+func (x *socketOperationsRefs) afterLoad() {}
-func (x *socketOpsCommonRefs) StateLoad(m state.Source) {
+func (x *socketOperationsRefs) StateLoad(m state.Source) {
+ m.Load(0, &x.refCount)
+}
+
+func (x *socketVFS2Refs) StateTypeName() string {
+ return "pkg/sentry/socket/unix.socketVFS2Refs"
+}
+
+func (x *socketVFS2Refs) StateFields() []string {
+ return []string{
+ "refCount",
+ }
+}
+
+func (x *socketVFS2Refs) beforeSave() {}
+
+func (x *socketVFS2Refs) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.refCount)
+}
+
+func (x *socketVFS2Refs) afterLoad() {}
+
+func (x *socketVFS2Refs) StateLoad(m state.Source) {
m.Load(0, &x.refCount)
}
@@ -35,6 +58,7 @@ func (x *SocketOperations) StateTypeName() string {
func (x *SocketOperations) StateFields() []string {
return []string{
+ "socketOperationsRefs",
"socketOpsCommon",
}
}
@@ -43,13 +67,15 @@ func (x *SocketOperations) beforeSave() {}
func (x *SocketOperations) StateSave(m state.Sink) {
x.beforeSave()
- m.Save(0, &x.socketOpsCommon)
+ m.Save(0, &x.socketOperationsRefs)
+ m.Save(1, &x.socketOpsCommon)
}
func (x *SocketOperations) afterLoad() {}
func (x *SocketOperations) StateLoad(m state.Source) {
- m.Load(0, &x.socketOpsCommon)
+ m.Load(0, &x.socketOperationsRefs)
+ m.Load(1, &x.socketOpsCommon)
}
func (x *socketOpsCommon) StateTypeName() string {
@@ -58,7 +84,6 @@ func (x *socketOpsCommon) StateTypeName() string {
func (x *socketOpsCommon) StateFields() []string {
return []string{
- "socketOpsCommonRefs",
"SendReceiveTimeout",
"ep",
"stype",
@@ -71,23 +96,21 @@ func (x *socketOpsCommon) beforeSave() {}
func (x *socketOpsCommon) StateSave(m state.Sink) {
x.beforeSave()
- m.Save(0, &x.socketOpsCommonRefs)
- m.Save(1, &x.SendReceiveTimeout)
- m.Save(2, &x.ep)
- m.Save(3, &x.stype)
- m.Save(4, &x.abstractName)
- m.Save(5, &x.abstractNamespace)
+ m.Save(0, &x.SendReceiveTimeout)
+ m.Save(1, &x.ep)
+ m.Save(2, &x.stype)
+ m.Save(3, &x.abstractName)
+ m.Save(4, &x.abstractNamespace)
}
func (x *socketOpsCommon) afterLoad() {}
func (x *socketOpsCommon) StateLoad(m state.Source) {
- m.Load(0, &x.socketOpsCommonRefs)
- m.Load(1, &x.SendReceiveTimeout)
- m.Load(2, &x.ep)
- m.Load(3, &x.stype)
- m.Load(4, &x.abstractName)
- m.Load(5, &x.abstractNamespace)
+ m.Load(0, &x.SendReceiveTimeout)
+ m.Load(1, &x.ep)
+ m.Load(2, &x.stype)
+ m.Load(3, &x.abstractName)
+ m.Load(4, &x.abstractNamespace)
}
func (x *SocketVFS2) StateTypeName() string {
@@ -100,6 +123,7 @@ func (x *SocketVFS2) StateFields() []string {
"FileDescriptionDefaultImpl",
"DentryMetadataFileDescriptionImpl",
"LockFD",
+ "socketVFS2Refs",
"socketOpsCommon",
}
}
@@ -112,7 +136,8 @@ func (x *SocketVFS2) StateSave(m state.Sink) {
m.Save(1, &x.FileDescriptionDefaultImpl)
m.Save(2, &x.DentryMetadataFileDescriptionImpl)
m.Save(3, &x.LockFD)
- m.Save(4, &x.socketOpsCommon)
+ m.Save(4, &x.socketVFS2Refs)
+ m.Save(5, &x.socketOpsCommon)
}
func (x *SocketVFS2) afterLoad() {}
@@ -122,11 +147,13 @@ func (x *SocketVFS2) StateLoad(m state.Source) {
m.Load(1, &x.FileDescriptionDefaultImpl)
m.Load(2, &x.DentryMetadataFileDescriptionImpl)
m.Load(3, &x.LockFD)
- m.Load(4, &x.socketOpsCommon)
+ m.Load(4, &x.socketVFS2Refs)
+ m.Load(5, &x.socketOpsCommon)
}
func init() {
- state.Register((*socketOpsCommonRefs)(nil))
+ state.Register((*socketOperationsRefs)(nil))
+ state.Register((*socketVFS2Refs)(nil))
state.Register((*SocketOperations)(nil))
state.Register((*socketOpsCommon)(nil))
state.Register((*SocketVFS2)(nil))
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 8b1abd922..3345124cc 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -45,6 +45,7 @@ type SocketVFS2 struct {
vfs.DentryMetadataFileDescriptionImpl
vfs.LockFD
+ socketVFS2Refs
socketOpsCommon
}
@@ -91,6 +92,25 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
return vfsfd, nil
}
+// DecRef implements RefCounter.DecRef.
+func (s *SocketVFS2) DecRef(ctx context.Context) {
+ s.socketVFS2Refs.DecRef(func() {
+ t := kernel.TaskFromContext(ctx)
+ t.Kernel().DeleteSocketVFS2(&s.vfsfd)
+ s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
+ })
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *SocketVFS2) Release(ctx context.Context) {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef(ctx)
+}
+
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {