summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/connect.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp/connect.go')
-rw-r--r--pkg/tcpip/transport/tcp/connect.go90
1 files changed, 43 insertions, 47 deletions
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 5115dabe6..0571ceaa5 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -690,7 +690,7 @@ func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value,
return err
}
-func (e *endpoint) handleWrite() bool {
+func (e *endpoint) handleWrite() *tcpip.Error {
// Move packets from send queue to send list. The queue is accessible
// from other goroutines and protected by the send mutex, while the send
// list is only accessible from the handler goroutine, so it needs no
@@ -714,47 +714,42 @@ func (e *endpoint) handleWrite() bool {
// Push out any new packets.
e.snd.sendData()
- return true
+ return nil
}
-func (e *endpoint) handleClose() bool {
+func (e *endpoint) handleClose() *tcpip.Error {
// Drain the send queue.
e.handleWrite()
// Mark send side as closed.
e.snd.closed = true
- return true
+ return nil
}
-// resetConnection sends a RST segment and puts the endpoint in an error state
-// with the given error code.
-// This method must only be called from the protocol goroutine.
-func (e *endpoint) resetConnection(err *tcpip.Error) {
+// resetConnectionLocked sends a RST segment and puts the endpoint in an error
+// state with the given error code. This method must only be called from the
+// protocol goroutine.
+func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
- e.mu.Lock()
e.state = stateError
e.hardError = err
- e.mu.Unlock()
}
-// completeWorker is called by the worker goroutine when it's about to exit. It
-// marks the worker as completed and performs cleanup work if requested by
-// Close().
-func (e *endpoint) completeWorker() {
- e.mu.Lock()
- defer e.mu.Unlock()
-
+// completeWorkerLocked is called by the worker goroutine when it's about to
+// exit. It marks the worker as completed and performs cleanup work if requested
+// by Close().
+func (e *endpoint) completeWorkerLocked() {
e.workerRunning = false
if e.workerCleanup {
- e.cleanup()
+ e.cleanupLocked()
}
}
// handleSegments pulls segments from the queue and processes them. It returns
-// true if the protocol loop should continue, false otherwise.
-func (e *endpoint) handleSegments() bool {
+// no error if the protocol loop should continue, an error otherwise.
+func (e *endpoint) handleSegments() *tcpip.Error {
checkRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
s := e.segmentQueue.dequeue()
@@ -775,11 +770,7 @@ func (e *endpoint) handleSegments() bool {
// validated by checking their SEQ-fields." So
// we only process it if it's acceptable.
s.decRef()
- e.mu.Lock()
- e.state = stateError
- e.hardError = tcpip.ErrConnectionReset
- e.mu.Unlock()
- return false
+ return tcpip.ErrConnectionReset
}
} else if s.flagIsSet(flagAck) {
// Patch the window size in the segment according to the
@@ -816,7 +807,7 @@ func (e *endpoint) handleSegments() bool {
e.snd.sendAck()
}
- return true
+ return nil
}
// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
@@ -827,9 +818,9 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
var closeWaker sleep.Waker
defer func() {
- // When the protocol loop exits we should wake up our waiters.
- e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
- e.completeWorker()
+ // e.mu is expected to be hold upon entering this section.
+
+ e.completeWorkerLocked()
if e.snd != nil {
e.snd.resendTimer.cleanup()
@@ -838,6 +829,15 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
if closeTimer != nil {
closeTimer.Stop()
}
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+
+ e.mu.Unlock()
+
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}()
if !passive {
@@ -856,12 +856,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
e.mu.Lock()
e.state = stateError
e.hardError = err
- drained := e.drainDone != nil
- e.mu.Unlock()
- if drained {
- close(e.drainDone)
- <-e.undrain
- }
+ // Lock released in deferred statement.
return err
}
@@ -892,7 +887,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
// wakes up.
funcs := []struct {
w *sleep.Waker
- f func() bool
+ f func() *tcpip.Error
}{
{
w: &e.sndWaker,
@@ -908,24 +903,22 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
},
{
w: &closeWaker,
- f: func() bool {
- e.resetConnection(tcpip.ErrConnectionAborted)
- return false
+ f: func() *tcpip.Error {
+ return tcpip.ErrConnectionAborted
},
},
{
w: &e.snd.resendWaker,
- f: func() bool {
+ f: func() *tcpip.Error {
if !e.snd.retransmitTimerExpired() {
- e.resetConnection(tcpip.ErrTimeout)
- return false
+ return tcpip.ErrTimeout
}
- return true
+ return nil
},
},
{
w: &e.notificationWaker,
- f: func() bool {
+ f: func() *tcpip.Error {
n := e.fetchNotifications()
if n&notifyNonZeroReceiveWindow != 0 {
e.rcv.nonZeroWindow()
@@ -952,7 +945,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
closeWaker.Assert()
})
}
- return true
+ return nil
},
},
}
@@ -969,7 +962,10 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
e.workMu.Unlock()
v, _ := s.Fetch(true)
e.workMu.Lock()
- if !funcs[v].f() {
+ if err := funcs[v].f(); err != nil {
+ e.mu.Lock()
+ e.resetConnectionLocked(err)
+ // Lock released in deferred statement.
return nil
}
}
@@ -977,7 +973,7 @@ func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
// Mark endpoint as closed.
e.mu.Lock()
e.state = stateClosed
- e.mu.Unlock()
+ // Lock released in deferred statement.
return nil
}