diff options
Diffstat (limited to 'pkg/tcpip/stack/nud.go')
-rw-r--r-- | pkg/tcpip/stack/nud.go | 80 |
1 files changed, 36 insertions, 44 deletions
diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index 5a94e9ac6..ca9822bca 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -16,6 +16,7 @@ package stack import ( "math" + "math/rand" "sync" "time" @@ -313,45 +314,36 @@ func calcMaxRandomFactor(minRandomFactor float32) float32 { return defaultMaxRandomFactor } -// A Rand is a source of random numbers. -type Rand interface { - // Float32 returns, as a float32, a pseudo-random number in [0.0,1.0). - Float32() float32 -} - // NUDState stores states needed for calculating reachable time. type NUDState struct { - rng Rand + clock tcpip.Clock + rng *rand.Rand - // mu protects the fields below. - // - // It is necessary for NUDState to handle its own locking since neighbor - // entries may access the NUD state from within the goroutine spawned by - // time.AfterFunc(). This goroutine may run concurrently with the main - // process for controlling the neighbor cache and would otherwise introduce - // race conditions if NUDState was not locked properly. - mu sync.RWMutex + mu struct { + sync.RWMutex - config NUDConfigurations + config NUDConfigurations - // reachableTime is the duration to wait for a REACHABLE entry to - // transition into STALE after inactivity. This value is calculated with - // the algorithm defined in RFC 4861 section 6.3.2. - reachableTime time.Duration + // reachableTime is the duration to wait for a REACHABLE entry to + // transition into STALE after inactivity. This value is calculated with + // the algorithm defined in RFC 4861 section 6.3.2. + reachableTime time.Duration - expiration time.Time - prevBaseReachableTime time.Duration - prevMinRandomFactor float32 - prevMaxRandomFactor float32 + expiration time.Time + prevBaseReachableTime time.Duration + prevMinRandomFactor float32 + prevMaxRandomFactor float32 + } } // NewNUDState returns new NUDState using c as configuration and the specified // random number generator for use in recomputing ReachableTime. -func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { +func NewNUDState(c NUDConfigurations, clock tcpip.Clock, rng *rand.Rand) *NUDState { s := &NUDState{ - rng: rng, + clock: clock, + rng: rng, } - s.config = c + s.mu.config = c return s } @@ -359,14 +351,14 @@ func NewNUDState(c NUDConfigurations, rng Rand) *NUDState { func (s *NUDState) Config() NUDConfigurations { s.mu.RLock() defer s.mu.RUnlock() - return s.config + return s.mu.config } // SetConfig replaces the existing NUD configurations with c. func (s *NUDState) SetConfig(c NUDConfigurations) { s.mu.Lock() defer s.mu.Unlock() - s.config = c + s.mu.config = c } // ReachableTime returns the duration to wait for a REACHABLE entry to @@ -377,13 +369,13 @@ func (s *NUDState) ReachableTime() time.Duration { s.mu.Lock() defer s.mu.Unlock() - if time.Now().After(s.expiration) || - s.config.BaseReachableTime != s.prevBaseReachableTime || - s.config.MinRandomFactor != s.prevMinRandomFactor || - s.config.MaxRandomFactor != s.prevMaxRandomFactor { + if s.clock.Now().After(s.mu.expiration) || + s.mu.config.BaseReachableTime != s.mu.prevBaseReachableTime || + s.mu.config.MinRandomFactor != s.mu.prevMinRandomFactor || + s.mu.config.MaxRandomFactor != s.mu.prevMaxRandomFactor { s.recomputeReachableTimeLocked() } - return s.reachableTime + return s.mu.reachableTime } // recomputeReachableTimeLocked forces a recalculation of ReachableTime using @@ -408,23 +400,23 @@ func (s *NUDState) ReachableTime() time.Duration { // // s.mu MUST be locked for writing. func (s *NUDState) recomputeReachableTimeLocked() { - s.prevBaseReachableTime = s.config.BaseReachableTime - s.prevMinRandomFactor = s.config.MinRandomFactor - s.prevMaxRandomFactor = s.config.MaxRandomFactor + s.mu.prevBaseReachableTime = s.mu.config.BaseReachableTime + s.mu.prevMinRandomFactor = s.mu.config.MinRandomFactor + s.mu.prevMaxRandomFactor = s.mu.config.MaxRandomFactor - randomFactor := s.config.MinRandomFactor + s.rng.Float32()*(s.config.MaxRandomFactor-s.config.MinRandomFactor) + randomFactor := s.mu.config.MinRandomFactor + s.rng.Float32()*(s.mu.config.MaxRandomFactor-s.mu.config.MinRandomFactor) // Check for overflow, given that minRandomFactor and maxRandomFactor are // guaranteed to be positive numbers. - if float32(math.MaxInt64)/randomFactor < float32(s.config.BaseReachableTime) { - s.reachableTime = time.Duration(math.MaxInt64) + if math.MaxInt64/randomFactor < float32(s.mu.config.BaseReachableTime) { + s.mu.reachableTime = time.Duration(math.MaxInt64) } else if randomFactor == 1 { // Avoid loss of precision when a large base reachable time is used. - s.reachableTime = s.config.BaseReachableTime + s.mu.reachableTime = s.mu.config.BaseReachableTime } else { - reachableTime := int64(float32(s.config.BaseReachableTime) * randomFactor) - s.reachableTime = time.Duration(reachableTime) + reachableTime := int64(float32(s.mu.config.BaseReachableTime) * randomFactor) + s.mu.reachableTime = time.Duration(reachableTime) } - s.expiration = time.Now().Add(2 * time.Hour) + s.mu.expiration = s.clock.Now().Add(2 * time.Hour) } |