summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/epsocket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go34
1 files changed, 26 insertions, 8 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index c5ce289b5..8c5db6af8 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -323,20 +323,27 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
f := &ioSequencePayload{ctx: ctx, src: src}
n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
- return int64(n), syserror.ErrWouldBlock
+ return 0, syserror.ErrWouldBlock
}
if resCh != nil {
t := ctx.(*kernel.Task)
if err := t.Block(resCh); err != nil {
- return int64(n), syserr.FromError(err).ToError()
+ return 0, syserr.FromError(err).ToError()
}
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
- return int64(n), syserr.TranslateNetstackError(err).ToError()
}
- return int64(n), syserr.TranslateNetstackError(err).ToError()
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
}
// Readiness returns a mask of ready events for socket s.
@@ -1343,11 +1350,16 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
if resCh != nil {
if err := t.Block(resCh); err != nil {
- return int(n), syserr.FromError(err)
+ return 0, syserr.FromError(err)
}
n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
}
- if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ if err == nil && (n >= uintptr(len(v)) || dontWait) {
+ // Complete write.
+ return int(n), nil
+ }
+ if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
return int(n), syserr.TranslateNetstackError(err)
}
@@ -1363,11 +1375,17 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
v.TrimFront(int(n))
total += n
- if err != tcpip.ErrWouldBlock {
- return int(total), syserr.TranslateNetstackError(err)
+
+ if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
+ return 0, syserr.TranslateNetstackError(err)
+ }
+
+ if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ return int(total), nil
}
if err := t.Block(ch); err != nil {
+ // handleIOError will consume errors from t.Block if needed.
return int(total), syserr.FromError(err)
}
}