summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/socket/netstack/netstack.go17
-rw-r--r--pkg/tcpip/BUILD1
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go4
-rw-r--r--pkg/tcpip/buffer/BUILD7
-rw-r--r--pkg/tcpip/buffer/view.go14
-rw-r--r--pkg/tcpip/buffer/view_test.go127
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go5
-rw-r--r--pkg/tcpip/network/ipv6/ipv6_test.go12
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go3
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go37
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go2
-rw-r--r--pkg/tcpip/stack/transport_test.go2
-rw-r--r--pkg/tcpip/tcpip.go32
-rw-r--r--pkg/tcpip/tcpip_test.go34
-rw-r--r--pkg/tcpip/tests/integration/forward_test.go2
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go2
-rw-r--r--pkg/tcpip/tests/integration/loopback_test.go4
-rw-r--r--pkg/tcpip/tests/integration/multicast_broadcast_test.go10
-rw-r--r--pkg/tcpip/tests/integration/route_test.go9
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go4
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go65
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go13
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go6
27 files changed, 265 insertions, 165 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 03749a8bf..22e128b96 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -425,8 +425,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
s.readMu.Lock()
defer s.readMu.Unlock()
+ w := tcpip.LimitedWriter{
+ W: dst,
+ N: count,
+ }
+
// This may return a blocking error.
- res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{
+ res, err := s.Endpoint.Read(&w, tcpip.ReadOptions{
Peek: dup,
})
if err != nil {
@@ -2579,7 +2584,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// caller-supplied buffer.
var w io.Writer
if !isPacket && trunc {
- w = ioutil.Discard
+ w = &tcpip.LimitedWriter{
+ W: ioutil.Discard,
+ N: dst.NumBytes(),
+ }
} else {
w = dst.Writer(ctx)
}
@@ -2587,7 +2595,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
s.readMu.Lock()
defer s.readMu.Unlock()
- res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions)
+ res, err := s.Endpoint.Read(w, readOptions)
+ if err == tcpip.ErrBadBuffer && dst.NumBytes() == 0 {
+ err = nil
+ }
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
index 89b765f1b..e7924e5c2 100644
--- a/pkg/tcpip/BUILD
+++ b/pkg/tcpip/BUILD
@@ -37,6 +37,7 @@ go_test(
size = "small",
srcs = ["tcpip_test.go"],
library = ":tcpip",
+ deps = ["@com_github_google_go_cmp//cmp:go_default_library"],
)
go_test(
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 85a0b8b90..fdeec12d3 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -295,7 +295,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
w := tcpip.SliceWriter(b)
opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
- res, err := ep.Read(&w, len(b), opts)
+ res, err := ep.Read(&w, opts)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
@@ -303,7 +303,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
- res, err = ep.Read(&w, len(b), opts)
+ res, err = ep.Read(&w, opts)
if err != tcpip.ErrWouldBlock {
break
}
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
index c326fab54..c9bcf9326 100644
--- a/pkg/tcpip/buffer/BUILD
+++ b/pkg/tcpip/buffer/BUILD
@@ -12,10 +12,13 @@ go_library(
)
go_test(
- name = "buffer_test",
+ name = "buffer_x_test",
size = "small",
srcs = [
"view_test.go",
],
- library = ":buffer",
+ deps = [
+ ":buffer",
+ "//pkg/tcpip",
+ ],
)
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 09d3dac66..91cc62cc8 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -148,23 +148,13 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
// ReadTo reads up to count bytes from vv to dst. It also removes them from vv
// unless peek is true.
-func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) {
+func (vv *VectorisedView) ReadTo(dst io.Writer, peek bool) (int, error) {
var err error
done := 0
for _, v := range vv.Views() {
- remaining := count - done
- if remaining <= 0 {
- break
- }
- if len(v) > remaining {
- v = v[:remaining]
- }
-
var n int
n, err = dst.Write(v)
- if n > 0 {
- done += n
- }
+ done += n
if err != nil {
break
}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
index e0ef8a94d..e7f7cc9f1 100644
--- a/pkg/tcpip/buffer/view_test.go
+++ b/pkg/tcpip/buffer/view_test.go
@@ -12,42 +12,43 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package buffer_test contains tests for the VectorisedView type.
-package buffer
+// Package buffer_test contains tests for the buffer.VectorisedView type.
+package buffer_test
import (
"bytes"
+ "io"
"reflect"
"testing"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
)
// copy returns a deep-copy of the vectorised view.
-func (vv VectorisedView) copy() VectorisedView {
- uu := VectorisedView{
- views: make([]View, 0, len(vv.views)),
- size: vv.size,
- }
- for _, v := range vv.views {
- uu.views = append(uu.views, append(View(nil), v...))
+func copyVV(vv buffer.VectorisedView) buffer.VectorisedView {
+ views := make([]buffer.View, 0, len(vv.Views()))
+ for _, v := range vv.Views() {
+ views = append(views, append(buffer.View(nil), v...))
}
- return uu
+ return buffer.NewVectorisedView(vv.Size(), views)
}
-// vv is an helper to build VectorisedView from different strings.
-func vv(size int, pieces ...string) VectorisedView {
- views := make([]View, len(pieces))
+// vv is an helper to build buffer.VectorisedView from different strings.
+func vv(size int, pieces ...string) buffer.VectorisedView {
+ views := make([]buffer.View, len(pieces))
for i, p := range pieces {
views[i] = []byte(p)
}
- return NewVectorisedView(size, views)
+ return buffer.NewVectorisedView(size, views)
}
var capLengthTestCases = []struct {
comment string
- in VectorisedView
+ in buffer.VectorisedView
length int
- want VectorisedView
+ want buffer.VectorisedView
}{
{
comment: "Simple case",
@@ -89,7 +90,7 @@ var capLengthTestCases = []struct {
func TestCapLength(t *testing.T) {
for _, c := range capLengthTestCases {
- orig := c.in.copy()
+ orig := copyVV(c.in)
c.in.CapLength(c.length)
if !reflect.DeepEqual(c.in, c.want) {
t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v",
@@ -100,9 +101,9 @@ func TestCapLength(t *testing.T) {
var trimFrontTestCases = []struct {
comment string
- in VectorisedView
+ in buffer.VectorisedView
count int
- want VectorisedView
+ want buffer.VectorisedView
}{
{
comment: "Simple case",
@@ -150,7 +151,7 @@ var trimFrontTestCases = []struct {
func TestTrimFront(t *testing.T) {
for _, c := range trimFrontTestCases {
- orig := c.in.copy()
+ orig := copyVV(c.in)
c.in.TrimFront(c.count)
if !reflect.DeepEqual(c.in, c.want) {
t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v",
@@ -161,8 +162,8 @@ func TestTrimFront(t *testing.T) {
var toViewCases = []struct {
comment string
- in VectorisedView
- want View
+ in buffer.VectorisedView
+ want buffer.View
}{
{
comment: "Simple case",
@@ -193,28 +194,28 @@ func TestToView(t *testing.T) {
var toCloneCases = []struct {
comment string
- inView VectorisedView
- inBuffer []View
+ inView buffer.VectorisedView
+ inBuffer []buffer.View
}{
{
comment: "Simple case",
inView: vv(1, "1"),
- inBuffer: make([]View, 1),
+ inBuffer: make([]buffer.View, 1),
},
{
comment: "Case with multiple views",
inView: vv(2, "1", "2"),
- inBuffer: make([]View, 2),
+ inBuffer: make([]buffer.View, 2),
},
{
comment: "Case with buffer too small",
inView: vv(2, "1", "2"),
- inBuffer: make([]View, 1),
+ inBuffer: make([]buffer.View, 1),
},
{
comment: "Case with buffer larger than needed",
inView: vv(1, "1"),
- inBuffer: make([]View, 2),
+ inBuffer: make([]buffer.View, 2),
},
{
comment: "Case with nil buffer",
@@ -237,10 +238,10 @@ func TestToClone(t *testing.T) {
type readToTestCases struct {
comment string
- vv VectorisedView
+ vv buffer.VectorisedView
bytesToRead int
wantBytes string
- leftVV VectorisedView
+ leftVV buffer.VectorisedView
}
func createReadToTestCases() []readToTestCases {
@@ -286,7 +287,7 @@ func createReadToTestCases() []readToTestCases {
func TestVVReadToVV(t *testing.T) {
for _, tc := range createReadToTestCases() {
t.Run(tc.comment, func(t *testing.T) {
- var readTo VectorisedView
+ var readTo buffer.VectorisedView
inSize := tc.vv.Size()
copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead)
if got, want := copied, len(tc.wantBytes); got != want {
@@ -308,13 +309,17 @@ func TestVVReadToVV(t *testing.T) {
func TestVVReadTo(t *testing.T) {
for _, tc := range createReadToTestCases() {
t.Run(tc.comment, func(t *testing.T) {
- var dst bytes.Buffer
+ b := make([]byte, tc.bytesToRead)
+ dst := tcpip.SliceWriter(b)
origSize := tc.vv.Size()
- copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, false /* peek */)
- if got, want := copied, len(tc.wantBytes); err != nil || got != want {
- t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want)
+ copied, err := tc.vv.ReadTo(&dst, false /* peek */)
+ if err != nil && err != io.ErrShortWrite {
+ t.Errorf("got ReadTo(&dst, false) = (_, %s); want nil or io.ErrShortWrite", err)
+ }
+ if got, want := copied, len(tc.wantBytes); got != want {
+ t.Errorf("got ReadTo(&dst, false) = (%d, _); want %d", got, want)
}
- if got, want := string(dst.Bytes()), tc.wantBytes; got != want {
+ if got, want := string(b[:copied]), tc.wantBytes; got != want {
t.Errorf("got dst = %q, want %q", got, want)
}
if got, want := tc.vv.Size(), origSize-copied; got != want {
@@ -330,14 +335,18 @@ func TestVVReadTo(t *testing.T) {
func TestVVReadToPeek(t *testing.T) {
for _, tc := range createReadToTestCases() {
t.Run(tc.comment, func(t *testing.T) {
- var dst bytes.Buffer
+ b := make([]byte, tc.bytesToRead)
+ dst := tcpip.SliceWriter(b)
origSize := tc.vv.Size()
origData := string(tc.vv.ToView())
- copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, true /* peek */)
- if got, want := copied, len(tc.wantBytes); err != nil || got != want {
- t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want)
+ copied, err := tc.vv.ReadTo(&dst, true /* peek */)
+ if err != nil && err != io.ErrShortWrite {
+ t.Errorf("got ReadTo(&dst, true) = (_, %s); want nil or io.ErrShortWrite", err)
+ }
+ if got, want := copied, len(tc.wantBytes); got != want {
+ t.Errorf("got ReadTo(&dst, true) = (%d, _); want %d", got, want)
}
- if got, want := string(dst.Bytes()), tc.wantBytes; got != want {
+ if got, want := string(b[:copied]), tc.wantBytes; got != want {
t.Errorf("got dst = %q, want %q", got, want)
}
// Expect tc.vv is unchanged.
@@ -354,7 +363,7 @@ func TestVVReadToPeek(t *testing.T) {
func TestVVRead(t *testing.T) {
testCases := []struct {
comment string
- vv VectorisedView
+ vv buffer.VectorisedView
bytesToRead int
readBytes string
leftBytes string
@@ -399,7 +408,7 @@ func TestVVRead(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.comment, func(t *testing.T) {
- readTo := NewView(tc.bytesToRead)
+ readTo := buffer.NewView(tc.bytesToRead)
inSize := tc.vv.Size()
copied, err := tc.vv.Read(readTo)
if !tc.wantError && err != nil {
@@ -424,10 +433,10 @@ func TestVVRead(t *testing.T) {
var pullUpTestCases = []struct {
comment string
- in VectorisedView
+ in buffer.VectorisedView
count int
want []byte
- result VectorisedView
+ result buffer.VectorisedView
ok bool
}{
{
@@ -521,7 +530,7 @@ func TestPullUp(t *testing.T) {
t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t",
c.comment, c.count, c.in, ok, c.ok)
}
- if bytes.Compare(got, View(c.want)) != 0 {
+ if bytes.Compare(got, buffer.View(c.want)) != 0 {
t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v",
c.comment, c.count, c.in, got, c.want)
}
@@ -536,12 +545,12 @@ func TestPullUp(t *testing.T) {
func TestToVectorisedView(t *testing.T) {
testCases := []struct {
- in View
- want VectorisedView
+ in buffer.View
+ want buffer.VectorisedView
}{
- {nil, VectorisedView{}},
- {View{}, VectorisedView{}},
- {View{'a'}, VectorisedView{size: 1, views: []View{{'a'}}}},
+ {nil, buffer.VectorisedView{}},
+ {buffer.View{}, buffer.VectorisedView{}},
+ {buffer.View{'a'}, buffer.NewVectorisedView(1, []buffer.View{{'a'}})},
}
for _, tc := range testCases {
if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) {
@@ -552,15 +561,15 @@ func TestToVectorisedView(t *testing.T) {
func TestAppendView(t *testing.T) {
testCases := []struct {
- vv VectorisedView
- in View
- want VectorisedView
+ vv buffer.VectorisedView
+ in buffer.View
+ want buffer.VectorisedView
}{
- {VectorisedView{}, nil, VectorisedView{}},
- {VectorisedView{}, View{}, VectorisedView{}},
- {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, nil, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}},
- {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}},
- {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{'e'}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}, {'e'}}, 5}},
+ {buffer.VectorisedView{}, nil, buffer.VectorisedView{}},
+ {buffer.VectorisedView{}, buffer.View{}, buffer.VectorisedView{}},
+ {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), nil, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})},
+ {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{}, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})},
+ {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{'e'}, buffer.NewVectorisedView(5, []buffer.View{{'a', 'b', 'c', 'd'}, {'e'}})},
}
for _, tc := range testCases {
tc.vv.AppendView(tc.in)
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 1c4919b1e..a9e137c24 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -2410,10 +2410,9 @@ func TestReceiveFragments(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
}
- const rcvSize = 65536 // Account for reassembled packets.
for i, expectedPayload := range test.expectedPayloads {
var buf bytes.Buffer
- result, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{})
+ result, err := ep.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("(i=%d) Read: %s", i, err)
}
@@ -2428,7 +2427,7 @@ func TestReceiveFragments(t *testing.T) {
}
}
- if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go
index 360025b20..b65c9d060 100644
--- a/pkg/tcpip/network/ipv6/ipv6_test.go
+++ b/pkg/tcpip/network/ipv6/ipv6_test.go
@@ -846,14 +846,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
},
}
- const mtu = header.IPv6MinimumMTU
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
})
- e := channel.New(1, mtu, linkAddr1)
+ e := channel.New(1, header.IPv6MinimumMTU, linkAddr1)
if err := s.CreateNIC(nicID, e); err != nil {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
@@ -983,7 +982,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = 1", got)
}
var buf bytes.Buffer
- result, err := ep.Read(&buf, mtu, tcpip.ReadOptions{})
+ result, err := ep.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read: %s", err)
}
@@ -998,7 +997,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) {
}
// Should not have any more UDP packets.
- if res, err := ep.Read(ioutil.Discard, mtu, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
@@ -1979,10 +1978,9 @@ func TestReceiveIPv6Fragments(t *testing.T) {
t.Errorf("got UDP Rx Packets = %d, want = %d", got, want)
}
- const rcvSize = 65536 // Account for reassembled packets.
for i, p := range test.expectedPayloads {
var buf bytes.Buffer
- _, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{})
+ _, err := ep.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("(i=%d) Read: %s", i, err)
}
@@ -1991,7 +1989,7 @@ func TestReceiveIPv6Fragments(t *testing.T) {
}
}
- if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock)
}
})
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index a7da9dcd9..3b4f900e3 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -44,7 +44,6 @@ import (
"bufio"
"fmt"
"log"
- "math"
"math/rand"
"net"
"os"
@@ -201,7 +200,7 @@ func main() {
// connection from its side.
wq.EventRegister(&waitEntry, waiter.EventIn)
for {
- _, err := ep.Read(os.Stdout, math.MaxUint16, tcpip.ReadOptions{})
+ _, err := ep.Read(os.Stdout, tcpip.ReadOptions{})
if err != nil {
if err == tcpip.ErrClosedForReceive {
break
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
index a80fa0474..3ac562756 100644
--- a/pkg/tcpip/sample/tun_tcp_echo/main.go
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -20,10 +20,9 @@
package main
import (
- "bytes"
"flag"
+ "io"
"log"
- "math"
"math/rand"
"net"
"os"
@@ -46,6 +45,31 @@ import (
var tap = flag.Bool("tap", false, "use tap istead of tun")
var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
+type endpointWriter struct {
+ ep tcpip.Endpoint
+}
+
+type tcpipError struct {
+ inner *tcpip.Error
+}
+
+func (e *tcpipError) Error() string {
+ return e.inner.String()
+}
+
+func (e *endpointWriter) Write(p []byte) (int, error) {
+ n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{})
+ if err != nil {
+ return int(n), &tcpipError{
+ inner: err,
+ }
+ }
+ if n != int64(len(p)) {
+ return int(n), io.ErrShortWrite
+ }
+ return int(n), nil
+}
+
func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
defer ep.Close()
@@ -55,9 +79,12 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
+ w := endpointWriter{
+ ep: ep,
+ }
+
for {
- var buf bytes.Buffer
- _, err := ep.Read(&buf, math.MaxUint16, tcpip.ReadOptions{})
+ _, err := ep.Read(&w, tcpip.ReadOptions{})
if err != nil {
if err == tcpip.ErrWouldBlock {
<-notifyCh
@@ -66,8 +93,6 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
return
}
-
- ep.Write(tcpip.SlicePayload(buf.Bytes()), tcpip.WriteOptions{})
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 859278f0b..57e1f8354 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -352,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) {
}
ep := <-pollChannel
- if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
+ if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil {
t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err)
}
stats[ep]++
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index a2ab7537c..9d39533a1 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -86,7 +86,7 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return mask
}
-func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
return tcpip.ReadResult{}, nil
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 49d4912ad..56aac093c 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -505,10 +505,34 @@ type SliceWriter []byte
func (s *SliceWriter) Write(b []byte) (int, error) {
n := copy(*s, b)
*s = (*s)[n:]
- if n < len(b) {
- return n, io.ErrShortWrite
+ var err error
+ if n != len(b) {
+ err = io.ErrShortWrite
}
- return n, nil
+ return n, err
+}
+
+var _ io.Writer = (*LimitedWriter)(nil)
+
+// A LimitedWriter writes to W but limits the amount of data copied to just N
+// bytes. Each call to Write updates N to reflect the new amount remaining.
+type LimitedWriter struct {
+ W io.Writer
+ N int64
+}
+
+func (l *LimitedWriter) Write(p []byte) (int, error) {
+ pLen := int64(len(p))
+ if pLen > l.N {
+ p = p[:l.N]
+ }
+ n, err := l.W.Write(p)
+ n64 := int64(n)
+ if err == nil && n64 != pLen {
+ err = io.ErrShortWrite
+ }
+ l.N -= n64
+ return n, err
}
// A ControlMessages contains socket control messages for IP sockets.
@@ -623,7 +647,7 @@ type Endpoint interface {
// If non-zero number of bytes are successfully read and written to dst, err
// must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
// should be returned.
- Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error)
+ Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
index 9bd563c46..269081ff8 100644
--- a/pkg/tcpip/tcpip_test.go
+++ b/pkg/tcpip/tcpip_test.go
@@ -15,12 +15,46 @@
package tcpip
import (
+ "bytes"
"fmt"
+ "io"
"net"
"strings"
"testing"
+
+ "github.com/google/go-cmp/cmp"
)
+func TestLimitedWriter_Write(t *testing.T) {
+ var b bytes.Buffer
+ l := LimitedWriter{
+ W: &b,
+ N: 5,
+ }
+ if n, err := l.Write([]byte{0, 1, 2}); err != nil {
+ t.Errorf("got l.Write(3/5) = (_, %s), want nil", err)
+ } else if n != 3 {
+ t.Errorf("got l.Write(3/5) = (%d, _), want 3", n)
+ }
+ if n, err := l.Write([]byte{3, 4, 5}); err != io.ErrShortWrite {
+ t.Errorf("got l.Write(3/2) = (_, %s), want io.ErrShortWrite", err)
+ } else if n != 2 {
+ t.Errorf("got l.Write(3/2) = (%d, _), want 2", n)
+ }
+ if l.N != 0 {
+ t.Errorf("got l.N = %d, want 0", l.N)
+ }
+ l.N = 1
+ if n, err := l.Write([]byte{5}); err != nil {
+ t.Errorf("got l.Write(1/1) = (_, %s), want nil", err)
+ } else if n != 1 {
+ t.Errorf("got l.Write(1/1) = (%d, _), want 1", n)
+ }
+ if diff := cmp.Diff(b.Bytes(), []byte{0, 1, 2, 3, 4, 5}); diff != "" {
+ t.Errorf("%T wrote incorrect data: (-want +got):\n%s", l, diff)
+ }
+}
+
func TestSubnetContains(t *testing.T) {
tests := []struct {
s Address
diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go
index 49acd504e..ac9670f9a 100644
--- a/pkg/tcpip/tests/integration/forward_test.go
+++ b/pkg/tcpip/tests/integration/forward_test.go
@@ -457,7 +457,7 @@ func TestForwarding(t *testing.T) {
<-ch
var buf bytes.Buffer
opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr}
- res, err := ep.Read(&buf, len(data), opts)
+ res, err := ep.Read(&buf, opts)
if err != nil {
t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index ed00c90d4..3f06c2145 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -218,7 +218,7 @@ func TestPing(t *testing.T) {
var buf bytes.Buffer
opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- res, err := ep.Read(&buf, len(icmpBuf), opts)
+ res, err := ep.Read(&buf, opts)
if err != nil {
t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err)
}
diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go
index a59f25cc3..3b13ba04d 100644
--- a/pkg/tcpip/tests/integration/loopback_test.go
+++ b/pkg/tcpip/tests/integration/loopback_test.go
@@ -242,9 +242,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) {
var buf bytes.Buffer
opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- if res, err := rep.Read(&buf, len(data), opts); test.expectRx {
+ if res, err := rep.Read(&buf, opts); test.expectRx {
if err != nil {
- t.Fatalf("rep.Read(_, %d, %#v): %s", len(data), opts, err)
+ t.Fatalf("rep.Read(_, %#v): %s", opts, err)
}
if diff := cmp.Diff(tcpip.ReadResult{
Count: buf.Len(),
diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
index eabc87938..ce7c16bd1 100644
--- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go
+++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go
@@ -466,9 +466,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) {
test.rxUDP(e, test.remoteAddr, test.dstAddr, data)
var buf bytes.Buffer
var opts tcpip.ReadOptions
- if res, err := ep.Read(&buf, len(data), opts); test.expectRx {
+ if res, err := ep.Read(&buf, opts); test.expectRx {
if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err)
+ t.Fatalf("ep.Read(_, %#v): %s", opts, err)
}
if diff := cmp.Diff(tcpip.ReadResult{
Count: buf.Len(),
@@ -598,7 +598,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) {
<-rep.ch
var buf bytes.Buffer
- result, err := rep.ep.Read(&buf, len(data), tcpip.ReadOptions{})
+ result, err := rep.ep.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err)
continue
@@ -738,7 +738,7 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
}
test.rxUDP(e, test.remoteAddr, test.multicastAddr, data)
var buf bytes.Buffer
- result, err := ep.Read(&buf, len(data), tcpip.ReadOptions{})
+ result, err := ep.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("ep.Read: %s", err)
} else {
@@ -759,7 +759,7 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) {
if err := ep.SetSockOpt(&removeOpt); err != nil {
t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err)
}
- if _, err := ep.Read(&buf, 1, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
+ if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock {
t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock)
}
})
diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go
index 76f7f54c6..b222d2b05 100644
--- a/pkg/tcpip/tests/integration/route_test.go
+++ b/pkg/tcpip/tests/integration/route_test.go
@@ -16,7 +16,6 @@ package integration_test
import (
"bytes"
- "math"
"testing"
"github.com/google/go-cmp/cmp"
@@ -208,9 +207,9 @@ func TestLocalPing(t *testing.T) {
var buf bytes.Buffer
opts := tcpip.ReadOptions{NeedRemoteAddr: true}
- res, err := ep.Read(&buf, math.MaxUint16, opts)
+ res, err := ep.Read(&buf, opts)
if err != nil {
- t.Fatalf("ep.Read(_, %d, %#v): %s", math.MaxUint16, opts, err)
+ t.Fatalf("ep.Read(_, %#v): %s", opts, err)
}
if diff := cmp.Diff(tcpip.ReadResult{
Count: buf.Len(),
@@ -351,7 +350,7 @@ func TestLocalUDP(t *testing.T) {
var clientAddr tcpip.FullAddress
var readBuf bytes.Buffer
- if read, err := server.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
+ if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
t.Fatalf("server.Read(_): %s", err)
} else {
clientAddr = read.RemoteAddr
@@ -393,7 +392,7 @@ func TestLocalUDP(t *testing.T) {
<-clientCH
readBuf.Reset()
- if read, err := client.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
+ if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil {
t.Fatalf("client.Read(_): %s", err)
} else {
if diff := cmp.Diff(tcpip.ReadResult{
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 87277fbd3..256e19296 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -154,7 +154,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -186,7 +186,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
res.RemoteAddr = p.senderAddress
}
- n, err := p.data.ReadTo(dst, count, opts.Peek)
+ n, err := p.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
return res, tcpip.ErrBadBuffer
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index c3b3b8d34..c0d6fb442 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -162,7 +162,7 @@ func (ep *endpoint) Close() {
func (ep *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -199,7 +199,7 @@ func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpi
res.LinkPacketInfo = packet.packetInfo
}
- n, err := packet.data.ReadTo(dst, count, opts.Peek)
+ n, err := packet.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
return res, tcpip.ErrBadBuffer
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 425bcf3ee..ae743f75e 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -191,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -225,7 +225,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
res.RemoteAddr = pkt.senderAddr
}
- n, err := pkt.data.ReadTo(dst, count, opts.Peek)
+ n, err := pkt.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
return res, tcpip.ErrBadBuffer
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a4508e871..ea509ac73 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvReadMu.Lock()
defer e.rcvReadMu.Unlock()
@@ -1346,9 +1346,9 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
var err error
done := 0
s := first
- for s != nil && done < count {
+ for s != nil {
var n int
- n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ n, err = s.data.ReadTo(dst, opts.Peek)
// Book keeping first then error handling.
done += n
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 729bf7ef5..93683b921 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -50,7 +50,7 @@ type endpointTester struct {
// CheckReadError issues a read to the endpoint and checking for an error.
func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
t.Helper()
- res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{})
+ res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{})
if got != want {
t.Fatalf("ep.Read = %s, want %s", got, want)
}
@@ -61,10 +61,10 @@ func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
// CheckRead issues a read to the endpoint and checking for a success, returning
// the data read.
-func (e *endpointTester) CheckRead(t *testing.T, count int) []byte {
+func (e *endpointTester) CheckRead(t *testing.T) []byte {
t.Helper()
var buf bytes.Buffer
- res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{})
+ res, err := e.ep.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("ep.Read = _, %s; want _, nil", err)
}
@@ -81,9 +81,12 @@ func (e *endpointTester) CheckRead(t *testing.T, count int) []byte {
func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
t.Helper()
var buf bytes.Buffer
- var done int
- for done < count {
- res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{})
+ w := tcpip.LimitedWriter{
+ W: &buf,
+ N: int64(count),
+ }
+ for w.N != 0 {
+ _, err := e.ep.Read(&w, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
// Wait for receive to be notified.
select {
@@ -95,7 +98,6 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha
} else if err != nil {
t.Fatalf("ep.Read = _, %s; want _, nil", err)
}
- done += res.Count
}
return buf.Bytes()
}
@@ -820,7 +822,7 @@ func TestSimpleReceive(t *testing.T) {
}
// Receive data.
- v := ept.CheckRead(t, defaultMTU)
+ v := ept.CheckRead(t)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1928,7 +1930,7 @@ func TestFullWindowReceive(t *testing.T) {
)
// Receive data and check it.
- v := ept.CheckRead(t, defaultMTU)
+ v := ept.CheckRead(t)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -2015,7 +2017,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
// Read the data so that the subsequent ACK from the endpoint
// grows the right edge of the window.
var buf bytes.Buffer
- if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
+ if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil {
t.Fatalf("c.EP.Read: %s", err)
}
@@ -2075,7 +2077,7 @@ func TestNoWindowShrinking(t *testing.T) {
}
// Read the 1 byte payload we just sent.
- if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) {
+ if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) {
t.Fatalf("got data: %v, want: %v", got, want)
}
@@ -2570,13 +2572,16 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// update to be sent. For 1MSS worth of window to be available we need to
// read at least 128KB. Since our segments above were 50KB each it means
// we need to read at 3 packets.
- sz := 0
- for sz < defaultMTU*2 {
- res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
+ w := tcpip.LimitedWriter{
+ W: ioutil.Discard,
+ N: defaultMTU * 2,
+ }
+ for w.N != 0 {
+ res, err := c.EP.Read(&w, tcpip.ReadOptions{})
+ t.Logf("err=%v res=%#v", err, res)
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- sz += res.Count
}
checker.IPv4(t, c.GetPacket(),
@@ -3271,12 +3276,12 @@ func TestReceiveOnResetConnection(t *testing.T) {
loop:
for {
- switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err {
+ switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err {
case tcpip.ErrWouldBlock:
select {
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
+ if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset)
}
break loop
@@ -4224,7 +4229,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Check that peek works.
var peekBuf bytes.Buffer
- res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true})
+ res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
@@ -4237,7 +4242,7 @@ func TestReadAfterClosedState(t *testing.T) {
}
// Receive data.
- v := ept.CheckRead(t, defaultMTU)
+ v := ept.CheckRead(t)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -4246,8 +4251,8 @@ func TestReadAfterClosedState(t *testing.T) {
// right error code.
ept.CheckReadError(t, tcpip.ErrClosedForReceive)
var buf bytes.Buffer
- if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
+ if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
}
}
@@ -6205,7 +6210,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
- _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
+ _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
@@ -6329,7 +6334,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// to happen before we measure the new window.
totalCopied := 0
for {
- res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
+ res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
@@ -7387,15 +7392,17 @@ func TestIncreaseWindowOnRead(t *testing.T) {
// We now have < 1 MSS in the buffer space. Read at least > 2 MSS
// worth of data as receive buffer space
- read := 0
- // defaultMTU is a good enough estimate for the MSS used for this
- // connection.
- for read < defaultMTU*2 {
- res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
+ w := tcpip.LimitedWriter{
+ W: ioutil.Discard,
+ // defaultMTU is a good enough estimate for the MSS used for this
+ // connection.
+ N: defaultMTU * 2,
+ }
+ for w.N != 0 {
+ _, err := c.EP.Read(&w, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- read += res.Count
}
// After reading > MSS worth of data, we surely crossed MSS. See the ack:
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 88fb054bb..b65091c3c 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -106,18 +106,19 @@ func TestTimeStampEnabledConnect(t *testing.T) {
// There should be 5 views to read and each of them should
// contain the same data.
for i := 0; i < 5; i++ {
- var buf bytes.Buffer
- result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{})
+ buf := make([]byte, len(data))
+ w := tcpip.SliceWriter(buf)
+ result, err := c.EP.Read(&w, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
+ Count: len(buf),
+ Total: len(buf),
}, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
}
- if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
+ if got, want := buf, data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
@@ -295,7 +296,7 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
// Issue a read and we should data.
var buf bytes.Buffer
- result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{})
+ result, err := c.EP.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 520a0ac9d..9f9b3d510 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -284,7 +284,7 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
return tcpip.ReadResult{}, err
}
@@ -340,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
res.RemoteAddr = p.senderAddress
}
- n, err := p.data.ReadTo(dst, count, opts.Peek)
+ n, err := p.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
return res, tcpip.ErrBadBuffer
}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 52403ed78..4e2123fe9 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -598,12 +598,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
var buf bytes.Buffer
- res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
+ res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
+ res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true})
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -839,7 +839,7 @@ func TestV4ReadSelfSource(t *testing.T) {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
}
- if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr {
+ if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr {
t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
}
})