diff options
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 23 |
1 files changed, 20 insertions, 3 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 4d32f7a31..550569b4c 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -276,10 +276,21 @@ func (i *ioSequencePayload) Size() int { // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return int64(n), syserror.ErrWouldBlock } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return int64(n), syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) + return int64(n), syserr.TranslateNetstackError(err).ToError() + } + return int64(n), syserr.TranslateNetstackError(err).ToError() } @@ -1016,7 +1027,13 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] EndOfRecord: flags&linux.MSG_EOR != 0, } - n, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts) + n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts) + if resCh != nil { + if err := t.Block(resCh); err != nil { + return int(n), syserr.FromError(err) + } + n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) + } if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { return int(n), syserr.TranslateNetstackError(err) } @@ -1030,7 +1047,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] v.TrimFront(int(n)) total := n for { - n, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) + n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) v.TrimFront(int(n)) total += n if err != tcpip.ErrWouldBlock { |