summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/accept.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/tcp/accept.go')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go89
1 files changed, 71 insertions, 18 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 52fd1bfa3..e9c5099ea 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -96,6 +96,17 @@ type listenContext struct {
hasher hash.Hash
v6only bool
netProto tcpip.NetworkProtocolNumber
+ // pendingMu protects pendingEndpoints. This should only be accessed
+ // by the listening endpoint's worker goroutine.
+ //
+ // Lock Ordering: listenEP.workerMu -> pendingMu
+ pendingMu sync.Mutex
+ // pending is used to wait for all pendingEndpoints to finish when
+ // a socket is closed.
+ pending sync.WaitGroup
+ // pendingEndpoints is a map of all endpoints for which a handshake is
+ // in progress.
+ pendingEndpoints map[stack.TransportEndpointID]*endpoint
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
@@ -133,14 +144,15 @@ func decSynRcvdCount() {
}
// newListenContext creates a new listen context.
-func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
- stack: stack,
- rcvWnd: rcvWnd,
- hasher: sha1.New(),
- v6only: v6only,
- netProto: netProto,
- listenEP: listenEP,
+ stack: stk,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6only: v6only,
+ netProto: netProto,
+ listenEP: listenEP,
+ pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
rand.Read(l.nonce[0][:])
@@ -253,6 +265,17 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return nil, err
}
+ // listenEP is nil when listenContext is used by tcp.Forwarder.
+ if l.listenEP != nil {
+ l.listenEP.mu.Lock()
+ if l.listenEP.state != StateListen {
+ l.listenEP.mu.Unlock()
+ return nil, tcpip.ErrConnectionAborted
+ }
+ l.addPendingEndpoint(ep)
+ l.listenEP.mu.Unlock()
+ }
+
// Perform the 3-way handshake.
h := newHandshake(ep, seqnum.Size(ep.initialReceiveWindow()))
@@ -260,6 +283,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
if err := h.execute(); err != nil {
ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
ep.Close()
+ if l.listenEP != nil {
+ l.removePendingEndpoint(ep)
+ }
return nil, err
}
ep.mu.Lock()
@@ -274,15 +300,41 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head
return ep, nil
}
+func (l *listenContext) addPendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ l.pendingEndpoints[n.id] = n
+ l.pending.Add(1)
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) removePendingEndpoint(n *endpoint) {
+ l.pendingMu.Lock()
+ delete(l.pendingEndpoints, n.id)
+ l.pending.Done()
+ l.pendingMu.Unlock()
+}
+
+func (l *listenContext) closeAllPendingEndpoints() {
+ l.pendingMu.Lock()
+ for _, n := range l.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
+ }
+ l.pendingMu.Unlock()
+ l.pending.Wait()
+}
+
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state, the new endpoint is closed
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
- e.mu.RLock()
+ e.mu.Lock()
state := e.state
- e.mu.RUnlock()
+ e.pendingAccepted.Add(1)
+ defer e.pendingAccepted.Done()
+ acceptedChan := e.acceptedChan
+ e.mu.Unlock()
if state == StateListen {
- e.acceptedChan <- n
+ acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
} else {
n.Close()
@@ -304,7 +356,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
return
}
-
+ ctx.removePendingEndpoint(n)
e.deliverAccepted(n)
}
@@ -451,6 +503,11 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ e.mu.Lock()
+ v6only := e.v6only
+ e.mu.Unlock()
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
@@ -458,6 +515,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.mu.Lock()
e.state = StateClose
+ // close any endpoints in SYN-RCVD state.
+ ctx.closeAllPendingEndpoints()
+
// Do cleanup if needed.
e.completeWorkerLocked()
@@ -470,12 +530,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
}()
- e.mu.Lock()
- v6only := e.v6only
- e.mu.Unlock()
-
- ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
-
s := sleep.Sleeper{}
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
@@ -492,7 +546,6 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
e.handleListenSegment(ctx, s)
s.decRef()
}
- synRcvdCount.pending.Wait()
close(e.drainDone)
<-e.undrain
}