diff options
Diffstat (limited to 'pkg/sentry/socket/unix/unix.go')
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 92411c901..01efd24d3 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -477,7 +477,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -515,11 +515,17 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if r.From != nil { from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } - if trunc { - n = int64(r.MsgSize) - } + if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { - return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) + if s.isPacket && n < int64(r.MsgSize) { + msgFlags |= linux.MSG_TRUNC + } + + if trunc { + n = int64(r.MsgSize) + } + + return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) } // Don't overwrite any data we received. @@ -541,14 +547,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) } if trunc { - n = int64(r.MsgSize) + // n and r.MsgSize are the same for streams. + total += int64(r.MsgSize) + } else { + total += n } - total += n if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { if total > 0 { err = nil } - return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) + if s.isPacket && n < int64(r.MsgSize) { + msgFlags |= linux.MSG_TRUNC + } + return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) } // Don't overwrite any data we received. @@ -560,9 +571,9 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags err = nil } if err == syserror.ETIMEDOUT { - return int(total), nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - return int(total), nil, 0, socket.ControlMessages{}, syserr.FromError(err) + return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } } } |