summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go34
-rw-r--r--pkg/sentry/socket/socket.go2
-rw-r--r--pkg/tcpip/tcpip.go5
-rw-r--r--pkg/tcpip/transport/ping/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go6
5 files changed, 38 insertions, 15 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index c5ce289b5..8c5db6af8 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -323,20 +323,27 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
f := &ioSequencePayload{ctx: ctx, src: src}
n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
- return int64(n), syserror.ErrWouldBlock
+ return 0, syserror.ErrWouldBlock
}
if resCh != nil {
t := ctx.(*kernel.Task)
if err := t.Block(resCh); err != nil {
- return int64(n), syserr.FromError(err).ToError()
+ return 0, syserr.FromError(err).ToError()
}
n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
- return int64(n), syserr.TranslateNetstackError(err).ToError()
}
- return int64(n), syserr.TranslateNetstackError(err).ToError()
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
}
// Readiness returns a mask of ready events for socket s.
@@ -1343,11 +1350,16 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
if resCh != nil {
if err := t.Block(resCh); err != nil {
- return int(n), syserr.FromError(err)
+ return 0, syserr.FromError(err)
}
n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
}
- if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ if err == nil && (n >= uintptr(len(v)) || dontWait) {
+ // Complete write.
+ return int(n), nil
+ }
+ if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
return int(n), syserr.TranslateNetstackError(err)
}
@@ -1363,11 +1375,17 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
v.TrimFront(int(n))
total += n
- if err != tcpip.ErrWouldBlock {
- return int(total), syserr.TranslateNetstackError(err)
+
+ if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
+ return 0, syserr.TranslateNetstackError(err)
+ }
+
+ if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ return int(total), nil
}
if err := t.Block(ch); err != nil {
+ // handleIOError will consume errors from t.Block if needed.
return int(total), syserr.FromError(err)
}
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index b1dcbf7b0..f31729819 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -90,6 +90,8 @@ type Socket interface {
// SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
// ownership of the ControlMessage on error.
+ //
+ // If n > 0, err will either be nil or an error from t.Block.
SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages ControlMessages) (n int, err *syserr.Error)
// SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 8e2fe70ee..dc6339173 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -312,7 +312,10 @@ type Endpoint interface {
// the caller should not use data[:n] after Write returns.
//
// Note that unlike io.Writer.Write, it is not an error for Write to
- // perform a partial write.
+ // perform a partial write (if n > 0, no error may be returned). Only
+ // stream (TCP) Endpoints may return partial writes, and even then only
+ // in the case where writing additional data would block. Other Endpoints
+ // will either write the entire message or return an error.
//
// For UDP and Ping sockets if address resolution is required,
// ErrNoLinkAddress and a notification channel is returned for the caller to
diff --git a/pkg/tcpip/transport/ping/endpoint.go b/pkg/tcpip/transport/ping/endpoint.go
index b3f54cfe0..10d4d138e 100644
--- a/pkg/tcpip/transport/ping/endpoint.go
+++ b/pkg/tcpip/transport/ping/endpoint.go
@@ -299,7 +299,11 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
err = sendPing6(route, e.id.LocalPort, v)
}
- return uintptr(len(v)), nil, err
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 1649dbc97..6034ba90b 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -554,10 +554,6 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
return 0, nil, perr
}
- var err *tcpip.Error
- if p.Size() > avail {
- err = tcpip.ErrWouldBlock
- }
l := len(v)
s := newSegmentFromView(&e.route, e.id, v)
@@ -576,7 +572,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
// Let the protocol goroutine do the work.
e.sndWaker.Assert()
}
- return uintptr(l), nil, err
+ return uintptr(l), nil, nil
}
// Peek reads data without consuming it from the endpoint.