summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
authorDean Deng <deandeng@google.com>2020-08-17 11:40:08 -0700
committergVisor bot <gvisor-bot@google.com>2020-08-17 11:42:20 -0700
commit3bd066d5032c297e501f5c71be301ffa2fe9ed34 (patch)
treef689175cdbabf2ca52c76d111fafe0e3450f5a0c /pkg/sentry/socket
parent58154194b3e472cf476be6bf0530bf271d1d29f8 (diff)
Remove weak references from unix sockets.
The abstract socket namespace no longer holds any references on sockets. Instead, TryIncRef() is used when a socket is being retrieved in BoundEndpoint(). Abstract sockets are now responsible for removing themselves from the namespace they are in, when they are destroyed. Updates #1486. PiperOrigin-RevId: 327064173
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/unix/BUILD14
-rw-r--r--pkg/sentry/socket/unix/unix.go22
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go6
3 files changed, 36 insertions, 6 deletions
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 061a689a9..cb953e4dc 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -1,12 +1,25 @@
load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
package(licenses = ["notice"])
+go_template_instance(
+ name = "socket_refs",
+ out = "socket_refs.go",
+ package = "unix",
+ prefix = "socketOpsCommon",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "socketOpsCommon",
+ },
+)
+
go_library(
name = "unix",
srcs = [
"device.go",
"io.go",
+ "socket_refs.go",
"unix.go",
"unix_vfs2.go",
],
@@ -15,6 +28,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/fspath",
+ "//pkg/log",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 2b8454edb..b7e8e4325 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -24,7 +24,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
- "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -80,7 +79,7 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
stype: stype,
},
}
- s.EnableLeakCheck("unix.SocketOperations")
+ s.EnableLeakCheck()
return fs.NewFile(ctx, d, flags, &s)
}
@@ -89,17 +88,26 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
//
// +stateify savable
type socketOpsCommon struct {
- refs.AtomicRefCount
+ socketOpsCommonRefs
socket.SendReceiveTimeout
ep transport.Endpoint
stype linux.SockType
+
+ // abstractName and abstractNamespace indicate the name and namespace of the
+ // socket if it is bound to an abstract socket namespace. Once the socket is
+ // bound, they cannot be modified.
+ abstractName string
+ abstractNamespace *kernel.AbstractSocketNamespace
}
// DecRef implements RefCounter.DecRef.
func (s *socketOpsCommon) DecRef(ctx context.Context) {
- s.DecRefWithDestructor(ctx, func(context.Context) {
+ s.socketOpsCommonRefs.DecRef(func() {
s.ep.Close(ctx)
+ if s.abstractNamespace != nil {
+ s.abstractNamespace.Remove(s.abstractName, s)
+ }
})
}
@@ -284,10 +292,14 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if t.IsNetworkNamespaced() {
return syserr.ErrInvalidEndpointState
}
- if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil {
+ asn := t.AbstractSockets()
+ name := p[1:]
+ if err := asn.Bind(t, name, bep, s); err != nil {
// syserr.ErrPortInUse corresponds to EADDRINUSE.
return syserr.ErrPortInUse
}
+ s.abstractName = name
+ s.abstractNamespace = asn
} else {
// The parent and name.
var d *fs.Dirent
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index dfa25241a..d066ef8ab 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -183,10 +183,14 @@ func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
if t.IsNetworkNamespaced() {
return syserr.ErrInvalidEndpointState
}
- if err := t.AbstractSockets().Bind(t, p[1:], bep, s); err != nil {
+ asn := t.AbstractSockets()
+ name := p[1:]
+ if err := asn.Bind(t, name, bep, s); err != nil {
// syserr.ErrPortInUse corresponds to EADDRINUSE.
return syserr.ErrPortInUse
}
+ s.abstractName = name
+ s.abstractNamespace = asn
} else {
path := fspath.Parse(p)
root := t.FSContext().RootDirectoryVFS2()