summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/unix
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r--pkg/sentry/socket/unix/BUILD5
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD3
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go35
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go3
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go99
-rw-r--r--pkg/sentry/socket/unix/unix.go15
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go5
7 files changed, 61 insertions, 104 deletions
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cc7408698..cce0acc33 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,7 +8,7 @@ go_template_instance(
out = "socket_refs.go",
package = "unix",
prefix = "socketOperations",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketOperations",
},
@@ -19,7 +19,7 @@ go_template_instance(
out = "socket_vfs2_refs.go",
package = "unix",
prefix = "socketVFS2",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "SocketVFS2",
},
@@ -43,6 +43,7 @@ go_library(
"//pkg/log",
"//pkg/marshal",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 26c3a51b9..3ebbd28b0 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -20,7 +20,7 @@ go_template_instance(
out = "queue_refs.go",
package = "transport",
prefix = "queue",
- template = "//pkg/refs_vfs2:refs_template",
+ template = "//pkg/refsvfs2:refs_template",
types = {
"T": "queue",
},
@@ -44,6 +44,7 @@ go_library(
"//pkg/ilist",
"//pkg/log",
"//pkg/refs",
+ "//pkg/refsvfs2",
"//pkg/sync",
"//pkg/syserr",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index aa4f3c04d..9f7aca305 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -118,33 +118,29 @@ var (
// NewConnectioned creates a new unbound connectionedEndpoint.
func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
- return &connectionedEndpoint{
+ return newConnectioned(ctx, stype, uid)
+}
+
+func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint {
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
- a := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
- b := &connectionedEndpoint{
- baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
- id: uid.UniqueID(),
- idGenerator: uid,
- stype: stype,
- }
+ a := newConnectioned(ctx, stype, uid)
+ b := newConnectioned(ctx, stype, uid)
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
- q1.EnableLeakCheck()
+ q1.InitRefs()
q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
- q2.EnableLeakCheck()
+ q2.InitRefs()
if stype == linux.SOCK_STREAM {
a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
@@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
// NewExternal creates a new externally backed Endpoint. It behaves like a
// socketpair.
func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
- return &connectionedEndpoint{
+ ep := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
id: uid.UniqueID(),
idGenerator: uid,
stype: stype,
}
+ ep.ops.InitHandler(ep)
+ return ep
}
// ID implements ConnectingEndpoint.ID.
@@ -298,16 +296,17 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
idGenerator: e.idGenerator,
stype: e.stype,
}
+ ne.ops.InitHandler(ne)
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
- readQueue.EnableLeakCheck()
+ readQueue.InitRefs()
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
- writeQueue.EnableLeakCheck()
+ writeQueue.InitRefs()
if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index f8aacca13..0813ad87d 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -42,8 +42,9 @@ var (
func NewConnectionless(ctx context.Context) Endpoint {
ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
- q.EnableLeakCheck()
+ q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
+ ep.ops.InitHandler(ep)
return ep
}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index d6fc03520..099a56281 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -16,8 +16,6 @@
package transport
import (
- "sync/atomic"
-
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
@@ -32,6 +30,8 @@ import (
const initialLimit = 16 * 1024
// A RightsControlMessage is a control message containing FDs.
+//
+// +stateify savable
type RightsControlMessage interface {
// Clone returns a copy of the RightsControlMessage.
Clone() RightsControlMessage
@@ -178,10 +178,6 @@ type Endpoint interface {
// SetSockOpt sets a socket option.
SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error
- // SetSockOptBool sets a socket option for simple cases when a value has
- // the int type.
- SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
-
// SetSockOptInt sets a socket option for simple cases when a value has
// the int type.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
@@ -189,10 +185,6 @@ type Endpoint interface {
// GetSockOpt gets a socket option.
GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error
- // GetSockOptBool gets a socket option for simple cases when a return
- // value has the int type.
- GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
-
// GetSockOptInt gets a socket option for simple cases when a return
// value has the int type.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
@@ -201,8 +193,12 @@ type Endpoint interface {
// procfs.
State() uint32
- // LastError implements tcpip.Endpoint.LastError.
+ // LastError clears and returns the last error reported by the endpoint.
LastError() *tcpip.Error
+
+ // SocketOptions returns the structure which contains all the socket
+ // level options.
+ SocketOptions() *tcpip.SocketOptions
}
// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
@@ -336,7 +332,7 @@ type Receiver interface {
RecvMaxQueueSize() int64
// Release releases any resources owned by the Receiver. It should be
- // called before droping all references to a Receiver.
+ // called before dropping all references to a Receiver.
Release(ctx context.Context)
}
@@ -487,7 +483,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
c := q.control.Clone()
// Don't consume data since we are peeking.
- copied, data, _ = vecCopy(data, q.buffer)
+ copied, _, _ = vecCopy(data, q.buffer)
return copied, copied, c, false, q.addr, notify, nil
}
@@ -572,6 +568,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds
return copied, copied, c, cmTruncated, q.addr, notify, nil
}
+// Release implements Receiver.Release.
+func (q *streamQueueReceiver) Release(ctx context.Context) {
+ q.queueReceiver.Release(ctx)
+ q.control.Release(ctx)
+}
+
// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
type ConnectedEndpoint interface {
// Passcred implements Endpoint.Passcred.
@@ -619,7 +621,7 @@ type ConnectedEndpoint interface {
SendMaxQueueSize() int64
// Release releases any resources owned by the ConnectedEndpoint. It should
- // be called before droping all references to a ConnectedEndpoint.
+ // be called before dropping all references to a ConnectedEndpoint.
Release(ctx context.Context)
// CloseUnread sets the fact that this end is closed with unread data to
@@ -728,10 +730,7 @@ func (e *connectedEndpoint) CloseUnread() {
// +stateify savable
type baseEndpoint struct {
*waiter.Queue
-
- // passcred specifies whether SCM_CREDENTIALS socket control messages are
- // enabled on this endpoint. Must be accessed atomically.
- passcred int32
+ tcpip.DefaultSocketOptionsHandler
// Mutex protects the below fields.
sync.Mutex `state:"nosave"`
@@ -747,8 +746,8 @@ type baseEndpoint struct {
// or may be used if the endpoint is connected.
path string
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
+ // ops is used to get socket level options.
+ ops tcpip.SocketOptions
}
// EventRegister implements waiter.Waitable.EventRegister.
@@ -773,7 +772,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
// Passcred implements Credentialer.Passcred.
func (e *baseEndpoint) Passcred() bool {
- return atomic.LoadInt32(&e.passcred) != 0
+ return e.SocketOptions().GetPassCred()
}
// ConnectedPasscred implements Credentialer.ConnectedPasscred.
@@ -783,14 +782,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool {
return e.connected != nil && e.connected.Passcred()
}
-func (e *baseEndpoint) setPasscred(pc bool) {
- if pc {
- atomic.StoreInt32(&e.passcred, 1)
- } else {
- atomic.StoreInt32(&e.passcred, 0)
- }
-}
-
// Connected implements ConnectingEndpoint.Connected.
func (e *baseEndpoint) Connected() bool {
return e.receiver != nil && e.connected != nil
@@ -846,24 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option.
func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- e.linger = *v
- e.Unlock()
- }
- return nil
-}
-
-func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
- switch opt {
- case tcpip.BroadcastOption:
- case tcpip.PasscredOption:
- e.setPasscred(v)
- case tcpip.ReuseAddressOption:
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- }
return nil
}
@@ -877,20 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
return nil
}
-func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- switch opt {
- case tcpip.KeepaliveEnabledOption:
- return false, nil
-
- case tcpip.PasscredOption:
- return e.Passcred(), nil
-
- default:
- log.Warningf("Unsupported socket option: %d", opt)
- return false, tcpip.ErrUnknownProtocolOption
- }
-}
-
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -954,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- *o = e.linger
- e.Unlock()
- return nil
-
- default:
- log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
- }
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
}
// LastError implements Endpoint.LastError.
@@ -972,6 +922,11 @@ func (*baseEndpoint) LastError() *tcpip.Error {
return nil
}
+// SocketOptions implements Endpoint.SocketOptions.
+func (e *baseEndpoint) SocketOptions() *tcpip.SocketOptions {
+ return &e.ops
+}
+
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index a4a76d0a3..c59297c80 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -80,8 +80,7 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
stype: stype,
},
}
- s.EnableLeakCheck()
-
+ s.InitRefs()
return fs.NewFile(ctx, d, flags, &s)
}
@@ -137,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, family, err := netstack.AddressAndFamily(sockaddr)
+ addr, family, err := socket.AddressAndFamily(sockaddr)
if err != nil {
if err == syserr.ErrAddressFamilyNotSupported {
err = syserr.ErrInvalidArgument
@@ -170,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -182,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := socket.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -256,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -648,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -683,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
- from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index 678355fb9..27f705bb2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{})
// returns a corresponding file description.
func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
mnt := t.Kernel().SocketMount()
- d := sockfs.NewDentry(t.Credentials(), mnt)
+ d := sockfs.NewDentry(t, mnt)
defer d.DecRef(t)
fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
@@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
stype: stype,
},
}
+ sock.InitRefs()
sock.LockFD.Init(locks)
vfsfd := &sock.vfsfd
if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
@@ -171,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
if peerAddr != nil {
- addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
+ addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{