diff options
author | gVisor bot <gvisor-bot@google.com> | 2021-09-29 17:54:48 +0000 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-09-29 17:54:48 +0000 |
commit | c322765e4826bef4847bb4c6bf2330b7df4796e7 (patch) | |
tree | a493b79bfd236711c17c07044e795d51c6e9fdea /pkg/tcpip/transport | |
parent | c5d32df9efa5daf091ba384ad23f17f0824cc3c8 (diff) | |
parent | 5aa37994c15883f4922ef3d81834d2f8ba3557a1 (diff) |
Merge release-20210921.0-40-g5aa37994c (automated)
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 109 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 91 |
3 files changed, 87 insertions, 117 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 62b8d9de9..d41e07521 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -20,7 +20,6 @@ import ( "fmt" "hash" "io" - "sync/atomic" "time" "gvisor.dev/gvisor/pkg/sleep" @@ -103,14 +102,14 @@ type listenContext struct { // pendingMu protects pendingEndpoints. This should only be accessed // by the listening endpoint's worker goroutine. - // - // Lock Ordering: listenEP.workerMu -> pendingMu pendingMu sync.Mutex // pending is used to wait for all pendingEndpoints to finish when // a socket is closed. pending sync.WaitGroup // pendingEndpoints is a set of all endpoints for which a handshake is // in progress. + // + // +checklocks:pendingMu pendingEndpoints map[*endpoint]struct{} } @@ -265,7 +264,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu return nil, &tcpip.ErrConnectionAborted{} } - l.addPendingEndpoint(ep) // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. @@ -275,8 +273,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu ep.mu.Unlock() ep.Close() - l.removePendingEndpoint(ep) - return nil, &tcpip.ErrConnectionAborted{} } @@ -295,10 +291,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu ep.mu.Unlock() ep.Close() - if l.listenEP != nil { - l.removePendingEndpoint(ep) - } - ep.drainClosingSegmentQueue() return nil, err @@ -336,38 +328,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions, return ep, nil } -func (l *listenContext) addPendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - l.pendingEndpoints[n] = struct{}{} - l.pending.Add(1) - l.pendingMu.Unlock() -} - -func (l *listenContext) removePendingEndpoint(n *endpoint) { - l.pendingMu.Lock() - delete(l.pendingEndpoints, n) - l.pending.Done() - l.pendingMu.Unlock() -} - -func (l *listenContext) closeAllPendingEndpoints() { - l.pendingMu.Lock() - for n := range l.pendingEndpoints { - n.notifyProtocolGoroutine(notifyClose) - } - l.pendingMu.Unlock() - l.pending.Wait() -} - // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep e.mu.Unlock() e.Close() e.notifyAborted() - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.drainClosingSegmentQueue() e.h = nil } @@ -378,9 +344,6 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep - if l.listenEP != nil { - l.removePendingEndpoint(e) - } e.isConnectNotified = true // Update the receive window scaling. We can't do it before the @@ -444,21 +407,6 @@ func (e *endpoint) notifyAborted() { e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) } -func (e *endpoint) synRcvdBacklogFull() bool { - e.acceptMu.Lock() - acceptedCap := e.accepted.cap - e.acceptMu.Unlock() - // The capacity of the accepted queue would always be one greater than the - // listen backlog. But, the SYNRCVD connections count is always checked - // against the listen backlog value for Linux parity reason. - // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 - // - // We maintain an equality check here as the synRcvdCount is incremented - // and compared only from a single listener context and the capacity of - // the accepted queue can only increase by a new listen call. - return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1 -} - func (e *endpoint) acceptQueueIsFull() bool { e.acceptMu.Lock() full := e.accepted.acceptQueueIsFullLocked() @@ -500,34 +448,53 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } - alwaysUseSynCookies := func() bool { + opts := parseSynSegmentOptions(s) + + useSynCookies, err := func() (bool, tcpip.Error) { var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) } - return bool(alwaysUseSynCookies) - }() - - opts := parseSynSegmentOptions(s) - if !alwaysUseSynCookies && !e.synRcvdBacklogFull() { - atomic.AddInt32(&e.synRcvdCount, 1) + if alwaysUseSynCookies { + return true, nil + } + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + + ctx.pendingMu.Lock() + defer ctx.pendingMu.Unlock() + // The capacity of the accepted queue would always be one greater than the + // listen backlog. But, the SYNRCVD connections count is always checked + // against the listen backlog value for Linux parity reason. + // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280 + if len(ctx.pendingEndpoints) == e.accepted.cap-1 { + return true, nil + } h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - atomic.AddInt32(&e.synRcvdCount, -1) - return err + return false, err } + ctx.pendingEndpoints[h.ep] = struct{}{} + ctx.pending.Add(1) + go func() { + defer func() { + ctx.pendingMu.Lock() + defer ctx.pendingMu.Unlock() + delete(ctx.pendingEndpoints, h.ep) + ctx.pending.Done() + }() + // Note that startHandshake returns a locked endpoint. The force call // here just makes it so. if err := h.complete(); err != nil { // +checklocksforce e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() ctx.cleanupFailedHandshake(h) - atomic.AddInt32(&e.synRcvdCount, -1) return } ctx.cleanupCompletedHandshake(h) @@ -558,7 +525,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } e.accepted.endpoints.PushBack(h.ep) - atomic.AddInt32(&e.synRcvdCount, -1) return true } }() @@ -570,6 +536,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } }() + return false, nil + }() + if err != nil { + return err + } + if !useSynCookies { return nil } @@ -780,7 +752,12 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.setEndpointState(StateClose) // Close any endpoints in SYN-RCVD state. - ctx.closeAllPendingEndpoints() + ctx.pendingMu.Lock() + for n := range ctx.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + ctx.pendingMu.Unlock() + ctx.pending.Wait() // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 7c8a11cfb..9c5b1b016 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -508,10 +508,6 @@ type endpoint struct { // and dropped when it is. segmentQueue segmentQueue `state:"wait"` - // synRcvdCount is the number of connections for this endpoint that are - // in SYN-RCVD state; this is only accessed atomically. - synRcvdCount int32 - // userMSS if non-zero is the MSS value explicitly set by the user // for this endpoint using the TCP_MAXSEG setsockopt. userMSS uint16 diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index ba94228d9..13f791243 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -309,7 +309,6 @@ func (e *endpoint) StateFields() []string { "delay", "scoreboard", "segmentQueue", - "synRcvdCount", "userMSS", "maxSynRetries", "windowClamp", @@ -368,29 +367,28 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(25, &e.delay) stateSinkObject.Save(26, &e.scoreboard) stateSinkObject.Save(27, &e.segmentQueue) - stateSinkObject.Save(28, &e.synRcvdCount) - stateSinkObject.Save(29, &e.userMSS) - stateSinkObject.Save(30, &e.maxSynRetries) - stateSinkObject.Save(31, &e.windowClamp) - stateSinkObject.Save(32, &e.sndQueueInfo) - stateSinkObject.Save(33, &e.cc) - stateSinkObject.Save(34, &e.keepalive) - stateSinkObject.Save(35, &e.userTimeout) - stateSinkObject.Save(36, &e.deferAccept) - stateSinkObject.Save(37, &e.accepted) - stateSinkObject.Save(38, &e.rcv) - stateSinkObject.Save(39, &e.snd) - stateSinkObject.Save(40, &e.connectingAddress) - stateSinkObject.Save(41, &e.amss) - stateSinkObject.Save(42, &e.sendTOS) - stateSinkObject.Save(43, &e.gso) - stateSinkObject.Save(44, &e.stats) - stateSinkObject.Save(45, &e.tcpLingerTimeout) - stateSinkObject.Save(46, &e.closed) - stateSinkObject.Save(47, &e.txHash) - stateSinkObject.Save(48, &e.owner) - stateSinkObject.Save(49, &e.ops) - stateSinkObject.Save(50, &e.lastOutOfWindowAckTime) + stateSinkObject.Save(28, &e.userMSS) + stateSinkObject.Save(29, &e.maxSynRetries) + stateSinkObject.Save(30, &e.windowClamp) + stateSinkObject.Save(31, &e.sndQueueInfo) + stateSinkObject.Save(32, &e.cc) + stateSinkObject.Save(33, &e.keepalive) + stateSinkObject.Save(34, &e.userTimeout) + stateSinkObject.Save(35, &e.deferAccept) + stateSinkObject.Save(36, &e.accepted) + stateSinkObject.Save(37, &e.rcv) + stateSinkObject.Save(38, &e.snd) + stateSinkObject.Save(39, &e.connectingAddress) + stateSinkObject.Save(40, &e.amss) + stateSinkObject.Save(41, &e.sendTOS) + stateSinkObject.Save(42, &e.gso) + stateSinkObject.Save(43, &e.stats) + stateSinkObject.Save(44, &e.tcpLingerTimeout) + stateSinkObject.Save(45, &e.closed) + stateSinkObject.Save(46, &e.txHash) + stateSinkObject.Save(47, &e.owner) + stateSinkObject.Save(48, &e.ops) + stateSinkObject.Save(49, &e.lastOutOfWindowAckTime) } // +checklocksignore @@ -422,29 +420,28 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(25, &e.delay) stateSourceObject.Load(26, &e.scoreboard) stateSourceObject.LoadWait(27, &e.segmentQueue) - stateSourceObject.Load(28, &e.synRcvdCount) - stateSourceObject.Load(29, &e.userMSS) - stateSourceObject.Load(30, &e.maxSynRetries) - stateSourceObject.Load(31, &e.windowClamp) - stateSourceObject.Load(32, &e.sndQueueInfo) - stateSourceObject.Load(33, &e.cc) - stateSourceObject.Load(34, &e.keepalive) - stateSourceObject.Load(35, &e.userTimeout) - stateSourceObject.Load(36, &e.deferAccept) - stateSourceObject.Load(37, &e.accepted) - stateSourceObject.LoadWait(38, &e.rcv) - stateSourceObject.LoadWait(39, &e.snd) - stateSourceObject.Load(40, &e.connectingAddress) - stateSourceObject.Load(41, &e.amss) - stateSourceObject.Load(42, &e.sendTOS) - stateSourceObject.Load(43, &e.gso) - stateSourceObject.Load(44, &e.stats) - stateSourceObject.Load(45, &e.tcpLingerTimeout) - stateSourceObject.Load(46, &e.closed) - stateSourceObject.Load(47, &e.txHash) - stateSourceObject.Load(48, &e.owner) - stateSourceObject.Load(49, &e.ops) - stateSourceObject.Load(50, &e.lastOutOfWindowAckTime) + stateSourceObject.Load(28, &e.userMSS) + stateSourceObject.Load(29, &e.maxSynRetries) + stateSourceObject.Load(30, &e.windowClamp) + stateSourceObject.Load(31, &e.sndQueueInfo) + stateSourceObject.Load(32, &e.cc) + stateSourceObject.Load(33, &e.keepalive) + stateSourceObject.Load(34, &e.userTimeout) + stateSourceObject.Load(35, &e.deferAccept) + stateSourceObject.Load(36, &e.accepted) + stateSourceObject.LoadWait(37, &e.rcv) + stateSourceObject.LoadWait(38, &e.snd) + stateSourceObject.Load(39, &e.connectingAddress) + stateSourceObject.Load(40, &e.amss) + stateSourceObject.Load(41, &e.sendTOS) + stateSourceObject.Load(42, &e.gso) + stateSourceObject.Load(43, &e.stats) + stateSourceObject.Load(44, &e.tcpLingerTimeout) + stateSourceObject.Load(45, &e.closed) + stateSourceObject.Load(46, &e.txHash) + stateSourceObject.Load(47, &e.owner) + stateSourceObject.Load(48, &e.ops) + stateSourceObject.Load(49, &e.lastOutOfWindowAckTime) stateSourceObject.LoadValue(10, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) }) stateSourceObject.AfterLoad(e.afterLoad) } |