summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go34
-rw-r--r--pkg/sentry/socket/socket.go2
2 files changed, 28 insertions, 8 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index c5ce289b5..8c5db6af8 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -323,20 +323,27 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
f := &ioSequencePayload{ctx: ctx, src: src}
n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
- return int64(n), syserror.ErrWouldBlock
+ return 0, syserror.ErrWouldBlock
}
if resCh != nil {
t := ctx.(*kernel.Task)
if err := t.Block(resCh); err != nil {
- return int64(n), syserr.FromError(err).ToError()
+ return 0, syserr.FromError(err).ToError()
}
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
- return int64(n), syserr.TranslateNetstackError(err).ToError()
}
- return int64(n), syserr.TranslateNetstackError(err).ToError()
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
}
// Readiness returns a mask of ready events for socket s.
@@ -1343,11 +1350,16 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
if resCh != nil {
if err := t.Block(resCh); err != nil {
- return int(n), syserr.FromError(err)
+ return 0, syserr.FromError(err)
}
n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
}
- if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ if err == nil && (n >= uintptr(len(v)) || dontWait) {
+ // Complete write.
+ return int(n), nil
+ }
+ if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
return int(n), syserr.TranslateNetstackError(err)
}
@@ -1363,11 +1375,17 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
v.TrimFront(int(n))
total += n
- if err != tcpip.ErrWouldBlock {
- return int(total), syserr.TranslateNetstackError(err)
+
+ if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
+ return 0, syserr.TranslateNetstackError(err)
+ }
+
+ if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ return int(total), nil
}
if err := t.Block(ch); err != nil {
+ // handleIOError will consume errors from t.Block if needed.
return int(total), syserr.FromError(err)
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index b1dcbf7b0..f31729819 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -90,6 +90,8 @@ type Socket interface {
// SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
// ownership of the ControlMessage on error.
+ //
+ // If n > 0, err will either be nil or an error from t.Block.
SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages ControlMessages) (n int, err *syserr.Error)
// SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means