From 5eede4e7563e245a685d6529dffddbf9c3a53f50 Mon Sep 17 00:00:00 2001 From: Mithun Iyer Date: Tue, 16 Mar 2021 15:06:26 -0700 Subject: Fix a race with synRcvdCount and accept There is a race in handling new incoming connections on a listening endpoint that causes the endpoint to reply to more incoming SYNs than what is permitted by the listen backlog. The race occurs when there is a successful passive connection handshake and the synRcvdCount counter is decremented, followed by the endpoint delivered to the accept queue. In the window of time between synRcvdCount decrementing and the endpoint being enqueued for accept, new incoming SYNs can be handled without honoring the listen backlog value, as the backlog could be perceived not full. Fixes #5637 PiperOrigin-RevId: 363279372 --- pkg/tcpip/transport/tcp/accept.go | 25 ++++++++++++------------- pkg/tcpip/transport/tcp/endpoint.go | 4 ++-- 2 files changed, 14 insertions(+), 15 deletions(-) (limited to 'pkg/tcpip/transport/tcp') diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 3b574837c..0a2f3291c 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -20,6 +20,7 @@ import ( "fmt" "hash" "io" + "sync/atomic" "time" "gvisor.dev/gvisor/pkg/rand" @@ -390,7 +391,7 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // deliverAccepted delivers the newly-accepted endpoint to the listener. If the // endpoint has transitioned out of the listen state (acceptedChan is nil), // the new endpoint is closed instead. -func (e *endpoint) deliverAccepted(n *endpoint) { +func (e *endpoint) deliverAccepted(n *endpoint, withSynCookie bool) { e.mu.Lock() e.pendingAccepted.Add(1) e.mu.Unlock() @@ -405,6 +406,9 @@ func (e *endpoint) deliverAccepted(n *endpoint) { } select { case e.acceptedChan <- n: + if !withSynCookie { + atomic.AddInt32(&e.synRcvdCount, -1) + } e.acceptMu.Unlock() e.waiterQueue.Notify(waiter.EventIn) return @@ -476,7 +480,7 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - e.synRcvdCount-- + atomic.AddInt32(&e.synRcvdCount, -1) return err } @@ -486,18 +490,13 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() + atomic.AddInt32(&e.synRcvdCount, -1) return } ctx.cleanupCompletedHandshake(h) - e.mu.Lock() - e.synRcvdCount-- - e.mu.Unlock() h.ep.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - e.deliverAccepted(h.ep) + e.deliverAccepted(h.ep, false /*withSynCookie*/) }() // S/R-SAFE: synRcvdCount is the barrier. return nil @@ -505,17 +504,17 @@ func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header func (e *endpoint) incSynRcvdCount() bool { e.acceptMu.Lock() - canInc := e.synRcvdCount < cap(e.acceptedChan) + canInc := int(atomic.LoadInt32(&e.synRcvdCount)) < cap(e.acceptedChan) e.acceptMu.Unlock() if canInc { - e.synRcvdCount++ + atomic.AddInt32(&e.synRcvdCount, 1) } return canInc } func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() - full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + full := len(e.acceptedChan)+int(atomic.LoadInt32(&e.synRcvdCount)) >= cap(e.acceptedChan) e.acceptMu.Unlock() return full } @@ -737,7 +736,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err // Start the protocol goroutine. n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - go e.deliverAccepted(n) + go e.deliverAccepted(n, true /*withSynCookie*/) return nil default: diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 129f36d11..43d344350 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -532,8 +532,8 @@ type endpoint struct { segmentQueue segmentQueue `state:"wait"` // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state. - synRcvdCount int + // in SYN-RCVD state; this is only accessed atomically. + synRcvdCount int32 // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. -- cgit v1.2.3