summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/linux_abi_autogen_unsafe.go32
-rw-r--r--pkg/refs_vfs2/refs.go36
-rw-r--r--pkg/refs_vfs2/refs_vfs2_state_autogen.go3
-rw-r--r--pkg/sentry/fsimpl/tmpfs/inode_refs.go5
-rw-r--r--pkg/sentry/kernel/abstract_socket_namespace.go77
-rw-r--r--pkg/sentry/kernel/kernel_state_autogen.go6
-rw-r--r--pkg/sentry/platform/ring0/defs_impl_arm64.go2
-rw-r--r--pkg/sentry/socket/unix/socket_refs.go111
-rw-r--r--pkg/sentry/socket/unix/unix.go22
-rw-r--r--pkg/sentry/socket/unix/unix_state_autogen.go36
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go6
11 files changed, 277 insertions, 59 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go
index 92451b60e..63380dbb4 100644
--- a/pkg/abi/linux/linux_abi_autogen_unsafe.go
+++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go
@@ -167,7 +167,7 @@ func (s *Statx) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (s *Statx) UnmarshalUnsafe(src []byte) {
- if s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() {
+ if s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
safecopy.CopyOut(unsafe.Pointer(s), src)
} else {
// Type Statx doesn't have a packed layout in memory, fallback to UnmarshalBytes.
@@ -178,7 +178,7 @@ func (s *Statx) UnmarshalUnsafe(src []byte) {
// CopyOutN implements marshal.Marshallable.CopyOutN.
//go:nosplit
func (s *Statx) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
- if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
+ if !s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
s.MarshalBytes(buf) // escapes: fallback.
@@ -208,7 +208,7 @@ func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
+ if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
@@ -627,12 +627,12 @@ func (f *FUSEHeaderIn) UnmarshalBytes(src []byte) {
// Packed implements marshal.Marshallable.Packed.
//go:nosplit
func (f *FUSEHeaderIn) Packed() bool {
- return f.Opcode.Packed() && f.Unique.Packed()
+ return f.Unique.Packed() && f.Opcode.Packed()
}
// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
func (f *FUSEHeaderIn) MarshalUnsafe(dst []byte) {
- if f.Opcode.Packed() && f.Unique.Packed() {
+ if f.Unique.Packed() && f.Opcode.Packed() {
safecopy.CopyIn(dst, unsafe.Pointer(f))
} else {
// Type FUSEHeaderIn doesn't have a packed layout in memory, fallback to MarshalBytes.
@@ -2213,7 +2213,7 @@ func (i *IPTIP) Packed() bool {
// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
func (i *IPTIP) MarshalUnsafe(dst []byte) {
- if i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed() {
+ if i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
safecopy.CopyIn(dst, unsafe.Pointer(i))
} else {
// Type IPTIP doesn't have a packed layout in memory, fallback to MarshalBytes.
@@ -2223,7 +2223,7 @@ func (i *IPTIP) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (i *IPTIP) UnmarshalUnsafe(src []byte) {
- if i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
+ if i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed() {
safecopy.CopyOut(unsafe.Pointer(i), src)
} else {
// Type IPTIP doesn't have a packed layout in memory, fallback to UnmarshalBytes.
@@ -2264,7 +2264,7 @@ func (i *IPTIP) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (i *IPTIP) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
+ if !i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed() {
// Type IPTIP doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(i.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
@@ -3014,7 +3014,7 @@ func (i *IP6TEntry) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (i *IP6TEntry) UnmarshalUnsafe(src []byte) {
- if i.IPv6.Packed() && i.Counters.Packed() {
+ if i.Counters.Packed() && i.IPv6.Packed() {
safecopy.CopyOut(unsafe.Pointer(i), src)
} else {
// Type IP6TEntry doesn't have a packed layout in memory, fallback to UnmarshalBytes.
@@ -3055,7 +3055,7 @@ func (i *IP6TEntry) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (i *IP6TEntry) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !i.IPv6.Packed() && i.Counters.Packed() {
+ if !i.Counters.Packed() && i.IPv6.Packed() {
// Type IP6TEntry doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(i.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
@@ -3081,7 +3081,7 @@ func (i *IP6TEntry) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
// WriteTo implements io.WriterTo.WriteTo.
func (i *IP6TEntry) WriteTo(w io.Writer) (int64, error) {
- if !i.Counters.Packed() && i.IPv6.Packed() {
+ if !i.IPv6.Packed() && i.Counters.Packed() {
// Type IP6TEntry doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := make([]byte, i.SizeBytes())
i.MarshalBytes(buf)
@@ -3196,12 +3196,12 @@ func (i *IP6TIP) UnmarshalBytes(src []byte) {
// Packed implements marshal.Marshallable.Packed.
//go:nosplit
func (i *IP6TIP) Packed() bool {
- return i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed()
+ return i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed()
}
// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
func (i *IP6TIP) MarshalUnsafe(dst []byte) {
- if i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
+ if i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() {
safecopy.CopyIn(dst, unsafe.Pointer(i))
} else {
// Type IP6TIP doesn't have a packed layout in memory, fallback to MarshalBytes.
@@ -3211,7 +3211,7 @@ func (i *IP6TIP) MarshalUnsafe(dst []byte) {
// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
func (i *IP6TIP) UnmarshalUnsafe(src []byte) {
- if i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() {
+ if i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
safecopy.CopyOut(unsafe.Pointer(i), src)
} else {
// Type IP6TIP doesn't have a packed layout in memory, fallback to UnmarshalBytes.
@@ -3252,7 +3252,7 @@ func (i *IP6TIP) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (i *IP6TIP) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() && i.Dst.Packed() {
+ if !i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
// Type IP6TIP doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(i.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
@@ -3278,7 +3278,7 @@ func (i *IP6TIP) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
// WriteTo implements io.WriterTo.WriteTo.
func (i *IP6TIP) WriteTo(w io.Writer) (int64, error) {
- if !i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() && i.Src.Packed() {
+ if !i.Src.Packed() && i.Dst.Packed() && i.SrcMask.Packed() && i.DstMask.Packed() {
// Type IP6TIP doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := make([]byte, i.SizeBytes())
i.MarshalBytes(buf)
diff --git a/pkg/refs_vfs2/refs.go b/pkg/refs_vfs2/refs.go
new file mode 100644
index 000000000..99a074e96
--- /dev/null
+++ b/pkg/refs_vfs2/refs.go
@@ -0,0 +1,36 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package refs_vfs2 defines an interface for a reference-counted object.
+package refs_vfs2
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+)
+
+// RefCounter is the interface to be implemented by objects that are reference
+// counted.
+type RefCounter interface {
+ // IncRef increments the reference counter on the object.
+ IncRef()
+
+ // DecRef decrements the object's reference count. Users of refs_template.Refs
+ // may specify a destructor to be called once the reference count reaches zero.
+ DecRef(ctx context.Context)
+
+ // TryIncRef attempts to increment the reference count, but may fail if all
+ // references have already been dropped, in which case it returns false. If
+ // true is returned, then a valid reference is now held on the object.
+ TryIncRef() bool
+}
diff --git a/pkg/refs_vfs2/refs_vfs2_state_autogen.go b/pkg/refs_vfs2/refs_vfs2_state_autogen.go
new file mode 100644
index 000000000..46925b4a4
--- /dev/null
+++ b/pkg/refs_vfs2/refs_vfs2_state_autogen.go
@@ -0,0 +1,3 @@
+// automatically generated by stateify.
+
+package refs_vfs2
diff --git a/pkg/sentry/fsimpl/tmpfs/inode_refs.go b/pkg/sentry/fsimpl/tmpfs/inode_refs.go
index 8b7ff185f..38eddde7e 100644
--- a/pkg/sentry/fsimpl/tmpfs/inode_refs.go
+++ b/pkg/sentry/fsimpl/tmpfs/inode_refs.go
@@ -1,11 +1,10 @@
package tmpfs
import (
- "runtime"
- "sync/atomic"
-
"gvisor.dev/gvisor/pkg/log"
refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "runtime"
+ "sync/atomic"
)
// ownerType is used to customize logging. Note that we use a pointer to T so
diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go
index 52ed5cea2..1b9721534 100644
--- a/pkg/sentry/kernel/abstract_socket_namespace.go
+++ b/pkg/sentry/kernel/abstract_socket_namespace.go
@@ -15,29 +15,21 @@
package kernel
import (
+ "fmt"
"syscall"
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/refs_vfs2"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sync"
)
// +stateify savable
type abstractEndpoint struct {
- ep transport.BoundEndpoint
- wr *refs.WeakRef
- name string
- ns *AbstractSocketNamespace
-}
-
-// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (e *abstractEndpoint) WeakRefGone(context.Context) {
- e.ns.mu.Lock()
- if e.ns.endpoints[e.name].ep == e.ep {
- delete(e.ns.endpoints, e.name)
- }
- e.ns.mu.Unlock()
+ ep transport.BoundEndpoint
+ socket refs_vfs2.RefCounter
+ name string
+ ns *AbstractSocketNamespace
}
// AbstractSocketNamespace is used to implement the Linux abstract socket functionality.
@@ -46,7 +38,11 @@ func (e *abstractEndpoint) WeakRefGone(context.Context) {
type AbstractSocketNamespace struct {
mu sync.Mutex `state:"nosave"`
- // Keeps mapping from name to endpoint.
+ // Keeps a mapping from name to endpoint. AbstractSocketNamespace does not hold
+ // any references on any sockets that it contains; when retrieving a socket,
+ // TryIncRef() must be called in case the socket is concurrently being
+ // destroyed. It is the responsibility of the socket to remove itself from the
+ // abstract socket namespace when it is destroyed.
endpoints map[string]abstractEndpoint
}
@@ -58,15 +54,15 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace {
}
// A boundEndpoint wraps a transport.BoundEndpoint to maintain a reference on
-// its backing object.
+// its backing socket.
type boundEndpoint struct {
transport.BoundEndpoint
- rc refs.RefCounter
+ socket refs_vfs2.RefCounter
}
// Release implements transport.BoundEndpoint.Release.
func (e *boundEndpoint) Release(ctx context.Context) {
- e.rc.DecRef(ctx)
+ e.socket.DecRef(ctx)
e.BoundEndpoint.Release(ctx)
}
@@ -81,32 +77,59 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp
return nil
}
- rc := ep.wr.Get()
- if rc == nil {
- delete(a.endpoints, name)
+ if !ep.socket.TryIncRef() {
+ // The socket has reached zero references and is being destroyed.
return nil
}
- return &boundEndpoint{ep.ep, rc}
+ return &boundEndpoint{ep.ep, ep.socket}
}
// Bind binds the given socket.
//
-// When the last reference managed by rc is dropped, ep may be removed from the
+// When the last reference managed by socket is dropped, ep may be removed from the
// namespace.
-func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, rc refs.RefCounter) error {
+func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error {
a.mu.Lock()
defer a.mu.Unlock()
+ // Check if there is already a socket (which has not yet been destroyed) bound at name.
if ep, ok := a.endpoints[name]; ok {
- if rc := ep.wr.Get(); rc != nil {
- rc.DecRef(ctx)
+ if ep.socket.TryIncRef() {
+ ep.socket.DecRef(ctx)
return syscall.EADDRINUSE
}
}
ae := abstractEndpoint{ep: ep, name: name, ns: a}
- ae.wr = refs.NewWeakRef(rc, &ae)
+ ae.socket = socket
a.endpoints[name] = ae
return nil
}
+
+// Remove removes the specified socket at name from the abstract socket
+// namespace, if it has not yet been replaced.
+func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+
+ ep, ok := a.endpoints[name]
+ if !ok {
+ // We never delete a map entry apart from a socket's destructor (although the
+ // map entry may be overwritten). Therefore, a socket should exist, even if it
+ // may not be the one we expect.
+ panic(fmt.Sprintf("expected socket to exist at '%s' in abstract socket namespace", name))
+ }
+
+ // A Bind() operation may race with callers of Remove(), e.g. in the
+ // following case:
+ // socket1 reaches zero references and begins destruction
+ // a.Bind("foo", ep, socket2) replaces socket1 with socket2
+ // socket1's destructor calls a.Remove("foo", socket1)
+ //
+ // Therefore, we need to check that the socket at name is what we expect
+ // before modifying the map.
+ if ep.socket == socket {
+ delete(a.endpoints, name)
+ }
+}
diff --git a/pkg/sentry/kernel/kernel_state_autogen.go b/pkg/sentry/kernel/kernel_state_autogen.go
index 8ab5e6f8e..106e237ec 100644
--- a/pkg/sentry/kernel/kernel_state_autogen.go
+++ b/pkg/sentry/kernel/kernel_state_autogen.go
@@ -16,7 +16,7 @@ func (x *abstractEndpoint) StateTypeName() string {
func (x *abstractEndpoint) StateFields() []string {
return []string{
"ep",
- "wr",
+ "socket",
"name",
"ns",
}
@@ -27,7 +27,7 @@ func (x *abstractEndpoint) beforeSave() {}
func (x *abstractEndpoint) StateSave(m state.Sink) {
x.beforeSave()
m.Save(0, &x.ep)
- m.Save(1, &x.wr)
+ m.Save(1, &x.socket)
m.Save(2, &x.name)
m.Save(3, &x.ns)
}
@@ -36,7 +36,7 @@ func (x *abstractEndpoint) afterLoad() {}
func (x *abstractEndpoint) StateLoad(m state.Source) {
m.Load(0, &x.ep)
- m.Load(1, &x.wr)
+ m.Load(1, &x.socket)
m.Load(2, &x.name)
m.Load(3, &x.ns)
}
diff --git a/pkg/sentry/platform/ring0/defs_impl_arm64.go b/pkg/sentry/platform/ring0/defs_impl_arm64.go
index 424b66f76..2dac9ad14 100644
--- a/pkg/sentry/platform/ring0/defs_impl_arm64.go
+++ b/pkg/sentry/platform/ring0/defs_impl_arm64.go
@@ -3,11 +3,11 @@ package ring0
import (
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables"
+ "reflect"
"fmt"
"gvisor.dev/gvisor/pkg/usermem"
"io"
- "reflect"
)
// Useful bits.
diff --git a/pkg/sentry/socket/unix/socket_refs.go b/pkg/sentry/socket/unix/socket_refs.go
new file mode 100644
index 000000000..8eb0c9327
--- /dev/null
+++ b/pkg/sentry/socket/unix/socket_refs.go
@@ -0,0 +1,111 @@
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/log"
+ refs_vfs1 "gvisor.dev/gvisor/pkg/refs"
+ "runtime"
+ "sync/atomic"
+)
+
+// ownerType is used to customize logging. Note that we use a pointer to T so
+// that we do not copy the entire object when passed as a format parameter.
+var socketOpsCommonownerType *socketOpsCommon
+
+// Refs implements refs.RefCounter. It keeps a reference count using atomic
+// operations and calls the destructor when the count reaches zero.
+//
+// Note that the number of references is actually refCount + 1 so that a default
+// zero-value Refs object contains one reference.
+//
+// +stateify savable
+type socketOpsCommonRefs struct {
+ // refCount is composed of two fields:
+ //
+ // [32-bit speculative references]:[32-bit real references]
+ //
+ // Speculative references are used for TryIncRef, to avoid a CompareAndSwap
+ // loop. See IncRef, DecRef and TryIncRef for details of how these fields are
+ // used.
+ refCount int64
+}
+
+func (r *socketOpsCommonRefs) finalize() {
+ var note string
+ switch refs_vfs1.GetLeakMode() {
+ case refs_vfs1.NoLeakChecking:
+ return
+ case refs_vfs1.UninitializedLeakChecking:
+ note = "(Leak checker uninitialized): "
+ }
+ if n := r.ReadRefs(); n != 0 {
+ log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, socketOpsCommonownerType, n)
+ }
+}
+
+// EnableLeakCheck checks for reference leaks when Refs gets garbage collected.
+func (r *socketOpsCommonRefs) EnableLeakCheck() {
+ if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking {
+ runtime.SetFinalizer(r, (*socketOpsCommonRefs).finalize)
+ }
+}
+
+// ReadRefs returns the current number of references. The returned count is
+// inherently racy and is unsafe to use without external synchronization.
+func (r *socketOpsCommonRefs) ReadRefs() int64 {
+
+ return atomic.LoadInt64(&r.refCount) + 1
+}
+
+// IncRef implements refs.RefCounter.IncRef.
+//
+//go:nosplit
+func (r *socketOpsCommonRefs) IncRef() {
+ if v := atomic.AddInt64(&r.refCount, 1); v <= 0 {
+ panic("Incrementing non-positive ref count")
+ }
+}
+
+// TryIncRef implements refs.RefCounter.TryIncRef.
+//
+// To do this safely without a loop, a speculative reference is first acquired
+// on the object. This allows multiple concurrent TryIncRef calls to distinguish
+// other TryIncRef calls from genuine references held.
+//
+//go:nosplit
+func (r *socketOpsCommonRefs) TryIncRef() bool {
+ const speculativeRef = 1 << 32
+ v := atomic.AddInt64(&r.refCount, speculativeRef)
+ if int32(v) < 0 {
+
+ atomic.AddInt64(&r.refCount, -speculativeRef)
+ return false
+ }
+
+ atomic.AddInt64(&r.refCount, -speculativeRef+1)
+ return true
+}
+
+// DecRef implements refs.RefCounter.DecRef.
+//
+// Note that speculative references are counted here. Since they were added
+// prior to real references reaching zero, they will successfully convert to
+// real references. In other words, we see speculative references only in the
+// following case:
+//
+// A: TryIncRef [speculative increase => sees non-negative references]
+// B: DecRef [real decrease]
+// A: TryIncRef [transform speculative to real]
+//
+//go:nosplit
+func (r *socketOpsCommonRefs) DecRef(destroy func()) {
+ switch v := atomic.AddInt64(&r.refCount, -1); {
+ case v < -1:
+ panic("Decrementing non-positive ref count")
+
+ case v == -1:
+
+ if destroy != nil {
+ destroy()
+ }
+ }
+}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 2b8454edb..b7e8e4325 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
- "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -80,7 +79,7 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
stype: stype,
},
}
- s.EnableLeakCheck("unix.SocketOperations")
+ s.EnableLeakCheck()
return fs.NewFile(ctx, d, flags, &s)
}
@@ -89,17 +88,26 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
//
// +stateify savable
type socketOpsCommon struct {
- refs.AtomicRefCount
+ socketOpsCommonRefs
socket.SendReceiveTimeout
ep transport.Endpoint
stype linux.SockType
+
+ // abstractName and abstractNamespace indicate the name and namespace of the
+ // socket if it is bound to an abstract socket namespace. Once the socket is
+ // bound, they cannot be modified.
+ abstractName string
+ abstractNamespace *kernel.AbstractSocketNamespace
}
// DecRef implements RefCounter.DecRef.
func (s *socketOpsCommon) DecRef(ctx context.Context) {
- s.DecRefWithDestructor(ctx, func(context.Context) {
+ s.socketOpsCommonRefs.DecRef(func() {
s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
})
}
@@ -284,10 +292,14 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if t.IsNetworkNamespaced() {
return syserr.ErrInvalidEndpointState
}
- if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil {
+ asn := t.AbstractSockets()
+ name := p[1:]
+ if err := asn.Bind(t, name, bep, s); err != nil {
// syserr.ErrPortInUse corresponds to EADDRINUSE.
return syserr.ErrPortInUse
}
+ s.abstractName = name
+ s.abstractNamespace = asn
} else {
// The parent and name.
var d *fs.Dirent
diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go
index 4a3bbc11b..6966529c6 100644
--- a/pkg/sentry/socket/unix/unix_state_autogen.go
+++ b/pkg/sentry/socket/unix/unix_state_autogen.go
@@ -6,6 +6,29 @@ import (
"gvisor.dev/gvisor/pkg/state"
)
+func (x *socketOpsCommonRefs) StateTypeName() string {
+ return "pkg/sentry/socket/unix.socketOpsCommonRefs"
+}
+
+func (x *socketOpsCommonRefs) StateFields() []string {
+ return []string{
+ "refCount",
+ }
+}
+
+func (x *socketOpsCommonRefs) beforeSave() {}
+
+func (x *socketOpsCommonRefs) StateSave(m state.Sink) {
+ x.beforeSave()
+ m.Save(0, &x.refCount)
+}
+
+func (x *socketOpsCommonRefs) afterLoad() {}
+
+func (x *socketOpsCommonRefs) StateLoad(m state.Source) {
+ m.Load(0, &x.refCount)
+}
+
func (x *SocketOperations) StateTypeName() string {
return "pkg/sentry/socket/unix.SocketOperations"
}
@@ -35,10 +58,12 @@ func (x *socketOpsCommon) StateTypeName() string {
func (x *socketOpsCommon) StateFields() []string {
return []string{
- "AtomicRefCount",
+ "socketOpsCommonRefs",
"SendReceiveTimeout",
"ep",
"stype",
+ "abstractName",
+ "abstractNamespace",
}
}
@@ -46,22 +71,27 @@ func (x *socketOpsCommon) beforeSave() {}
func (x *socketOpsCommon) StateSave(m state.Sink) {
x.beforeSave()
- m.Save(0, &x.AtomicRefCount)
+ m.Save(0, &x.socketOpsCommonRefs)
m.Save(1, &x.SendReceiveTimeout)
m.Save(2, &x.ep)
m.Save(3, &x.stype)
+ m.Save(4, &x.abstractName)
+ m.Save(5, &x.abstractNamespace)
}
func (x *socketOpsCommon) afterLoad() {}
func (x *socketOpsCommon) StateLoad(m state.Source) {
- m.Load(0, &x.AtomicRefCount)
+ m.Load(0, &x.socketOpsCommonRefs)
m.Load(1, &x.SendReceiveTimeout)
m.Load(2, &x.ep)
m.Load(3, &x.stype)
+ m.Load(4, &x.abstractName)
+ m.Load(5, &x.abstractNamespace)
}
func init() {
+ state.Register((*socketOpsCommonRefs)(nil))
state.Register((*SocketOperations)(nil))
state.Register((*socketOpsCommon)(nil))
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index dfa25241a..d066ef8ab 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -183,10 +183,14 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if t.IsNetworkNamespaced() {
return syserr.ErrInvalidEndpointState
}
- if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil {
+ asn := t.AbstractSockets()
+ name := p[1:]
+ if err := asn.Bind(t, name, bep, s); err != nil {
// syserr.ErrPortInUse corresponds to EADDRINUSE.
return syserr.ErrPortInUse
}
+ s.abstractName = name
+ s.abstractNamespace = asn
} else {
path := fspath.Parse(p)
root := t.FSContext().RootDirectoryVFS2()