diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2020-05-02 01:30:23 -0600 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2020-05-02 01:56:48 -0600 |
commit | 28c4d043048e8bb7167e96df6558a6366306fc17 (patch) | |
tree | c98e39cd6ed75e23f54e6d1b72b6f5c70fa9ab8a | |
parent | fdba6c183aa8d4c19680f436517624038a6f3be5 (diff) |
device: use atomic access for unlocked keypair.next
Go's GC semantics might not always guarantee the safety of this, and the
race detector gets upset too, so instead we wrap this all in atomic
accessors.
Reported-by: David Anderson <danderson@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | device/keypair.go | 10 | ||||
-rw-r--r-- | device/noise-protocol.go | 16 | ||||
-rw-r--r-- | device/noise_test.go | 2 | ||||
-rw-r--r-- | device/peer.go | 6 |
4 files changed, 23 insertions, 11 deletions
diff --git a/device/keypair.go b/device/keypair.go index 9c78fa9..d70c7f4 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -8,7 +8,9 @@ package device import ( "crypto/cipher" "sync" + "sync/atomic" "time" + "unsafe" "golang.zx2c4.com/wireguard/replay" ) @@ -38,6 +40,14 @@ type Keypairs struct { next *Keypair } +func (kp *Keypairs) storeNext(next *Keypair) { + atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next)) +} + +func (kp *Keypairs) loadNext() *Keypair { + return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)))) +} + func (kp *Keypairs) Current() *Keypair { kp.RLock() defer kp.RUnlock() diff --git a/device/noise-protocol.go b/device/noise-protocol.go index a848c47..e6f676c 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -14,6 +14,7 @@ import ( "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" ) @@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.next + next := keypairs.loadNext() current := keypairs.current if isInitiator { if next != nil { - keypairs.next = nil + keypairs.storeNext(nil) keypairs.previous = next device.DeleteKeypair(current) } else { @@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.next = keypair + keypairs.storeNext(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs - if keypairs.next != receivedKeypair { + + if keypairs.loadNext() != receivedKeypair { return false } keypairs.Lock() defer keypairs.Unlock() - if keypairs.next != receivedKeypair { + if keypairs.loadNext() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.next - keypairs.next = nil + keypairs.current = keypairs.loadNext() + keypairs.storeNext(nil) return true } diff --git a/device/noise_test.go b/device/noise_test.go index 6ba3f2e..b5d5845 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) { t.Fatal("failed to derive keypair for peer 2", err) } - key1 := peer1.keypairs.next + key1 := peer1.keypairs.loadNext() key2 := peer2.keypairs.current // encrypting / decryption test diff --git a/device/peer.go b/device/peer.go index 79d4981..899591b 100644 --- a/device/peer.go +++ b/device/peer.go @@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.next) + device.DeleteKeypair(keypairs.loadNext()) keypairs.previous = nil keypairs.current = nil - keypairs.next = nil + keypairs.storeNext(nil) keypairs.Unlock() // clear handshake state @@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() { keypairs.current.sendNonce = RejectAfterMessages } if keypairs.next != nil { - keypairs.next.sendNonce = RejectAfterMessages + keypairs.loadNext().sendNonce = RejectAfterMessages } keypairs.Unlock() } |