diff options
-rw-r--r-- | pkg/sentry/kernel/kernel.go | 83 | ||||
-rw-r--r-- | pkg/sentry/kernel/kernel_state_autogen.go | 103 | ||||
-rw-r--r-- | pkg/sentry/kernel/socket_list.go | 32 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket_vfs2.go | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket_vfs2.go | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_vfs2.go | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/socket_refs.go | 24 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/socket_vfs2_refs.go | 118 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 36 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix_state_autogen.go | 77 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix_vfs2.go | 20 |
11 files changed, 381 insertions, 133 deletions
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 08bb5bd12..d6c21adb7 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -220,13 +220,18 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // sockets is the list of all network sockets the system. Protected by - // extMu. + // sockets is the list of all network sockets in the system. + // Protected by extMu. + // TODO(gvisor.dev/issue/1624): Only used by VFS1. sockets socketList - // nextSocketEntry is the next entry number to use in sockets. Protected + // socketsVFS2 records all network sockets in the system. Protected by + // extMu. + socketsVFS2 map[*vfs.FileDescription]*SocketRecord + + // nextSocketRecord is the next entry number to use in sockets. Protected // by extMu. - nextSocketEntry uint64 + nextSocketRecord uint64 // deviceRegistry is used to save/restore device.SimpleDevices. deviceRegistry struct{} `state:".(*device.Registry)"` @@ -414,6 +419,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { return fmt.Errorf("failed to create sockfs mount: %v", err) } k.socketMount = socketMount + + k.socketsVFS2 = make(map[*vfs.FileDescription]*SocketRecord) } return nil @@ -1512,20 +1519,27 @@ func (k *Kernel) SupervisorContext() context.Context { } } -// SocketEntry represents a socket recorded in Kernel.sockets. It implements +// SocketRecord represents a socket recorded in Kernel.socketsVFS2. +// +// +stateify savable +type SocketRecord struct { + k *Kernel + Sock *refs.WeakRef // TODO(gvisor.dev/issue/1624): Only used by VFS1. + SockVFS2 *vfs.FileDescription // Only used by VFS2. + ID uint64 // Socket table entry number. +} + +// SocketRecordVFS1 represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // // +stateify savable -type SocketEntry struct { +type SocketRecordVFS1 struct { socketEntry - k *Kernel - Sock *refs.WeakRef - SockVFS2 *vfs.FileDescription - ID uint64 // Socket table entry number. + SocketRecord } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *SocketEntry) WeakRefGone(context.Context) { +func (s *SocketRecordVFS1) WeakRefGone(context.Context) { s.k.extMu.Lock() s.k.sockets.Remove(s) s.k.extMu.Unlock() @@ -1536,9 +1550,14 @@ func (s *SocketEntry) WeakRefGone(context.Context) { // Precondition: Caller must hold a reference to sock. func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() - id := k.nextSocketEntry - k.nextSocketEntry++ - s := &SocketEntry{k: k, ID: id} + id := k.nextSocketRecord + k.nextSocketRecord++ + s := &SocketRecordVFS1{ + SocketRecord: SocketRecord{ + k: k, + ID: id, + }, + } s.Sock = refs.NewWeakRef(sock, s) k.sockets.PushBack(s) k.extMu.Unlock() @@ -1550,29 +1569,45 @@ func (k *Kernel) RecordSocket(sock *fs.File) { // Precondition: Caller must hold a reference to sock. // // Note that the socket table will not hold a reference on the -// vfs.FileDescription, because we do not support weak refs on VFS2 files. +// vfs.FileDescription. func (k *Kernel) RecordSocketVFS2(sock *vfs.FileDescription) { k.extMu.Lock() - id := k.nextSocketEntry - k.nextSocketEntry++ - s := &SocketEntry{ + if _, ok := k.socketsVFS2[sock]; ok { + panic(fmt.Sprintf("Socket %p added twice", sock)) + } + id := k.nextSocketRecord + k.nextSocketRecord++ + s := &SocketRecord{ k: k, ID: id, SockVFS2: sock, } - k.sockets.PushBack(s) + k.socketsVFS2[sock] = s + k.extMu.Unlock() +} + +// DeleteSocketVFS2 removes a VFS2 socket from the system-wide socket table. +func (k *Kernel) DeleteSocketVFS2(sock *vfs.FileDescription) { + k.extMu.Lock() + delete(k.socketsVFS2, sock) k.extMu.Unlock() } // ListSockets returns a snapshot of all sockets. // -// Callers of ListSockets() in VFS2 should use SocketEntry.SockVFS2.TryIncRef() +// Callers of ListSockets() in VFS2 should use SocketRecord.SockVFS2.TryIncRef() // to get a reference on a socket in the table. -func (k *Kernel) ListSockets() []*SocketEntry { +func (k *Kernel) ListSockets() []*SocketRecord { k.extMu.Lock() - var socks []*SocketEntry - for s := k.sockets.Front(); s != nil; s = s.Next() { - socks = append(socks, s) + var socks []*SocketRecord + if VFS2Enabled { + for _, s := range k.socketsVFS2 { + socks = append(socks, s) + } + } else { + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, &s.SocketRecord) + } } k.extMu.Unlock() return socks diff --git a/pkg/sentry/kernel/kernel_state_autogen.go b/pkg/sentry/kernel/kernel_state_autogen.go index d0ff135d7..f20800960 100644 --- a/pkg/sentry/kernel/kernel_state_autogen.go +++ b/pkg/sentry/kernel/kernel_state_autogen.go @@ -297,7 +297,8 @@ func (x *Kernel) StateFields() []string { "netlinkPorts", "danglingEndpoints", "sockets", - "nextSocketEntry", + "socketsVFS2", + "nextSocketRecord", "deviceRegistry", "DirentCacheLimiter", "SpecialOpts", @@ -317,7 +318,7 @@ func (x *Kernel) StateSave(m state.Sink) { var danglingEndpoints []tcpip.Endpoint = x.saveDanglingEndpoints() m.SaveValue(24, danglingEndpoints) var deviceRegistry *device.Registry = x.saveDeviceRegistry() - m.SaveValue(27, deviceRegistry) + m.SaveValue(28, deviceRegistry) m.Save(0, &x.featureSet) m.Save(1, &x.timekeeper) m.Save(2, &x.tasks) @@ -343,15 +344,16 @@ func (x *Kernel) StateSave(m state.Sink) { m.Save(22, &x.nextInotifyCookie) m.Save(23, &x.netlinkPorts) m.Save(25, &x.sockets) - m.Save(26, &x.nextSocketEntry) - m.Save(28, &x.DirentCacheLimiter) - m.Save(29, &x.SpecialOpts) - m.Save(30, &x.vfs) - m.Save(31, &x.hostMount) - m.Save(32, &x.pipeMount) - m.Save(33, &x.shmMount) - m.Save(34, &x.socketMount) - m.Save(35, &x.SleepForAddressSpaceActivation) + m.Save(26, &x.socketsVFS2) + m.Save(27, &x.nextSocketRecord) + m.Save(29, &x.DirentCacheLimiter) + m.Save(30, &x.SpecialOpts) + m.Save(31, &x.vfs) + m.Save(32, &x.hostMount) + m.Save(33, &x.pipeMount) + m.Save(34, &x.shmMount) + m.Save(35, &x.socketMount) + m.Save(36, &x.SleepForAddressSpaceActivation) } func (x *Kernel) afterLoad() {} @@ -382,26 +384,26 @@ func (x *Kernel) StateLoad(m state.Source) { m.Load(22, &x.nextInotifyCookie) m.Load(23, &x.netlinkPorts) m.Load(25, &x.sockets) - m.Load(26, &x.nextSocketEntry) - m.Load(28, &x.DirentCacheLimiter) - m.Load(29, &x.SpecialOpts) - m.Load(30, &x.vfs) - m.Load(31, &x.hostMount) - m.Load(32, &x.pipeMount) - m.Load(33, &x.shmMount) - m.Load(34, &x.socketMount) - m.Load(35, &x.SleepForAddressSpaceActivation) + m.Load(26, &x.socketsVFS2) + m.Load(27, &x.nextSocketRecord) + m.Load(29, &x.DirentCacheLimiter) + m.Load(30, &x.SpecialOpts) + m.Load(31, &x.vfs) + m.Load(32, &x.hostMount) + m.Load(33, &x.pipeMount) + m.Load(34, &x.shmMount) + m.Load(35, &x.socketMount) + m.Load(36, &x.SleepForAddressSpaceActivation) m.LoadValue(24, new([]tcpip.Endpoint), func(y interface{}) { x.loadDanglingEndpoints(y.([]tcpip.Endpoint)) }) - m.LoadValue(27, new(*device.Registry), func(y interface{}) { x.loadDeviceRegistry(y.(*device.Registry)) }) + m.LoadValue(28, new(*device.Registry), func(y interface{}) { x.loadDeviceRegistry(y.(*device.Registry)) }) } -func (x *SocketEntry) StateTypeName() string { - return "pkg/sentry/kernel.SocketEntry" +func (x *SocketRecord) StateTypeName() string { + return "pkg/sentry/kernel.SocketRecord" } -func (x *SocketEntry) StateFields() []string { +func (x *SocketRecord) StateFields() []string { return []string{ - "socketEntry", "k", "Sock", "SockVFS2", @@ -409,25 +411,49 @@ func (x *SocketEntry) StateFields() []string { } } -func (x *SocketEntry) beforeSave() {} +func (x *SocketRecord) beforeSave() {} -func (x *SocketEntry) StateSave(m state.Sink) { +func (x *SocketRecord) StateSave(m state.Sink) { + x.beforeSave() + m.Save(0, &x.k) + m.Save(1, &x.Sock) + m.Save(2, &x.SockVFS2) + m.Save(3, &x.ID) +} + +func (x *SocketRecord) afterLoad() {} + +func (x *SocketRecord) StateLoad(m state.Source) { + m.Load(0, &x.k) + m.Load(1, &x.Sock) + m.Load(2, &x.SockVFS2) + m.Load(3, &x.ID) +} + +func (x *SocketRecordVFS1) StateTypeName() string { + return "pkg/sentry/kernel.SocketRecordVFS1" +} + +func (x *SocketRecordVFS1) StateFields() []string { + return []string{ + "socketEntry", + "SocketRecord", + } +} + +func (x *SocketRecordVFS1) beforeSave() {} + +func (x *SocketRecordVFS1) StateSave(m state.Sink) { x.beforeSave() m.Save(0, &x.socketEntry) - m.Save(1, &x.k) - m.Save(2, &x.Sock) - m.Save(3, &x.SockVFS2) - m.Save(4, &x.ID) + m.Save(1, &x.SocketRecord) } -func (x *SocketEntry) afterLoad() {} +func (x *SocketRecordVFS1) afterLoad() {} -func (x *SocketEntry) StateLoad(m state.Source) { +func (x *SocketRecordVFS1) StateLoad(m state.Source) { m.Load(0, &x.socketEntry) - m.Load(1, &x.k) - m.Load(2, &x.Sock) - m.Load(3, &x.SockVFS2) - m.Load(4, &x.ID) + m.Load(1, &x.SocketRecord) } func (x *pendingSignals) StateTypeName() string { @@ -2264,7 +2290,8 @@ func init() { state.Register((*FSContextRefs)(nil)) state.Register((*IPCNamespace)(nil)) state.Register((*Kernel)(nil)) - state.Register((*SocketEntry)(nil)) + state.Register((*SocketRecord)(nil)) + state.Register((*SocketRecordVFS1)(nil)) state.Register((*pendingSignals)(nil)) state.Register((*pendingSignalQueue)(nil)) state.Register((*pendingSignal)(nil)) diff --git a/pkg/sentry/kernel/socket_list.go b/pkg/sentry/kernel/socket_list.go index d2d4307a1..246fba405 100644 --- a/pkg/sentry/kernel/socket_list.go +++ b/pkg/sentry/kernel/socket_list.go @@ -13,7 +13,7 @@ type socketElementMapper struct{} // This default implementation should be inlined. // //go:nosplit -func (socketElementMapper) linkerFor(elem *SocketEntry) *SocketEntry { return elem } +func (socketElementMapper) linkerFor(elem *SocketRecordVFS1) *SocketRecordVFS1 { return elem } // List is an intrusive list. Entries can be added to or removed from the list // in O(1) time and with no additional memory allocations. @@ -27,8 +27,8 @@ func (socketElementMapper) linkerFor(elem *SocketEntry) *SocketEntry { return el // // +stateify savable type socketList struct { - head *SocketEntry - tail *SocketEntry + head *SocketRecordVFS1 + tail *SocketRecordVFS1 } // Reset resets list l to the empty state. @@ -43,12 +43,12 @@ func (l *socketList) Empty() bool { } // Front returns the first element of list l or nil. -func (l *socketList) Front() *SocketEntry { +func (l *socketList) Front() *SocketRecordVFS1 { return l.head } // Back returns the last element of list l or nil. -func (l *socketList) Back() *SocketEntry { +func (l *socketList) Back() *SocketRecordVFS1 { return l.tail } @@ -63,7 +63,7 @@ func (l *socketList) Len() (count int) { } // PushFront inserts the element e at the front of list l. -func (l *socketList) PushFront(e *SocketEntry) { +func (l *socketList) PushFront(e *SocketRecordVFS1) { linker := socketElementMapper{}.linkerFor(e) linker.SetNext(l.head) linker.SetPrev(nil) @@ -77,7 +77,7 @@ func (l *socketList) PushFront(e *SocketEntry) { } // PushBack inserts the element e at the back of list l. -func (l *socketList) PushBack(e *SocketEntry) { +func (l *socketList) PushBack(e *SocketRecordVFS1) { linker := socketElementMapper{}.linkerFor(e) linker.SetNext(nil) linker.SetPrev(l.tail) @@ -106,7 +106,7 @@ func (l *socketList) PushBackList(m *socketList) { } // InsertAfter inserts e after b. -func (l *socketList) InsertAfter(b, e *SocketEntry) { +func (l *socketList) InsertAfter(b, e *SocketRecordVFS1) { bLinker := socketElementMapper{}.linkerFor(b) eLinker := socketElementMapper{}.linkerFor(e) @@ -124,7 +124,7 @@ func (l *socketList) InsertAfter(b, e *SocketEntry) { } // InsertBefore inserts e before a. -func (l *socketList) InsertBefore(a, e *SocketEntry) { +func (l *socketList) InsertBefore(a, e *SocketRecordVFS1) { aLinker := socketElementMapper{}.linkerFor(a) eLinker := socketElementMapper{}.linkerFor(e) @@ -141,7 +141,7 @@ func (l *socketList) InsertBefore(a, e *SocketEntry) { } // Remove removes e from l. -func (l *socketList) Remove(e *SocketEntry) { +func (l *socketList) Remove(e *SocketRecordVFS1) { linker := socketElementMapper{}.linkerFor(e) prev := linker.Prev() next := linker.Next() @@ -168,26 +168,26 @@ func (l *socketList) Remove(e *SocketEntry) { // // +stateify savable type socketEntry struct { - next *SocketEntry - prev *SocketEntry + next *SocketRecordVFS1 + prev *SocketRecordVFS1 } // Next returns the entry that follows e in the list. -func (e *socketEntry) Next() *SocketEntry { +func (e *socketEntry) Next() *SocketRecordVFS1 { return e.next } // Prev returns the entry that precedes e in the list. -func (e *socketEntry) Prev() *SocketEntry { +func (e *socketEntry) Prev() *SocketRecordVFS1 { return e.prev } // SetNext assigns 'entry' as the entry that follows e in the list. -func (e *socketEntry) SetNext(elem *SocketEntry) { +func (e *socketEntry) SetNext(elem *SocketRecordVFS1) { e.next = elem } // SetPrev assigns 'entry' as the entry that precedes e in the list. -func (e *socketEntry) SetPrev(elem *SocketEntry) { +func (e *socketEntry) SetPrev(elem *SocketRecordVFS1) { e.prev = elem } diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 87b077e68..163af329b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -78,6 +78,13 @@ func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol in return vfsfd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *socketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index a38d25da9..c83b23242 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -82,6 +82,13 @@ func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketV return fd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index c0212ad76..4c6791fff 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -79,6 +79,13 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu return vfsfd, nil } +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.socketOpsCommon.Release(ctx) +} + // Readiness implements waiter.Waitable.Readiness. func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { return s.socketOpsCommon.Readiness(mask) diff --git a/pkg/sentry/socket/unix/socket_refs.go b/pkg/sentry/socket/unix/socket_refs.go index dababb85f..ea63dc659 100644 --- a/pkg/sentry/socket/unix/socket_refs.go +++ b/pkg/sentry/socket/unix/socket_refs.go @@ -11,7 +11,7 @@ import ( // ownerType is used to customize logging. Note that we use a pointer to T so // that we do not copy the entire object when passed as a format parameter. -var socketOpsCommonownerType *socketOpsCommon +var socketOperationsownerType *SocketOperations // Refs implements refs.RefCounter. It keeps a reference count using atomic // operations and calls the destructor when the count reaches zero. @@ -25,7 +25,7 @@ var socketOpsCommonownerType *socketOpsCommon // without growing the size of Refs. // // +stateify savable -type socketOpsCommonRefs struct { +type socketOperationsRefs struct { // refCount is composed of two fields: // // [32-bit speculative references]:[32-bit real references] @@ -36,7 +36,7 @@ type socketOpsCommonRefs struct { refCount int64 } -func (r *socketOpsCommonRefs) finalize() { +func (r *socketOperationsRefs) finalize() { var note string switch refs_vfs1.GetLeakMode() { case refs_vfs1.NoLeakChecking: @@ -45,20 +45,20 @@ func (r *socketOpsCommonRefs) finalize() { note = "(Leak checker uninitialized): " } if n := r.ReadRefs(); n != 0 { - log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, socketOpsCommonownerType, n) + log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, socketOperationsownerType, n) } } // EnableLeakCheck checks for reference leaks when Refs gets garbage collected. -func (r *socketOpsCommonRefs) EnableLeakCheck() { +func (r *socketOperationsRefs) EnableLeakCheck() { if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { - runtime.SetFinalizer(r, (*socketOpsCommonRefs).finalize) + runtime.SetFinalizer(r, (*socketOperationsRefs).finalize) } } // ReadRefs returns the current number of references. The returned count is // inherently racy and is unsafe to use without external synchronization. -func (r *socketOpsCommonRefs) ReadRefs() int64 { +func (r *socketOperationsRefs) ReadRefs() int64 { return atomic.LoadInt64(&r.refCount) + 1 } @@ -66,9 +66,9 @@ func (r *socketOpsCommonRefs) ReadRefs() int64 { // IncRef implements refs.RefCounter.IncRef. // //go:nosplit -func (r *socketOpsCommonRefs) IncRef() { +func (r *socketOperationsRefs) IncRef() { if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { - panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, socketOpsCommonownerType)) + panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, socketOperationsownerType)) } } @@ -79,7 +79,7 @@ func (r *socketOpsCommonRefs) IncRef() { // other TryIncRef calls from genuine references held. // //go:nosplit -func (r *socketOpsCommonRefs) TryIncRef() bool { +func (r *socketOperationsRefs) TryIncRef() bool { const speculativeRef = 1 << 32 v := atomic.AddInt64(&r.refCount, speculativeRef) if int32(v) < 0 { @@ -104,10 +104,10 @@ func (r *socketOpsCommonRefs) TryIncRef() bool { // A: TryIncRef [transform speculative to real] // //go:nosplit -func (r *socketOpsCommonRefs) DecRef(destroy func()) { +func (r *socketOperationsRefs) DecRef(destroy func()) { switch v := atomic.AddInt64(&r.refCount, -1); { case v < -1: - panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, socketOpsCommonownerType)) + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, socketOperationsownerType)) case v == -1: diff --git a/pkg/sentry/socket/unix/socket_vfs2_refs.go b/pkg/sentry/socket/unix/socket_vfs2_refs.go new file mode 100644 index 000000000..dc55f2947 --- /dev/null +++ b/pkg/sentry/socket/unix/socket_vfs2_refs.go @@ -0,0 +1,118 @@ +package unix + +import ( + "fmt" + "runtime" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/log" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" +) + +// ownerType is used to customize logging. Note that we use a pointer to T so +// that we do not copy the entire object when passed as a format parameter. +var socketVFS2ownerType *SocketVFS2 + +// Refs implements refs.RefCounter. It keeps a reference count using atomic +// operations and calls the destructor when the count reaches zero. +// +// Note that the number of references is actually refCount + 1 so that a default +// zero-value Refs object contains one reference. +// +// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in +// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount. +// This will allow us to add stack trace information to the leak messages +// without growing the size of Refs. +// +// +stateify savable +type socketVFS2Refs struct { + // refCount is composed of two fields: + // + // [32-bit speculative references]:[32-bit real references] + // + // Speculative references are used for TryIncRef, to avoid a CompareAndSwap + // loop. See IncRef, DecRef and TryIncRef for details of how these fields are + // used. + refCount int64 +} + +func (r *socketVFS2Refs) finalize() { + var note string + switch refs_vfs1.GetLeakMode() { + case refs_vfs1.NoLeakChecking: + return + case refs_vfs1.UninitializedLeakChecking: + note = "(Leak checker uninitialized): " + } + if n := r.ReadRefs(); n != 0 { + log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, socketVFS2ownerType, n) + } +} + +// EnableLeakCheck checks for reference leaks when Refs gets garbage collected. +func (r *socketVFS2Refs) EnableLeakCheck() { + if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { + runtime.SetFinalizer(r, (*socketVFS2Refs).finalize) + } +} + +// ReadRefs returns the current number of references. The returned count is +// inherently racy and is unsafe to use without external synchronization. +func (r *socketVFS2Refs) ReadRefs() int64 { + + return atomic.LoadInt64(&r.refCount) + 1 +} + +// IncRef implements refs.RefCounter.IncRef. +// +//go:nosplit +func (r *socketVFS2Refs) IncRef() { + if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { + panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, socketVFS2ownerType)) + } +} + +// TryIncRef implements refs.RefCounter.TryIncRef. +// +// To do this safely without a loop, a speculative reference is first acquired +// on the object. This allows multiple concurrent TryIncRef calls to distinguish +// other TryIncRef calls from genuine references held. +// +//go:nosplit +func (r *socketVFS2Refs) TryIncRef() bool { + const speculativeRef = 1 << 32 + v := atomic.AddInt64(&r.refCount, speculativeRef) + if int32(v) < 0 { + + atomic.AddInt64(&r.refCount, -speculativeRef) + return false + } + + atomic.AddInt64(&r.refCount, -speculativeRef+1) + return true +} + +// DecRef implements refs.RefCounter.DecRef. +// +// Note that speculative references are counted here. Since they were added +// prior to real references reaching zero, they will successfully convert to +// real references. In other words, we see speculative references only in the +// following case: +// +// A: TryIncRef [speculative increase => sees non-negative references] +// B: DecRef [real decrease] +// A: TryIncRef [transform speculative to real] +// +//go:nosplit +func (r *socketVFS2Refs) DecRef(destroy func()) { + switch v := atomic.AddInt64(&r.refCount, -1); { + case v < -1: + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, socketVFS2ownerType)) + + case v == -1: + + if destroy != nil { + destroy() + } + } +} diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 917055cea..f80011ce4 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -55,6 +55,7 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + socketOperationsRefs socketOpsCommon } @@ -84,11 +85,27 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty return fs.NewFile(ctx, d, flags, &s) } +// DecRef implements RefCounter.DecRef. +func (s *SocketOperations) DecRef(ctx context.Context) { + s.socketOperationsRefs.DecRef(func() { + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } + }) +} + +// Release implemements fs.FileOperations.Release. +func (s *SocketOperations) Release(ctx context.Context) { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef(ctx) +} + // socketOpsCommon contains the socket operations common to VFS1 and VFS2. // // +stateify savable type socketOpsCommon struct { - socketOpsCommonRefs socket.SendReceiveTimeout ep transport.Endpoint @@ -101,23 +118,6 @@ type socketOpsCommon struct { abstractNamespace *kernel.AbstractSocketNamespace } -// DecRef implements RefCounter.DecRef. -func (s *socketOpsCommon) DecRef(ctx context.Context) { - s.socketOpsCommonRefs.DecRef(func() { - s.ep.Close(ctx) - if s.abstractNamespace != nil { - s.abstractNamespace.Remove(s.abstractName, s) - } - }) -} - -// Release implemements fs.FileOperations.Release. -func (s *socketOpsCommon) Release(ctx context.Context) { - // Release only decrements a reference on s because s may be referenced in - // the abstract socket namespace. - s.DecRef(ctx) -} - func (s *socketOpsCommon) isPacket() bool { switch s.stype { case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go index 89d78a9ad..51fd66b78 100644 --- a/pkg/sentry/socket/unix/unix_state_autogen.go +++ b/pkg/sentry/socket/unix/unix_state_autogen.go @@ -6,26 +6,49 @@ import ( "gvisor.dev/gvisor/pkg/state" ) -func (x *socketOpsCommonRefs) StateTypeName() string { - return "pkg/sentry/socket/unix.socketOpsCommonRefs" +func (x *socketOperationsRefs) StateTypeName() string { + return "pkg/sentry/socket/unix.socketOperationsRefs" } -func (x *socketOpsCommonRefs) StateFields() []string { +func (x *socketOperationsRefs) StateFields() []string { return []string{ "refCount", } } -func (x *socketOpsCommonRefs) beforeSave() {} +func (x *socketOperationsRefs) beforeSave() {} -func (x *socketOpsCommonRefs) StateSave(m state.Sink) { +func (x *socketOperationsRefs) StateSave(m state.Sink) { x.beforeSave() m.Save(0, &x.refCount) } -func (x *socketOpsCommonRefs) afterLoad() {} +func (x *socketOperationsRefs) afterLoad() {} -func (x *socketOpsCommonRefs) StateLoad(m state.Source) { +func (x *socketOperationsRefs) StateLoad(m state.Source) { + m.Load(0, &x.refCount) +} + +func (x *socketVFS2Refs) StateTypeName() string { + return "pkg/sentry/socket/unix.socketVFS2Refs" +} + +func (x *socketVFS2Refs) StateFields() []string { + return []string{ + "refCount", + } +} + +func (x *socketVFS2Refs) beforeSave() {} + +func (x *socketVFS2Refs) StateSave(m state.Sink) { + x.beforeSave() + m.Save(0, &x.refCount) +} + +func (x *socketVFS2Refs) afterLoad() {} + +func (x *socketVFS2Refs) StateLoad(m state.Source) { m.Load(0, &x.refCount) } @@ -35,6 +58,7 @@ func (x *SocketOperations) StateTypeName() string { func (x *SocketOperations) StateFields() []string { return []string{ + "socketOperationsRefs", "socketOpsCommon", } } @@ -43,13 +67,15 @@ func (x *SocketOperations) beforeSave() {} func (x *SocketOperations) StateSave(m state.Sink) { x.beforeSave() - m.Save(0, &x.socketOpsCommon) + m.Save(0, &x.socketOperationsRefs) + m.Save(1, &x.socketOpsCommon) } func (x *SocketOperations) afterLoad() {} func (x *SocketOperations) StateLoad(m state.Source) { - m.Load(0, &x.socketOpsCommon) + m.Load(0, &x.socketOperationsRefs) + m.Load(1, &x.socketOpsCommon) } func (x *socketOpsCommon) StateTypeName() string { @@ -58,7 +84,6 @@ func (x *socketOpsCommon) StateTypeName() string { func (x *socketOpsCommon) StateFields() []string { return []string{ - "socketOpsCommonRefs", "SendReceiveTimeout", "ep", "stype", @@ -71,23 +96,21 @@ func (x *socketOpsCommon) beforeSave() {} func (x *socketOpsCommon) StateSave(m state.Sink) { x.beforeSave() - m.Save(0, &x.socketOpsCommonRefs) - m.Save(1, &x.SendReceiveTimeout) - m.Save(2, &x.ep) - m.Save(3, &x.stype) - m.Save(4, &x.abstractName) - m.Save(5, &x.abstractNamespace) + m.Save(0, &x.SendReceiveTimeout) + m.Save(1, &x.ep) + m.Save(2, &x.stype) + m.Save(3, &x.abstractName) + m.Save(4, &x.abstractNamespace) } func (x *socketOpsCommon) afterLoad() {} func (x *socketOpsCommon) StateLoad(m state.Source) { - m.Load(0, &x.socketOpsCommonRefs) - m.Load(1, &x.SendReceiveTimeout) - m.Load(2, &x.ep) - m.Load(3, &x.stype) - m.Load(4, &x.abstractName) - m.Load(5, &x.abstractNamespace) + m.Load(0, &x.SendReceiveTimeout) + m.Load(1, &x.ep) + m.Load(2, &x.stype) + m.Load(3, &x.abstractName) + m.Load(4, &x.abstractNamespace) } func (x *SocketVFS2) StateTypeName() string { @@ -100,6 +123,7 @@ func (x *SocketVFS2) StateFields() []string { "FileDescriptionDefaultImpl", "DentryMetadataFileDescriptionImpl", "LockFD", + "socketVFS2Refs", "socketOpsCommon", } } @@ -112,7 +136,8 @@ func (x *SocketVFS2) StateSave(m state.Sink) { m.Save(1, &x.FileDescriptionDefaultImpl) m.Save(2, &x.DentryMetadataFileDescriptionImpl) m.Save(3, &x.LockFD) - m.Save(4, &x.socketOpsCommon) + m.Save(4, &x.socketVFS2Refs) + m.Save(5, &x.socketOpsCommon) } func (x *SocketVFS2) afterLoad() {} @@ -122,11 +147,13 @@ func (x *SocketVFS2) StateLoad(m state.Source) { m.Load(1, &x.FileDescriptionDefaultImpl) m.Load(2, &x.DentryMetadataFileDescriptionImpl) m.Load(3, &x.LockFD) - m.Load(4, &x.socketOpsCommon) + m.Load(4, &x.socketVFS2Refs) + m.Load(5, &x.socketOpsCommon) } func init() { - state.Register((*socketOpsCommonRefs)(nil)) + state.Register((*socketOperationsRefs)(nil)) + state.Register((*socketVFS2Refs)(nil)) state.Register((*SocketOperations)(nil)) state.Register((*socketOpsCommon)(nil)) state.Register((*SocketVFS2)(nil)) diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 8b1abd922..3345124cc 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -45,6 +45,7 @@ type SocketVFS2 struct { vfs.DentryMetadataFileDescriptionImpl vfs.LockFD + socketVFS2Refs socketOpsCommon } @@ -91,6 +92,25 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 return vfsfd, nil } +// DecRef implements RefCounter.DecRef. +func (s *SocketVFS2) DecRef(ctx context.Context) { + s.socketVFS2Refs.DecRef(func() { + t := kernel.TaskFromContext(ctx) + t.Kernel().DeleteSocketVFS2(&s.vfsfd) + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } + }) +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (s *SocketVFS2) Release(ctx context.Context) { + // Release only decrements a reference on s because s may be referenced in + // the abstract socket namespace. + s.DecRef(ctx) +} + // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { |