diff options
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r-- | pkg/sentry/socket/unix/BUILD | 29 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/io.go | 17 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/BUILD | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned.go | 20 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless.go | 14 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/queue.go | 52 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 141 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 215 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix_vfs2.go | 376 |
9 files changed, 710 insertions, 161 deletions
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 5b6a154f6..cb953e4dc 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -1,35 +1,54 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "socket_refs", + out = "socket_refs.go", + package = "unix", + prefix = "socketOpsCommon", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "socketOpsCommon", + }, +) + go_library( name = "unix", srcs = [ "device.go", "io.go", + "socket_refs.go", "unix.go", + "unix_vfs2.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", + "//pkg/fspath", + "//pkg/log", "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", + "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", + "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", + "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 2ec1a662d..129949990 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -15,8 +15,8 @@ package unix import ( - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -83,6 +83,19 @@ type EndpointReader struct { ControlTrunc bool } +// Truncate calls RecvMsg on the endpoint without writing to a destination. +func (r *EndpointReader) Truncate() error { + // Ignore bytes read since it will always be zero. + _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From) + r.Control = c + r.ControlTrunc = ct + r.MsgSize = ms + if err != nil { + return err.ToError() + } + return nil +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) { diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 788ad70d2..c708b6030 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -25,13 +25,14 @@ go_library( "transport_message_list.go", "unix.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/ilist", + "//pkg/log", "//pkg/refs", - "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index dea11e253..c67b602f0 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -15,10 +15,9 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" @@ -212,7 +211,7 @@ func (e *connectionedEndpoint) Listening() bool { // The socket will be a fresh state after a call to close and may be reused. // That is, close may be used to "unbind" or "disconnect" the socket in error // paths. -func (e *connectionedEndpoint) Close() { +func (e *connectionedEndpoint) Close(ctx context.Context) { e.Lock() var c ConnectedEndpoint var r Receiver @@ -234,7 +233,7 @@ func (e *connectionedEndpoint) Close() { case e.Listening(): close(e.acceptedChan) for n := range e.acceptedChan { - n.Close() + n.Close(ctx) } e.acceptedChan = nil e.path = "" @@ -242,18 +241,18 @@ func (e *connectionedEndpoint) Close() { e.Unlock() if c != nil { c.CloseNotify() - c.Release() + c.Release(ctx) } if r != nil { r.CloseNotify() - r.Release() + r.Release(ctx) } } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { if ce.Type() != e.stype { - return syserr.ErrConnectionRefused + return syserr.ErrWrongProtocolForSocket } // Check if ce is e to avoid a deadlock. @@ -341,7 +340,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn return nil default: // Busy; return ECONNREFUSED per spec. - ne.Close() + ne.Close(ctx) e.Unlock() ce.Unlock() return syserr.ErrConnectionRefused @@ -477,6 +476,9 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask // State implements socket.Socket.State. func (e *connectionedEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + if e.Connected() { return linux.SS_CONNECTED } diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 0322dec0b..70ee8f9b8 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -16,7 +16,7 @@ package transport import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" @@ -54,10 +54,10 @@ func (e *connectionlessEndpoint) isBound() bool { // Close puts the endpoint in a closed state and frees all resources associated // with it. -func (e *connectionlessEndpoint) Close() { +func (e *connectionlessEndpoint) Close(ctx context.Context) { e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) e.connected = nil } @@ -71,7 +71,7 @@ func (e *connectionlessEndpoint) Close() { e.Unlock() r.CloseNotify() - r.Release() + r.Release(ctx) } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. @@ -108,10 +108,10 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C if err != nil { return 0, syserr.ErrInvalidEndpointState } - defer connected.Release() + defer connected.Release(ctx) e.Lock() - n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -135,7 +135,7 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) } e.connected = connected e.Unlock() diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e27b1c714..ef6043e19 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,10 +15,12 @@ package transport import ( - "sync" - + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/waiter" ) @@ -56,10 +58,10 @@ func (q *queue) Close() { // Both the read and write queues must be notified after resetting: // q.ReaderQueue.Notify(waiter.EventIn) // q.WriterQueue.Notify(waiter.EventOut) -func (q *queue) Reset() { +func (q *queue) Reset(ctx context.Context) { q.mu.Lock() for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { - cur.Release() + cur.Release(ctx) } q.dataList.Reset() q.used = 0 @@ -67,8 +69,8 @@ func (q *queue) Reset() { } // DecRef implements RefCounter.DecRef with destructor q.Reset. -func (q *queue) DecRef() { - q.DecRefWithDestructor(q.Reset) +func (q *queue) DecRef(ctx context.Context) { + q.DecRefWithDestructor(ctx, q.Reset) // We don't need to notify after resetting because no one cares about // this queue after all references have been dropped. } @@ -101,12 +103,16 @@ func (q *queue) IsWritable() bool { // Enqueue adds an entry to the data queue if room is available. // +// If discardEmpty is true and there are zero bytes of data, the packet is +// dropped. +// // If truncate is true, Enqueue may truncate the message before enqueuing it. -// Otherwise, the entire message must fit. If n < e.Length(), err indicates why. +// Otherwise, the entire message must fit. If l is less than the size of data, +// err indicates why. // // If notify is true, ReaderQueue.Notify must be called: // q.ReaderQueue.Notify(waiter.EventIn) -func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *syserr.Error) { +func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) { q.mu.Lock() if q.closed { @@ -114,9 +120,16 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s return 0, false, syserr.ErrClosedForSend } - free := q.limit - q.used + for _, d := range data { + l += int64(len(d)) + } + if discardEmpty && l == 0 { + q.mu.Unlock() + c.Release(ctx) + return 0, false, nil + } - l = e.Length() + free := q.limit - q.used if l > free && truncate { if free == 0 { @@ -125,8 +138,7 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s return 0, false, syserr.ErrWouldBlock } - e.Truncate(free) - l = e.Length() + l = free err = syserr.ErrWouldBlock } @@ -137,14 +149,26 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s } if l > free { - // Message can't fit right now. + // Message can't fit right now, and could not be truncated. q.mu.Unlock() return 0, false, syserr.ErrWouldBlock } + // Aggregate l bytes of data. This will truncate the data if l is less than + // the total bytes held in data. + v := make([]byte, l) + for i, b := 0, v; i < len(data) && len(b) > 0; i++ { + n := copy(b, data[i]) + b = b[n:] + } + notify = q.dataList.Front() == nil q.used += l - q.dataList.PushBack(e) + q.dataList.PushBack(&message{ + Data: buffer.View(v), + Control: c, + Address: from, + }) q.mu.Unlock() diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 529a7a7a9..475d7177e 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,11 +16,12 @@ package transport import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -36,7 +37,7 @@ type RightsControlMessage interface { Clone() RightsControlMessage // Release releases any resources owned by the RightsControlMessage. - Release() + Release(ctx context.Context) } // A CredentialsControlMessage is a control message containing Unix credentials. @@ -73,9 +74,9 @@ func (c *ControlMessages) Clone() ControlMessages { } // Release releases both the credentials and the rights. -func (c *ControlMessages) Release() { +func (c *ControlMessages) Release(ctx context.Context) { if c.Rights != nil { - c.Rights.Release() + c.Rights.Release(ctx) } *c = ControlMessages{} } @@ -89,7 +90,7 @@ type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources // associated with it. - Close() + Close(ctx context.Context) // RecvMsg reads data and a control message from the endpoint. This method // does not block if there is no data pending. @@ -175,17 +176,25 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptBool sets a socket option for simple cases when a value has + // the int type. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has // the int type. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOptBool gets a socket option for simple cases when a return + // value has the int type. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) // State returns the current state of the socket, as represented by Linux in // procfs. @@ -243,7 +252,7 @@ type BoundEndpoint interface { // Release releases any resources held by the BoundEndpoint. It must be // called before dropping all references to a BoundEndpoint returned by a // function. - Release() + Release(ctx context.Context) } // message represents a message passed over a Unix domain socket. @@ -272,8 +281,8 @@ func (m *message) Length() int64 { } // Release releases any resources held by the message. -func (m *message) Release() { - m.Control.Release() +func (m *message) Release(ctx context.Context) { + m.Control.Release(ctx) } // Peek returns a copy of the message. @@ -295,7 +304,7 @@ type Receiver interface { // See Endpoint.RecvMsg for documentation on shared arguments. // // notify indicates if RecvNotify should be called. - Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) + Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) // RecvNotify notifies the Receiver of a successful Recv. This must not be // called while holding any endpoint locks. @@ -324,7 +333,7 @@ type Receiver interface { // Release releases any resources owned by the Receiver. It should be // called before droping all references to a Receiver. - Release() + Release(ctx context.Context) } // queueReceiver implements Receiver for datagram sockets. @@ -335,7 +344,7 @@ type queueReceiver struct { } // Recv implements Receiver.Recv. -func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *queueReceiver) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { var m *message var notify bool var err *syserr.Error @@ -389,8 +398,8 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 { } // Release implements Receiver.Release. -func (q *queueReceiver) Release() { - q.readQueue.DecRef() +func (q *queueReceiver) Release(ctx context.Context) { + q.readQueue.DecRef(ctx) } // streamQueueReceiver implements Receiver for stream sockets. @@ -447,7 +456,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { } // Recv implements Receiver.Recv. -func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { q.mu.Lock() defer q.mu.Unlock() @@ -493,7 +502,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, var cmTruncated bool if c.Rights != nil && numRights == 0 { - c.Rights.Release() + c.Rights.Release(ctx) c.Rights = nil cmTruncated = true } @@ -548,7 +557,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, // Consume rights. if numRights == 0 { cmTruncated = true - q.control.Rights.Release() + q.control.Rights.Release(ctx) } else { c.Rights = q.control.Rights haveRights = true @@ -573,7 +582,7 @@ type ConnectedEndpoint interface { // // syserr.ErrWouldBlock can be returned along with a partial write if // the caller should block to send the rest of the data. - Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) + Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) // SendNotify notifies the ConnectedEndpoint of a successful Send. This // must not be called while holding any endpoint locks. @@ -607,7 +616,7 @@ type ConnectedEndpoint interface { // Release releases any resources owned by the ConnectedEndpoint. It should // be called before droping all references to a ConnectedEndpoint. - Release() + Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to // the peer socket. @@ -645,35 +654,22 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) } // Send implements ConnectedEndpoint.Send. -func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { - var l int64 - for _, d := range data { - l += int64(len(d)) - } - +func (e *connectedEndpoint) Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { + discardEmpty := false truncate := false if e.endpoint.Type() == linux.SOCK_STREAM { - // Since stream sockets don't preserve message boundaries, we - // can write only as much of the message as fits in the queue. - truncate = true - // Discard empty stream packets. Since stream sockets don't // preserve message boundaries, sending zero bytes is a no-op. // In Linux, the receiver actually uses a zero-length receive // as an indication that the stream was closed. - if l == 0 { - controlMessages.Release() - return 0, false, nil - } - } + discardEmpty = true - v := make([]byte, 0, l) - for _, d := range data { - v = append(v, d...) + // Since stream sockets don't preserve message boundaries, we + // can write only as much of the message as fits in the queue. + truncate = true } - l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate) - return int64(l), notify, err + return e.writeQueue.Enqueue(ctx, data, c, from, discardEmpty, truncate) } // SendNotify implements ConnectedEndpoint.SendNotify. @@ -711,8 +707,8 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 { } // Release implements ConnectedEndpoint.Release. -func (e *connectedEndpoint) Release() { - e.writeQueue.DecRef() +func (e *connectedEndpoint) Release(ctx context.Context) { + e.writeQueue.DecRef(ctx) } // CloseUnread implements ConnectedEndpoint.CloseUnread. @@ -802,7 +798,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected } - recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(ctx, data, creds, numRights, peek) e.Unlock() if err != nil { return 0, 0, ControlMessages{}, false, err @@ -831,7 +827,7 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return 0, syserr.ErrAlreadyConnected } - n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := e.connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -843,19 +839,46 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. Currently not supported. func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { + return nil +} + +func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.BroadcastOption: case tcpip.PasscredOption: - e.setPasscred(v != 0) - return nil + e.setPasscred(v) + case tcpip.ReuseAddressOption: + default: + log.Warningf("Unsupported socket option: %d", opt) } return nil } -func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + } return nil } -func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.PasscredOption: + return e.Passcred(), nil + + default: + log.Warningf("Unsupported socket option: %d", opt) + return false, tcpip.ErrUnknownProtocolOption + } +} + +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -911,29 +934,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return int(v), nil default: + log.Warningf("Unsupported socket option: %d", opt) return -1, tcpip.ErrUnknownProtocolOption } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) - } - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: + log.Warningf("Unsupported socket option: %T", opt) return tcpip.ErrUnknownProtocolOption } } @@ -988,6 +1001,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } // Release implements BoundEndpoint.Release. -func (*baseEndpoint) Release() { +func (*baseEndpoint) Release(context.Context) { // Binding a baseEndpoint doesn't take a reference. } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1aaae8487..b7e8e4325 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -22,9 +22,9 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -33,11 +33,13 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -52,17 +54,14 @@ type SocketOperations struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` - refs.AtomicRefCount - socket.SendReceiveTimeout - ep transport.Endpoint - stype linux.SockType + socketOpsCommon } // New creates a new unix socket. func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true}) } @@ -75,29 +74,51 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty } s := SocketOperations{ - ep: ep, - stype: stype, + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, } - s.EnableLeakCheck("unix.SocketOperations") + s.EnableLeakCheck() return fs.NewFile(ctx, d, flags, &s) } +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { + socketOpsCommonRefs + socket.SendReceiveTimeout + + ep transport.Endpoint + stype linux.SockType + + // abstractName and abstractNamespace indicate the name and namespace of the + // socket if it is bound to an abstract socket namespace. Once the socket is + // bound, they cannot be modified. + abstractName string + abstractNamespace *kernel.AbstractSocketNamespace +} + // DecRef implements RefCounter.DecRef. -func (s *SocketOperations) DecRef() { - s.DecRefWithDestructor(func() { - s.ep.Close() +func (s *socketOpsCommon) DecRef(ctx context.Context) { + s.socketOpsCommonRefs.DecRef(func() { + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } }) } // Release implemements fs.FileOperations.Release. -func (s *SocketOperations) Release() { +func (s *socketOpsCommon) Release(ctx context.Context) { // Release only decrements a reference on s because s may be referenced in // the abstract socket namespace. - s.DecRef() + s.DecRef(ctx) } -func (s *SocketOperations) isPacket() bool { +func (s *socketOpsCommon) isPacket() bool { switch s.stype { case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: return true @@ -110,16 +131,22 @@ func (s *SocketOperations) isPacket() bool { } // Endpoint extracts the transport.Endpoint. -func (s *SocketOperations) Endpoint() transport.Endpoint { +func (s *socketOpsCommon) Endpoint() transport.Endpoint { return s.ep } // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, family, err := netstack.AddressAndFamily(sockaddr) if err != nil { + if err == syserr.ErrAddressFamilyNotSupported { + err = syserr.ErrInvalidArgument + } return "", err } + if family != linux.AF_UNIX { + return "", syserr.ErrInvalidArgument + } // The address is trimmed by GetAddress. p := string(addr.Addr) @@ -137,7 +164,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -149,7 +176,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -166,13 +193,13 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return s.ep.Listen(backlog) } @@ -215,7 +242,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } ns := New(t, ep, s.stype) - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -265,17 +292,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + asn := t.AbstractSockets() + name := p[1:] + if err := asn.Bind(t, name, bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } + s.abstractName = name + s.abstractNamespace = asn } else { // The parent and name. var d *fs.Dirent var name string cwd := t.FSContext().WorkingDirectory() - defer cwd.DecRef() + defer cwd.DecRef(t) // Is there no slash at all? if !strings.Contains(p, "/") { @@ -283,7 +314,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { name = p } else { root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) // Find the last path component, we know that something follows // that final slash, otherwise extractPath() would have failed. lastSlash := strings.LastIndex(p, "/") @@ -299,16 +330,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // No path available. return syserr.ErrNoSuchFile } - defer d.DecRef() + defer d.DecRef(t) name = p[lastSlash+1:] } // Create the socket. + // + // Note that the file permissions here are not set correctly (see + // gvisor.dev/issue/2324). There is no convenient way to get permissions + // on the socket referred to by s, so we will leave this discrepancy + // unresolved until VFS2 replaces this code. childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}}) if err != nil { return syserr.ErrPortInUse } - childDir.DecRef() + childDir.DecRef(t) } return nil @@ -339,41 +375,76 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, return ep, nil } + if kernel.VFS2Enabled { + p := fspath.Parse(path) + root := t.FSContext().RootDirectoryVFS2() + start := root + relPath := !p.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: p, + FollowFinalSymlink: true, + } + ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path}) + root.DecRef(t) + if relPath { + start.DecRef(t) + } + if e != nil { + return nil, syserr.FromError(e) + } + return ep, nil + } + // Find the node in the filesystem. root := t.FSContext().RootDirectory() cwd := t.FSContext().WorkingDirectory() remainingTraversals := uint(fs.DefaultTraversalLimit) d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals) - cwd.DecRef() - root.DecRef() + cwd.DecRef(t) + root.DecRef(t) if e != nil { return nil, syserr.FromError(e) } // Extract the endpoint if one is there. ep := d.Inode.BoundEndpoint(path) - d.DecRef() + d.DecRef(t) if ep == nil { // No socket! return nil, syserr.ErrConnectionRefused } - return ep, nil } // Connect implements the linux syscall connect(2) for unix sockets. -func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { ep, err := extractEndpoint(t, sockaddr) if err != nil { return err } - defer ep.Release() + defer ep.Release(t) // Connect the server endpoint. - return s.ep.Connect(t, ep) + err = s.ep.Connect(t, ep) + + if err == syserr.ErrWrongProtocolForSocket { + // Linux for abstract sockets returns ErrConnectionRefused + // instead of ErrWrongProtocolForSocket. + path, _ := extractPath(sockaddr) + if len(path) > 0 && path[0] == 0 { + err = syserr.ErrConnectionRefused + } + } + + return err } -// Writev implements fs.FileOperations.Write. +// Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { t := kernel.TaskFromContext(ctx) ctrl := control.New(t, s.ep, nil) @@ -393,7 +464,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by // a transport.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { w := EndpointWriter{ Ctx: t, Endpoint: s.ep, @@ -401,15 +472,25 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] To: nil, } if len(to) > 0 { - ep, err := extractEndpoint(t, to) - if err != nil { - return 0, err - } - defer ep.Release() - w.To = ep + switch s.stype { + case linux.SOCK_SEQPACKET: + to = nil + case linux.SOCK_STREAM: + if s.State() == linux.SS_CONNECTED { + return 0, syserr.ErrAlreadyConnected + } + return 0, syserr.ErrNotSupported + default: + ep, err := extractEndpoint(t, to) + if err != nil { + return 0, err + } + defer ep.Release(t) + w.To = ep - if ep.Passcred() && w.Control.Credentials == nil { - w.Control.Credentials = control.MakeCreds(t) + if ep.Passcred() && w.Control.Credentials == nil { + w.Control.Credentials = control.MakeCreds(t) + } } } @@ -447,27 +528,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } // Passcred implements transport.Credentialer.Passcred. -func (s *SocketOperations) Passcred() bool { +func (s *socketOpsCommon) Passcred() bool { return s.ep.Passcred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. -func (s *SocketOperations) ConnectedPasscred() bool { +func (s *socketOpsCommon) ConnectedPasscred() bool { return s.ep.ConnectedPasscred() } // Readiness implements waiter.Waitable.Readiness. -func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return s.ep.Readiness(mask) } // EventRegister implements waiter.Waitable.EventRegister. -func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.ep.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *SocketOperations) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } @@ -479,7 +560,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := netstack.ConvertShutdown(how) if err != nil { return err @@ -505,7 +586,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -541,8 +622,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if senderRequested { r.From = &tcpip.FullAddress{} } + + doRead := func() (int64, error) { + return dst.CopyOutFrom(t, &r) + } + + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. + if trunc && dst.Addrs.NumBytes() == 0 { + doRead = func() (int64, error) { + err := r.Truncate() + // Always return zero for bytes read since the destination size is + // zero. + return 0, err + } + + } + var total int64 - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { + if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait { var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { @@ -577,7 +677,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != syserror.ErrWouldBlock { var from linux.SockAddr var fromLen uint32 if r.From != nil { @@ -623,12 +723,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } // State implements socket.Socket.State. -func (s *SocketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { return s.ep.State() } // Type implements socket.Socket.Type. -func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { // Unix domain sockets always have a protocol of 0. return linux.AF_UNIX, s.stype, 0 } @@ -681,4 +781,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F func init() { socket.RegisterProvider(linux.AF_UNIX, &provider{}) + socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{}) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go new file mode 100644 index 000000000..d066ef8ab --- /dev/null +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -0,0 +1,376 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unix + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/control" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// SocketVFS2 implements socket.SocketVFS2 (and by extension, +// vfs.FileDescriptionImpl) for Unix sockets. +type SocketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD + + socketOpsCommon +} + +var _ = socket.SocketVFS2(&SocketVFS2{}) + +// NewSockfsFile creates a new socket file in the global sockfs mount and +// returns a corresponding file description. +func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { + mnt := t.Kernel().SocketMount() + d := sockfs.NewDentry(t.Credentials(), mnt) + + fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) + if err != nil { + return nil, syserr.FromError(err) + } + return fd, nil +} + +// NewFileDescription creates and returns a socket file description +// corresponding to the given mount and dentry. +func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) { + // You can create AF_UNIX, SOCK_RAW sockets. They're the same as + // SOCK_DGRAM and don't require CAP_NET_RAW. + if stype == linux.SOCK_RAW { + stype = linux.SOCK_DGRAM + } + + sock := &SocketVFS2{ + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, + } + sock.LockFD.Init(locks) + vfsfd := &sock.vfsfd + if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return vfsfd, nil +} + +// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) +} + +// blockingAccept implements a blocking version of accept(2), that is, if no +// connections are ready to be accept, it will block until one becomes ready. +func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { + // Register for notifications. + e, ch := waiter.NewChannelEntry(nil) + s.socketOpsCommon.EventRegister(&e, waiter.EventIn) + defer s.socketOpsCommon.EventUnregister(&e) + + // Try to accept the connection; if it fails, then wait until we get a + // notification. + for { + if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + return ep, err + } + + if err := t.Block(ch); err != nil { + return nil, syserr.FromError(err) + } + } +} + +// Accept implements the linux syscall accept(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + // Issue the accept request to get the new endpoint. + ep, err := s.ep.Accept() + if err != nil { + if err != syserr.ErrWouldBlock || !blocking { + return 0, nil, 0, err + } + + var err *syserr.Error + ep, err = s.blockingAccept(t) + if err != nil { + return 0, nil, 0, err + } + } + + ns, err := NewSockfsFile(t, ep, s.stype) + if err != nil { + return 0, nil, 0, err + } + defer ns.DecRef(t) + + if flags&linux.SOCK_NONBLOCK != 0 { + ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK) + } + + var addr linux.SockAddr + var addrLen uint32 + if peerRequested { + // Get address of the peer. + var err *syserr.Error + addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) + if err != nil { + return 0, nil, 0, err + } + } + + fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ + CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, + }) + if e != nil { + return 0, nil, 0, syserr.FromError(e) + } + + t.Kernel().RecordSocketVFS2(ns) + return fd, addr, addrLen, nil +} + +// Bind implements the linux syscall bind(2) for unix sockets. +func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { + p, e := extractPath(sockaddr) + if e != nil { + return e + } + + bep, ok := s.ep.(transport.BoundEndpoint) + if !ok { + // This socket can't be bound. + return syserr.ErrInvalidArgument + } + + return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error { + // Is it abstract? + if p[0] == 0 { + if t.IsNetworkNamespaced() { + return syserr.ErrInvalidEndpointState + } + asn := t.AbstractSockets() + name := p[1:] + if err := asn.Bind(t, name, bep, s); err != nil { + // syserr.ErrPortInUse corresponds to EADDRINUSE. + return syserr.ErrPortInUse + } + s.abstractName = name + s.abstractNamespace = asn + } else { + path := fspath.Parse(p) + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef(t) + start := root + relPath := !path.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef(t) + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + } + stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE}) + if err != nil { + return syserr.FromError(err) + } + err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{ + // File permissions correspond to net/unix/af_unix.c:unix_bind. + Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()), + Endpoint: bep, + }) + if err == syserror.EEXIST { + return syserr.ErrAddressInUse + } + return syserr.FromError(err) + } + + return nil + }) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return netstack.Ioctl(ctx, s.ep, uio, args) +} + +// PRead implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + if dst.NumBytes() == 0 { + return 0, nil + } + return dst.CopyOutFrom(ctx, &EndpointReader{ + Ctx: ctx, + Endpoint: s.ep, + NumRights: 0, + Peek: false, + From: nil, + }) +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + t := kernel.TaskFromContext(ctx) + ctrl := control.New(t, s.ep, nil) + + if src.NumBytes() == 0 { + nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil) + return int64(nInt), err.ToError() + } + + return src.CopyInTo(ctx, &EndpointWriter{ + Ctx: ctx, + Endpoint: s.ep, + Control: ctrl, + To: nil, + }) +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { + return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + +// providerVFS2 is a unix domain socket provider for VFS2. +type providerVFS2 struct{} + +func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Check arguments. + if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { + return nil, syserr.ErrProtocolNotSupported + } + + // Create the endpoint and socket. + var ep transport.Endpoint + switch stype { + case linux.SOCK_DGRAM, linux.SOCK_RAW: + ep = transport.NewConnectionless(t) + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: + ep = transport.NewConnectioned(t, stype, t.Kernel()) + default: + return nil, syserr.ErrInvalidArgument + } + + f, err := NewSockfsFile(t, ep, stype) + if err != nil { + ep.Close(t) + return nil, err + } + return f, nil +} + +// Pair creates a new pair of AF_UNIX connected sockets. +func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Check arguments. + if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { + return nil, nil, syserr.ErrProtocolNotSupported + } + + switch stype { + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW: + // Ok + default: + return nil, nil, syserr.ErrInvalidArgument + } + + // Create the endpoints and sockets. + ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) + s1, err := NewSockfsFile(t, ep1, stype) + if err != nil { + ep1.Close(t) + ep2.Close(t) + return nil, nil, err + } + s2, err := NewSockfsFile(t, ep2, stype) + if err != nil { + s1.DecRef(t) + ep2.Close(t) + return nil, nil, err + } + + return s1, s2, nil +} |