diff options
author | Tamir Duberstein <tamird@google.com> | 2021-01-22 12:24:20 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-22 12:26:09 -0800 |
commit | 6c0e1d9cfe6adbfbb32e7020d6426608ac63ad37 (patch) | |
tree | 3ddb770b9965ccfcdb6ecb73062e2d466c379439 /pkg/sentry/socket/netstack/netstack.go | |
parent | 527ef5fc0307102fa7cc0b32bcc2eb8cca3e21a8 (diff) |
Define tcpip.Payloader in terms of io.Reader
Fixes #1509.
PiperOrigin-RevId: 353295589
Diffstat (limited to 'pkg/sentry/socket/netstack/netstack.go')
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 176 |
1 files changed, 56 insertions, 120 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 4c9d335c0..65111154b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -19,7 +19,7 @@ // be used to expose certain endpoints to the sentry while leaving others out, // for example, TCP endpoints and Unix-domain endpoints. // -// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside +// Lock ordering: netstack => mm: ioSequenceReadWriter copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. package netstack @@ -55,7 +55,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -440,45 +439,10 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write return int64(res.Count), nil } -// ioSequencePayload implements tcpip.Payload. -// -// t copies user memory bytes on demand based on the requested size. -type ioSequencePayload struct { - ctx context.Context - src usermem.IOSequence -} - -// FullPayload implements tcpip.Payloader.FullPayload -func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { - return i.Payload(int(i.src.NumBytes())) -} - -// Payload implements tcpip.Payloader.Payload. -func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { - if max := int(i.src.NumBytes()); size > max { - size = max - } - v := buffer.NewView(size) - if _, err := i.src.CopyIn(i.ctx, v); err != nil { - // EOF can be returned only if src is a file and this means it - // is in a splice syscall and the error has to be ignored. - if err == io.EOF { - return v, nil - } - return nil, tcpip.ErrBadAddress - } - return v, nil -} - -// DropFirst drops the first n bytes from underlying src. -func (i *ioSequencePayload) DropFirst(n int) { - i.src = i.src.DropFirst(int(n)) -} - // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } @@ -486,69 +450,40 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } -// readerPayload implements tcpip.Payloader. -// -// It allocates a view and reads from a reader on-demand, based on available -// capacity in the endpoint. -type readerPayload struct { - ctx context.Context - r io.Reader - count int64 - err error -} +var _ tcpip.Payloader = (*limitedPayloader)(nil) -// FullPayload implements tcpip.Payloader.FullPayload. -func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { - return r.Payload(int(r.count)) +type limitedPayloader struct { + io.LimitedReader } -// Payload implements tcpip.Payloader.Payload. -func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { - if size > int(r.count) { - size = int(r.count) - } - v := buffer.NewView(size) - n, err := r.r.Read(v) - if n > 0 { - // We ignore the error here. It may re-occur on subsequent - // reads, but for now we can enqueue some amount of data. - r.count -= int64(n) - return v[:n], nil - } - if err == syserror.ErrWouldBlock { - return nil, tcpip.ErrWouldBlock - } else if err != nil { - r.err = err // Save for propation. - return nil, tcpip.ErrBadAddress - } - - // There is no data and no error. Return an error, which will propagate - // r.err, which will be nil. This is the desired result: (0, nil). - return nil, tcpip.ErrBadAddress +func (l limitedPayloader) Len() int { + return int(l.N) } // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { - f := &readerPayload{ctx: ctx, r: r, count: count} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + f := limitedPayloader{ + LimitedReader: io.LimitedReader{ + R: r, + N: count, + }, + } + n, err := s.Endpoint.Write(&f, tcpip.WriteOptions{ // Reads may be destructive but should be very fast, // so we can't release the lock while copying data. Atomic: true, }) - if err == tcpip.ErrWouldBlock { - return n, syserror.ErrWouldBlock - } else if err != nil { - return int64(n), f.err // Propagate error. + if err == tcpip.ErrBadBuffer { + err = nil } - - return int64(n), nil + return n, syserr.TranslateNetstackError(err).ToError() } // Readiness returns a mask of ready events for socket s. @@ -2836,45 +2771,46 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b EndOfRecord: flags&linux.MSG_EOR != 0, } - v := &ioSequencePayload{t, src} - n, err := s.Endpoint.Write(v, opts) - dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= v.src.NumBytes() || dontWait) { - // Complete write. - return int(n), nil - } - if err != nil && (err != tcpip.ErrWouldBlock || dontWait) { - return int(n), syserr.TranslateNetstackError(err) - } - - // We'll have to block. Register for notification and keep trying to - // send all the data. - e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventOut) - defer s.EventUnregister(&e) - - v.DropFirst(int(n)) - total := n + r := src.Reader(t) + var ( + total int64 + entry waiter.Entry + ch <-chan struct{} + ) for { - n, err = s.Endpoint.Write(v, opts) - v.DropFirst(int(n)) + n, err := s.Endpoint.Write(r, opts) total += n - - if err != nil && err != tcpip.ErrWouldBlock && total == 0 { - return 0, syserr.TranslateNetstackError(err) - } - - if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { - return int(total), nil + if flags&linux.MSG_DONTWAIT != 0 { + return int(total), syserr.TranslateNetstackError(err) } - - if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { - return int(total), syserr.ErrTryAgain + switch err { + case nil: + if total == src.NumBytes() { + break + } + fallthrough + case tcpip.ErrWouldBlock: + if ch == nil { + // We'll have to block. Register for notification and keep trying to + // send all the data. + entry, ch = waiter.NewChannelEntry(nil) + s.EventRegister(&entry, waiter.EventOut) + defer s.EventUnregister(&entry) + } else { + // Don't wait immediately after registration in case more data + // became available between when we last checked and when we setup + // the notification. + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + return int(total), syserr.ErrTryAgain + } + // handleIOError will consume errors from t.Block if needed. + return int(total), syserr.FromError(err) + } } - // handleIOError will consume errors from t.Block if needed. - return int(total), syserr.FromError(err) + continue } + return int(total), syserr.TranslateNetstackError(err) } } |