diff options
-rw-r--r-- | pkg/sentry/kernel/task_block.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 176 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_vfs2.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/usermem/usermem.go | 20 |
11 files changed, 133 insertions, 204 deletions
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 9419f2e95..ecbe8f920 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -69,7 +69,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time. // syserror.ErrInterrupted if t is interrupted. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) BlockWithDeadline(C chan struct{}, haveDeadline bool, deadline ktime.Time) error { +func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline ktime.Time) error { if !haveDeadline { return t.block(C, nil) } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 4c9d335c0..65111154b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -19,7 +19,7 @@ // be used to expose certain endpoints to the sentry while leaving others out, // for example, TCP endpoints and Unix-domain endpoints. // -// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside +// Lock ordering: netstack => mm: ioSequenceReadWriter copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. package netstack @@ -55,7 +55,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -440,45 +439,10 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write return int64(res.Count), nil } -// ioSequencePayload implements tcpip.Payload. -// -// t copies user memory bytes on demand based on the requested size. -type ioSequencePayload struct { - ctx context.Context - src usermem.IOSequence -} - -// FullPayload implements tcpip.Payloader.FullPayload -func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { - return i.Payload(int(i.src.NumBytes())) -} - -// Payload implements tcpip.Payloader.Payload. -func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { - if max := int(i.src.NumBytes()); size > max { - size = max - } - v := buffer.NewView(size) - if _, err := i.src.CopyIn(i.ctx, v); err != nil { - // EOF can be returned only if src is a file and this means it - // is in a splice syscall and the error has to be ignored. - if err == io.EOF { - return v, nil - } - return nil, tcpip.ErrBadAddress - } - return v, nil -} - -// DropFirst drops the first n bytes from underlying src. -func (i *ioSequencePayload) DropFirst(n int) { - i.src = i.src.DropFirst(int(n)) -} - // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } @@ -486,69 +450,40 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } -// readerPayload implements tcpip.Payloader. -// -// It allocates a view and reads from a reader on-demand, based on available -// capacity in the endpoint. -type readerPayload struct { - ctx context.Context - r io.Reader - count int64 - err error -} +var _ tcpip.Payloader = (*limitedPayloader)(nil) -// FullPayload implements tcpip.Payloader.FullPayload. -func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { - return r.Payload(int(r.count)) +type limitedPayloader struct { + io.LimitedReader } -// Payload implements tcpip.Payloader.Payload. -func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { - if size > int(r.count) { - size = int(r.count) - } - v := buffer.NewView(size) - n, err := r.r.Read(v) - if n > 0 { - // We ignore the error here. It may re-occur on subsequent - // reads, but for now we can enqueue some amount of data. - r.count -= int64(n) - return v[:n], nil - } - if err == syserror.ErrWouldBlock { - return nil, tcpip.ErrWouldBlock - } else if err != nil { - r.err = err // Save for propation. - return nil, tcpip.ErrBadAddress - } - - // There is no data and no error. Return an error, which will propagate - // r.err, which will be nil. This is the desired result: (0, nil). - return nil, tcpip.ErrBadAddress +func (l limitedPayloader) Len() int { + return int(l.N) } // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { - f := &readerPayload{ctx: ctx, r: r, count: count} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + f := limitedPayloader{ + LimitedReader: io.LimitedReader{ + R: r, + N: count, + }, + } + n, err := s.Endpoint.Write(&f, tcpip.WriteOptions{ // Reads may be destructive but should be very fast, // so we can't release the lock while copying data. Atomic: true, }) - if err == tcpip.ErrWouldBlock { - return n, syserror.ErrWouldBlock - } else if err != nil { - return int64(n), f.err // Propagate error. + if err == tcpip.ErrBadBuffer { + err = nil } - - return int64(n), nil + return n, syserr.TranslateNetstackError(err).ToError() } // Readiness returns a mask of ready events for socket s. @@ -2836,45 +2771,46 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b EndOfRecord: flags&linux.MSG_EOR != 0, } - v := &ioSequencePayload{t, src} - n, err := s.Endpoint.Write(v, opts) - dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= v.src.NumBytes() || dontWait) { - // Complete write. - return int(n), nil - } - if err != nil && (err != tcpip.ErrWouldBlock || dontWait) { - return int(n), syserr.TranslateNetstackError(err) - } - - // We'll have to block. Register for notification and keep trying to - // send all the data. - e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventOut) - defer s.EventUnregister(&e) - - v.DropFirst(int(n)) - total := n + r := src.Reader(t) + var ( + total int64 + entry waiter.Entry + ch <-chan struct{} + ) for { - n, err = s.Endpoint.Write(v, opts) - v.DropFirst(int(n)) + n, err := s.Endpoint.Write(r, opts) total += n - - if err != nil && err != tcpip.ErrWouldBlock && total == 0 { - return 0, syserr.TranslateNetstackError(err) - } - - if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { - return int(total), nil + if flags&linux.MSG_DONTWAIT != 0 { + return int(total), syserr.TranslateNetstackError(err) } - - if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { - return int(total), syserr.ErrTryAgain + switch err { + case nil: + if total == src.NumBytes() { + break + } + fallthrough + case tcpip.ErrWouldBlock: + if ch == nil { + // We'll have to block. Register for notification and keep trying to + // send all the data. + entry, ch = waiter.NewChannelEntry(nil) + s.EventRegister(&entry, waiter.EventOut) + defer s.EventUnregister(&entry) + } else { + // Don't wait immediately after registration in case more data + // became available between when we last checked and when we setup + // the notification. + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + return int(total), syserr.ErrTryAgain + } + // handleIOError will consume errors from t.Block if needed. + return int(total), syserr.FromError(err) + } } - // handleIOError will consume errors from t.Block if needed. - return int(total), syserr.FromError(err) + continue } + return int(total), syserr.TranslateNetstackError(err) } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 9600a64e1..3bbdf552e 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -128,8 +128,8 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return 0, syserror.EOPNOTSUPP } - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } @@ -137,11 +137,11 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } // Accept implements the linux syscall accept(2) for sockets backed by diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index fdeec12d3..7c7495c30 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -16,6 +16,7 @@ package gonet import ( + "bytes" "context" "errors" "io" @@ -354,8 +355,6 @@ func (c *TCPConn) Write(b []byte) (int, error) { default: } - v := buffer.NewViewFromBytes(b) - // We must handle two soft failure conditions simultaneously: // 1. Write may write nothing and return tcpip.ErrWouldBlock. // If this happens, we need to register for notifications if we have @@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) { // There is no guarantee that all of the condition #1s will occur before // all of the condition #2s or visa-versa. var ( - err *tcpip.Error - nbytes int - reg bool - notifyCh chan struct{} + r bytes.Reader + nbytes int + entry waiter.Entry + ch <-chan struct{} ) - for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { - if err == tcpip.ErrWouldBlock { - if !reg { - // Only register once. - reg = true - - // Create wait queue entry that notifies a channel. - var waitEntry waiter.Entry - waitEntry, notifyCh = waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) + for nbytes != len(b) { + r.Reset(b[nbytes:]) + n, err := c.ep.Write(&r, tcpip.WriteOptions{}) + nbytes += int(n) + switch err { + case nil: + case tcpip.ErrWouldBlock: + if ch == nil { + entry, ch = waiter.NewChannelEntry(nil) + + c.wq.EventRegister(&entry, waiter.EventOut) + defer c.wq.EventUnregister(&entry) } else { // Don't wait immediately after registration in case more data // became available between when we last checked and when we setup @@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) { select { case <-deadline: return nbytes, c.newOpError("write", &timeoutError{}) - case <-notifyCh: + case <-ch: + continue } } + default: + return nbytes, c.newOpError("write", errors.New(err.String())) } - - var n int64 - n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } - - if err == nil { - return nbytes, nil } - - return nbytes, c.newOpError("write", errors.New(err.String())) + return nbytes, nil } // Close implements net.Conn.Close. @@ -644,16 +637,18 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // If we're being called by Write, there is no addr - wopts := tcpip.WriteOptions{} + writeOptions := tcpip.WriteOptions{} if addr != nil { ua := addr.(*net.UDPAddr) - wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} + writeOptions.To = &tcpip.FullAddress{ + Addr: tcpip.Address(ua.IP), + Port: uint16(ua.Port), + } } - v := buffer.NewView(len(b)) - copy(v, b) - - n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) + var r bytes.Reader + r.Reset(b) + n, err := c.ep.Write(&r, writeOptions) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -666,7 +661,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { case <-notifyCh: } - n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + n, err = c.ep.Write(&r, writeOptions) if err != tcpip.ErrWouldBlock { break } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4f59e4ff7..fe01029ad 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -29,6 +29,7 @@ package tcpip import ( + "bytes" "errors" "fmt" "io" @@ -471,30 +472,15 @@ type FullAddress struct { // This interface allows the endpoint to request the amount of data it needs // based on internal buffers without exposing them. type Payloader interface { - // FullPayload returns all available bytes. - FullPayload() ([]byte, *Error) + io.Reader - // Payload returns a slice containing at most size bytes. - Payload(size int) ([]byte, *Error) + // Len returns the number of bytes of the unread portion of the + // Reader. + Len() int } -// SlicePayload implements Payloader for slices. -// -// This is typically used for tests. -type SlicePayload []byte - -// FullPayload implements Payloader.FullPayload. -func (s SlicePayload) FullPayload() ([]byte, *Error) { - return s, nil -} - -// Payload implements Payloader.Payload. -func (s SlicePayload) Payload(size int) ([]byte, *Error) { - if size > len(s) { - size = len(s) - } - return s[:size], nil -} +var _ Payloader = (*bytes.Buffer)(nil) +var _ Payloader = (*bytes.Reader)(nil) var _ io.Writer = (*SliceWriter)(nil) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 256e19296..af00ed548 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -313,11 +313,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc route = r } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } + var err *tcpip.Error switch e.NetProto { case header.IPv4ProtocolNumber: err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c0d6fb442..6fd116a98 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -207,7 +207,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, tcpip.ErrInvalidOptionValue } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ae743f75e..2dacf5a64 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -280,9 +280,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return 0, tcpip.ErrInvalidEndpointState } - payloadBytes, err := p.FullPayload() - if err != nil { - return 0, err + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, tcpip.ErrBadBuffer } // If this is an unassociated socket and callee provided a nonzero diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ea509ac73..8d27d43c2 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1534,14 +1534,19 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } // Fetch data. - v, perr := p.Payload(avail) - if perr != nil || len(v) == 0 { - // Note that perr may be nil if len(v) == 0. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return 0, nil + } + v := make([]byte, avail) + if _, err := io.ReadFull(p, v); err != nil { if opts.Atomic { e.sndBufMu.Unlock() e.UnlockUser() } - return 0, perr + return 0, tcpip.ErrBadBuffer } if !opts.Atomic { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9f9b3d510..8544fcb08 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -514,9 +514,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return 0, tcpip.ErrBroadcastDisabled } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 79db8895b..dc2571154 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -517,28 +517,29 @@ func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, er // Reader returns an io.Reader that reads from s. Reads beyond the end of s // return io.EOF. The preconditions that apply to s.CopyIn also apply to the // returned io.Reader.Read. -func (s IOSequence) Reader(ctx context.Context) io.Reader { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Reader(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // Writer returns an io.Writer that writes to s. Writes beyond the end of s // return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also // apply to the returned io.Writer.Write. -func (s IOSequence) Writer(ctx context.Context) io.Writer { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Writer(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when // attempting to write beyond the end of the IOSequence. var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence") -type ioSequenceReadWriter struct { +// IOSequenceReadWriter implements io.Reader and io.Writer for an IOSequence. +type IOSequenceReadWriter struct { ctx context.Context s IOSequence } // Read implements io.Reader.Read. -func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { +func (rw *IOSequenceReadWriter) Read(dst []byte) (int, error) { n, err := rw.s.CopyIn(rw.ctx, dst) rw.s = rw.s.DropFirst(n) if err == nil && rw.s.NumBytes() == 0 { @@ -547,8 +548,13 @@ func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { return n, err } +// Len implements tcpip.Payloader. +func (rw *IOSequenceReadWriter) Len() int { + return int(rw.s.NumBytes()) +} + // Write implements io.Writer.Write. -func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) { +func (rw *IOSequenceReadWriter) Write(src []byte) (int, error) { n, err := rw.s.CopyOut(rw.ctx, src) rw.s = rw.s.DropFirst(n) if err == nil && n < len(src) { |