summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/tcpip.go5
-rw-r--r--pkg/tcpip/transport/tcp/accept.go106
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go77
3 files changed, 124 insertions, 64 deletions
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 109121dbc..1ca4088c9 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -685,6 +685,11 @@ type TCPDeferAcceptOption time.Duration
// default MinRTO used by the Stack.
type TCPMinRTOOption time.Duration
+// TCPSynRcvdCountThresholdOption is used by SetSockOpt/GetSockOpt to specify
+// the number of endpoints that can be in SYN-RCVD state before the stack
+// switches to using SYN cookies.
+type TCPSynRcvdCountThresholdOption uint64
+
// MulticastInterfaceOption is used by SetSockOpt/GetSockOpt to specify a
// default interface for multicast.
type MulticastInterfaceOption struct {
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
index e07b436c4..b61c2a8c3 100644
--- a/pkg/tcpip/transport/tcp/accept.go
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -17,6 +17,7 @@ package tcp
import (
"crypto/sha1"
"encoding/binary"
+ "fmt"
"hash"
"io"
"time"
@@ -49,17 +50,14 @@ const (
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
-)
-var (
- // SynRcvdCountThreshold is the global maximum number of connections
- // that are allowed to be in SYN-RCVD state before TCP starts using SYN
- // cookies to accept connections.
- //
- // It is an exported variable only for testing, and should not otherwise
- // be used by importers of this package.
+ // SynRcvdCountThreshold is the default global maximum number of
+ // connections that are allowed to be in SYN-RCVD state before TCP
+ // starts using SYN cookies to accept connections.
SynRcvdCountThreshold uint64 = 1000
+)
+var (
// mssTable is a slice containing the possible MSS values that we
// encode in the SYN cookie with two bits.
mssTable = []uint16{536, 1300, 1440, 1460}
@@ -74,29 +72,42 @@ func encodeMSS(mss uint16) uint32 {
return 0
}
-// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
-// protected by a mutex so that we can increment only when it's guaranteed not
-// to go above a threshold.
-var synRcvdCount struct {
- sync.Mutex
- value uint64
- pending sync.WaitGroup
-}
-
// listenContext is used by a listening endpoint to store state used while
// listening for connections. This struct is allocated by the listen goroutine
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
- stack *stack.Stack
- rcvWnd seqnum.Size
- nonce [2][sha1.BlockSize]byte
+ stack *stack.Stack
+
+ // synRcvdCount is a reference to the stack level synRcvdCount.
+ synRcvdCount *synRcvdCounter
+
+ // rcvWnd is the receive window that is sent by this listening context
+ // in the initial SYN-ACK.
+ rcvWnd seqnum.Size
+
+ // nonce are random bytes that are initialized once when the context
+ // is created and used to seed the hash function when generating
+ // the SYN cookie.
+ nonce [2][sha1.BlockSize]byte
+
+ // listenEP is a reference to the listening endpoint associated with
+ // this context. Can be nil if the context is created by the forwarder.
listenEP *endpoint
+ // hasherMu protects hasher.
hasherMu sync.Mutex
- hasher hash.Hash
- v6only bool
+ // hasher is the hash function used to generate a SYN cookie.
+ hasher hash.Hash
+
+ // v6Only is true if listenEP is a dual stack socket and has the
+ // IPV6_V6ONLY option set.
+ v6only bool
+
+ // netProto indicates the network protocol(IPv4/v6) for the listening
+ // endpoint.
netProto tcpip.NetworkProtocolNumber
+
// pendingMu protects pendingEndpoints. This should only be accessed
// by the listening endpoint's worker goroutine.
//
@@ -115,44 +126,6 @@ func timeStamp() uint32 {
return uint32(time.Now().Unix()>>6) & tsMask
}
-// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
-// state. It succeeds if the increment doesn't make the count go beyond the
-// threshold, and fails otherwise.
-func incSynRcvdCount() bool {
- synRcvdCount.Lock()
-
- if synRcvdCount.value >= SynRcvdCountThreshold {
- synRcvdCount.Unlock()
- return false
- }
-
- synRcvdCount.pending.Add(1)
- synRcvdCount.value++
-
- synRcvdCount.Unlock()
- return true
-}
-
-// decSynRcvdCount atomically decrements the global number of endpoints in
-// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
-// succeeded.
-func decSynRcvdCount() {
- synRcvdCount.Lock()
-
- synRcvdCount.value--
- synRcvdCount.pending.Done()
- synRcvdCount.Unlock()
-}
-
-// synCookiesInUse() returns true if the synRcvdCount is greater than
-// SynRcvdCountThreshold.
-func synCookiesInUse() bool {
- synRcvdCount.Lock()
- v := synRcvdCount.value
- synRcvdCount.Unlock()
- return v >= SynRcvdCountThreshold
-}
-
// newListenContext creates a new listen context.
func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
@@ -164,6 +137,11 @@ func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size,
listenEP: listenEP,
pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint),
}
+ p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol)
+ if !ok {
+ panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk))
+ }
+ l.synRcvdCount = p.SynRcvdCounter()
rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:])
@@ -410,7 +388,7 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) {
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
- defer decSynRcvdCount()
+ defer ctx.synRcvdCount.dec()
defer func() {
e.mu.Lock()
e.decSynRcvdCount()
@@ -477,7 +455,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
switch {
case s.flags == header.TCPFlagSyn:
opts := parseSynSegmentOptions(s)
- if incSynRcvdCount() {
+ if ctx.synRcvdCount.inc() {
// Only handle the syn if the following conditions hold
// - accept queue is not full.
// - number of connections in synRcvd state is less than the
@@ -487,7 +465,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
return
}
- decSynRcvdCount()
+ ctx.synRcvdCount.dec()
e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment()
e.stack.Stats().DroppedPackets.Increment()
@@ -540,7 +518,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
return
}
- if !synCookiesInUse() {
+ if !ctx.synRcvdCount.synCookiesInUse() {
// When not using SYN cookies, as per RFC 793, section 3.9, page 64:
// Any acknowledgment is bad if it arrives on a connection still in
// the LISTEN state. An acceptable reset segment should be formed
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
index 91f25c132..effbf203f 100644
--- a/pkg/tcpip/transport/tcp/protocol.go
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -94,6 +94,63 @@ const (
ccCubic = "cubic"
)
+// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The
+// value is protected by a mutex so that we can increment only when it's
+// guaranteed not to go above a threshold.
+type synRcvdCounter struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+ threshold uint64
+}
+
+// inc tries to increment the global number of endpoints in SYN-RCVD state. It
+// succeeds if the increment doesn't make the count go beyond the threshold, and
+// fails otherwise.
+func (s *synRcvdCounter) inc() bool {
+ s.Lock()
+ defer s.Unlock()
+ if s.value >= s.threshold {
+ return false
+ }
+
+ s.pending.Add(1)
+ s.value++
+
+ return true
+}
+
+// dec atomically decrements the global number of endpoints in SYN-RCVD
+// state. It must only be called if a previous call to inc succeeded.
+func (s *synRcvdCounter) dec() {
+ s.Lock()
+ defer s.Unlock()
+ s.value--
+ s.pending.Done()
+}
+
+// synCookiesInUse returns true if the synRcvdCount is greater than
+// SynRcvdCountThreshold.
+func (s *synRcvdCounter) synCookiesInUse() bool {
+ s.Lock()
+ defer s.Unlock()
+ return s.value >= s.threshold
+}
+
+// SetThreshold sets synRcvdCounter.Threshold to ths new threshold.
+func (s *synRcvdCounter) SetThreshold(threshold uint64) {
+ s.Lock()
+ defer s.Unlock()
+ s.threshold = threshold
+}
+
+// Threshold returns the current value of synRcvdCounter.Threhsold.
+func (s *synRcvdCounter) Threshold() uint64 {
+ s.Lock()
+ defer s.Unlock()
+ return s.threshold
+}
+
type protocol struct {
mu sync.RWMutex
sackEnabled bool
@@ -106,6 +163,7 @@ type protocol struct {
tcpLingerTimeout time.Duration
tcpTimeWaitTimeout time.Duration
minRTO time.Duration
+ synRcvdCount synRcvdCounter
dispatcher *dispatcher
}
@@ -282,6 +340,12 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error {
p.mu.Unlock()
return nil
+ case tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.Lock()
+ p.synRcvdCount.SetThreshold(uint64(v))
+ p.mu.Unlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -350,6 +414,12 @@ func (p *protocol) Option(option interface{}) *tcpip.Error {
p.mu.RUnlock()
return nil
+ case *tcpip.TCPSynRcvdCountThresholdOption:
+ p.mu.RLock()
+ *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold())
+ p.mu.RUnlock()
+ return nil
+
default:
return tcpip.ErrUnknownProtocolOption
}
@@ -365,6 +435,12 @@ func (p *protocol) Wait() {
p.dispatcher.wait()
}
+// SynRcvdCounter returns a reference to the synRcvdCount for this protocol
+// instance.
+func (p *protocol) SynRcvdCounter() *synRcvdCounter {
+ return &p.synRcvdCount
+}
+
// NewProtocol returns a TCP transport protocol.
func NewProtocol() stack.TransportProtocol {
return &protocol{
@@ -374,6 +450,7 @@ func NewProtocol() stack.TransportProtocol {
availableCongestionControl: []string{ccReno, ccCubic},
tcpLingerTimeout: DefaultTCPLingerTimeout,
tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout,
+ synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold},
dispatcher: newDispatcher(runtime.GOMAXPROCS(0)),
minRTO: MinRTO,
}