summaryrefslogtreecommitdiffhomepage
path: root/src/noise_protocol.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/noise_protocol.go')
-rw-r--r--src/noise_protocol.go127
1 files changed, 81 insertions, 46 deletions
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index e237dbe..46ceeda 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -77,7 +77,7 @@ type MessageCookieReply struct {
type Handshake struct {
state int
- mutex sync.Mutex
+ mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
presharedKey NoiseSymmetricKey // psk
@@ -205,49 +205,64 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
}
hash = mixHash(hash, msg.Static[:])
- // find peer
+ // lookup peer
peer := device.LookupPeer(peerPK)
if peer == nil {
return nil
}
handshake := &peer.handshake
- handshake.mutex.Lock()
- defer handshake.mutex.Unlock()
- // decrypt timestamp
+ // verify identity
var timestamp TAI64N
- func() {
- var key [chacha20poly1305.KeySize]byte
- chainKey, key = KDF2(
- chainKey[:],
- handshake.precomputedStaticStatic[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
- }()
- if err != nil {
- return nil
- }
- hash = mixHash(hash, msg.Timestamp[:])
+ ok := func() bool {
+
+ // read lock handshake
+
+ handshake.mutex.RLock()
+ defer handshake.mutex.RUnlock()
+
+ // decrypt timestamp
+
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ chainKey, key = KDF2(
+ chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
+ }()
+ if err != nil {
+ return false
+ }
+ hash = mixHash(hash, msg.Timestamp[:])
+
+ // TODO: check for flood attack
+
+ // check for replay attack
- // check for replay attack
+ return timestamp.After(handshake.lastTimestamp)
+ }()
- if !timestamp.After(handshake.lastTimestamp) {
+ if !ok {
return nil
}
- // TODO: check for flood attack
-
// update handshake state
+ handshake.mutex.Lock()
+
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp
handshake.state = HandshakeInitiationConsumed
+
+ handshake.mutex.Unlock()
+
return peer
}
@@ -320,47 +335,67 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return nil
}
- handshake.mutex.Lock()
- defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitiationCreated {
- return nil
- }
+ var (
+ hash [blake2s.Size]byte
+ chainKey [blake2s.Size]byte
+ )
- // finish 3-way DH
+ ok := func() bool {
- hash := mixHash(handshake.hash, msg.Ephemeral[:])
- chainKey := handshake.chainKey
+ // read lock handshake
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- chainKey = mixKey(chainKey, ss[:])
- ss = device.privateKey.sharedSecret(msg.Ephemeral)
- chainKey = mixKey(chainKey, ss[:])
- }()
+ handshake.mutex.RLock()
+ defer handshake.mutex.RUnlock()
- // add preshared key (psk)
+ if handshake.state != HandshakeInitiationCreated {
+ return false
+ }
- var tau [blake2s.Size]byte
- var key [chacha20poly1305.KeySize]byte
- chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
- hash = mixHash(hash, tau[:])
+ // finish 3-way DH
- // authenticate
+ hash = mixHash(handshake.hash, msg.Ephemeral[:])
+ chainKey = handshake.chainKey
- aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
- if err != nil {
+ func() {
+ ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ chainKey = mixKey(chainKey, ss[:])
+ ss = device.privateKey.sharedSecret(msg.Ephemeral)
+ chainKey = mixKey(chainKey, ss[:])
+ }()
+
+ // add preshared key (psk)
+
+ var tau [blake2s.Size]byte
+ var key [chacha20poly1305.KeySize]byte
+ chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
+ hash = mixHash(hash, tau[:])
+
+ // authenticate
+
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ if err != nil {
+ return false
+ }
+ hash = mixHash(hash, msg.Empty[:])
+ return true
+ }()
+
+ if !ok {
return nil
}
- hash = mixHash(hash, msg.Empty[:])
// update handshake state
+ handshake.mutex.Lock()
+
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
+ handshake.mutex.Unlock()
+
return lookup.peer
}