diff options
Diffstat (limited to 'pkg/sentry/fs/host/socket.go')
-rw-r--r-- | pkg/sentry/fs/host/socket.go | 73 |
1 files changed, 34 insertions, 39 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 3ed137006..305eea718 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -15,9 +15,11 @@ package host import ( + "fmt" "sync" "syscall" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/fd" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" "gvisor.googlesource.com/gvisor/pkg/log" @@ -51,25 +53,11 @@ type ConnectedEndpoint struct { // ref keeps track of references to a connectedEndpoint. ref refs.AtomicRefCount - // mu protects fd, readClosed and writeClosed. - mu sync.RWMutex `state:"nosave"` - - // file is an *fd.FD containing the FD backing this endpoint. It must be - // set to nil if it has been closed. - file *fd.FD `state:"nosave"` - - // readClosed is true if the FD has read shutdown or if it has been closed. - readClosed bool - - // writeClosed is true if the FD has write shutdown or if it has been - // closed. - writeClosed bool - // If srfd >= 0, it is the host FD that file was imported from. srfd int `state:"wait"` // stype is the type of Unix socket. - stype transport.SockType + stype linux.SockType // sndbuf is the size of the send buffer. // @@ -78,6 +66,13 @@ type ConnectedEndpoint struct { // prevent lots of small messages from filling the real send buffer // size on the host. sndbuf int `state:"nosave"` + + // mu protects the fields below. + mu sync.RWMutex `state:"nosave"` + + // file is an *fd.FD containing the FD backing this endpoint. It must be + // set to nil if it has been closed. + file *fd.FD `state:"nosave"` } // init performs initialization required for creating new ConnectedEndpoints and @@ -111,7 +106,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { return syserr.ErrInvalidEndpointState } - c.stype = transport.SockType(stype) + c.stype = linux.SockType(stype) c.sndbuf = sndbuf return nil @@ -169,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -201,16 +196,13 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil + return unixsocket.New(ctx, ep, e.stype), nil } // Send implements transport.ConnectedEndpoint.Send. func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() - if c.writeClosed { - return 0, false, syserr.ErrClosedForSend - } if !controlMessages.Empty() { return 0, false, syserr.ErrInvalidEndpointState @@ -218,7 +210,7 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.Contro // Since stream sockets don't preserve message boundaries, we can write // only as much of the message as fits in the send buffer. - truncate := c.stype == transport.SockStream + truncate := c.stype == linux.SOCK_STREAM n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate) if n < totalLen && err == nil { @@ -244,8 +236,13 @@ func (c *ConnectedEndpoint) SendNotify() {} // CloseSend implements transport.ConnectedEndpoint.CloseSend. func (c *ConnectedEndpoint) CloseSend() { c.mu.Lock() - c.writeClosed = true - c.mu.Unlock() + defer c.mu.Unlock() + + if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_WR); err != nil { + // A well-formed UDS shutdown can't fail. See + // net/unix/af_unix.c:unix_shutdown. + panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err)) + } } // CloseNotify implements transport.ConnectedEndpoint.CloseNotify. @@ -255,9 +252,7 @@ func (c *ConnectedEndpoint) CloseNotify() {} func (c *ConnectedEndpoint) Writable() bool { c.mu.RLock() defer c.mu.RUnlock() - if c.writeClosed { - return true - } + return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0 } @@ -285,9 +280,6 @@ func (c *ConnectedEndpoint) EventUpdate() { func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { c.mu.RLock() defer c.mu.RUnlock() - if c.readClosed { - return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive - } var cm unet.ControlMessage if numRights > 0 { @@ -344,31 +336,34 @@ func (c *ConnectedEndpoint) RecvNotify() {} // CloseRecv implements transport.Receiver.CloseRecv. func (c *ConnectedEndpoint) CloseRecv() { c.mu.Lock() - c.readClosed = true - c.mu.Unlock() + defer c.mu.Unlock() + + if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_RD); err != nil { + // A well-formed UDS shutdown can't fail. See + // net/unix/af_unix.c:unix_shutdown. + panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err)) + } } // Readable implements transport.Receiver.Readable. func (c *ConnectedEndpoint) Readable() bool { c.mu.RLock() defer c.mu.RUnlock() - if c.readClosed { - return true - } + return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0 } // SendQueuedSize implements transport.Receiver.SendQueuedSize. func (c *ConnectedEndpoint) SendQueuedSize() int64 { - // SendQueuedSize isn't supported for host sockets because we don't allow the - // sentry to call ioctl(2). + // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host + // sockets because we don't allow the sentry to call ioctl(2). return -1 } // RecvQueuedSize implements transport.Receiver.RecvQueuedSize. func (c *ConnectedEndpoint) RecvQueuedSize() int64 { - // RecvQueuedSize isn't supported for host sockets because we don't allow the - // sentry to call ioctl(2). + // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host + // sockets because we don't allow the sentry to call ioctl(2). return -1 } |