summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/fs/host/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/fs/host/socket.go')
-rw-r--r--pkg/sentry/fs/host/socket.go62
1 files changed, 28 insertions, 34 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 3ed137006..e4ec0f62c 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -15,6 +15,7 @@
package host
import (
+ "fmt"
"sync"
"syscall"
@@ -51,20 +52,6 @@ type ConnectedEndpoint struct {
// ref keeps track of references to a connectedEndpoint.
ref refs.AtomicRefCount
- // mu protects fd, readClosed and writeClosed.
- mu sync.RWMutex `state:"nosave"`
-
- // file is an *fd.FD containing the FD backing this endpoint. It must be
- // set to nil if it has been closed.
- file *fd.FD `state:"nosave"`
-
- // readClosed is true if the FD has read shutdown or if it has been closed.
- readClosed bool
-
- // writeClosed is true if the FD has write shutdown or if it has been
- // closed.
- writeClosed bool
-
// If srfd >= 0, it is the host FD that file was imported from.
srfd int `state:"wait"`
@@ -78,6 +65,13 @@ type ConnectedEndpoint struct {
// prevent lots of small messages from filling the real send buffer
// size on the host.
sndbuf int `state:"nosave"`
+
+ // mu protects the fields below.
+ mu sync.RWMutex `state:"nosave"`
+
+ // file is an *fd.FD containing the FD backing this endpoint. It must be
+ // set to nil if it has been closed.
+ file *fd.FD `state:"nosave"`
}
// init performs initialization required for creating new ConnectedEndpoints and
@@ -208,9 +202,6 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.ControlMessages, from tcpip.FullAddress) (uintptr, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.writeClosed {
- return 0, false, syserr.ErrClosedForSend
- }
if !controlMessages.Empty() {
return 0, false, syserr.ErrInvalidEndpointState
@@ -244,8 +235,13 @@ func (c *ConnectedEndpoint) SendNotify() {}
// CloseSend implements transport.ConnectedEndpoint.CloseSend.
func (c *ConnectedEndpoint) CloseSend() {
c.mu.Lock()
- c.writeClosed = true
- c.mu.Unlock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_WR); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed write shutdown on host socket %+v: %v", c, err))
+ }
}
// CloseNotify implements transport.ConnectedEndpoint.CloseNotify.
@@ -255,9 +251,7 @@ func (c *ConnectedEndpoint) CloseNotify() {}
func (c *ConnectedEndpoint) Writable() bool {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.writeClosed {
- return true
- }
+
return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventOut)&waiter.EventOut != 0
}
@@ -285,9 +279,6 @@ func (c *ConnectedEndpoint) EventUpdate() {
func (c *ConnectedEndpoint) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, transport.ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.readClosed {
- return 0, 0, transport.ControlMessages{}, false, tcpip.FullAddress{}, false, syserr.ErrClosedForReceive
- }
var cm unet.ControlMessage
if numRights > 0 {
@@ -344,31 +335,34 @@ func (c *ConnectedEndpoint) RecvNotify() {}
// CloseRecv implements transport.Receiver.CloseRecv.
func (c *ConnectedEndpoint) CloseRecv() {
c.mu.Lock()
- c.readClosed = true
- c.mu.Unlock()
+ defer c.mu.Unlock()
+
+ if err := syscall.Shutdown(c.file.FD(), syscall.SHUT_RD); err != nil {
+ // A well-formed UDS shutdown can't fail. See
+ // net/unix/af_unix.c:unix_shutdown.
+ panic(fmt.Sprintf("failed read shutdown on host socket %+v: %v", c, err))
+ }
}
// Readable implements transport.Receiver.Readable.
func (c *ConnectedEndpoint) Readable() bool {
c.mu.RLock()
defer c.mu.RUnlock()
- if c.readClosed {
- return true
- }
+
return fdnotifier.NonBlockingPoll(int32(c.file.FD()), waiter.EventIn)&waiter.EventIn != 0
}
// SendQueuedSize implements transport.Receiver.SendQueuedSize.
func (c *ConnectedEndpoint) SendQueuedSize() int64 {
- // SendQueuedSize isn't supported for host sockets because we don't allow the
- // sentry to call ioctl(2).
+ // TODO(gvisor.dev/issue/273): SendQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
return -1
}
// RecvQueuedSize implements transport.Receiver.RecvQueuedSize.
func (c *ConnectedEndpoint) RecvQueuedSize() int64 {
- // RecvQueuedSize isn't supported for host sockets because we don't allow the
- // sentry to call ioctl(2).
+ // TODO(gvisor.dev/issue/273): RecvQueuedSize isn't supported for host
+ // sockets because we don't allow the sentry to call ioctl(2).
return -1
}