diff options
-rw-r--r-- | replay/replay.go | 97 | ||||
-rw-r--r-- | replay/replay_test.go | 24 |
2 files changed, 50 insertions, 71 deletions
diff --git a/replay/replay.go b/replay/replay.go index 85647f5..8685712 100644 --- a/replay/replay.go +++ b/replay/replay.go @@ -3,81 +3,60 @@ * Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved. */ +// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. package replay -/* Implementation of RFC6479 - * https://tools.ietf.org/html/rfc6479 - * - * The implementation is not safe for concurrent use! - */ - -const ( - // See: https://golang.org/src/math/big/arith.go - _Wordm = ^uintptr(0) - _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1 - _WordSize = 1 << _WordLogSize -) +type block uint64 const ( - CounterRedundantBitsLog = _WordLogSize + 3 - CounterRedundantBits = _WordSize * 8 - CounterBitsTotal = 8192 - CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits) + blockBitLog = 6 // 1<<6 == 64 bits + blockBits = 1 << blockBitLog // must be power of 2 + ringBlocks = 1 << 7 // must be power of 2 + windowSize = (ringBlocks - 1) * blockBits + blockMask = ringBlocks - 1 + bitMask = blockBits - 1 ) -const ( - BacktrackWords = CounterBitsTotal / 8 / _WordSize -) - -func minUint64(a uint64, b uint64) uint64 { - if a > b { - return b - } - return a -} - +// A ReplayFilter rejects replayed messages by checking if message counter value is +// within a sliding window of previously received messages. +// The zero value for ReplayFilter is an empty filter ready to use. +// Filters are unsafe for concurrent use. type ReplayFilter struct { - counter uint64 - backtrack [BacktrackWords]uintptr + last uint64 + ring [ringBlocks]block } -func (filter *ReplayFilter) Init() { - filter.counter = 0 - filter.backtrack[0] = 0 +// Init resets the filter to empty state. +func (f *ReplayFilter) Init() { + f.last = 0 + f.ring[0] = 0 } -func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool { +// ValidateCounter checks if the counter should be accepted. +// Overlimit counters (>= limit) are always rejected. +func (f *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool { if counter >= limit { return false } - - indexWord := counter >> CounterRedundantBitsLog - - if counter > filter.counter { - - // move window forward - - current := filter.counter >> CounterRedundantBitsLog - diff := minUint64(indexWord-current, BacktrackWords) - for i := uint64(1); i <= diff; i++ { - filter.backtrack[(current+i)%BacktrackWords] = 0 + indexBlock := counter >> blockBitLog + if counter > f.last { // move window forward + current := f.last >> blockBitLog + diff := indexBlock - current + if diff > ringBlocks { + diff = ringBlocks // cap diff to clear the whole ring } - filter.counter = counter - - } else if filter.counter-counter > CounterWindowSize { - - // behind current window - + for i := current + 1; i <= current+diff; i++ { + f.ring[i&blockMask] = 0 + } + f.last = counter + } else if f.last-counter > windowSize { // behind current window return false } - - indexWord %= BacktrackWords - indexBit := counter & uint64(CounterRedundantBits-1) - // check and set bit - - oldValue := filter.backtrack[indexWord] - newValue := oldValue | (1 << indexBit) - filter.backtrack[indexWord] = newValue - return oldValue != newValue + indexBlock &= blockMask + indexBit := counter & bitMask + old := f.ring[indexBlock] + new := old | 1<<indexBit + f.ring[indexBlock] = new + return old != new } diff --git a/replay/replay_test.go b/replay/replay_test.go index ceae2f3..5af66ff 100644 --- a/replay/replay_test.go +++ b/replay/replay_test.go @@ -19,13 +19,13 @@ const RejectAfterMessages = (1 << 64) - (1 << 4) - 1 func TestReplay(t *testing.T) { var filter ReplayFilter - T_LIM := CounterWindowSize + 1 + const T_LIM = windowSize + 1 testNumber := 0 - T := func(n uint64, v bool) { + T := func(n uint64, expected bool) { testNumber++ - if filter.ValidateCounter(n, RejectAfterMessages) != v { - t.Fatal("Test", testNumber, "failed", n, v) + if filter.ValidateCounter(n, RejectAfterMessages) != expected { + t.Fatal("Test", testNumber, "failed", n, expected) } } @@ -69,7 +69,7 @@ func TestReplay(t *testing.T) { t.Log("Bulk test 1") filter.Init() testNumber = 0 - for i := uint64(1); i <= CounterWindowSize; i++ { + for i := uint64(1); i <= windowSize; i++ { T(i, true) } T(0, true) @@ -78,7 +78,7 @@ func TestReplay(t *testing.T) { t.Log("Bulk test 2") filter.Init() testNumber = 0 - for i := uint64(2); i <= CounterWindowSize+1; i++ { + for i := uint64(2); i <= windowSize+1; i++ { T(i, true) } T(1, true) @@ -87,14 +87,14 @@ func TestReplay(t *testing.T) { t.Log("Bulk test 3") filter.Init() testNumber = 0 - for i := CounterWindowSize + 1; i > 0; i-- { + for i := uint64(windowSize + 1); i > 0; i-- { T(i, true) } t.Log("Bulk test 4") filter.Init() testNumber = 0 - for i := CounterWindowSize + 2; i > 1; i-- { + for i := uint64(windowSize + 2); i > 1; i-- { T(i, true) } T(0, false) @@ -102,18 +102,18 @@ func TestReplay(t *testing.T) { t.Log("Bulk test 5") filter.Init() testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { + for i := uint64(windowSize); i > 0; i-- { T(i, true) } - T(CounterWindowSize+1, true) + T(windowSize+1, true) T(0, false) t.Log("Bulk test 6") filter.Init() testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { + for i := uint64(windowSize); i > 0; i-- { T(i, true) } T(0, true) - T(CounterWindowSize+1, true) + T(windowSize+1, true) } |