diff options
Diffstat (limited to 'pkg/tcpip/transport/tcp/connect.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 90 |
1 files changed, 43 insertions, 47 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 5115dabe6..0571ceaa5 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -690,7 +690,7 @@ func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, return err } -func (e *endpoint) handleWrite() bool { +func (e *endpoint) handleWrite() *tcpip.Error { // Move packets from send queue to send list. The queue is accessible // from other goroutines and protected by the send mutex, while the send // list is only accessible from the handler goroutine, so it needs no @@ -714,47 +714,42 @@ func (e *endpoint) handleWrite() bool { // Push out any new packets. e.snd.sendData() - return true + return nil } -func (e *endpoint) handleClose() bool { +func (e *endpoint) handleClose() *tcpip.Error { // Drain the send queue. e.handleWrite() // Mark send side as closed. e.snd.closed = true - return true + return nil } -// resetConnection sends a RST segment and puts the endpoint in an error state -// with the given error code. -// This method must only be called from the protocol goroutine. -func (e *endpoint) resetConnection(err *tcpip.Error) { +// resetConnectionLocked sends a RST segment and puts the endpoint in an error +// state with the given error code. This method must only be called from the +// protocol goroutine. +func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) - e.mu.Lock() e.state = stateError e.hardError = err - e.mu.Unlock() } -// completeWorker is called by the worker goroutine when it's about to exit. It -// marks the worker as completed and performs cleanup work if requested by -// Close(). -func (e *endpoint) completeWorker() { - e.mu.Lock() - defer e.mu.Unlock() - +// completeWorkerLocked is called by the worker goroutine when it's about to +// exit. It marks the worker as completed and performs cleanup work if requested +// by Close(). +func (e *endpoint) completeWorkerLocked() { e.workerRunning = false if e.workerCleanup { - e.cleanup() + e.cleanupLocked() } } // handleSegments pulls segments from the queue and processes them. It returns -// true if the protocol loop should continue, false otherwise. -func (e *endpoint) handleSegments() bool { +// no error if the protocol loop should continue, an error otherwise. +func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { s := e.segmentQueue.dequeue() @@ -775,11 +770,7 @@ func (e *endpoint) handleSegments() bool { // validated by checking their SEQ-fields." So // we only process it if it's acceptable. s.decRef() - e.mu.Lock() - e.state = stateError - e.hardError = tcpip.ErrConnectionReset - e.mu.Unlock() - return false + return tcpip.ErrConnectionReset } } else if s.flagIsSet(flagAck) { // Patch the window size in the segment according to the @@ -816,7 +807,7 @@ func (e *endpoint) handleSegments() bool { e.snd.sendAck() } - return true + return nil } // protocolMainLoop is the main loop of the TCP protocol. It runs in its own @@ -827,9 +818,9 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { var closeWaker sleep.Waker defer func() { - // When the protocol loop exits we should wake up our waiters. - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) - e.completeWorker() + // e.mu is expected to be hold upon entering this section. + + e.completeWorkerLocked() if e.snd != nil { e.snd.resendTimer.cleanup() @@ -838,6 +829,15 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { if closeTimer != nil { closeTimer.Stop() } + + if e.drainDone != nil { + close(e.drainDone) + } + + e.mu.Unlock() + + // When the protocol loop exits we should wake up our waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) }() if !passive { @@ -856,12 +856,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.mu.Lock() e.state = stateError e.hardError = err - drained := e.drainDone != nil - e.mu.Unlock() - if drained { - close(e.drainDone) - <-e.undrain - } + // Lock released in deferred statement. return err } @@ -892,7 +887,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // wakes up. funcs := []struct { w *sleep.Waker - f func() bool + f func() *tcpip.Error }{ { w: &e.sndWaker, @@ -908,24 +903,22 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { }, { w: &closeWaker, - f: func() bool { - e.resetConnection(tcpip.ErrConnectionAborted) - return false + f: func() *tcpip.Error { + return tcpip.ErrConnectionAborted }, }, { w: &e.snd.resendWaker, - f: func() bool { + f: func() *tcpip.Error { if !e.snd.retransmitTimerExpired() { - e.resetConnection(tcpip.ErrTimeout) - return false + return tcpip.ErrTimeout } - return true + return nil }, }, { w: &e.notificationWaker, - f: func() bool { + f: func() *tcpip.Error { n := e.fetchNotifications() if n¬ifyNonZeroReceiveWindow != 0 { e.rcv.nonZeroWindow() @@ -952,7 +945,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { closeWaker.Assert() }) } - return true + return nil }, }, } @@ -969,7 +962,10 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { e.workMu.Unlock() v, _ := s.Fetch(true) e.workMu.Lock() - if !funcs[v].f() { + if err := funcs[v].f(); err != nil { + e.mu.Lock() + e.resetConnectionLocked(err) + // Lock released in deferred statement. return nil } } @@ -977,7 +973,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() e.state = stateClosed - e.mu.Unlock() + // Lock released in deferred statement. return nil } |