summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/rpcinet/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/rpcinet/socket.go')
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go39
1 files changed, 33 insertions, 6 deletions
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 7328661ab..90844f10f 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -212,6 +212,11 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}})
+ if n > 0 && n < uint32(src.NumBytes()) {
+ // The FileOperations.Write interface expects us to return ErrWouldBlock in
+ // the event of a partial write.
+ return int64(n), syserror.ErrWouldBlock
+ }
return int64(n), err.ToError()
}
@@ -735,19 +740,24 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
// TODO: this needs to change to map directly to a SendMsg syscall
// in the RPC.
- req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
+ totalWritten := 0
+ n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
Fd: uint32(s.fd),
Data: v,
Address: to,
More: flags&linux.MSG_MORE != 0,
EndOfRecord: flags&linux.MSG_EOR != 0,
- }}
+ }})
- n, err := rpcSendMsg(t, req)
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 {
return int(n), err
}
+ if n > 0 {
+ totalWritten += int(n)
+ v.TrimFront(int(n))
+ }
+
// We'll have to block. Register for notification and keep trying to
// send all the data.
e, ch := waiter.NewChannelEntry(nil)
@@ -755,13 +765,30 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
defer s.EventUnregister(&e)
for {
- n, err := rpcSendMsg(t, req)
+ n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
+ Fd: uint32(s.fd),
+ Data: v,
+ Address: to,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }})
+
+ if n > 0 {
+ totalWritten += int(n)
+ v.TrimFront(int(n))
+
+ if err == nil && totalWritten < int(src.NumBytes()) {
+ continue
+ }
+ }
+
if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain {
- return int(n), err
+ // We eat the error in this situation.
+ return int(totalWritten), nil
}
if err := t.Block(ch); err != nil {
- return 0, syserr.FromError(err)
+ return int(totalWritten), syserr.FromError(err)
}
}
}