diff options
Diffstat (limited to 'pkg/sentry/socket/netstack/netstack.go')
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 35 |
1 files changed, 4 insertions, 31 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 57f224120..94fb425b2 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -36,7 +36,6 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -459,18 +458,10 @@ func (i *ioSequencePayload) DropFirst(n int) { // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { f := &ioSequencePayload{ctx: ctx, src: src} - n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } - - if resCh != nil { - if err := amutex.Block(ctx, resCh); err != nil { - return 0, err - } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) - } - if err != nil { return 0, syserr.TranslateNetstackError(err).ToError() } @@ -526,24 +517,12 @@ func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { f := &readerPayload{ctx: ctx, r: r, count: count} - n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + n, err := s.Endpoint.Write(f, tcpip.WriteOptions{ // Reads may be destructive but should be very fast, // so we can't release the lock while copying data. Atomic: true, }) if err == tcpip.ErrWouldBlock { - return 0, syserror.ErrWouldBlock - } - - if resCh != nil { - if err := amutex.Block(ctx, resCh); err != nil { - return 0, err - } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ - Atomic: true, // See above. - }) - } - if err == tcpip.ErrWouldBlock { return n, syserror.ErrWouldBlock } else if err != nil { return int64(n), f.err // Propagate error. @@ -2836,13 +2815,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b } v := &ioSequencePayload{t, src} - n, resCh, err := s.Endpoint.Write(v, opts) - if resCh != nil { - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err) - } - n, _, err = s.Endpoint.Write(v, opts) - } + n, err := s.Endpoint.Write(v, opts) dontWait := flags&linux.MSG_DONTWAIT != 0 if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. @@ -2861,7 +2834,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b v.DropFirst(int(n)) total := n for { - n, _, err = s.Endpoint.Write(v, opts) + n, err = s.Endpoint.Write(v, opts) v.DropFirst(int(n)) total += n |