diff options
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 62 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 23 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 7 |
3 files changed, 42 insertions, 50 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 2a77f07cf..95fcdc1b6 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -100,18 +100,6 @@ type listenContext struct { // netProto indicates the network protocol(IPv4/v6) for the listening // endpoint. netProto tcpip.NetworkProtocolNumber - - // pendingMu protects pendingEndpoints. This should only be accessed - // by the listening endpoint's worker goroutine. - pendingMu sync.Mutex - // pending is used to wait for all pendingEndpoints to finish when - // a socket is closed. - pending sync.WaitGroup - // pendingEndpoints is a set of all endpoints for which a handshake is - // in progress. - // - // +checklocks:pendingMu - pendingEndpoints map[*endpoint]struct{} } // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. @@ -122,14 +110,13 @@ func timeStamp(clock tcpip.Clock) uint32 { // newListenContext creates a new listen context. func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { l := &listenContext{ - stack: stk, - protocol: protocol, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6Only: v6Only, - netProto: netProto, - listenEP: listenEP, - pendingEndpoints: make(map[*endpoint]struct{}), + stack: stk, + protocol: protocol, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6Only: v6Only, + netProto: netProto, + listenEP: listenEP, } for i := range l.nonce { @@ -422,6 +409,10 @@ type acceptQueue struct { // dispatcher's list. endpoints list.List `state:".([]*endpoint)"` + // pendingEndpoints is a set of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[*endpoint]struct{} + // capacity is the maximum number of endpoints that can be in endpoints. capacity int } @@ -473,13 +464,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.acceptMu.Lock() defer e.acceptMu.Unlock() - ctx.pendingMu.Lock() - defer ctx.pendingMu.Unlock() // The capacity of the accepted queue would always be one greater than the // listen backlog. But, the SYNRCVD connections count is always checked // against the listen backlog value for Linux parity reason. // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 - if len(ctx.pendingEndpoints) == e.acceptQueue.capacity-1 { + if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 { return true, nil } @@ -490,15 +479,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return false, err } - ctx.pendingEndpoints[h.ep] = struct{}{} - ctx.pending.Add(1) + e.acceptQueue.pendingEndpoints[h.ep] = struct{}{} + e.pendingAccepted.Add(1) go func() { defer func() { - ctx.pendingMu.Lock() - defer ctx.pendingMu.Unlock() - delete(ctx.pendingEndpoints, h.ep) - ctx.pending.Done() + e.pendingAccepted.Done() + + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + delete(e.acceptQueue.pendingEndpoints, h.ep) }() // Note that startHandshake returns a locked endpoint. The force call @@ -514,11 +504,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() // Deliver the endpoint to the accept queue. - e.mu.Lock() - e.pendingAccepted.Add(1) - e.mu.Unlock() - defer e.pendingAccepted.Done() - + // // Drop the lock before notifying to avoid deadlock in user-specified // callbacks. delivered := func() bool { @@ -761,14 +747,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { defer func() { e.setEndpointState(StateClose) - // Close any endpoints in SYN-RCVD state. - ctx.pendingMu.Lock() - for n := range ctx.pendingEndpoints { - n.notifyProtocolGoroutine(notifyClose) - } - ctx.pendingMu.Unlock() - ctx.pending.Wait() - // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6fca6346b..b60f9becf 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1081,16 +1081,20 @@ func (e *endpoint) closeNoShutdownLocked() { // handshake but not yet been delivered to the application. func (e *endpoint) closePendingAcceptableConnectionsLocked() { e.acceptMu.Lock() - acceptedCopy := e.acceptQueue - e.acceptQueue = acceptQueue{} + // Close any endpoints in SYN-RCVD state. + for n := range e.acceptQueue.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + e.acceptQueue.pendingEndpoints = nil + // Reset all connections that are waiting to be accepted. + for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() { + n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) + } + e.acceptQueue.endpoints.Init() e.acceptMu.Unlock() e.acceptCond.Broadcast() - // Reset all connections that are waiting to be accepted. - for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() { - n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset) - } // Wait for reset of all endpoints that are still waiting to be delivered to // the now closed accepted. e.pendingAccepted.Wait() @@ -2490,6 +2494,10 @@ func (e *endpoint) listen(backlog int) tcpip.Error { } e.acceptQueue.capacity = backlog + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) + } + e.shutdownFlags = 0 e.rcvQueueInfo.rcvQueueMu.Lock() e.rcvQueueInfo.RcvClosed = false @@ -2529,6 +2537,9 @@ func (e *endpoint) listen(backlog int) tcpip.Error { // may be pre-populated with some previously accepted (but not Accepted) // endpoints. e.acceptMu.Lock() + if e.acceptQueue.pendingEndpoints == nil { + e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{}) + } if e.acceptQueue.capacity == 0 { e.acceptQueue.capacity = backlog } diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index 7c5ef8952..8c28da609 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -14,6 +14,7 @@ func (a *acceptQueue) StateTypeName() string { func (a *acceptQueue) StateFields() []string { return []string{ "endpoints", + "pendingEndpoints", "capacity", } } @@ -26,14 +27,16 @@ func (a *acceptQueue) StateSave(stateSinkObject state.Sink) { var endpointsValue []*endpoint endpointsValue = a.saveEndpoints() stateSinkObject.SaveValue(0, endpointsValue) - stateSinkObject.Save(1, &a.capacity) + stateSinkObject.Save(1, &a.pendingEndpoints) + stateSinkObject.Save(2, &a.capacity) } func (a *acceptQueue) afterLoad() {} // +checklocksignore func (a *acceptQueue) StateLoad(stateSourceObject state.Source) { - stateSourceObject.Load(1, &a.capacity) + stateSourceObject.Load(1, &a.pendingEndpoints) + stateSourceObject.Load(2, &a.capacity) stateSourceObject.LoadValue(0, new([]*endpoint), func(y interface{}) { a.loadEndpoints(y.([]*endpoint)) }) } |