summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorDean Deng <deandeng@google.com>2020-10-30 14:56:34 -0700
committergVisor bot <gvisor-bot@google.com>2020-10-30 14:58:43 -0700
commitba05c6845d3d285be8c2d12d9f10a5745f722e58 (patch)
tree6704c57db2a0a2e1df2929408ac7e9b803f437ae /pkg
parent9ad864628dc3f36a24a5eb6a51d7e13471598517 (diff)
Automated rollback of changelist 339750876
PiperOrigin-RevId: 339945377
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/transport/tcp/accept.go143
-rw-r--r--pkg/tcpip/transport/tcp/connect.go125
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go67
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go1
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go2
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go47
-rw-r--r--pkg/tcpip/transport/tcp/timer.go4
9 files changed, 289 insertions, 104 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 17f2e6b46..ff02c7c65 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -895,7 +895,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep
}
// isValidForOutgoing returns true if the endpoint can be used to send out a
-// packet. It requires the endpoint to not be marked expired (i.e., its address)
+// packet. It requires the endpoint to not be marked expired (i.e., its address
// has been removed) unless the NIC is in spoofing mode, or temporary.
func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool {
n.mu.RLock()
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index b76e2d37b..87f7008f7 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -191,7 +191,7 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) {
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
-// the link address before the this route can be written to.
+// the link address before r can be written to.
//
// The NIC r uses must not be locked.
func (r *Route) IsResolutionRequired() bool {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 6b3238d6b..7534aa89f 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -228,11 +228,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i
return n
}
-// createEndpointAndPerformHandshake creates a new endpoint in connected state
-// and then performs the TCP 3-way handshake.
+// startHandshake creates a new endpoint in connecting state and then sends
+// the SYN-ACK for the TCP 3-way handshake. It returns the state of the
+// handshake in progress, which includes the new endpoint in the SYN-RCVD
+// state.
//
-// The new endpoint is returned with e.mu held.
-func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+// On success, a handshake h is returned with h.ep.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
isn := generateSecureISN(s.id, l.stack.Seed())
@@ -247,10 +251,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
// listenEP is nil when listenContext is used by tcp.Forwarder.
deferAccept := time.Duration(0)
if l.listenEP != nil {
- l.listenEP.mu.Lock()
if l.listenEP.EndpointState() != StateListen {
- l.listenEP.mu.Unlock()
// Ensure we release any registrations done by the newly
// created endpoint.
ep.mu.Unlock()
@@ -268,16 +270,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- l.listenEP.mu.Unlock()
- }
+ l.removePendingEndpoint(ep)
return nil, tcpip.ErrConnectionAborted
}
deferAccept = l.listenEP.deferAccept
- l.listenEP.mu.Unlock()
}
// Register new endpoint so that packets are routed to it.
@@ -296,28 +294,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
ep.isRegistered = true
- // Perform the 3-way handshake.
- h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept)
- if err := h.execute(); err != nil {
- ep.mu.Unlock()
- ep.Close()
- ep.notifyAborted()
-
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
- ep.drainClosingSegmentQueue()
-
+ // Initialize and start the handshake.
+ h := ep.newPassiveHandshake(isn, irs, opts, deferAccept)
+ if err := h.start(); err != nil {
+ l.cleanupFailedHandshake(h)
return nil, err
}
- ep.isConnectNotified = true
+ return h, nil
+}
- // Update the receive window scaling. We can't do it before the
- // handshake because it's possible that the peer doesn't support window
- // scaling.
- ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+// performHandshake performs a TCP 3-way handshake. On success, the new
+// established endpoint is returned with e.mu held.
+//
+// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked.
+func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) {
+ h, err := l.startHandshake(s, opts, queue, owner)
+ if err != nil {
+ return nil, err
+ }
+ ep := h.ep
+ if err := h.complete(); err != nil {
+ ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ ep.stats.FailedConnectionAttempts.Increment()
+ l.cleanupFailedHandshake(h)
+ return nil, err
+ }
+ l.cleanupCompletedHandshake(h)
return ep, nil
}
@@ -344,6 +347,39 @@ func (l *listenContext) closeAllPendingEndpoints() {
l.pending.Wait()
}
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupFailedHandshake(h *handshake) {
+ e := h.ep
+ e.mu.Unlock()
+ e.Close()
+ e.notifyAborted()
+ if l.listenEP != nil {
+ l.removePendingEndpoint(e)
+ }
+ e.drainClosingSegmentQueue()
+ e.h = nil
+}
+
+// cleanupCompletedHandshake transfers any state from the completed handshake to
+// the new endpoint.
+//
+// Precondition: h.ep.mu must be held.
+func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
+ e := h.ep
+ if l.listenEP != nil {
+ l.removePendingEndpoint(e)
+ }
+ e.isConnectNotified = true
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ e.rcv.rcvWndScale = e.h.effectiveRcvWndScale()
+
+ // Clean up handshake state stored in the endpoint so that it can be GCed.
+ e.h = nil
+}
+
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state (acceptedChan is nil),
// the new endpoint is closed instead.
@@ -423,23 +459,40 @@ func (e *endpoint) notifyAborted() {
//
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
-func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
- defer ctx.synRcvdCount.dec()
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error {
defer s.decRef()
- n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner)
+ h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- e.decSynRcvdCount()
- return
+ e.synRcvdCount--
+ return err
}
- ctx.removePendingEndpoint(n)
- e.decSynRcvdCount()
- n.startAcceptedLoop()
- e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
- e.deliverAccepted(n)
+ go func() {
+ defer ctx.synRcvdCount.dec()
+ if err := h.complete(); err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ e.stats.FailedConnectionAttempts.Increment()
+ ctx.cleanupFailedHandshake(h)
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+ return
+ }
+ ctx.cleanupCompletedHandshake(h)
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+ h.ep.startAcceptedLoop()
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+ e.deliverAccepted(h.ep)
+ }() // S/R-SAFE: synRcvdCount is the barrier.
+
+ return nil
}
func (e *endpoint) incSynRcvdCount() bool {
@@ -452,12 +505,6 @@ func (e *endpoint) incSynRcvdCount() bool {
return canInc
}
-func (e *endpoint) decSynRcvdCount() {
- e.mu.Lock()
- e.synRcvdCount--
- e.mu.Unlock()
-}
-
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
@@ -467,6 +514,8 @@ func (e *endpoint) acceptQueueIsFull() bool {
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
+//
+// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
e.rcvListMu.Lock()
rcvClosed := e.rcvClosed
@@ -491,7 +540,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// backlog.
if !e.acceptQueueIsFull() && e.incSynRcvdCount() {
s.incRef()
- go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+ _ = e.handleSynSegment(ctx, s, &opts)
return
}
ctx.synRcvdCount.dec()
@@ -686,7 +735,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
// to the endpoint.
e.setEndpointState(StateClose)
- // close any endpoints in SYN-RCVD state.
+ // Close any endpoints in SYN-RCVD state.
ctx.closeAllPendingEndpoints()
// Do cleanup if needed.
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 0aaef495d..fd5373ed4 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -102,21 +102,26 @@ type handshake struct {
// been received. This is required to stop retransmitting the
// original SYN-ACK when deferAccept is enabled.
acked bool
+
+ // sendSYNOpts is the cached values for the SYN options to be sent.
+ sendSYNOpts header.TCPSynOptions
}
-func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
- h := handshake{
- ep: ep,
+func (e *endpoint) newHandshake() *handshake {
+ h := &handshake{
+ ep: e,
active: true,
- rcvWnd: rcvWnd,
- rcvWndScale: ep.rcvWndScaleForHandshake(),
+ rcvWnd: seqnum.Size(e.initialReceiveWindow()),
+ rcvWndScale: e.rcvWndScaleForHandshake(),
}
h.resetState()
+ // Store reference to handshake state in endpoint.
+ e.h = h
return h
}
-func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake {
- h := newHandshake(ep, rcvWnd)
+func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake {
+ h := e.newHandshake()
h.resetToSynRcvd(isn, irs, opts, deferAccept)
return h
}
@@ -496,12 +501,13 @@ func (h *handshake) resolveRoute() *tcpip.Error {
}
// Wait for notification.
- index, _ = s.Fetch(true)
+ index, _ = s.Fetch(true /* block */)
}
}
-// execute executes the TCP 3-way handshake.
-func (h *handshake) execute() *tcpip.Error {
+// start resolves the route if necessary and sends the first
+// SYN/SYN-ACK.
+func (h *handshake) start() *tcpip.Error {
if h.ep.route.IsResolutionRequired() {
if err := h.resolveRoute(); err != nil {
return err
@@ -509,19 +515,7 @@ func (h *handshake) execute() *tcpip.Error {
}
h.startTime = time.Now()
- // Initialize the resend timer.
- resendWaker := sleep.Waker{}
- timeOut := time.Duration(time.Second)
- rt := time.AfterFunc(timeOut, resendWaker.Assert)
- defer rt.Stop()
-
- // Set up the wakers.
- s := sleep.Sleeper{}
- s.AddWaker(&resendWaker, wakerForResend)
- s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
- s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
- defer s.Done()
-
+ h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
var sackEnabled tcpip.TCPSACKEnabled
if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
// If stack returned an error when checking for SACKEnabled
@@ -529,10 +523,6 @@ func (h *handshake) execute() *tcpip.Error {
sackEnabled = false
}
- // Send the initial SYN segment and loop until the handshake is
- // completed.
- h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route)
-
synOpts := header.TCPSynOptions{
WS: h.rcvWndScale,
TS: true,
@@ -542,9 +532,8 @@ func (h *handshake) execute() *tcpip.Error {
MSS: h.ep.amss,
}
- // Execute is also called in a listen context so we want to make sure we
- // only send the TS/SACK option when we received the TS/SACK in the
- // initial SYN.
+ // start() is also called in a listen context so we want to make sure we only
+ // send the TS/SACK option when we received the TS/SACK in the initial SYN.
if h.state == handshakeSynRcvd {
synOpts.TS = h.ep.sendTSOk
synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
@@ -555,6 +544,7 @@ func (h *handshake) execute() *tcpip.Error {
}
}
+ h.sendSYNOpts = synOpts
h.ep.sendSynTCP(&h.ep.route, tcpFields{
id: h.ep.ID,
ttl: h.ep.ttl,
@@ -564,19 +554,38 @@ func (h *handshake) execute() *tcpip.Error {
ack: h.ackNum,
rcvWnd: h.rcvWnd,
}, synOpts)
+ return nil
+}
+
+// complete completes the TCP 3-way handshake initiated by h.start().
+func (h *handshake) complete() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resendWaker := sleep.Waker{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+
+ // Initialize the resend timer.
+ timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert)
+ if err != nil {
+ return err
+ }
+ defer timer.stop()
for h.state != handshakeCompleted {
+ // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
+ // throughout handshake processing).
h.ep.mu.Unlock()
- index, _ := s.Fetch(true)
+ index, _ := s.Fetch(true /* block */)
h.ep.mu.Lock()
switch index {
case wakerForResend:
- timeOut *= 2
- if timeOut > MaxRTO {
- return tcpip.ErrTimeout
+ if err := timer.reset(); err != nil {
+ return err
}
- rt.Reset(timeOut)
// Resend the SYN/SYN-ACK only if the following conditions hold.
// - It's an active handshake (deferAccept does not apply)
// - It's a passive handshake and we have not yet got the final-ACK.
@@ -594,7 +603,7 @@ func (h *handshake) execute() *tcpip.Error {
seq: h.iss,
ack: h.ackNum,
rcvWnd: h.rcvWnd,
- }, synOpts)
+ }, h.sendSYNOpts)
}
case wakerForNotification:
@@ -633,6 +642,34 @@ func (h *handshake) execute() *tcpip.Error {
return nil
}
+type backoffTimer struct {
+ timeout time.Duration
+ maxTimeout time.Duration
+ t *time.Timer
+}
+
+func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) {
+ if timeout > maxTimeout {
+ return nil, tcpip.ErrTimeout
+ }
+ bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout}
+ bt.t = time.AfterFunc(timeout, f)
+ return bt, nil
+}
+
+func (bt *backoffTimer) reset() *tcpip.Error {
+ bt.timeout *= 2
+ if bt.timeout > MaxRTO {
+ return tcpip.ErrTimeout
+ }
+ bt.t.Reset(bt.timeout)
+ return nil
+}
+
+func (bt *backoffTimer) stop() {
+ bt.t.Stop()
+}
+
func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
if synOpts.TS {
@@ -1338,14 +1375,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
if handshake {
- // This is an active connection, so we must initiate the 3-way
- // handshake, and then inform potential waiters about its
- // completion.
- initialRcvWnd := e.initialReceiveWindow()
- h := newHandshake(e, seqnum.Size(initialRcvWnd))
- h.ep.setEndpointState(StateSynSent)
-
- if err := h.execute(); err != nil {
+ if err := e.h.complete(); err != nil {
e.lastErrorMu.Lock()
e.lastError = err
e.lastErrorMu.Unlock()
@@ -1360,9 +1390,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
}
}
- e.keepalive.timer.init(&e.keepalive.waker)
- defer e.keepalive.timer.cleanup()
-
drained := e.drainDone != nil
if drained {
close(e.drainDone)
@@ -1535,7 +1562,7 @@ loop:
}
e.mu.Unlock()
- v, _ := s.Fetch(true)
+ v, _ := s.Fetch(true /* block */)
e.mu.Lock()
// We need to double check here because the notification may be
@@ -1683,7 +1710,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) {
for {
e.mu.Unlock()
- v, _ := s.Fetch(true)
+ v, _ := s.Fetch(true /* block */)
e.mu.Lock()
switch v {
case newSegment:
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index bfe26e460..3e4403bee 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -440,6 +440,11 @@ type endpoint struct {
ttl uint8
v6only bool
isConnectNotified bool
+ // h stores a reference to the current handshake state if the endpoint is in
+ // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep.
+ // nil otherwise.
+ h *handshake `state:"nosave"`
+
// TCP should never broadcast but Linux nevertheless supports enabling/
// disabling SO_BROADCAST, albeit as a NOOP.
broadcast bool
@@ -721,9 +726,9 @@ func (e *endpoint) LockUser() {
for {
// Try first if the sock is locked then check if it's owned
// by another user goroutine if not then we spin, otherwise
- // we just goto sleep on the Lock() and wait.
+ // we just go to sleep on the Lock() and wait.
if !e.mu.TryLock() {
- // If socket is owned by the user then just goto sleep
+ // If socket is owned by the user then just go to sleep
// as the lock could be held for a reasonably long time.
if atomic.LoadUint32(&e.ownedByUser) == 1 {
e.mu.Lock()
@@ -922,6 +927,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.segmentQueue.ep = e
e.tsOffset = timeStampOffset()
e.acceptCond = sync.NewCond(&e.acceptMu)
+ e.keepalive.timer.init(&e.keepalive.waker)
return e
}
@@ -1143,6 +1149,7 @@ func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
e.closePendingAcceptableConnectionsLocked()
+ e.keepalive.timer.cleanup()
e.workerCleanup = false
@@ -2182,6 +2189,8 @@ func (*endpoint) Disconnect() *tcpip.Error {
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
err := e.connect(addr, true, true)
if err != nil && !err.IgnoreStats() {
+ // Connect failed. Let's wake up any waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
}
@@ -2395,14 +2404,62 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
}
if run {
- e.workerRunning = true
- e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
- go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ if err := e.startMainLoop(handshake); err != nil {
+ return err
+ }
}
return tcpip.ErrConnectStarted
}
+// startMainLoop sends the initial SYN and starts the main loop for the
+// endpoint.
+func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
+ preloop := func() *tcpip.Error {
+ if handshake {
+ h := e.newHandshake()
+ e.setEndpointState(StateSynSent)
+ if err := h.start(); err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.setEndpointState(StateError)
+ e.HardError = err
+
+ // Call cleanupLocked to free up any reservations.
+ e.cleanupLocked()
+ return err
+ }
+ }
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ e.workerRunning = true
+ return nil
+ }
+
+ if e.route.IsResolutionRequired() {
+ // Sending the initial SYN may block due to route resolution; do it in a
+ // separate goroutine to avoid blocking the syscall goroutine.
+ go func() { // S/R-SAFE: will be drained before save.
+ if err := preloop(); err != nil {
+ return
+ }
+ e.protocolMainLoop(handshake, nil)
+ }()
+ return nil
+ }
+
+ // No route resolution is required, so we can send the initial SYN here without
+ // blocking. This will hopefully reduce overall latency by overlapping time
+ // spent waiting for a SYN-ACK and time spent spinning up a new goroutine
+ // for the main loop.
+ if err := preloop(); err != nil {
+ return err
+ }
+ go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save.
+ return nil
+}
+
// ConnectEndpoint is not supported.
func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index b25431467..5e7962794 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() {
// Condition variables and mutexs are not S/R'ed so reinitialize
// acceptCond with e.acceptMu.
e.acceptCond = sync.NewCond(&e.acceptMu)
+ e.keepalive.timer.init(&e.keepalive.waker)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
index 070b634b4..8c334c97b 100644
--- a/pkg/tcpip/transport/tcp/forwarder.go
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -150,7 +150,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint,
}
f := r.forwarder
- ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 5f05608e2..49e19a4af 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -5717,6 +5717,53 @@ func TestListenBacklogFullSynCookieInUse(t *testing.T) {
}
}
+func TestSYNRetransmit(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %s", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
+ t.Fatalf("Bind failed: %s", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %s", err)
+ }
+
+ // Send the same SYN packet multiple times. We should still get a valid SYN-ACK
+ // reply.
+ irs := seqnum.Value(789)
+ for i := 0; i < 5; i++ {
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+ }
+
+ // Receive the SYN-ACK reply.
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.TCPAckNum(uint32(irs) + 1),
+ }
+ checker.IPv4(t, c.GetPacket(), checker.TCP(tcpCheckers...))
+ if p := c.GetPacketWithTimeout(1 * time.Second); p != nil {
+ t.Fatalf("Unexpected packet received: %#v", p)
+ }
+}
+
func TestSynRcvdBadSeqNumber(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
index 7981d469b..38a335840 100644
--- a/pkg/tcpip/transport/tcp/timer.go
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) {
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
+ if t.timer == nil {
+ // No cleanup needed.
+ return
+ }
t.timer.Stop()
*t = timer{}
}