summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2021-09-29 17:54:48 +0000
committergVisor bot <gvisor-bot@google.com>2021-09-29 17:54:48 +0000
commitc322765e4826bef4847bb4c6bf2330b7df4796e7 (patch)
treea493b79bfd236711c17c07044e795d51c6e9fdea /pkg/tcpip/transport
parentc5d32df9efa5daf091ba384ad23f17f0824cc3c8 (diff)
parent5aa37994c15883f4922ef3d81834d2f8ba3557a1 (diff)
Merge release-20210921.0-40-g5aa37994c (automated)
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/tcp/accept.go109
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go91
3 files changed, 87 insertions, 117 deletions
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index 62b8d9de9..d41e07521 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -20,7 +20,6 @@ import (
"fmt"
"hash"
"io"
- "sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/sleep"
@@ -103,14 +102,14 @@ type listenContext struct {
// pendingMu protects pendingEndpoints. This should only be accessed
// by the listening endpoint's worker goroutine.
- //
- // Lock Ordering: listenEP.workerMu -> pendingMu
pendingMu sync.Mutex
// pending is used to wait for all pendingEndpoints to finish when
// a socket is closed.
pending sync.WaitGroup
// pendingEndpoints is a set of all endpoints for which a handshake is
// in progress.
+ //
+ // +checklocks:pendingMu
pendingEndpoints map[*endpoint]struct{}
}
@@ -265,7 +264,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
return nil, &tcpip.ErrConnectionAborted{}
}
- l.addPendingEndpoint(ep)
// Propagate any inheritable options from the listening endpoint
// to the newly created endpoint.
@@ -275,8 +273,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
- l.removePendingEndpoint(ep)
-
return nil, &tcpip.ErrConnectionAborted{}
}
@@ -295,10 +291,6 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu
ep.mu.Unlock()
ep.Close()
- if l.listenEP != nil {
- l.removePendingEndpoint(ep)
- }
-
ep.drainClosingSegmentQueue()
return nil, err
@@ -336,38 +328,12 @@ func (l *listenContext) performHandshake(s *segment, opts header.TCPSynOptions,
return ep, nil
}
-func (l *listenContext) addPendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- l.pendingEndpoints[n] = struct{}{}
- l.pending.Add(1)
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) removePendingEndpoint(n *endpoint) {
- l.pendingMu.Lock()
- delete(l.pendingEndpoints, n)
- l.pending.Done()
- l.pendingMu.Unlock()
-}
-
-func (l *listenContext) closeAllPendingEndpoints() {
- l.pendingMu.Lock()
- for n := range l.pendingEndpoints {
- n.notifyProtocolGoroutine(notifyClose)
- }
- l.pendingMu.Unlock()
- l.pending.Wait()
-}
-
// +checklocks:h.ep.mu
func (l *listenContext) cleanupFailedHandshake(h *handshake) {
e := h.ep
e.mu.Unlock()
e.Close()
e.notifyAborted()
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.drainClosingSegmentQueue()
e.h = nil
}
@@ -378,9 +344,6 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) {
// +checklocks:h.ep.mu
func (l *listenContext) cleanupCompletedHandshake(h *handshake) {
e := h.ep
- if l.listenEP != nil {
- l.removePendingEndpoint(e)
- }
e.isConnectNotified = true
// Update the receive window scaling. We can't do it before the
@@ -444,21 +407,6 @@ func (e *endpoint) notifyAborted() {
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents)
}
-func (e *endpoint) synRcvdBacklogFull() bool {
- e.acceptMu.Lock()
- acceptedCap := e.accepted.cap
- e.acceptMu.Unlock()
- // The capacity of the accepted queue would always be one greater than the
- // listen backlog. But, the SYNRCVD connections count is always checked
- // against the listen backlog value for Linux parity reason.
- // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
- //
- // We maintain an equality check here as the synRcvdCount is incremented
- // and compared only from a single listener context and the capacity of
- // the accepted queue can only increase by a new listen call.
- return int(atomic.LoadInt32(&e.synRcvdCount)) == acceptedCap-1
-}
-
func (e *endpoint) acceptQueueIsFull() bool {
e.acceptMu.Lock()
full := e.accepted.acceptQueueIsFullLocked()
@@ -500,34 +448,53 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
return nil
}
- alwaysUseSynCookies := func() bool {
+ opts := parseSynSegmentOptions(s)
+
+ useSynCookies, err := func() (bool, tcpip.Error) {
var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies
if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil {
panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err))
}
- return bool(alwaysUseSynCookies)
- }()
-
- opts := parseSynSegmentOptions(s)
- if !alwaysUseSynCookies && !e.synRcvdBacklogFull() {
- atomic.AddInt32(&e.synRcvdCount, 1)
+ if alwaysUseSynCookies {
+ return true, nil
+ }
+ e.acceptMu.Lock()
+ defer e.acceptMu.Unlock()
+
+ ctx.pendingMu.Lock()
+ defer ctx.pendingMu.Unlock()
+ // The capacity of the accepted queue would always be one greater than the
+ // listen backlog. But, the SYNRCVD connections count is always checked
+ // against the listen backlog value for Linux parity reason.
+ // https://github.com/torvalds/linux/blob/7acac4b3196/include/net/inet_connection_sock.h#L280
+ if len(ctx.pendingEndpoints) == e.accepted.cap-1 {
+ return true, nil
+ }
h, err := ctx.startHandshake(s, opts, &waiter.Queue{}, e.owner)
if err != nil {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
- atomic.AddInt32(&e.synRcvdCount, -1)
- return err
+ return false, err
}
+ ctx.pendingEndpoints[h.ep] = struct{}{}
+ ctx.pending.Add(1)
+
go func() {
+ defer func() {
+ ctx.pendingMu.Lock()
+ defer ctx.pendingMu.Unlock()
+ delete(ctx.pendingEndpoints, h.ep)
+ ctx.pending.Done()
+ }()
+
// Note that startHandshake returns a locked endpoint. The force call
// here just makes it so.
if err := h.complete(); err != nil { // +checklocksforce
e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
e.stats.FailedConnectionAttempts.Increment()
ctx.cleanupFailedHandshake(h)
- atomic.AddInt32(&e.synRcvdCount, -1)
return
}
ctx.cleanupCompletedHandshake(h)
@@ -558,7 +525,6 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
e.accepted.endpoints.PushBack(h.ep)
- atomic.AddInt32(&e.synRcvdCount, -1)
return true
}
}()
@@ -570,6 +536,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err
}
}()
+ return false, nil
+ }()
+ if err != nil {
+ return err
+ }
+ if !useSynCookies {
return nil
}
@@ -780,7 +752,12 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) {
e.setEndpointState(StateClose)
// Close any endpoints in SYN-RCVD state.
- ctx.closeAllPendingEndpoints()
+ ctx.pendingMu.Lock()
+ for n := range ctx.pendingEndpoints {
+ n.notifyProtocolGoroutine(notifyClose)
+ }
+ ctx.pendingMu.Unlock()
+ ctx.pending.Wait()
// Do cleanup if needed.
e.completeWorkerLocked()
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 7c8a11cfb..9c5b1b016 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -508,10 +508,6 @@ type endpoint struct {
// and dropped when it is.
segmentQueue segmentQueue `state:"wait"`
- // synRcvdCount is the number of connections for this endpoint that are
- // in SYN-RCVD state; this is only accessed atomically.
- synRcvdCount int32
-
// userMSS if non-zero is the MSS value explicitly set by the user
// for this endpoint using the TCP_MAXSEG setsockopt.
userMSS uint16
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index ba94228d9..13f791243 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -309,7 +309,6 @@ func (e *endpoint) StateFields() []string {
"delay",
"scoreboard",
"segmentQueue",
- "synRcvdCount",
"userMSS",
"maxSynRetries",
"windowClamp",
@@ -368,29 +367,28 @@ func (e *endpoint) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(25, &e.delay)
stateSinkObject.Save(26, &e.scoreboard)
stateSinkObject.Save(27, &e.segmentQueue)
- stateSinkObject.Save(28, &e.synRcvdCount)
- stateSinkObject.Save(29, &e.userMSS)
- stateSinkObject.Save(30, &e.maxSynRetries)
- stateSinkObject.Save(31, &e.windowClamp)
- stateSinkObject.Save(32, &e.sndQueueInfo)
- stateSinkObject.Save(33, &e.cc)
- stateSinkObject.Save(34, &e.keepalive)
- stateSinkObject.Save(35, &e.userTimeout)
- stateSinkObject.Save(36, &e.deferAccept)
- stateSinkObject.Save(37, &e.accepted)
- stateSinkObject.Save(38, &e.rcv)
- stateSinkObject.Save(39, &e.snd)
- stateSinkObject.Save(40, &e.connectingAddress)
- stateSinkObject.Save(41, &e.amss)
- stateSinkObject.Save(42, &e.sendTOS)
- stateSinkObject.Save(43, &e.gso)
- stateSinkObject.Save(44, &e.stats)
- stateSinkObject.Save(45, &e.tcpLingerTimeout)
- stateSinkObject.Save(46, &e.closed)
- stateSinkObject.Save(47, &e.txHash)
- stateSinkObject.Save(48, &e.owner)
- stateSinkObject.Save(49, &e.ops)
- stateSinkObject.Save(50, &e.lastOutOfWindowAckTime)
+ stateSinkObject.Save(28, &e.userMSS)
+ stateSinkObject.Save(29, &e.maxSynRetries)
+ stateSinkObject.Save(30, &e.windowClamp)
+ stateSinkObject.Save(31, &e.sndQueueInfo)
+ stateSinkObject.Save(32, &e.cc)
+ stateSinkObject.Save(33, &e.keepalive)
+ stateSinkObject.Save(34, &e.userTimeout)
+ stateSinkObject.Save(35, &e.deferAccept)
+ stateSinkObject.Save(36, &e.accepted)
+ stateSinkObject.Save(37, &e.rcv)
+ stateSinkObject.Save(38, &e.snd)
+ stateSinkObject.Save(39, &e.connectingAddress)
+ stateSinkObject.Save(40, &e.amss)
+ stateSinkObject.Save(41, &e.sendTOS)
+ stateSinkObject.Save(42, &e.gso)
+ stateSinkObject.Save(43, &e.stats)
+ stateSinkObject.Save(44, &e.tcpLingerTimeout)
+ stateSinkObject.Save(45, &e.closed)
+ stateSinkObject.Save(46, &e.txHash)
+ stateSinkObject.Save(47, &e.owner)
+ stateSinkObject.Save(48, &e.ops)
+ stateSinkObject.Save(49, &e.lastOutOfWindowAckTime)
}
// +checklocksignore
@@ -422,29 +420,28 @@ func (e *endpoint) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(25, &e.delay)
stateSourceObject.Load(26, &e.scoreboard)
stateSourceObject.LoadWait(27, &e.segmentQueue)
- stateSourceObject.Load(28, &e.synRcvdCount)
- stateSourceObject.Load(29, &e.userMSS)
- stateSourceObject.Load(30, &e.maxSynRetries)
- stateSourceObject.Load(31, &e.windowClamp)
- stateSourceObject.Load(32, &e.sndQueueInfo)
- stateSourceObject.Load(33, &e.cc)
- stateSourceObject.Load(34, &e.keepalive)
- stateSourceObject.Load(35, &e.userTimeout)
- stateSourceObject.Load(36, &e.deferAccept)
- stateSourceObject.Load(37, &e.accepted)
- stateSourceObject.LoadWait(38, &e.rcv)
- stateSourceObject.LoadWait(39, &e.snd)
- stateSourceObject.Load(40, &e.connectingAddress)
- stateSourceObject.Load(41, &e.amss)
- stateSourceObject.Load(42, &e.sendTOS)
- stateSourceObject.Load(43, &e.gso)
- stateSourceObject.Load(44, &e.stats)
- stateSourceObject.Load(45, &e.tcpLingerTimeout)
- stateSourceObject.Load(46, &e.closed)
- stateSourceObject.Load(47, &e.txHash)
- stateSourceObject.Load(48, &e.owner)
- stateSourceObject.Load(49, &e.ops)
- stateSourceObject.Load(50, &e.lastOutOfWindowAckTime)
+ stateSourceObject.Load(28, &e.userMSS)
+ stateSourceObject.Load(29, &e.maxSynRetries)
+ stateSourceObject.Load(30, &e.windowClamp)
+ stateSourceObject.Load(31, &e.sndQueueInfo)
+ stateSourceObject.Load(32, &e.cc)
+ stateSourceObject.Load(33, &e.keepalive)
+ stateSourceObject.Load(34, &e.userTimeout)
+ stateSourceObject.Load(35, &e.deferAccept)
+ stateSourceObject.Load(36, &e.accepted)
+ stateSourceObject.LoadWait(37, &e.rcv)
+ stateSourceObject.LoadWait(38, &e.snd)
+ stateSourceObject.Load(39, &e.connectingAddress)
+ stateSourceObject.Load(40, &e.amss)
+ stateSourceObject.Load(41, &e.sendTOS)
+ stateSourceObject.Load(42, &e.gso)
+ stateSourceObject.Load(43, &e.stats)
+ stateSourceObject.Load(44, &e.tcpLingerTimeout)
+ stateSourceObject.Load(45, &e.closed)
+ stateSourceObject.Load(46, &e.txHash)
+ stateSourceObject.Load(47, &e.owner)
+ stateSourceObject.Load(48, &e.ops)
+ stateSourceObject.Load(49, &e.lastOutOfWindowAckTime)
stateSourceObject.LoadValue(10, new(EndpointState), func(y interface{}) { e.loadState(y.(EndpointState)) })
stateSourceObject.AfterLoad(e.afterLoad)
}