summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go205
2 files changed, 206 insertions, 1 deletions
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 9ce8fcae9..90e493978 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -477,7 +477,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) {
// segments. This ensures that we always leave some space for the inorder
// segments to arrive allowing pending segments to be processed and
// delivered to the user.
- if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 {
+ if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && (r.PendingBufUsed+int(segLen)) < int(rcvBufSize)>>2 {
r.ep.rcvQueueInfo.rcvQueueMu.Lock()
r.PendingBufUsed += s.segMemSize()
r.ep.rcvQueueInfo.rcvQueueMu.Unlock()
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 90b74a2a7..bc8708a5b 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -2128,6 +2128,211 @@ func TestFullWindowReceive(t *testing.T) {
)
}
+func TestSmallReceiveBufferReadiness(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol},
+ })
+
+ ep := loopback.New()
+ if testing.Verbose() {
+ ep = sniffer.New(ep)
+ }
+
+ const nicID = 1
+ nicOpts := stack.NICOptions{Name: "nic1"}
+ if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil {
+ t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err)
+ }
+
+ addr := tcpip.AddressWithPrefix{
+ Address: tcpip.Address("\x7f\x00\x00\x01"),
+ PrefixLen: 8,
+ }
+ if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil {
+ t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err)
+ }
+
+ {
+ subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00")
+ if err != nil {
+ t.Fatalf("tcpip.NewSubnet failed: %s", err)
+ }
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: subnet,
+ NIC: nicID,
+ },
+ })
+ }
+
+ listenerEntry, listenerCh := waiter.NewChannelEntry(nil)
+ var listenerWQ waiter.Queue
+ listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer listener.Close()
+ listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents)
+ defer listenerWQ.EventUnregister(&listenerEntry)
+
+ if err := listener.Bind(tcpip.FullAddress{}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+ if err := listener.Listen(1); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ localAddress, err := listener.GetLocalAddress()
+ if err != nil {
+ t.Fatalf("GetLocalAddress failed: %s", err)
+ }
+
+ for i := 8; i > 0; i /= 2 {
+ size := int64(i << 10)
+ t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) {
+ var clientWQ waiter.Queue
+ client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+ defer client.Close()
+ switch err := client.Connect(localAddress).(type) {
+ case nil:
+ t.Fatal("Connect returned nil error")
+ case *tcpip.ErrConnectStarted:
+ default:
+ t.Fatalf("Connect failed: %s", err)
+ }
+
+ <-listenerCh
+ server, serverWQ, err := listener.Accept(nil)
+ if err != nil {
+ t.Fatalf("Accept failed: %s", err)
+ }
+ defer server.Close()
+
+ client.SocketOptions().SetReceiveBufferSize(size, true)
+ // Send buffer size doesn't seem to affect this test.
+ // server.SocketOptions().SetSendBufferSize(size, true)
+
+ clientEntry, clientCh := waiter.NewChannelEntry(nil)
+ clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents)
+ defer clientWQ.EventUnregister(&clientEntry)
+
+ serverEntry, serverCh := waiter.NewChannelEntry(nil)
+ serverWQ.EventRegister(&serverEntry, waiter.WritableEvents)
+ defer serverWQ.EventUnregister(&serverEntry)
+
+ var total int64
+ for {
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ continue
+ case *tcpip.ErrWouldBlock:
+ select {
+ case <-serverCh:
+ continue
+ case <-time.After(100 * time.Millisecond):
+ // Well and truly full.
+ t.Logf("send and receive queues are full")
+ }
+ default:
+ t.Fatalf("Write failed: %s", err)
+ }
+ break
+ }
+ t.Logf("wrote %d bytes in total", total)
+
+ var wg sync.WaitGroup
+ defer wg.Wait()
+
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+
+ var b [64 << 10]byte
+ var r bytes.Reader
+ r.Reset(b[:])
+ if err := func() error {
+ var total int64
+ defer t.Logf("wrote %d bytes in total", total)
+ for r.Len() != 0 {
+ switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) {
+ case nil:
+ t.Logf("wrote %d bytes", n)
+ total += n
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on server")
+ select {
+ case <-serverCh:
+ case <-time.After(time.Second):
+ if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 {
+ t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("server.Write failed: %s", err)
+ }
+ }
+ if err := server.Shutdown(tcpip.ShutdownWrite); err != nil {
+ return fmt.Errorf("server.Shutdown failed: %s", err)
+ }
+ t.Logf("server end shutdown done")
+ return nil
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+
+ go func() {
+ defer wg.Done()
+
+ if err := func() error {
+ total := 0
+ defer t.Logf("read %d bytes in total", total)
+ for {
+ switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) {
+ case nil:
+ t.Logf("read %d bytes", res.Count)
+ total += res.Count
+ t.Logf("read total %d bytes till now", total)
+ case *tcpip.ErrClosedForReceive:
+ return nil
+ case *tcpip.ErrWouldBlock:
+ for {
+ t.Logf("waiting on client")
+ select {
+ case <-clientCh:
+ case <-time.After(time.Second):
+ if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 {
+ return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness)
+ }
+ continue
+ }
+ break
+ }
+ default:
+ return fmt.Errorf("client.Write failed: %s", err)
+ }
+ }
+ }(); err != nil {
+ t.Error(err)
+ }
+ }()
+ })
+ }
+}
+
// Test the stack receive window advertisement on receiving segments smaller than
// segment overhead. It tests for the right edge of the window to not grow when
// the endpoint is not being read from.