summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/tcp/accept.go109
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
2 files changed, 43 insertions, 70 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 62b8d9de9..d41e07521 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -20,7 +20,6 @@ import (
"fmt"
"hash"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -103,14 +102,14 @@ type listenContext struct {
// pendingMu protects pendingEndpoints. This should only be accessed
// by the listening endpoint's worker goroutine.
- //
- // Lock Ordering: listenEP.workerMu -> pendingMu
pendingMu sync.Mutex
// pending is used to wait for all pendingEndpoints to finish when
// a socket is closed.
pending sync.WaitGroup
// pendingEndpoints is a set of all endpoints for which a handshake is
// in progress.
+ //
+ // +checklocks:pendingMu
pendingEndpoints map[*endpoint]struct{}
}
@@ -265,7 +264,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
return nil, &tcpip.ErrConnectionAborted{}
}
- l.addPendingEndpoint(ep)
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
@@ -275,8 +273,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
- l.removePendingEndpoint(ep)
-
return nil, &tcpip.ErrConnectionAborted{}
}
@@ -295,10 +291,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
ep.drainClosingSegmentQueue()
return nil, err
@@ -336,38 +328,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions,
return ep, nil
}
-func (l *listenContext) addPendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- l.pendingEndpoints[n] = struct{}{}
- l.pending.Add(1)
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) removePendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- delete(l.pendingEndpoints, n)
- l.pending.Done()
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) closeAllPendingEndpoints() {
- l.pendingMu.Lock()
- for n := range l.pendingEndpoints {
- n.notifyProtocolGoroutine(notifyClose)
- }
- l.pendingMu.Unlock()
- l.pending.Wait()
-}
-
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
e.Close()
e.notifyAborted()
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.drainClosingSegmentQueue()
e.h = nil
}
@@ -378,9 +344,6 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
@@ -444,21 +407,6 @@ func (e *endpoint) notifyAborted() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-func (e *endpoint) synRcvdBacklogFull() bool {
- e.acceptMu.Lock()
- acceptedCap := e.accepted.cap
- e.acceptMu.Unlock()
- // The capacity of the accepted queue would always be one greater than the
- // listen backlog. But, the SYNRCVD connections count is always checked
- // against the listen backlog value for Linux parity reason.
- // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
- //
- // We maintain an equality check here as the synRcvdCount is incremented
- // and compared only from a single listener context and the capacity of
- // the accepted queue can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
-}
-
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
full := e.accepted.acceptQueueIsFullLocked()
@@ -500,34 +448,53 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
}
- alwaysUseSynCookies := func() bool {
+ opts := parseSynSegmentOptions(s)
+
+ useSynCookies, err := func() (bool, tcpip.Error) {
var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
}
- return bool(alwaysUseSynCookies)
- }()
-
- opts := parseSynSegmentOptions(s)
- if !alwaysUseSynCookies && !e.synRcvdBacklogFull() {
- atomic.AddInt32(&e.synRcvdCount, 1)
+ if alwaysUseSynCookies {
+ return true, nil
+ }
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+
+ ctx.pendingMu.Lock()
+ defer ctx.pendingMu.Unlock()
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ if len(ctx.pendingEndpoints) == e.accepted.cap-1 {
+ return true, nil
+ }
h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- atomic.AddInt32(&e.synRcvdCount, -1)
- return err
+ return false, err
}
+ ctx.pendingEndpoints[h.ep] = struct{}{}
+ ctx.pending.Add(1)
+
go func() {
+ defer func() {
+ ctx.pendingMu.Lock()
+ defer ctx.pendingMu.Unlock()
+ delete(ctx.pendingEndpoints, h.ep)
+ ctx.pending.Done()
+ }()
+
// Note that startHandshake returns a locked endpoint. The force call
// here just makes it so.
if err := h.complete(); err != nil { // +checklocksforce
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
- atomic.AddInt32(&e.synRcvdCount, -1)
return
}
ctx.cleanupCompletedHandshake(h)
@@ -558,7 +525,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
e.accepted.endpoints.PushBack(h.ep)
- atomic.AddInt32(&e.synRcvdCount, -1)
return true
}
}()
@@ -570,6 +536,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
}()
+ return false, nil
+ }()
+ if err != nil {
+ return err
+ }
+ if !useSynCookies {
return nil
}
@@ -780,7 +752,12 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.setEndpointState(StateClose)
// Close any endpoints in SYN-RCVD state.
- ctx.closeAllPendingEndpoints()
+ ctx.pendingMu.Lock()
+ for n := range ctx.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
+ }
+ ctx.pendingMu.Unlock()
+ ctx.pending.Wait()
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7c8a11cfb..9c5b1b016 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -508,10 +508,6 @@ type endpoint struct {
// and dropped when it is.
segmentQueue segmentQueue `state:"wait"`
- // synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state; this is only accessed atomically.
- synRcvdCount int32
-
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
userMSS uint16