summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xpkg/sync/aliases.go5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go46
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go70
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go3
4 files changed, 82 insertions, 42 deletions
diff --git a/pkg/sync/aliases.go b/pkg/sync/aliases.go
index d2d7132fa..0d4316254 100755
--- a/pkg/sync/aliases.go
+++ b/pkg/sync/aliases.go
@@ -29,3 +29,8 @@ type (
// Map is an alias of sync.Map.
Map = sync.Map
)
+
+// NewCond is a wrapper around sync.NewCond.
+func NewCond(l Locker) *Cond {
+ return sync.NewCond(l)
+}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index b4c4c8ab1..375ca21f6 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -365,21 +365,29 @@ func (l *listenContext) closeAllPendingEndpoints() {
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
-// endpoint has transitioned out of the listen state, the new endpoint is closed
-// instead.
+// endpoint has transitioned out of the listen state (acceptedChan is nil),
+// the new endpoint is closed instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.Lock()
- state := e.EndpointState()
e.pendingAccepted.Add(1)
- defer e.pendingAccepted.Done()
- acceptedChan := e.acceptedChan
e.mu.Unlock()
+ defer e.pendingAccepted.Done()
- if state == StateListen {
- acceptedChan <- n
- e.waiterQueue.Notify(waiter.EventIn)
- } else {
- n.Close()
+ e.acceptMu.Lock()
+ for {
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ n.Close()
+ return
+ }
+ select {
+ case e.acceptedChan <- n:
+ e.acceptMu.Unlock()
+ e.waiterQueue.Notify(waiter.EventIn)
+ return
+ default:
+ e.acceptCond.Wait()
+ }
}
}
@@ -420,11 +428,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
}
func (e *endpoint) incSynRcvdCount() bool {
- if e.synRcvdCount >= cap(e.acceptedChan) {
- return false
+ e.acceptMu.Lock()
+ canInc := e.synRcvdCount < cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ if canInc {
+ e.synRcvdCount++
}
- e.synRcvdCount++
- return true
+ return canInc
}
func (e *endpoint) decSynRcvdCount() {
@@ -432,10 +442,10 @@ func (e *endpoint) decSynRcvdCount() {
}
func (e *endpoint) acceptQueueIsFull() bool {
- if l, c := len(e.acceptedChan)+e.synRcvdCount, cap(e.acceptedChan); l >= c {
- return true
- }
- return false
+ e.acceptMu.Lock()
+ full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan)
+ e.acceptMu.Unlock()
+ return full
}
// handleListenSegment is called when a listening endpoint receives a segment
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b6e571361..1ebee0cfe 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -291,6 +291,7 @@ func (*EndpointInfo) IsEndpointInfo() {}
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
+// e.acceptMu -> protects acceptedChan.
// e.rcvListMu -> Protects the rcvList and associated fields.
// e.sndBufMu -> Protects the sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -533,6 +534,23 @@ type endpoint struct {
// to the acceptedChan below terminate before we close acceptedChan.
pendingAccepted sync.WaitGroup `state:"nosave"`
+ // acceptMu protects acceptedChan.
+ acceptMu sync.Mutex `state:"nosave"`
+
+ // acceptCond is a condition variable that can be used to block on when
+ // acceptedChan is full and an endpoint is ready to be delivered.
+ //
+ // This condition variable is required because just blocking on sending
+ // to acceptedChan does not work in cases where endpoint.Listen is
+ // called twice with different backlog values. In such cases the channel
+ // is closed and a new one created. Any pending goroutines blocking on
+ // the write to the channel will panic.
+ //
+ // We use this condition variable to block/unblock goroutines which
+ // tried to deliver an endpoint but couldn't because accept backlog was
+ // full ( See: endpoint.deliverAccepted ).
+ acceptCond *sync.Cond `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -814,6 +832,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.tsOffset = timeStampOffset()
+ e.acceptCond = sync.NewCond(&e.acceptMu)
return e
}
@@ -834,9 +853,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
case StateListen:
// Check if there's anything in the accepted channel.
if (mask & waiter.EventIn) != 0 {
+ e.acceptMu.Lock()
if len(e.acceptedChan) > 0 {
result |= waiter.EventIn
}
+ e.acceptMu.Unlock()
}
}
if e.EndpointState().connected() {
@@ -981,29 +1002,19 @@ func (e *endpoint) closeNoShutdownLocked() {
// closePendingAcceptableConnections closes all connections that have completed
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
- done := make(chan struct{})
- // Spin a goroutine up as ranging on e.acceptedChan will just block when
- // there are no more connections in the channel. Using a non-blocking
- // select does not work as it can potentially select the default case
- // even when there are pending writes but that are not yet written to
- // the channel.
- go func() {
- defer close(done)
- for n := range e.acceptedChan {
- n.notifyProtocolGoroutine(notifyReset)
- // close all connections that have completed but
- // not accepted by the application.
- n.Close()
- }
- }()
- // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
- // endpoints which have completed handshake but are not yet written to
- // the e.acceptedChan. We wait here till the goroutine above can drain
- // all such connections from e.acceptedChan.
- e.pendingAccepted.Wait()
+ e.acceptMu.Lock()
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ return
+ }
+
close(e.acceptedChan)
- <-done
e.acceptedChan = nil
+ e.acceptCond.Broadcast()
+ e.acceptMu.Unlock()
+
+ // Wait for all pending endpoints to close.
+ e.pendingAccepted.Wait()
}
// cleanupLocked frees all resources associated with the endpoint. It is called
@@ -1012,9 +1023,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
- if e.acceptedChan != nil {
- e.closePendingAcceptableConnectionsLocked()
- }
+ e.closePendingAcceptableConnectionsLocked()
e.workerCleanup = false
@@ -2204,6 +2213,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
if e.EndpointState() == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
if len(e.acceptedChan) > backlog {
return tcpip.ErrInvalidEndpointState
}
@@ -2216,6 +2227,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
for ep := range origChan {
e.acceptedChan <- ep
}
+
+ // Notify any blocked goroutines that they can attempt to
+ // deliver endpoints again.
+ e.acceptCond.Broadcast()
+
return nil
}
@@ -2245,9 +2261,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// The channel may be non-nil when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
+ e.acceptMu.Lock()
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
+ e.acceptMu.Unlock()
+
e.workerRunning = true
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
@@ -2276,9 +2295,12 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
// Get the new accepted endpoint.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
var n *endpoint
select {
case n = <-e.acceptedChan:
+ e.acceptCond.Signal()
default:
return nil, nil, tcpip.ErrWouldBlock
}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
index 9175de441..c3c692555 100644
--- a/pkg/tcpip/transport/tcp/endpoint_state.go
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -173,6 +173,9 @@ func (e *endpoint) afterLoad() {
// Restore the endpoint to InitialState as it will be moved to
// its origEndpointState during Resume.
e.state = StateInitial
+ // Condition variables and mutexs are not S/R'ed so reinitialize
+ // acceptCond with e.acceptMu.
+ e.acceptCond = sync.NewCond(&e.acceptMu)
stack.StackFromEnv.RegisterRestoredEndpoint(e)
}