diff options
author | Tamir Duberstein <tamird@google.com> | 2021-01-22 12:24:20 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-22 12:26:09 -0800 |
commit | 6c0e1d9cfe6adbfbb32e7020d6426608ac63ad37 (patch) | |
tree | 3ddb770b9965ccfcdb6ecb73062e2d466c379439 | |
parent | 527ef5fc0307102fa7cc0b32bcc2eb8cca3e21a8 (diff) |
Define tcpip.Payloader in terms of io.Reader
Fixes #1509.
PiperOrigin-RevId: 353295589
30 files changed, 372 insertions, 373 deletions
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index 9419f2e95..ecbe8f920 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -69,7 +69,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time. // syserror.ErrInterrupted if t is interrupted. // // Preconditions: The caller must be running on the task goroutine. -func (t *Task) BlockWithDeadline(C chan struct{}, haveDeadline bool, deadline ktime.Time) error { +func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline ktime.Time) error { if !haveDeadline { return t.block(C, nil) } diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index f820467c5..915134b41 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -41,7 +41,6 @@ go_library( "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 4c9d335c0..65111154b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -19,7 +19,7 @@ // be used to expose certain endpoints to the sentry while leaving others out, // for example, TCP endpoints and Unix-domain endpoints. // -// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside +// Lock ordering: netstack => mm: ioSequenceReadWriter copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. package netstack @@ -55,7 +55,6 @@ import ( "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -440,45 +439,10 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write return int64(res.Count), nil } -// ioSequencePayload implements tcpip.Payload. -// -// t copies user memory bytes on demand based on the requested size. -type ioSequencePayload struct { - ctx context.Context - src usermem.IOSequence -} - -// FullPayload implements tcpip.Payloader.FullPayload -func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { - return i.Payload(int(i.src.NumBytes())) -} - -// Payload implements tcpip.Payloader.Payload. -func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { - if max := int(i.src.NumBytes()); size > max { - size = max - } - v := buffer.NewView(size) - if _, err := i.src.CopyIn(i.ctx, v); err != nil { - // EOF can be returned only if src is a file and this means it - // is in a splice syscall and the error has to be ignored. - if err == io.EOF { - return v, nil - } - return nil, tcpip.ErrBadAddress - } - return v, nil -} - -// DropFirst drops the first n bytes from underlying src. -func (i *ioSequencePayload) DropFirst(n int) { - i.src = i.src.DropFirst(int(n)) -} - // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } @@ -486,69 +450,40 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } -// readerPayload implements tcpip.Payloader. -// -// It allocates a view and reads from a reader on-demand, based on available -// capacity in the endpoint. -type readerPayload struct { - ctx context.Context - r io.Reader - count int64 - err error -} +var _ tcpip.Payloader = (*limitedPayloader)(nil) -// FullPayload implements tcpip.Payloader.FullPayload. -func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { - return r.Payload(int(r.count)) +type limitedPayloader struct { + io.LimitedReader } -// Payload implements tcpip.Payloader.Payload. -func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { - if size > int(r.count) { - size = int(r.count) - } - v := buffer.NewView(size) - n, err := r.r.Read(v) - if n > 0 { - // We ignore the error here. It may re-occur on subsequent - // reads, but for now we can enqueue some amount of data. - r.count -= int64(n) - return v[:n], nil - } - if err == syserror.ErrWouldBlock { - return nil, tcpip.ErrWouldBlock - } else if err != nil { - r.err = err // Save for propation. - return nil, tcpip.ErrBadAddress - } - - // There is no data and no error. Return an error, which will propagate - // r.err, which will be nil. This is the desired result: (0, nil). - return nil, tcpip.ErrBadAddress +func (l limitedPayloader) Len() int { + return int(l.N) } // ReadFrom implements fs.FileOperations.ReadFrom. func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { - f := &readerPayload{ctx: ctx, r: r, count: count} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + f := limitedPayloader{ + LimitedReader: io.LimitedReader{ + R: r, + N: count, + }, + } + n, err := s.Endpoint.Write(&f, tcpip.WriteOptions{ // Reads may be destructive but should be very fast, // so we can't release the lock while copying data. Atomic: true, }) - if err == tcpip.ErrWouldBlock { - return n, syserror.ErrWouldBlock - } else if err != nil { - return int64(n), f.err // Propagate error. + if err == tcpip.ErrBadBuffer { + err = nil } - - return int64(n), nil + return n, syserr.TranslateNetstackError(err).ToError() } // Readiness returns a mask of ready events for socket s. @@ -2836,45 +2771,46 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b EndOfRecord: flags&linux.MSG_EOR != 0, } - v := &ioSequencePayload{t, src} - n, err := s.Endpoint.Write(v, opts) - dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= v.src.NumBytes() || dontWait) { - // Complete write. - return int(n), nil - } - if err != nil && (err != tcpip.ErrWouldBlock || dontWait) { - return int(n), syserr.TranslateNetstackError(err) - } - - // We'll have to block. Register for notification and keep trying to - // send all the data. - e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventOut) - defer s.EventUnregister(&e) - - v.DropFirst(int(n)) - total := n + r := src.Reader(t) + var ( + total int64 + entry waiter.Entry + ch <-chan struct{} + ) for { - n, err = s.Endpoint.Write(v, opts) - v.DropFirst(int(n)) + n, err := s.Endpoint.Write(r, opts) total += n - - if err != nil && err != tcpip.ErrWouldBlock && total == 0 { - return 0, syserr.TranslateNetstackError(err) - } - - if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { - return int(total), nil + if flags&linux.MSG_DONTWAIT != 0 { + return int(total), syserr.TranslateNetstackError(err) } - - if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { - return int(total), syserr.ErrTryAgain + switch err { + case nil: + if total == src.NumBytes() { + break + } + fallthrough + case tcpip.ErrWouldBlock: + if ch == nil { + // We'll have to block. Register for notification and keep trying to + // send all the data. + entry, ch = waiter.NewChannelEntry(nil) + s.EventRegister(&entry, waiter.EventOut) + defer s.EventUnregister(&entry) + } else { + // Don't wait immediately after registration in case more data + // became available between when we last checked and when we setup + // the notification. + if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { + if err == syserror.ETIMEDOUT { + return int(total), syserr.ErrTryAgain + } + // handleIOError will consume errors from t.Block if needed. + return int(total), syserr.FromError(err) + } } - // handleIOError will consume errors from t.Block if needed. - return int(total), syserr.FromError(err) + continue } + return int(total), syserr.TranslateNetstackError(err) } } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 9600a64e1..3bbdf552e 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -128,8 +128,8 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return 0, syserror.EOPNOTSUPP } - f := &ioSequencePayload{ctx: ctx, src: src} - n, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + r := src.Reader(ctx) + n, err := s.Endpoint.Write(r, tcpip.WriteOptions{}) if err == tcpip.ErrWouldBlock { return 0, syserror.ErrWouldBlock } @@ -137,11 +137,11 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs return 0, syserr.TranslateNetstackError(err).ToError() } - if int64(n) < src.NumBytes() { - return int64(n), syserror.ErrWouldBlock + if n < src.NumBytes() { + return n, syserror.ErrWouldBlock } - return int64(n), nil + return n, nil } // Accept implements the linux syscall accept(2) for sockets backed by diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index fdeec12d3..7c7495c30 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -16,6 +16,7 @@ package gonet import ( + "bytes" "context" "errors" "io" @@ -354,8 +355,6 @@ func (c *TCPConn) Write(b []byte) (int, error) { default: } - v := buffer.NewViewFromBytes(b) - // We must handle two soft failure conditions simultaneously: // 1. Write may write nothing and return tcpip.ErrWouldBlock. // If this happens, we need to register for notifications if we have @@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) { // There is no guarantee that all of the condition #1s will occur before // all of the condition #2s or visa-versa. var ( - err *tcpip.Error - nbytes int - reg bool - notifyCh chan struct{} + r bytes.Reader + nbytes int + entry waiter.Entry + ch <-chan struct{} ) - for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { - if err == tcpip.ErrWouldBlock { - if !reg { - // Only register once. - reg = true - - // Create wait queue entry that notifies a channel. - var waitEntry waiter.Entry - waitEntry, notifyCh = waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) + for nbytes != len(b) { + r.Reset(b[nbytes:]) + n, err := c.ep.Write(&r, tcpip.WriteOptions{}) + nbytes += int(n) + switch err { + case nil: + case tcpip.ErrWouldBlock: + if ch == nil { + entry, ch = waiter.NewChannelEntry(nil) + + c.wq.EventRegister(&entry, waiter.EventOut) + defer c.wq.EventUnregister(&entry) } else { // Don't wait immediately after registration in case more data // became available between when we last checked and when we setup @@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) { select { case <-deadline: return nbytes, c.newOpError("write", &timeoutError{}) - case <-notifyCh: + case <-ch: + continue } } + default: + return nbytes, c.newOpError("write", errors.New(err.String())) } - - var n int64 - n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } - - if err == nil { - return nbytes, nil } - - return nbytes, c.newOpError("write", errors.New(err.String())) + return nbytes, nil } // Close implements net.Conn.Close. @@ -644,16 +637,18 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // If we're being called by Write, there is no addr - wopts := tcpip.WriteOptions{} + writeOptions := tcpip.WriteOptions{} if addr != nil { ua := addr.(*net.UDPAddr) - wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} + writeOptions.To = &tcpip.FullAddress{ + Addr: tcpip.Address(ua.IP), + Port: uint16(ua.Port), + } } - v := buffer.NewView(len(b)) - copy(v, b) - - n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) + var r bytes.Reader + r.Reset(b) + n, err := c.ep.Write(&r, writeOptions) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -666,7 +661,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { case <-notifyCh: } - n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + n, err = c.ep.Write(&r, writeOptions) if err != tcpip.ErrWouldBlock { break } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index defea46b0..3fd1bacae 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -15,6 +15,7 @@ package ipv6 import ( + "bytes" "context" "net" "reflect" @@ -638,7 +639,6 @@ func TestLinkResolution(t *testing.T) { pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) pkt.SetType(header.ICMPv6EchoRequest) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - payload := tcpip.SlicePayload(hdr.View()) // We can't send our payload directly over the route because that // doesn't provoke NDP discovery. @@ -648,8 +648,12 @@ func TestLinkResolution(t *testing.T) { t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err) } - if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { - t.Fatalf("ep.Write(_): %s", err) + { + var r bytes.Reader + r.Reset(hdr.View()) + if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil { + t.Fatalf("ep.Write(_): %s", err) + } } for _, args := range []routeArgs{ {src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))}, diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD index cf0a5fefe..db9b91815 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/BUILD +++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD @@ -8,7 +8,6 @@ go_binary( visibility = ["//:sandbox"], deps = [ "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/fdbased", "//pkg/tcpip/link/rawfile", diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 3b4f900e3..3d9954c84 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -41,7 +41,7 @@ package main import ( - "bufio" + "bytes" "fmt" "log" "math/rand" @@ -51,7 +51,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" "gvisor.dev/gvisor/pkg/tcpip/link/rawfile" @@ -71,24 +70,21 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) { close(ch) }() - r := bufio.NewReader(os.Stdin) - for { - v := buffer.NewView(1024) - n, err := r.Read(v) - if err != nil { - return - } - - v.CapLength(n) - for len(v) > 0 { - n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - if err != nil { - fmt.Println("Write failed:", err) - return + var b bytes.Buffer + if err := func() error { + for { + if _, err := b.ReadFrom(os.Stdin); err != nil { + return fmt.Errorf("b.ReadFrom failed: %w", err) } - v.TrimFront(int(n)) + for b.Len() != 0 { + if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil { + return fmt.Errorf("ep.Write failed: %s", err) + } + } } + }(); err != nil { + fmt.Println(err) } } diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 3ac562756..ae9cf44e7 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -20,6 +20,7 @@ package main import ( + "bytes" "flag" "io" "log" @@ -58,7 +59,9 @@ func (e *tcpipError) Error() string { } func (e *endpointWriter) Write(p []byte) (int, error) { - n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(p) + n, err := e.ep.Write(&r, tcpip.WriteOptions{}) if err != nil { return int(n), &tcpipError{ inner: err, diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 9d39533a1..dbf8b4db1 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "bytes" "io" "testing" @@ -95,10 +96,11 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return 0, tcpip.ErrNoRoute } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen, Data: buffer.View(v).ToVectorisedView(), @@ -520,8 +522,10 @@ func TestTransportSend(t *testing.T) { } // Create buffer that will hold the payload. - view := buffer.NewView(30) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + b := make([]byte, 30) + var r bytes.Reader + r.Reset(b) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("write failed: %v", err) } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4f59e4ff7..fe01029ad 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -29,6 +29,7 @@ package tcpip import ( + "bytes" "errors" "fmt" "io" @@ -471,30 +472,15 @@ type FullAddress struct { // This interface allows the endpoint to request the amount of data it needs // based on internal buffers without exposing them. type Payloader interface { - // FullPayload returns all available bytes. - FullPayload() ([]byte, *Error) + io.Reader - // Payload returns a slice containing at most size bytes. - Payload(size int) ([]byte, *Error) + // Len returns the number of bytes of the unread portion of the + // Reader. + Len() int } -// SlicePayload implements Payloader for slices. -// -// This is typically used for tests. -type SlicePayload []byte - -// FullPayload implements Payloader.FullPayload. -func (s SlicePayload) FullPayload() ([]byte, *Error) { - return s, nil -} - -// Payload implements Payloader.Payload. -func (s SlicePayload) Payload(size int) ([]byte, *Error) { - if size > len(s) { - size = len(s) - } - return s[:size], nil -} +var _ Payloader = (*bytes.Buffer)(nil) +var _ Payloader = (*bytes.Reader)(nil) var _ io.Writer = (*SliceWriter)(nil) diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index ac9670f9a..aedf1845e 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -436,9 +436,10 @@ func TestForwarding(t *testing.T) { write := func(ep tcpip.Endpoint, data []byte) { t.Helper() - dataPayload := tcpip.SlicePayload(data) + var r bytes.Reader + r.Reset(data) var wOpts tcpip.WriteOptions - n, err := ep.Write(dataPayload, wOpts) + n, err := ep.Write(&r, wOpts) if err != nil { t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) } @@ -486,7 +487,7 @@ func TestForwarding(t *testing.T) { read(serverCH, serverEP, data, clientAddr) - data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12}) + data = []byte{5, 6, 7, 8, 9, 10, 11, 12} write(serverEP, data) read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) }) diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 1e13fd6d6..d0d0e4ef2 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -209,8 +209,10 @@ func TestPing(t *testing.T) { defer ep.Close() icmpBuf := test.icmpBuf(t) + var r bytes.Reader + r.Reset(icmpBuf) wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}} - if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(_, _): %s", err) } else if want := int64(len(icmpBuf)); n != want { t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want) @@ -360,9 +362,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) { // Wait for an error due to link resolution failing, or the endpoint to be // writable. <-ch + var r bytes.Reader + r.Reset([]byte{0}) var wOpts tcpip.WriteOptions - if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr { - t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) + if n, err := clientEP.Write(&r, wOpts); err != test.expectedWriteErr { + t.Errorf("got clientEP.Write(_, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr) } if test.expectedWriteErr == nil { diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index 3b13ba04d..761283b66 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { Port: localPort, }, } - n, err := sep.Write(tcpip.SlicePayload(data), wopts) + var r bytes.Reader + r.Reset(data) + n, err := sep.Write(&r, wopts) if err != nil { t.Fatalf("sep.Write(_, _): %s", err) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index ce7c16bd1..9cc12fa58 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) - if n, err := wep.ep.Write(data, writeOpts); err != nil { + data := []byte{byte(i), 2, 3, 4} + var r bytes.Reader + r.Reset(data) + if n, err := wep.ep.Write(&r, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want) diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index b222d2b05..35ee7437a 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -194,9 +194,11 @@ func TestLocalPing(t *testing.T) { return } - payload := tcpip.SlicePayload(test.icmpBuf(t)) + payload := test.icmpBuf(t) + var r bytes.Reader + r.Reset(payload) var wOpts tcpip.WriteOptions - if n, err := ep.Write(payload, wOpts); err != nil { + if n, err := ep.Write(&r, wOpts); err != nil { t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) } else if n != int64(len(payload)) { t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload)) @@ -329,12 +331,14 @@ func TestLocalUDP(t *testing.T) { Port: 80, } - clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + clientPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(clientPayload) wOpts := tcpip.WriteOptions{ To: &serverAddr, } - if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr { t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) } else if subTest.expectedWriteErr != nil { // Nothing else to test if we expected not to be able to send the @@ -376,12 +380,14 @@ func TestLocalUDP(t *testing.T) { } } - serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + serverPayload := []byte{1, 2, 3, 4} { + var r bytes.Reader + r.Reset(serverPayload) wOpts := tcpip.WriteOptions{ To: &clientAddr, } - if n, err := server.Write(serverPayload, wOpts); err != nil { + if n, err := server.Write(&r, wOpts); err != nil { t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) } else if n != int64(len(serverPayload)) { t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload)) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 256e19296..af00ed548 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -313,11 +313,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc route = r } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } + var err *tcpip.Error switch e.NetProto { case header.IPv4ProtocolNumber: err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c0d6fb442..6fd116a98 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -207,7 +207,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, tcpip.ErrInvalidOptionValue } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ae743f75e..2dacf5a64 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -280,9 +280,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return 0, tcpip.ErrInvalidEndpointState } - payloadBytes, err := p.FullPayload() - if err != nil { - return 0, err + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, tcpip.ErrBadBuffer } // If this is an unassociated socket and callee provided a nonzero diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 7e81203ba..fcdd032c5 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -99,7 +99,6 @@ go_test( "//pkg/rand", "//pkg/sync", "//pkg/tcpip", - "//pkg/tcpip/buffer", "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/loopback", diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 1d1b01a6c..809c88732 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -15,11 +15,11 @@ package tcp_test import ( + "strings" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -415,8 +415,10 @@ func testV4Accept(t *testing.T, c *context.Context) { t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) } + var r strings.Reader data := "Don't panic" - nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + r.Reset(data) + nep.Write(&r, tcpip.WriteOptions{}) b = c.GetPacket() tcp = header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ea509ac73..8d27d43c2 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1534,14 +1534,19 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } // Fetch data. - v, perr := p.Payload(avail) - if perr != nil || len(v) == 0 { - // Note that perr may be nil if len(v) == 0. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return 0, nil + } + v := make([]byte, avail) + if _, err := io.ReadFull(p, v); err != nil { if opts.Atomic { e.sndBufMu.Unlock() e.UnlockUser() } - return 0, perr + return 0, tcpip.ErrBadBuffer } if !opts.Atomic { diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go index f7aaee23f..ced3a9c58 100644 --- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -21,13 +21,13 @@ package tcp_test import ( + "bytes" "fmt" "math" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" @@ -42,14 +42,16 @@ func TestFastRecovery(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -207,14 +209,16 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -249,14 +253,16 @@ func TestCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -353,15 +359,16 @@ func TestCubicCongestionAvoidance(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) - + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -462,19 +469,20 @@ func TestRetransmit(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) const iterations = 3 - data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in two shots. Packets will only be written at the // MTU size though. - half := data[:len(data)/2] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data[:len(data)/2]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } - half = data[len(data)/2:] - if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + r.Reset(data[len(data)/2:]) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go index 342eb5eb8..af915203b 100644 --- a/pkg/tcpip/transport/tcp/tcp_rack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go @@ -15,11 +15,11 @@ package tcp_test import ( + "bytes" "testing" "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -61,14 +61,16 @@ func TestRACKUpdate(t *testing.T) { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(maxPayload) + data := make([]byte, maxPayload) for i := range data { data[i] = byte(i) } // Write the data. xmitTime = time.Now() - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -114,13 +116,15 @@ func TestRACKDetectReorder(t *testing.T) { }) setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(ackNumToVerify * maxPayload) + data := make([]byte, ackNumToVerify*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -141,17 +145,19 @@ func TestRACKDetectReorder(t *testing.T) { <-probeDone } -func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View { +func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte { setStackSACKPermitted(t, c, true) createConnectedWithSACKAndTS(c) - data := buffer.NewView(numPackets * maxPayload) + data := make([]byte, numPackets*maxPayload) for i := range data { data[i] = byte(i) } // Write the data. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 6635bb815..5024bc925 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -15,6 +15,7 @@ package tcp_test import ( + "bytes" "fmt" "log" "reflect" @@ -22,7 +23,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -395,14 +395,16 @@ func TestSACKRecovery(t *testing.T) { createConnectedWithSACKAndTS(c) const iterations = 3 - data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1))) for i := range data { data[i] = byte(i) } // Write all the data in one shot. Packets will only be written at the // MTU size though. - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 93683b921..db9dc4e25 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io/ioutil" "math" + "strings" "testing" "time" @@ -26,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -1347,10 +1347,9 @@ func TestTOSV4(t *testing.T) { testV4Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -1396,10 +1395,9 @@ func TestTrafficClassV6(t *testing.T) { testV6Connect(t, c, checker.TOS(tos, 0)) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2176,10 +2174,9 @@ func TestSimpleSend(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2217,10 +2214,9 @@ func TestZeroWindowSend(t *testing.T) { c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2285,10 +2281,9 @@ func TestScaledWindowConnect(t *testing.T) { }) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2317,10 +2312,9 @@ func TestNonScaledWindowConnect(t *testing.T) { c.CreateConnected(789, 30000, 65535*3) data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2391,10 +2385,9 @@ func TestScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2465,10 +2458,9 @@ func TestNonScaledWindowAccept(t *testing.T) { } data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -2632,9 +2624,10 @@ func TestSegmentMerging(t *testing.T) { // Send tcp.InitialCwnd number of segments to fill up // InitialWindow but don't ACK. That should prevent // anymore packets from going out. + var r bytes.Reader for i := 0; i < tcp.InitialCwnd; i++ { - view := buffer.NewViewFromBytes([]byte{0}) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset([]byte{0}) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2644,8 +2637,8 @@ func TestSegmentMerging(t *testing.T) { var allData []byte for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2714,8 +2707,9 @@ func TestDelay(t *testing.T) { var allData []byte for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { allData = append(allData, data...) - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2761,8 +2755,9 @@ func TestUndelay(t *testing.T) { allData := [][]byte{{0}, {1, 2, 3}} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2845,8 +2840,9 @@ func TestMSSNotDelayed(t *testing.T) { allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} for i, data := range allData { - view := buffer.NewViewFromBytes(data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write #%d failed: %s", i+1, err) } } @@ -2894,10 +2890,9 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { data[i] = byte(i) } - view := buffer.NewView(len(data)) - copy(view, data) - - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3328,8 +3323,9 @@ func TestSendOnResetConnection(t *testing.T) { time.Sleep(1 * time.Second) // Try to write. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { + var r bytes.Reader + r.Reset(make([]byte, 10)) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) } } @@ -3352,7 +3348,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { c.WQ.EventRegister(&waitEntry, waiter.EventHUp) defer c.WQ.EventUnregister(&waitEntry) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3409,7 +3407,9 @@ func TestMaxRTO(t *testing.T) { c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) - _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + var r bytes.Reader + r.Reset(make([]byte, 1)) + _, err := c.EP.Write(&r, tcpip.WriteOptions{}) if err != nil { t.Fatalf("Write failed: %s", err) } @@ -3458,7 +3458,9 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) { t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) } - if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(make([]byte, tc.size)) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } pkt := c.GetPacket() @@ -3595,8 +3597,10 @@ func TestFinWithNoPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and have it acknowledged. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3667,9 +3671,11 @@ func TestFinWithPendingDataCwndFull(t *testing.T) { // Write enough segments to fill the congestion window before ACK'ing // any of them. - view := buffer.NewView(10) + view := make([]byte, 10) + var r bytes.Reader for i := tcp.InitialCwnd; i > 0; i-- { - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } } @@ -3754,8 +3760,10 @@ func TestFinWithPendingData(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) // Write something out, and acknowledge it to get cwnd to 2. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3781,7 +3789,8 @@ func TestFinWithPendingData(t *testing.T) { }) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3841,8 +3850,10 @@ func TestFinWithPartialAck(t *testing.T) { // Write something out, and acknowledge it to get cwnd to 2. Also send // FIN from the test side. - view := buffer.NewView(10) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 10) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3879,7 +3890,8 @@ func TestFinWithPartialAck(t *testing.T) { ) // Write new data, but don't acknowledge it. - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -3985,8 +3997,10 @@ func scaledSendWindow(t *testing.T, scale uint8) { }) // Send some data. Check that it's capped by the window size. - view := buffer.NewView(65535) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 65535) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4610,9 +4624,9 @@ func TestSelfConnect(t *testing.T) { // Write something. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -4785,12 +4799,13 @@ func TestPathMTUDiscovery(t *testing.T) { // Send 3200 bytes of data. const writeSize = 3200 - data := buffer.NewView(writeSize) + data := make([]byte, writeSize) for i := range data { data[i] = byte(i) } - - if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5078,8 +5093,10 @@ func TestKeepalive(t *testing.T) { // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -5358,7 +5375,9 @@ func TestListenBacklogFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) b := c.GetPacket() tcp := header.TCP(header.IPv4(b).Payload()) if string(tcp.Payload()) != data { @@ -5674,7 +5693,9 @@ func TestListenSynRcvdQueueFull(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + var r strings.Reader + r.Reset(data) + newEP.Write(&r, tcpip.WriteOptions{}) pkt := c.GetPacket() tcp = header.TCP(header.IPv4(pkt).Payload()) if string(tcp.Payload()) != data { @@ -5908,7 +5929,9 @@ func TestSynRcvdBadSeqNumber(t *testing.T) { // Now verify that the TCP socket is usable and in a connected state. data := "Don't panic" - if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil { + var r strings.Reader + r.Reset(data) + if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7103,10 +7126,10 @@ func TestTCPCloseWithData(t *testing.T) { // Now write a few bytes and then close the endpoint. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } @@ -7204,8 +7227,10 @@ func TestTCPUserTimeout(t *testing.T) { } // Send some data and wait before ACKing it. - view := buffer.NewView(3) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + view := make([]byte, 3) + var r bytes.Reader + r.Reset(view) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Write failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index b65091c3c..5a9745ad7 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -22,7 +22,6 @@ import ( "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" @@ -152,10 +151,10 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS // Now send some data and validate that timestamp is echoed correctly in the ACK. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } @@ -215,10 +214,10 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd // Now send some data with the accepted connection endpoint and validate // that no timestamp option is sent in the TCP segment. data := []byte{1, 2, 3} - view := buffer.NewView(len(data)) - copy(view, data) - if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + var r bytes.Reader + r.Reset(data) + if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil { t.Fatalf("Unexpected error from Write: %s", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9f9b3d510..8544fcb08 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -514,9 +514,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return 0, tcpip.ErrBroadcastDisabled } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 4e2123fe9..c4794e876 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -966,8 +966,9 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { h := flow.header4Tuple(outgoing) writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) - payload := buffer.View(newPayload()) - _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + var r bytes.Reader + r.Reset(newPayload()) + _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, }) c.checkEndpointWriteStats(1, epstats, gotErr) @@ -1007,8 +1008,10 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, } } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("Write failed: %s", err) } @@ -1183,8 +1186,10 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) { writeOpts := tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort}, } - payload := buffer.View(newPayload()) - n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + var r bytes.Reader + payload := newPayload() + r.Reset(payload) + n, err := c.ep.Write(&r, writeOpts) if err != nil { c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err) } @@ -2497,7 +2502,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { } defer ep.Close() - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + var r bytes.Reader + data := []byte{1, 2, 3, 4} to := tcpip.FullAddress{ Addr: test.remoteAddr, Port: 80, @@ -2508,19 +2514,22 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { expectedErrWithoutBcastOpt = nil } - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != expectedErrWithoutBcastOpt { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) } ep.SocketOptions().SetBroadcast(true) - if n, err := ep.Write(data, opts); err != nil { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != nil { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err) } ep.SocketOptions().SetBroadcast(false) - if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + r.Reset(data) + if n, err := ep.Write(&r, opts); err != expectedErrWithoutBcastOpt { t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt) } }) diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 79db8895b..dc2571154 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -517,28 +517,29 @@ func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, er // Reader returns an io.Reader that reads from s. Reads beyond the end of s // return io.EOF. The preconditions that apply to s.CopyIn also apply to the // returned io.Reader.Read. -func (s IOSequence) Reader(ctx context.Context) io.Reader { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Reader(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // Writer returns an io.Writer that writes to s. Writes beyond the end of s // return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also // apply to the returned io.Writer.Write. -func (s IOSequence) Writer(ctx context.Context) io.Writer { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Writer(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when // attempting to write beyond the end of the IOSequence. var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence") -type ioSequenceReadWriter struct { +// IOSequenceReadWriter implements io.Reader and io.Writer for an IOSequence. +type IOSequenceReadWriter struct { ctx context.Context s IOSequence } // Read implements io.Reader.Read. -func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { +func (rw *IOSequenceReadWriter) Read(dst []byte) (int, error) { n, err := rw.s.CopyIn(rw.ctx, dst) rw.s = rw.s.DropFirst(n) if err == nil && rw.s.NumBytes() == 0 { @@ -547,8 +548,13 @@ func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { return n, err } +// Len implements tcpip.Payloader. +func (rw *IOSequenceReadWriter) Len() int { + return int(rw.s.NumBytes()) +} + // Write implements io.Writer.Write. -func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) { +func (rw *IOSequenceReadWriter) Write(src []byte) (int, error) { n, err := rw.s.CopyOut(rw.ctx, src) rw.s = rw.s.DropFirst(n) if err == nil && n < len(src) { |