diff options
Diffstat (limited to 'pkg/sentry')
-rw-r--r-- | pkg/sentry/socket/rpcinet/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 64 |
2 files changed, 59 insertions, 6 deletions
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index b0351b363..8973453f9 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -34,6 +34,7 @@ go_library( "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", + "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/transport/unix", "//pkg/unet", diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index b4b380ac6..f641f25df 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -33,6 +33,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" + "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -52,6 +53,11 @@ type socketOperations struct { wq *waiter.Queue rpcConn *conn.RPCConnection notifier *notifier.Notifier + + // shState is the state of the connection with respect to shutdown. Because + // we're mixing non-blocking semantics on the other side we have to adapt for + // some strange differences between blocking and non-blocking sockets. + shState tcpip.ShutdownFlags } // Verify that we actually implement socket.Socket. @@ -96,6 +102,31 @@ func translateIOSyscallError(err error) error { return err } +// setShutdownFlags will set the shutdown flag so we can handle blocking reads +// after a read shutdown. +func (s *socketOperations) setShutdownFlags(how int) { + switch how { + case linux.SHUT_RD: + s.shState |= tcpip.ShutdownRead + case linux.SHUT_WR: + s.shState |= tcpip.ShutdownWrite + case linux.SHUT_RDWR: + s.shState |= tcpip.ShutdownWrite | tcpip.ShutdownRead + } +} + +func (s *socketOperations) resetShutdownFlags() { + s.shState = 0 +} + +func (s *socketOperations) isShutRdSet() bool { + return s.shState&tcpip.ShutdownRead != 0 +} + +func (s *socketOperations) isShutWrSet() bool { + return s.shState&tcpip.ShutdownWrite != 0 +} + // Release implements fs.FileOperations.Release. func (s *socketOperations) Release() { s.notifier.RemoveFD(s.fd) @@ -191,7 +222,12 @@ func rpcConnect(t *kernel.Task, fd uint32, sockaddr []byte) *syserr.Error { // Connect implements socket.Socket.Connect. func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { if !blocking { - return rpcConnect(t, s.fd, sockaddr) + e := rpcConnect(t, s.fd, sockaddr) + if e == nil { + // Reset the shutdown state on new connects. + s.resetShutdownFlags() + } + return e } // Register for notification when the endpoint becomes writable, then @@ -201,6 +237,10 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo defer s.EventUnregister(&e) for { if err := rpcConnect(t, s.fd, sockaddr); err == nil || err != syserr.ErrInProgress && err != syserr.ErrAlreadyInProgress { + if err == nil { + // Reset the shutdown state on new connects. + s.resetShutdownFlags() + } return err } @@ -314,6 +354,11 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Shutdown).Shutdown.ErrorNumber; e != 0 { return syserr.FromHost(syscall.Errno(e)) } + + // We save the shutdown state because of strange differences on linux + // related to recvs on blocking vs. non-blocking sockets after a SHUT_RD. + // We need to emulate that behavior on the blocking side. + s.setShutdownFlags(how) return nil } @@ -511,11 +556,12 @@ func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_Re // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{ - Fd: s.fd, - Length: uint32(dst.NumBytes()), - Sender: senderRequested, - Trunc: flags&linux.MSG_TRUNC != 0, - Peek: flags&linux.MSG_PEEK != 0, + Fd: s.fd, + Length: uint32(dst.NumBytes()), + Sender: senderRequested, + Trunc: flags&linux.MSG_TRUNC != 0, + Peek: flags&linux.MSG_PEEK != 0, + CmsgLength: uint32(controlDataLen), }} res, err := rpcRecvMsg(t, req) @@ -559,6 +605,12 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, nil, 0, socket.ControlMessages{}, err } + if s.isShutRdSet() { + // Blocking would have caused us to block indefinitely so we return 0, + // this is the same behavior as Linux. + return 0, nil, 0, socket.ControlMessages{}, nil + } + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if err == syserror.ETIMEDOUT { return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain |