summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go30
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go6
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go139
3 files changed, 161 insertions, 14 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index fe629aa40..5d42f8045 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -955,10 +955,20 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
}
e.rcvBufUsed -= len(v)
+
+ avail := e.receiveBufferAvailableLocked()
+
+ // If the window was small before, lower than MSS, send immediate
+ // ack. Without this, the sender might be stuck waiting for the
+ // window to grow, while we think the window is non-zero so we
+ // don't need to send acks. To avoid silly window syndrome, send
+ // ack only when the window grows above one MSS.
+ crossedMSS := avail-len(v) < int(e.amss) && avail >= int(e.amss)
+
// If the window was zero before this read and if the read freed up
// enough buffer space for the scaled window to be non-zero then notify
// the protocol goroutine to send a window update.
- if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) {
+ if (e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale)) || crossedMSS {
e.zeroWindow = false
e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
}
@@ -1138,11 +1148,8 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
//
// It must be called with rcvListMu held.
func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
- if e.rcvBufUsed >= e.rcvBufSize {
- return true
- }
-
- return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
+ avail := e.receiveBufferAvailableLocked()
+ return (avail >> scale) == 0
}
// SetSockOptInt sets a socket option.
@@ -1181,9 +1188,16 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
size = math.MaxInt32 / 2
}
+ availBefore := e.receiveBufferAvailableLocked()
e.rcvBufSize = size
+ availAfter := e.receiveBufferAvailableLocked()
+
e.rcvAutoParams.disabled = true
- if e.zeroWindow && !e.zeroReceiveWindow(scale) {
+
+ // Immediatelly send ACK in two cases: when the buffer
+ // grows so that it leaves zero-window state, when the
+ // buffer grows from small < MSS to >= MSS.
+ if (e.zeroWindow && !e.zeroReceiveWindow(scale)) || (availBefore < int(e.amss) && availAfter >= int(e.amss)) {
e.zeroWindow = false
mask |= notifyNonZeroReceiveWindow
}
@@ -2229,7 +2243,7 @@ func (e *endpoint) readyToRead(s *segment) {
// we set the zero window before we deliver the segment to ensure
// that a subsequent read of the segment will correctly trigger
// a non-zero notification.
- if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 {
+ if e.zeroReceiveWindow(e.rcv.rcvWndScale) {
e.stats.ReceiveErrors.ZeroRcvWindowState.Increment()
e.zeroWindow = true
}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
index 0a5534959..05c8488f8 100644
--- a/pkg/tcpip/transport/tcp/rcv.go
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -98,12 +98,6 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// in such cases we may need to send an ack to indicate to our peer that it can
// resume sending data.
func (r *receiver) nonZeroWindow() {
- if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
- // We never got around to announcing a zero window size, so we
- // don't need to immediately announce a nonzero one.
- return
- }
-
// Immediately send an ack.
r.ep.snd.sendAck()
}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index e8fe4dab5..a05365c6a 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -6561,3 +6561,142 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want)
}
}
+
+func TestIncreaseWindowOnReceive(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after recv() when the window grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // We now have < 1 MSS in the buffer space. Read the data! An
+ // ack should be sent in response to that. The window was not
+ // zero, but it grew to larger than MSS.
+ _, _, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ _, _, err = c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestIncreaseWindowOnBufferResize(t *testing.T) {
+ // This test ensures that the endpoint sends an ack,
+ // after available recv buffer grows to more than 1 MSS.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ const rcvBuf = 65535 * 10
+ c.CreateConnected(789, 30000, rcvBuf)
+
+ // Write chunks of ~30000 bytes. It's important that two
+ // payloads make it equal or longer than MSS.
+ remain := rcvBuf
+ sent := 0
+ data := make([]byte, defaultMTU/2)
+ lastWnd := uint16(0)
+
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+
+ lastWnd = uint16(remain)
+ if remain > 0xffff {
+ lastWnd = 0xffff
+ }
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(lastWnd),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ if lastWnd == 0xffff || lastWnd == 0 {
+ t.Fatalf("expected small, non-zero window: %d", lastWnd)
+ }
+
+ // Increasing the buffer from should generate an ACK,
+ // since window grew from small value to larger equal MSS
+ c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2)
+
+ // After reading two packets, we surely crossed MSS. See the ack:
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(0xffff)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}