diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 15 |
2 files changed, 22 insertions, 3 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 91ee3b0be..9d4dce826 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -1516,6 +1516,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ // Main loop. Handle segments until both send and receive ends of the // connection have completed. cleanupOnError := func(err *tcpip.Error) { + e.stack.Stats().TCP.CurrentConnected.Decrement() e.workerCleanup = true if err != nil { e.resetConnectionLocked(err) @@ -1568,11 +1569,14 @@ loop: reuseTW = e.doTimeWait() } - // Mark endpoint as closed. - if e.EndpointState() != StateError { - e.transitionToStateCloseLocked() + // Handle any StateError transition from StateTimeWait. + if e.EndpointState() == StateError { + cleanupOnError(nil) + return nil } + e.transitionToStateCloseLocked() + // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index c6ffa7a9d..0668cedc9 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -2967,6 +2967,9 @@ loop: if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } } func TestSendOnResetConnection(t *testing.T) { @@ -3050,6 +3053,9 @@ func TestMaxRetransmitsTimeout(t *testing.T) { if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } } // TestMaxRTO tests if the retransmit interval caps to MaxRTO. @@ -4754,6 +4760,9 @@ func TestKeepalive(t *testing.T) { if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { @@ -6771,6 +6780,9 @@ func TestTCPUserTimeout(t *testing.T) { if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } } func TestKeepaliveWithUserTimeout(t *testing.T) { @@ -6842,6 +6854,9 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } } func TestIncreaseWindowOnReceive(t *testing.T) { |