diff options
author | Ian Gudger <igudger@google.com> | 2018-12-06 11:40:39 -0800 |
---|---|---|
committer | Shentubot <shentubot@google.com> | 2018-12-06 11:41:33 -0800 |
commit | 000fa84a3bb1aebeda235c56545c942d7c29003d (patch) | |
tree | 9026936e4d865c118b6903f3cd1c32dc4ea701e8 /pkg/sentry/socket/epsocket | |
parent | 685eaf119ffa6c44c4dcaec0e083bbdc0271231a (diff) |
Fix tcpip.Endpoint.Write contract regarding short writes
* Clarify tcpip.Endpoint.Write contract regarding short writes.
* Enforce tcpip.Endpoint.Write contract regarding short writes.
* Update relevant users of tcpip.Endpoint.Write.
PiperOrigin-RevId: 224377586
Change-Id: I24299ecce902eb11317ee13dae3b8d8a7c5b097d
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 34 |
1 files changed, 26 insertions, 8 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index c5ce289b5..8c5db6af8 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -323,20 +323,27 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO f := &ioSequencePayload{ctx: ctx, src: src} n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { - return int64(n), syserror.ErrWouldBlock + return 0, syserror.ErrWouldBlock } if resCh != nil { t := ctx.(*kernel.Task) if err := t.Block(resCh); err != nil { - return int64(n), syserr.FromError(err).ToError() + return 0, syserr.FromError(err).ToError() } n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) - return int64(n), syserr.TranslateNetstackError(err).ToError() } - return int64(n), syserr.TranslateNetstackError(err).ToError() + if err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + + if int64(n) < src.NumBytes() { + return int64(n), syserror.ErrWouldBlock + } + + return int64(n), nil } // Readiness returns a mask of ready events for socket s. @@ -1343,11 +1350,16 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts) if resCh != nil { if err := t.Block(resCh); err != nil { - return int(n), syserr.FromError(err) + return 0, syserr.FromError(err) } n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) } - if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + dontWait := flags&linux.MSG_DONTWAIT != 0 + if err == nil && (n >= uintptr(len(v)) || dontWait) { + // Complete write. + return int(n), nil + } + if err != nil && (err != tcpip.ErrWouldBlock || dontWait) { return int(n), syserr.TranslateNetstackError(err) } @@ -1363,11 +1375,17 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) v.TrimFront(int(n)) total += n - if err != tcpip.ErrWouldBlock { - return int(total), syserr.TranslateNetstackError(err) + + if err != nil && err != tcpip.ErrWouldBlock && total == 0 { + return 0, syserr.TranslateNetstackError(err) + } + + if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock { + return int(total), nil } if err := t.Block(ch); err != nil { + // handleIOError will consume errors from t.Block if needed. return int(total), syserr.FromError(err) } } |