From 5d87d8865f8771c00b84717d40f27f8f93dda7ca Mon Sep 17 00:00:00 2001
From: Ian Gudger <igudger@google.com>
Date: Mon, 10 Dec 2018 17:55:45 -0800
Subject: Implement MSG_WAITALL

MSG_WAITALL requests that recv family calls do not perform short reads. It only
has an effect for SOCK_STREAM sockets, other types ignore it.

PiperOrigin-RevId: 224918540
Change-Id: Id97fbf972f1f7cbd4e08eec0138f8cbdf1c94fe7
---
 pkg/sentry/socket/epsocket/epsocket.go | 30 +++++++++++++++++++++++++++---
 1 file changed, 27 insertions(+), 3 deletions(-)

(limited to 'pkg/sentry/socket/epsocket')

diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index e1cda78c4..b49ef21ad 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -1300,6 +1300,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
 func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
 	trunc := flags&linux.MSG_TRUNC != 0
 	peek := flags&linux.MSG_PEEK != 0
+	dontWait := flags&linux.MSG_DONTWAIT != 0
+	waitAll := flags&linux.MSG_WAITALL != 0
 	if senderRequested && !s.isPacketBased() {
 		// Stream sockets ignore the sender address.
 		senderRequested = false
@@ -1311,10 +1313,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
 		return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
 	}
 
-	if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+	if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
+		// Read failed and we should not retry.
+		return 0, nil, 0, socket.ControlMessages{}, err
+	}
+
+	if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
+		// We got all the data we need.
 		return
 	}
 
+	// Don't overwrite any data we received.
+	dst = dst.DropFirst(n)
+
 	// We'll have to block. Register for notifications and keep trying to
 	// send all the data.
 	e, ch := waiter.NewChannelEntry(nil)
@@ -1322,10 +1333,23 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
 	defer s.EventUnregister(&e)
 
 	for {
-		n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
-		if err != syserr.ErrWouldBlock {
+		var rn int
+		rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+		n += rn
+		if err != nil && err != syserr.ErrWouldBlock {
+			// Always stop on errors other than would block as we generally
+			// won't be able to get any more data. Eat the error if we got
+			// any data.
+			if n > 0 {
+				err = nil
+			}
+			return
+		}
+		if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) {
+			// We got all the data we need.
 			return
 		}
+		dst = dst.DropFirst(rn)
 
 		if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
 			if err == syserror.ETIMEDOUT {
-- 
cgit v1.2.3