summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/endpoint.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp/endpoint.go')
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go70
1 files changed, 46 insertions, 24 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index b6e571361..1ebee0cfe 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -291,6 +291,7 @@ func (*EndpointInfo) IsEndpointInfo() {}
// The following three mutexes can be acquired independent of e.mu but if
// acquired with e.mu then e.mu must be acquired first.
//
+// e.acceptMu -> protects acceptedChan.
// e.rcvListMu -> Protects the rcvList and associated fields.
// e.sndBufMu -> Protects the sndQueue and associated fields.
// e.lastErrorMu -> Protects the lastError field.
@@ -533,6 +534,23 @@ type endpoint struct {
// to the acceptedChan below terminate before we close acceptedChan.
pendingAccepted sync.WaitGroup `state:"nosave"`
+ // acceptMu protects acceptedChan.
+ acceptMu sync.Mutex `state:"nosave"`
+
+ // acceptCond is a condition variable that can be used to block on when
+ // acceptedChan is full and an endpoint is ready to be delivered.
+ //
+ // This condition variable is required because just blocking on sending
+ // to acceptedChan does not work in cases where endpoint.Listen is
+ // called twice with different backlog values. In such cases the channel
+ // is closed and a new one created. Any pending goroutines blocking on
+ // the write to the channel will panic.
+ //
+ // We use this condition variable to block/unblock goroutines which
+ // tried to deliver an endpoint but couldn't because accept backlog was
+ // full ( See: endpoint.deliverAccepted ).
+ acceptCond *sync.Cond `state:"nosave"`
+
// acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be
// read by Accept() calls.
@@ -814,6 +832,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue
e.segmentQueue.setLimit(MaxUnprocessedSegments)
e.tsOffset = timeStampOffset()
+ e.acceptCond = sync.NewCond(&e.acceptMu)
return e
}
@@ -834,9 +853,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
case StateListen:
// Check if there's anything in the accepted channel.
if (mask & waiter.EventIn) != 0 {
+ e.acceptMu.Lock()
if len(e.acceptedChan) > 0 {
result |= waiter.EventIn
}
+ e.acceptMu.Unlock()
}
}
if e.EndpointState().connected() {
@@ -981,29 +1002,19 @@ func (e *endpoint) closeNoShutdownLocked() {
// closePendingAcceptableConnections closes all connections that have completed
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
- done := make(chan struct{})
- // Spin a goroutine up as ranging on e.acceptedChan will just block when
- // there are no more connections in the channel. Using a non-blocking
- // select does not work as it can potentially select the default case
- // even when there are pending writes but that are not yet written to
- // the channel.
- go func() {
- defer close(done)
- for n := range e.acceptedChan {
- n.notifyProtocolGoroutine(notifyReset)
- // close all connections that have completed but
- // not accepted by the application.
- n.Close()
- }
- }()
- // pendingAccepted(see endpoint.deliverAccepted) tracks the number of
- // endpoints which have completed handshake but are not yet written to
- // the e.acceptedChan. We wait here till the goroutine above can drain
- // all such connections from e.acceptedChan.
- e.pendingAccepted.Wait()
+ e.acceptMu.Lock()
+ if e.acceptedChan == nil {
+ e.acceptMu.Unlock()
+ return
+ }
+
close(e.acceptedChan)
- <-done
e.acceptedChan = nil
+ e.acceptCond.Broadcast()
+ e.acceptMu.Unlock()
+
+ // Wait for all pending endpoints to close.
+ e.pendingAccepted.Wait()
}
// cleanupLocked frees all resources associated with the endpoint. It is called
@@ -1012,9 +1023,7 @@ func (e *endpoint) closePendingAcceptableConnectionsLocked() {
func (e *endpoint) cleanupLocked() {
// Close all endpoints that might have been accepted by TCP but not by
// the client.
- if e.acceptedChan != nil {
- e.closePendingAcceptableConnectionsLocked()
- }
+ e.closePendingAcceptableConnectionsLocked()
e.workerCleanup = false
@@ -2204,6 +2213,8 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
if e.EndpointState() == StateListen && !e.workerCleanup {
// Adjust the size of the channel iff we can fix existing
// pending connections into the new one.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
if len(e.acceptedChan) > backlog {
return tcpip.ErrInvalidEndpointState
}
@@ -2216,6 +2227,11 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
for ep := range origChan {
e.acceptedChan <- ep
}
+
+ // Notify any blocked goroutines that they can attempt to
+ // deliver endpoints again.
+ e.acceptCond.Broadcast()
+
return nil
}
@@ -2245,9 +2261,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error {
// The channel may be non-nil when we're restoring the endpoint, and it
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
+ e.acceptMu.Lock()
if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog)
}
+ e.acceptMu.Unlock()
+
e.workerRunning = true
go e.protocolListenLoop( // S/R-SAFE: drained on save.
seqnum.Size(e.receiveBufferAvailable()))
@@ -2276,9 +2295,12 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
}
// Get the new accepted endpoint.
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
var n *endpoint
select {
case n = <-e.acceptedChan:
+ e.acceptCond.Signal()
default:
return nil, nil, tcpip.ErrWouldBlock
}