summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fs/host/socket.go
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/sentry/fs/host/socket.go
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/sentry/fs/host/socket.go')
-rw-r--r--pkg/sentry/fs/host/socket.go471
1 files changed, 471 insertions, 0 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
new file mode 100644
index 000000000..8e36ed7ee
--- /dev/null
+++ b/pkg/sentry/fs/host/socket.go
@@ -0,0 +1,471 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package host
+
+import (
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/fd"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
+ unixsocket "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier"
+)
+
+// endpoint encapsulates the state needed to represent a host Unix socket.
+type endpoint struct {
+ queue waiter.Queue `state:"nosave"`
+
+ // stype is the type of Unix socket. (Ex: unix.SockStream,
+ // unix.SockSeqpacket, unix.SockDgram)
+ stype unix.SockType `state:"nosave"`
+
+ // fd is the host fd backing this file.
+ fd int `state:"nosave"`
+
+ // If srfd >= 0, it is the host fd that fd was imported from.
+ srfd int `state:"wait"`
+}
+
+func (e *endpoint) init() error {
+ family, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN)
+ if err != nil {
+ return err
+ }
+
+ if family != syscall.AF_UNIX {
+ // We only allow Unix sockets.
+ return syserror.EINVAL
+ }
+
+ stype, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+ if err != nil {
+ return err
+ }
+
+ if err := syscall.SetNonblock(e.fd, true); err != nil {
+ return err
+ }
+
+ e.stype = unix.SockType(stype)
+ if err := fdnotifier.AddFD(int32(e.fd), &e.queue); err != nil {
+ return err
+ }
+ return nil
+}
+
+// newEndpoint creates a new host endpoint.
+func newEndpoint(fd int, srfd int) (*endpoint, error) {
+ ep := &endpoint{fd: fd, srfd: srfd}
+ if err := ep.init(); err != nil {
+ return nil, err
+ }
+ return ep, nil
+}
+
+// newSocket allocates a new unix socket with host endpoint.
+func newSocket(ctx context.Context, fd int, saveable bool) (*fs.File, error) {
+ ownedfd := fd
+ srfd := -1
+ if saveable {
+ var err error
+ ownedfd, err = syscall.Dup(fd)
+ if err != nil {
+ return nil, err
+ }
+ srfd = fd
+ }
+ ep, err := newEndpoint(ownedfd, srfd)
+ if err != nil {
+ if saveable {
+ syscall.Close(ownedfd)
+ }
+ return nil, err
+ }
+ return unixsocket.New(ctx, ep), nil
+}
+
+// NewSocketWithDirent allocates a new unix socket with host endpoint.
+//
+// This is currently only used by unsaveable Gofer nodes.
+//
+// NewSocketWithDirent takes ownership of f on success.
+func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.FileFlags) (*fs.File, error) {
+ ep, err := newEndpoint(f.FD(), -1)
+ if err != nil {
+ return nil, err
+ }
+
+ // Take ownship of the FD.
+ f.Release()
+
+ return unixsocket.NewWithDirent(ctx, d, ep, flags), nil
+}
+
+// Close implements unix.Endpoint.Close.
+func (e *endpoint) Close() {
+ fdnotifier.RemoveFD(int32(e.fd))
+ syscall.Close(e.fd)
+ e.fd = -1
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (e *endpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
+ e.queue.EventRegister(we, mask)
+ fdnotifier.UpdateFD(int32(e.fd))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (e *endpoint) EventUnregister(we *waiter.Entry) {
+ e.queue.EventUnregister(we)
+ fdnotifier.UpdateFD(int32(e.fd))
+}
+
+// Readiness implements unix.Endpoint.Readiness.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(e.fd), mask)
+}
+
+// Type implements unix.Endpoint.Type.
+func (e *endpoint) Type() unix.SockType {
+ return e.stype
+}
+
+// Connect implements unix.Endpoint.Connect.
+func (e *endpoint) Connect(server unix.BoundEndpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Bind implements unix.Endpoint.Bind.
+func (e *endpoint) Bind(address tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Listen implements unix.Endpoint.Listen.
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Accept implements unix.Endpoint.Accept.
+func (e *endpoint) Accept() (unix.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown implements unix.Endpoint.Shutdown.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// GetSockOpt implements unix.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ _, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_ERROR)
+ return translateError(err)
+ case *tcpip.PasscredOption:
+ // We don't support passcred on host sockets.
+ *o = 0
+ return nil
+ case *tcpip.SendBufferSizeOption:
+ v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ *o = tcpip.SendBufferSizeOption(v)
+ return translateError(err)
+ case *tcpip.ReceiveBufferSizeOption:
+ v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF)
+ *o = tcpip.ReceiveBufferSizeOption(v)
+ return translateError(err)
+ case *tcpip.ReuseAddressOption:
+ v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR)
+ *o = tcpip.ReuseAddressOption(v)
+ return translateError(err)
+ case *tcpip.ReceiveQueueSizeOption:
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ return tcpip.ErrInvalidEndpointState
+}
+
+// SetSockOpt implements unix.Endpoint.SetSockOpt.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// GetLocalAddress implements unix.Endpoint.GetLocalAddress.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, nil
+}
+
+// GetRemoteAddress implements unix.Endpoint.GetRemoteAddress.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, nil
+}
+
+// Passcred returns whether or not the SO_PASSCRED socket option is
+// enabled on this end.
+func (e *endpoint) Passcred() bool {
+ // We don't support credential passing for host sockets.
+ return false
+}
+
+// ConnectedPasscred returns whether or not the SO_PASSCRED socket option
+// is enabled on the connected end.
+func (e *endpoint) ConnectedPasscred() bool {
+ // We don't support credential passing for host sockets.
+ return false
+}
+
+// SendMsg implements unix.Endpoint.SendMsg.
+func (e *endpoint) SendMsg(data [][]byte, controlMessages unix.ControlMessages, to unix.BoundEndpoint) (uintptr, *tcpip.Error) {
+ if to != nil {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ return sendMsg(e.fd, data, controlMessages)
+}
+
+func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages) (uintptr, *tcpip.Error) {
+ if !controlMessages.Empty() {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ n, err := fdWriteVec(fd, data)
+ return n, translateError(err)
+}
+
+// RecvMsg implements unix.Endpoint.RecvMsg.
+func (e *endpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
+ return recvMsg(e.fd, data, numRights, peek, addr)
+}
+
+func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) {
+ var cm unet.ControlMessage
+ if numRights > 0 {
+ cm.EnableFDs(int(numRights))
+ }
+ rl, ml, cl, err := fdReadVec(fd, data, []byte(cm), peek)
+ if err == syscall.EAGAIN {
+ return 0, 0, unix.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+ if err != nil {
+ return 0, 0, unix.ControlMessages{}, translateError(err)
+ }
+
+ // Trim the control data if we received less than the full amount.
+ if cl < uint64(len(cm)) {
+ cm = cm[:cl]
+ }
+
+ // Avoid extra allocations in the case where there isn't any control data.
+ if len(cm) == 0 {
+ return rl, ml, unix.ControlMessages{}, nil
+ }
+
+ fds, err := cm.ExtractFDs()
+ if err != nil {
+ return 0, 0, unix.ControlMessages{}, translateError(err)
+ }
+
+ if len(fds) == 0 {
+ return rl, ml, unix.ControlMessages{}, nil
+ }
+ return rl, ml, control.New(nil, nil, newSCMRights(fds)), nil
+}
+
+// NewConnectedEndpoint creates a new unix.Receiver and unix.ConnectedEndpoint
+// backed by a host FD that will pretend to be bound at a given sentry path.
+func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (unix.Receiver, unix.ConnectedEndpoint, *tcpip.Error) {
+ if err := fdnotifier.AddFD(int32(file.FD()), queue); err != nil {
+ return nil, nil, translateError(err)
+ }
+
+ e := &connectedEndpoint{path: path, queue: queue, file: file}
+
+ // AtomicRefCounters start off with a single reference. We need two.
+ e.ref.IncRef()
+
+ return e, e, nil
+}
+
+// connectedEndpoint is a host FD backed implementation of
+// unix.ConnectedEndpoint and unix.Receiver.
+//
+// connectedEndpoint does not support save/restore for now.
+type connectedEndpoint struct {
+ queue *waiter.Queue
+ path string
+
+ // ref keeps track of references to a connectedEndpoint.
+ ref refs.AtomicRefCount
+
+ // mu protects fd, readClosed and writeClosed.
+ mu sync.RWMutex
+
+ // file is an *fd.FD containing the FD backing this endpoint. It must be
+ // set to nil if it has been closed.
+ file *fd.FD
+
+ // readClosed is true if the FD has read shutdown or if it has been closed.
+ readClosed bool
+
+ // writeClosed is true if the FD has write shutdown or if it has been
+ // closed.
+ writeClosed bool
+}
+
+// Send implements unix.ConnectedEndpoint.Send.
+func (c *connectedEndpoint) Send(data [][]byte, controlMessages unix.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.writeClosed {
+ return 0, false, tcpip.ErrClosedForSend
+ }
+ n, err := sendMsg(c.file.FD(), data, controlMessages)
+ // There is no need for the callee to call SendNotify because sendMsg uses
+ // the host's sendmsg(2) and the host kernel's queue.
+ return n, false, err
+}
+
+// SendNotify implements unix.ConnectedEndpoint.SendNotify.
+func (c *connectedEndpoint) SendNotify() {}
+
+// CloseSend implements unix.ConnectedEndpoint.CloseSend.
+func (c *connectedEndpoint) CloseSend() {
+ c.mu.Lock()
+ c.writeClosed = true
+ c.mu.Unlock()
+}
+
+// CloseNotify implements unix.ConnectedEndpoint.CloseNotify.
+func (c *connectedEndpoint) CloseNotify() {}
+
+// Writable implements unix.ConnectedEndpoint.Writable.
+func (c *connectedEndpoint) Writable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.writeClosed {
+ return true
+ }
+ return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0
+}
+
+// Passcred implements unix.ConnectedEndpoint.Passcred.
+func (c *connectedEndpoint) Passcred() bool {
+ // We don't support credential passing for host sockets.
+ return false
+}
+
+// GetLocalAddress implements unix.ConnectedEndpoint.GetLocalAddress.
+func (c *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{Addr: tcpip.Address(c.path)}, nil
+}
+
+// EventUpdate implements unix.ConnectedEndpoint.EventUpdate.
+func (c *connectedEndpoint) EventUpdate() {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.file.FD() != -1 {
+ fdnotifier.UpdateFD(int32(c.file.FD()))
+ }
+}
+
+// Recv implements unix.Receiver.Recv.
+func (c *connectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, unix.ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.readClosed {
+ return 0, 0, unix.ControlMessages{}, tcpip.FullAddress{}, false, tcpip.ErrClosedForReceive
+ }
+ rl, ml, cm, err := recvMsg(c.file.FD(), data, numRights, peek, nil)
+ // There is no need for the callee to call RecvNotify because recvMsg uses
+ // the host's recvmsg(2) and the host kernel's queue.
+ return rl, ml, cm, tcpip.FullAddress{Addr: tcpip.Address(c.path)}, false, err
+}
+
+// close releases all resources related to the endpoint.
+func (c *connectedEndpoint) close() {
+ fdnotifier.RemoveFD(int32(c.file.FD()))
+ c.file.Close()
+ c.file = nil
+}
+
+// RecvNotify implements unix.Receiver.RecvNotify.
+func (c *connectedEndpoint) RecvNotify() {}
+
+// CloseRecv implements unix.Receiver.CloseRecv.
+func (c *connectedEndpoint) CloseRecv() {
+ c.mu.Lock()
+ c.readClosed = true
+ c.mu.Unlock()
+}
+
+// Readable implements unix.Receiver.Readable.
+func (c *connectedEndpoint) Readable() bool {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+ if c.readClosed {
+ return true
+ }
+ return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0
+}
+
+// SendQueuedSize implements unix.Receiver.SendQueuedSize.
+func (c *connectedEndpoint) SendQueuedSize() int64 {
+ // SendQueuedSize isn't supported for host sockets because we don't allow the
+ // sentry to call ioctl(2).
+ return -1
+}
+
+// RecvQueuedSize implements unix.Receiver.RecvQueuedSize.
+func (c *connectedEndpoint) RecvQueuedSize() int64 {
+ // RecvQueuedSize isn't supported for host sockets because we don't allow the
+ // sentry to call ioctl(2).
+ return -1
+}
+
+// SendMaxQueueSize implements unix.Receiver.SendMaxQueueSize.
+func (c *connectedEndpoint) SendMaxQueueSize() int64 {
+ v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF)
+ if err != nil {
+ return -1
+ }
+ return int64(v)
+}
+
+// RecvMaxQueueSize implements unix.Receiver.RecvMaxQueueSize.
+func (c *connectedEndpoint) RecvMaxQueueSize() int64 {
+ v, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_RCVBUF)
+ if err != nil {
+ return -1
+ }
+ return int64(v)
+}
+
+// Release implements unix.ConnectedEndpoint.Release and unix.Receiver.Release.
+func (c *connectedEndpoint) Release() {
+ c.ref.DecRefWithDestructor(c.close)
+}
+
+func translateError(err error) *tcpip.Error {
+ if err == nil {
+ return nil
+ }
+ return rawfile.TranslateErrno(err.(syscall.Errno))
+}