summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp
diff options
context:
space:
mode:
authorTamir Duberstein <tamird@google.com>2021-01-15 15:47:13 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-15 15:49:15 -0800
commit12d9790833cc2f6a9b197066a5ecbeb434f74164 (patch)
treee9eec8e4c755c33c5a30c1912422b28380ed1f53 /pkg/tcpip/transport/tcp
parentf37ace6661dfed8acae7e22ed0eb9ad78bdeab34 (diff)
Remove count argument from tcpip.Endpoint.Read
The same intent can be specified via the io.Writer. PiperOrigin-RevId: 352098747
Diffstat (limited to 'pkg/tcpip/transport/tcp')
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go65
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go13
3 files changed, 46 insertions, 38 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a4508e871..ea509ac73 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvReadMu.Lock()
defer e.rcvReadMu.Unlock()
@@ -1346,9 +1346,9 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
var err error
done := 0
s := first
- for s != nil && done < count {
+ for s != nil {
var n int
- n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ n, err = s.data.ReadTo(dst, opts.Peek)
// Book keeping first then error handling.
done += n
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 729bf7ef5..93683b921 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -50,7 +50,7 @@ type endpointTester struct {
// CheckReadError issues a read to the endpoint and checking for an error.
func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
t.Helper()
- res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{})
+ res, got := e.ep.Read(ioutil.Discard, tcpip.ReadOptions{})
if got != want {
t.Fatalf("ep.Read = %s, want %s", got, want)
}
@@ -61,10 +61,10 @@ func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
// CheckRead issues a read to the endpoint and checking for a success, returning
// the data read.
-func (e *endpointTester) CheckRead(t *testing.T, count int) []byte {
+func (e *endpointTester) CheckRead(t *testing.T) []byte {
t.Helper()
var buf bytes.Buffer
- res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{})
+ res, err := e.ep.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("ep.Read = _, %s; want _, nil", err)
}
@@ -81,9 +81,12 @@ func (e *endpointTester) CheckRead(t *testing.T, count int) []byte {
func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
t.Helper()
var buf bytes.Buffer
- var done int
- for done < count {
- res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{})
+ w := tcpip.LimitedWriter{
+ W: &buf,
+ N: int64(count),
+ }
+ for w.N != 0 {
+ _, err := e.ep.Read(&w, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
// Wait for receive to be notified.
select {
@@ -95,7 +98,6 @@ func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-cha
} else if err != nil {
t.Fatalf("ep.Read = _, %s; want _, nil", err)
}
- done += res.Count
}
return buf.Bytes()
}
@@ -820,7 +822,7 @@ func TestSimpleReceive(t *testing.T) {
}
// Receive data.
- v := ept.CheckRead(t, defaultMTU)
+ v := ept.CheckRead(t)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1928,7 +1930,7 @@ func TestFullWindowReceive(t *testing.T) {
)
// Receive data and check it.
- v := ept.CheckRead(t, defaultMTU)
+ v := ept.CheckRead(t)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -2015,7 +2017,7 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
// Read the data so that the subsequent ACK from the endpoint
// grows the right edge of the window.
var buf bytes.Buffer
- if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
+ if _, err := c.EP.Read(&buf, tcpip.ReadOptions{}); err != nil {
t.Fatalf("c.EP.Read: %s", err)
}
@@ -2075,7 +2077,7 @@ func TestNoWindowShrinking(t *testing.T) {
}
// Read the 1 byte payload we just sent.
- if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) {
+ if got, want := payload, ept.CheckRead(t); !bytes.Equal(got, want) {
t.Fatalf("got data: %v, want: %v", got, want)
}
@@ -2570,13 +2572,16 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// update to be sent. For 1MSS worth of window to be available we need to
// read at least 128KB. Since our segments above were 50KB each it means
// we need to read at 3 packets.
- sz := 0
- for sz < defaultMTU*2 {
- res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
+ w := tcpip.LimitedWriter{
+ W: ioutil.Discard,
+ N: defaultMTU * 2,
+ }
+ for w.N != 0 {
+ res, err := c.EP.Read(&w, tcpip.ReadOptions{})
+ t.Logf("err=%v res=%#v", err, res)
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- sz += res.Count
}
checker.IPv4(t, c.GetPacket(),
@@ -3271,12 +3276,12 @@ func TestReceiveOnResetConnection(t *testing.T) {
loop:
for {
- switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err {
+ switch _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err {
case tcpip.ErrWouldBlock:
select {
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
+ if _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset)
}
break loop
@@ -4224,7 +4229,7 @@ func TestReadAfterClosedState(t *testing.T) {
// Check that peek works.
var peekBuf bytes.Buffer
- res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true})
+ res, err := c.EP.Read(&peekBuf, tcpip.ReadOptions{Peek: true})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
@@ -4237,7 +4242,7 @@ func TestReadAfterClosedState(t *testing.T) {
}
// Receive data.
- v := ept.CheckRead(t, defaultMTU)
+ v := ept.CheckRead(t)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -4246,8 +4251,8 @@ func TestReadAfterClosedState(t *testing.T) {
// right error code.
ept.CheckReadError(t, tcpip.ErrClosedForReceive)
var buf bytes.Buffer
- if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
+ if _, err := c.EP.Read(&buf, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("c.EP.Read(_, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
}
}
@@ -6205,7 +6210,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
- _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
+ _, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
@@ -6329,7 +6334,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// to happen before we measure the new window.
totalCopied := 0
for {
- res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
+ res, err := c.EP.Read(ioutil.Discard, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
@@ -7387,15 +7392,17 @@ func TestIncreaseWindowOnRead(t *testing.T) {
// We now have < 1 MSS in the buffer space. Read at least > 2 MSS
// worth of data as receive buffer space
- read := 0
- // defaultMTU is a good enough estimate for the MSS used for this
- // connection.
- for read < defaultMTU*2 {
- res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
+ w := tcpip.LimitedWriter{
+ W: ioutil.Discard,
+ // defaultMTU is a good enough estimate for the MSS used for this
+ // connection.
+ N: defaultMTU * 2,
+ }
+ for w.N != 0 {
+ _, err := c.EP.Read(&w, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- read += res.Count
}
// After reading > MSS worth of data, we surely crossed MSS. See the ack:
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 88fb054bb..b65091c3c 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -106,18 +106,19 @@ func TestTimeStampEnabledConnect(t *testing.T) {
// There should be 5 views to read and each of them should
// contain the same data.
for i := 0; i < 5; i++ {
- var buf bytes.Buffer
- result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{})
+ buf := make([]byte, len(data))
+ w := tcpip.SliceWriter(buf)
+ result, err := c.EP.Read(&w, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
if diff := cmp.Diff(tcpip.ReadResult{
- Count: buf.Len(),
- Total: buf.Len(),
+ Count: len(buf),
+ Total: len(buf),
}, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
}
- if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
+ if got, want := buf, data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
@@ -295,7 +296,7 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
// Issue a read and we should data.
var buf bytes.Buffer
- result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{})
+ result, err := c.EP.Read(&buf, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}