summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go23
1 files changed, 20 insertions, 3 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 4d32f7a31..550569b4c 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -276,10 +276,21 @@ func (i *ioSequencePayload) Size() int {
// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
f := &ioSequencePayload{ctx: ctx, src: src}
- n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
return int64(n), syserror.ErrWouldBlock
}
+
+ if resCh != nil {
+ t := ctx.(*kernel.Task)
+ if err := t.Block(resCh); err != nil {
+ return int64(n), syserr.FromError(err).ToError()
+ }
+
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+ }
+
return int64(n), syserr.TranslateNetstackError(err).ToError()
}
@@ -1016,7 +1027,13 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
EndOfRecord: flags&linux.MSG_EOR != 0,
}
- n, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ if resCh != nil {
+ if err := t.Block(resCh); err != nil {
+ return int(n), syserr.FromError(err)
+ }
+ n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ }
if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
return int(n), syserr.TranslateNetstackError(err)
}
@@ -1030,7 +1047,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
v.TrimFront(int(n))
total := n
for {
- n, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
v.TrimFront(int(n))
total += n
if err != tcpip.ErrWouldBlock {