summaryrefslogtreecommitdiffhomepage
path: root/device
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2020-05-02 01:30:23 -0600
committerJason A. Donenfeld <Jason@zx2c4.com>2020-05-02 01:56:48 -0600
commit28c4d043048e8bb7167e96df6558a6366306fc17 (patch)
treec98e39cd6ed75e23f54e6d1b72b6f5c70fa9ab8a /device
parentfdba6c183aa8d4c19680f436517624038a6f3be5 (diff)
device: use atomic access for unlocked keypair.next
Go's GC semantics might not always guarantee the safety of this, and the race detector gets upset too, so instead we wrap this all in atomic accessors. Reported-by: David Anderson <danderson@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device')
-rw-r--r--device/keypair.go10
-rw-r--r--device/noise-protocol.go16
-rw-r--r--device/noise_test.go2
-rw-r--r--device/peer.go6
4 files changed, 23 insertions, 11 deletions
diff --git a/device/keypair.go b/device/keypair.go
index 9c78fa9..d70c7f4 100644
--- a/device/keypair.go
+++ b/device/keypair.go
@@ -8,7 +8,9 @@ package device
import (
"crypto/cipher"
"sync"
+ "sync/atomic"
"time"
+ "unsafe"
"golang.zx2c4.com/wireguard/replay"
)
@@ -38,6 +40,14 @@ type Keypairs struct {
next *Keypair
}
+func (kp *Keypairs) storeNext(next *Keypair) {
+ atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
+}
+
+func (kp *Keypairs) loadNext() *Keypair {
+ return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
+}
+
func (kp *Keypairs) Current() *Keypair {
kp.RLock()
defer kp.RUnlock()
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index a848c47..e6f676c 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -14,6 +14,7 @@ import (
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
+
"golang.zx2c4.com/wireguard/tai64n"
)
@@ -583,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
defer keypairs.Unlock()
previous := keypairs.previous
- next := keypairs.next
+ next := keypairs.loadNext()
current := keypairs.current
if isInitiator {
if next != nil {
- keypairs.next = nil
+ keypairs.storeNext(nil)
keypairs.previous = next
device.DeleteKeypair(current)
} else {
@@ -597,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
- keypairs.next = keypair
+ keypairs.storeNext(keypair)
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
@@ -608,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
- if keypairs.next != receivedKeypair {
+
+ if keypairs.loadNext() != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
- if keypairs.next != receivedKeypair {
+ if keypairs.loadNext() != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
- keypairs.current = keypairs.next
- keypairs.next = nil
+ keypairs.current = keypairs.loadNext()
+ keypairs.storeNext(nil)
return true
}
diff --git a/device/noise_test.go b/device/noise_test.go
index 6ba3f2e..b5d5845 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
t.Fatal("failed to derive keypair for peer 2", err)
}
- key1 := peer1.keypairs.next
+ key1 := peer1.keypairs.loadNext()
key2 := peer2.keypairs.current
// encrypting / decryption test
diff --git a/device/peer.go b/device/peer.go
index 79d4981..899591b 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -223,10 +223,10 @@ func (peer *Peer) ZeroAndFlushAll() {
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
- device.DeleteKeypair(keypairs.next)
+ device.DeleteKeypair(keypairs.loadNext())
keypairs.previous = nil
keypairs.current = nil
- keypairs.next = nil
+ keypairs.storeNext(nil)
keypairs.Unlock()
// clear handshake state
@@ -254,7 +254,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
keypairs.current.sendNonce = RejectAfterMessages
}
if keypairs.next != nil {
- keypairs.next.sendNonce = RejectAfterMessages
+ keypairs.loadNext().sendNonce = RejectAfterMessages
}
keypairs.Unlock()
}