diff options
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned.go | 4 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless.go | 30 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 17 |
3 files changed, 28 insertions, 23 deletions
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index e4c416233..73d2df15d 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -143,7 +143,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E } q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} + q1.EnableLeakCheck("transport.queue") q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} + q2.EnableLeakCheck("transport.queue") if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} @@ -294,12 +296,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn } readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} + readQueue.EnableLeakCheck("transport.queue") ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} + writeQueue.EnableLeakCheck("transport.queue") if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index e987519f0..c7f7c5b16 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -41,7 +41,9 @@ var ( // NewConnectionless creates a new unbound dgram endpoint. func NewConnectionless(ctx context.Context) Endpoint { ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} - ep.receiver = &queueReceiver{readQueue: &queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}} + q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} + q.EnableLeakCheck("transport.queue") + ep.receiver = &queueReceiver{readQueue: &q} return ep } @@ -52,29 +54,24 @@ func (e *connectionlessEndpoint) isBound() bool { // Close puts the endpoint in a closed state and frees all resources associated // with it. -// -// The socket will be a fresh state after a call to close and may be reused. -// That is, close may be used to "unbind" or "disconnect" the socket in error -// paths. func (e *connectionlessEndpoint) Close() { e.Lock() - var r Receiver - if e.Connected() { - e.receiver.CloseRecv() - r = e.receiver - e.receiver = nil - + if e.connected != nil { e.connected.Release() e.connected = nil } + if e.isBound() { e.path = "" } + + e.receiver.CloseRecv() + r := e.receiver + e.receiver = nil e.Unlock() - if r != nil { - r.CloseNotify() - r.Release() - } + + r.CloseNotify() + r.Release() } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. @@ -137,6 +134,9 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi } e.Lock() + if e.connected != nil { + e.connected.Release() + } e.connected = connected e.Unlock() diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 6190de0c5..637168714 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -28,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" @@ -69,10 +68,13 @@ func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) // NewWithDirent creates a new unix socket using an existing dirent. func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File { - return fs.NewFile(ctx, d, flags, &SocketOperations{ + s := SocketOperations{ ep: ep, stype: stype, - }) + } + s.EnableLeakCheck("unix.SocketOperations") + + return fs.NewFile(ctx, d, flags, &s) } // DecRef implements RefCounter.DecRef. @@ -108,7 +110,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr) + addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { return "", err } @@ -191,7 +193,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, * // Accept implements the linux syscall accept(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) { +func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) { // Issue the accept request to get the new endpoint. ep, err := s.ep.Accept() if err != nil { @@ -226,10 +228,9 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - fdFlags := kernel.FDFlags{ + fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, - } - fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits()) + }) if e != nil { return 0, nil, 0, syserr.FromError(e) } |