summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/fs/host/socket.go4
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go30
-rw-r--r--pkg/sentry/socket/unix/unix.go51
-rw-r--r--pkg/sentry/syscalls/linux/sys_socket.go6
-rw-r--r--test/syscalls/linux/socket_generic.cc11
-rw-r--r--test/syscalls/linux/socket_non_stream_blocking.cc2
-rw-r--r--test/syscalls/linux/socket_stream_blocking.cc2
7 files changed, 79 insertions, 27 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 506be3056..b9e2aa705 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -169,7 +169,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
- return unixsocket.NewWithDirent(ctx, d, ep, flags), nil
+ return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil
}
// newSocket allocates a new unix socket with host endpoint.
@@ -201,7 +201,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
- return unixsocket.New(ctx, ep), nil
+ return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil
}
// Send implements transport.ConnectedEndpoint.Send.
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index e1cda78c4..b49ef21ad 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -1300,6 +1300,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
if senderRequested && !s.isPacketBased() {
// Stream sockets ignore the sender address.
senderRequested = false
@@ -1311,10 +1313,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
- if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
+ // Read failed and we should not retry.
+ return 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
+ // We got all the data we need.
return
}
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst(n)
+
// We'll have to block. Register for notifications and keep trying to
// send all the data.
e, ch := waiter.NewChannelEntry(nil)
@@ -1322,10 +1333,23 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
- if err != syserr.ErrWouldBlock {
+ var rn int
+ rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ n += rn
+ if err != nil && err != syserr.ErrWouldBlock {
+ // Always stop on errors other than would block as we generally
+ // won't be able to get any more data. Eat the error if we got
+ // any data.
+ if n > 0 {
+ err = nil
+ }
+ return
+ }
+ if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) {
+ // We got all the data we need.
return
}
+ dst = dst.DropFirst(rn)
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if err == syserror.ETIMEDOUT {
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4379486cf..11cad411d 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -53,19 +53,21 @@ type SocketOperations struct {
fsutil.NoopFlush `state:"nosave"`
fsutil.NoMMap `state:"nosave"`
ep transport.Endpoint
+ isPacket bool
}
// New creates a new unix socket.
-func New(ctx context.Context, endpoint transport.Endpoint) *fs.File {
+func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
defer dirent.DecRef()
- return NewWithDirent(ctx, dirent, endpoint, fs.FileFlags{Read: true, Write: true})
+ return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
}
// NewWithDirent creates a new unix socket using an existing dirent.
-func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, flags fs.FileFlags) *fs.File {
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
return fs.NewFile(ctx, d, flags, &SocketOperations{
- ep: ep,
+ ep: ep,
+ isPacket: isPacket,
})
}
@@ -188,7 +190,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns := New(t, ep)
+ ns := New(t, ep, s.isPacket)
defer ns.DecRef()
if flags&linux.SOCK_NONBLOCK != 0 {
@@ -471,6 +473,8 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
// Calculate the number of FDs for which we have space and if we are
// requesting credentials.
@@ -497,7 +501,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if senderRequested {
r.From = &tcpip.FullAddress{}
}
- if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ var total int64
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
var from interface{}
var fromLen uint32
if r.From != nil {
@@ -506,7 +511,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if trunc {
n = int64(r.MsgSize)
}
- return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ total += n
}
// We'll have to block. Register for notification and keep trying to
@@ -525,7 +536,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if trunc {
n = int64(r.MsgSize)
}
- return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ total += n
+ if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
}
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
@@ -549,16 +566,21 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int)
// Create the endpoint and socket.
var ep transport.Endpoint
+ var isPacket bool
switch stype {
case linux.SOCK_DGRAM:
+ isPacket = true
ep = transport.NewConnectionless()
- case linux.SOCK_STREAM, linux.SOCK_SEQPACKET:
+ case linux.SOCK_SEQPACKET:
+ isPacket = true
+ fallthrough
+ case linux.SOCK_STREAM:
ep = transport.NewConnectioned(stype, t.Kernel())
default:
return nil, syserr.ErrInvalidArgument
}
- return New(t, ep), nil
+ return New(t, ep, isPacket), nil
}
// Pair creates a new pair of AF_UNIX connected sockets.
@@ -568,16 +590,19 @@ func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*
return nil, nil, syserr.ErrInvalidArgument
}
+ var isPacket bool
switch stype {
- case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ case linux.SOCK_STREAM:
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ isPacket = true
default:
return nil, nil, syserr.ErrInvalidArgument
}
// Create the endpoints and sockets.
ep1, ep2 := transport.NewPair(stype, t.Kernel())
- s1 := New(t, ep1)
- s2 := New(t, ep2)
+ s1 := New(t, ep1, isPacket)
+ s2 := New(t, ep2, isPacket)
return s1, s2, nil
}
diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go
index 0a7551742..1165d4566 100644
--- a/pkg/sentry/syscalls/linux/sys_socket.go
+++ b/pkg/sentry/syscalls/linux/sys_socket.go
@@ -602,7 +602,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
}
// Reject flags that we don't handle yet.
- if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE|linux.MSG_WAITALL) != 0 {
return 0, nil, syscall.EINVAL
}
@@ -635,7 +635,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
}
// Reject flags that we don't handle yet.
- if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE|linux.MSG_WAITALL) != 0 {
return 0, nil, syscall.EINVAL
}
@@ -791,7 +791,7 @@ func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, f
}
// Reject flags that we don't handle yet.
- if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CONFIRM) != 0 {
+ if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CONFIRM|linux.MSG_WAITALL) != 0 {
return 0, syscall.EINVAL
}
diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc
index fbc3bebed..fdc346d4d 100644
--- a/test/syscalls/linux/socket_generic.cc
+++ b/test/syscalls/linux/socket_generic.cc
@@ -383,8 +383,6 @@ TEST_P(AllSocketPairTest, RecvmsgTimeoutOneSecondSucceeds) {
}
TEST_P(AllSocketPairTest, RecvWaitAll) {
- SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
-
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[100];
@@ -399,5 +397,14 @@ TEST_P(AllSocketPairTest, RecvWaitAll) {
SyscallSucceedsWithValue(sizeof(sent_data)));
}
+TEST_P(AllSocketPairTest, RecvWaitAllDontWait) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ char data[100] = {};
+ ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), data, sizeof(data),
+ MSG_WAITALL | MSG_DONTWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_non_stream_blocking.cc b/test/syscalls/linux/socket_non_stream_blocking.cc
index d64b181c9..9e92628c3 100644
--- a/test/syscalls/linux/socket_non_stream_blocking.cc
+++ b/test/syscalls/linux/socket_non_stream_blocking.cc
@@ -31,8 +31,6 @@ namespace gvisor {
namespace testing {
TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) {
- SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
-
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[100];
diff --git a/test/syscalls/linux/socket_stream_blocking.cc b/test/syscalls/linux/socket_stream_blocking.cc
index dd209c67c..3fbbe54d8 100644
--- a/test/syscalls/linux/socket_stream_blocking.cc
+++ b/test/syscalls/linux/socket_stream_blocking.cc
@@ -99,8 +99,6 @@ TEST_P(BlockingStreamSocketPairTest, RecvLessThanBuffer) {
}
TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) {
- SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
-
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
char sent_data[100];