diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 30 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/rcv.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 139 |
3 files changed, 161 insertions, 14 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index fe629aa40..5d42f8045 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -955,10 +955,20 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { } e.rcvBufUsed -= len(v) + + avail := e.receiveBufferAvailableLocked() + + // If the window was small before, lower than MSS, send immediate + // ack. Without this, the sender might be stuck waiting for the + // window to grow, while we think the window is non-zero so we + // don't need to send acks. To avoid silly window syndrome, send + // ack only when the window grows above one MSS. + crossedMSS := avail-len(v) < int(e.amss) && avail >= int(e.amss) + // If the window was zero before this read and if the read freed up // enough buffer space for the scaled window to be non-zero then notify // the protocol goroutine to send a window update. - if e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale) { + if (e.zeroWindow && !e.zeroReceiveWindow(e.rcv.rcvWndScale)) || crossedMSS { e.zeroWindow = false e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) } @@ -1138,11 +1148,8 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro // // It must be called with rcvListMu held. func (e *endpoint) zeroReceiveWindow(scale uint8) bool { - if e.rcvBufUsed >= e.rcvBufSize { - return true - } - - return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 + avail := e.receiveBufferAvailableLocked() + return (avail >> scale) == 0 } // SetSockOptInt sets a socket option. @@ -1181,9 +1188,16 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { size = math.MaxInt32 / 2 } + availBefore := e.receiveBufferAvailableLocked() e.rcvBufSize = size + availAfter := e.receiveBufferAvailableLocked() + e.rcvAutoParams.disabled = true - if e.zeroWindow && !e.zeroReceiveWindow(scale) { + + // Immediatelly send ACK in two cases: when the buffer + // grows so that it leaves zero-window state, when the + // buffer grows from small < MSS to >= MSS. + if (e.zeroWindow && !e.zeroReceiveWindow(scale)) || (availBefore < int(e.amss) && availAfter >= int(e.amss)) { e.zeroWindow = false mask |= notifyNonZeroReceiveWindow } @@ -2229,7 +2243,7 @@ func (e *endpoint) readyToRead(s *segment) { // we set the zero window before we deliver the segment to ensure // that a subsequent read of the segment will correctly trigger // a non-zero notification. - if avail := e.receiveBufferAvailableLocked(); avail>>e.rcv.rcvWndScale == 0 { + if e.zeroReceiveWindow(e.rcv.rcvWndScale) { e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() e.zeroWindow = true } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index 0a5534959..05c8488f8 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -98,12 +98,6 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { // in such cases we may need to send an ack to indicate to our peer that it can // resume sending data. func (r *receiver) nonZeroWindow() { - if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 { - // We never got around to announcing a zero window size, so we - // don't need to immediately announce a nonzero one. - return - } - // Immediately send an ack. r.ep.snd.sendAck() } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index e8fe4dab5..a05365c6a 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6561,3 +6561,142 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %v, want = %v", got, want) } } + +func TestIncreaseWindowOnReceive(t *testing.T) { + // This test ensures that the endpoint sends an ack, + // after recv() when the window grows to more than 1 MSS. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const rcvBuf = 65535 * 10 + c.CreateConnected(789, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, defaultMTU/2) + lastWnd := uint16(0) + + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + + lastWnd = uint16(remain) + if remain > 0xffff { + lastWnd = 0xffff + } + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(lastWnd), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + if lastWnd == 0xffff || lastWnd == 0 { + t.Fatalf("expected small, non-zero window: %d", lastWnd) + } + + // We now have < 1 MSS in the buffer space. Read the data! An + // ack should be sent in response to that. The window was not + // zero, but it grew to larger than MSS. + _, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + _, _, err = c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + // After reading two packets, we surely crossed MSS. See the ack: + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(0xffff)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestIncreaseWindowOnBufferResize(t *testing.T) { + // This test ensures that the endpoint sends an ack, + // after available recv buffer grows to more than 1 MSS. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const rcvBuf = 65535 * 10 + c.CreateConnected(789, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, defaultMTU/2) + lastWnd := uint16(0) + + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + + lastWnd = uint16(remain) + if remain > 0xffff { + lastWnd = 0xffff + } + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(lastWnd), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + if lastWnd == 0xffff || lastWnd == 0 { + t.Fatalf("expected small, non-zero window: %d", lastWnd) + } + + // Increasing the buffer from should generate an ACK, + // since window grew from small value to larger equal MSS + c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) + + // After reading two packets, we surely crossed MSS. See the ack: + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(0xffff)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} |