diff options
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 45 |
1 files changed, 25 insertions, 20 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f370b803b..23138d874 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -376,7 +376,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS if dst.NumBytes() == 0 { return 0, nil } - n, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) + n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) if err == syserr.ErrWouldBlock { return int64(n), syserror.ErrWouldBlock } @@ -1696,7 +1696,7 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq // nonBlockingRead issues a non-blocking read. // // TODO: Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { +func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, interface{}, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -1712,14 +1712,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) s.readMu.Unlock() - return n, nil, 0, socket.ControlMessages{}, err + return n, 0, nil, 0, socket.ControlMessages{}, err } s.readMu.Lock() defer s.readMu.Unlock() if err := s.fetchReadView(); err != nil { - return 0, nil, 0, socket.ControlMessages{}, err + return 0, 0, nil, 0, socket.ControlMessages{}, err } if !isPacket && peek && trunc { @@ -1727,14 +1727,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe // amount that could be read. var rql tcpip.ReceiveQueueSizeOption if err := s.Endpoint.GetSockOpt(&rql); err != nil { - return 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } available := len(s.readView) + int(rql) bufLen := int(dst.NumBytes()) if available < bufLen { - return available, nil, 0, socket.ControlMessages{}, nil + return available, 0, nil, 0, socket.ControlMessages{}, nil } - return bufLen, nil, 0, socket.ControlMessages{}, nil + return bufLen, 0, nil, 0, socket.ControlMessages{}, nil } n, err := dst.CopyOut(ctx, s.readView) @@ -1751,11 +1751,11 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe if peek { if l := len(s.readView); trunc && l > n { // isPacket must be true. - return l, addr, addrLen, s.controlMessages(), syserr.FromError(err) + return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err) } if isPacket || err != nil { - return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err) + return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err) } // We need to peek beyond the first message. @@ -1773,7 +1773,7 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe // We got some data, so no need to return an error. err = nil } - return int(n), nil, 0, s.controlMessages(), syserr.FromError(err) + return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err) } var msgLen int @@ -1785,11 +1785,16 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + var flags int + if msgLen > int(n) { + flags |= linux.MSG_TRUNC + } + if trunc { - return msgLen, addr, addrLen, s.controlMessages(), syserr.FromError(err) + n = msgLen } - return int(n), addr, addrLen, s.controlMessages(), syserr.FromError(err) + return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err) } func (s *SocketOperations) controlMessages() socket.ControlMessages { @@ -1810,7 +1815,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -1819,16 +1824,16 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // Stream sockets ignore the sender address. senderRequested = false } - n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) + n, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) if s.isPacketBased() && err == syserr.ErrClosedForReceive && flags&linux.MSG_DONTWAIT != 0 { // In this situation we should return EAGAIN. - return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } if err != nil && (err != syserr.ErrWouldBlock || dontWait) { // Read failed and we should not retry. - return 0, nil, 0, socket.ControlMessages{}, err + return 0, 0, nil, 0, socket.ControlMessages{}, err } if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) { @@ -1847,7 +1852,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags for { var rn int - rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) + rn, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested) n += rn if err != nil && err != syserr.ErrWouldBlock { // Always stop on errors other than would block as we generally @@ -1866,12 +1871,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { if n > 0 { - return n, senderAddr, senderAddrLen, controlMessages, nil + return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil } if err == syserror.ETIMEDOUT { - return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain } - return 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } } } |