diff options
-rw-r--r-- | pkg/sentry/fs/host/socket.go | 4 | ||||
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 30 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 51 | ||||
-rw-r--r-- | pkg/sentry/syscalls/linux/sys_socket.go | 6 | ||||
-rw-r--r-- | test/syscalls/linux/socket_generic.cc | 11 | ||||
-rw-r--r-- | test/syscalls/linux/socket_non_stream_blocking.cc | 2 | ||||
-rw-r--r-- | test/syscalls/linux/socket_stream_blocking.cc | 2 |
7 files changed, 79 insertions, 27 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 506be3056..b9e2aa705 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -169,7 +169,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -201,7 +201,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep), nil + return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil } // Send implements transport.ConnectedEndpoint.Send. diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e1cda78c4..b49ef21ad 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -1300,6 +1300,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 + dontWait := flags&linux.MSG_DONTWAIT != 0 + waitAll := flags&linux.MSG_WAITALL != 0 if senderRequested && !s.isPacketBased() { // Stream sockets ignore the sender address. senderRequested = false @@ -1311,10 +1313,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if err != nil && (err != syserr.ErrWouldBlock || dontWait) { + // Read failed and we should not retry. + return 0, nil, 0, socket.ControlMessages{}, err + } + + if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) { + // We got all the data we need. return } + // Don't overwrite any data we received. + dst = dst.DropFirst(n) + // We'll have to block. Register for notifications and keep trying to // send all the data. e, ch := waiter.NewChannelEntry(nil) @@ -1322,10 +1333,23 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) - if err != syserr.ErrWouldBlock { + var rn int + rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) + n += rn + if err != nil && err != syserr.ErrWouldBlock { + // Always stop on errors other than would block as we generally + // won't be able to get any more data. Eat the error if we got + // any data. + if n > 0 { + err = nil + } + return + } + if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) { + // We got all the data we need. return } + dst = dst.DropFirst(rn) if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if err == syserror.ETIMEDOUT { diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 4379486cf..11cad411d 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -53,19 +53,21 @@ type SocketOperations struct { fsutil.NoopFlush `state:"nosave"` fsutil.NoMMap `state:"nosave"` ep transport.Endpoint + isPacket bool } // New creates a new unix socket. -func New(ctx context.Context, endpoint transport.Endpoint) *fs.File { +func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) defer dirent.DecRef() - return NewWithDirent(ctx, dirent, endpoint, fs.FileFlags{Read: true, Write: true}) + return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true}) } // NewWithDirent creates a new unix socket using an existing dirent. -func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, flags fs.FileFlags) *fs.File { +func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File { return fs.NewFile(ctx, d, flags, &SocketOperations{ - ep: ep, + ep: ep, + isPacket: isPacket, }) } @@ -188,7 +190,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns := New(t, ep) + ns := New(t, ep, s.isPacket) defer ns.DecRef() if flags&linux.SOCK_NONBLOCK != 0 { @@ -471,6 +473,8 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 + dontWait := flags&linux.MSG_DONTWAIT != 0 + waitAll := flags&linux.MSG_WAITALL != 0 // Calculate the number of FDs for which we have space and if we are // requesting credentials. @@ -497,7 +501,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if senderRequested { r.From = &tcpip.FullAddress{} } - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + var total int64 + if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { var from interface{} var fromLen uint32 if r.From != nil { @@ -506,7 +511,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if trunc { n = int64(r.MsgSize) } - return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) + if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { + return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) + } + + // Don't overwrite any data we received. + dst = dst.DropFirst64(n) + total += n } // We'll have to block. Register for notification and keep trying to @@ -525,7 +536,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if trunc { n = int64(r.MsgSize) } - return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) + total += n + if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { + return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) + } + + // Don't overwrite any data we received. + dst = dst.DropFirst64(n) } if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { @@ -549,16 +566,21 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) // Create the endpoint and socket. var ep transport.Endpoint + var isPacket bool switch stype { case linux.SOCK_DGRAM: + isPacket = true ep = transport.NewConnectionless() - case linux.SOCK_STREAM, linux.SOCK_SEQPACKET: + case linux.SOCK_SEQPACKET: + isPacket = true + fallthrough + case linux.SOCK_STREAM: ep = transport.NewConnectioned(stype, t.Kernel()) default: return nil, syserr.ErrInvalidArgument } - return New(t, ep), nil + return New(t, ep, isPacket), nil } // Pair creates a new pair of AF_UNIX connected sockets. @@ -568,16 +590,19 @@ func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (* return nil, nil, syserr.ErrInvalidArgument } + var isPacket bool switch stype { - case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + case linux.SOCK_STREAM: + case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + isPacket = true default: return nil, nil, syserr.ErrInvalidArgument } // Create the endpoints and sockets. ep1, ep2 := transport.NewPair(stype, t.Kernel()) - s1 := New(t, ep1) - s2 := New(t, ep2) + s1 := New(t, ep1, isPacket) + s2 := New(t, ep2, isPacket) return s1, s2, nil } diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 0a7551742..1165d4566 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -602,7 +602,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } // Reject flags that we don't handle yet. - if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { + if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE|linux.MSG_WAITALL) != 0 { return 0, nil, syscall.EINVAL } @@ -635,7 +635,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc } // Reject flags that we don't handle yet. - if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 { + if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE|linux.MSG_WAITALL) != 0 { return 0, nil, syscall.EINVAL } @@ -791,7 +791,7 @@ func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, f } // Reject flags that we don't handle yet. - if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CONFIRM) != 0 { + if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CONFIRM|linux.MSG_WAITALL) != 0 { return 0, syscall.EINVAL } diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index fbc3bebed..fdc346d4d 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -383,8 +383,6 @@ TEST_P(AllSocketPairTest, RecvmsgTimeoutOneSecondSucceeds) { } TEST_P(AllSocketPairTest, RecvWaitAll) { - SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL. - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[100]; @@ -399,5 +397,14 @@ TEST_P(AllSocketPairTest, RecvWaitAll) { SyscallSucceedsWithValue(sizeof(sent_data))); } +TEST_P(AllSocketPairTest, RecvWaitAllDontWait) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + char data[100] = {}; + ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), data, sizeof(data), + MSG_WAITALL | MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc index d64b181c9..9e92628c3 100644 --- a/test/syscalls/linux/socket_non_stream_blocking.cc +++ b/test/syscalls/linux/socket_non_stream_blocking.cc @@ -31,8 +31,6 @@ namespace gvisor { namespace testing { TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) { - SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL. - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[100]; diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc index dd209c67c..3fbbe54d8 100644 --- a/test/syscalls/linux/socket_stream_blocking.cc +++ b/test/syscalls/linux/socket_stream_blocking.cc @@ -99,8 +99,6 @@ TEST_P(BlockingStreamSocketPairTest, RecvLessThanBuffer) { } TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) { - SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL. - auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[100]; |