diff options
-rw-r--r-- | pkg/tcpip/stack/nic.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 126 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 125 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/forwarder.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/timer.go | 4 |
8 files changed, 94 insertions, 233 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index ff02c7c65..17f2e6b46 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -895,7 +895,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address +// packet. It requires the endpoint to not be marked expired (i.e., its address) // has been removed) unless the NIC is in spoofing mode, or temporary. func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 87f7008f7..b76e2d37b 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -191,7 +191,7 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { } // IsResolutionRequired returns true if Resolve() must be called to resolve -// the link address before r can be written to. +// the link address before the this route can be written to. // // The NIC r uses must not be locked. func (r *Route) IsResolutionRequired() bool { diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index e68cd95a8..6b3238d6b 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -228,15 +228,11 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i return n } -// startHandshake creates a new endpoint in connecting state and then sends -// the SYN-ACK for the TCP 3-way handshake. It returns the state of the -// handshake in progress, which includes the new endpoint in the SYN-RCVD -// state. +// createEndpointAndPerformHandshake creates a new endpoint in connected state +// and then performs the TCP 3-way handshake. // -// On success, a handshake h is returned with h.ep.mu held. -// -// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*handshake, *tcpip.Error) { +// The new endpoint is returned with e.mu held. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) @@ -251,8 +247,10 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q // listenEP is nil when listenContext is used by tcp.Forwarder. deferAccept := time.Duration(0) if l.listenEP != nil { + l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { + l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() @@ -270,12 +268,16 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.mu.Unlock() ep.Close() - l.removePendingEndpoint(ep) + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } return nil, tcpip.ErrConnectionAborted } deferAccept = l.listenEP.deferAccept + l.listenEP.mu.Unlock() } // Register new endpoint so that packets are routed to it. @@ -294,33 +296,28 @@ func (l *listenContext) startHandshake(s *segment, opts *header.TCPSynOptions, q ep.isRegistered = true - // Initialize and start the handshake. - h := ep.newPassiveHandshake(isn, irs, opts, deferAccept) - if err := h.start(); err != nil { - l.cleanupFailedHandshake(h) - return nil, err - } - return h, nil -} + // Perform the 3-way handshake. + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) + if err := h.execute(); err != nil { + ep.mu.Unlock() + ep.Close() + ep.notifyAborted() -// performHandshake performs a TCP 3-way handshake. On success, the new -// established endpoint is returned with e.mu held. -// -// Precondition: if l.listenEP != nil, l.listenEP.mu must be locked. -func (l *listenContext) performHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { - h, err := l.startHandshake(s, opts, queue, owner) - if err != nil { - return nil, err - } - ep := h.ep + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } + + ep.drainClosingSegmentQueue() - if err := h.complete(); err != nil { - ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() - ep.stats.FailedConnectionAttempts.Increment() - l.cleanupFailedHandshake(h) return nil, err } - l.cleanupCompletedHandshake(h) + ep.isConnectNotified = true + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + ep.rcv.rcvWndScale = h.effectiveRcvWndScale() + return ep, nil } @@ -347,39 +344,6 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } -// Precondition: h.ep.mu must be held. -func (l *listenContext) cleanupFailedHandshake(h *handshake) { - e := h.ep - e.mu.Unlock() - e.Close() - e.notifyAborted() - if l.listenEP != nil { - l.removePendingEndpoint(e) - } - e.drainClosingSegmentQueue() - e.h = nil -} - -// cleanupCompletedHandshake transfers any state from the completed handshake to -// the new endpoint. -// -// Precondition: h.ep.mu must be held. -func (l *listenContext) cleanupCompletedHandshake(h *handshake) { - e := h.ep - if l.listenEP != nil { - l.removePendingEndpoint(e) - } - e.isConnectNotified = true - - // Update the receive window scaling. We can't do it before the - // handshake because it's possible that the peer doesn't support window - // scaling. - e.rcv.rcvWndScale = e.h.effectiveRcvWndScale() - - // Clean up handshake state stored in the endpoint so that it can be GCed. - e.h = nil -} - // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. @@ -459,35 +423,23 @@ func (e *endpoint) notifyAborted() { // // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. -// -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { + defer ctx.synRcvdCount.dec() defer s.decRef() - h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) - n := h.ep + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() e.decSynRcvdCount() return } + ctx.removePendingEndpoint(n) + e.decSynRcvdCount() + n.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go func() { - defer ctx.synRcvdCount.dec() - if err := h.complete(); err != nil { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - e.stats.FailedConnectionAttempts.Increment() - ctx.cleanupFailedHandshake(h) - e.decSynRcvdCount() - return - } - ctx.cleanupCompletedHandshake(h) - e.decSynRcvdCount() - n.startAcceptedLoop() - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(n) - }() // S/R-SAFE: synRcvdCount is the barrier. + e.deliverAccepted(n) } func (e *endpoint) incSynRcvdCount() bool { @@ -515,8 +467,6 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. -// -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.rcvListMu.Lock() rcvClosed := e.rcvClosed @@ -541,7 +491,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // backlog. if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() - e.handleSynSegment(ctx, s, &opts) + go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. return } ctx.synRcvdCount.dec() @@ -736,7 +686,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // to the endpoint. e.setEndpointState(StateClose) - // Close any endpoints in SYN-RCVD state. + // close any endpoints in SYN-RCVD state. ctx.closeAllPendingEndpoints() // Do cleanup if needed. diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index fd5373ed4..0aaef495d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -102,26 +102,21 @@ type handshake struct { // been received. This is required to stop retransmitting the // original SYN-ACK when deferAccept is enabled. acked bool - - // sendSYNOpts is the cached values for the SYN options to be sent. - sendSYNOpts header.TCPSynOptions } -func (e *endpoint) newHandshake() *handshake { - h := &handshake{ - ep: e, +func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { + h := handshake{ + ep: ep, active: true, - rcvWnd: seqnum.Size(e.initialReceiveWindow()), - rcvWndScale: e.rcvWndScaleForHandshake(), + rcvWnd: rcvWnd, + rcvWndScale: ep.rcvWndScaleForHandshake(), } h.resetState() - // Store reference to handshake state in endpoint. - e.h = h return h } -func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake { - h := e.newHandshake() +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) h.resetToSynRcvd(isn, irs, opts, deferAccept) return h } @@ -501,13 +496,12 @@ func (h *handshake) resolveRoute() *tcpip.Error { } // Wait for notification. - index, _ = s.Fetch(true /* block */) + index, _ = s.Fetch(true) } } -// start resolves the route if necessary and sends the first -// SYN/SYN-ACK. -func (h *handshake) start() *tcpip.Error { +// execute executes the TCP 3-way handshake. +func (h *handshake) execute() *tcpip.Error { if h.ep.route.IsResolutionRequired() { if err := h.resolveRoute(); err != nil { return err @@ -515,7 +509,19 @@ func (h *handshake) start() *tcpip.Error { } h.startTime = time.Now() - h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) + // Initialize the resend timer. + resendWaker := sleep.Waker{} + timeOut := time.Duration(time.Second) + rt := time.AfterFunc(timeOut, resendWaker.Assert) + defer rt.Stop() + + // Set up the wakers. + s := sleep.Sleeper{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled @@ -523,6 +529,10 @@ func (h *handshake) start() *tcpip.Error { sackEnabled = false } + // Send the initial SYN segment and loop until the handshake is + // completed. + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) + synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, @@ -532,8 +542,9 @@ func (h *handshake) start() *tcpip.Error { MSS: h.ep.amss, } - // start() is also called in a listen context so we want to make sure we only - // send the TS/SACK option when we received the TS/SACK in the initial SYN. + // Execute is also called in a listen context so we want to make sure we + // only send the TS/SACK option when we received the TS/SACK in the + // initial SYN. if h.state == handshakeSynRcvd { synOpts.TS = h.ep.sendTSOk synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) @@ -544,7 +555,6 @@ func (h *handshake) start() *tcpip.Error { } } - h.sendSYNOpts = synOpts h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, @@ -554,38 +564,19 @@ func (h *handshake) start() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) - return nil -} - -// complete completes the TCP 3-way handshake initiated by h.start(). -func (h *handshake) complete() *tcpip.Error { - // Set up the wakers. - s := sleep.Sleeper{} - resendWaker := sleep.Waker{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) - defer s.Done() - - // Initialize the resend timer. - timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) - if err != nil { - return err - } - defer timer.stop() for h.state != handshakeCompleted { - // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held - // throughout handshake processing). h.ep.mu.Unlock() - index, _ := s.Fetch(true /* block */) + index, _ := s.Fetch(true) h.ep.mu.Lock() switch index { case wakerForResend: - if err := timer.reset(); err != nil { - return err + timeOut *= 2 + if timeOut > MaxRTO { + return tcpip.ErrTimeout } + rt.Reset(timeOut) // Resend the SYN/SYN-ACK only if the following conditions hold. // - It's an active handshake (deferAccept does not apply) // - It's a passive handshake and we have not yet got the final-ACK. @@ -603,7 +594,7 @@ func (h *handshake) complete() *tcpip.Error { seq: h.iss, ack: h.ackNum, rcvWnd: h.rcvWnd, - }, h.sendSYNOpts) + }, synOpts) } case wakerForNotification: @@ -642,34 +633,6 @@ func (h *handshake) complete() *tcpip.Error { return nil } -type backoffTimer struct { - timeout time.Duration - maxTimeout time.Duration - t *time.Timer -} - -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { - if timeout > maxTimeout { - return nil, tcpip.ErrTimeout - } - bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} - bt.t = time.AfterFunc(timeout, f) - return bt, nil -} - -func (bt *backoffTimer) reset() *tcpip.Error { - bt.timeout *= 2 - if bt.timeout > MaxRTO { - return tcpip.ErrTimeout - } - bt.t.Reset(bt.timeout) - return nil -} - -func (bt *backoffTimer) stop() { - bt.t.Stop() -} - func parseSynSegmentOptions(s *segment) header.TCPSynOptions { synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) if synOpts.TS { @@ -1375,7 +1338,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if handshake { - if err := e.h.complete(); err != nil { + // This is an active connection, so we must initiate the 3-way + // handshake, and then inform potential waiters about its + // completion. + initialRcvWnd := e.initialReceiveWindow() + h := newHandshake(e, seqnum.Size(initialRcvWnd)) + h.ep.setEndpointState(StateSynSent) + + if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -1390,6 +1360,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } + e.keepalive.timer.init(&e.keepalive.waker) + defer e.keepalive.timer.cleanup() + drained := e.drainDone != nil if drained { close(e.drainDone) @@ -1562,7 +1535,7 @@ loop: } e.mu.Unlock() - v, _ := s.Fetch(true /* block */) + v, _ := s.Fetch(true) e.mu.Lock() // We need to double check here because the notification may be @@ -1710,7 +1683,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { for { e.mu.Unlock() - v, _ := s.Fetch(true /* block */) + v, _ := s.Fetch(true) e.mu.Lock() switch v { case newSegment: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 8f5e3a42d..bfe26e460 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -440,11 +440,6 @@ type endpoint struct { ttl uint8 v6only bool isConnectNotified bool - // h stores a reference to the current handshake state if the endpoint is in - // the SYN-SENT or SYN-RECV states, in which case endpoint == endpoint.h.ep. - // nil otherwise. - h *handshake `state:"nosave"` - // TCP should never broadcast but Linux nevertheless supports enabling/ // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool @@ -726,9 +721,9 @@ func (e *endpoint) LockUser() { for { // Try first if the sock is locked then check if it's owned // by another user goroutine if not then we spin, otherwise - // we just go to sleep on the Lock() and wait. + // we just goto sleep on the Lock() and wait. if !e.mu.TryLock() { - // If socket is owned by the user then just go to sleep + // If socket is owned by the user then just goto sleep // as the lock could be held for a reasonably long time. if atomic.LoadUint32(&e.ownedByUser) == 1 { e.mu.Lock() @@ -927,7 +922,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.segmentQueue.ep = e e.tsOffset = timeStampOffset() e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) return e } @@ -1149,7 +1143,6 @@ func (e *endpoint) cleanupLocked() { // Close all endpoints that might have been accepted by TCP but not by // the client. e.closePendingAcceptableConnectionsLocked() - e.keepalive.timer.cleanup() e.workerCleanup = false @@ -2189,8 +2182,6 @@ func (*endpoint) Disconnect() *tcpip.Error { func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { err := e.connect(addr, true, true) if err != nil && !err.IgnoreStats() { - // Connect failed. Let's wake up any waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() } @@ -2404,60 +2395,12 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } if run { - if err := e.startMainLoop(handshake); err != nil { - return err - } - } - - return tcpip.ErrConnectStarted -} - -// startMainLoop sends the initial SYN and starts the main loop for the -// endpoint. -func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error { - preloop := func() *tcpip.Error { - if handshake { - h := e.newHandshake() - e.setEndpointState(StateSynSent) - if err := h.start(); err != nil { - e.lastErrorMu.Lock() - e.lastError = err - e.lastErrorMu.Unlock() - - e.setEndpointState(StateError) - e.HardError = err - - // Call cleanupLocked to free up any reservations. - e.cleanupLocked() - return err - } - } - e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() e.workerRunning = true - return nil - } - - if !e.route.IsResolutionRequired() { - // No route resolution is required, so we can send the initial SYN here without - // blocking. This will hopefully reduce overall latency by overlapping time - // spent waiting for a SYN-ACK and time spent spinning up a new goroutine - // for the main loop. - if err := preloop(); err != nil { - return err - } + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. - return nil } - // Sending the initial SYN may block due to route resolution; do it in a - // separate goroutine to avoid blocking the syscall goroutine. - go func() { // S/R-SAFE: will be drained before save. - if err := preloop(); err != nil { - return - } - e.protocolMainLoop(handshake, nil) - }() - return nil + return tcpip.ErrConnectStarted } // ConnectEndpoint is not supported. diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 5e7962794..b25431467 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -172,7 +172,6 @@ func (e *endpoint) afterLoad() { // Condition variables and mutexs are not S/R'ed so reinitialize // acceptCond with e.acceptMu. e.acceptCond = sync.NewCond(&e.acceptMu) - e.keepalive.timer.init(&e.keepalive.waker) stack.StackFromEnv.RegisterRestoredEndpoint(e) } diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 8c334c97b..070b634b4 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -150,7 +150,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, } f := r.forwarder - ep, err := f.listen.performHandshake(r.segment, &header.TCPSynOptions{ + ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ MSS: r.synOptions.MSS, WS: r.synOptions.WS, TS: r.synOptions.TS, diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go index 38a335840..7981d469b 100644 --- a/pkg/tcpip/transport/tcp/timer.go +++ b/pkg/tcpip/transport/tcp/timer.go @@ -84,10 +84,6 @@ func (t *timer) init(w *sleep.Waker) { // cleanup frees all resources associated with the timer. func (t *timer) cleanup() { - if t.timer == nil { - // No cleanup needed. - return - } t.timer.Stop() *t = timer{} } |