diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 52 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_test.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/testing/context/context.go | 3 |
4 files changed, 39 insertions, 22 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index f543a6105..74df3edfb 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -298,8 +298,6 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head return nil, err } ep.mu.Lock() - ep.stack.Stats().TCP.CurrentEstablished.Increment() - ep.state = StateEstablished ep.isConnectNotified = true ep.mu.Unlock() @@ -546,6 +544,8 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. + // We do not use transitionToStateEstablishedLocked here as there is + // no handshake state available when doing a SYN cookie based accept. n.stack.Stats().TCP.CurrentEstablished.Increment() n.state = StateEstablished n.isConnectNotified = true diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 16f8aea12..2975a1c3c 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -252,6 +252,11 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // and the handshake is completed. if s.flagIsSet(header.TCPFlagAck) { h.state = handshakeCompleted + + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) return nil } @@ -352,6 +357,10 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) } h.state = handshakeCompleted + h.ep.mu.Lock() + h.ep.transitionToStateEstablishedLocked(h) + h.ep.mu.Unlock() + return nil } @@ -880,6 +889,30 @@ func (e *endpoint) completeWorkerLocked() { } } +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver if required. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + if e.snd == nil { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + } + if e.rcv == nil { + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + } + h.ep.stack.Stats().TCP.CurrentEstablished.Increment() + e.state = StateEstablished +} + // transitionToStateCloseLocked ensures that the endpoint is // cleaned up from the transport demuxer, "before" moving to // StateClose. This will ensure that no packet will be @@ -1156,25 +1189,6 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } - - // Transfer handshake state to TCP connection. We disable - // receive window scaling if the peer doesn't support it - // (indicated by a negative send window scale). - e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) - - rcvBufSize := seqnum.Size(e.receiveBufferSize()) - e.rcvListMu.Lock() - e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) - // boot strap the auto tuning algorithm. Starting at zero will - // result in a large step function on the first proper causing - // the window to just go to a really large value after the first - // RTT itself. - e.rcvAutoParams.prevCopied = initialRcvWnd - e.rcvListMu.Unlock() - e.stack.Stats().TCP.CurrentEstablished.Increment() - e.mu.Lock() - e.state = StateEstablished - e.mu.Unlock() } e.keepalive.timer.init(&e.keepalive.waker) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index d1f0d6ce7..52c2fa7e3 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -541,7 +541,7 @@ func TestClosingWithEnqueuedSegments(t *testing.T) { ep.(interface{ ResumeWork() }).ResumeWork() // Wait for the protocolMainLoop to resume and update state. - time.Sleep(1 * time.Millisecond) + time.Sleep(10 * time.Millisecond) // Expect the endpoint to be closed. if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6cb66c1af..b0a376eba 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -231,6 +231,7 @@ func (c *Context) CheckNoPacket(errMsg string) { // addresses. It will fail with an error if no packet is received for // 2 seconds. func (c *Context) GetPacket() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { @@ -259,6 +260,7 @@ func (c *Context) GetPacket() []byte { // and destination address. If no packet is available it will return // nil immediately. func (c *Context) GetPacketNonBlocking() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv4.ProtocolNumber { @@ -483,6 +485,7 @@ func (c *Context) CreateV6Endpoint(v6only bool) { // GetV6Packet reads a single packet from the link layer endpoint of the context // and asserts that it is an IPv6 Packet with the expected src/dest addresses. func (c *Context) GetV6Packet() []byte { + c.t.Helper() select { case p := <-c.linkEP.C: if p.Proto != ipv6.ProtocolNumber { |