From 6c0e1d9cfe6adbfbb32e7020d6426608ac63ad37 Mon Sep 17 00:00:00 2001
From: Tamir Duberstein <tamird@google.com>
Date: Fri, 22 Jan 2021 12:24:20 -0800
Subject: Define tcpip.Payloader in terms of io.Reader

Fixes #1509.

PiperOrigin-RevId: 353295589
---
 pkg/sentry/socket/netstack/netstack.go | 176 +++++++++++----------------------
 1 file changed, 56 insertions(+), 120 deletions(-)

(limited to 'pkg/sentry/socket/netstack/netstack.go')

diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 4c9d335c0..65111154b 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -19,7 +19,7 @@
 // be used to expose certain endpoints to the sentry while leaving others out,
 // for example, TCP endpoints and Unix-domain endpoints.
 //
-// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
+// Lock ordering: netstack => mm: ioSequenceReadWriter copies user memory inside
 // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
 // this operation.
 package netstack
@@ -55,7 +55,6 @@ import (
 	"gvisor.dev/gvisor/pkg/syserr"
 	"gvisor.dev/gvisor/pkg/syserror"
 	"gvisor.dev/gvisor/pkg/tcpip"
-	"gvisor.dev/gvisor/pkg/tcpip/buffer"
 	"gvisor.dev/gvisor/pkg/tcpip/header"
 	"gvisor.dev/gvisor/pkg/tcpip/stack"
 	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -440,45 +439,10 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
 	return int64(res.Count), nil
 }
 
-// ioSequencePayload implements tcpip.Payload.
-//
-// t copies user memory bytes on demand based on the requested size.
-type ioSequencePayload struct {
-	ctx context.Context
-	src usermem.IOSequence
-}
-
-// FullPayload implements tcpip.Payloader.FullPayload
-func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
-	return i.Payload(int(i.src.NumBytes()))
-}
-
-// Payload implements tcpip.Payloader.Payload.
-func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
-	if max := int(i.src.NumBytes()); size > max {
-		size = max
-	}
-	v := buffer.NewView(size)
-	if _, err := i.src.CopyIn(i.ctx, v); err != nil {
-		// EOF can be returned only if src is a file and this means it
-		// is in a splice syscall and the error has to be ignored.
-		if err == io.EOF {
-			return v, nil
-		}
-		return nil, tcpip.ErrBadAddress
-	}
-	return v, nil
-}
-
-// DropFirst drops the first n bytes from underlying src.
-func (i *ioSequencePayload) DropFirst(n int) {
-	i.src = i.src.DropFirst(int(n))
-}
-
 // Write implements fs.FileOperations.Write.
 func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
-	f := &ioSequencePayload{ctx: ctx, src: src}
-	n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+	r := src.Reader(ctx)
+	n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
 	if err == tcpip.ErrWouldBlock {
 		return 0, syserror.ErrWouldBlock
 	}
@@ -486,69 +450,40 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
 		return 0, syserr.TranslateNetstackError(err).ToError()
 	}
 
-	if int64(n) < src.NumBytes() {
-		return int64(n), syserror.ErrWouldBlock
+	if n < src.NumBytes() {
+		return n, syserror.ErrWouldBlock
 	}
 
-	return int64(n), nil
+	return n, nil
 }
 
-// readerPayload implements tcpip.Payloader.
-//
-// It allocates a view and reads from a reader on-demand, based on available
-// capacity in the endpoint.
-type readerPayload struct {
-	ctx   context.Context
-	r     io.Reader
-	count int64
-	err   error
-}
+var _ tcpip.Payloader = (*limitedPayloader)(nil)
 
-// FullPayload implements tcpip.Payloader.FullPayload.
-func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
-	return r.Payload(int(r.count))
+type limitedPayloader struct {
+	io.LimitedReader
 }
 
-// Payload implements tcpip.Payloader.Payload.
-func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
-	if size > int(r.count) {
-		size = int(r.count)
-	}
-	v := buffer.NewView(size)
-	n, err := r.r.Read(v)
-	if n > 0 {
-		// We ignore the error here. It may re-occur on subsequent
-		// reads, but for now we can enqueue some amount of data.
-		r.count -= int64(n)
-		return v[:n], nil
-	}
-	if err == syserror.ErrWouldBlock {
-		return nil, tcpip.ErrWouldBlock
-	} else if err != nil {
-		r.err = err // Save for propation.
-		return nil, tcpip.ErrBadAddress
-	}
-
-	// There is no data and no error. Return an error, which will propagate
-	// r.err, which will be nil. This is the desired result: (0, nil).
-	return nil, tcpip.ErrBadAddress
+func (l limitedPayloader) Len() int {
+	return int(l.N)
 }
 
 // ReadFrom implements fs.FileOperations.ReadFrom.
 func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
-	f := &readerPayload{ctx: ctx, r: r, count: count}
-	n, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+	f := limitedPayloader{
+		LimitedReader: io.LimitedReader{
+			R: r,
+			N: count,
+		},
+	}
+	n, err := s.Endpoint.Write(&f, tcpip.WriteOptions{
 		// Reads may be destructive but should be very fast,
 		// so we can't release the lock while copying data.
 		Atomic: true,
 	})
-	if err == tcpip.ErrWouldBlock {
-		return n, syserror.ErrWouldBlock
-	} else if err != nil {
-		return int64(n), f.err // Propagate error.
+	if err == tcpip.ErrBadBuffer {
+		err = nil
 	}
-
-	return int64(n), nil
+	return n, syserr.TranslateNetstackError(err).ToError()
 }
 
 // Readiness returns a mask of ready events for socket s.
@@ -2836,45 +2771,46 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
 		EndOfRecord: flags&linux.MSG_EOR != 0,
 	}
 
-	v := &ioSequencePayload{t, src}
-	n, err := s.Endpoint.Write(v, opts)
-	dontWait := flags&linux.MSG_DONTWAIT != 0
-	if err == nil && (n >= v.src.NumBytes() || dontWait) {
-		// Complete write.
-		return int(n), nil
-	}
-	if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
-		return int(n), syserr.TranslateNetstackError(err)
-	}
-
-	// We'll have to block. Register for notification and keep trying to
-	// send all the data.
-	e, ch := waiter.NewChannelEntry(nil)
-	s.EventRegister(&e, waiter.EventOut)
-	defer s.EventUnregister(&e)
-
-	v.DropFirst(int(n))
-	total := n
+	r := src.Reader(t)
+	var (
+		total int64
+		entry waiter.Entry
+		ch    <-chan struct{}
+	)
 	for {
-		n, err = s.Endpoint.Write(v, opts)
-		v.DropFirst(int(n))
+		n, err := s.Endpoint.Write(r, opts)
 		total += n
-
-		if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
-			return 0, syserr.TranslateNetstackError(err)
-		}
-
-		if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
-			return int(total), nil
+		if flags&linux.MSG_DONTWAIT != 0 {
+			return int(total), syserr.TranslateNetstackError(err)
 		}
-
-		if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
-			if err == syserror.ETIMEDOUT {
-				return int(total), syserr.ErrTryAgain
+		switch err {
+		case nil:
+			if total == src.NumBytes() {
+				break
+			}
+			fallthrough
+		case tcpip.ErrWouldBlock:
+			if ch == nil {
+				// We'll have to block. Register for notification and keep trying to
+				// send all the data.
+				entry, ch = waiter.NewChannelEntry(nil)
+				s.EventRegister(&entry, waiter.EventOut)
+				defer s.EventUnregister(&entry)
+			} else {
+				// Don't wait immediately after registration in case more data
+				// became available between when we last checked and when we setup
+				// the notification.
+				if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+					if err == syserror.ETIMEDOUT {
+						return int(total), syserr.ErrTryAgain
+					}
+					// handleIOError will consume errors from t.Block if needed.
+					return int(total), syserr.FromError(err)
+				}
 			}
-			// handleIOError will consume errors from t.Block if needed.
-			return int(total), syserr.FromError(err)
+			continue
 		}
+		return int(total), syserr.TranslateNetstackError(err)
 	}
 }
 
-- 
cgit v1.2.3