diff options
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 39 |
1 files changed, 33 insertions, 6 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 7328661ab..90844f10f 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -212,6 +212,11 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}}) + if n > 0 && n < uint32(src.NumBytes()) { + // The FileOperations.Write interface expects us to return ErrWouldBlock in + // the event of a partial write. + return int64(n), syserror.ErrWouldBlock + } return int64(n), err.ToError() } @@ -735,19 +740,24 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // TODO: this needs to change to map directly to a SendMsg syscall // in the RPC. - req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ + totalWritten := 0 + n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ Fd: uint32(s.fd), Data: v, Address: to, More: flags&linux.MSG_MORE != 0, EndOfRecord: flags&linux.MSG_EOR != 0, - }} + }}) - n, err := rpcSendMsg(t, req) if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { return int(n), err } + if n > 0 { + totalWritten += int(n) + v.TrimFront(int(n)) + } + // We'll have to block. Register for notification and keep trying to // send all the data. e, ch := waiter.NewChannelEntry(nil) @@ -755,13 +765,30 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] defer s.EventUnregister(&e) for { - n, err := rpcSendMsg(t, req) + n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ + Fd: uint32(s.fd), + Data: v, + Address: to, + More: flags&linux.MSG_MORE != 0, + EndOfRecord: flags&linux.MSG_EOR != 0, + }}) + + if n > 0 { + totalWritten += int(n) + v.TrimFront(int(n)) + + if err == nil && totalWritten < int(src.NumBytes()) { + continue + } + } + if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { - return int(n), err + // We eat the error in this situation. + return int(totalWritten), nil } if err := t.Block(ch); err != nil { - return 0, syserr.FromError(err) + return int(totalWritten), syserr.FromError(err) } } } |