summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/accept.go62
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go23
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go7
3 files changed, 42 insertions, 50 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 2a77f07cf..95fcdc1b6 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -100,18 +100,6 @@ type listenContext struct {
// netProto indicates the network protocol(IPv4/v6) for the listening
// endpoint.
netProto tcpip.NetworkProtocolNumber
-
- // pendingMu protects pendingEndpoints. This should only be accessed
- // by the listening endpoint's worker goroutine.
- pendingMu sync.Mutex
- // pending is used to wait for all pendingEndpoints to finish when
- // a socket is closed.
- pending sync.WaitGroup
- // pendingEndpoints is a set of all endpoints for which a handshake is
- // in progress.
- //
- // +checklocks:pendingMu
- pendingEndpoints map[*endpoint]struct{}
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -122,14 +110,13 @@ func timeStamp(clock tcpip.Clock) uint32 {
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, protocol *protocol, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stk,
- protocol: protocol,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6Only: v6Only,
- netProto: netProto,
- listenEP: listenEP,
- pendingEndpoints: make(map[*endpoint]struct{}),
+ stack: stk,
+ protocol: protocol,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6Only: v6Only,
+ netProto: netProto,
+ listenEP: listenEP,
}
for i := range l.nonce {
@@ -422,6 +409,10 @@ type acceptQueue struct {
// dispatcher's list.
endpoints list.List `state:".([]*endpoint)"`
+ // pendingEndpoints is a set of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[*endpoint]struct{}
+
// capacity is the maximum number of endpoints that can be in endpoints.
capacity int
}
@@ -473,13 +464,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.acceptMu.Lock()
defer e.acceptMu.Unlock()
- ctx.pendingMu.Lock()
- defer ctx.pendingMu.Unlock()
// The capacity of the accepted queue would always be one greater than the
// listen backlog. But, the SYNRCVD connections count is always checked
// against the listen backlog value for Linux parity reason.
// https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
- if len(ctx.pendingEndpoints) == e.acceptQueue.capacity-1 {
+ if len(e.acceptQueue.pendingEndpoints) == e.acceptQueue.capacity-1 {
return true, nil
}
@@ -490,15 +479,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return false, err
}
- ctx.pendingEndpoints[h.ep] = struct{}{}
- ctx.pending.Add(1)
+ e.acceptQueue.pendingEndpoints[h.ep] = struct{}{}
+ e.pendingAccepted.Add(1)
go func() {
defer func() {
- ctx.pendingMu.Lock()
- defer ctx.pendingMu.Unlock()
- delete(ctx.pendingEndpoints, h.ep)
- ctx.pending.Done()
+ e.pendingAccepted.Done()
+
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+ delete(e.acceptQueue.pendingEndpoints, h.ep)
}()
// Note that startHandshake returns a locked endpoint. The force call
@@ -514,11 +504,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
// Deliver the endpoint to the accept queue.
- e.mu.Lock()
- e.pendingAccepted.Add(1)
- e.mu.Unlock()
- defer e.pendingAccepted.Done()
-
+ //
// Drop the lock before notifying to avoid deadlock in user-specified
// callbacks.
delivered := func() bool {
@@ -761,14 +747,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
defer func() {
e.setEndpointState(StateClose)
- // Close any endpoints in SYN-RCVD state.
- ctx.pendingMu.Lock()
- for n := range ctx.pendingEndpoints {
- n.notifyProtocolGoroutine(notifyClose)
- }
- ctx.pendingMu.Unlock()
- ctx.pending.Wait()
-
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6fca6346b..b60f9becf 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1081,16 +1081,20 @@ func (e *endpoint) closeNoShutdownLocked() {
// handshake but not yet been delivered to the application.
func (e *endpoint) closePendingAcceptableConnectionsLocked() {
e.acceptMu.Lock()
- acceptedCopy := e.acceptQueue
- e.acceptQueue = acceptQueue{}
+ // Close any endpoints in SYN-RCVD state.
+ for n := range e.acceptQueue.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
+ }
+ e.acceptQueue.pendingEndpoints = nil
+ // Reset all connections that are waiting to be accepted.
+ for n := e.acceptQueue.endpoints.Front(); n != nil; n = n.Next() {
+ n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
+ }
+ e.acceptQueue.endpoints.Init()
e.acceptMu.Unlock()
e.acceptCond.Broadcast()
- // Reset all connections that are waiting to be accepted.
- for n := acceptedCopy.endpoints.Front(); n != nil; n = n.Next() {
- n.Value.(*endpoint).notifyProtocolGoroutine(notifyReset)
- }
// Wait for reset of all endpoints that are still waiting to be delivered to
// the now closed accepted.
e.pendingAccepted.Wait()
@@ -2490,6 +2494,10 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
}
e.acceptQueue.capacity = backlog
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
+
e.shutdownFlags = 0
e.rcvQueueInfo.rcvQueueMu.Lock()
e.rcvQueueInfo.RcvClosed = false
@@ -2529,6 +2537,9 @@ func (e *endpoint) listen(backlog int) tcpip.Error {
// may be pre-populated with some previously accepted (but not Accepted)
// endpoints.
e.acceptMu.Lock()
+ if e.acceptQueue.pendingEndpoints == nil {
+ e.acceptQueue.pendingEndpoints = make(map[*endpoint]struct{})
+ }
if e.acceptQueue.capacity == 0 {
e.acceptQueue.capacity = backlog
}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index 7c5ef8952..8c28da609 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -14,6 +14,7 @@ func (a *acceptQueue) StateTypeName() string {
func (a *acceptQueue) StateFields() []string {
return []string{
"endpoints",
+ "pendingEndpoints",
"capacity",
}
}
@@ -26,14 +27,16 @@ func (a *acceptQueue) StateSave(stateSinkObject state.Sink) {
var endpointsValue []*endpoint
endpointsValue = a.saveEndpoints()
stateSinkObject.SaveValue(0, endpointsValue)
- stateSinkObject.Save(1, &a.capacity)
+ stateSinkObject.Save(1, &a.pendingEndpoints)
+ stateSinkObject.Save(2, &a.capacity)
}
func (a *acceptQueue) afterLoad() {}
// +checklocksignore
func (a *acceptQueue) StateLoad(stateSourceObject state.Source) {
- stateSourceObject.Load(1, &a.capacity)
+ stateSourceObject.Load(1, &a.pendingEndpoints)
+ stateSourceObject.Load(2, &a.capacity)
stateSourceObject.LoadValue(0, new([]*endpoint), func(y interface{}) { a.loadEndpoints(y.([]*endpoint)) })
}