diff options
author | Tamir Duberstein <tamird@google.com> | 2021-01-15 15:47:13 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-15 15:49:15 -0800 |
commit | 12d9790833cc2f6a9b197066a5ecbeb434f74164 (patch) | |
tree | e9eec8e4c755c33c5a30c1912422b28380ed1f53 /pkg/tcpip/transport/tcp | |
parent | f37ace6661dfed8acae7e22ed0eb9ad78bdeab34 (diff) |
Remove count argument from tcpip.Endpoint.Read
The same intent can be specified via the io.Writer.
PiperOrigin-RevId: 352098747
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_timestamp_test.go | 13 |
3 files changed, 46 insertions, 38 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a4508e871..ea509ac73 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvReadMu.Lock() defer e.rcvReadMu.Unlock() @@ -1346,9 +1346,9 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip var err error done := 0 s := first - for s != nil && done < count { + for s != nil { var n int - n, err = s.data.ReadTo(dst, count-done, opts.Peek) + n, err = s.data.ReadTo(dst, opts.Peek) // Book keeping first then error handling. done += n diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 729bf7ef5..93683b921 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -50,7 +50,7 @@ type endpointTester struct { // CheckReadError issues a read to the endpoint and checking for an error. func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { t.Helper() - res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{}) + res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{}) if got != want { t.Fatalf("ep.Read = %s, want %s", got, want) } @@ -61,10 +61,10 @@ func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { // CheckRead issues a read to the endpoint and checking for a success, returning // the data read. -func (e *endpointTester) CheckRead(t *testing.T, count int) []byte { +func (e *endpointTester) CheckRead(t *testing.T) []byte { t.Helper() var buf bytes.Buffer - res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{}) + res, err := e.ep.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("ep.Read = _, %s; want _, nil", err) } @@ -81,9 +81,12 @@ func (e *endpointTester) CheckRead(t *testing.T, count int) []byte { func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte { t.Helper() var buf bytes.Buffer - var done int - for done < count { - res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{}) + w := tcpip.LimitedWriter{ + W: &buf, + N: int64(count), + } + for w.N != 0 { + _, err := e.ep.Read(&w, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { // Wait for receive to be notified. select { @@ -95,7 +98,6 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha } else if err != nil { t.Fatalf("ep.Read = _, %s; want _, nil", err) } - done += res.Count } return buf.Bytes() } @@ -820,7 +822,7 @@ func TestSimpleReceive(t *testing.T) { } // Receive data. - v := ept.CheckRead(t, defaultMTU) + v := ept.CheckRead(t) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -1928,7 +1930,7 @@ func TestFullWindowReceive(t *testing.T) { ) // Receive data and check it. - v := ept.CheckRead(t, defaultMTU) + v := ept.CheckRead(t) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -2015,7 +2017,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Read the data so that the subsequent ACK from the endpoint // grows the right edge of the window. var buf bytes.Buffer - if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil { + if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil { t.Fatalf("c.EP.Read: %s", err) } @@ -2075,7 +2077,7 @@ func TestNoWindowShrinking(t *testing.T) { } // Read the 1 byte payload we just sent. - if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) { + if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) { t.Fatalf("got data: %v, want: %v", got, want) } @@ -2570,13 +2572,16 @@ func TestZeroScaledWindowReceive(t *testing.T) { // update to be sent. For 1MSS worth of window to be available we need to // read at least 128KB. Since our segments above were 50KB each it means // we need to read at 3 packets. - sz := 0 - for sz < defaultMTU*2 { - res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) + w := tcpip.LimitedWriter{ + W: ioutil.Discard, + N: defaultMTU * 2, + } + for w.N != 0 { + res, err := c.EP.Read(&w, tcpip.ReadOptions{}) + t.Logf("err=%v res=%#v", err, res) if err != nil { t.Fatalf("Read failed: %s", err) } - sz += res.Count } checker.IPv4(t, c.GetPacket(), @@ -3271,12 +3276,12 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { - switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err { + switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err { case tcpip.ErrWouldBlock: select { case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { + if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset) } break loop @@ -4224,7 +4229,7 @@ func TestReadAfterClosedState(t *testing.T) { // Check that peek works. var peekBuf bytes.Buffer - res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true}) + res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true}) if err != nil { t.Fatalf("Peek failed: %s", err) } @@ -4237,7 +4242,7 @@ func TestReadAfterClosedState(t *testing.T) { } // Receive data. - v := ept.CheckRead(t, defaultMTU) + v := ept.CheckRead(t) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -4246,8 +4251,8 @@ func TestReadAfterClosedState(t *testing.T) { // right error code. ept.CheckReadError(t, tcpip.ErrClosedForReceive) var buf bytes.Buffer - if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { - t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) + if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { + t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) } } @@ -6205,7 +6210,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { - _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) + _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { break } @@ -6329,7 +6334,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // to happen before we measure the new window. totalCopied := 0 for { - res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) + res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { break } @@ -7387,15 +7392,17 @@ func TestIncreaseWindowOnRead(t *testing.T) { // We now have < 1 MSS in the buffer space. Read at least > 2 MSS // worth of data as receive buffer space - read := 0 - // defaultMTU is a good enough estimate for the MSS used for this - // connection. - for read < defaultMTU*2 { - res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) + w := tcpip.LimitedWriter{ + W: ioutil.Discard, + // defaultMTU is a good enough estimate for the MSS used for this + // connection. + N: defaultMTU * 2, + } + for w.N != 0 { + _, err := c.EP.Read(&w, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Read failed: %s", err) } - read += res.Count } // After reading > MSS worth of data, we surely crossed MSS. See the ack: diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 88fb054bb..b65091c3c 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -106,18 +106,19 @@ func TestTimeStampEnabledConnect(t *testing.T) { // There should be 5 views to read and each of them should // contain the same data. for i := 0; i < 5; i++ { - var buf bytes.Buffer - result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{}) + buf := make([]byte, len(data)) + w := tcpip.SliceWriter(buf) + result, err := c.EP.Read(&w, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Unexpected error from Read: %v", err) } if diff := cmp.Diff(tcpip.ReadResult{ - Count: buf.Len(), - Total: buf.Len(), + Count: len(buf), + Total: len(buf), }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { t.Errorf("Read: unexpected result (-want +got):\n%s", diff) } - if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { + if got, want := buf, data; bytes.Compare(got, want) != 0 { t.Fatalf("Data is different: got: %v, want: %v", got, want) } } @@ -295,7 +296,7 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { // Issue a read and we should data. var buf bytes.Buffer - result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{}) + result, err := c.EP.Read(&buf, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Unexpected error from Read: %v", err) } |