summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2021-01-22 12:24:20 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-22 12:26:09 -0800
commit6c0e1d9cfe6adbfbb32e7020d6426608ac63ad37 (patch)
tree3ddb770b9965ccfcdb6ecb73062e2d466c379439
parent527ef5fc0307102fa7cc0b32bcc2eb8cca3e21a8 (diff)
Define tcpip.Payloader in terms of io.Reader
Fixes #1509. PiperOrigin-RevId: 353295589
-rw-r--r--pkg/sentry/kernel/task_block.go2
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go176
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go10
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go67
-rw-r--r--pkg/tcpip/network/ipv6/icmp_test.go10
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD1
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go30
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go5
-rw-r--r--pkg/tcpip/stack/transport_test.go14
-rw-r--r--pkg/tcpip/tcpip.go28
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go7
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go10
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go4
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go6
-rw-r--r--pkg/tcpip/tests/integration/route_test.go18
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go7
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go13
-rw-r--r--pkg/tcpip/transport/tcp/tcp_noracedetector_test.go38
-rw-r--r--pkg/tcpip/transport/tcp/tcp_rack_test.go22
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go8
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go185
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go13
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go29
-rw-r--r--pkg/usermem/usermem.go20
30 files changed, 372 insertions, 373 deletions
diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go
index 9419f2e95..ecbe8f920 100644
--- a/pkg/sentry/kernel/task_block.go
+++ b/pkg/sentry/kernel/task_block.go
@@ -69,7 +69,7 @@ func (t *Task) BlockWithTimeout(C chan struct{}, haveTimeout bool, timeout time.
// syserror.ErrInterrupted if t is interrupted.
//
// Preconditions: The caller must be running on the task goroutine.
-func (t *Task) BlockWithDeadline(C chan struct{}, haveDeadline bool, deadline ktime.Time) error {
+func (t *Task) BlockWithDeadline(C <-chan struct{}, haveDeadline bool, deadline ktime.Time) error {
if !haveDeadline {
return t.block(C, nil)
}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index f820467c5..915134b41 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -41,7 +41,6 @@ go_library(
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 4c9d335c0..65111154b 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -19,7 +19,7 @@
// be used to expose certain endpoints to the sentry while leaving others out,
// for example, TCP endpoints and Unix-domain endpoints.
//
-// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
+// Lock ordering: netstack => mm: ioSequenceReadWriter copies user memory inside
// tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
// this operation.
package netstack
@@ -55,7 +55,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -440,45 +439,10 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
return int64(res.Count), nil
}
-// ioSequencePayload implements tcpip.Payload.
-//
-// t copies user memory bytes on demand based on the requested size.
-type ioSequencePayload struct {
- ctx context.Context
- src usermem.IOSequence
-}
-
-// FullPayload implements tcpip.Payloader.FullPayload
-func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
- return i.Payload(int(i.src.NumBytes()))
-}
-
-// Payload implements tcpip.Payloader.Payload.
-func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
- if max := int(i.src.NumBytes()); size > max {
- size = max
- }
- v := buffer.NewView(size)
- if _, err := i.src.CopyIn(i.ctx, v); err != nil {
- // EOF can be returned only if src is a file and this means it
- // is in a splice syscall and the error has to be ignored.
- if err == io.EOF {
- return v, nil
- }
- return nil, tcpip.ErrBadAddress
- }
- return v, nil
-}
-
-// DropFirst drops the first n bytes from underlying src.
-func (i *ioSequencePayload) DropFirst(n int) {
- i.src = i.src.DropFirst(int(n))
-}
-
// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
- f := &ioSequencePayload{ctx: ctx, src: src}
- n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ r := src.Reader(ctx)
+ n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
return 0, syserror.ErrWouldBlock
}
@@ -486,69 +450,40 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
return 0, syserr.TranslateNetstackError(err).ToError()
}
- if int64(n) < src.NumBytes() {
- return int64(n), syserror.ErrWouldBlock
+ if n < src.NumBytes() {
+ return n, syserror.ErrWouldBlock
}
- return int64(n), nil
+ return n, nil
}
-// readerPayload implements tcpip.Payloader.
-//
-// It allocates a view and reads from a reader on-demand, based on available
-// capacity in the endpoint.
-type readerPayload struct {
- ctx context.Context
- r io.Reader
- count int64
- err error
-}
+var _ tcpip.Payloader = (*limitedPayloader)(nil)
-// FullPayload implements tcpip.Payloader.FullPayload.
-func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
- return r.Payload(int(r.count))
+type limitedPayloader struct {
+ io.LimitedReader
}
-// Payload implements tcpip.Payloader.Payload.
-func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
- if size > int(r.count) {
- size = int(r.count)
- }
- v := buffer.NewView(size)
- n, err := r.r.Read(v)
- if n > 0 {
- // We ignore the error here. It may re-occur on subsequent
- // reads, but for now we can enqueue some amount of data.
- r.count -= int64(n)
- return v[:n], nil
- }
- if err == syserror.ErrWouldBlock {
- return nil, tcpip.ErrWouldBlock
- } else if err != nil {
- r.err = err // Save for propation.
- return nil, tcpip.ErrBadAddress
- }
-
- // There is no data and no error. Return an error, which will propagate
- // r.err, which will be nil. This is the desired result: (0, nil).
- return nil, tcpip.ErrBadAddress
+func (l limitedPayloader) Len() int {
+ return int(l.N)
}
// ReadFrom implements fs.FileOperations.ReadFrom.
func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
- f := &readerPayload{ctx: ctx, r: r, count: count}
- n, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ f := limitedPayloader{
+ LimitedReader: io.LimitedReader{
+ R: r,
+ N: count,
+ },
+ }
+ n, err := s.Endpoint.Write(&f, tcpip.WriteOptions{
// Reads may be destructive but should be very fast,
// so we can't release the lock while copying data.
Atomic: true,
})
- if err == tcpip.ErrWouldBlock {
- return n, syserror.ErrWouldBlock
- } else if err != nil {
- return int64(n), f.err // Propagate error.
+ if err == tcpip.ErrBadBuffer {
+ err = nil
}
-
- return int64(n), nil
+ return n, syserr.TranslateNetstackError(err).ToError()
}
// Readiness returns a mask of ready events for socket s.
@@ -2836,45 +2771,46 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
EndOfRecord: flags&linux.MSG_EOR != 0,
}
- v := &ioSequencePayload{t, src}
- n, err := s.Endpoint.Write(v, opts)
- dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= v.src.NumBytes() || dontWait) {
- // Complete write.
- return int(n), nil
- }
- if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
- return int(n), syserr.TranslateNetstackError(err)
- }
-
- // We'll have to block. Register for notification and keep trying to
- // send all the data.
- e, ch := waiter.NewChannelEntry(nil)
- s.EventRegister(&e, waiter.EventOut)
- defer s.EventUnregister(&e)
-
- v.DropFirst(int(n))
- total := n
+ r := src.Reader(t)
+ var (
+ total int64
+ entry waiter.Entry
+ ch <-chan struct{}
+ )
for {
- n, err = s.Endpoint.Write(v, opts)
- v.DropFirst(int(n))
+ n, err := s.Endpoint.Write(r, opts)
total += n
-
- if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
- return 0, syserr.TranslateNetstackError(err)
- }
-
- if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
- return int(total), nil
+ if flags&linux.MSG_DONTWAIT != 0 {
+ return int(total), syserr.TranslateNetstackError(err)
}
-
- if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
- if err == syserror.ETIMEDOUT {
- return int(total), syserr.ErrTryAgain
+ switch err {
+ case nil:
+ if total == src.NumBytes() {
+ break
+ }
+ fallthrough
+ case tcpip.ErrWouldBlock:
+ if ch == nil {
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ entry, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&entry, waiter.EventOut)
+ defer s.EventUnregister(&entry)
+ } else {
+ // Don't wait immediately after registration in case more data
+ // became available between when we last checked and when we setup
+ // the notification.
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(total), syserr.ErrTryAgain
+ }
+ // handleIOError will consume errors from t.Block if needed.
+ return int(total), syserr.FromError(err)
+ }
}
- // handleIOError will consume errors from t.Block if needed.
- return int(total), syserr.FromError(err)
+ continue
}
+ return int(total), syserr.TranslateNetstackError(err)
}
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 9600a64e1..3bbdf552e 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -128,8 +128,8 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
return 0, syserror.EOPNOTSUPP
}
- f := &ioSequencePayload{ctx: ctx, src: src}
- n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ r := src.Reader(ctx)
+ n, err := s.Endpoint.Write(r, tcpip.WriteOptions{})
if err == tcpip.ErrWouldBlock {
return 0, syserror.ErrWouldBlock
}
@@ -137,11 +137,11 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
return 0, syserr.TranslateNetstackError(err).ToError()
}
- if int64(n) < src.NumBytes() {
- return int64(n), syserror.ErrWouldBlock
+ if n < src.NumBytes() {
+ return n, syserror.ErrWouldBlock
}
- return int64(n), nil
+ return n, nil
}
// Accept implements the linux syscall accept(2) for sockets backed by
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index fdeec12d3..7c7495c30 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -16,6 +16,7 @@
package gonet
import (
+ "bytes"
"context"
"errors"
"io"
@@ -354,8 +355,6 @@ func (c *TCPConn) Write(b []byte) (int, error) {
default:
}
- v := buffer.NewViewFromBytes(b)
-
// We must handle two soft failure conditions simultaneously:
// 1. Write may write nothing and return tcpip.ErrWouldBlock.
// If this happens, we need to register for notifications if we have
@@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) {
// There is no guarantee that all of the condition #1s will occur before
// all of the condition #2s or visa-versa.
var (
- err *tcpip.Error
- nbytes int
- reg bool
- notifyCh chan struct{}
+ r bytes.Reader
+ nbytes int
+ entry waiter.Entry
+ ch <-chan struct{}
)
- for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
- if err == tcpip.ErrWouldBlock {
- if !reg {
- // Only register once.
- reg = true
-
- // Create wait queue entry that notifies a channel.
- var waitEntry waiter.Entry
- waitEntry, notifyCh = waiter.NewChannelEntry(nil)
- c.wq.EventRegister(&waitEntry, waiter.EventOut)
- defer c.wq.EventUnregister(&waitEntry)
+ for nbytes != len(b) {
+ r.Reset(b[nbytes:])
+ n, err := c.ep.Write(&r, tcpip.WriteOptions{})
+ nbytes += int(n)
+ switch err {
+ case nil:
+ case tcpip.ErrWouldBlock:
+ if ch == nil {
+ entry, ch = waiter.NewChannelEntry(nil)
+
+ c.wq.EventRegister(&entry, waiter.EventOut)
+ defer c.wq.EventUnregister(&entry)
} else {
// Don't wait immediately after registration in case more data
// became available between when we last checked and when we setup
@@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) {
select {
case <-deadline:
return nbytes, c.newOpError("write", &timeoutError{})
- case <-notifyCh:
+ case <-ch:
+ continue
}
}
+ default:
+ return nbytes, c.newOpError("write", errors.New(err.String()))
}
-
- var n int64
- n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- nbytes += int(n)
- v.TrimFront(int(n))
- }
-
- if err == nil {
- return nbytes, nil
}
-
- return nbytes, c.newOpError("write", errors.New(err.String()))
+ return nbytes, nil
}
// Close implements net.Conn.Close.
@@ -644,16 +637,18 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
}
// If we're being called by Write, there is no addr
- wopts := tcpip.WriteOptions{}
+ writeOptions := tcpip.WriteOptions{}
if addr != nil {
ua := addr.(*net.UDPAddr)
- wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
+ writeOptions.To = &tcpip.FullAddress{
+ Addr: tcpip.Address(ua.IP),
+ Port: uint16(ua.Port),
+ }
}
- v := buffer.NewView(len(b))
- copy(v, b)
-
- n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
+ var r bytes.Reader
+ r.Reset(b)
+ n, err := c.ep.Write(&r, writeOptions)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
@@ -666,7 +661,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
case <-notifyCh:
}
- n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ n, err = c.ep.Write(&r, writeOptions)
if err != tcpip.ErrWouldBlock {
break
}
diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go
index defea46b0..3fd1bacae 100644
--- a/pkg/tcpip/network/ipv6/icmp_test.go
+++ b/pkg/tcpip/network/ipv6/icmp_test.go
@@ -15,6 +15,7 @@
package ipv6
import (
+ "bytes"
"context"
"net"
"reflect"
@@ -638,7 +639,6 @@ func TestLinkResolution(t *testing.T) {
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
- payload := tcpip.SlicePayload(hdr.View())
// We can't send our payload directly over the route because that
// doesn't provoke NDP discovery.
@@ -648,8 +648,12 @@ func TestLinkResolution(t *testing.T) {
t.Fatalf("NewEndpoint(_) = (_, %s), want = (_, nil)", err)
}
- if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
- t.Fatalf("ep.Write(_): %s", err)
+ {
+ var r bytes.Reader
+ r.Reset(hdr.View())
+ if _, err := ep.Write(&r, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: nicID, Addr: lladdr1}}); err != nil {
+ t.Fatalf("ep.Write(_): %s", err)
+ }
}
for _, args := range []routeArgs{
{src: c.linkEP0, dst: c.linkEP1, typ: header.ICMPv6NeighborSolicit, remoteLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.SolicitedNodeAddr(lladdr1))},
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
index cf0a5fefe..db9b91815 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/BUILD
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -8,7 +8,6 @@ go_binary(
visibility = ["//:sandbox"],
deps = [
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/fdbased",
"//pkg/tcpip/link/rawfile",
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 3b4f900e3..3d9954c84 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -41,7 +41,7 @@
package main
import (
- "bufio"
+ "bytes"
"fmt"
"log"
"math/rand"
@@ -51,7 +51,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/link/rawfile"
@@ -71,24 +70,21 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) {
close(ch)
}()
- r := bufio.NewReader(os.Stdin)
- for {
- v := buffer.NewView(1024)
- n, err := r.Read(v)
- if err != nil {
- return
- }
-
- v.CapLength(n)
- for len(v) > 0 {
- n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
- if err != nil {
- fmt.Println("Write failed:", err)
- return
+ var b bytes.Buffer
+ if err := func() error {
+ for {
+ if _, err := b.ReadFrom(os.Stdin); err != nil {
+ return fmt.Errorf("b.ReadFrom failed: %w", err)
}
- v.TrimFront(int(n))
+ for b.Len() != 0 {
+ if _, err := ep.Write(&b, tcpip.WriteOptions{Atomic: true}); err != nil {
+ return fmt.Errorf("ep.Write failed: %s", err)
+ }
+ }
}
+ }(); err != nil {
+ fmt.Println(err)
}
}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index 3ac562756..ae9cf44e7 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -20,6 +20,7 @@
package main
import (
+ "bytes"
"flag"
"io"
"log"
@@ -58,7 +59,9 @@ func (e *tcpipError) Error() string {
}
func (e *endpointWriter) Write(p []byte) (int, error) {
- n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(p)
+ n, err := e.ep.Write(&r, tcpip.WriteOptions{})
if err != nil {
return int(n), &tcpipError{
inner: err,
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 9d39533a1..dbf8b4db1 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "bytes"
"io"
"testing"
@@ -95,10 +96,11 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, tcpip.ErrNoRoute
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, tcpip.ErrBadBuffer
}
+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
Data: buffer.View(v).ToVectorisedView(),
@@ -520,8 +522,10 @@ func TestTransportSend(t *testing.T) {
}
// Create buffer that will hold the payload.
- view := buffer.NewView(30)
- if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ b := make([]byte, 30)
+ var r bytes.Reader
+ r.Reset(b)
+ if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("write failed: %v", err)
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 4f59e4ff7..fe01029ad 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -29,6 +29,7 @@
package tcpip
import (
+ "bytes"
"errors"
"fmt"
"io"
@@ -471,30 +472,15 @@ type FullAddress struct {
// This interface allows the endpoint to request the amount of data it needs
// based on internal buffers without exposing them.
type Payloader interface {
- // FullPayload returns all available bytes.
- FullPayload() ([]byte, *Error)
+ io.Reader
- // Payload returns a slice containing at most size bytes.
- Payload(size int) ([]byte, *Error)
+ // Len returns the number of bytes of the unread portion of the
+ // Reader.
+ Len() int
}
-// SlicePayload implements Payloader for slices.
-//
-// This is typically used for tests.
-type SlicePayload []byte
-
-// FullPayload implements Payloader.FullPayload.
-func (s SlicePayload) FullPayload() ([]byte, *Error) {
- return s, nil
-}
-
-// Payload implements Payloader.Payload.
-func (s SlicePayload) Payload(size int) ([]byte, *Error) {
- if size > len(s) {
- size = len(s)
- }
- return s[:size], nil
-}
+var _ Payloader = (*bytes.Buffer)(nil)
+var _ Payloader = (*bytes.Reader)(nil)
var _ io.Writer = (*SliceWriter)(nil)
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index ac9670f9a..aedf1845e 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -436,9 +436,10 @@ func TestForwarding(t *testing.T) {
write := func(ep tcpip.Endpoint, data []byte) {
t.Helper()
- dataPayload := tcpip.SlicePayload(data)
+ var r bytes.Reader
+ r.Reset(data)
var wOpts tcpip.WriteOptions
- n, err := ep.Write(dataPayload, wOpts)
+ n, err := ep.Write(&r, wOpts)
if err != nil {
t.Fatalf("ep.Write(_, %#v): %s", wOpts, err)
}
@@ -486,7 +487,7 @@ func TestForwarding(t *testing.T) {
read(serverCH, serverEP, data, clientAddr)
- data = tcpip.SlicePayload([]byte{5, 6, 7, 8, 9, 10, 11, 12})
+ data = []byte{5, 6, 7, 8, 9, 10, 11, 12}
write(serverEP, data)
read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr)
})
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 1e13fd6d6..d0d0e4ef2 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -209,8 +209,10 @@ func TestPing(t *testing.T) {
defer ep.Close()
icmpBuf := test.icmpBuf(t)
+ var r bytes.Reader
+ r.Reset(icmpBuf)
wOpts := tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: test.remoteAddr}}
- if n, err := ep.Write(tcpip.SlicePayload(icmpBuf), wOpts); err != nil {
+ if n, err := ep.Write(&r, wOpts); err != nil {
t.Fatalf("ep.Write(_, _): %s", err)
} else if want := int64(len(icmpBuf)); n != want {
t.Fatalf("got ep.Write(_, _) = (%d, _), want = (%d, _)", n, want)
@@ -360,9 +362,11 @@ func TestTCPLinkResolutionFailure(t *testing.T) {
// Wait for an error due to link resolution failing, or the endpoint to be
// writable.
<-ch
+ var r bytes.Reader
+ r.Reset([]byte{0})
var wOpts tcpip.WriteOptions
- if n, err := clientEP.Write(tcpip.SlicePayload(nil), wOpts); err != test.expectedWriteErr {
- t.Errorf("got clientEP.Write(nil, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr)
+ if n, err := clientEP.Write(&r, wOpts); err != test.expectedWriteErr {
+ t.Errorf("got clientEP.Write(_, %#v) = (%d, %s), want = (_, %s)", wOpts, n, err, test.expectedWriteErr)
}
if test.expectedWriteErr == nil {
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index 3b13ba04d..761283b66 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -232,7 +232,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
Port: localPort,
},
}
- n, err := sep.Write(tcpip.SlicePayload(data), wopts)
+ var r bytes.Reader
+ r.Reset(data)
+ n, err := sep.Write(&r, wopts)
if err != nil {
t.Fatalf("sep.Write(_, _): %s", err)
}
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index ce7c16bd1..9cc12fa58 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -586,8 +586,10 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
Port: localPort,
},
}
- data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4})
- if n, err := wep.ep.Write(data, writeOpts); err != nil {
+ data := []byte{byte(i), 2, 3, 4}
+ var r bytes.Reader
+ r.Reset(data)
+ if n, err := wep.ep.Write(&r, writeOpts); err != nil {
t.Fatalf("eps[%d].Write(_, _): %s", i, err)
} else if want := int64(len(data)); n != want {
t.Fatalf("got eps[%d].Write(_, _) = (%d, nil), want = (%d, nil)", i, n, want)
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index b222d2b05..35ee7437a 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -194,9 +194,11 @@ func TestLocalPing(t *testing.T) {
return
}
- payload := tcpip.SlicePayload(test.icmpBuf(t))
+ payload := test.icmpBuf(t)
+ var r bytes.Reader
+ r.Reset(payload)
var wOpts tcpip.WriteOptions
- if n, err := ep.Write(payload, wOpts); err != nil {
+ if n, err := ep.Write(&r, wOpts); err != nil {
t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err)
} else if n != int64(len(payload)) {
t.Fatalf("got ep.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", payload, wOpts, n, len(payload))
@@ -329,12 +331,14 @@ func TestLocalUDP(t *testing.T) {
Port: 80,
}
- clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ clientPayload := []byte{1, 2, 3, 4}
{
+ var r bytes.Reader
+ r.Reset(clientPayload)
wOpts := tcpip.WriteOptions{
To: &serverAddr,
}
- if n, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr {
+ if n, err := client.Write(&r, wOpts); err != subTest.expectedWriteErr {
t.Fatalf("got client.Write(%#v, %#v) = (%d, %s), want = (_, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr)
} else if subTest.expectedWriteErr != nil {
// Nothing else to test if we expected not to be able to send the
@@ -376,12 +380,14 @@ func TestLocalUDP(t *testing.T) {
}
}
- serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ serverPayload := []byte{1, 2, 3, 4}
{
+ var r bytes.Reader
+ r.Reset(serverPayload)
wOpts := tcpip.WriteOptions{
To: &clientAddr,
}
- if n, err := server.Write(serverPayload, wOpts); err != nil {
+ if n, err := server.Write(&r, wOpts); err != nil {
t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err)
} else if n != int64(len(serverPayload)) {
t.Fatalf("got server.Write(%#v, %#v) = (%d, nil), want = (%d, nil)", serverPayload, wOpts, n, len(serverPayload))
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 256e19296..af00ed548 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -313,11 +313,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
route = r
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, tcpip.ErrBadBuffer
}
+ var err *tcpip.Error
switch e.NetProto {
case header.IPv4ProtocolNumber:
err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index c0d6fb442..6fd116a98 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -207,7 +207,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul
return res, nil
}
-func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) {
+func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, *tcpip.Error) {
// TODO(gvisor.dev/issue/173): Implement.
return 0, tcpip.ErrInvalidOptionValue
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ae743f75e..2dacf5a64 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -280,9 +280,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return 0, tcpip.ErrInvalidEndpointState
}
- payloadBytes, err := p.FullPayload()
- if err != nil {
- return 0, err
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return 0, tcpip.ErrBadBuffer
}
// If this is an unassociated socket and callee provided a nonzero
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index 7e81203ba..fcdd032c5 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -99,7 +99,6 @@ go_test(
"//pkg/rand",
"//pkg/sync",
"//pkg/tcpip",
- "//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 1d1b01a6c..809c88732 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -15,11 +15,11 @@
package tcp_test
import (
+ "strings"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
@@ -415,8 +415,10 @@ func testV4Accept(t *testing.T, c *context.Context) {
t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
}
+ var r strings.Reader
data := "Don't panic"
- nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ r.Reset(data)
+ nep.Write(&r, tcpip.WriteOptions{})
b = c.GetPacket()
tcp = header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ea509ac73..8d27d43c2 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1534,14 +1534,19 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
}
// Fetch data.
- v, perr := p.Payload(avail)
- if perr != nil || len(v) == 0 {
- // Note that perr may be nil if len(v) == 0.
+ if l := p.Len(); l < avail {
+ avail = l
+ }
+ if avail == 0 {
+ return 0, nil
+ }
+ v := make([]byte, avail)
+ if _, err := io.ReadFull(p, v); err != nil {
if opts.Atomic {
e.sndBufMu.Unlock()
e.UnlockUser()
}
- return 0, perr
+ return 0, tcpip.ErrBadBuffer
}
if !opts.Atomic {
diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
index f7aaee23f..ced3a9c58 100644
--- a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go
@@ -21,13 +21,13 @@
package tcp_test
import (
+ "bytes"
"fmt"
"math"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context"
@@ -42,14 +42,16 @@ func TestFastRecovery(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -207,14 +209,16 @@ func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -249,14 +253,16 @@ func TestCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -353,15 +359,16 @@ func TestCubicCongestionAvoidance(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
-
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -462,19 +469,20 @@ func TestRetransmit(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
const iterations = 3
- data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in two shots. Packets will only be written at the
// MTU size though.
- half := data[:len(data)/2]
- if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data[:len(data)/2])
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
- half = data[len(data)/2:]
- if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ r.Reset(data[len(data)/2:])
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_rack_test.go b/pkg/tcpip/transport/tcp/tcp_rack_test.go
index 342eb5eb8..af915203b 100644
--- a/pkg/tcpip/transport/tcp/tcp_rack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_rack_test.go
@@ -15,11 +15,11 @@
package tcp_test
import (
+ "bytes"
"testing"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -61,14 +61,16 @@ func TestRACKUpdate(t *testing.T) {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(maxPayload)
+ data := make([]byte, maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
xmitTime = time.Now()
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -114,13 +116,15 @@ func TestRACKDetectReorder(t *testing.T) {
})
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(ackNumToVerify * maxPayload)
+ data := make([]byte, ackNumToVerify*maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -141,17 +145,19 @@ func TestRACKDetectReorder(t *testing.T) {
<-probeDone
}
-func sendAndReceive(t *testing.T, c *context.Context, numPackets int) buffer.View {
+func sendAndReceive(t *testing.T, c *context.Context, numPackets int) []byte {
setStackSACKPermitted(t, c, true)
createConnectedWithSACKAndTS(c)
- data := buffer.NewView(numPackets * maxPayload)
+ data := make([]byte, numPackets*maxPayload)
for i := range data {
data[i] = byte(i)
}
// Write the data.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
index 6635bb815..5024bc925 100644
--- a/pkg/tcpip/transport/tcp/tcp_sack_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -15,6 +15,7 @@
package tcp_test
import (
+ "bytes"
"fmt"
"log"
"reflect"
@@ -22,7 +23,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/seqnum"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -395,14 +395,16 @@ func TestSACKRecovery(t *testing.T) {
createConnectedWithSACKAndTS(c)
const iterations = 3
- data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ data := make([]byte, 2*maxPayload*(tcp.InitialCwnd<<(iterations+1)))
for i := range data {
data[i] = byte(i)
}
// Write all the data in one shot. Packets will only be written at the
// MTU size though.
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 93683b921..db9dc4e25 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -19,6 +19,7 @@ import (
"fmt"
"io/ioutil"
"math"
+ "strings"
"testing"
"time"
@@ -26,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
@@ -1347,10 +1347,9 @@ func TestTOSV4(t *testing.T) {
testV4Connect(t, c, checker.TOS(tos, 0))
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -1396,10 +1395,9 @@ func TestTrafficClassV6(t *testing.T) {
testV6Connect(t, c, checker.TOS(tos, 0))
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2176,10 +2174,9 @@ func TestSimpleSend(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2217,10 +2214,9 @@ func TestZeroWindowSend(t *testing.T) {
c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2285,10 +2281,9 @@ func TestScaledWindowConnect(t *testing.T) {
})
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2317,10 +2312,9 @@ func TestNonScaledWindowConnect(t *testing.T) {
c.CreateConnected(789, 30000, 65535*3)
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2391,10 +2385,9 @@ func TestScaledWindowAccept(t *testing.T) {
}
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2465,10 +2458,9 @@ func TestNonScaledWindowAccept(t *testing.T) {
}
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -2632,9 +2624,10 @@ func TestSegmentMerging(t *testing.T) {
// Send tcp.InitialCwnd number of segments to fill up
// InitialWindow but don't ACK. That should prevent
// anymore packets from going out.
+ var r bytes.Reader
for i := 0; i < tcp.InitialCwnd; i++ {
- view := buffer.NewViewFromBytes([]byte{0})
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset([]byte{0})
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2644,8 +2637,8 @@ func TestSegmentMerging(t *testing.T) {
var allData []byte
for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2714,8 +2707,9 @@ func TestDelay(t *testing.T) {
var allData []byte
for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} {
allData = append(allData, data...)
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2761,8 +2755,9 @@ func TestUndelay(t *testing.T) {
allData := [][]byte{{0}, {1, 2, 3}}
for i, data := range allData {
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2845,8 +2840,9 @@ func TestMSSNotDelayed(t *testing.T) {
allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)}
for i, data := range allData {
- view := buffer.NewViewFromBytes(data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write #%d failed: %s", i+1, err)
}
}
@@ -2894,10 +2890,9 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
data[i] = byte(i)
}
- view := buffer.NewView(len(data))
- copy(view, data)
-
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3328,8 +3323,9 @@ func TestSendOnResetConnection(t *testing.T) {
time.Sleep(1 * time.Second)
// Try to write.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
+ var r bytes.Reader
+ r.Reset(make([]byte, 10))
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset {
t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
@@ -3352,7 +3348,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) {
c.WQ.EventRegister(&waitEntry, waiter.EventHUp)
defer c.WQ.EventUnregister(&waitEntry)
- _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(make([]byte, 1))
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3409,7 +3407,9 @@ func TestMaxRTO(t *testing.T) {
c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */)
- _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{})
+ var r bytes.Reader
+ r.Reset(make([]byte, 1))
+ _, err := c.EP.Write(&r, tcpip.WriteOptions{})
if err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3458,7 +3458,9 @@ func TestRetransmitIPv4IDUniqueness(t *testing.T) {
t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err)
}
- if _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(make([]byte, tc.size))
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
pkt := c.GetPacket()
@@ -3595,8 +3597,10 @@ func TestFinWithNoPendingData(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and have it acknowledged.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3667,9 +3671,11 @@ func TestFinWithPendingDataCwndFull(t *testing.T) {
// Write enough segments to fill the congestion window before ACK'ing
// any of them.
- view := buffer.NewView(10)
+ view := make([]byte, 10)
+ var r bytes.Reader
for i := tcp.InitialCwnd; i > 0; i-- {
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
}
@@ -3754,8 +3760,10 @@ func TestFinWithPendingData(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
// Write something out, and acknowledge it to get cwnd to 2.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3781,7 +3789,8 @@ func TestFinWithPendingData(t *testing.T) {
})
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3841,8 +3850,10 @@ func TestFinWithPartialAck(t *testing.T) {
// Write something out, and acknowledge it to get cwnd to 2. Also send
// FIN from the test side.
- view := buffer.NewView(10)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 10)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3879,7 +3890,8 @@ func TestFinWithPartialAck(t *testing.T) {
)
// Write new data, but don't acknowledge it.
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -3985,8 +3997,10 @@ func scaledSendWindow(t *testing.T, scale uint8) {
})
// Send some data. Check that it's capped by the window size.
- view := buffer.NewView(65535)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 65535)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4610,9 +4624,9 @@ func TestSelfConnect(t *testing.T) {
// Write something.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -4785,12 +4799,13 @@ func TestPathMTUDiscovery(t *testing.T) {
// Send 3200 bytes of data.
const writeSize = 3200
- data := buffer.NewView(writeSize)
+ data := make([]byte, writeSize)
for i := range data {
data[i] = byte(i)
}
-
- if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -5078,8 +5093,10 @@ func TestKeepalive(t *testing.T) {
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
- view := buffer.NewView(3)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 3)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -5358,7 +5375,9 @@ func TestListenBacklogFull(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ var r strings.Reader
+ r.Reset(data)
+ newEP.Write(&r, tcpip.WriteOptions{})
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
if string(tcp.Payload()) != data {
@@ -5674,7 +5693,9 @@ func TestListenSynRcvdQueueFull(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{})
+ var r strings.Reader
+ r.Reset(data)
+ newEP.Write(&r, tcpip.WriteOptions{})
pkt := c.GetPacket()
tcp = header.TCP(header.IPv4(pkt).Payload())
if string(tcp.Payload()) != data {
@@ -5908,7 +5929,9 @@ func TestSynRcvdBadSeqNumber(t *testing.T) {
// Now verify that the TCP socket is usable and in a connected state.
data := "Don't panic"
- if _, err := newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}); err != nil {
+ var r strings.Reader
+ r.Reset(data)
+ if _, err := newEP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -7103,10 +7126,10 @@ func TestTCPCloseWithData(t *testing.T) {
// Now write a few bytes and then close the endpoint.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
@@ -7204,8 +7227,10 @@ func TestTCPUserTimeout(t *testing.T) {
}
// Send some data and wait before ACKing it.
- view := buffer.NewView(3)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ view := make([]byte, 3)
+ var r bytes.Reader
+ r.Reset(view)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Write failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index b65091c3c..5a9745ad7 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -22,7 +22,6 @@ import (
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
@@ -152,10 +151,10 @@ func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndS
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %s", err)
}
@@ -215,10 +214,10 @@ func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wnd
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
data := []byte{1, 2, 3}
- view := buffer.NewView(len(data))
- copy(view, data)
- if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ var r bytes.Reader
+ r.Reset(data)
+ if _, err := c.EP.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %s", err)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 9f9b3d510..8544fcb08 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -514,9 +514,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc
return 0, tcpip.ErrBroadcastDisabled
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, tcpip.ErrBadBuffer
}
if len(v) > header.UDPMaximumPacketSize {
// Payload can't possibly fit in a packet.
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4e2123fe9..c4794e876 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -966,8 +966,9 @@ func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) {
h := flow.header4Tuple(outgoing)
writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr)
- payload := buffer.View(newPayload())
- _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ var r bytes.Reader
+ r.Reset(newPayload())
+ _, gotErr := c.ep.Write(&r, tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
})
c.checkEndpointWriteStats(1, epstats, gotErr)
@@ -1007,8 +1008,10 @@ func testWriteNoVerify(c *testContext, flow testFlow, setDest bool) buffer.View
To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port},
}
}
- payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ var r bytes.Reader
+ payload := newPayload()
+ r.Reset(payload)
+ n, err := c.ep.Write(&r, writeOpts)
if err != nil {
c.t.Fatalf("Write failed: %s", err)
}
@@ -1183,8 +1186,10 @@ func TestWriteOnConnectedInvalidPort(t *testing.T) {
writeOpts := tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: stackAddr, Port: invalidPort},
}
- payload := buffer.View(newPayload())
- n, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts)
+ var r bytes.Reader
+ payload := newPayload()
+ r.Reset(payload)
+ n, err := c.ep.Write(&r, writeOpts)
if err != nil {
c.t.Fatalf("c.ep.Write(...) = %+s, want nil", err)
}
@@ -2497,7 +2502,8 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}
defer ep.Close()
- data := tcpip.SlicePayload([]byte{1, 2, 3, 4})
+ var r bytes.Reader
+ data := []byte{1, 2, 3, 4}
to := tcpip.FullAddress{
Addr: test.remoteAddr,
Port: 80,
@@ -2508,19 +2514,22 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
expectedErrWithoutBcastOpt = nil
}
- if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ r.Reset(data)
+ if n, err := ep.Write(&r, opts); err != expectedErrWithoutBcastOpt {
t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt)
}
ep.SocketOptions().SetBroadcast(true)
- if n, err := ep.Write(data, opts); err != nil {
+ r.Reset(data)
+ if n, err := ep.Write(&r, opts); err != nil {
t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, nil)", opts, n, err)
}
ep.SocketOptions().SetBroadcast(false)
- if n, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt {
+ r.Reset(data)
+ if n, err := ep.Write(&r, opts); err != expectedErrWithoutBcastOpt {
t.Fatalf("got ep.Write(_, %#v) = (%d, %s), want = (_, %s)", opts, n, err, expectedErrWithoutBcastOpt)
}
})
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 79db8895b..dc2571154 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -517,28 +517,29 @@ func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, er
// Reader returns an io.Reader that reads from s. Reads beyond the end of s
// return io.EOF. The preconditions that apply to s.CopyIn also apply to the
// returned io.Reader.Read.
-func (s IOSequence) Reader(ctx context.Context) io.Reader {
- return &ioSequenceReadWriter{ctx, s}
+func (s IOSequence) Reader(ctx context.Context) *IOSequenceReadWriter {
+ return &IOSequenceReadWriter{ctx, s}
}
// Writer returns an io.Writer that writes to s. Writes beyond the end of s
// return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also
// apply to the returned io.Writer.Write.
-func (s IOSequence) Writer(ctx context.Context) io.Writer {
- return &ioSequenceReadWriter{ctx, s}
+func (s IOSequence) Writer(ctx context.Context) *IOSequenceReadWriter {
+ return &IOSequenceReadWriter{ctx, s}
}
// ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when
// attempting to write beyond the end of the IOSequence.
var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence")
-type ioSequenceReadWriter struct {
+// IOSequenceReadWriter implements io.Reader and io.Writer for an IOSequence.
+type IOSequenceReadWriter struct {
ctx context.Context
s IOSequence
}
// Read implements io.Reader.Read.
-func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) {
+func (rw *IOSequenceReadWriter) Read(dst []byte) (int, error) {
n, err := rw.s.CopyIn(rw.ctx, dst)
rw.s = rw.s.DropFirst(n)
if err == nil && rw.s.NumBytes() == 0 {
@@ -547,8 +548,13 @@ func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) {
return n, err
}
+// Len implements tcpip.Payloader.
+func (rw *IOSequenceReadWriter) Len() int {
+ return int(rw.s.NumBytes())
+}
+
// Write implements io.Writer.Write.
-func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) {
+func (rw *IOSequenceReadWriter) Write(src []byte) (int, error) {
n, err := rw.s.CopyOut(rw.ctx, src)
rw.s = rw.s.DropFirst(n)
if err == nil && n < len(src) {