diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 143 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 115 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 71 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/timer.go | 4 |
6 files changed, 240 insertions, 96 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 47982ca41..6e5adc383 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -235,11 +235,15 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i return n, nil } -// createEndpointAndPerformHandshake creates a new endpoint in connected state -// and then performs the TCP 3-way handshake. +// startHandshake creates a new endpoint in connecting state and then sends +// the SYN-ACK for the TCP 3-way handshake. It returns the state of the +// handshake in progress, which includes the new endpoint in the SYN-RCVD +// state. // -// The new endpoint is returned with e.mu held. -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { +// On success, a handshake h is returned with h.ep.mu held. +// +// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. +func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -257,10 +261,8 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) if l.listenEP != nil { - l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { - l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() @@ -278,16 +280,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - l.listenEP.mu.Unlock() - } + l.removePendingEndpoint(ep) return nil, tcpip.ErrConnectionAborted } deferAccept = l.listenEP.deferAccept - l.listenEP.mu.Unlock() } // Register new endpoint so that packets are routed to it. @@ -306,28 +304,33 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.isRegistered = true - // Perform the 3-way handshake. - h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) - if err := h.execute(); err != nil { - ep.mu.Unlock() - ep.Close() - ep.notifyAborted() - - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - - ep.drainClosingSegmentQueue() - + // Initialize and start the handshake. + h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) + if err := h.start(); err != nil { + l.cleanupFailedHandshake(h) return nil, err } - ep.isConnectNotified = true + return h, nil +} - // Update the receive window scaling. We can't do it before the - // handshake because it's possible that the peer doesn't support window - // scaling. - ep.rcv.rcvWndScale = h.effectiveRcvWndScale() +// performHandshake performs a TCP 3-way handshake. On success, the new +// established endpoint is returned with e.mu held. +// +// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. +func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { + h, err := l.startHandshake(s, opts, queue, owner) + if err != nil { + return nil, err + } + ep := h.ep + if err := h.complete(); err != nil { + ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() + ep.stats.FailedConnectionAttempts.Increment() + l.cleanupFailedHandshake(h) + return nil, err + } + l.cleanupCompletedHandshake(h) return ep, nil } @@ -354,6 +357,39 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } +// Precondition: h.ep.mu must be held. +func (l *listenContext) cleanupFailedHandshake(h *handshake) { + e := h.ep + e.mu.Unlock() + e.Close() + e.notifyAborted() + if l.listenEP != nil { + l.removePendingEndpoint(e) + } + e.drainClosingSegmentQueue() + e.h = nil +} + +// cleanupCompletedHandshake transfers any state from the completed handshake to +// the new endpoint. +// +// Precondition: h.ep.mu must be held. +func (l *listenContext) cleanupCompletedHandshake(h *handshake) { + e := h.ep + if l.listenEP != nil { + l.removePendingEndpoint(e) + } + e.isConnectNotified = true + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() + + // Clean up handshake state stored in the endpoint so that it can be GCed. + e.h = nil +} + // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. @@ -433,23 +469,40 @@ func (e *endpoint) notifyAborted() { // // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer ctx.synRcvdCount.dec() +// +// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) *tcpip.Error { defer s.decRef() - n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) + h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - e.decSynRcvdCount() - return + e.synRcvdCount-- + return err } - ctx.removePendingEndpoint(n) - e.decSynRcvdCount() - n.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(n) + go func() { + defer ctx.synRcvdCount.dec() + if err := h.complete(); err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + ctx.cleanupFailedHandshake(h) + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() + return + } + ctx.cleanupCompletedHandshake(h) + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() + h.ep.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + e.deliverAccepted(h.ep) + }() // S/R-SAFE: synRcvdCount is the barrier. + + return nil } func (e *endpoint) incSynRcvdCount() bool { @@ -462,12 +515,6 @@ func (e *endpoint) incSynRcvdCount() bool { return canInc } -func (e *endpoint) decSynRcvdCount() { - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() -} - func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) @@ -477,6 +524,8 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. +// +// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed @@ -500,7 +549,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Er // backlog. if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() - go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. + _ = e.handleSynSegment(ctx, s, &opts) return nil } ctx.synRcvdCount.dec() @@ -712,7 +761,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { // to the endpoint. e.setEndpointState(StateClose) - // close any endpoints in SYN-RCVD state. + // Close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() // Do cleanup if needed. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6e9015be1..ac6d879a7 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -102,21 +102,26 @@ type handshake struct { // been received. This is required to stop retransmitting the // original SYN-ACK when deferAccept is enabled. acked bool + + // sendSYNOpts is the cached values for the SYN options to be sent. + sendSYNOpts header.TCPSynOptions } -func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { - h := handshake{ - ep: ep, +func (e *endpoint) newHandshake() *handshake { + h := &handshake{ + ep: e, active: true, - rcvWnd: rcvWnd, - rcvWndScale: ep.rcvWndScaleForHandshake(), + rcvWnd: seqnum.Size(e.initialReceiveWindow()), + rcvWndScale: e.rcvWndScaleForHandshake(), } h.resetState() + // Store reference to handshake state in endpoint. + e.h = h return h } -func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { - h := newHandshake(ep, rcvWnd) +func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake { + h := e.newHandshake() h.resetToSynRcvd(isn, irs, opts, deferAccept) return h } @@ -502,8 +507,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } } -// execute executes the TCP 3-way handshake. -func (h *handshake) execute() *tcpip.Error { +// start resolves the route if necessary and sends the first +// SYN/SYN-ACK. +func (h *handshake) start() *tcpip.Error { if h.ep.route.IsResolutionRequired() { if err := h.resolveRoute(); err != nil { return err @@ -511,19 +517,7 @@ func (h *handshake) execute() *tcpip.Error { } h.startTime = time.Now() - // Initialize the resend timer. - resendWaker := sleep.Waker{} - timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, resendWaker.Assert) - defer rt.Stop() - - // Set up the wakers. - s := sleep.Sleeper{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) - defer s.Done() - + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled @@ -531,10 +525,6 @@ func (h *handshake) execute() *tcpip.Error { sackEnabled = false } - // Send the initial SYN segment and loop until the handshake is - // completed. - h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) - synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, @@ -544,9 +534,8 @@ func (h *handshake) execute() *tcpip.Error { MSS: h.ep.amss, } - // Execute is also called in a listen context so we want to make sure we - // only send the TS/SACK option when we received the TS/SACK in the - // initial SYN. + // start() is also called in a listen context so we want to make sure we only + // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { synOpts.TS = h.ep.sendTSOk synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) @@ -557,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error { } } + h.sendSYNOpts = synOpts h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, @@ -566,6 +556,25 @@ func (h *handshake) execute() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) + return nil +} + +// complete completes the TCP 3-way handshake initiated by h.start(). +func (h *handshake) complete() *tcpip.Error { + // Set up the wakers. + s := sleep.Sleeper{} + resendWaker := sleep.Waker{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + + // Initialize the resend timer. + timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + if err != nil { + return err + } + defer timer.stop() for h.state != handshakeCompleted { // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held @@ -576,11 +585,9 @@ func (h *handshake) execute() *tcpip.Error { switch index { case wakerForResend: - timeOut *= 2 - if timeOut > MaxRTO { - return tcpip.ErrTimeout + if err := timer.reset(); err != nil { + return err } - rt.Reset(timeOut) // Resend the SYN/SYN-ACK only if the following conditions hold. // - It's an active handshake (deferAccept does not apply) // - It's a passive handshake and we have not yet got the final-ACK. @@ -598,7 +605,7 @@ func (h *handshake) execute() *tcpip.Error { seq: h.iss, ack: h.ackNum, rcvWnd: h.rcvWnd, - }, synOpts) + }, h.sendSYNOpts) } case wakerForNotification: @@ -637,6 +644,34 @@ func (h *handshake) execute() *tcpip.Error { return nil } +type backoffTimer struct { + timeout time.Duration + maxTimeout time.Duration + t *time.Timer +} + +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { + if timeout > maxTimeout { + return nil, tcpip.ErrTimeout + } + bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} + bt.t = time.AfterFunc(timeout, f) + return bt, nil +} + +func (bt *backoffTimer) reset() *tcpip.Error { + bt.timeout *= 2 + if bt.timeout > MaxRTO { + return tcpip.ErrTimeout + } + bt.t.Reset(bt.timeout) + return nil +} + +func (bt *backoffTimer) stop() { + bt.t.Stop() +} + func parseSynSegmentOptions(s *segment) header.TCPSynOptions { synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) if synOpts.TS { @@ -1342,14 +1377,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if handshake { - // This is an active connection, so we must initiate the 3-way - // handshake, and then inform potential waiters about its - // completion. - initialRcvWnd := e.initialReceiveWindow() - h := newHandshake(e, seqnum.Size(initialRcvWnd)) - h.ep.setEndpointState(StateSynSent) - - if err := h.execute(); err != nil { + if err := e.h.complete(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -1364,9 +1392,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } - e.keepalive.timer.init(&e.keepalive.waker) - defer e.keepalive.timer.cleanup() - drained := e.drainDone != nil if drained { close(e.drainDone) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index e78138415..4f4f4c65e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -441,6 +441,11 @@ type endpoint struct { v6only bool isConnectNotified bool + // h stores a reference to the current handshake state if the endpoint is in + // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep. + // nil otherwise. + h *handshake `state:"nosave"` + // portFlags stores the current values of port related flags. portFlags ports.Flags @@ -922,6 +927,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) + e.keepalive.timer.init(&e.keepalive.waker) return e } @@ -1146,6 +1152,7 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. e.closePendingAcceptableConnectionsLocked() + e.keepalive.timer.cleanup() e.workerCleanup = false @@ -2175,6 +2182,8 @@ func (*endpoint) Disconnect() *tcpip.Error { func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { err := e.connect(addr, true, true) if err != nil && !err.IgnoreStats() { + // Connect failed. Let's wake up any waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() } @@ -2387,14 +2396,70 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if run { - e.workerRunning = true - e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() - go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + if err := e.startMainLoop(handshake); err != nil { + return err + } } return tcpip.ErrConnectStarted } +// startMainLoop sends the initial SYN and starts the main loop for the +// endpoint. +func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { + preloop := func() *tcpip.Error { + if handshake { + h := e.newHandshake() + e.setEndpointState(StateSynSent) + if err := h.start(); err != nil { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + e.setEndpointState(StateError) + e.HardError = err + + // Call cleanupLocked to free up any reservations. + e.cleanupLocked() + return err + } + } + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() + return nil + } + + if e.route.IsResolutionRequired() { + // If the endpoint is closed between releasing e.mu and the goroutine below + // acquiring it, make sure that cleanup is deferred to the new goroutine. + e.workerRunning = true + + // Sending the initial SYN may block due to route resolution; do it in a + // separate goroutine to avoid blocking the syscall goroutine. + go func() { // S/R-SAFE: will be drained before save. + e.mu.Lock() + if err := preloop(); err != nil { + e.workerRunning = false + e.mu.Unlock() + return + } + e.mu.Unlock() + _ = e.protocolMainLoop(handshake, nil) + }() + return nil + } + + // No route resolution is required, so we can send the initial SYN here without + // blocking. This will hopefully reduce overall latency by overlapping time + // spent waiting for a SYN-ACK and time spent spinning up a new goroutine + // for the main loop. + if err := preloop(); err != nil { + return err + } + e.workerRunning = true + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + return nil +} + // ConnectEndpoint is not supported. func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { return tcpip.ErrInvalidEndpointState diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 2bcc5e1c2..bb901c0f8 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -172,6 +172,7 @@ func (e *endpoint) afterLoad() { // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) + e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 0664789da..596178625 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -152,7 +152,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } f := r.forwarder - ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ + ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{ MSS: r.synOptions.MSS, WS: r.synOptions.WS, TS: r.synOptions.TS, diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 7981d469b..38a335840 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -84,6 +84,10 @@ func (t *timer) init(w *sleep.Waker) { // cleanup frees all resources associated with the timer. func (t *timer) cleanup() { + if t.timer == nil { + // No cleanup needed. + return + } t.timer.Stop() *t = timer{} } |