diff options
Diffstat (limited to 'pkg/sentry/fs/host/socket.go')
-rw-r--r-- | pkg/sentry/fs/host/socket.go | 319 |
1 files changed, 79 insertions, 240 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 68ebf6402..577e9e272 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -25,6 +25,8 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control" unixsocket "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" + "gvisor.googlesource.com/gvisor/pkg/sentry/uniqueid" + "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile" @@ -39,88 +41,35 @@ import ( // N.B. 8MB is the default maximum on Linux (2 * sysctl_wmem_max). const maxSendBufferSize = 8 << 20 -// endpoint encapsulates the state needed to represent a host Unix socket. -// -// TODO: Remove/merge with ConnectedEndpoint. -// -// +stateify savable -type endpoint struct { - queue waiter.Queue `state:"zerovalue"` - - // fd is the host fd backing this file. - fd int `state:"nosave"` - - // If srfd >= 0, it is the host fd that fd was imported from. - srfd int `state:"wait"` - - // stype is the type of Unix socket. - stype unix.SockType `state:"nosave"` - - // sndbuf is the size of the send buffer. - sndbuf int `state:"nosave"` -} - -func (e *endpoint) init() error { - family, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_DOMAIN) - if err != nil { - return err - } - - if family != syscall.AF_UNIX { - // We only allow Unix sockets. - return syserror.EINVAL - } - - stype, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_TYPE) - if err != nil { - return err - } - e.stype = unix.SockType(stype) - - e.sndbuf, err = syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF) - if err != nil { - return err - } - if e.sndbuf > maxSendBufferSize { - log.Warningf("Socket send buffer too large: %d", e.sndbuf) - return syserror.EINVAL - } - - if err := syscall.SetNonblock(e.fd, true); err != nil { - return err - } - - return fdnotifier.AddFD(int32(e.fd), &e.queue) -} - -// newEndpoint creates a new host endpoint. -func newEndpoint(fd int, srfd int) (*endpoint, error) { - ep := &endpoint{fd: fd, srfd: srfd} - if err := ep.init(); err != nil { - return nil, err - } - return ep, nil -} - // newSocket allocates a new unix socket with host endpoint. -func newSocket(ctx context.Context, fd int, saveable bool) (*fs.File, error) { - ownedfd := fd +func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) { + ownedfd := orgfd srfd := -1 if saveable { var err error - ownedfd, err = syscall.Dup(fd) + ownedfd, err = syscall.Dup(orgfd) if err != nil { return nil, err } - srfd = fd + srfd = orgfd } - ep, err := newEndpoint(ownedfd, srfd) + f := fd.New(ownedfd) + var q waiter.Queue + e, err := NewConnectedEndpoint(f, &q, "" /* path */) if err != nil { if saveable { - syscall.Close(ownedfd) + f.Close() + } else { + f.Release() } - return nil, err + return nil, syserr.TranslateNetstackError(err).ToError() } + + e.srfd = srfd + e.Init() + + ep := unix.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) + return unixsocket.New(ctx, ep), nil } @@ -130,139 +79,22 @@ func newSocket(ctx context.Context, fd int, saveable bool) (*fs.File, error) { // // NewSocketWithDirent takes ownership of f on success. func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.FileFlags) (*fs.File, error) { - ep, err := newEndpoint(f.FD(), -1) + f2 := fd.New(f.FD()) + var q waiter.Queue + e, err := NewConnectedEndpoint(f2, &q, "" /* path */) if err != nil { - return nil, err + f2.Release() + return nil, syserr.TranslateNetstackError(err).ToError() } // Take ownship of the FD. f.Release() - return unixsocket.NewWithDirent(ctx, d, ep, flags), nil -} - -// Close implements unix.Endpoint.Close. -func (e *endpoint) Close() { - fdnotifier.RemoveFD(int32(e.fd)) - syscall.Close(e.fd) - e.fd = -1 -} - -// EventRegister implements waiter.Waitable.EventRegister. -func (e *endpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) { - e.queue.EventRegister(we, mask) - fdnotifier.UpdateFD(int32(e.fd)) -} - -// EventUnregister implements waiter.Waitable.EventUnregister. -func (e *endpoint) EventUnregister(we *waiter.Entry) { - e.queue.EventUnregister(we) - fdnotifier.UpdateFD(int32(e.fd)) -} - -// Readiness implements unix.Endpoint.Readiness. -func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return fdnotifier.NonBlockingPoll(int32(e.fd), mask) -} - -// Type implements unix.Endpoint.Type. -func (e *endpoint) Type() unix.SockType { - return e.stype -} - -// Connect implements unix.Endpoint.Connect. -func (e *endpoint) Connect(server unix.BoundEndpoint) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// Bind implements unix.Endpoint.Bind. -func (e *endpoint) Bind(address tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// Listen implements unix.Endpoint.Listen. -func (e *endpoint) Listen(backlog int) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// Accept implements unix.Endpoint.Accept. -func (e *endpoint) Accept() (unix.Endpoint, *tcpip.Error) { - return nil, tcpip.ErrInvalidEndpointState -} - -// Shutdown implements unix.Endpoint.Shutdown. -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - return tcpip.ErrInvalidEndpointState -} - -// GetSockOpt implements unix.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - _, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_ERROR) - return translateError(err) - case *tcpip.PasscredOption: - // We don't support passcred on host sockets. - *o = 0 - return nil - case *tcpip.SendBufferSizeOption: - *o = tcpip.SendBufferSizeOption(e.sndbuf) - return nil - case *tcpip.ReceiveBufferSizeOption: - // N.B. Unix sockets don't use the receive buffer. We'll claim it is - // the same size as the send buffer. - *o = tcpip.ReceiveBufferSizeOption(e.sndbuf) - return nil - case *tcpip.ReuseAddressOption: - v, err := syscall.GetsockoptInt(e.fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR) - *o = tcpip.ReuseAddressOption(v) - return translateError(err) - case *tcpip.ReceiveQueueSizeOption: - return tcpip.ErrQueueSizeNotSupported - } - return tcpip.ErrInvalidEndpointState -} - -// SetSockOpt implements unix.Endpoint.SetSockOpt. -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return nil -} - -// GetLocalAddress implements unix.Endpoint.GetLocalAddress. -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -// GetRemoteAddress implements unix.Endpoint.GetRemoteAddress. -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil -} - -// Passcred returns whether or not the SO_PASSCRED socket option is -// enabled on this end. -func (e *endpoint) Passcred() bool { - // We don't support credential passing for host sockets. - return false -} - -// ConnectedPasscred returns whether or not the SO_PASSCRED socket option -// is enabled on the connected end. -func (e *endpoint) ConnectedPasscred() bool { - // We don't support credential passing for host sockets. - return false -} + e.Init() -// SendMsg implements unix.Endpoint.SendMsg. -func (e *endpoint) SendMsg(data [][]byte, controlMessages unix.ControlMessages, to unix.BoundEndpoint) (uintptr, *tcpip.Error) { - if to != nil { - return 0, tcpip.ErrInvalidEndpointState - } - - // Since stream sockets don't preserve message boundaries, we can write - // only as much of the message as fits in the send buffer. - truncate := e.stype == unix.SockStream + ep := unix.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return sendMsg(e.fd, data, controlMessages, e.sndbuf, truncate) + return unixsocket.NewWithDirent(ctx, d, ep, flags), nil } func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages, maxlen int, truncate bool) (uintptr, *tcpip.Error) { @@ -278,18 +110,6 @@ func sendMsg(fd int, data [][]byte, controlMessages unix.ControlMessages, maxlen return n, translateError(err) } -// RecvMsg implements unix.Endpoint.RecvMsg. -func (e *endpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) { - // N.B. Unix sockets don't have a receive buffer, the send buffer - // serves both purposes. - rl, ml, cm, err := recvMsg(e.fd, data, numRights, peek, addr, e.sndbuf) - if rl > 0 && err == tcpip.ErrWouldBlock { - // Message did not fill buffer; that's fine, no need to block. - err = nil - } - return rl, ml, cm, err -} - func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.FullAddress, maxlen int) (uintptr, uintptr, unix.ControlMessages, *tcpip.Error) { var cm unet.ControlMessage if numRights > 0 { @@ -328,42 +148,21 @@ func recvMsg(fd int, data [][]byte, numRights uintptr, peek bool, addr *tcpip.Fu // be called twice because host.ConnectedEndpoint is both a unix.Receiver and // unix.ConnectedEndpoint. func NewConnectedEndpoint(file *fd.FD, queue *waiter.Queue, path string) (*ConnectedEndpoint, *tcpip.Error) { - family, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_DOMAIN) - if err != nil { - return nil, translateError(err) - } - - if family != syscall.AF_UNIX { - // We only allow Unix sockets. - return nil, tcpip.ErrInvalidEndpointState - } - - stype, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_TYPE) - if err != nil { - return nil, translateError(err) - } - - sndbuf, err := syscall.GetsockoptInt(file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF) - if err != nil { - return nil, translateError(err) - } - if sndbuf > maxSendBufferSize { - log.Warningf("Socket send buffer too large: %d", sndbuf) - return nil, tcpip.ErrInvalidEndpointState + e := ConnectedEndpoint{ + path: path, + queue: queue, + file: file, + srfd: -1, } - e := &ConnectedEndpoint{ - path: path, - queue: queue, - file: file, - stype: unix.SockType(stype), - sndbuf: sndbuf, + if err := e.init(); err != nil { + return nil, err } // AtomicRefCounters start off with a single reference. We need two. e.ref.IncRef() - return e, nil + return &e, nil } // Init will do initialization required without holding other locks. @@ -376,7 +175,7 @@ func (c *ConnectedEndpoint) Init() { // ConnectedEndpoint is a host FD backed implementation of // unix.ConnectedEndpoint and unix.Receiver. // -// ConnectedEndpoint does not support save/restore for now. +// +stateify savable type ConnectedEndpoint struct { queue *waiter.Queue path string @@ -385,11 +184,11 @@ type ConnectedEndpoint struct { ref refs.AtomicRefCount // mu protects fd, readClosed and writeClosed. - mu sync.RWMutex + mu sync.RWMutex `state:"nosave"` // file is an *fd.FD containing the FD backing this endpoint. It must be // set to nil if it has been closed. - file *fd.FD + file *fd.FD `state:"nosave"` // readClosed is true if the FD has read shutdown or if it has been closed. readClosed bool @@ -398,6 +197,9 @@ type ConnectedEndpoint struct { // closed. writeClosed bool + // If srfd >= 0, it is the host FD that file was imported from. + srfd int `state:"wait"` + // stype is the type of Unix socket. stype unix.SockType @@ -407,7 +209,44 @@ type ConnectedEndpoint struct { // GetSockOpt and message splitting/rejection in SendMsg, but do not // prevent lots of small messages from filling the real send buffer // size on the host. - sndbuf int + sndbuf int `state:"nosave"` +} + +// init performs initialization required for creating new ConnectedEndpoints and +// for restoring them. +func (c *ConnectedEndpoint) init() *tcpip.Error { + family, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_DOMAIN) + if err != nil { + return translateError(err) + } + + if family != syscall.AF_UNIX { + // We only allow Unix sockets. + return tcpip.ErrInvalidEndpointState + } + + stype, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_TYPE) + if err != nil { + return translateError(err) + } + + if err := syscall.SetNonblock(c.file.FD(), true); err != nil { + return translateError(err) + } + + sndbuf, err := syscall.GetsockoptInt(c.file.FD(), syscall.SOL_SOCKET, syscall.SO_SNDBUF) + if err != nil { + return translateError(err) + } + if sndbuf > maxSendBufferSize { + log.Warningf("Socket send buffer too large: %d", sndbuf) + return tcpip.ErrInvalidEndpointState + } + + c.stype = unix.SockType(stype) + c.sndbuf = sndbuf + + return nil } // Send implements unix.ConnectedEndpoint.Send. |