diff options
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 34 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/transport/ping/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 6 |
5 files changed, 38 insertions, 15 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index c5ce289b5..8c5db6af8 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -323,20 +323,27 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO f := &ioSequencePayload{ctx: ctx, src: src} n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { - return int64(n), syserror.ErrWouldBlock + return 0, syserror.ErrWouldBlock } if resCh != nil { t := ctx.(*kernel.Task) if err := t.Block(resCh); err != nil { - return int64(n), syserr.FromError(err).ToError() + return 0, syserr.FromError(err).ToError() } n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) - return int64(n), syserr.TranslateNetstackError(err).ToError() } - return int64(n), syserr.TranslateNetstackError(err).ToError() + if err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + + if int64(n) < src.NumBytes() { + return int64(n), syserror.ErrWouldBlock + } + + return int64(n), nil } // Readiness returns a mask of ready events for socket s. @@ -1343,11 +1350,16 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts) if resCh != nil { if err := t.Block(resCh); err != nil { - return int(n), syserr.FromError(err) + return 0, syserr.FromError(err) } n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) } - if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + dontWait := flags&linux.MSG_DONTWAIT != 0 + if err == nil && (n >= uintptr(len(v)) || dontWait) { + // Complete write. + return int(n), nil + } + if err != nil && (err != tcpip.ErrWouldBlock || dontWait) { return int(n), syserr.TranslateNetstackError(err) } @@ -1363,11 +1375,17 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) v.TrimFront(int(n)) total += n - if err != tcpip.ErrWouldBlock { - return int(total), syserr.TranslateNetstackError(err) + + if err != nil && err != tcpip.ErrWouldBlock && total == 0 { + return 0, syserr.TranslateNetstackError(err) + } + + if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock { + return int(total), nil } if err := t.Block(ch); err != nil { + // handleIOError will consume errors from t.Block if needed. return int(total), syserr.FromError(err) } } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index b1dcbf7b0..f31729819 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -90,6 +90,8 @@ type Socket interface { // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take // ownership of the ControlMessage on error. + // + // If n > 0, err will either be nil or an error from t.Block. SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages ControlMessages) (n int, err *syserr.Error) // SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 8e2fe70ee..dc6339173 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -312,7 +312,10 @@ type Endpoint interface { // the caller should not use data[:n] after Write returns. // // Note that unlike io.Writer.Write, it is not an error for Write to - // perform a partial write. + // perform a partial write (if n > 0, no error may be returned). Only + // stream (TCP) Endpoints may return partial writes, and even then only + // in the case where writing additional data would block. Other Endpoints + // will either write the entire message or return an error. // // For UDP and Ping sockets if address resolution is required, // ErrNoLinkAddress and a notification channel is returned for the caller to diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go index b3f54cfe0..10d4d138e 100644 --- a/pkg/tcpip/transport/ping/endpoint.go +++ b/pkg/tcpip/transport/ping/endpoint.go @@ -299,7 +299,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c err = sendPing6(route, e.id.LocalPort, v) } - return uintptr(len(v)), nil, err + if err != nil { + return 0, nil, err + } + + return uintptr(len(v)), nil, nil } // Peek only returns data from a single datagram, so do nothing here. diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 1649dbc97..6034ba90b 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -554,10 +554,6 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c return 0, nil, perr } - var err *tcpip.Error - if p.Size() > avail { - err = tcpip.ErrWouldBlock - } l := len(v) s := newSegmentFromView(&e.route, e.id, v) @@ -576,7 +572,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c // Let the protocol goroutine do the work. e.sndWaker.Assert() } - return uintptr(l), nil, err + return uintptr(l), nil, nil } // Peek reads data without consuming it from the endpoint. |