summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/connect.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp/connect.go')
-rw-r--r--pkg/tcpip/transport/tcp/connect.go56
1 files changed, 44 insertions, 12 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 4d20f4d3f..698e2b440 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -296,6 +296,21 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
return nil
}
+func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+ h.sndWnd = s.window
+ if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
+ h.sndWnd <<= uint8(h.sndWndScale)
+ }
+
+ switch h.state {
+ case handshakeSynRcvd:
+ return h.synRcvdState(s)
+ case handshakeSynSent:
+ return h.synSentState(s)
+ }
+ return nil
+}
+
// processSegments goes through the segment queue and processes up to
// maxSegmentsPerWake (if they're available).
func (h *handshake) processSegments() *tcpip.Error {
@@ -305,18 +320,7 @@ func (h *handshake) processSegments() *tcpip.Error {
return nil
}
- h.sndWnd = s.window
- if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
- h.sndWnd <<= uint8(h.sndWndScale)
- }
-
- var err *tcpip.Error
- switch h.state {
- case handshakeSynRcvd:
- err = h.synRcvdState(s)
- case handshakeSynSent:
- err = h.synSentState(s)
- }
+ err := h.handleSegment(s)
s.decRef()
if err != nil {
return err
@@ -364,6 +368,10 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.route.RemoveWaker(resolutionWaker)
return tcpip.ErrAborted
}
+ if n&notifyDrain != 0 {
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
}
// Wait for notification.
@@ -434,6 +442,20 @@ func (h *handshake) execute() *tcpip.Error {
if n&notifyClose != 0 {
return tcpip.ErrAborted
}
+ if n&notifyDrain != 0 {
+ for s := h.ep.segmentQueue.dequeue(); s != nil; s = h.ep.segmentQueue.dequeue() {
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+ if h.state == handshakeCompleted {
+ return nil
+ }
+ }
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
@@ -833,7 +855,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
e.mu.Lock()
e.state = stateError
e.hardError = err
+ drained := e.drainDone != nil
e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
return err
}
@@ -851,7 +878,12 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
// Tell waiters that the endpoint is connected and writable.
e.mu.Lock()
e.state = stateConnected
+ drained := e.drainDone != nil
e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
e.waiterQueue.Notify(waiter.EventOut)