summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/unix/unix.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/unix/unix.go')
-rw-r--r--pkg/sentry/socket/unix/unix.go31
1 files changed, 21 insertions, 10 deletions
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 92411c901..01efd24d3 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -477,7 +477,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -515,11 +515,17 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if r.From != nil {
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
- if trunc {
- n = int64(r.MsgSize)
- }
+
if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
- return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ if s.isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+
+ return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
}
// Don't overwrite any data we received.
@@ -541,14 +547,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
}
if trunc {
- n = int64(r.MsgSize)
+ // n and r.MsgSize are the same for streams.
+ total += int64(r.MsgSize)
+ } else {
+ total += n
}
- total += n
if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil
}
- return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ if s.isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
}
// Don't overwrite any data we received.
@@ -560,9 +571,9 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
err = nil
}
if err == syserror.ETIMEDOUT {
- return int(total), nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
- return int(total), nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
}
}