From 5d87d8865f8771c00b84717d40f27f8f93dda7ca Mon Sep 17 00:00:00 2001
From: Ian Gudger <igudger@google.com>
Date: Mon, 10 Dec 2018 17:55:45 -0800
Subject: Implement MSG_WAITALL

MSG_WAITALL requests that recv family calls do not perform short reads. It only
has an effect for SOCK_STREAM sockets, other types ignore it.

PiperOrigin-RevId: 224918540
Change-Id: Id97fbf972f1f7cbd4e08eec0138f8cbdf1c94fe7
---
 pkg/sentry/socket/unix/unix.go | 51 +++++++++++++++++++++++++++++++-----------
 1 file changed, 38 insertions(+), 13 deletions(-)

(limited to 'pkg/sentry/socket/unix')

diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4379486cf..11cad411d 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -53,19 +53,21 @@ type SocketOperations struct {
 	fsutil.NoopFlush     `state:"nosave"`
 	fsutil.NoMMap        `state:"nosave"`
 	ep                   transport.Endpoint
+	isPacket             bool
 }
 
 // New creates a new unix socket.
-func New(ctx context.Context, endpoint transport.Endpoint) *fs.File {
+func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
 	dirent := socket.NewDirent(ctx, unixSocketDevice)
 	defer dirent.DecRef()
-	return NewWithDirent(ctx, dirent, endpoint, fs.FileFlags{Read: true, Write: true})
+	return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
 }
 
 // NewWithDirent creates a new unix socket using an existing dirent.
-func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, flags fs.FileFlags) *fs.File {
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
 	return fs.NewFile(ctx, d, flags, &SocketOperations{
-		ep: ep,
+		ep:       ep,
+		isPacket: isPacket,
 	})
 }
 
@@ -188,7 +190,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
 		}
 	}
 
-	ns := New(t, ep)
+	ns := New(t, ep, s.isPacket)
 	defer ns.DecRef()
 
 	if flags&linux.SOCK_NONBLOCK != 0 {
@@ -471,6 +473,8 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
 func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
 	trunc := flags&linux.MSG_TRUNC != 0
 	peek := flags&linux.MSG_PEEK != 0
+	dontWait := flags&linux.MSG_DONTWAIT != 0
+	waitAll := flags&linux.MSG_WAITALL != 0
 
 	// Calculate the number of FDs for which we have space and if we are
 	// requesting credentials.
@@ -497,7 +501,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
 	if senderRequested {
 		r.From = &tcpip.FullAddress{}
 	}
-	if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+	var total int64
+	if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
 		var from interface{}
 		var fromLen uint32
 		if r.From != nil {
@@ -506,7 +511,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
 		if trunc {
 			n = int64(r.MsgSize)
 		}
-		return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+		if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
+			return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+		}
+
+		// Don't overwrite any data we received.
+		dst = dst.DropFirst64(n)
+		total += n
 	}
 
 	// We'll have to block. Register for notification and keep trying to
@@ -525,7 +536,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
 			if trunc {
 				n = int64(r.MsgSize)
 			}
-			return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+			total += n
+			if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+				return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+			}
+
+			// Don't overwrite any data we received.
+			dst = dst.DropFirst64(n)
 		}
 
 		if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
@@ -549,16 +566,21 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int)
 
 	// Create the endpoint and socket.
 	var ep transport.Endpoint
+	var isPacket bool
 	switch stype {
 	case linux.SOCK_DGRAM:
+		isPacket = true
 		ep = transport.NewConnectionless()
-	case linux.SOCK_STREAM, linux.SOCK_SEQPACKET:
+	case linux.SOCK_SEQPACKET:
+		isPacket = true
+		fallthrough
+	case linux.SOCK_STREAM:
 		ep = transport.NewConnectioned(stype, t.Kernel())
 	default:
 		return nil, syserr.ErrInvalidArgument
 	}
 
-	return New(t, ep), nil
+	return New(t, ep, isPacket), nil
 }
 
 // Pair creates a new pair of AF_UNIX connected sockets.
@@ -568,16 +590,19 @@ func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*
 		return nil, nil, syserr.ErrInvalidArgument
 	}
 
+	var isPacket bool
 	switch stype {
-	case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+	case linux.SOCK_STREAM:
+	case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+		isPacket = true
 	default:
 		return nil, nil, syserr.ErrInvalidArgument
 	}
 
 	// Create the endpoints and sockets.
 	ep1, ep2 := transport.NewPair(stype, t.Kernel())
-	s1 := New(t, ep1)
-	s2 := New(t, ep2)
+	s1 := New(t, ep1, isPacket)
+	s2 := New(t, ep2, isPacket)
 
 	return s1, s2, nil
 }
-- 
cgit v1.2.3