summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/epsocket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go45
1 files changed, 25 insertions, 20 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index f370b803b..23138d874 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -376,7 +376,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
if dst.NumBytes() == 0 {
return 0, nil
}
- n, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
if err == syserr.ErrWouldBlock {
return int64(n), syserror.ErrWouldBlock
}
@@ -1696,7 +1696,7 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
// nonBlockingRead issues a non-blocking read.
//
// TODO: Support timestamps for stream sockets.
-func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased()
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
@@ -1712,14 +1712,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
s.readMu.Unlock()
- return n, nil, 0, socket.ControlMessages{}, err
+ return n, 0, nil, 0, socket.ControlMessages{}, err
}
s.readMu.Lock()
defer s.readMu.Unlock()
if err := s.fetchReadView(); err != nil {
- return 0, nil, 0, socket.ControlMessages{}, err
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
}
if !isPacket && peek && trunc {
@@ -1727,14 +1727,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// amount that could be read.
var rql tcpip.ReceiveQueueSizeOption
if err := s.Endpoint.GetSockOpt(&rql); err != nil {
- return 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
available := len(s.readView) + int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
- return available, nil, 0, socket.ControlMessages{}, nil
+ return available, 0, nil, 0, socket.ControlMessages{}, nil
}
- return bufLen, nil, 0, socket.ControlMessages{}, nil
+ return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
}
n, err := dst.CopyOut(ctx, s.readView)
@@ -1751,11 +1751,11 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
if peek {
if l := len(s.readView); trunc && l > n {
// isPacket must be true.
- return l, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
}
if isPacket || err != nil {
- return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
}
// We need to peek beyond the first message.
@@ -1773,7 +1773,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// We got some data, so no need to return an error.
err = nil
}
- return int(n), nil, 0, s.controlMessages(), syserr.FromError(err)
+ return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
}
var msgLen int
@@ -1785,11 +1785,16 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readView.TrimFront(int(n))
}
+ var flags int
+ if msgLen > int(n) {
+ flags |= linux.MSG_TRUNC
+ }
+
if trunc {
- return msgLen, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ n = msgLen
}
- return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err)
}
func (s *SocketOperations) controlMessages() socket.ControlMessages {
@@ -1810,7 +1815,7 @@ func (s *SocketOperations) updateTimestamp() {
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -1819,16 +1824,16 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// Stream sockets ignore the sender address.
senderRequested = false
}
- n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ n, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
if s.isPacketBased() && err == syserr.ErrClosedForReceive && flags&linux.MSG_DONTWAIT != 0 {
// In this situation we should return EAGAIN.
- return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
// Read failed and we should not retry.
- return 0, nil, 0, socket.ControlMessages{}, err
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
}
if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
@@ -1847,7 +1852,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
for {
var rn int
- rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ rn, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
n += rn
if err != nil && err != syserr.ErrWouldBlock {
// Always stop on errors other than would block as we generally
@@ -1866,12 +1871,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if n > 0 {
- return n, senderAddr, senderAddrLen, controlMessages, nil
+ return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil
}
if err == syserror.ETIMEDOUT {
- return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
}
- return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
}
}