summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/keypair.go15
-rw-r--r--src/misc.go7
-rw-r--r--src/noise_protocol.go8
-rw-r--r--src/receive.go4
-rw-r--r--src/replay.go71
-rw-r--r--src/replay_test.go114
-rw-r--r--src/timers.go50
7 files changed, 227 insertions, 42 deletions
diff --git a/src/keypair.go b/src/keypair.go
index b24dbe4..b5f46df 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -7,13 +7,14 @@ import (
)
type KeyPair struct {
- receive cipher.AEAD
- send cipher.AEAD
- sendNonce uint64
- isInitiator bool
- created time.Time
- localIndex uint32
- remoteIndex uint32
+ receive cipher.AEAD
+ replayFilter ReplayFilter
+ send cipher.AEAD
+ sendNonce uint64
+ isInitiator bool
+ created time.Time
+ localIndex uint32
+ remoteIndex uint32
}
type KeyPairs struct {
diff --git a/src/misc.go b/src/misc.go
index 75561b2..fc75c0d 100644
--- a/src/misc.go
+++ b/src/misc.go
@@ -19,6 +19,13 @@ func min(a uint, b uint) uint {
return a
}
+func minUint64(a uint64, b uint64) uint64 {
+ if a > b {
+ return b
+ }
+ return a
+}
+
func signalSend(c chan struct{}) {
select {
case c <- struct{}{}:
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index a90fe4c..bfa3797 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -415,6 +415,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return lookup.peer
}
+/* Derives a new key-pair from the current handshake state
+ *
+ */
func (peer *Peer) NewKeyPair() *KeyPair {
handshake := &peer.handshake
handshake.mutex.Lock()
@@ -445,10 +448,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// create AEAD instances
keyPair := new(KeyPair)
+ keyPair.created = time.Now()
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
keyPair.sendNonce = 0
- keyPair.created = time.Now()
+ keyPair.replayFilter.Init()
keyPair.isInitiator = isInitiator
keyPair.localIndex = peer.handshake.localIndex
keyPair.remoteIndex = peer.handshake.remoteIndex
@@ -462,8 +466,6 @@ func (peer *Peer) NewKeyPair() *KeyPair {
})
handshake.localIndex = 0
- // TODO: start timer for keypair (clearing)
-
// rotate key pairs
kp := &peer.keyPairs
diff --git a/src/receive.go b/src/receive.go
index e780c66..6530c47 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -432,6 +432,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for replay
+ if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+ return
+ }
+
// time (passive) keep-alive
peer.TimerStartKeepalive()
diff --git a/src/replay.go b/src/replay.go
new file mode 100644
index 0000000..49c7e08
--- /dev/null
+++ b/src/replay.go
@@ -0,0 +1,71 @@
+package main
+
+/* Implementation of RFC6479
+ * https://tools.ietf.org/html/rfc6479
+ *
+ * The implementation is not safe for concurrent use!
+ */
+
+const (
+ // See: https://golang.org/src/math/big/arith.go
+ _Wordm = ^uintptr(0)
+ _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
+ _WordSize = 1 << _WordLogSize
+)
+
+const (
+ CounterRedundantBitsLog = _WordLogSize + 3
+ CounterRedundantBits = _WordSize * 8
+ CounterBitsTotal = 2048
+ CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
+)
+
+const (
+ BacktrackWords = CounterBitsTotal / _WordSize
+)
+
+type ReplayFilter struct {
+ counter uint64
+ backtrack [BacktrackWords]uintptr
+}
+
+func (filter *ReplayFilter) Init() {
+ filter.counter = 0
+ filter.backtrack[0] = 0
+}
+
+func (filter *ReplayFilter) ValidateCounter(counter uint64) bool {
+ if counter >= RejectAfterMessages {
+ return false
+ }
+
+ indexWord := counter >> CounterRedundantBitsLog
+
+ if counter > filter.counter {
+
+ // move window forward
+
+ current := filter.counter >> CounterRedundantBitsLog
+ diff := minUint64(indexWord-current, BacktrackWords)
+ for i := uint64(1); i <= diff; i++ {
+ filter.backtrack[(current+i)%BacktrackWords] = 0
+ }
+ filter.counter = counter
+
+ } else if filter.counter-counter > CounterWindowSize {
+
+ // behind current window
+
+ return false
+ }
+
+ indexWord %= BacktrackWords
+ indexBit := counter & uint64(CounterRedundantBits-1)
+
+ // check and set bit
+
+ oldValue := filter.backtrack[indexWord]
+ newValue := oldValue | (1 << indexBit)
+ filter.backtrack[indexWord] = newValue
+ return oldValue != newValue
+}
diff --git a/src/replay_test.go b/src/replay_test.go
new file mode 100644
index 0000000..e75c5c1
--- /dev/null
+++ b/src/replay_test.go
@@ -0,0 +1,114 @@
+package main
+
+import (
+ "testing"
+)
+
+/* Ported from the linux kernel implementation
+ *
+ *
+ */
+
+/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
+
+func TestReplay(t *testing.T) {
+ var filter ReplayFilter
+
+ T_LIM := CounterWindowSize + 1
+
+ testNumber := 0
+ T := func(n uint64, v bool) {
+ testNumber++
+ if filter.ValidateCounter(n) != v {
+ t.Fatal("Test", testNumber, "failed", n, v)
+ }
+ }
+
+ filter.Init()
+
+ /* 1 */ T(0, true)
+ /* 2 */ T(1, true)
+ /* 3 */ T(1, false)
+ /* 4 */ T(9, true)
+ /* 5 */ T(8, true)
+ /* 6 */ T(7, true)
+ /* 7 */ T(7, false)
+ /* 8 */ T(T_LIM, true)
+ /* 9 */ T(T_LIM-1, true)
+ /* 10 */ T(T_LIM-1, false)
+ /* 11 */ T(T_LIM-2, true)
+ /* 12 */ T(2, true)
+ /* 13 */ T(2, false)
+ /* 14 */ T(T_LIM+16, true)
+ /* 15 */ T(3, false)
+ /* 16 */ T(T_LIM+16, false)
+ /* 17 */ T(T_LIM*4, true)
+ /* 18 */ T(T_LIM*4-(T_LIM-1), true)
+ /* 19 */ T(10, false)
+ /* 20 */ T(T_LIM*4-T_LIM, false)
+ /* 21 */ T(T_LIM*4-(T_LIM+1), false)
+ /* 22 */ T(T_LIM*4-(T_LIM-2), true)
+ /* 23 */ T(T_LIM*4+1-T_LIM, false)
+ /* 24 */ T(0, false)
+ /* 25 */ T(RejectAfterMessages, false)
+ /* 26 */ T(RejectAfterMessages-1, true)
+ /* 27 */ T(RejectAfterMessages, false)
+ /* 28 */ T(RejectAfterMessages-1, false)
+ /* 29 */ T(RejectAfterMessages-2, true)
+ /* 30 */ T(RejectAfterMessages+1, false)
+ /* 31 */ T(RejectAfterMessages+2, false)
+ /* 32 */ T(RejectAfterMessages-2, false)
+ /* 33 */ T(RejectAfterMessages-3, true)
+ /* 34 */ T(0, false)
+
+ t.Log("Bulk test 1")
+ filter.Init()
+ testNumber = 0
+ for i := uint64(1); i <= CounterWindowSize; i++ {
+ T(i, true)
+ }
+ T(0, true)
+ T(0, false)
+
+ t.Log("Bulk test 2")
+ filter.Init()
+ testNumber = 0
+ for i := uint64(2); i <= CounterWindowSize+1; i++ {
+ T(i, true)
+ }
+ T(1, true)
+ T(0, false)
+
+ t.Log("Bulk test 3")
+ filter.Init()
+ testNumber = 0
+ for i := CounterWindowSize + 1; i > 0; i-- {
+ T(i, true)
+ }
+
+ t.Log("Bulk test 4")
+ filter.Init()
+ testNumber = 0
+ for i := CounterWindowSize + 2; i > 1; i-- {
+ T(i, true)
+ }
+ T(0, false)
+
+ t.Log("Bulk test 5")
+ filter.Init()
+ testNumber = 0
+ for i := CounterWindowSize; i > 0; i-- {
+ T(i, true)
+ }
+ T(CounterWindowSize+1, true)
+ T(0, false)
+
+ t.Log("Bulk test 6")
+ filter.Init()
+ testNumber = 0
+ for i := CounterWindowSize; i > 0; i-- {
+ T(i, true)
+ }
+ T(0, true)
+ T(CounterWindowSize+1, true)
+}
diff --git a/src/timers.go b/src/timers.go
index 26926c2..70e0766 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -12,22 +12,15 @@ import (
*
*/
func (peer *Peer) KeepKeyFreshSending() {
- send := func() bool {
- peer.keyPairs.mutex.RLock()
- defer peer.keyPairs.mutex.RUnlock()
-
- kp := peer.keyPairs.current
- if kp == nil {
- return false
- }
-
- if !kp.isInitiator {
- return false
- }
-
- nonce := atomic.LoadUint64(&kp.sendNonce)
- return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
- }()
+ kp := peer.keyPairs.Current()
+ if kp == nil {
+ return
+ }
+ if !kp.isInitiator {
+ return
+ }
+ nonce := atomic.LoadUint64(&kp.sendNonce)
+ send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
if send {
signalSend(peer.signal.handshakeBegin)
}
@@ -37,22 +30,15 @@ func (peer *Peer) KeepKeyFreshSending() {
*
*/
func (peer *Peer) KeepKeyFreshReceiving() {
- send := func() bool {
- peer.keyPairs.mutex.RLock()
- defer peer.keyPairs.mutex.RUnlock()
-
- kp := peer.keyPairs.current
- if kp == nil {
- return false
- }
-
- if !kp.isInitiator {
- return false
- }
-
- nonce := atomic.LoadUint64(&kp.sendNonce)
- return nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
- }()
+ kp := peer.keyPairs.Current()
+ if kp == nil {
+ return
+ }
+ if !kp.isInitiator {
+ return
+ }
+ nonce := atomic.LoadUint64(&kp.sendNonce)
+ send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
signalSend(peer.signal.handshakeBegin)
}