diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/connect.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 115 |
1 files changed, 70 insertions, 45 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 6e9015be1..ac6d879a7 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -102,21 +102,26 @@ type handshake struct { // been received. This is required to stop retransmitting the // original SYN-ACK when deferAccept is enabled. acked bool + + // sendSYNOpts is the cached values for the SYN options to be sent. + sendSYNOpts header.TCPSynOptions } -func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { - h := handshake{ - ep: ep, +func (e *endpoint) newHandshake() *handshake { + h := &handshake{ + ep: e, active: true, - rcvWnd: rcvWnd, - rcvWndScale: ep.rcvWndScaleForHandshake(), + rcvWnd: seqnum.Size(e.initialReceiveWindow()), + rcvWndScale: e.rcvWndScaleForHandshake(), } h.resetState() + // Store reference to handshake state in endpoint. + e.h = h return h } -func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { - h := newHandshake(ep, rcvWnd) +func (e *endpoint) newPassiveHandshake(isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) *handshake { + h := e.newHandshake() h.resetToSynRcvd(isn, irs, opts, deferAccept) return h } @@ -502,8 +507,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } } -// execute executes the TCP 3-way handshake. -func (h *handshake) execute() *tcpip.Error { +// start resolves the route if necessary and sends the first +// SYN/SYN-ACK. +func (h *handshake) start() *tcpip.Error { if h.ep.route.IsResolutionRequired() { if err := h.resolveRoute(); err != nil { return err @@ -511,19 +517,7 @@ func (h *handshake) execute() *tcpip.Error { } h.startTime = time.Now() - // Initialize the resend timer. - resendWaker := sleep.Waker{} - timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, resendWaker.Assert) - defer rt.Stop() - - // Set up the wakers. - s := sleep.Sleeper{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) - defer s.Done() - + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) var sackEnabled tcpip.TCPSACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled @@ -531,10 +525,6 @@ func (h *handshake) execute() *tcpip.Error { sackEnabled = false } - // Send the initial SYN segment and loop until the handshake is - // completed. - h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) - synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: true, @@ -544,9 +534,8 @@ func (h *handshake) execute() *tcpip.Error { MSS: h.ep.amss, } - // Execute is also called in a listen context so we want to make sure we - // only send the TS/SACK option when we received the TS/SACK in the - // initial SYN. + // start() is also called in a listen context so we want to make sure we only + // send the TS/SACK option when we received the TS/SACK in the initial SYN. if h.state == handshakeSynRcvd { synOpts.TS = h.ep.sendTSOk synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) @@ -557,6 +546,7 @@ func (h *handshake) execute() *tcpip.Error { } } + h.sendSYNOpts = synOpts h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, @@ -566,6 +556,25 @@ func (h *handshake) execute() *tcpip.Error { ack: h.ackNum, rcvWnd: h.rcvWnd, }, synOpts) + return nil +} + +// complete completes the TCP 3-way handshake initiated by h.start(). +func (h *handshake) complete() *tcpip.Error { + // Set up the wakers. + s := sleep.Sleeper{} + resendWaker := sleep.Waker{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + + // Initialize the resend timer. + timer, err := newBackoffTimer(time.Second, MaxRTO, resendWaker.Assert) + if err != nil { + return err + } + defer timer.stop() for h.state != handshakeCompleted { // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held @@ -576,11 +585,9 @@ func (h *handshake) execute() *tcpip.Error { switch index { case wakerForResend: - timeOut *= 2 - if timeOut > MaxRTO { - return tcpip.ErrTimeout + if err := timer.reset(); err != nil { + return err } - rt.Reset(timeOut) // Resend the SYN/SYN-ACK only if the following conditions hold. // - It's an active handshake (deferAccept does not apply) // - It's a passive handshake and we have not yet got the final-ACK. @@ -598,7 +605,7 @@ func (h *handshake) execute() *tcpip.Error { seq: h.iss, ack: h.ackNum, rcvWnd: h.rcvWnd, - }, synOpts) + }, h.sendSYNOpts) } case wakerForNotification: @@ -637,6 +644,34 @@ func (h *handshake) execute() *tcpip.Error { return nil } +type backoffTimer struct { + timeout time.Duration + maxTimeout time.Duration + t *time.Timer +} + +func newBackoffTimer(timeout, maxTimeout time.Duration, f func()) (*backoffTimer, *tcpip.Error) { + if timeout > maxTimeout { + return nil, tcpip.ErrTimeout + } + bt := &backoffTimer{timeout: timeout, maxTimeout: maxTimeout} + bt.t = time.AfterFunc(timeout, f) + return bt, nil +} + +func (bt *backoffTimer) reset() *tcpip.Error { + bt.timeout *= 2 + if bt.timeout > MaxRTO { + return tcpip.ErrTimeout + } + bt.t.Reset(bt.timeout) + return nil +} + +func (bt *backoffTimer) stop() { + bt.t.Stop() +} + func parseSynSegmentOptions(s *segment) header.TCPSynOptions { synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) if synOpts.TS { @@ -1342,14 +1377,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } if handshake { - // This is an active connection, so we must initiate the 3-way - // handshake, and then inform potential waiters about its - // completion. - initialRcvWnd := e.initialReceiveWindow() - h := newHandshake(e, seqnum.Size(initialRcvWnd)) - h.ep.setEndpointState(StateSynSent) - - if err := h.execute(); err != nil { + if err := e.h.complete(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() @@ -1364,9 +1392,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } } - e.keepalive.timer.init(&e.keepalive.waker) - defer e.keepalive.timer.cleanup() - drained := e.drainDone != nil if drained { close(e.drainDone) |