summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/epsocket/epsocket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/epsocket/epsocket.go')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go30
1 files changed, 27 insertions, 3 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index e1cda78c4..b49ef21ad 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -1300,6 +1300,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
if senderRequested && !s.isPacketBased() {
// Stream sockets ignore the sender address.
senderRequested = false
@@ -1311,10 +1313,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
- if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
+ // Read failed and we should not retry.
+ return 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
+ // We got all the data we need.
return
}
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst(n)
+
// We'll have to block. Register for notifications and keep trying to
// send all the data.
e, ch := waiter.NewChannelEntry(nil)
@@ -1322,10 +1333,23 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
- if err != syserr.ErrWouldBlock {
+ var rn int
+ rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ n += rn
+ if err != nil && err != syserr.ErrWouldBlock {
+ // Always stop on errors other than would block as we generally
+ // won't be able to get any more data. Eat the error if we got
+ // any data.
+ if n > 0 {
+ err = nil
+ }
+ return
+ }
+ if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) {
+ // We got all the data we need.
return
}
+ dst = dst.DropFirst(rn)
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if err == syserror.ETIMEDOUT {