diff options
author | Bhasker Hariharan <bhaskerh@google.com> | 2020-04-16 16:48:14 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-04-16 16:49:18 -0700 |
commit | 0eda0104a5a7c95a36dd288199ec1e90be9d8be9 (patch) | |
tree | 3837585eff602122f4b63275ec5e2c36820edc99 /pkg/tcpip/transport/tcp/protocol.go | |
parent | 75e864fc7529bf71484ecabbb2ce2264e96399cf (diff) |
Fix data race in tcp_test.
This change makes SynRcvdCountThreshold and the global synRcvdCount into a stack
configurable value. This is required because in cases like mod_proxy which
create multiple Stack instances the count will be a global value that impacts
all Stack instances.
Further the tests relied on modifying the global threshold to simulate tests
where we want to verify SYN cookie based behaviour. This lead to data races due
to the global being modified/read without locks or atomics.
PiperOrigin-RevId: 306947723
Diffstat (limited to 'pkg/tcpip/transport/tcp/protocol.go')
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 91f25c132..effbf203f 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -94,6 +94,63 @@ const ( ccCubic = "cubic" ) +// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The +// value is protected by a mutex so that we can increment only when it's +// guaranteed not to go above a threshold. +type synRcvdCounter struct { + sync.Mutex + value uint64 + pending sync.WaitGroup + threshold uint64 +} + +// inc tries to increment the global number of endpoints in SYN-RCVD state. It +// succeeds if the increment doesn't make the count go beyond the threshold, and +// fails otherwise. +func (s *synRcvdCounter) inc() bool { + s.Lock() + defer s.Unlock() + if s.value >= s.threshold { + return false + } + + s.pending.Add(1) + s.value++ + + return true +} + +// dec atomically decrements the global number of endpoints in SYN-RCVD +// state. It must only be called if a previous call to inc succeeded. +func (s *synRcvdCounter) dec() { + s.Lock() + defer s.Unlock() + s.value-- + s.pending.Done() +} + +// synCookiesInUse returns true if the synRcvdCount is greater than +// SynRcvdCountThreshold. +func (s *synRcvdCounter) synCookiesInUse() bool { + s.Lock() + defer s.Unlock() + return s.value >= s.threshold +} + +// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. +func (s *synRcvdCounter) SetThreshold(threshold uint64) { + s.Lock() + defer s.Unlock() + s.threshold = threshold +} + +// Threshold returns the current value of synRcvdCounter.Threhsold. +func (s *synRcvdCounter) Threshold() uint64 { + s.Lock() + defer s.Unlock() + return s.threshold +} + type protocol struct { mu sync.RWMutex sackEnabled bool @@ -106,6 +163,7 @@ type protocol struct { tcpLingerTimeout time.Duration tcpTimeWaitTimeout time.Duration minRTO time.Duration + synRcvdCount synRcvdCounter dispatcher *dispatcher } @@ -282,6 +340,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil + case tcpip.TCPSynRcvdCountThresholdOption: + p.mu.Lock() + p.synRcvdCount.SetThreshold(uint64(v)) + p.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -350,6 +414,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { p.mu.RUnlock() return nil + case *tcpip.TCPSynRcvdCountThresholdOption: + p.mu.RLock() + *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + p.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -365,6 +435,12 @@ func (p *protocol) Wait() { p.dispatcher.wait() } +// SynRcvdCounter returns a reference to the synRcvdCount for this protocol +// instance. +func (p *protocol) SynRcvdCounter() *synRcvdCounter { + return &p.synRcvdCount +} + // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { return &protocol{ @@ -374,6 +450,7 @@ func NewProtocol() stack.TransportProtocol { availableCongestionControl: []string{ccReno, ccCubic}, tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), minRTO: MinRTO, } |