summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD1
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go64
2 files changed, 59 insertions, 6 deletions
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index b0351b363..8973453f9 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -34,6 +34,7 @@ go_library(
"//pkg/sentry/usermem",
"//pkg/syserr",
"//pkg/syserror",
+ "//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/transport/unix",
"//pkg/unet",
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index b4b380ac6..f641f25df 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -33,6 +33,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
"gvisor.googlesource.com/gvisor/pkg/waiter"
@@ -52,6 +53,11 @@ type socketOperations struct {
wq *waiter.Queue
rpcConn *conn.RPCConnection
notifier *notifier.Notifier
+
+ // shState is the state of the connection with respect to shutdown. Because
+ // we're mixing non-blocking semantics on the other side we have to adapt for
+ // some strange differences between blocking and non-blocking sockets.
+ shState tcpip.ShutdownFlags
}
// Verify that we actually implement socket.Socket.
@@ -96,6 +102,31 @@ func translateIOSyscallError(err error) error {
return err
}
+// setShutdownFlags will set the shutdown flag so we can handle blocking reads
+// after a read shutdown.
+func (s *socketOperations) setShutdownFlags(how int) {
+ switch how {
+ case linux.SHUT_RD:
+ s.shState |= tcpip.ShutdownRead
+ case linux.SHUT_WR:
+ s.shState |= tcpip.ShutdownWrite
+ case linux.SHUT_RDWR:
+ s.shState |= tcpip.ShutdownWrite | tcpip.ShutdownRead
+ }
+}
+
+func (s *socketOperations) resetShutdownFlags() {
+ s.shState = 0
+}
+
+func (s *socketOperations) isShutRdSet() bool {
+ return s.shState&tcpip.ShutdownRead != 0
+}
+
+func (s *socketOperations) isShutWrSet() bool {
+ return s.shState&tcpip.ShutdownWrite != 0
+}
+
// Release implements fs.FileOperations.Release.
func (s *socketOperations) Release() {
s.notifier.RemoveFD(s.fd)
@@ -191,7 +222,12 @@ func rpcConnect(t *kernel.Task, fd uint32, sockaddr []byte) *syserr.Error {
// Connect implements socket.Socket.Connect.
func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
if !blocking {
- return rpcConnect(t, s.fd, sockaddr)
+ e := rpcConnect(t, s.fd, sockaddr)
+ if e == nil {
+ // Reset the shutdown state on new connects.
+ s.resetShutdownFlags()
+ }
+ return e
}
// Register for notification when the endpoint becomes writable, then
@@ -201,6 +237,10 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
defer s.EventUnregister(&e)
for {
if err := rpcConnect(t, s.fd, sockaddr); err == nil || err != syserr.ErrInProgress && err != syserr.ErrAlreadyInProgress {
+ if err == nil {
+ // Reset the shutdown state on new connects.
+ s.resetShutdownFlags()
+ }
return err
}
@@ -314,6 +354,11 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Shutdown).Shutdown.ErrorNumber; e != 0 {
return syserr.FromHost(syscall.Errno(e))
}
+
+ // We save the shutdown state because of strange differences on linux
+ // related to recvs on blocking vs. non-blocking sockets after a SHUT_RD.
+ // We need to emulate that behavior on the blocking side.
+ s.setShutdownFlags(how)
return nil
}
@@ -511,11 +556,12 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re
// RecvMsg implements socket.Socket.RecvMsg.
func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
- Fd: s.fd,
- Length: uint32(dst.NumBytes()),
- Sender: senderRequested,
- Trunc: flags&linux.MSG_TRUNC != 0,
- Peek: flags&linux.MSG_PEEK != 0,
+ Fd: s.fd,
+ Length: uint32(dst.NumBytes()),
+ Sender: senderRequested,
+ Trunc: flags&linux.MSG_TRUNC != 0,
+ Peek: flags&linux.MSG_PEEK != 0,
+ CmsgLength: uint32(controlDataLen),
}}
res, err := rpcRecvMsg(t, req)
@@ -559,6 +605,12 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, nil, 0, socket.ControlMessages{}, err
}
+ if s.isShutRdSet() {
+ // Blocking would have caused us to block indefinitely so we return 0,
+ // this is the same behavior as Linux.
+ return 0, nil, 0, socket.ControlMessages{}, nil
+ }
+
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if err == syserror.ETIMEDOUT {
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain