From 840a133c64dff95b25bf3ad6019cb5bd16f0999b Mon Sep 17 00:00:00 2001
From: Dean Deng <deandeng@google.com>
Date: Mon, 16 Nov 2020 10:41:20 -0800
Subject: Automated rollback of changelist 340274194

PiperOrigin-RevId: 342669574
---
 pkg/tcpip/transport/tcp/accept.go         | 143 ++++++++++++++++++++----------
 pkg/tcpip/transport/tcp/connect.go        | 115 ++++++++++++++----------
 pkg/tcpip/transport/tcp/endpoint.go       |  71 ++++++++++++++-
 pkg/tcpip/transport/tcp/endpoint_state.go |   1 +
 pkg/tcpip/transport/tcp/forwarder.go      |   2 +-
 pkg/tcpip/transport/tcp/tcp_test.go       |  44 +++++++++
 pkg/tcpip/transport/tcp/timer.go          |   4 +
 7 files changed, 284 insertions(+), 96 deletions(-)

(limited to 'pkg/tcpip')

diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 47982ca41..6e5adc383 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -235,11 +235,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
 	return n, nil
 }
 
-// createEndpointAndPerformHandshake creates a new endpoint in connected state
-// and then performs the TCP 3-way handshake.
+// startHandshake creates a new endpoint in connecting state and then sends
+// the SYN-ACK for the TCP 3-way handshake. It returns the state of the
+// handshake in progress, which includes the new endpoint in the SYN-RCVD
+// state.
 //
-// The new endpoint is returned with e.mu held.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+// On success, a handshake h is returned with h.ep.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) {
 	// Create new endpoint.
 	irs := s.sequenceNumber
 	isn := generateSecureISN(s.id, l.stack.Seed())
@@ -257,10 +261,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
 	// listenEP is nil when listenContext is used by tcp.Forwarder.
 	deferAccept := time.Duration(0)
 	if l.listenEP != nil {
-		l.listenEP.mu.Lock()
 		if l.listenEP.EndpointState() != StateListen {
 
-			l.listenEP.mu.Unlock()
 			// Ensure we release any registrations done by the newly
 			// created endpoint.
 			ep.mu.Unlock()
@@ -278,16 +280,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
 			ep.mu.Unlock()
 			ep.Close()
 
-			if l.listenEP != nil {
-				l.removePendingEndpoint(ep)
-				l.listenEP.mu.Unlock()
-			}
+			l.removePendingEndpoint(ep)
 
 			return nil, tcpip.ErrConnectionAborted
 		}
 
 		deferAccept = l.listenEP.deferAccept
-		l.listenEP.mu.Unlock()
 	}
 
 	// Register new endpoint so that packets are routed to it.
@@ -306,28 +304,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
 
 	ep.isRegistered = true
 
-	// Perform the 3-way handshake.
-	h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
-	if err := h.execute(); err != nil {
-		ep.mu.Unlock()
-		ep.Close()
-		ep.notifyAborted()
-
-		if l.listenEP != nil {
-			l.removePendingEndpoint(ep)
-		}
-
-		ep.drainClosingSegmentQueue()
-
+	// Initialize and start the handshake.
+	h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
+	if err := h.start(); err != nil {
+		l.cleanupFailedHandshake(h)
 		return nil, err
 	}
-	ep.isConnectNotified = true
+	return h, nil
+}
 
-	// Update the receive window scaling. We can't do it before the
-	// handshake because it's possible that the peer doesn't support window
-	// scaling.
-	ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+// performHandshake performs a TCP 3-way handshake. On success, the new
+// established endpoint is returned with e.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+	h, err := l.startHandshake(s, opts, queue, owner)
+	if err != nil {
+		return nil, err
+	}
+	ep := h.ep
 
+	if err := h.complete(); err != nil {
+		ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+		ep.stats.FailedConnectionAttempts.Increment()
+		l.cleanupFailedHandshake(h)
+		return nil, err
+	}
+	l.cleanupCompletedHandshake(h)
 	return ep, nil
 }
 
@@ -354,6 +357,39 @@ func (l *listenContext) closeAllPendingEndpoints() {
 	l.pending.Wait()
 }
 
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupFailedHandshake(h *handshake) {
+	e := h.ep
+	e.mu.Unlock()
+	e.Close()
+	e.notifyAborted()
+	if l.listenEP != nil {
+		l.removePendingEndpoint(e)
+	}
+	e.drainClosingSegmentQueue()
+	e.h = nil
+}
+
+// cleanupCompletedHandshake transfers any state from the completed handshake to
+// the new endpoint.
+//
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
+	e := h.ep
+	if l.listenEP != nil {
+		l.removePendingEndpoint(e)
+	}
+	e.isConnectNotified = true
+
+	// Update the receive window scaling. We can't do it before the
+	// handshake because it's possible that the peer doesn't support window
+	// scaling.
+	e.rcv.rcvWndScale = e.h.effectiveRcvWndScale()
+
+	// Clean up handshake state stored in the endpoint so that it can be GCed.
+	e.h = nil
+}
+
 // deliverAccepted delivers the newly-accepted endpoint to the listener. If the
 // endpoint has transitioned out of the listen state (acceptedChan is nil),
 // the new endpoint is closed instead.
@@ -433,23 +469,40 @@ func (e *endpoint) notifyAborted() {
 //
 // A limited number of these goroutines are allowed before TCP starts using SYN
 // cookies to accept connections.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
-	defer ctx.synRcvdCount.dec()
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error {
 	defer s.decRef()
 
-	n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
+	h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
 	if err != nil {
 		e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
 		e.stats.FailedConnectionAttempts.Increment()
-		e.decSynRcvdCount()
-		return
+		e.synRcvdCount--
+		return err
 	}
-	ctx.removePendingEndpoint(n)
-	e.decSynRcvdCount()
-	n.startAcceptedLoop()
-	e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
 
-	e.deliverAccepted(n)
+	go func() {
+		defer ctx.synRcvdCount.dec()
+		if err := h.complete(); err != nil {
+			e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+			e.stats.FailedConnectionAttempts.Increment()
+			ctx.cleanupFailedHandshake(h)
+			e.mu.Lock()
+			e.synRcvdCount--
+			e.mu.Unlock()
+			return
+		}
+		ctx.cleanupCompletedHandshake(h)
+		e.mu.Lock()
+		e.synRcvdCount--
+		e.mu.Unlock()
+		h.ep.startAcceptedLoop()
+		e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+		e.deliverAccepted(h.ep)
+	}() // S/R-SAFE: synRcvdCount is the barrier.
+
+	return nil
 }
 
 func (e *endpoint) incSynRcvdCount() bool {
@@ -462,12 +515,6 @@ func (e *endpoint) incSynRcvdCount() bool {
 	return canInc
 }
 
-func (e *endpoint) decSynRcvdCount() {
-	e.mu.Lock()
-	e.synRcvdCount--
-	e.mu.Unlock()
-}
-
 func (e *endpoint) acceptQueueIsFull() bool {
 	e.acceptMu.Lock()
 	full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
@@ -477,6 +524,8 @@ func (e *endpoint) acceptQueueIsFull() bool {
 
 // handleListenSegment is called when a listening endpoint receives a segment
 // and needs to handle it.
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
 func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
 	e.rcvListMu.Lock()
 	rcvClosed := e.rcvClosed
@@ -500,7 +549,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
 			//     backlog.
 			if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
 				s.incRef()
-				go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+				_ = e.handleSynSegment(ctx, s, &opts)
 				return nil
 			}
 			ctx.synRcvdCount.dec()
@@ -712,7 +761,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
 		// to the endpoint.
 		e.setEndpointState(StateClose)
 
-		// close any endpoints in SYN-RCVD state.
+		// Close any endpoints in SYN-RCVD state.
 		ctx.closeAllPendingEndpoints()
 
 		// Do cleanup if needed.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 6e9015be1..ac6d879a7 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -102,21 +102,26 @@ type handshake struct {
 	// been received. This is required to stop retransmitting the
 	// original SYN-ACK when deferAccept is enabled.
 	acked bool
+
+	// sendSYNOpts is the cached values for the SYN options to be sent.
+	sendSYNOpts header.TCPSynOptions
 }
 
-func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
-	h := handshake{
-		ep:          ep,
+func (e *endpoint) newHandshake() *handshake {
+	h := &handshake{
+		ep:          e,
 		active:      true,
-		rcvWnd:      rcvWnd,
-		rcvWndScale: ep.rcvWndScaleForHandshake(),
+		rcvWnd:      seqnum.Size(e.initialReceiveWindow()),
+		rcvWndScale: e.rcvWndScaleForHandshake(),
 	}
 	h.resetState()
+	// Store reference to handshake state in endpoint.
+	e.h = h
 	return h
 }
 
-func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
-	h := newHandshake(ep, rcvWnd)
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+	h := e.newHandshake()
 	h.resetToSynRcvd(isn, irs, opts, deferAccept)
 	return h
 }
@@ -502,8 +507,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
 	}
 }
 
-// execute executes the TCP 3-way handshake.
-func (h *handshake) execute() *tcpip.Error {
+// start resolves the route if necessary and sends the first
+// SYN/SYN-ACK.
+func (h *handshake) start() *tcpip.Error {
 	if h.ep.route.IsResolutionRequired() {
 		if err := h.resolveRoute(); err != nil {
 			return err
@@ -511,19 +517,7 @@ func (h *handshake) execute() *tcpip.Error {
 	}
 
 	h.startTime = time.Now()
-	// Initialize the resend timer.
-	resendWaker := sleep.Waker{}
-	timeOut := time.Duration(time.Second)
-	rt := time.AfterFunc(timeOut, resendWaker.Assert)
-	defer rt.Stop()
-
-	// Set up the wakers.
-	s := sleep.Sleeper{}
-	s.AddWaker(&resendWaker, wakerForResend)
-	s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
-	s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
-	defer s.Done()
-
+	h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
 	var sackEnabled tcpip.TCPSACKEnabled
 	if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
 		// If stack returned an error when checking for SACKEnabled
@@ -531,10 +525,6 @@ func (h *handshake) execute() *tcpip.Error {
 		sackEnabled = false
 	}
 
-	// Send the initial SYN segment and loop until the handshake is
-	// completed.
-	h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
-
 	synOpts := header.TCPSynOptions{
 		WS:            h.rcvWndScale,
 		TS:            true,
@@ -544,9 +534,8 @@ func (h *handshake) execute() *tcpip.Error {
 		MSS:           h.ep.amss,
 	}
 
-	// Execute is also called in a listen context so we want to make sure we
-	// only send the TS/SACK option when we received the TS/SACK in the
-	// initial SYN.
+	// start() is also called in a listen context so we want to make sure we only
+	// send the TS/SACK option when we received the TS/SACK in the initial SYN.
 	if h.state == handshakeSynRcvd {
 		synOpts.TS = h.ep.sendTSOk
 		synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
@@ -557,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
 		}
 	}
 
+	h.sendSYNOpts = synOpts
 	h.ep.sendSynTCP(&h.ep.route, tcpFields{
 		id:     h.ep.ID,
 		ttl:    h.ep.ttl,
@@ -566,6 +556,25 @@ func (h *handshake) execute() *tcpip.Error {
 		ack:    h.ackNum,
 		rcvWnd: h.rcvWnd,
 	}, synOpts)
+	return nil
+}
+
+// complete completes the TCP 3-way handshake initiated by h.start().
+func (h *handshake) complete() *tcpip.Error {
+	// Set up the wakers.
+	s := sleep.Sleeper{}
+	resendWaker := sleep.Waker{}
+	s.AddWaker(&resendWaker, wakerForResend)
+	s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+	s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+	defer s.Done()
+
+	// Initialize the resend timer.
+	timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert)
+	if err != nil {
+		return err
+	}
+	defer timer.stop()
 
 	for h.state != handshakeCompleted {
 		// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
@@ -576,11 +585,9 @@ func (h *handshake) execute() *tcpip.Error {
 		switch index {
 
 		case wakerForResend:
-			timeOut *= 2
-			if timeOut > MaxRTO {
-				return tcpip.ErrTimeout
+			if err := timer.reset(); err != nil {
+				return err
 			}
-			rt.Reset(timeOut)
 			// Resend the SYN/SYN-ACK only if the following conditions hold.
 			//  - It's an active handshake (deferAccept does not apply)
 			//  - It's a passive handshake and we have not yet got the final-ACK.
@@ -598,7 +605,7 @@ func (h *handshake) execute() *tcpip.Error {
 					seq:    h.iss,
 					ack:    h.ackNum,
 					rcvWnd: h.rcvWnd,
-				}, synOpts)
+				}, h.sendSYNOpts)
 			}
 
 		case wakerForNotification:
@@ -637,6 +644,34 @@ func (h *handshake) execute() *tcpip.Error {
 	return nil
 }
 
+type backoffTimer struct {
+	timeout    time.Duration
+	maxTimeout time.Duration
+	t          *time.Timer
+}
+
+func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) {
+	if timeout > maxTimeout {
+		return nil, tcpip.ErrTimeout
+	}
+	bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
+	bt.t = time.AfterFunc(timeout, f)
+	return bt, nil
+}
+
+func (bt *backoffTimer) reset() *tcpip.Error {
+	bt.timeout *= 2
+	if bt.timeout > MaxRTO {
+		return tcpip.ErrTimeout
+	}
+	bt.t.Reset(bt.timeout)
+	return nil
+}
+
+func (bt *backoffTimer) stop() {
+	bt.t.Stop()
+}
+
 func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
 	synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
 	if synOpts.TS {
@@ -1342,14 +1377,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
 	}
 
 	if handshake {
-		// This is an active connection, so we must initiate the 3-way
-		// handshake, and then inform potential waiters about its
-		// completion.
-		initialRcvWnd := e.initialReceiveWindow()
-		h := newHandshake(e, seqnum.Size(initialRcvWnd))
-		h.ep.setEndpointState(StateSynSent)
-
-		if err := h.execute(); err != nil {
+		if err := e.h.complete(); err != nil {
 			e.lastErrorMu.Lock()
 			e.lastError = err
 			e.lastErrorMu.Unlock()
@@ -1364,9 +1392,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
 		}
 	}
 
-	e.keepalive.timer.init(&e.keepalive.waker)
-	defer e.keepalive.timer.cleanup()
-
 	drained := e.drainDone != nil
 	if drained {
 		close(e.drainDone)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index e78138415..4f4f4c65e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -441,6 +441,11 @@ type endpoint struct {
 	v6only            bool
 	isConnectNotified bool
 
+	// h stores a reference to the current handshake state if the endpoint is in
+	// the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep.
+	// nil otherwise.
+	h *handshake `state:"nosave"`
+
 	// portFlags stores the current values of port related flags.
 	portFlags ports.Flags
 
@@ -922,6 +927,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
 	e.segmentQueue.ep = e
 	e.tsOffset = timeStampOffset()
 	e.acceptCond = sync.NewCond(&e.acceptMu)
+	e.keepalive.timer.init(&e.keepalive.waker)
 
 	return e
 }
@@ -1146,6 +1152,7 @@ func (e *endpoint) cleanupLocked() {
 	// Close all endpoints that might have been accepted by TCP but not by
 	// the client.
 	e.closePendingAcceptableConnectionsLocked()
+	e.keepalive.timer.cleanup()
 
 	e.workerCleanup = false
 
@@ -2175,6 +2182,8 @@ func (*endpoint) Disconnect() *tcpip.Error {
 func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
 	err := e.connect(addr, true, true)
 	if err != nil && !err.IgnoreStats() {
+		// Connect failed. Let's wake up any waiters.
+		e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
 		e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
 		e.stats.FailedConnectionAttempts.Increment()
 	}
@@ -2387,14 +2396,70 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
 	}
 
 	if run {
-		e.workerRunning = true
-		e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
-		go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+		if err := e.startMainLoop(handshake); err != nil {
+			return err
+		}
 	}
 
 	return tcpip.ErrConnectStarted
 }
 
+// startMainLoop sends the initial SYN and starts the main loop for the
+// endpoint.
+func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
+	preloop := func() *tcpip.Error {
+		if handshake {
+			h := e.newHandshake()
+			e.setEndpointState(StateSynSent)
+			if err := h.start(); err != nil {
+				e.lastErrorMu.Lock()
+				e.lastError = err
+				e.lastErrorMu.Unlock()
+
+				e.setEndpointState(StateError)
+				e.HardError = err
+
+				// Call cleanupLocked to free up any reservations.
+				e.cleanupLocked()
+				return err
+			}
+		}
+		e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+		return nil
+	}
+
+	if e.route.IsResolutionRequired() {
+		// If the endpoint is closed between releasing e.mu and the goroutine below
+		// acquiring it, make sure that cleanup is deferred to the new goroutine.
+		e.workerRunning = true
+
+		// Sending the initial SYN may block due to route resolution; do it in a
+		// separate goroutine to avoid blocking the syscall goroutine.
+		go func() { // S/R-SAFE: will be drained before save.
+			e.mu.Lock()
+			if err := preloop(); err != nil {
+				e.workerRunning = false
+				e.mu.Unlock()
+				return
+			}
+			e.mu.Unlock()
+			_ = e.protocolMainLoop(handshake, nil)
+		}()
+		return nil
+	}
+
+	// No route resolution is required, so we can send the initial SYN here without
+	// blocking. This will hopefully reduce overall latency by overlapping time
+	// spent waiting for a SYN-ACK and time spent spinning up a new goroutine
+	// for the main loop.
+	if err := preloop(); err != nil {
+		return err
+	}
+	e.workerRunning = true
+	go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+	return nil
+}
+
 // ConnectEndpoint is not supported.
 func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
 	return tcpip.ErrInvalidEndpointState
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 2bcc5e1c2..bb901c0f8 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() {
 	// Condition variables and mutexs are not S/R'ed so reinitialize
 	// acceptCond with e.acceptMu.
 	e.acceptCond = sync.NewCond(&e.acceptMu)
+	e.keepalive.timer.init(&e.keepalive.waker)
 	stack.StackFromEnv.RegisterRestoredEndpoint(e)
 }
 
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 0664789da..596178625 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
 	}
 
 	f := r.forwarder
-	ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+	ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
 		MSS:           r.synOptions.MSS,
 		WS:            r.synOptions.WS,
 		TS:            r.synOptions.TS,
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index fcc3c5000..9f0fb41e3 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5717,6 +5717,50 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
 	}
 }
 
+func TestSYNRetransmit(t *testing.T) {
+	c := context.New(t, defaultMTU)
+	defer c.Cleanup()
+
+	// Create TCP endpoint.
+	var err *tcpip.Error
+	c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+	if err != nil {
+		t.Fatalf("NewEndpoint failed: %s", err)
+	}
+
+	// Bind to wildcard.
+	if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+		t.Fatalf("Bind failed: %s", err)
+	}
+
+	// Start listening.
+	if err := c.EP.Listen(10); err != nil {
+		t.Fatalf("Listen failed: %s", err)
+	}
+
+	// Send the same SYN packet multiple times. We should still get a valid SYN-ACK
+	// reply.
+	irs := seqnum.Value(789)
+	for i := 0; i < 5; i++ {
+		c.SendPacket(nil, &context.Headers{
+			SrcPort: context.TestPort,
+			DstPort: context.StackPort,
+			Flags:   header.TCPFlagSyn,
+			SeqNum:  irs,
+			RcvWnd:  30000,
+		})
+	}
+
+	// Receive the SYN-ACK reply.
+	tcpCheckers := []checker.TransportChecker{
+		checker.SrcPort(context.StackPort),
+		checker.DstPort(context.TestPort),
+		checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+		checker.TCPAckNum(uint32(irs) + 1),
+	}
+	checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...))
+}
+
 func TestSynRcvdBadSeqNumber(t *testing.T) {
 	c := context.New(t, defaultMTU)
 	defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index 7981d469b..38a335840 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) {
 
 // cleanup frees all resources associated with the timer.
 func (t *timer) cleanup() {
+	if t.timer == nil {
+		// No cleanup needed.
+		return
+	}
 	t.timer.Stop()
 	*t = timer{}
 }
-- 
cgit v1.2.3