summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDean Deng <deandeng@google.com>2020-11-16 10:41:20 -0800
committergVisor bot <gvisor-bot@google.com>2020-11-16 10:43:37 -0800
commit840a133c64dff95b25bf3ad6019cb5bd16f0999b (patch)
tree18f15a3f85b7ba9c913c0212111c3d8b67252b67
parent39f712f1d8a53db9d1ccbf7a894d9190edab076d (diff)
Automated rollback of changelist 340274194
PiperOrigin-RevId: 342669574
-rw-r--r--pkg/tcpip/transport/tcp/accept.go143
-rw-r--r--pkg/tcpip/transport/tcp/connect.go115
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go71
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go1
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go44
-rw-r--r--pkg/tcpip/transport/tcp/timer.go4
7 files changed, 284 insertions, 96 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 47982ca41..6e5adc383 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -235,11 +235,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
return n, nil
}
-// createEndpointAndPerformHandshake creates a new endpoint in connected state
-// and then performs the TCP 3-way handshake.
+// startHandshake creates a new endpoint in connecting state and then sends
+// the SYN-ACK for the TCP 3-way handshake. It returns the state of the
+// handshake in progress, which includes the new endpoint in the SYN-RCVD
+// state.
//
-// The new endpoint is returned with e.mu held.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+// On success, a handshake h is returned with h.ep.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
@@ -257,10 +261,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// listenEP is nil when listenContext is used by tcp.Forwarder.
deferAccept := time.Duration(0)
if l.listenEP != nil {
- l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
- l.listenEP.mu.Unlock()
// Ensure we release any registrations done by the newly
// created endpoint.
ep.mu.Unlock()
@@ -278,16 +280,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- l.listenEP.mu.Unlock()
- }
+ l.removePendingEndpoint(ep)
return nil, tcpip.ErrConnectionAborted
}
deferAccept = l.listenEP.deferAccept
- l.listenEP.mu.Unlock()
}
// Register new endpoint so that packets are routed to it.
@@ -306,28 +304,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
ep.isRegistered = true
- // Perform the 3-way handshake.
- h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
- if err := h.execute(); err != nil {
- ep.mu.Unlock()
- ep.Close()
- ep.notifyAborted()
-
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
- ep.drainClosingSegmentQueue()
-
+ // Initialize and start the handshake.
+ h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
+ if err := h.start(); err != nil {
+ l.cleanupFailedHandshake(h)
return nil, err
}
- ep.isConnectNotified = true
+ return h, nil
+}
- // Update the receive window scaling. We can't do it before the
- // handshake because it's possible that the peer doesn't support window
- // scaling.
- ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+// performHandshake performs a TCP 3-way handshake. On success, the new
+// established endpoint is returned with e.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+ h, err := l.startHandshake(s, opts, queue, owner)
+ if err != nil {
+ return nil, err
+ }
+ ep := h.ep
+ if err := h.complete(); err != nil {
+ ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ ep.stats.FailedConnectionAttempts.Increment()
+ l.cleanupFailedHandshake(h)
+ return nil, err
+ }
+ l.cleanupCompletedHandshake(h)
return ep, nil
}
@@ -354,6 +357,39 @@ func (l *listenContext) closeAllPendingEndpoints() {
l.pending.Wait()
}
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupFailedHandshake(h *handshake) {
+ e := h.ep
+ e.mu.Unlock()
+ e.Close()
+ e.notifyAborted()
+ if l.listenEP != nil {
+ l.removePendingEndpoint(e)
+ }
+ e.drainClosingSegmentQueue()
+ e.h = nil
+}
+
+// cleanupCompletedHandshake transfers any state from the completed handshake to
+// the new endpoint.
+//
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
+ e := h.ep
+ if l.listenEP != nil {
+ l.removePendingEndpoint(e)
+ }
+ e.isConnectNotified = true
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ e.rcv.rcvWndScale = e.h.effectiveRcvWndScale()
+
+ // Clean up handshake state stored in the endpoint so that it can be GCed.
+ e.h = nil
+}
+
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state (acceptedChan is nil),
// the new endpoint is closed instead.
@@ -433,23 +469,40 @@ func (e *endpoint) notifyAborted() {
//
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
- defer ctx.synRcvdCount.dec()
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error {
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- e.decSynRcvdCount()
- return
+ e.synRcvdCount--
+ return err
}
- ctx.removePendingEndpoint(n)
- e.decSynRcvdCount()
- n.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(n)
+ go func() {
+ defer ctx.synRcvdCount.dec()
+ if err := h.complete(); err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+ e.deliverAccepted(h.ep)
+ }() // S/R-SAFE: synRcvdCount is the barrier.
+
+ return nil
}
func (e *endpoint) incSynRcvdCount() bool {
@@ -462,12 +515,6 @@ func (e *endpoint) incSynRcvdCount() bool {
return canInc
}
-func (e *endpoint) decSynRcvdCount() {
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
-}
-
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
@@ -477,6 +524,8 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
@@ -500,7 +549,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er
// backlog.
if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
- go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+ _ = e.handleSynSegment(ctx, s, &opts)
return nil
}
ctx.synRcvdCount.dec()
@@ -712,7 +761,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
// to the endpoint.
e.setEndpointState(StateClose)
- // close any endpoints in SYN-RCVD state.
+ // Close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
// Do cleanup if needed.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 6e9015be1..ac6d879a7 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -102,21 +102,26 @@ type handshake struct {
// been received. This is required to stop retransmitting the
// original SYN-ACK when deferAccept is enabled.
acked bool
+
+ // sendSYNOpts is the cached values for the SYN options to be sent.
+ sendSYNOpts header.TCPSynOptions
}
-func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
- h := handshake{
- ep: ep,
+func (e *endpoint) newHandshake() *handshake {
+ h := &handshake{
+ ep: e,
active: true,
- rcvWnd: rcvWnd,
- rcvWndScale: ep.rcvWndScaleForHandshake(),
+ rcvWnd: seqnum.Size(e.initialReceiveWindow()),
+ rcvWndScale: e.rcvWndScaleForHandshake(),
}
h.resetState()
+ // Store reference to handshake state in endpoint.
+ e.h = h
return h
}
-func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
- h := newHandshake(ep, rcvWnd)
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+ h := e.newHandshake()
h.resetToSynRcvd(isn, irs, opts, deferAccept)
return h
}
@@ -502,8 +507,9 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
}
-// execute executes the TCP 3-way handshake.
-func (h *handshake) execute() *tcpip.Error {
+// start resolves the route if necessary and sends the first
+// SYN/SYN-ACK.
+func (h *handshake) start() *tcpip.Error {
if h.ep.route.IsResolutionRequired() {
if err := h.resolveRoute(); err != nil {
return err
@@ -511,19 +517,7 @@ func (h *handshake) execute() *tcpip.Error {
}
h.startTime = time.Now()
- // Initialize the resend timer.
- resendWaker := sleep.Waker{}
- timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, resendWaker.Assert)
- defer rt.Stop()
-
- // Set up the wakers.
- s := sleep.Sleeper{}
- s.AddWaker(&resendWaker, wakerForResend)
- s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
- s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
- defer s.Done()
-
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
// If stack returned an error when checking for SACKEnabled
@@ -531,10 +525,6 @@ func (h *handshake) execute() *tcpip.Error {
sackEnabled = false
}
- // Send the initial SYN segment and loop until the handshake is
- // completed.
- h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
-
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
@@ -544,9 +534,8 @@ func (h *handshake) execute() *tcpip.Error {
MSS: h.ep.amss,
}
- // Execute is also called in a listen context so we want to make sure we
- // only send the TS/SACK option when we received the TS/SACK in the
- // initial SYN.
+ // start() is also called in a listen context so we want to make sure we only
+ // send the TS/SACK option when we received the TS/SACK in the initial SYN.
if h.state == handshakeSynRcvd {
synOpts.TS = h.ep.sendTSOk
synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
@@ -557,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error {
}
}
+ h.sendSYNOpts = synOpts
h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
@@ -566,6 +556,25 @@ func (h *handshake) execute() *tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, synOpts)
+ return nil
+}
+
+// complete completes the TCP 3-way handshake initiated by h.start().
+func (h *handshake) complete() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resendWaker := sleep.Waker{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+
+ // Initialize the resend timer.
+ timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert)
+ if err != nil {
+ return err
+ }
+ defer timer.stop()
for h.state != handshakeCompleted {
// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
@@ -576,11 +585,9 @@ func (h *handshake) execute() *tcpip.Error {
switch index {
case wakerForResend:
- timeOut *= 2
- if timeOut > MaxRTO {
- return tcpip.ErrTimeout
+ if err := timer.reset(); err != nil {
+ return err
}
- rt.Reset(timeOut)
// Resend the SYN/SYN-ACK only if the following conditions hold.
// - It's an active handshake (deferAccept does not apply)
// - It's a passive handshake and we have not yet got the final-ACK.
@@ -598,7 +605,7 @@ func (h *handshake) execute() *tcpip.Error {
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
- }, synOpts)
+ }, h.sendSYNOpts)
}
case wakerForNotification:
@@ -637,6 +644,34 @@ func (h *handshake) execute() *tcpip.Error {
return nil
}
+type backoffTimer struct {
+ timeout time.Duration
+ maxTimeout time.Duration
+ t *time.Timer
+}
+
+func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) {
+ if timeout > maxTimeout {
+ return nil, tcpip.ErrTimeout
+ }
+ bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
+ bt.t = time.AfterFunc(timeout, f)
+ return bt, nil
+}
+
+func (bt *backoffTimer) reset() *tcpip.Error {
+ bt.timeout *= 2
+ if bt.timeout > MaxRTO {
+ return tcpip.ErrTimeout
+ }
+ bt.t.Reset(bt.timeout)
+ return nil
+}
+
+func (bt *backoffTimer) stop() {
+ bt.t.Stop()
+}
+
func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
if synOpts.TS {
@@ -1342,14 +1377,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if handshake {
- // This is an active connection, so we must initiate the 3-way
- // handshake, and then inform potential waiters about its
- // completion.
- initialRcvWnd := e.initialReceiveWindow()
- h := newHandshake(e, seqnum.Size(initialRcvWnd))
- h.ep.setEndpointState(StateSynSent)
-
- if err := h.execute(); err != nil {
+ if err := e.h.complete(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
@@ -1364,9 +1392,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
}
- e.keepalive.timer.init(&e.keepalive.waker)
- defer e.keepalive.timer.cleanup()
-
drained := e.drainDone != nil
if drained {
close(e.drainDone)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index e78138415..4f4f4c65e 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -441,6 +441,11 @@ type endpoint struct {
v6only bool
isConnectNotified bool
+ // h stores a reference to the current handshake state if the endpoint is in
+ // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep.
+ // nil otherwise.
+ h *handshake `state:"nosave"`
+
// portFlags stores the current values of port related flags.
portFlags ports.Flags
@@ -922,6 +927,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
+ e.keepalive.timer.init(&e.keepalive.waker)
return e
}
@@ -1146,6 +1152,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
e.closePendingAcceptableConnectionsLocked()
+ e.keepalive.timer.cleanup()
e.workerCleanup = false
@@ -2175,6 +2182,8 @@ func (*endpoint) Disconnect() *tcpip.Error {
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
err := e.connect(addr, true, true)
if err != nil && !err.IgnoreStats() {
+ // Connect failed. Let's wake up any waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
}
@@ -2387,14 +2396,70 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
if run {
- e.workerRunning = true
- e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ if err := e.startMainLoop(handshake); err != nil {
+ return err
+ }
}
return tcpip.ErrConnectStarted
}
+// startMainLoop sends the initial SYN and starts the main loop for the
+// endpoint.
+func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
+ preloop := func() *tcpip.Error {
+ if handshake {
+ h := e.newHandshake()
+ e.setEndpointState(StateSynSent)
+ if err := h.start(); err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.setEndpointState(StateError)
+ e.HardError = err
+
+ // Call cleanupLocked to free up any reservations.
+ e.cleanupLocked()
+ return err
+ }
+ }
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ return nil
+ }
+
+ if e.route.IsResolutionRequired() {
+ // If the endpoint is closed between releasing e.mu and the goroutine below
+ // acquiring it, make sure that cleanup is deferred to the new goroutine.
+ e.workerRunning = true
+
+ // Sending the initial SYN may block due to route resolution; do it in a
+ // separate goroutine to avoid blocking the syscall goroutine.
+ go func() { // S/R-SAFE: will be drained before save.
+ e.mu.Lock()
+ if err := preloop(); err != nil {
+ e.workerRunning = false
+ e.mu.Unlock()
+ return
+ }
+ e.mu.Unlock()
+ _ = e.protocolMainLoop(handshake, nil)
+ }()
+ return nil
+ }
+
+ // No route resolution is required, so we can send the initial SYN here without
+ // blocking. This will hopefully reduce overall latency by overlapping time
+ // spent waiting for a SYN-ACK and time spent spinning up a new goroutine
+ // for the main loop.
+ if err := preloop(); err != nil {
+ return err
+ }
+ e.workerRunning = true
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ return nil
+}
+
// ConnectEndpoint is not supported.
func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 2bcc5e1c2..bb901c0f8 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() {
// Condition variables and mutexs are not S/R'ed so reinitialize
// acceptCond with e.acceptMu.
e.acceptCond = sync.NewCond(&e.acceptMu)
+ e.keepalive.timer.init(&e.keepalive.waker)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 0664789da..596178625 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
f := r.forwarder
- ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index fcc3c5000..9f0fb41e3 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5717,6 +5717,50 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
+func TestSYNRetransmit(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send the same SYN packet multiple times. We should still get a valid SYN-ACK
+ // reply.
+ irs := seqnum.Value(789)
+ for i := 0; i < 5; i++ {
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Receive the SYN-ACK reply.
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.TCPAckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...))
+}
+
func TestSynRcvdBadSeqNumber(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index 7981d469b..38a335840 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
+ if t.timer == nil {
+ // No cleanup needed.
+ return
+ }
t.timer.Stop()
*t = timer{}
}