From 12d9790833cc2f6a9b197066a5ecbeb434f74164 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Fri, 15 Jan 2021 15:47:13 -0800 Subject: Remove count argument from tcpip.Endpoint.Read The same intent can be specified via the io.Writer. PiperOrigin-RevId: 352098747 --- pkg/sentry/socket/netstack/netstack.go | 17 ++- pkg/tcpip/BUILD | 1 + pkg/tcpip/adapters/gonet/gonet.go | 4 +- pkg/tcpip/buffer/BUILD | 7 +- pkg/tcpip/buffer/view.go | 14 +-- pkg/tcpip/buffer/view_test.go | 127 +++++++++++---------- pkg/tcpip/network/ipv4/ipv4_test.go | 5 +- pkg/tcpip/network/ipv6/ipv6_test.go | 12 +- pkg/tcpip/sample/tun_tcp_connect/main.go | 3 +- pkg/tcpip/sample/tun_tcp_echo/main.go | 37 +++++- pkg/tcpip/stack/transport_demuxer_test.go | 2 +- pkg/tcpip/stack/transport_test.go | 2 +- pkg/tcpip/tcpip.go | 32 +++++- pkg/tcpip/tcpip_test.go | 34 ++++++ pkg/tcpip/tests/integration/forward_test.go | 2 +- .../tests/integration/link_resolution_test.go | 2 +- pkg/tcpip/tests/integration/loopback_test.go | 4 +- .../tests/integration/multicast_broadcast_test.go | 10 +- pkg/tcpip/tests/integration/route_test.go | 9 +- pkg/tcpip/transport/icmp/endpoint.go | 4 +- pkg/tcpip/transport/packet/endpoint.go | 4 +- pkg/tcpip/transport/raw/endpoint.go | 4 +- pkg/tcpip/transport/tcp/endpoint.go | 6 +- pkg/tcpip/transport/tcp/tcp_test.go | 65 ++++++----- pkg/tcpip/transport/tcp/tcp_timestamp_test.go | 13 ++- pkg/tcpip/transport/udp/endpoint.go | 4 +- pkg/tcpip/transport/udp/udp_test.go | 6 +- 27 files changed, 265 insertions(+), 165 deletions(-) (limited to 'pkg') diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 03749a8bf..22e128b96 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -425,8 +425,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write s.readMu.Lock() defer s.readMu.Unlock() + w := tcpip.LimitedWriter{ + W: dst, + N: count, + } + // This may return a blocking error. - res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{ + res, err := s.Endpoint.Read(&w, tcpip.ReadOptions{ Peek: dup, }) if err != nil { @@ -2579,7 +2584,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // caller-supplied buffer. var w io.Writer if !isPacket && trunc { - w = ioutil.Discard + w = &tcpip.LimitedWriter{ + W: ioutil.Discard, + N: dst.NumBytes(), + } } else { w = dst.Writer(ctx) } @@ -2587,7 +2595,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq s.readMu.Lock() defer s.readMu.Unlock() - res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions) + res, err := s.Endpoint.Read(w, readOptions) + if err == tcpip.ErrBadBuffer && dst.NumBytes() == 0 { + err = nil + } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD index 89b765f1b..e7924e5c2 100644 --- a/pkg/tcpip/BUILD +++ b/pkg/tcpip/BUILD @@ -37,6 +37,7 @@ go_test( size = "small", srcs = ["tcpip_test.go"], library = ":tcpip", + deps = ["@com_github_google_go_cmp//cmp:go_default_library"], ) go_test( diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 85a0b8b90..fdeec12d3 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -295,7 +295,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s w := tcpip.SliceWriter(b) opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} - res, err := ep.Read(&w, len(b), opts) + res, err := ep.Read(&w, opts) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. @@ -303,7 +303,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { - res, err = ep.Read(&w, len(b), opts) + res, err = ep.Read(&w, opts) if err != tcpip.ErrWouldBlock { break } diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD index c326fab54..c9bcf9326 100644 --- a/pkg/tcpip/buffer/BUILD +++ b/pkg/tcpip/buffer/BUILD @@ -12,10 +12,13 @@ go_library( ) go_test( - name = "buffer_test", + name = "buffer_x_test", size = "small", srcs = [ "view_test.go", ], - library = ":buffer", + deps = [ + ":buffer", + "//pkg/tcpip", + ], ) diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 09d3dac66..91cc62cc8 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -148,23 +148,13 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int // ReadTo reads up to count bytes from vv to dst. It also removes them from vv // unless peek is true. -func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) { +func (vv *VectorisedView) ReadTo(dst io.Writer, peek bool) (int, error) { var err error done := 0 for _, v := range vv.Views() { - remaining := count - done - if remaining <= 0 { - break - } - if len(v) > remaining { - v = v[:remaining] - } - var n int n, err = dst.Write(v) - if n > 0 { - done += n - } + done += n if err != nil { break } diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index e0ef8a94d..e7f7cc9f1 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -12,42 +12,43 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package buffer_test contains tests for the VectorisedView type. -package buffer +// Package buffer_test contains tests for the buffer.VectorisedView type. +package buffer_test import ( "bytes" + "io" "reflect" "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" ) // copy returns a deep-copy of the vectorised view. -func (vv VectorisedView) copy() VectorisedView { - uu := VectorisedView{ - views: make([]View, 0, len(vv.views)), - size: vv.size, - } - for _, v := range vv.views { - uu.views = append(uu.views, append(View(nil), v...)) +func copyVV(vv buffer.VectorisedView) buffer.VectorisedView { + views := make([]buffer.View, 0, len(vv.Views())) + for _, v := range vv.Views() { + views = append(views, append(buffer.View(nil), v...)) } - return uu + return buffer.NewVectorisedView(vv.Size(), views) } -// vv is an helper to build VectorisedView from different strings. -func vv(size int, pieces ...string) VectorisedView { - views := make([]View, len(pieces)) +// vv is an helper to build buffer.VectorisedView from different strings. +func vv(size int, pieces ...string) buffer.VectorisedView { + views := make([]buffer.View, len(pieces)) for i, p := range pieces { views[i] = []byte(p) } - return NewVectorisedView(size, views) + return buffer.NewVectorisedView(size, views) } var capLengthTestCases = []struct { comment string - in VectorisedView + in buffer.VectorisedView length int - want VectorisedView + want buffer.VectorisedView }{ { comment: "Simple case", @@ -89,7 +90,7 @@ var capLengthTestCases = []struct { func TestCapLength(t *testing.T) { for _, c := range capLengthTestCases { - orig := c.in.copy() + orig := copyVV(c.in) c.in.CapLength(c.length) if !reflect.DeepEqual(c.in, c.want) { t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v", @@ -100,9 +101,9 @@ func TestCapLength(t *testing.T) { var trimFrontTestCases = []struct { comment string - in VectorisedView + in buffer.VectorisedView count int - want VectorisedView + want buffer.VectorisedView }{ { comment: "Simple case", @@ -150,7 +151,7 @@ var trimFrontTestCases = []struct { func TestTrimFront(t *testing.T) { for _, c := range trimFrontTestCases { - orig := c.in.copy() + orig := copyVV(c.in) c.in.TrimFront(c.count) if !reflect.DeepEqual(c.in, c.want) { t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v", @@ -161,8 +162,8 @@ func TestTrimFront(t *testing.T) { var toViewCases = []struct { comment string - in VectorisedView - want View + in buffer.VectorisedView + want buffer.View }{ { comment: "Simple case", @@ -193,28 +194,28 @@ func TestToView(t *testing.T) { var toCloneCases = []struct { comment string - inView VectorisedView - inBuffer []View + inView buffer.VectorisedView + inBuffer []buffer.View }{ { comment: "Simple case", inView: vv(1, "1"), - inBuffer: make([]View, 1), + inBuffer: make([]buffer.View, 1), }, { comment: "Case with multiple views", inView: vv(2, "1", "2"), - inBuffer: make([]View, 2), + inBuffer: make([]buffer.View, 2), }, { comment: "Case with buffer too small", inView: vv(2, "1", "2"), - inBuffer: make([]View, 1), + inBuffer: make([]buffer.View, 1), }, { comment: "Case with buffer larger than needed", inView: vv(1, "1"), - inBuffer: make([]View, 2), + inBuffer: make([]buffer.View, 2), }, { comment: "Case with nil buffer", @@ -237,10 +238,10 @@ func TestToClone(t *testing.T) { type readToTestCases struct { comment string - vv VectorisedView + vv buffer.VectorisedView bytesToRead int wantBytes string - leftVV VectorisedView + leftVV buffer.VectorisedView } func createReadToTestCases() []readToTestCases { @@ -286,7 +287,7 @@ func createReadToTestCases() []readToTestCases { func TestVVReadToVV(t *testing.T) { for _, tc := range createReadToTestCases() { t.Run(tc.comment, func(t *testing.T) { - var readTo VectorisedView + var readTo buffer.VectorisedView inSize := tc.vv.Size() copied := tc.vv.ReadToVV(&readTo, tc.bytesToRead) if got, want := copied, len(tc.wantBytes); got != want { @@ -308,13 +309,17 @@ func TestVVReadToVV(t *testing.T) { func TestVVReadTo(t *testing.T) { for _, tc := range createReadToTestCases() { t.Run(tc.comment, func(t *testing.T) { - var dst bytes.Buffer + b := make([]byte, tc.bytesToRead) + dst := tcpip.SliceWriter(b) origSize := tc.vv.Size() - copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, false /* peek */) - if got, want := copied, len(tc.wantBytes); err != nil || got != want { - t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want) + copied, err := tc.vv.ReadTo(&dst, false /* peek */) + if err != nil && err != io.ErrShortWrite { + t.Errorf("got ReadTo(&dst, false) = (_, %s); want nil or io.ErrShortWrite", err) + } + if got, want := copied, len(tc.wantBytes); got != want { + t.Errorf("got ReadTo(&dst, false) = (%d, _); want %d", got, want) } - if got, want := string(dst.Bytes()), tc.wantBytes; got != want { + if got, want := string(b[:copied]), tc.wantBytes; got != want { t.Errorf("got dst = %q, want %q", got, want) } if got, want := tc.vv.Size(), origSize-copied; got != want { @@ -330,14 +335,18 @@ func TestVVReadTo(t *testing.T) { func TestVVReadToPeek(t *testing.T) { for _, tc := range createReadToTestCases() { t.Run(tc.comment, func(t *testing.T) { - var dst bytes.Buffer + b := make([]byte, tc.bytesToRead) + dst := tcpip.SliceWriter(b) origSize := tc.vv.Size() origData := string(tc.vv.ToView()) - copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, true /* peek */) - if got, want := copied, len(tc.wantBytes); err != nil || got != want { - t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want) + copied, err := tc.vv.ReadTo(&dst, true /* peek */) + if err != nil && err != io.ErrShortWrite { + t.Errorf("got ReadTo(&dst, true) = (_, %s); want nil or io.ErrShortWrite", err) + } + if got, want := copied, len(tc.wantBytes); got != want { + t.Errorf("got ReadTo(&dst, true) = (%d, _); want %d", got, want) } - if got, want := string(dst.Bytes()), tc.wantBytes; got != want { + if got, want := string(b[:copied]), tc.wantBytes; got != want { t.Errorf("got dst = %q, want %q", got, want) } // Expect tc.vv is unchanged. @@ -354,7 +363,7 @@ func TestVVReadToPeek(t *testing.T) { func TestVVRead(t *testing.T) { testCases := []struct { comment string - vv VectorisedView + vv buffer.VectorisedView bytesToRead int readBytes string leftBytes string @@ -399,7 +408,7 @@ func TestVVRead(t *testing.T) { for _, tc := range testCases { t.Run(tc.comment, func(t *testing.T) { - readTo := NewView(tc.bytesToRead) + readTo := buffer.NewView(tc.bytesToRead) inSize := tc.vv.Size() copied, err := tc.vv.Read(readTo) if !tc.wantError && err != nil { @@ -424,10 +433,10 @@ func TestVVRead(t *testing.T) { var pullUpTestCases = []struct { comment string - in VectorisedView + in buffer.VectorisedView count int want []byte - result VectorisedView + result buffer.VectorisedView ok bool }{ { @@ -521,7 +530,7 @@ func TestPullUp(t *testing.T) { t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got an ok of %t. Want %t", c.comment, c.count, c.in, ok, c.ok) } - if bytes.Compare(got, View(c.want)) != 0 { + if bytes.Compare(got, buffer.View(c.want)) != 0 { t.Errorf("Test %q failed when calling PullUp(%d) on %v. Got %v. Want %v", c.comment, c.count, c.in, got, c.want) } @@ -536,12 +545,12 @@ func TestPullUp(t *testing.T) { func TestToVectorisedView(t *testing.T) { testCases := []struct { - in View - want VectorisedView + in buffer.View + want buffer.VectorisedView }{ - {nil, VectorisedView{}}, - {View{}, VectorisedView{}}, - {View{'a'}, VectorisedView{size: 1, views: []View{{'a'}}}}, + {nil, buffer.VectorisedView{}}, + {buffer.View{}, buffer.VectorisedView{}}, + {buffer.View{'a'}, buffer.NewVectorisedView(1, []buffer.View{{'a'}})}, } for _, tc := range testCases { if got, want := tc.in.ToVectorisedView(), tc.want; !reflect.DeepEqual(got, want) { @@ -552,15 +561,15 @@ func TestToVectorisedView(t *testing.T) { func TestAppendView(t *testing.T) { testCases := []struct { - vv VectorisedView - in View - want VectorisedView + vv buffer.VectorisedView + in buffer.View + want buffer.VectorisedView }{ - {VectorisedView{}, nil, VectorisedView{}}, - {VectorisedView{}, View{}, VectorisedView{}}, - {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, nil, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}}, - {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}}, - {VectorisedView{[]View{{'a', 'b', 'c', 'd'}}, 4}, View{'e'}, VectorisedView{[]View{{'a', 'b', 'c', 'd'}, {'e'}}, 5}}, + {buffer.VectorisedView{}, nil, buffer.VectorisedView{}}, + {buffer.VectorisedView{}, buffer.View{}, buffer.VectorisedView{}}, + {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), nil, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})}, + {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{}, buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}})}, + {buffer.NewVectorisedView(4, []buffer.View{{'a', 'b', 'c', 'd'}}), buffer.View{'e'}, buffer.NewVectorisedView(5, []buffer.View{{'a', 'b', 'c', 'd'}, {'e'}})}, } for _, tc := range testCases { tc.vv.AppendView(tc.in) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 1c4919b1e..a9e137c24 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -2410,10 +2410,9 @@ func TestReceiveFragments(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) } - const rcvSize = 65536 // Account for reassembled packets. for i, expectedPayload := range test.expectedPayloads { var buf bytes.Buffer - result, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{}) + result, err := ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("(i=%d) Read: %s", i, err) } @@ -2428,7 +2427,7 @@ func TestReceiveFragments(t *testing.T) { } } - if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 360025b20..b65c9d060 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -846,14 +846,13 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, } - const mtu = header.IPv6MinimumMTU for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(1, mtu, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -983,7 +982,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = 1", got) } var buf bytes.Buffer - result, err := ep.Read(&buf, mtu, tcpip.ReadOptions{}) + result, err := ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Read: %s", err) } @@ -998,7 +997,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { } // Should not have any more UDP packets. - if res, err := ep.Read(ioutil.Discard, mtu, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) @@ -1979,10 +1978,9 @@ func TestReceiveIPv6Fragments(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) } - const rcvSize = 65536 // Account for reassembled packets. for i, p := range test.expectedPayloads { var buf bytes.Buffer - _, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{}) + _, err := ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("(i=%d) Read: %s", i, err) } @@ -1991,7 +1989,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { } } - if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + if res, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index a7da9dcd9..3b4f900e3 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -44,7 +44,6 @@ import ( "bufio" "fmt" "log" - "math" "math/rand" "net" "os" @@ -201,7 +200,7 @@ func main() { // connection from its side. wq.EventRegister(&waitEntry, waiter.EventIn) for { - _, err := ep.Read(os.Stdout, math.MaxUint16, tcpip.ReadOptions{}) + _, err := ep.Read(os.Stdout, tcpip.ReadOptions{}) if err != nil { if err == tcpip.ErrClosedForReceive { break diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index a80fa0474..3ac562756 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -20,10 +20,9 @@ package main import ( - "bytes" "flag" + "io" "log" - "math" "math/rand" "net" "os" @@ -46,6 +45,31 @@ import ( var tap = flag.Bool("tap", false, "use tap istead of tun") var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") +type endpointWriter struct { + ep tcpip.Endpoint +} + +type tcpipError struct { + inner *tcpip.Error +} + +func (e *tcpipError) Error() string { + return e.inner.String() +} + +func (e *endpointWriter) Write(p []byte) (int, error) { + n, err := e.ep.Write(tcpip.SlicePayload(p), tcpip.WriteOptions{}) + if err != nil { + return int(n), &tcpipError{ + inner: err, + } + } + if n != int64(len(p)) { + return int(n), io.ErrShortWrite + } + return int(n), nil +} + func echo(wq *waiter.Queue, ep tcpip.Endpoint) { defer ep.Close() @@ -55,9 +79,12 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) + w := endpointWriter{ + ep: ep, + } + for { - var buf bytes.Buffer - _, err := ep.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}) + _, err := ep.Read(&w, tcpip.ReadOptions{}) if err != nil { if err == tcpip.ErrWouldBlock { <-notifyCh @@ -66,8 +93,6 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { return } - - ep.Write(tcpip.SlicePayload(buf.Bytes()), tcpip.WriteOptions{}) } } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 859278f0b..57e1f8354 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -352,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) { } ep := <-pollChannel - if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil { + if _, err := ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != nil { t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index a2ab7537c..9d39533a1 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -86,7 +86,7 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return mask } -func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (*fakeTransportEndpoint) Read(io.Writer, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { return tcpip.ReadResult{}, nil } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 49d4912ad..56aac093c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -505,10 +505,34 @@ type SliceWriter []byte func (s *SliceWriter) Write(b []byte) (int, error) { n := copy(*s, b) *s = (*s)[n:] - if n < len(b) { - return n, io.ErrShortWrite + var err error + if n != len(b) { + err = io.ErrShortWrite } - return n, nil + return n, err +} + +var _ io.Writer = (*LimitedWriter)(nil) + +// A LimitedWriter writes to W but limits the amount of data copied to just N +// bytes. Each call to Write updates N to reflect the new amount remaining. +type LimitedWriter struct { + W io.Writer + N int64 +} + +func (l *LimitedWriter) Write(p []byte) (int, error) { + pLen := int64(len(p)) + if pLen > l.N { + p = p[:l.N] + } + n, err := l.W.Write(p) + n64 := int64(n) + if err == nil && n64 != pLen { + err = io.ErrShortWrite + } + l.N -= n64 + return n, err } // A ControlMessages contains socket control messages for IP sockets. @@ -623,7 +647,7 @@ type Endpoint interface { // If non-zero number of bytes are successfully read and written to dst, err // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer // should be returned. - Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error) + Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error) // Write writes data to the endpoint's peer. This method does not block if // the data cannot be written. diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go index 9bd563c46..269081ff8 100644 --- a/pkg/tcpip/tcpip_test.go +++ b/pkg/tcpip/tcpip_test.go @@ -15,12 +15,46 @@ package tcpip import ( + "bytes" "fmt" + "io" "net" "strings" "testing" + + "github.com/google/go-cmp/cmp" ) +func TestLimitedWriter_Write(t *testing.T) { + var b bytes.Buffer + l := LimitedWriter{ + W: &b, + N: 5, + } + if n, err := l.Write([]byte{0, 1, 2}); err != nil { + t.Errorf("got l.Write(3/5) = (_, %s), want nil", err) + } else if n != 3 { + t.Errorf("got l.Write(3/5) = (%d, _), want 3", n) + } + if n, err := l.Write([]byte{3, 4, 5}); err != io.ErrShortWrite { + t.Errorf("got l.Write(3/2) = (_, %s), want io.ErrShortWrite", err) + } else if n != 2 { + t.Errorf("got l.Write(3/2) = (%d, _), want 2", n) + } + if l.N != 0 { + t.Errorf("got l.N = %d, want 0", l.N) + } + l.N = 1 + if n, err := l.Write([]byte{5}); err != nil { + t.Errorf("got l.Write(1/1) = (_, %s), want nil", err) + } else if n != 1 { + t.Errorf("got l.Write(1/1) = (%d, _), want 1", n) + } + if diff := cmp.Diff(b.Bytes(), []byte{0, 1, 2, 3, 4, 5}); diff != "" { + t.Errorf("%T wrote incorrect data: (-want +got):\n%s", l, diff) + } +} + func TestSubnetContains(t *testing.T) { tests := []struct { s Address diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 49acd504e..ac9670f9a 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -457,7 +457,7 @@ func TestForwarding(t *testing.T) { <-ch var buf bytes.Buffer opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} - res, err := ep.Read(&buf, len(data), opts) + res, err := ep.Read(&buf, opts) if err != nil { t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index ed00c90d4..3f06c2145 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -218,7 +218,7 @@ func TestPing(t *testing.T) { var buf bytes.Buffer opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, len(icmpBuf), opts) + res, err := ep.Read(&buf, opts) if err != nil { t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err) } diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index a59f25cc3..3b13ba04d 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -242,9 +242,9 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { var buf bytes.Buffer opts := tcpip.ReadOptions{NeedRemoteAddr: true} - if res, err := rep.Read(&buf, len(data), opts); test.expectRx { + if res, err := rep.Read(&buf, opts); test.expectRx { if err != nil { - t.Fatalf("rep.Read(_, %d, %#v): %s", len(data), opts, err) + t.Fatalf("rep.Read(_, %#v): %s", opts, err) } if diff := cmp.Diff(tcpip.ReadResult{ Count: buf.Len(), diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index eabc87938..ce7c16bd1 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -466,9 +466,9 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { test.rxUDP(e, test.remoteAddr, test.dstAddr, data) var buf bytes.Buffer var opts tcpip.ReadOptions - if res, err := ep.Read(&buf, len(data), opts); test.expectRx { + if res, err := ep.Read(&buf, opts); test.expectRx { if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + t.Fatalf("ep.Read(_, %#v): %s", opts, err) } if diff := cmp.Diff(tcpip.ReadResult{ Count: buf.Len(), @@ -598,7 +598,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { <-rep.ch var buf bytes.Buffer - result, err := rep.ep.Read(&buf, len(data), tcpip.ReadOptions{}) + result, err := rep.ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err) continue @@ -738,7 +738,7 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { } test.rxUDP(e, test.remoteAddr, test.multicastAddr, data) var buf bytes.Buffer - result, err := ep.Read(&buf, len(data), tcpip.ReadOptions{}) + result, err := ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("ep.Read: %s", err) } else { @@ -759,7 +759,7 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { if err := ep.SetSockOpt(&removeOpt); err != nil { t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) } - if _, err := ep.Read(&buf, 1, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + if _, err := ep.Read(&buf, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) } }) diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 76f7f54c6..b222d2b05 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -16,7 +16,6 @@ package integration_test import ( "bytes" - "math" "testing" "github.com/google/go-cmp/cmp" @@ -208,9 +207,9 @@ func TestLocalPing(t *testing.T) { var buf bytes.Buffer opts := tcpip.ReadOptions{NeedRemoteAddr: true} - res, err := ep.Read(&buf, math.MaxUint16, opts) + res, err := ep.Read(&buf, opts) if err != nil { - t.Fatalf("ep.Read(_, %d, %#v): %s", math.MaxUint16, opts, err) + t.Fatalf("ep.Read(_, %#v): %s", opts, err) } if diff := cmp.Diff(tcpip.ReadResult{ Count: buf.Len(), @@ -351,7 +350,7 @@ func TestLocalUDP(t *testing.T) { var clientAddr tcpip.FullAddress var readBuf bytes.Buffer - if read, err := server.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { + if read, err := server.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { t.Fatalf("server.Read(_): %s", err) } else { clientAddr = read.RemoteAddr @@ -393,7 +392,7 @@ func TestLocalUDP(t *testing.T) { <-clientCH readBuf.Reset() - if read, err := client.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { + if read, err := client.Read(&readBuf, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { t.Fatalf("client.Read(_): %s", err) } else { if diff := cmp.Diff(tcpip.ReadResult{ diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 87277fbd3..256e19296 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -154,7 +154,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { @@ -186,7 +186,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip res.RemoteAddr = p.senderAddress } - n, err := p.data.ReadTo(dst, count, opts.Peek) + n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { return res, tcpip.ErrBadBuffer } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c3b3b8d34..c0d6fb442 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -162,7 +162,7 @@ func (ep *endpoint) Close() { func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -199,7 +199,7 @@ func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpi res.LinkPacketInfo = packet.packetInfo } - n, err := packet.data.ReadTo(dst, count, opts.Peek) + n, err := packet.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { return res, tcpip.ErrBadBuffer } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 425bcf3ee..ae743f75e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -191,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -225,7 +225,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip res.RemoteAddr = pkt.senderAddr } - n, err := pkt.data.ReadTo(dst, count, opts.Peek) + n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { return res, tcpip.ErrBadBuffer } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a4508e871..ea509ac73 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvReadMu.Lock() defer e.rcvReadMu.Unlock() @@ -1346,9 +1346,9 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip var err error done := 0 s := first - for s != nil && done < count { + for s != nil { var n int - n, err = s.data.ReadTo(dst, count-done, opts.Peek) + n, err = s.data.ReadTo(dst, opts.Peek) // Book keeping first then error handling. done += n diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 729bf7ef5..93683b921 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -50,7 +50,7 @@ type endpointTester struct { // CheckReadError issues a read to the endpoint and checking for an error. func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { t.Helper() - res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{}) + res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) if got != want { t.Fatalf("ep.Read = %s, want %s", got, want) } @@ -61,10 +61,10 @@ func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { // CheckRead issues a read to the endpoint and checking for a success, returning // the data read. -func (e *endpointTester) CheckRead(t *testing.T, count int) []byte { +func (e *endpointTester) CheckRead(t *testing.T) []byte { t.Helper() var buf bytes.Buffer - res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{}) + res, err := e.ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("ep.Read = _, %s; want _, nil", err) } @@ -81,9 +81,12 @@ func (e *endpointTester) CheckRead(t *testing.T, count int) []byte { func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte { t.Helper() var buf bytes.Buffer - var done int - for done < count { - res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{}) + w := tcpip.LimitedWriter{ + W: &buf, + N: int64(count), + } + for w.N != 0 { + _, err := e.ep.Read(&w, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { // Wait for receive to be notified. select { @@ -95,7 +98,6 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } else if err != nil { t.Fatalf("ep.Read = _, %s; want _, nil", err) } - done += res.Count } return buf.Bytes() } @@ -820,7 +822,7 @@ func TestSimpleReceive(t *testing.T) { } // Receive data. - v := ept.CheckRead(t, defaultMTU) + v := ept.CheckRead(t) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -1928,7 +1930,7 @@ func TestFullWindowReceive(t *testing.T) { ) // Receive data and check it. - v := ept.CheckRead(t, defaultMTU) + v := ept.CheckRead(t) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -2015,7 +2017,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Read the data so that the subsequent ACK from the endpoint // grows the right edge of the window. var buf bytes.Buffer - if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil { + if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil { t.Fatalf("c.EP.Read: %s", err) } @@ -2075,7 +2077,7 @@ func TestNoWindowShrinking(t *testing.T) { } // Read the 1 byte payload we just sent. - if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) { + if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) { t.Fatalf("got data: %v, want: %v", got, want) } @@ -2570,13 +2572,16 @@ func TestZeroScaledWindowReceive(t *testing.T) { // update to be sent. For 1MSS worth of window to be available we need to // read at least 128KB. Since our segments above were 50KB each it means // we need to read at 3 packets. - sz := 0 - for sz < defaultMTU*2 { - res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) + w := tcpip.LimitedWriter{ + W: ioutil.Discard, + N: defaultMTU * 2, + } + for w.N != 0 { + res, err := c.EP.Read(&w, tcpip.ReadOptions{}) + t.Logf("err=%v res=%#v", err, res) if err != nil { t.Fatalf("Read failed: %s", err) } - sz += res.Count } checker.IPv4(t, c.GetPacket(), @@ -3271,12 +3276,12 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { - switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err { + switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err { case tcpip.ErrWouldBlock: select { case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { + if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset) } break loop @@ -4224,7 +4229,7 @@ func TestReadAfterClosedState(t *testing.T) { // Check that peek works. var peekBuf bytes.Buffer - res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true}) + res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true}) if err != nil { t.Fatalf("Peek failed: %s", err) } @@ -4237,7 +4242,7 @@ func TestReadAfterClosedState(t *testing.T) { } // Receive data. - v := ept.CheckRead(t, defaultMTU) + v := ept.CheckRead(t) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -4246,8 +4251,8 @@ func TestReadAfterClosedState(t *testing.T) { // right error code. ept.CheckReadError(t, tcpip.ErrClosedForReceive) var buf bytes.Buffer - if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { - t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) + if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { + t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) } } @@ -6205,7 +6210,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { - _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { break } @@ -6329,7 +6334,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // to happen before we measure the new window. totalCopied := 0 for { - res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) + res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { break } @@ -7387,15 +7392,17 @@ func TestIncreaseWindowOnRead(t *testing.T) { // We now have < 1 MSS in the buffer space. Read at least > 2 MSS // worth of data as receive buffer space - read := 0 - // defaultMTU is a good enough estimate for the MSS used for this - // connection. - for read < defaultMTU*2 { - res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) + w := tcpip.LimitedWriter{ + W: ioutil.Discard, + // defaultMTU is a good enough estimate for the MSS used for this + // connection. + N: defaultMTU * 2, + } + for w.N != 0 { + _, err := c.EP.Read(&w, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Read failed: %s", err) } - read += res.Count } // After reading > MSS worth of data, we surely crossed MSS. See the ack: diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 88fb054bb..b65091c3c 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -106,18 +106,19 @@ func TestTimeStampEnabledConnect(t *testing.T) { // There should be 5 views to read and each of them should // contain the same data. for i := 0; i < 5; i++ { - var buf bytes.Buffer - result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{}) + buf := make([]byte, len(data)) + w := tcpip.SliceWriter(buf) + result, err := c.EP.Read(&w, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Unexpected error from Read: %v", err) } if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), + Count: len(buf), + Total: len(buf), }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { t.Errorf("Read: unexpected result (-want +got):\n%s", diff) } - if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { + if got, want := buf, data; bytes.Compare(got, want) != 0 { t.Fatalf("Data is different: got: %v, want: %v", got, want) } } @@ -295,7 +296,7 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { // Issue a read and we should data. var buf bytes.Buffer - result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{}) + result, err := c.EP.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Unexpected error from Read: %v", err) } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 520a0ac9d..9f9b3d510 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -284,7 +284,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err } @@ -340,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip res.RemoteAddr = p.senderAddress } - n, err := p.data.ReadTo(dst, count, opts.Peek) + n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { return res, tcpip.ErrBadBuffer } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 52403ed78..4e2123fe9 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -598,12 +598,12 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() var buf bytes.Buffer - res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true}) + res, err := c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true}) + res, err = c.ep.Read(&buf, tcpip.ReadOptions{NeedRemoteAddr: true}) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -839,7 +839,7 @@ func TestV4ReadSelfSource(t *testing.T) { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) } - if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr { + if _, err := c.ep.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tt.wantErr { t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr) } }) -- cgit v1.2.3