diff options
author | Brian Geffon <bgeffon@google.com> | 2018-12-04 18:14:17 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-12-04 18:15:10 -0800 |
commit | ffcbda0c8bd772c9019977775daf1d86891c3f28 (patch) | |
tree | d1ba294f99df79f301e01d82f6916286899e0789 /pkg/sentry/socket/rpcinet | |
parent | d209f71b9f1b6ab57684240112553aa8c700f929 (diff) |
Partial writes should loop in rpcinet.
FileOperations.Write should return ErrWouldBlock to allow the upper
layer to loop and sendmsg should continue writing where it left off
on a partial write.
PiperOrigin-RevId: 224081631
Change-Id: Ic61f6943ea6b7abbd82e4279decea215347eac48
Diffstat (limited to 'pkg/sentry/socket/rpcinet')
-rw-r--r-- | pkg/sentry/socket/rpcinet/socket.go | 39 |
1 files changed, 33 insertions, 6 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 7328661ab..90844f10f 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -212,6 +212,11 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}}) + if n > 0 && n < uint32(src.NumBytes()) { + // The FileOperations.Write interface expects us to return ErrWouldBlock in + // the event of a partial write. + return int64(n), syserror.ErrWouldBlock + } return int64(n), err.ToError() } @@ -735,19 +740,24 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // TODO: this needs to change to map directly to a SendMsg syscall // in the RPC. - req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ + totalWritten := 0 + n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ Fd: uint32(s.fd), Data: v, Address: to, More: flags&linux.MSG_MORE != 0, EndOfRecord: flags&linux.MSG_EOR != 0, - }} + }}) - n, err := rpcSendMsg(t, req) if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { return int(n), err } + if n > 0 { + totalWritten += int(n) + v.TrimFront(int(n)) + } + // We'll have to block. Register for notification and keep trying to // send all the data. e, ch := waiter.NewChannelEntry(nil) @@ -755,13 +765,30 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] defer s.EventUnregister(&e) for { - n, err := rpcSendMsg(t, req) + n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ + Fd: uint32(s.fd), + Data: v, + Address: to, + More: flags&linux.MSG_MORE != 0, + EndOfRecord: flags&linux.MSG_EOR != 0, + }}) + + if n > 0 { + totalWritten += int(n) + v.TrimFront(int(n)) + + if err == nil && totalWritten < int(src.NumBytes()) { + continue + } + } + if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { - return int(n), err + // We eat the error in this situation. + return int(totalWritten), nil } if err := t.Block(ch); err != nil { - return 0, syserr.FromError(err) + return int(totalWritten), syserr.FromError(err) } } } |