summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/tcp/protocol.go
diff options
context:
space:
mode:
authorBhasker Hariharan <bhaskerh@google.com>2020-04-16 16:48:14 -0700
committergVisor bot <gvisor-bot@google.com>2020-04-16 16:49:18 -0700
commit0eda0104a5a7c95a36dd288199ec1e90be9d8be9 (patch)
tree3837585eff602122f4b63275ec5e2c36820edc99 /pkg/tcpip/transport/tcp/protocol.go
parent75e864fc7529bf71484ecabbb2ce2264e96399cf (diff)
Fix data race in tcp_test.
This change makes SynRcvdCountThreshold and the global synRcvdCount into a stack configurable value. This is required because in cases like mod_proxy which create multiple Stack instances the count will be a global value that impacts all Stack instances. Further the tests relied on modifying the global threshold to simulate tests where we want to verify SYN cookie based behaviour. This lead to data races due to the global being modified/read without locks or atomics. PiperOrigin-RevId: 306947723
Diffstat (limited to 'pkg/tcpip/transport/tcp/protocol.go')
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go77
1 files changed, 77 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 91f25c132..effbf203f 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -94,6 +94,63 @@ const (
ccCubic = "cubic"
)
+// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
+// value is protected by a mutex so that we can increment only when it's
+// guaranteed not to go above a threshold.
+type synRcvdCounter struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+ threshold uint64
+}
+
+// inc tries to increment the global number of endpoints in SYN-RCVD state. It
+// succeeds if the increment doesn't make the count go beyond the threshold, and
+// fails otherwise.
+func (s *synRcvdCounter) inc() bool {
+ s.Lock()
+ defer s.Unlock()
+ if s.value >= s.threshold {
+ return false
+ }
+
+ s.pending.Add(1)
+ s.value++
+
+ return true
+}
+
+// dec atomically decrements the global number of endpoints in SYN-RCVD
+// state. It must only be called if a previous call to inc succeeded.
+func (s *synRcvdCounter) dec() {
+ s.Lock()
+ defer s.Unlock()
+ s.value--
+ s.pending.Done()
+}
+
+// synCookiesInUse returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func (s *synRcvdCounter) synCookiesInUse() bool {
+ s.Lock()
+ defer s.Unlock()
+ return s.value >= s.threshold
+}
+
+// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
+func (s *synRcvdCounter) SetThreshold(threshold uint64) {
+ s.Lock()
+ defer s.Unlock()
+ s.threshold = threshold
+}
+
+// Threshold returns the current value of synRcvdCounter.Threhsold.
+func (s *synRcvdCounter) Threshold() uint64 {
+ s.Lock()
+ defer s.Unlock()
+ return s.threshold
+}
+
type protocol struct {
mu sync.RWMutex
sackEnabled bool
@@ -106,6 +163,7 @@ type protocol struct {
tcpLingerTimeout time.Duration
tcpTimeWaitTimeout time.Duration
minRTO time.Duration
+ synRcvdCount synRcvdCounter
dispatcher *dispatcher
}
@@ -282,6 +340,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.Lock()
+ p.synRcvdCount.SetThreshold(uint64(v))
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -350,6 +414,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.RUnlock()
return nil
+ case *tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.RLock()
+ *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ p.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -365,6 +435,12 @@ func (p *protocol) Wait() {
p.dispatcher.wait()
}
+// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
+// instance.
+func (p *protocol) SynRcvdCounter() *synRcvdCounter {
+ return &p.synRcvdCount
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
@@ -374,6 +450,7 @@ func NewProtocol() stack.TransportProtocol {
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
minRTO: MinRTO,
}