diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/connect.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 125 |
1 files changed, 49 insertions, 76 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index fd5373ed4..0aaef495d 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -102,26 +102,21 @@ type handshake struct { // been received. This is required to stop retransmitting the // original SYN-ACK when deferAccept is enabled. acked bool - - // sendSYNOpts is the cached values for the SYN options to be sent. - sendSYNOpts header.TCPSynOptions } -func (e *endpoint) newHandshake() *handshake { - h := &handshake{ - ep: e, +func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { + h := handshake{ + ep: ep, active: true, - rcvWnd: seqnum.Size(e.initialReceiveWindow()), - rcvWndScale: e.rcvWndScaleForHandshake(), + rcvWnd: rcvWnd, + rcvWndScale: ep.rcvWndScaleForHandshake(), } h.resetState() - // Store reference to handshake state in endpoint. - e.h = h return h } -func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake { - h := e.newHandshake() +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) h.resetToSynRcvd(isn, irs, opts, deferAccept) return h } @@ -501,13 +496,12 @@ func (h *handshake) resolveRoute() *tcpip.Error { } // Wait for notification. - index, _ = s.Fetch(true /* block */) + index, _ = s.Fetch(true) } } -// start resolves the route if necessary and sends the first -// SYN/SYN-ACK. -func (h *handshake) start() *tcpip.Error { +// execute executes the TCP 3-way handshake. +func (h *handshake) execute() *tcpip.Error { if h.ep.route.IsResolutionRequired() { if err := h.resolveRoute(); err != nil { return err @@ -515,7 +509,19 @@ func (h *handshake) start() *tcpip.Error { } h.startTime = time.Now() - h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) + // Initialize the resend timer. + resendWaker := sleep.Waker{} + timeOut := time.Duration(time.Second) + rt := time.AfterFunc(timeOut, resendWaker.Assert) + defer rt.Stop() + + // Set up the wakers. + s := sleep.Sleeper{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled @@ -523,6 +529,10 @@ func (h *handshake) start() *tcpip.Error { sackEnabled = false } + // Send the initial SYN segment and loop until the handshake is + // completed. + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) + synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, @@ -532,8 +542,9 @@ func (h *handshake) start() *tcpip.Error { MSS: h.ep.amss, } - // start() is also called in a listen context so we want to make sure we only - // send the TS/SACK option when we received the TS/SACK in the initial SYN. + // Execute is also called in a listen context so we want to make sure we + // only send the TS/SACK option when we received the TS/SACK in the + // initial SYN. if h.state == handshakeSynRcvd { synOpts.TS = h.ep.sendTSOk synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) @@ -544,7 +555,6 @@ func (h *handshake) start() *tcpip.Error { } } - h.sendSYNOpts = synOpts h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, @@ -554,38 +564,19 @@ func (h *handshake) start() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) - return nil -} - -// complete completes the TCP 3-way handshake initiated by h.start(). -func (h *handshake) complete() *tcpip.Error { - // Set up the wakers. - s := sleep.Sleeper{} - resendWaker := sleep.Waker{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) - defer s.Done() - - // Initialize the resend timer. - timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) - if err != nil { - return err - } - defer timer.stop() for h.state != handshakeCompleted { - // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held - // throughout handshake processing). h.ep.mu.Unlock() - index, _ := s.Fetch(true /* block */) + index, _ := s.Fetch(true) h.ep.mu.Lock() switch index { case wakerForResend: - if err := timer.reset(); err != nil { - return err + timeOut *= 2 + if timeOut > MaxRTO { + return tcpip.ErrTimeout } + rt.Reset(timeOut) // Resend the SYN/SYN-ACK only if the following conditions hold. // - It's an active handshake (deferAccept does not apply) // - It's a passive handshake and we have not yet got the final-ACK. @@ -603,7 +594,7 @@ func (h *handshake) complete() *tcpip.Error { seq: h.iss, ack: h.ackNum, rcvWnd: h.rcvWnd, - }, h.sendSYNOpts) + }, synOpts) } case wakerForNotification: @@ -642,34 +633,6 @@ func (h *handshake) complete() *tcpip.Error { return nil } -type backoffTimer struct { - timeout time.Duration - maxTimeout time.Duration - t *time.Timer -} - -func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { - if timeout > maxTimeout { - return nil, tcpip.ErrTimeout - } - bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} - bt.t = time.AfterFunc(timeout, f) - return bt, nil -} - -func (bt *backoffTimer) reset() *tcpip.Error { - bt.timeout *= 2 - if bt.timeout > MaxRTO { - return tcpip.ErrTimeout - } - bt.t.Reset(bt.timeout) - return nil -} - -func (bt *backoffTimer) stop() { - bt.t.Stop() -} - func parseSynSegmentOptions(s *segment) header.TCPSynOptions { synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) if synOpts.TS { @@ -1375,7 +1338,14 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if handshake { - if err := e.h.complete(); err != nil { + // This is an active connection, so we must initiate the 3-way + // handshake, and then inform potential waiters about its + // completion. + initialRcvWnd := e.initialReceiveWindow() + h := newHandshake(e, seqnum.Size(initialRcvWnd)) + h.ep.setEndpointState(StateSynSent) + + if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -1390,6 +1360,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } + e.keepalive.timer.init(&e.keepalive.waker) + defer e.keepalive.timer.cleanup() + drained := e.drainDone != nil if drained { close(e.drainDone) @@ -1562,7 +1535,7 @@ loop: } e.mu.Unlock() - v, _ := s.Fetch(true /* block */) + v, _ := s.Fetch(true) e.mu.Lock() // We need to double check here because the notification may be @@ -1710,7 +1683,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { for { e.mu.Unlock() - v, _ := s.Fetch(true /* block */) + v, _ := s.Fetch(true) e.mu.Lock() switch v { case newSegment: |