diff options
author | Ian Gudger <igudger@google.com> | 2018-12-10 17:55:45 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-12-10 17:56:34 -0800 |
commit | 5d87d8865f8771c00b84717d40f27f8f93dda7ca (patch) | |
tree | 86362c16f38874bf4bf3b98e7c47fc1cbba7e396 /pkg/sentry/socket/epsocket | |
parent | d3bc79bc8438206ac6a14fde4eaa288fc07eee82 (diff) |
Implement MSG_WAITALL
MSG_WAITALL requests that recv family calls do not perform short reads. It only
has an effect for SOCK_STREAM sockets, other types ignore it.
PiperOrigin-RevId: 224918540
Change-Id: Id97fbf972f1f7cbd4e08eec0138f8cbdf1c94fe7
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e1cda78c4..b49ef21ad 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -1300,6 +1300,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 + dontWait := flags&linux.MSG_DONTWAIT != 0 + waitAll := flags&linux.MSG_WAITALL != 0 if senderRequested && !s.isPacketBased() { // Stream sockets ignore the sender address. senderRequested = false @@ -1311,10 +1313,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if err != nil && (err != syserr.ErrWouldBlock || dontWait) { + // Read failed and we should not retry. + return 0, nil, 0, socket.ControlMessages{}, err + } + + if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) { + // We got all the data we need. return } + // Don't overwrite any data we received. + dst = dst.DropFirst(n) + // We'll have to block. Register for notifications and keep trying to // send all the data. e, ch := waiter.NewChannelEntry(nil) @@ -1322,10 +1333,23 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) - if err != syserr.ErrWouldBlock { + var rn int + rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) + n += rn + if err != nil && err != syserr.ErrWouldBlock { + // Always stop on errors other than would block as we generally + // won't be able to get any more data. Eat the error if we got + // any data. + if n > 0 { + err = nil + } + return + } + if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) { + // We got all the data we need. return } + dst = dst.DropFirst(rn) if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if err == syserror.ETIMEDOUT { |