summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/device.go22
-rw-r--r--src/keypair.go4
-rw-r--r--src/noise_protocol.go22
-rw-r--r--src/peer.go44
-rw-r--r--src/timers.go7
5 files changed, 58 insertions, 41 deletions
diff --git a/src/device.go b/src/device.go
index 0317b60..c041987 100644
--- a/src/device.go
+++ b/src/device.go
@@ -88,28 +88,6 @@ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
device.routing.table.RemovePeer(peer)
peer.Stop()
- // clean index table
-
- kp := &peer.keyPairs
- kp.mutex.Lock()
-
- if kp.previous != nil {
- device.indices.Delete(kp.previous.localIndex)
- }
-
- if kp.current != nil {
- device.indices.Delete(kp.current.localIndex)
- }
-
- if kp.next != nil {
- device.indices.Delete(kp.next.localIndex)
- }
-
- kp.previous = nil
- kp.current = nil
- kp.next = nil
- kp.mutex.Unlock()
-
// remove from peer map
delete(device.peers.keyMap, key)
diff --git a/src/keypair.go b/src/keypair.go
index 7e5297b..283cb92 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -38,5 +38,7 @@ func (kp *KeyPairs) Current() *KeyPair {
}
func (device *Device) DeleteKeyPair(key *KeyPair) {
- device.indices.Delete(key.localIndex)
+ if key != nil {
+ device.indices.Delete(key.localIndex)
+ }
}
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index d620a0d..c9713c0 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -121,6 +121,15 @@ func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
hsh.Reset()
}
+func (h *Handshake) Clear() {
+ setZero(h.localEphemeral[:])
+ setZero(h.remoteEphemeral[:])
+ setZero(h.chainKey[:])
+ setZero(h.hash[:])
+ h.localIndex = 0
+ h.state = HandshakeZeroed
+}
+
func (h *Handshake) mixHash(data []byte) {
mixHash(&h.hash, &h.hash, data)
}
@@ -138,8 +147,8 @@ func init() {
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
- device.noise.mutex.Lock()
- defer device.noise.mutex.Unlock()
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
handshake := &peer.handshake
handshake.mutex.Lock()
@@ -393,7 +402,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
ok := func() bool {
- // read lock handshake
+ // lock handshake state
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
@@ -402,6 +411,11 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return false
}
+ // lock private key for reading
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
// finish 3-way DH
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
@@ -432,7 +446,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
)
mixHash(&hash, &hash, tau[:])
- // authenticate
+ // authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
diff --git a/src/peer.go b/src/peer.go
index 3b8f7cc..7776b71 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -48,8 +48,8 @@ type Peer struct {
// state related to WireGuard timers
- keepalivePersistent Timer // set for persistent keepalives
- keepalivePassive Timer // set upon recieving messages
+ keepalivePersistent Timer // set for persistent keep-alive
+ keepalivePassive Timer // set upon receiving messages
zeroAllKeys Timer // zero all key material
handshakeNew Timer // begin a new handshake (stale)
handshakeDeadline Timer // complete handshake timeout
@@ -69,7 +69,7 @@ type Peer struct {
mutex deadlock.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop
- stop Signal // size 0, stop all goroutines in peer
+ stop Signal // size 0, stop all go-routines in peer
}
mac CookieGenerator
@@ -123,7 +123,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
}
device.peers.keyMap[pk] = peer
- // precompute DH
+ // pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
@@ -186,16 +186,19 @@ func (peer *Peer) String() string {
func (peer *Peer) Start() {
+ // should never start a peer on a closed device
+
if peer.device.isClosed.Get() {
return
}
+ // prevent simultaneous start/stop operations
+
peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock()
-
peer.device.log.Debug.Println("Starting:", peer.String())
- // stop & wait for ungoing routines (if any)
+ // stop & wait for ongoing routines (if any)
peer.isRunning.Set(false)
peer.routines.stop.Broadcast()
@@ -230,12 +233,15 @@ func (peer *Peer) Start() {
func (peer *Peer) Stop() {
+ // prevent simultaneous start/stop operations
+
peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock()
- peer.device.log.Debug.Println("Stopping:", peer.String())
+ device := peer.device
+ device.log.Debug.Println("Stopping:", peer.String())
- // stop & wait for ungoing routines (if any)
+ // stop & wait for ongoing peer routines (if any)
peer.routines.stop.Broadcast()
peer.routines.starting.Wait()
@@ -247,6 +253,28 @@ func (peer *Peer) Stop() {
close(peer.queue.outbound)
close(peer.queue.inbound)
+ // clear key pairs
+
+ kp := &peer.keyPairs
+ kp.mutex.Lock()
+
+ device.DeleteKeyPair(kp.previous)
+ device.DeleteKeyPair(kp.current)
+ device.DeleteKeyPair(kp.next)
+
+ kp.previous = nil
+ kp.current = nil
+ kp.next = nil
+ kp.mutex.Unlock()
+
+ // clear handshake state
+
+ hs := &peer.handshake
+ hs.mutex.Lock()
+ device.indices.Delete(hs.localIndex)
+ hs.Clear()
+ hs.mutex.Unlock()
+
// reset signal (to handle repeated stopping)
peer.routines.stop = NewSignal()
diff --git a/src/timers.go b/src/timers.go
index 2ef105e..7092688 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -274,12 +274,7 @@ func (peer *Peer) RoutineTimerHandler() {
// zero out handshake
device.indices.Delete(hs.localIndex)
-
- hs.localIndex = 0
- setZero(hs.localEphemeral[:])
- setZero(hs.remoteEphemeral[:])
- setZero(hs.chainKey[:])
- setZero(hs.hash[:])
+ hs.Clear()
hs.mutex.Unlock()
// handshake timers