summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/unix
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r--pkg/sentry/socket/unix/BUILD3
-rw-r--r--pkg/sentry/socket/unix/device.go2
-rw-r--r--pkg/sentry/socket/unix/io.go15
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD3
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go36
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go62
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go8
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go25
-rw-r--r--pkg/sentry/socket/unix/unix.go79
9 files changed, 133 insertions, 100 deletions
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index fe6871cc6..da9977fde 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -9,7 +9,7 @@ go_library(
"io.go",
"unix.go",
],
- importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix",
+ importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix",
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
@@ -20,7 +20,6 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
- "//pkg/sentry/kernel/kdefs",
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
diff --git a/pkg/sentry/socket/unix/device.go b/pkg/sentry/socket/unix/device.go
index 734d39ee6..db01ac4c9 100644
--- a/pkg/sentry/socket/unix/device.go
+++ b/pkg/sentry/socket/unix/device.go
@@ -14,7 +14,7 @@
package unix
-import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+import "gvisor.dev/gvisor/pkg/sentry/device"
// unixSocketDevice is the unix socket virtual device.
var unixSocketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
index 5a1475ec2..760c7beab 100644
--- a/pkg/sentry/socket/unix/io.go
+++ b/pkg/sentry/socket/unix/io.go
@@ -15,15 +15,18 @@
package unix
import (
- "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/tcpip"
)
// EndpointWriter implements safemem.Writer that writes to a transport.Endpoint.
//
// EndpointWriter is not thread-safe.
type EndpointWriter struct {
+ Ctx context.Context
+
// Endpoint is the transport.Endpoint to write to.
Endpoint transport.Endpoint
@@ -37,7 +40,7 @@ type EndpointWriter struct {
// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
return safemem.FromVecWriterFunc{func(bufs [][]byte) (int64, error) {
- n, err := w.Endpoint.SendMsg(bufs, w.Control, w.To)
+ n, err := w.Endpoint.SendMsg(w.Ctx, bufs, w.Control, w.To)
if err != nil {
return int64(n), err.ToError()
}
@@ -50,6 +53,8 @@ func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error)
//
// EndpointReader is not thread-safe.
type EndpointReader struct {
+ Ctx context.Context
+
// Endpoint is the transport.Endpoint to read from.
Endpoint transport.Endpoint
@@ -81,7 +86,7 @@ type EndpointReader struct {
// ReadToBlocks implements safemem.Reader.ReadToBlocks.
func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
- n, ms, c, ct, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From)
+ n, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, bufs, r.Creds, r.NumRights, r.Peek, r.From)
r.Control = c
r.ControlTrunc = ct
r.MsgSize = ms
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 52f324eed..0b0240336 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -25,12 +25,13 @@ go_library(
"transport_message_list.go",
"unix.go",
],
- importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport",
+ importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport",
visibility = ["//:sandbox"],
deps = [
"//pkg/abi/linux",
"//pkg/ilist",
"//pkg/refs",
+ "//pkg/sentry/context",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index db79ac904..73d2df15d 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -17,10 +17,11 @@ package transport
import (
"sync"
- "gvisor.googlesource.com/gvisor/pkg/abi/linux"
- "gvisor.googlesource.com/gvisor/pkg/syserr"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// UniqueIDProvider generates a sequence of unique identifiers useful for,
@@ -111,8 +112,13 @@ type connectionedEndpoint struct {
acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"`
}
+var (
+ _ = BoundEndpoint((*connectionedEndpoint)(nil))
+ _ = Endpoint((*connectionedEndpoint)(nil))
+)
+
// NewConnectioned creates a new unbound connectionedEndpoint.
-func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint {
+func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
return &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
@@ -122,7 +128,7 @@ func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint {
}
// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
-func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
+func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
a := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
@@ -137,7 +143,9 @@ func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
}
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
+ q1.EnableLeakCheck("transport.queue")
q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
+ q2.EnableLeakCheck("transport.queue")
if stype == linux.SOCK_STREAM {
a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
@@ -163,7 +171,7 @@ func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
// NewExternal creates a new externally backed Endpoint. It behaves like a
// socketpair.
-func NewExternal(stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
+func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
return &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
id: uid.UniqueID(),
@@ -238,7 +246,7 @@ func (e *connectionedEndpoint) Close() {
}
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
-func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
if ce.Type() != e.stype {
return syserr.ErrConnectionRefused
}
@@ -288,12 +296,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur
}
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
+ readQueue.EnableLeakCheck("transport.queue")
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
+ writeQueue.EnableLeakCheck("transport.queue")
if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
@@ -334,19 +344,19 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur
}
// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
-func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) {
+func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
return nil, syserr.ErrConnectionRefused
}
// Connect attempts to directly connect to another Endpoint.
// Implements Endpoint.Connect.
-func (e *connectionedEndpoint) Connect(server BoundEndpoint) *syserr.Error {
+func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
returnConnect := func(r Receiver, ce ConnectedEndpoint) {
e.receiver = r
e.connected = ce
}
- return server.BidirectionalConnect(e, returnConnect)
+ return server.BidirectionalConnect(ctx, e, returnConnect)
}
// Listen starts listening on the connection.
@@ -426,13 +436,13 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
-func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
// Stream sockets do not support specifying the endpoint. Seqpacket
// sockets ignore the passed endpoint.
if e.stype == linux.SOCK_STREAM && to != nil {
return 0, syserr.ErrNotSupported
}
- return e.baseEndpoint.SendMsg(data, c, to)
+ return e.baseEndpoint.SendMsg(ctx, data, c, to)
}
// Readiness returns the current readiness of the connectionedEndpoint. For
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 81ebfba10..c7f7c5b16 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -15,14 +15,15 @@
package transport
import (
- "gvisor.googlesource.com/gvisor/pkg/abi/linux"
- "gvisor.googlesource.com/gvisor/pkg/syserr"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// connectionlessEndpoint is a unix endpoint for unix sockets that support operating in
-// a conectionless fashon.
+// a connectionless fashon.
//
// Specifically, this means datagram unix sockets not created with
// socketpair(2).
@@ -32,10 +33,17 @@ type connectionlessEndpoint struct {
baseEndpoint
}
+var (
+ _ = BoundEndpoint((*connectionlessEndpoint)(nil))
+ _ = Endpoint((*connectionlessEndpoint)(nil))
+)
+
// NewConnectionless creates a new unbound dgram endpoint.
-func NewConnectionless() Endpoint {
+func NewConnectionless(ctx context.Context) Endpoint {
ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
- ep.receiver = &queueReceiver{readQueue: &queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}}
+ q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
+ q.EnableLeakCheck("transport.queue")
+ ep.receiver = &queueReceiver{readQueue: &q}
return ep
}
@@ -46,38 +54,33 @@ func (e *connectionlessEndpoint) isBound() bool {
// Close puts the endpoint in a closed state and frees all resources associated
// with it.
-//
-// The socket will be a fresh state after a call to close and may be reused.
-// That is, close may be used to "unbind" or "disconnect" the socket in error
-// paths.
func (e *connectionlessEndpoint) Close() {
e.Lock()
- var r Receiver
- if e.Connected() {
- e.receiver.CloseRecv()
- r = e.receiver
- e.receiver = nil
-
+ if e.connected != nil {
e.connected.Release()
e.connected = nil
}
+
if e.isBound() {
e.path = ""
}
+
+ e.receiver.CloseRecv()
+ r := e.receiver
+ e.receiver = nil
e.Unlock()
- if r != nil {
- r.CloseNotify()
- r.Release()
- }
+
+ r.CloseNotify()
+ r.Release()
}
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
-func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
return syserr.ErrConnectionRefused
}
// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
-func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) {
+func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
e.Lock()
r := e.receiver
e.Unlock()
@@ -96,12 +99,12 @@ func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *sy
// SendMsg writes data and a control message to the specified endpoint.
// This method does not block if the data cannot be written.
-func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
if to == nil {
- return e.baseEndpoint.SendMsg(data, c, nil)
+ return e.baseEndpoint.SendMsg(ctx, data, c, nil)
}
- connected, err := to.UnidirectionalConnect()
+ connected, err := to.UnidirectionalConnect(ctx)
if err != nil {
return 0, syserr.ErrInvalidEndpointState
}
@@ -124,13 +127,16 @@ func (e *connectionlessEndpoint) Type() linux.SockType {
}
// Connect attempts to connect directly to server.
-func (e *connectionlessEndpoint) Connect(server BoundEndpoint) *syserr.Error {
- connected, err := server.UnidirectionalConnect()
+func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
+ connected, err := server.UnidirectionalConnect(ctx)
if err != nil {
return err
}
e.Lock()
+ if e.connected != nil {
+ e.connected.Release()
+ }
e.connected = connected
e.Unlock()
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index b650caae7..0415fae9a 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -17,9 +17,9 @@ package transport
import (
"sync"
- "gvisor.googlesource.com/gvisor/pkg/refs"
- "gvisor.googlesource.com/gvisor/pkg/syserr"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// queue is a buffer queue.
@@ -100,7 +100,7 @@ func (q *queue) IsWritable() bool {
// Enqueue adds an entry to the data queue if room is available.
//
-// If truncate is true, Enqueue may truncate the message beforing enqueuing it.
+// If truncate is true, Enqueue may truncate the message before enqueuing it.
// Otherwise, the entire message must fit. If n < e.Length(), err indicates why.
//
// If notify is true, ReaderQueue.Notify must be called:
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 5c55c529e..b0765ba55 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -19,11 +19,12 @@ import (
"sync"
"sync/atomic"
- "gvisor.googlesource.com/gvisor/pkg/abi/linux"
- "gvisor.googlesource.com/gvisor/pkg/syserr"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// initialLimit is the starting limit for the socket buffers.
@@ -120,13 +121,13 @@ type Endpoint interface {
// CMTruncated indicates that the numRights hint was used to receive fewer
// than the total available SCM_RIGHTS FDs. Additional truncation may be
// required by the caller.
- RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error)
+ RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error)
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
//
// SendMsg does not take ownership of any of its arguments on error.
- SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error)
+ SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error)
// Connect connects this endpoint directly to another.
//
@@ -134,7 +135,7 @@ type Endpoint interface {
// endpoint passed in as a parameter.
//
// The error codes are the same as Connect.
- Connect(server BoundEndpoint) *syserr.Error
+ Connect(ctx context.Context, server BoundEndpoint) *syserr.Error
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
@@ -215,7 +216,7 @@ type BoundEndpoint interface {
//
// This method will return syserr.ErrConnectionRefused on endpoints with a
// type that isn't SockStream or SockSeqpacket.
- BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error
+ BidirectionalConnect(ctx context.Context, ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error
// UnidirectionalConnect establishes a write-only connection to a unix
// endpoint.
@@ -225,7 +226,7 @@ type BoundEndpoint interface {
//
// This method will return syserr.ErrConnectionRefused on a non-SockDgram
// endpoint.
- UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error)
+ UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error)
// Passcred returns whether or not the SO_PASSCRED socket option is
// enabled on this end.
@@ -776,7 +777,7 @@ func (e *baseEndpoint) Connected() bool {
}
// RecvMsg reads data and a control message from the endpoint.
-func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) {
+func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) {
e.Lock()
if e.receiver == nil {
@@ -802,7 +803,7 @@ func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, pee
// SendMsg writes data and a control message to the endpoint's peer.
// This method does not block if the data cannot be written.
-func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
e.Lock()
if !e.Connected() {
e.Unlock()
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index b07e8d67b..eb262ecaf 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -21,24 +21,23 @@ import (
"strings"
"syscall"
- "gvisor.googlesource.com/gvisor/pkg/abi/linux"
- "gvisor.googlesource.com/gvisor/pkg/refs"
- "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
- "gvisor.googlesource.com/gvisor/pkg/sentry/context"
- "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
- "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
- "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
- "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
- ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/epsocket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
- "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
- "gvisor.googlesource.com/gvisor/pkg/syserr"
- "gvisor.googlesource.com/gvisor/pkg/syserror"
- "gvisor.googlesource.com/gvisor/pkg/tcpip"
- "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/epsocket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/waiter"
)
// SocketOperations is a Unix socket. It is similar to an epsocket, except it
@@ -64,15 +63,24 @@ type SocketOperations struct {
func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
defer dirent.DecRef()
- return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true})
+ return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true})
}
// NewWithDirent creates a new unix socket using an existing dirent.
func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
- return fs.NewFile(ctx, d, flags, &SocketOperations{
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ s := SocketOperations{
ep: ep,
stype: stype,
- })
+ }
+ s.EnableLeakCheck("unix.SocketOperations")
+
+ return fs.NewFile(ctx, d, flags, &s)
}
// DecRef implements RefCounter.DecRef.
@@ -108,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr)
+ addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}
@@ -152,7 +160,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *sy
}
// Ioctl implements fs.FileOperations.Ioctl.
-func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
return epsocket.Ioctl(ctx, s.ep, io, args)
}
@@ -191,7 +199,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, err := s.ep.Accept()
if err != nil {
@@ -226,10 +234,9 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- fdFlags := kernel.FDFlags{
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- }
- fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+ })
if e != nil {
return 0, nil, 0, syserr.FromError(e)
}
@@ -363,7 +370,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
defer ep.Release()
// Connect the server endpoint.
- return s.ep.Connect(ep)
+ return s.ep.Connect(t, ep)
}
// Writev implements fs.FileOperations.Write.
@@ -372,11 +379,12 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
ctrl := control.New(t, s.ep, nil)
if src.NumBytes() == 0 {
- nInt, err := s.ep.SendMsg([][]byte{}, ctrl, nil)
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
return int64(nInt), err.ToError()
}
return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
Endpoint: s.ep,
Control: ctrl,
To: nil,
@@ -387,6 +395,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
// a transport.Endpoint.
func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
w := EndpointWriter{
+ Ctx: t,
Endpoint: s.ep,
Control: controlMessages.Unix,
To: nil,
@@ -486,6 +495,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
return 0, nil
}
return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
Endpoint: s.ep,
NumRights: 0,
Peek: false,
@@ -522,6 +532,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
r := EndpointReader{
+ Ctx: t,
Endpoint: s.ep,
Creds: wantCreds,
NumRights: uintptr(numRights),
@@ -634,10 +645,10 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs
// Create the endpoint and socket.
var ep transport.Endpoint
switch stype {
- case linux.SOCK_DGRAM:
- ep = transport.NewConnectionless()
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
- ep = transport.NewConnectioned(stype, t.Kernel())
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
default:
return nil, syserr.ErrInvalidArgument
}
@@ -653,14 +664,14 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
}
switch stype {
- case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
// Ok
default:
return nil, nil, syserr.ErrInvalidArgument
}
// Create the endpoints and sockets.
- ep1, ep2 := transport.NewPair(stype, t.Kernel())
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
s1 := New(t, ep1, stype)
s2 := New(t, ep2, stype)