diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/connect.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 56 |
1 files changed, 44 insertions, 12 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 4d20f4d3f..698e2b440 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -296,6 +296,21 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { return nil } +func (h *handshake) handleSegment(s *segment) *tcpip.Error { + h.sndWnd = s.window + if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 { + h.sndWnd <<= uint8(h.sndWndScale) + } + + switch h.state { + case handshakeSynRcvd: + return h.synRcvdState(s) + case handshakeSynSent: + return h.synSentState(s) + } + return nil +} + // processSegments goes through the segment queue and processes up to // maxSegmentsPerWake (if they're available). func (h *handshake) processSegments() *tcpip.Error { @@ -305,18 +320,7 @@ func (h *handshake) processSegments() *tcpip.Error { return nil } - h.sndWnd = s.window - if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 { - h.sndWnd <<= uint8(h.sndWndScale) - } - - var err *tcpip.Error - switch h.state { - case handshakeSynRcvd: - err = h.synRcvdState(s) - case handshakeSynSent: - err = h.synSentState(s) - } + err := h.handleSegment(s) s.decRef() if err != nil { return err @@ -364,6 +368,10 @@ func (h *handshake) resolveRoute() *tcpip.Error { h.ep.route.RemoveWaker(resolutionWaker) return tcpip.ErrAborted } + if n¬ifyDrain != 0 { + close(h.ep.drainDone) + <-h.ep.undrain + } } // Wait for notification. @@ -434,6 +442,20 @@ func (h *handshake) execute() *tcpip.Error { if n¬ifyClose != 0 { return tcpip.ErrAborted } + if n¬ifyDrain != 0 { + for s := h.ep.segmentQueue.dequeue(); s != nil; s = h.ep.segmentQueue.dequeue() { + err := h.handleSegment(s) + s.decRef() + if err != nil { + return err + } + if h.state == handshakeCompleted { + return nil + } + } + close(h.ep.drainDone) + <-h.ep.undrain + } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -833,7 +855,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.mu.Lock() e.state = stateError e.hardError = err + drained := e.drainDone != nil e.mu.Unlock() + if drained { + close(e.drainDone) + <-e.undrain + } return err } @@ -851,7 +878,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() e.state = stateConnected + drained := e.drainDone != nil e.mu.Unlock() + if drained { + close(e.drainDone) + <-e.undrain + } e.waiterQueue.Notify(waiter.EventOut) |