summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/unix
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r--pkg/sentry/socket/unix/device.go20
-rw-r--r--pkg/sentry/socket/unix/io.go93
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go460
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned_state.go53
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go196
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go210
-rwxr-xr-xpkg/sentry/socket/unix/transport/transport_message_list.go173
-rwxr-xr-xpkg/sentry/socket/unix/transport/transport_state_autogen.go191
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go973
-rw-r--r--pkg/sentry/socket/unix/unix.go650
-rwxr-xr-xpkg/sentry/socket/unix/unix_state_autogen.go28
11 files changed, 3047 insertions, 0 deletions
diff --git a/pkg/sentry/socket/unix/device.go b/pkg/sentry/socket/unix/device.go
new file mode 100644
index 000000000..734d39ee6
--- /dev/null
+++ b/pkg/sentry/socket/unix/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// unixSocketDevice is the unix socket virtual device.
+var unixSocketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
new file mode 100644
index 000000000..5a1475ec2
--- /dev/null
+++ b/pkg/sentry/socket/unix/io.go
@@ -0,0 +1,93 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// EndpointWriter implements safemem.Writer that writes to a transport.Endpoint.
+//
+// EndpointWriter is not thread-safe.
+type EndpointWriter struct {
+ // Endpoint is the transport.Endpoint to write to.
+ Endpoint transport.Endpoint
+
+ // Control is the control messages to send.
+ Control transport.ControlMessages
+
+ // To is the endpoint to send to. May be nil.
+ To transport.BoundEndpoint
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecWriterFunc{func(bufs [][]byte) (int64, error) {
+ n, err := w.Endpoint.SendMsg(bufs, w.Control, w.To)
+ if err != nil {
+ return int64(n), err.ToError()
+ }
+ return int64(n), nil
+ }}.WriteFromBlocks(srcs)
+}
+
+// EndpointReader implements safemem.Reader that reads from a
+// transport.Endpoint.
+//
+// EndpointReader is not thread-safe.
+type EndpointReader struct {
+ // Endpoint is the transport.Endpoint to read from.
+ Endpoint transport.Endpoint
+
+ // Creds indicates if credential control messages are requested.
+ Creds bool
+
+ // NumRights is the number of SCM_RIGHTS FDs requested.
+ NumRights uintptr
+
+ // Peek indicates that the data should not be consumed from the
+ // endpoint.
+ Peek bool
+
+ // MsgSize is the size of the message that was read from. For stream
+ // sockets, it is the amount read.
+ MsgSize uintptr
+
+ // From, if not nil, will be set with the address read from.
+ From *tcpip.FullAddress
+
+ // Control contains the received control messages.
+ Control transport.ControlMessages
+
+ // ControlTrunc indicates that SCM_RIGHTS FDs were discarded based on
+ // the value of NumRights.
+ ControlTrunc bool
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
+ n, ms, c, ct, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return int64(n), err.ToError()
+ }
+ return int64(n), nil
+ }}.ReadToBlocks(dsts)
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
new file mode 100644
index 000000000..18e492862
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -0,0 +1,460 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// UniqueIDProvider generates a sequence of unique identifiers useful for,
+// among other things, lock ordering.
+type UniqueIDProvider interface {
+ // UniqueID returns a new unique identifier.
+ UniqueID() uint64
+}
+
+// A ConnectingEndpoint is a connectioned unix endpoint that is attempting to
+// establish a bidirectional connection with a BoundEndpoint.
+type ConnectingEndpoint interface {
+ // ID returns the endpoint's globally unique identifier. This identifier
+ // must be used to determine locking order if more than one endpoint is
+ // to be locked in the same codepath. The endpoint with the smaller
+ // identifier must be locked before endpoints with larger identifiers.
+ ID() uint64
+
+ // Passcred implements socket.Credentialer.Passcred.
+ Passcred() bool
+
+ // Type returns the socket type, typically either SockStream or
+ // SockSeqpacket. The connection attempt must be aborted if this
+ // value doesn't match the ConnectableEndpoint's type.
+ Type() SockType
+
+ // GetLocalAddress returns the bound path.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Locker protects the following methods. While locked, only the holder of
+ // the lock can change the return value of the protected methods.
+ sync.Locker
+
+ // Connected returns true iff the ConnectingEndpoint is in the connected
+ // state. ConnectingEndpoints can only be connected to a single endpoint,
+ // so the connection attempt must be aborted if this returns true.
+ Connected() bool
+
+ // Listening returns true iff the ConnectingEndpoint is in the listening
+ // state. ConnectingEndpoints cannot make connections while listening, so
+ // the connection attempt must be aborted if this returns true.
+ Listening() bool
+
+ // WaiterQueue returns a pointer to the endpoint's waiter queue.
+ WaiterQueue() *waiter.Queue
+}
+
+// connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements
+// ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint.
+//
+// connectionedEndpoints must be in connected state in order to transfer data.
+//
+// This implementation includes STREAM and SEQPACKET Unix sockets created with
+// socket(2), accept(2) or socketpair(2) and dgram unix sockets created with
+// socketpair(2). See unix_connectionless.go for the implementation of DGRAM
+// Unix sockets created with socket(2).
+//
+// The state is much simpler than a TCP endpoint, so it is not encoded
+// explicitly. Instead we enforce the following invariants:
+//
+// receiver != nil, connected != nil => connected.
+// path != "" && acceptedChan == nil => bound, not listening.
+// path != "" && acceptedChan != nil => bound and listening.
+//
+// Only one of these will be true at any moment.
+//
+// +stateify savable
+type connectionedEndpoint struct {
+ baseEndpoint
+
+ // id is the unique endpoint identifier. This is used exclusively for
+ // lock ordering within connect.
+ id uint64
+
+ // idGenerator is used to generate new unique endpoint identifiers.
+ idGenerator UniqueIDProvider
+
+ // stype is used by connecting sockets to ensure that they are the
+ // same type. The value is typically either tcpip.SockSeqpacket or
+ // tcpip.SockStream.
+ stype SockType
+
+ // acceptedChan is per the TCP endpoint implementation. Note that the
+ // sockets in this channel are _already in the connected state_, and
+ // have another associated connectionedEndpoint.
+ //
+ // If nil, then no listen call has been made.
+ acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"`
+}
+
+// NewConnectioned creates a new unbound connectionedEndpoint.
+func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
+func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
+ a := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+ b := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+
+ q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
+ q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
+
+ if stype == SockStream {
+ a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
+ b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}}
+ } else {
+ a.receiver = &queueReceiver{q1}
+ b.receiver = &queueReceiver{q2}
+ }
+
+ q2.IncRef()
+ a.connected = &connectedEndpoint{
+ endpoint: b,
+ writeQueue: q2,
+ }
+ q1.IncRef()
+ b.connected = &connectedEndpoint{
+ endpoint: a,
+ writeQueue: q1,
+ }
+
+ return a, b
+}
+
+// NewExternal creates a new externally backed Endpoint. It behaves like a
+// socketpair.
+func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// ID implements ConnectingEndpoint.ID.
+func (e *connectionedEndpoint) ID() uint64 {
+ return e.id
+}
+
+// Type implements ConnectingEndpoint.Type and Endpoint.Type.
+func (e *connectionedEndpoint) Type() SockType {
+ return e.stype
+}
+
+// WaiterQueue implements ConnectingEndpoint.WaiterQueue.
+func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue {
+ return e.Queue
+}
+
+// isBound returns true iff the connectionedEndpoint is bound (but not
+// listening).
+func (e *connectionedEndpoint) isBound() bool {
+ return e.path != "" && e.acceptedChan == nil
+}
+
+// Listening implements ConnectingEndpoint.Listening.
+func (e *connectionedEndpoint) Listening() bool {
+ return e.acceptedChan != nil
+}
+
+// Close puts the connectionedEndpoint in a closed state and frees all
+// resources associated with it.
+//
+// The socket will be a fresh state after a call to close and may be reused.
+// That is, close may be used to "unbind" or "disconnect" the socket in error
+// paths.
+func (e *connectionedEndpoint) Close() {
+ e.Lock()
+ var c ConnectedEndpoint
+ var r Receiver
+ switch {
+ case e.Connected():
+ e.connected.CloseSend()
+ e.receiver.CloseRecv()
+ c = e.connected
+ r = e.receiver
+ e.connected = nil
+ e.receiver = nil
+ case e.isBound():
+ e.path = ""
+ case e.Listening():
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.Close()
+ }
+ e.acceptedChan = nil
+ e.path = ""
+ }
+ e.Unlock()
+ if c != nil {
+ c.CloseNotify()
+ c.Release()
+ }
+ if r != nil {
+ r.CloseNotify()
+ r.Release()
+ }
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+ if ce.Type() != e.stype {
+ return syserr.ErrConnectionRefused
+ }
+
+ // Check if ce is e to avoid a deadlock.
+ if ce, ok := ce.(*connectionedEndpoint); ok && ce == e {
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Do a dance to safely acquire locks on both endpoints.
+ if e.id < ce.ID() {
+ e.Lock()
+ ce.Lock()
+ } else {
+ ce.Lock()
+ e.Lock()
+ }
+
+ // Check connecting state.
+ if ce.Connected() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Check bound state.
+ if !e.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+
+ // Create a newly bound connectionedEndpoint.
+ ne := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{
+ path: e.path,
+ Queue: &waiter.Queue{},
+ },
+ id: e.idGenerator.UniqueID(),
+ idGenerator: e.idGenerator,
+ stype: e.stype,
+ }
+
+ readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
+ ne.connected = &connectedEndpoint{
+ endpoint: ce,
+ writeQueue: readQueue,
+ }
+
+ writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
+ if e.stype == SockStream {
+ ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
+ } else {
+ ne.receiver = &queueReceiver{readQueue: writeQueue}
+ }
+
+ select {
+ case e.acceptedChan <- ne:
+ // Commit state.
+ writeQueue.IncRef()
+ connected := &connectedEndpoint{
+ endpoint: ne,
+ writeQueue: writeQueue,
+ }
+ readQueue.IncRef()
+ if e.stype == SockStream {
+ returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
+ } else {
+ returnConnect(&queueReceiver{readQueue: readQueue}, connected)
+ }
+
+ // Notify can deadlock if we are holding these locks.
+ e.Unlock()
+ ce.Unlock()
+
+ // Notify on both ends.
+ e.Notify(waiter.EventIn)
+ ce.WaiterQueue().Notify(waiter.EventOut)
+
+ return nil
+ default:
+ // Busy; return ECONNREFUSED per spec.
+ ne.Close()
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) {
+ return nil, syserr.ErrConnectionRefused
+}
+
+// Connect attempts to directly connect to another Endpoint.
+// Implements Endpoint.Connect.
+func (e *connectionedEndpoint) Connect(server BoundEndpoint) *syserr.Error {
+ returnConnect := func(r Receiver, ce ConnectedEndpoint) {
+ e.receiver = r
+ e.connected = ce
+ }
+
+ return server.BidirectionalConnect(e, returnConnect)
+}
+
+// Listen starts listening on the connection.
+func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.Listening() {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return syserr.ErrInvalidEndpointState
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+ if !e.isBound() {
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Normal case.
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ return nil
+}
+
+// Accept accepts a new connection.
+func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
+ e.Lock()
+ defer e.Unlock()
+
+ if !e.Listening() {
+ return nil, syserr.ErrInvalidEndpointState
+ }
+
+ select {
+ case ne := <-e.acceptedChan:
+ return ne, nil
+
+ default:
+ // Nothing left.
+ return nil, syserr.ErrWouldBlock
+ }
+}
+
+// Bind binds the connection.
+//
+// For Unix connectionedEndpoints, this _only sets the address associated with
+// the socket_. Work associated with sockets in the filesystem or finding those
+// sockets must be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() || e.Listening() {
+ return syserr.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return syserr.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+ // Stream sockets do not support specifying the endpoint. Seqpacket
+ // sockets ignore the passed endpoint.
+ if e.stype == SockStream && to != nil {
+ return 0, syserr.ErrNotSupported
+ }
+ return e.baseEndpoint.SendMsg(data, c, to)
+}
+
+// Readiness returns the current readiness of the connectionedEndpoint. For
+// example, if waiter.EventIn is set, the connectionedEndpoint is immediately
+// readable.
+func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ switch {
+ case e.Connected():
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ case e.Listening():
+ if mask&waiter.EventIn != 0 && len(e.acceptedChan) > 0 {
+ ready |= waiter.EventIn
+ }
+ }
+
+ return ready
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go
new file mode 100644
index 000000000..7e02a5db8
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectioned_state.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+// saveAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) saveAcceptedChan() []*connectionedEndpoint {
+ // If acceptedChan is nil (i.e. we are not listening) then we will save nil.
+ // Otherwise we create a (possibly empty) slice of the values in acceptedChan and
+ // save that.
+ var acceptedSlice []*connectionedEndpoint
+ if e.acceptedChan != nil {
+ // Swap out acceptedChan with a new empty channel of the same capacity.
+ saveChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(saveChan))
+
+ // Create a new slice with the same len and capacity as the channel.
+ acceptedSlice = make([]*connectionedEndpoint, len(saveChan), cap(saveChan))
+ // Drain acceptedChan into saveSlice, and fill up the new acceptChan at the
+ // same time.
+ for i := range acceptedSlice {
+ ep := <-saveChan
+ acceptedSlice[i] = ep
+ e.acceptedChan <- ep
+ }
+ close(saveChan)
+ }
+ return acceptedSlice
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEndpoint) {
+ // If acceptedSlice is nil, then acceptedChan should also be nil.
+ if acceptedSlice != nil {
+ // Otherwise, create a new channel with the same capacity as acceptedSlice.
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(acceptedSlice))
+ // Seed the channel with values from acceptedSlice.
+ for _, ep := range acceptedSlice {
+ e.acceptedChan <- ep
+ }
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
new file mode 100644
index 000000000..43ff875e4
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -0,0 +1,196 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// connectionlessEndpoint is a unix endpoint for unix sockets that support operating in
+// a conectionless fashon.
+//
+// Specifically, this means datagram unix sockets not created with
+// socketpair(2).
+//
+// +stateify savable
+type connectionlessEndpoint struct {
+ baseEndpoint
+}
+
+// NewConnectionless creates a new unbound dgram endpoint.
+func NewConnectionless() Endpoint {
+ ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
+ ep.receiver = &queueReceiver{readQueue: &queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}}
+ return ep
+}
+
+// isBound returns true iff the endpoint is bound.
+func (e *connectionlessEndpoint) isBound() bool {
+ return e.path != ""
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it.
+//
+// The socket will be a fresh state after a call to close and may be reused.
+// That is, close may be used to "unbind" or "disconnect" the socket in error
+// paths.
+func (e *connectionlessEndpoint) Close() {
+ e.Lock()
+ var r Receiver
+ if e.Connected() {
+ e.receiver.CloseRecv()
+ r = e.receiver
+ e.receiver = nil
+
+ e.connected.Release()
+ e.connected = nil
+ }
+ if e.isBound() {
+ e.path = ""
+ }
+ e.Unlock()
+ if r != nil {
+ r.CloseNotify()
+ r.Release()
+ }
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+ return syserr.ErrConnectionRefused
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) {
+ e.Lock()
+ r := e.receiver
+ e.Unlock()
+ if r == nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+ q := r.(*queueReceiver).readQueue
+ if !q.TryIncRef() {
+ return nil, syserr.ErrConnectionRefused
+ }
+ return &connectedEndpoint{
+ endpoint: e,
+ writeQueue: q,
+ }, nil
+}
+
+// SendMsg writes data and a control message to the specified endpoint.
+// This method does not block if the data cannot be written.
+func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+ if to == nil {
+ return e.baseEndpoint.SendMsg(data, c, nil)
+ }
+
+ connected, err := to.UnidirectionalConnect()
+ if err != nil {
+ return 0, syserr.ErrInvalidEndpointState
+ }
+ defer connected.Release()
+
+ e.Lock()
+ n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+
+ if notify {
+ connected.SendNotify()
+ }
+
+ return n, err
+}
+
+// Type implements Endpoint.Type.
+func (e *connectionlessEndpoint) Type() SockType {
+ return SockDgram
+}
+
+// Connect attempts to connect directly to server.
+func (e *connectionlessEndpoint) Connect(server BoundEndpoint) *syserr.Error {
+ connected, err := server.UnidirectionalConnect()
+ if err != nil {
+ return err
+ }
+
+ e.Lock()
+ e.connected = connected
+ e.Unlock()
+
+ return nil
+}
+
+// Listen starts listening on the connection.
+func (e *connectionlessEndpoint) Listen(int) *syserr.Error {
+ return syserr.ErrNotSupported
+}
+
+// Accept accepts a new connection.
+func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) {
+ return nil, syserr.ErrNotSupported
+}
+
+// Bind binds the connection.
+//
+// For Unix endpoints, this _only sets the address associated with the socket_.
+// Work associated with sockets in the filesystem or finding those sockets must
+// be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() {
+ return syserr.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return syserr.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+
+ if e.Connected() {
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ }
+
+ return ready
+}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
new file mode 100644
index 000000000..b650caae7
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -0,0 +1,210 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// queue is a buffer queue.
+//
+// +stateify savable
+type queue struct {
+ refs.AtomicRefCount
+
+ ReaderQueue *waiter.Queue
+ WriterQueue *waiter.Queue
+
+ mu sync.Mutex `state:"nosave"`
+ closed bool
+ used int64
+ limit int64
+ dataList messageList
+}
+
+// Close closes q for reading and writing. It is immediately not writable and
+// will become unreadable when no more data is pending.
+//
+// Both the read and write queues must be notified after closing:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Close() {
+ q.mu.Lock()
+ q.closed = true
+ q.mu.Unlock()
+}
+
+// Reset empties the queue and Releases all of the Entries.
+//
+// Both the read and write queues must be notified after resetting:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Reset() {
+ q.mu.Lock()
+ for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.Release()
+ }
+ q.dataList.Reset()
+ q.used = 0
+ q.mu.Unlock()
+}
+
+// DecRef implements RefCounter.DecRef with destructor q.Reset.
+func (q *queue) DecRef() {
+ q.DecRefWithDestructor(q.Reset)
+ // We don't need to notify after resetting because no one cares about
+ // this queue after all references have been dropped.
+}
+
+// IsReadable determines if q is currently readable.
+func (q *queue) IsReadable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.dataList.Front() != nil
+}
+
+// bufWritable returns true if there is space for writing.
+//
+// N.B. Linux only considers a unix socket "writable" if >75% of the buffer is
+// free.
+//
+// See net/unix/af_unix.c:unix_writeable.
+func (q *queue) bufWritable() bool {
+ return 4*q.used < q.limit
+}
+
+// IsWritable determines if q is currently writable.
+func (q *queue) IsWritable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.bufWritable()
+}
+
+// Enqueue adds an entry to the data queue if room is available.
+//
+// If truncate is true, Enqueue may truncate the message beforing enqueuing it.
+// Otherwise, the entire message must fit. If n < e.Length(), err indicates why.
+//
+// If notify is true, ReaderQueue.Notify must be called:
+// q.ReaderQueue.Notify(waiter.EventIn)
+func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *syserr.Error) {
+ q.mu.Lock()
+
+ if q.closed {
+ q.mu.Unlock()
+ return 0, false, syserr.ErrClosedForSend
+ }
+
+ free := q.limit - q.used
+
+ l = e.Length()
+
+ if l > free && truncate {
+ if free == 0 {
+ // Message can't fit right now.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrWouldBlock
+ }
+
+ e.Truncate(free)
+ l = e.Length()
+ err = syserr.ErrWouldBlock
+ }
+
+ if l > q.limit {
+ // Message is too big to ever fit.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrMessageTooLong
+ }
+
+ if l > free {
+ // Message can't fit right now.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrWouldBlock
+ }
+
+ notify = q.dataList.Front() == nil
+ q.used += l
+ q.dataList.PushBack(e)
+
+ q.mu.Unlock()
+
+ return l, notify, err
+}
+
+// Dequeue removes the first entry in the data queue, if one exists.
+//
+// If notify is true, WriterQueue.Notify must be called:
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) {
+ q.mu.Lock()
+
+ if q.dataList.Front() == nil {
+ err := syserr.ErrWouldBlock
+ if q.closed {
+ err = syserr.ErrClosedForReceive
+ }
+ q.mu.Unlock()
+
+ return nil, false, err
+ }
+
+ notify = !q.bufWritable()
+
+ e = q.dataList.Front()
+ q.dataList.Remove(e)
+ q.used -= e.Length()
+
+ notify = notify && q.bufWritable()
+
+ q.mu.Unlock()
+
+ return e, notify, nil
+}
+
+// Peek returns the first entry in the data queue, if one exists.
+func (q *queue) Peek() (*message, *syserr.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if q.dataList.Front() == nil {
+ err := syserr.ErrWouldBlock
+ if q.closed {
+ err = syserr.ErrClosedForReceive
+ }
+ return nil, err
+ }
+
+ return q.dataList.Front().Peek(), nil
+}
+
+// QueuedSize returns the number of bytes currently in the queue, that is, the
+// number of readable bytes.
+func (q *queue) QueuedSize() int64 {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ return q.used
+}
+
+// MaxQueueSize returns the maximum number of bytes storable in the queue.
+func (q *queue) MaxQueueSize() int64 {
+ return q.limit
+}
diff --git a/pkg/sentry/socket/unix/transport/transport_message_list.go b/pkg/sentry/socket/unix/transport/transport_message_list.go
new file mode 100755
index 000000000..6d394860c
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/transport_message_list.go
@@ -0,0 +1,173 @@
+package transport
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type messageElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (messageElementMapper) linkerFor(elem *message) *message { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type messageList struct {
+ head *message
+ tail *message
+}
+
+// Reset resets list l to the empty state.
+func (l *messageList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *messageList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *messageList) Front() *message {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *messageList) Back() *message {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *messageList) PushFront(e *message) {
+ messageElementMapper{}.linkerFor(e).SetNext(l.head)
+ messageElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ messageElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *messageList) PushBack(e *message) {
+ messageElementMapper{}.linkerFor(e).SetNext(nil)
+ messageElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ messageElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *messageList) PushBackList(m *messageList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ messageElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ messageElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *messageList) InsertAfter(b, e *message) {
+ a := messageElementMapper{}.linkerFor(b).Next()
+ messageElementMapper{}.linkerFor(e).SetNext(a)
+ messageElementMapper{}.linkerFor(e).SetPrev(b)
+ messageElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ messageElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *messageList) InsertBefore(a, e *message) {
+ b := messageElementMapper{}.linkerFor(a).Prev()
+ messageElementMapper{}.linkerFor(e).SetNext(a)
+ messageElementMapper{}.linkerFor(e).SetPrev(b)
+ messageElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ messageElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *messageList) Remove(e *message) {
+ prev := messageElementMapper{}.linkerFor(e).Prev()
+ next := messageElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ messageElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ messageElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type messageEntry struct {
+ next *message
+ prev *message
+}
+
+// Next returns the entry that follows e in the list.
+func (e *messageEntry) Next() *message {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *messageEntry) Prev() *message {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *messageEntry) SetNext(elem *message) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *messageEntry) SetPrev(elem *message) {
+ e.prev = elem
+}
diff --git a/pkg/sentry/socket/unix/transport/transport_state_autogen.go b/pkg/sentry/socket/unix/transport/transport_state_autogen.go
new file mode 100755
index 000000000..8bb1b5b31
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/transport_state_autogen.go
@@ -0,0 +1,191 @@
+// automatically generated by stateify.
+
+package transport
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *connectionedEndpoint) beforeSave() {}
+func (x *connectionedEndpoint) save(m state.Map) {
+ x.beforeSave()
+ var acceptedChan []*connectionedEndpoint = x.saveAcceptedChan()
+ m.SaveValue("acceptedChan", acceptedChan)
+ m.Save("baseEndpoint", &x.baseEndpoint)
+ m.Save("id", &x.id)
+ m.Save("idGenerator", &x.idGenerator)
+ m.Save("stype", &x.stype)
+}
+
+func (x *connectionedEndpoint) afterLoad() {}
+func (x *connectionedEndpoint) load(m state.Map) {
+ m.Load("baseEndpoint", &x.baseEndpoint)
+ m.Load("id", &x.id)
+ m.Load("idGenerator", &x.idGenerator)
+ m.Load("stype", &x.stype)
+ m.LoadValue("acceptedChan", new([]*connectionedEndpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*connectionedEndpoint)) })
+}
+
+func (x *connectionlessEndpoint) beforeSave() {}
+func (x *connectionlessEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("baseEndpoint", &x.baseEndpoint)
+}
+
+func (x *connectionlessEndpoint) afterLoad() {}
+func (x *connectionlessEndpoint) load(m state.Map) {
+ m.Load("baseEndpoint", &x.baseEndpoint)
+}
+
+func (x *queue) beforeSave() {}
+func (x *queue) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("ReaderQueue", &x.ReaderQueue)
+ m.Save("WriterQueue", &x.WriterQueue)
+ m.Save("closed", &x.closed)
+ m.Save("used", &x.used)
+ m.Save("limit", &x.limit)
+ m.Save("dataList", &x.dataList)
+}
+
+func (x *queue) afterLoad() {}
+func (x *queue) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("ReaderQueue", &x.ReaderQueue)
+ m.Load("WriterQueue", &x.WriterQueue)
+ m.Load("closed", &x.closed)
+ m.Load("used", &x.used)
+ m.Load("limit", &x.limit)
+ m.Load("dataList", &x.dataList)
+}
+
+func (x *messageList) beforeSave() {}
+func (x *messageList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *messageList) afterLoad() {}
+func (x *messageList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *messageEntry) beforeSave() {}
+func (x *messageEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *messageEntry) afterLoad() {}
+func (x *messageEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func (x *ControlMessages) beforeSave() {}
+func (x *ControlMessages) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Rights", &x.Rights)
+ m.Save("Credentials", &x.Credentials)
+}
+
+func (x *ControlMessages) afterLoad() {}
+func (x *ControlMessages) load(m state.Map) {
+ m.Load("Rights", &x.Rights)
+ m.Load("Credentials", &x.Credentials)
+}
+
+func (x *message) beforeSave() {}
+func (x *message) save(m state.Map) {
+ x.beforeSave()
+ m.Save("messageEntry", &x.messageEntry)
+ m.Save("Data", &x.Data)
+ m.Save("Control", &x.Control)
+ m.Save("Address", &x.Address)
+}
+
+func (x *message) afterLoad() {}
+func (x *message) load(m state.Map) {
+ m.Load("messageEntry", &x.messageEntry)
+ m.Load("Data", &x.Data)
+ m.Load("Control", &x.Control)
+ m.Load("Address", &x.Address)
+}
+
+func (x *queueReceiver) beforeSave() {}
+func (x *queueReceiver) save(m state.Map) {
+ x.beforeSave()
+ m.Save("readQueue", &x.readQueue)
+}
+
+func (x *queueReceiver) afterLoad() {}
+func (x *queueReceiver) load(m state.Map) {
+ m.Load("readQueue", &x.readQueue)
+}
+
+func (x *streamQueueReceiver) beforeSave() {}
+func (x *streamQueueReceiver) save(m state.Map) {
+ x.beforeSave()
+ m.Save("queueReceiver", &x.queueReceiver)
+ m.Save("buffer", &x.buffer)
+ m.Save("control", &x.control)
+ m.Save("addr", &x.addr)
+}
+
+func (x *streamQueueReceiver) afterLoad() {}
+func (x *streamQueueReceiver) load(m state.Map) {
+ m.Load("queueReceiver", &x.queueReceiver)
+ m.Load("buffer", &x.buffer)
+ m.Load("control", &x.control)
+ m.Load("addr", &x.addr)
+}
+
+func (x *connectedEndpoint) beforeSave() {}
+func (x *connectedEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("endpoint", &x.endpoint)
+ m.Save("writeQueue", &x.writeQueue)
+}
+
+func (x *connectedEndpoint) afterLoad() {}
+func (x *connectedEndpoint) load(m state.Map) {
+ m.Load("endpoint", &x.endpoint)
+ m.Load("writeQueue", &x.writeQueue)
+}
+
+func (x *baseEndpoint) beforeSave() {}
+func (x *baseEndpoint) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Queue", &x.Queue)
+ m.Save("passcred", &x.passcred)
+ m.Save("receiver", &x.receiver)
+ m.Save("connected", &x.connected)
+ m.Save("path", &x.path)
+}
+
+func (x *baseEndpoint) afterLoad() {}
+func (x *baseEndpoint) load(m state.Map) {
+ m.Load("Queue", &x.Queue)
+ m.Load("passcred", &x.passcred)
+ m.Load("receiver", &x.receiver)
+ m.Load("connected", &x.connected)
+ m.Load("path", &x.path)
+}
+
+func init() {
+ state.Register("transport.connectionedEndpoint", (*connectionedEndpoint)(nil), state.Fns{Save: (*connectionedEndpoint).save, Load: (*connectionedEndpoint).load})
+ state.Register("transport.connectionlessEndpoint", (*connectionlessEndpoint)(nil), state.Fns{Save: (*connectionlessEndpoint).save, Load: (*connectionlessEndpoint).load})
+ state.Register("transport.queue", (*queue)(nil), state.Fns{Save: (*queue).save, Load: (*queue).load})
+ state.Register("transport.messageList", (*messageList)(nil), state.Fns{Save: (*messageList).save, Load: (*messageList).load})
+ state.Register("transport.messageEntry", (*messageEntry)(nil), state.Fns{Save: (*messageEntry).save, Load: (*messageEntry).load})
+ state.Register("transport.ControlMessages", (*ControlMessages)(nil), state.Fns{Save: (*ControlMessages).save, Load: (*ControlMessages).load})
+ state.Register("transport.message", (*message)(nil), state.Fns{Save: (*message).save, Load: (*message).load})
+ state.Register("transport.queueReceiver", (*queueReceiver)(nil), state.Fns{Save: (*queueReceiver).save, Load: (*queueReceiver).load})
+ state.Register("transport.streamQueueReceiver", (*streamQueueReceiver)(nil), state.Fns{Save: (*streamQueueReceiver).save, Load: (*streamQueueReceiver).load})
+ state.Register("transport.connectedEndpoint", (*connectedEndpoint)(nil), state.Fns{Save: (*connectedEndpoint).save, Load: (*connectedEndpoint).load})
+ state.Register("transport.baseEndpoint", (*baseEndpoint)(nil), state.Fns{Save: (*baseEndpoint).save, Load: (*baseEndpoint).load})
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
new file mode 100644
index 000000000..b734b4c20
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -0,0 +1,973 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport contains the implementation of Unix endpoints.
+package transport
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// initialLimit is the starting limit for the socket buffers.
+const initialLimit = 16 * 1024
+
+// A SockType is a type (as opposed to family) of sockets. These are enumerated
+// in the syscall package as syscall.SOCK_* constants.
+type SockType int
+
+const (
+ // SockStream corresponds to syscall.SOCK_STREAM.
+ SockStream SockType = 1
+ // SockDgram corresponds to syscall.SOCK_DGRAM.
+ SockDgram SockType = 2
+ // SockRaw corresponds to syscall.SOCK_RAW.
+ SockRaw SockType = 3
+ // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET.
+ SockSeqpacket SockType = 5
+)
+
+// A RightsControlMessage is a control message containing FDs.
+type RightsControlMessage interface {
+ // Clone returns a copy of the RightsControlMessage.
+ Clone() RightsControlMessage
+
+ // Release releases any resources owned by the RightsControlMessage.
+ Release()
+}
+
+// A CredentialsControlMessage is a control message containing Unix credentials.
+type CredentialsControlMessage interface {
+ // Equals returns true iff the two messages are equal.
+ Equals(CredentialsControlMessage) bool
+}
+
+// A ControlMessages represents a collection of socket control messages.
+//
+// +stateify savable
+type ControlMessages struct {
+ // Rights is a control message containing FDs.
+ Rights RightsControlMessage
+
+ // Credentials is a control message containing Unix credentials.
+ Credentials CredentialsControlMessage
+}
+
+// Empty returns true iff the ControlMessages does not contain either
+// credentials or rights.
+func (c *ControlMessages) Empty() bool {
+ return c.Rights == nil && c.Credentials == nil
+}
+
+// Clone clones both the credentials and the rights.
+func (c *ControlMessages) Clone() ControlMessages {
+ cm := ControlMessages{}
+ if c.Rights != nil {
+ cm.Rights = c.Rights.Clone()
+ }
+ cm.Credentials = c.Credentials
+ return cm
+}
+
+// Release releases both the credentials and the rights.
+func (c *ControlMessages) Release() {
+ if c.Rights != nil {
+ c.Rights.Release()
+ }
+ *c = ControlMessages{}
+}
+
+// Endpoint is the interface implemented by Unix transport protocol
+// implementations that expose functionality like sendmsg, recvmsg, connect,
+// etc. to Unix socket implementations.
+type Endpoint interface {
+ Credentialer
+ waiter.Waitable
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it.
+ Close()
+
+ // RecvMsg reads data and a control message from the endpoint. This method
+ // does not block if there is no data pending.
+ //
+ // creds indicates if credential control messages are requested by the
+ // caller. This is useful for determining if control messages can be
+ // coalesced. creds is a hint and can be safely ignored by the
+ // implementation if no coalescing is possible. It is fine to return
+ // credential control messages when none were requested or to not return
+ // credential control messages when they were requested.
+ //
+ // numRights is the number of SCM_RIGHTS FDs requested by the caller. This
+ // is useful if one must allocate a buffer to receive a SCM_RIGHTS message
+ // or determine if control messages can be coalesced. numRights is a hint
+ // and can be safely ignored by the implementation if the number of
+ // available SCM_RIGHTS FDs is known and no coalescing is possible. It is
+ // fine for the returned number of SCM_RIGHTS FDs to be either higher or
+ // lower than the requested number.
+ //
+ // If peek is true, no data should be consumed from the Endpoint. Any and
+ // all data returned from a peek should be available in the next call to
+ // RecvMsg.
+ //
+ // recvLen is the number of bytes copied into data.
+ //
+ // msgLen is the length of the read message consumed for datagram Endpoints.
+ // msgLen is always the same as recvLen for stream Endpoints.
+ //
+ // CMTruncated indicates that the numRights hint was used to receive fewer
+ // than the total available SCM_RIGHTS FDs. Additional truncation may be
+ // required by the caller.
+ RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error)
+
+ // SendMsg writes data and a control message to the endpoint's peer.
+ // This method does not block if the data cannot be written.
+ //
+ // SendMsg does not take ownership of any of its arguments on error.
+ SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error)
+
+ // Connect connects this endpoint directly to another.
+ //
+ // This should be called on the client endpoint, and the (bound)
+ // endpoint passed in as a parameter.
+ //
+ // The error codes are the same as Connect.
+ Connect(server BoundEndpoint) *syserr.Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags tcpip.ShutdownFlags) *syserr.Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *syserr.Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *syserr.Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ //
+ // An optional commit function will be executed atomically with respect
+ // to binding the endpoint. If this returns an error, the bind will not
+ // occur and the error will be propagated back to the caller.
+ Bind(address tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error
+
+ // Type return the socket type, typically either SockStream, SockDgram
+ // or SockSeqpacket.
+ Type() SockType
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option
+ // types.
+ SetSockOpt(opt interface{}) *tcpip.Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // tcpip.*Option types.
+ GetSockOpt(opt interface{}) *tcpip.Error
+}
+
+// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
+// option.
+type Credentialer interface {
+ // Passcred returns whether or not the SO_PASSCRED socket option is
+ // enabled on this end.
+ Passcred() bool
+
+ // ConnectedPasscred returns whether or not the SO_PASSCRED socket option
+ // is enabled on the connected end.
+ ConnectedPasscred() bool
+}
+
+// A BoundEndpoint is a unix endpoint that can be connected to.
+type BoundEndpoint interface {
+ // BidirectionalConnect establishes a bi-directional connection between two
+ // unix endpoints in an all-or-nothing manner. If an error occurs during
+ // connecting, the state of neither endpoint should be modified.
+ //
+ // In order for an endpoint to establish such a bidirectional connection
+ // with a BoundEndpoint, the endpoint calls the BidirectionalConnect method
+ // on the BoundEndpoint and sends a representation of itself (the
+ // ConnectingEndpoint) and a callback (returnConnect) to receive the
+ // connection information (Receiver and ConnectedEndpoint) upon a
+ // successful connect. The callback should only be called on a successful
+ // connect.
+ //
+ // For a connection attempt to be successful, the ConnectingEndpoint must
+ // be unconnected and not listening and the BoundEndpoint whose
+ // BidirectionalConnect method is being called must be listening.
+ //
+ // This method will return syserr.ErrConnectionRefused on endpoints with a
+ // type that isn't SockStream or SockSeqpacket.
+ BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error
+
+ // UnidirectionalConnect establishes a write-only connection to a unix
+ // endpoint.
+ //
+ // An endpoint which calls UnidirectionalConnect and supports it itself must
+ // not hold its own lock when calling UnidirectionalConnect.
+ //
+ // This method will return syserr.ErrConnectionRefused on a non-SockDgram
+ // endpoint.
+ UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error)
+
+ // Release releases any resources held by the BoundEndpoint. It must be
+ // called before dropping all references to a BoundEndpoint returned by a
+ // function.
+ Release()
+}
+
+// message represents a message passed over a Unix domain socket.
+//
+// +stateify savable
+type message struct {
+ messageEntry
+
+ // Data is the Message payload.
+ Data buffer.View
+
+ // Control is auxiliary control message data that goes along with the
+ // data.
+ Control ControlMessages
+
+ // Address is the bound address of the endpoint that sent the message.
+ //
+ // If the endpoint that sent the message is not bound, the Address is
+ // the empty string.
+ Address tcpip.FullAddress
+}
+
+// Length returns number of bytes stored in the message.
+func (m *message) Length() int64 {
+ return int64(len(m.Data))
+}
+
+// Release releases any resources held by the message.
+func (m *message) Release() {
+ m.Control.Release()
+}
+
+// Peek returns a copy of the message.
+func (m *message) Peek() *message {
+ return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address}
+}
+
+// Truncate reduces the length of the message payload to n bytes.
+//
+// Preconditions: n <= m.Length().
+func (m *message) Truncate(n int64) {
+ m.Data.CapLength(int(n))
+}
+
+// A Receiver can be used to receive Messages.
+type Receiver interface {
+ // Recv receives a single message. This method does not block.
+ //
+ // See Endpoint.RecvMsg for documentation on shared arguments.
+ //
+ // notify indicates if RecvNotify should be called.
+ Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
+
+ // RecvNotify notifies the Receiver of a successful Recv. This must not be
+ // called while holding any endpoint locks.
+ RecvNotify()
+
+ // CloseRecv prevents the receiving of additional Messages.
+ //
+ // After CloseRecv is called, CloseNotify must also be called.
+ CloseRecv()
+
+ // CloseNotify notifies the Receiver of recv being closed. This must not be
+ // called while holding any endpoint locks.
+ CloseNotify()
+
+ // Readable returns if messages should be attempted to be received. This
+ // includes when read has been shutdown.
+ Readable() bool
+
+ // RecvQueuedSize returns the total amount of data currently receivable.
+ // RecvQueuedSize should return -1 if the operation isn't supported.
+ RecvQueuedSize() int64
+
+ // RecvMaxQueueSize returns maximum value for RecvQueuedSize.
+ // RecvMaxQueueSize should return -1 if the operation isn't supported.
+ RecvMaxQueueSize() int64
+
+ // Release releases any resources owned by the Receiver. It should be
+ // called before droping all references to a Receiver.
+ Release()
+}
+
+// queueReceiver implements Receiver for datagram sockets.
+//
+// +stateify savable
+type queueReceiver struct {
+ readQueue *queue
+}
+
+// Recv implements Receiver.Recv.
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ var m *message
+ var notify bool
+ var err *syserr.Error
+ if peek {
+ m, err = q.readQueue.Peek()
+ } else {
+ m, notify, err = q.readQueue.Dequeue()
+ }
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
+ }
+ src := []byte(m.Data)
+ var copied uintptr
+ for i := 0; i < len(data) && len(src) > 0; i++ {
+ n := copy(data[i], src)
+ copied += uintptr(n)
+ src = src[n:]
+ }
+ return copied, uintptr(len(m.Data)), m.Control, false, m.Address, notify, nil
+}
+
+// RecvNotify implements Receiver.RecvNotify.
+func (q *queueReceiver) RecvNotify() {
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseNotify implements Receiver.CloseNotify.
+func (q *queueReceiver) CloseNotify() {
+ q.readQueue.ReaderQueue.Notify(waiter.EventIn)
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseRecv implements Receiver.CloseRecv.
+func (q *queueReceiver) CloseRecv() {
+ q.readQueue.Close()
+}
+
+// Readable implements Receiver.Readable.
+func (q *queueReceiver) Readable() bool {
+ return q.readQueue.IsReadable()
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *queueReceiver) RecvQueuedSize() int64 {
+ return q.readQueue.QueuedSize()
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *queueReceiver) RecvMaxQueueSize() int64 {
+ return q.readQueue.MaxQueueSize()
+}
+
+// Release implements Receiver.Release.
+func (q *queueReceiver) Release() {
+ q.readQueue.DecRef()
+}
+
+// streamQueueReceiver implements Receiver for stream sockets.
+//
+// +stateify savable
+type streamQueueReceiver struct {
+ queueReceiver
+
+ mu sync.Mutex `state:"nosave"`
+ buffer []byte
+ control ControlMessages
+ addr tcpip.FullAddress
+}
+
+func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) {
+ var copied uintptr
+ for len(data) > 0 && len(buf) > 0 {
+ n := copy(data[0], buf)
+ copied += uintptr(n)
+ buf = buf[n:]
+ data[0] = data[0][n:]
+ if len(data[0]) == 0 {
+ data = data[1:]
+ }
+ }
+ return copied, data, buf
+}
+
+// Readable implements Receiver.Readable.
+func (q *streamQueueReceiver) Readable() bool {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ r := q.readQueue.IsReadable()
+ q.mu.Unlock()
+ // We're readable if we have data in our buffer or if the queue receiver is
+ // readable.
+ return bl > 0 || r
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *streamQueueReceiver) RecvQueuedSize() int64 {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ qs := q.readQueue.QueuedSize()
+ q.mu.Unlock()
+ return int64(bl) + qs
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
+ // The RecvMaxQueueSize() is the readQueue's MaxQueueSize() plus the largest
+ // message we can buffer which is also the largest message we can receive.
+ return 2 * q.readQueue.MaxQueueSize()
+}
+
+// Recv implements Receiver.Recv.
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ var notify bool
+
+ // If we have no data in the endpoint, we need to get some.
+ if len(q.buffer) == 0 {
+ // Load the next message into a buffer, even if we are peeking. Peeking
+ // won't consume the message, so it will be still available to be read
+ // the next time Recv() is called.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
+ }
+ notify = n
+ q.buffer = []byte(m.Data)
+ q.control = m.Control
+ q.addr = m.Address
+ }
+
+ var copied uintptr
+ if peek {
+ // Don't consume control message if we are peeking.
+ c := q.control.Clone()
+
+ // Don't consume data since we are peeking.
+ copied, data, _ = vecCopy(data, q.buffer)
+
+ return copied, copied, c, false, q.addr, notify, nil
+ }
+
+ // Consume data and control message since we are not peeking.
+ copied, data, q.buffer = vecCopy(data, q.buffer)
+
+ // Save the original state of q.control.
+ c := q.control
+
+ // Remove rights from q.control and leave behind just the creds.
+ q.control.Rights = nil
+ if !wantCreds {
+ c.Credentials = nil
+ }
+
+ var cmTruncated bool
+ if c.Rights != nil && numRights == 0 {
+ c.Rights.Release()
+ c.Rights = nil
+ cmTruncated = true
+ }
+
+ haveRights := c.Rights != nil
+
+ // If we have more capacity for data and haven't received any usable
+ // rights.
+ //
+ // Linux never coalesces rights control messages.
+ for !haveRights && len(data) > 0 {
+ // Get a message from the readQueue.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ // We already got some data, so ignore this error. This will
+ // manifest as a short read to the user, which is what Linux
+ // does.
+ break
+ }
+ notify = notify || n
+ q.buffer = []byte(m.Data)
+ q.control = m.Control
+ q.addr = m.Address
+
+ if wantCreds {
+ if (q.control.Credentials == nil) != (c.Credentials == nil) {
+ // One message has credentials, the other does not.
+ break
+ }
+
+ if q.control.Credentials != nil && c.Credentials != nil && !q.control.Credentials.Equals(c.Credentials) {
+ // Both messages have credentials, but they don't match.
+ break
+ }
+ }
+
+ if numRights != 0 && c.Rights != nil && q.control.Rights != nil {
+ // Both messages have rights.
+ break
+ }
+
+ var cpd uintptr
+ cpd, data, q.buffer = vecCopy(data, q.buffer)
+ copied += cpd
+
+ if cpd == 0 {
+ // data was actually full.
+ break
+ }
+
+ if q.control.Rights != nil {
+ // Consume rights.
+ if numRights == 0 {
+ cmTruncated = true
+ q.control.Rights.Release()
+ } else {
+ c.Rights = q.control.Rights
+ haveRights = true
+ }
+ q.control.Rights = nil
+ }
+ }
+ return copied, copied, c, cmTruncated, q.addr, notify, nil
+}
+
+// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
+type ConnectedEndpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Send sends a single message. This method does not block.
+ //
+ // notify indicates if SendNotify should be called.
+ //
+ // syserr.ErrWouldBlock can be returned along with a partial write if
+ // the caller should block to send the rest of the data.
+ Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *syserr.Error)
+
+ // SendNotify notifies the ConnectedEndpoint of a successful Send. This
+ // must not be called while holding any endpoint locks.
+ SendNotify()
+
+ // CloseSend prevents the sending of additional Messages.
+ //
+ // After CloseSend is call, CloseNotify must also be called.
+ CloseSend()
+
+ // CloseNotify notifies the ConnectedEndpoint of send being closed. This
+ // must not be called while holding any endpoint locks.
+ CloseNotify()
+
+ // Writable returns if messages should be attempted to be sent. This
+ // includes when write has been shutdown.
+ Writable() bool
+
+ // EventUpdate lets the ConnectedEndpoint know that event registrations
+ // have changed.
+ EventUpdate()
+
+ // SendQueuedSize returns the total amount of data currently queued for
+ // sending. SendQueuedSize should return -1 if the operation isn't
+ // supported.
+ SendQueuedSize() int64
+
+ // SendMaxQueueSize returns maximum value for SendQueuedSize.
+ // SendMaxQueueSize should return -1 if the operation isn't supported.
+ SendMaxQueueSize() int64
+
+ // Release releases any resources owned by the ConnectedEndpoint. It should
+ // be called before droping all references to a ConnectedEndpoint.
+ Release()
+}
+
+// +stateify savable
+type connectedEndpoint struct {
+ // endpoint represents the subset of the Endpoint functionality needed by
+ // the connectedEndpoint. It is implemented by both connectionedEndpoint
+ // and connectionlessEndpoint and allows the use of types which don't
+ // fully implement Endpoint.
+ endpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Type implements Endpoint.Type.
+ Type() SockType
+ }
+
+ writeQueue *queue
+}
+
+// Passcred implements ConnectedEndpoint.Passcred.
+func (e *connectedEndpoint) Passcred() bool {
+ return e.endpoint.Passcred()
+}
+
+// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress.
+func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return e.endpoint.GetLocalAddress()
+}
+
+// Send implements ConnectedEndpoint.Send.
+func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
+ var l int64
+ for _, d := range data {
+ l += int64(len(d))
+ }
+
+ truncate := false
+ if e.endpoint.Type() == SockStream {
+ // Since stream sockets don't preserve message boundaries, we
+ // can write only as much of the message as fits in the queue.
+ truncate = true
+
+ // Discard empty stream packets. Since stream sockets don't
+ // preserve message boundaries, sending zero bytes is a no-op.
+ // In Linux, the receiver actually uses a zero-length receive
+ // as an indication that the stream was closed.
+ if l == 0 {
+ controlMessages.Release()
+ return 0, false, nil
+ }
+ }
+
+ v := make([]byte, 0, l)
+ for _, d := range data {
+ v = append(v, d...)
+ }
+
+ l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate)
+ return uintptr(l), notify, err
+}
+
+// SendNotify implements ConnectedEndpoint.SendNotify.
+func (e *connectedEndpoint) SendNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+}
+
+// CloseNotify implements ConnectedEndpoint.CloseNotify.
+func (e *connectedEndpoint) CloseNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+ e.writeQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseSend implements ConnectedEndpoint.CloseSend.
+func (e *connectedEndpoint) CloseSend() {
+ e.writeQueue.Close()
+}
+
+// Writable implements ConnectedEndpoint.Writable.
+func (e *connectedEndpoint) Writable() bool {
+ return e.writeQueue.IsWritable()
+}
+
+// EventUpdate implements ConnectedEndpoint.EventUpdate.
+func (*connectedEndpoint) EventUpdate() {}
+
+// SendQueuedSize implements ConnectedEndpoint.SendQueuedSize.
+func (e *connectedEndpoint) SendQueuedSize() int64 {
+ return e.writeQueue.QueuedSize()
+}
+
+// SendMaxQueueSize implements ConnectedEndpoint.SendMaxQueueSize.
+func (e *connectedEndpoint) SendMaxQueueSize() int64 {
+ return e.writeQueue.MaxQueueSize()
+}
+
+// Release implements ConnectedEndpoint.Release.
+func (e *connectedEndpoint) Release() {
+ e.writeQueue.DecRef()
+}
+
+// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
+// unix domain socket Endpoint implementations.
+//
+// Not to be used on its own.
+//
+// +stateify savable
+type baseEndpoint struct {
+ *waiter.Queue
+
+ // passcred specifies whether SCM_CREDENTIALS socket control messages are
+ // enabled on this endpoint. Must be accessed atomically.
+ passcred int32
+
+ // Mutex protects the below fields.
+ sync.Mutex `state:"nosave"`
+
+ // receiver allows Messages to be received.
+ receiver Receiver
+
+ // connected allows messages to be sent and state information about the
+ // connected endpoint to be read.
+ connected ConnectedEndpoint
+
+ // path is not empty if the endpoint has been bound,
+ // or may be used if the endpoint is connected.
+ path string
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
+ e.Queue.EventRegister(we, mask)
+ e.Lock()
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
+ e.Queue.EventUnregister(we)
+ e.Lock()
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// Passcred implements Credentialer.Passcred.
+func (e *baseEndpoint) Passcred() bool {
+ return atomic.LoadInt32(&e.passcred) != 0
+}
+
+// ConnectedPasscred implements Credentialer.ConnectedPasscred.
+func (e *baseEndpoint) ConnectedPasscred() bool {
+ e.Lock()
+ defer e.Unlock()
+ return e.connected != nil && e.connected.Passcred()
+}
+
+func (e *baseEndpoint) setPasscred(pc bool) {
+ if pc {
+ atomic.StoreInt32(&e.passcred, 1)
+ } else {
+ atomic.StoreInt32(&e.passcred, 0)
+ }
+}
+
+// Connected implements ConnectingEndpoint.Connected.
+func (e *baseEndpoint) Connected() bool {
+ return e.receiver != nil && e.connected != nil
+}
+
+// RecvMsg reads data and a control message from the endpoint.
+func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) {
+ e.Lock()
+
+ if e.receiver == nil {
+ e.Unlock()
+ return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected
+ }
+
+ recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ e.Unlock()
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, err
+ }
+
+ if notify {
+ e.receiver.RecvNotify()
+ }
+
+ if addr != nil {
+ *addr = a
+ }
+ return recvLen, msgLen, cms, cmt, nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return 0, syserr.ErrNotConnected
+ }
+ if to != nil {
+ e.Unlock()
+ return 0, syserr.ErrAlreadyConnected
+ }
+
+ n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+
+ if notify {
+ e.connected.SendNotify()
+ }
+
+ return n, err
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.PasscredOption:
+ e.setPasscred(v != 0)
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
+ }
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.Lock()
+ if e.receiver == nil {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return syserr.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseRecv()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseSend()
+ }
+
+ e.Unlock()
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseNotify()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseNotify()
+ }
+
+ return nil
+}
+
+// GetLocalAddress returns the bound path.
+func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ defer e.Unlock()
+ return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil
+}
+
+// GetRemoteAddress returns the local address of the connected endpoint (if
+// available).
+func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ c := e.connected
+ e.Unlock()
+ if c != nil {
+ return c.GetLocalAddress()
+ }
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Release implements BoundEndpoint.Release.
+func (*baseEndpoint) Release() {
+ // Binding a baseEndpoint doesn't take a reference.
+}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
new file mode 100644
index 000000000..1414be0c6
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix.go
@@ -0,0 +1,650 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unix provides an implementation of the socket.Socket interface for
+// the AF_UNIX protocol family.
+package unix
+
+import (
+ "strings"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/epsocket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// SocketOperations is a Unix socket. It is similar to an epsocket, except it
+// is backed by a transport.Endpoint instead of a tcpip.Endpoint.
+//
+// +stateify savable
+type SocketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ isPacket bool
+}
+
+// New creates a new unix socket.
+func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
+ dirent := socket.NewDirent(ctx, unixSocketDevice)
+ defer dirent.DecRef()
+ return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
+}
+
+// NewWithDirent creates a new unix socket using an existing dirent.
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
+ return fs.NewFile(ctx, d, flags, &SocketOperations{
+ ep: ep,
+ isPacket: isPacket,
+ })
+}
+
+// DecRef implements RefCounter.DecRef.
+func (s *SocketOperations) DecRef() {
+ s.DecRefWithDestructor(func() {
+ s.ep.Close()
+ })
+}
+
+// Release implemements fs.FileOperations.Release.
+func (s *SocketOperations) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+// Endpoint extracts the transport.Endpoint.
+func (s *SocketOperations) Endpoint() transport.Endpoint {
+ return s.ep
+}
+
+// extractPath extracts and validates the address.
+func extractPath(sockaddr []byte) (string, *syserr.Error) {
+ addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr)
+ if err != nil {
+ return "", err
+ }
+
+ // The address is trimmed by GetAddress.
+ p := string(addr.Addr)
+ if p == "" {
+ // Not allowed.
+ return "", syserr.ErrInvalidArgument
+ }
+ if p[len(p)-1] == '/' {
+ // Weird, they tried to bind '/a/b/c/'?
+ return "", syserr.ErrIsDir
+ }
+
+ return p, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.ep.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.ep.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return epsocket.Ioctl(ctx, s.ep, io, args)
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+ return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return s.ep.Listen(backlog)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns := New(t, ep, s.isPacket)
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr interface{}
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ }
+ fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocket(ns, linux.AF_UNIX)
+
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ // The parent and name.
+ var d *fs.Dirent
+ var name string
+
+ cwd := t.FSContext().WorkingDirectory()
+ defer cwd.DecRef()
+
+ // Is there no slash at all?
+ if !strings.Contains(p, "/") {
+ d = cwd
+ name = p
+ } else {
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+ // Find the last path component, we know that something follows
+ // that final slash, otherwise extractPath() would have failed.
+ lastSlash := strings.LastIndex(p, "/")
+ subPath := p[:lastSlash]
+ if subPath == "" {
+ // Fix up subpath in case file is in root.
+ subPath = "/"
+ }
+ var err error
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, err = t.MountNamespace().FindInode(t, root, cwd, subPath, &remainingTraversals)
+ if err != nil {
+ // No path available.
+ return syserr.ErrNoSuchFile
+ }
+ defer d.DecRef()
+ name = p[lastSlash+1:]
+ }
+
+ // Create the socket.
+ childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
+ if err != nil {
+ return syserr.ErrPortInUse
+ }
+ childDir.DecRef()
+ }
+
+ return nil
+ })
+}
+
+// extractEndpoint retrieves the transport.BoundEndpoint associated with a Unix
+// socket path. The Release must be called on the transport.BoundEndpoint when
+// the caller is done with it.
+func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, *syserr.Error) {
+ path, err := extractPath(sockaddr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Is it abstract?
+ if path[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ ep := t.AbstractSockets().BoundEndpoint(path[1:])
+ if ep == nil {
+ // No socket found.
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ return ep, nil
+ }
+
+ // Find the node in the filesystem.
+ root := t.FSContext().RootDirectory()
+ cwd := t.FSContext().WorkingDirectory()
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals)
+ cwd.DecRef()
+ root.DecRef()
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+
+ // Extract the endpoint if one is there.
+ ep := d.Inode.BoundEndpoint(path)
+ d.DecRef()
+ if ep == nil {
+ // No socket!
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ return ep, nil
+}
+
+// Connect implements the linux syscall connect(2) for unix sockets.
+func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ ep, err := extractEndpoint(t, sockaddr)
+ if err != nil {
+ return err
+ }
+ defer ep.Release()
+
+ // Connect the server endpoint.
+ return s.ep.Connect(ep)
+}
+
+// Writev implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg([][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ w := EndpointWriter{
+ Endpoint: s.ep,
+ Control: controlMessages.Unix,
+ To: nil,
+ }
+ if len(to) > 0 {
+ ep, err := extractEndpoint(t, to)
+ if err != nil {
+ return 0, err
+ }
+ defer ep.Release()
+ w.To = ep
+ }
+
+ n, err := src.CopyInTo(t, &w)
+ if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return int(n), syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ n, err = src.CopyInTo(t, &w)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+
+ return int(total), syserr.FromError(err)
+}
+
+// Passcred implements transport.Credentialer.Passcred.
+func (s *SocketOperations) Passcred() bool {
+ return s.ep.Passcred()
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *SocketOperations) ConnectedPasscred() bool {
+ return s.ep.ConnectedPasscred()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.ep.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := epsocket.ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return s.ep.Shutdown(f)
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+ peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
+
+ // Calculate the number of FDs for which we have space and if we are
+ // requesting credentials.
+ var wantCreds bool
+ rightsLen := int(controlDataLen) - syscall.SizeofCmsghdr
+ if s.Passcred() {
+ // Credentials take priority if they are enabled and there is space.
+ wantCreds = rightsLen > 0
+ if !wantCreds {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+ credLen := syscall.CmsgSpace(syscall.SizeofUcred)
+ rightsLen -= credLen
+ }
+ // FDs are 32 bit (4 byte) ints.
+ numRights := rightsLen / 4
+ if numRights < 0 {
+ numRights = 0
+ }
+
+ r := EndpointReader{
+ Endpoint: s.ep,
+ Creds: wantCreds,
+ NumRights: uintptr(numRights),
+ Peek: peek,
+ }
+ if senderRequested {
+ r.From = &tcpip.FullAddress{}
+ }
+ var total int64
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
+ var from interface{}
+ var fromLen uint32
+ if r.From != nil {
+ from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
+ if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if s.isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+
+ return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ total += n
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ var from interface{}
+ var fromLen uint32
+ if r.From != nil {
+ from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
+ if trunc {
+ // n and r.MsgSize are the same for streams.
+ total += int64(r.MsgSize)
+ } else {
+ total += n
+ }
+
+ if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if total > 0 {
+ err = nil
+ }
+ if s.isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if total > 0 {
+ err = nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// provider is a unix domain socket provider.
+type provider struct{}
+
+// Socket returns a new unix domain socket.
+func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ var isPacket bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ isPacket = true
+ ep = transport.NewConnectionless()
+ case linux.SOCK_SEQPACKET:
+ isPacket = true
+ fallthrough
+ case linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return New(t, ep, isPacket), nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ var isPacket bool
+ switch stype {
+ case linux.SOCK_STREAM:
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ isPacket = true
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(stype, t.Kernel())
+ s1 := New(t, ep1, isPacket)
+ s2 := New(t, ep2, isPacket)
+
+ return s1, s2, nil
+}
+
+func init() {
+ socket.RegisterProvider(linux.AF_UNIX, &provider{})
+}
diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go
new file mode 100755
index 000000000..6f8d24b44
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_state_autogen.go
@@ -0,0 +1,28 @@
+// automatically generated by stateify.
+
+package unix
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *SocketOperations) beforeSave() {}
+func (x *SocketOperations) save(m state.Map) {
+ x.beforeSave()
+ m.Save("AtomicRefCount", &x.AtomicRefCount)
+ m.Save("SendReceiveTimeout", &x.SendReceiveTimeout)
+ m.Save("ep", &x.ep)
+ m.Save("isPacket", &x.isPacket)
+}
+
+func (x *SocketOperations) afterLoad() {}
+func (x *SocketOperations) load(m state.Map) {
+ m.Load("AtomicRefCount", &x.AtomicRefCount)
+ m.Load("SendReceiveTimeout", &x.SendReceiveTimeout)
+ m.Load("ep", &x.ep)
+ m.Load("isPacket", &x.isPacket)
+}
+
+func init() {
+ state.Register("unix.SocketOperations", (*SocketOperations)(nil), state.Fns{Save: (*SocketOperations).save, Load: (*SocketOperations).load})
+}