diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 205 |
2 files changed, 206 insertions, 1 deletions
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 9ce8fcae9..90e493978 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -477,7 +477,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err tcpip.Error) { // segments. This ensures that we always leave some space for the inorder // segments to arrive allowing pending segments to be processed and // delivered to the user. - if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && r.PendingBufUsed < int(rcvBufSize)>>2 { + if rcvBufSize := r.ep.ops.GetReceiveBufferSize(); rcvBufSize > 0 && (r.PendingBufUsed+int(segLen)) < int(rcvBufSize)>>2 { r.ep.rcvQueueInfo.rcvQueueMu.Lock() r.PendingBufUsed += s.segMemSize() r.ep.rcvQueueInfo.rcvQueueMu.Unlock() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 90b74a2a7..bc8708a5b 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2128,6 +2128,211 @@ func TestFullWindowReceive(t *testing.T) { ) } +func TestSmallReceiveBufferReadiness(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + }) + + ep := loopback.New() + if testing.Verbose() { + ep = sniffer.New(ep) + } + + const nicID = 1 + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(nicID, ep, nicOpts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %s", nicOpts, err) + } + + addr := tcpip.AddressWithPrefix{ + Address: tcpip.Address("\x7f\x00\x00\x01"), + PrefixLen: 8, + } + if err := s.AddAddressWithPrefix(nicID, ipv4.ProtocolNumber, addr); err != nil { + t.Fatalf("AddAddressWithPrefix(_, _, %s) failed: %s", addr, err) + } + + { + subnet, err := tcpip.NewSubnet("\x7f\x00\x00\x00", "\xff\x00\x00\x00") + if err != nil { + t.Fatalf("tcpip.NewSubnet failed: %s", err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: subnet, + NIC: nicID, + }, + }) + } + + listenerEntry, listenerCh := waiter.NewChannelEntry(nil) + var listenerWQ waiter.Queue + listener, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &listenerWQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer listener.Close() + listenerWQ.EventRegister(&listenerEntry, waiter.ReadableEvents) + defer listenerWQ.EventUnregister(&listenerEntry) + + if err := listener.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := listener.Listen(1); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + localAddress, err := listener.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress failed: %s", err) + } + + for i := 8; i > 0; i /= 2 { + size := int64(i << 10) + t.Run(fmt.Sprintf("size=%d", size), func(t *testing.T) { + var clientWQ waiter.Queue + client, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &clientWQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer client.Close() + switch err := client.Connect(localAddress).(type) { + case nil: + t.Fatal("Connect returned nil error") + case *tcpip.ErrConnectStarted: + default: + t.Fatalf("Connect failed: %s", err) + } + + <-listenerCh + server, serverWQ, err := listener.Accept(nil) + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + defer server.Close() + + client.SocketOptions().SetReceiveBufferSize(size, true) + // Send buffer size doesn't seem to affect this test. + // server.SocketOptions().SetSendBufferSize(size, true) + + clientEntry, clientCh := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientEntry, waiter.ReadableEvents) + defer clientWQ.EventUnregister(&clientEntry) + + serverEntry, serverCh := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverEntry, waiter.WritableEvents) + defer serverWQ.EventUnregister(&serverEntry) + + var total int64 + for { + var b [64 << 10]byte + var r bytes.Reader + r.Reset(b[:]) + switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + case nil: + t.Logf("wrote %d bytes", n) + total += n + continue + case *tcpip.ErrWouldBlock: + select { + case <-serverCh: + continue + case <-time.After(100 * time.Millisecond): + // Well and truly full. + t.Logf("send and receive queues are full") + } + default: + t.Fatalf("Write failed: %s", err) + } + break + } + t.Logf("wrote %d bytes in total", total) + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(2) + go func() { + defer wg.Done() + + var b [64 << 10]byte + var r bytes.Reader + r.Reset(b[:]) + if err := func() error { + var total int64 + defer t.Logf("wrote %d bytes in total", total) + for r.Len() != 0 { + switch n, err := server.Write(&r, tcpip.WriteOptions{}); err.(type) { + case nil: + t.Logf("wrote %d bytes", n) + total += n + case *tcpip.ErrWouldBlock: + for { + t.Logf("waiting on server") + select { + case <-serverCh: + case <-time.After(time.Second): + if readiness := server.Readiness(waiter.WritableEvents); readiness != 0 { + t.Logf("server.Readiness(%b) = %b but channel not signaled", waiter.WritableEvents, readiness) + } + continue + } + break + } + default: + return fmt.Errorf("server.Write failed: %s", err) + } + } + if err := server.Shutdown(tcpip.ShutdownWrite); err != nil { + return fmt.Errorf("server.Shutdown failed: %s", err) + } + t.Logf("server end shutdown done") + return nil + }(); err != nil { + t.Error(err) + } + }() + + go func() { + defer wg.Done() + + if err := func() error { + total := 0 + defer t.Logf("read %d bytes in total", total) + for { + switch res, err := client.Read(ioutil.Discard, tcpip.ReadOptions{}); err.(type) { + case nil: + t.Logf("read %d bytes", res.Count) + total += res.Count + t.Logf("read total %d bytes till now", total) + case *tcpip.ErrClosedForReceive: + return nil + case *tcpip.ErrWouldBlock: + for { + t.Logf("waiting on client") + select { + case <-clientCh: + case <-time.After(time.Second): + if readiness := client.Readiness(waiter.ReadableEvents); readiness != 0 { + return fmt.Errorf("client.Readiness(%b) = %b but channel not signaled", waiter.ReadableEvents, readiness) + } + continue + } + break + } + default: + return fmt.Errorf("client.Write failed: %s", err) + } + } + }(); err != nil { + t.Error(err) + } + }() + }) + } +} + // Test the stack receive window advertisement on receiving segments smaller than // segment overhead. It tests for the right edge of the window to not grow when // the endpoint is not being read from. |