diff options
Diffstat (limited to 'src/noise_protocol.go')
-rw-r--r-- | src/noise_protocol.go | 127 |
1 files changed, 81 insertions, 46 deletions
diff --git a/src/noise_protocol.go b/src/noise_protocol.go index e237dbe..46ceeda 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -77,7 +77,7 @@ type MessageCookieReply struct { type Handshake struct { state int - mutex sync.Mutex + mutex sync.RWMutex hash [blake2s.Size]byte // hash value chainKey [blake2s.Size]byte // chain key presharedKey NoiseSymmetricKey // psk @@ -205,49 +205,64 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { } hash = mixHash(hash, msg.Static[:]) - // find peer + // lookup peer peer := device.LookupPeer(peerPK) if peer == nil { return nil } handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - // decrypt timestamp + // verify identity var timestamp TAI64N - func() { - var key [chacha20poly1305.KeySize]byte - chainKey, key = KDF2( - chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) - }() - if err != nil { - return nil - } - hash = mixHash(hash, msg.Timestamp[:]) + ok := func() bool { + + // read lock handshake + + handshake.mutex.RLock() + defer handshake.mutex.RUnlock() + + // decrypt timestamp + + func() { + var key [chacha20poly1305.KeySize]byte + chainKey, key = KDF2( + chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) + }() + if err != nil { + return false + } + hash = mixHash(hash, msg.Timestamp[:]) + + // TODO: check for flood attack + + // check for replay attack - // check for replay attack + return timestamp.After(handshake.lastTimestamp) + }() - if !timestamp.After(handshake.lastTimestamp) { + if !ok { return nil } - // TODO: check for flood attack - // update handshake state + handshake.mutex.Lock() + handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender handshake.remoteEphemeral = msg.Ephemeral handshake.lastTimestamp = timestamp handshake.state = HandshakeInitiationConsumed + + handshake.mutex.Unlock() + return peer } @@ -320,47 +335,67 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { return nil } - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitiationCreated { - return nil - } + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) - // finish 3-way DH + ok := func() bool { - hash := mixHash(handshake.hash, msg.Ephemeral[:]) - chainKey := handshake.chainKey + // read lock handshake - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - chainKey = mixKey(chainKey, ss[:]) - ss = device.privateKey.sharedSecret(msg.Ephemeral) - chainKey = mixKey(chainKey, ss[:]) - }() + handshake.mutex.RLock() + defer handshake.mutex.RUnlock() - // add preshared key (psk) + if handshake.state != HandshakeInitiationCreated { + return false + } - var tau [blake2s.Size]byte - var key [chacha20poly1305.KeySize]byte - chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) - hash = mixHash(hash, tau[:]) + // finish 3-way DH - // authenticate + hash = mixHash(handshake.hash, msg.Ephemeral[:]) + chainKey = handshake.chainKey - aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) - if err != nil { + func() { + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + chainKey = mixKey(chainKey, ss[:]) + ss = device.privateKey.sharedSecret(msg.Ephemeral) + chainKey = mixKey(chainKey, ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) + hash = mixHash(hash, tau[:]) + + // authenticate + + aead, _ := chacha20poly1305.New(key[:]) + _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + if err != nil { + return false + } + hash = mixHash(hash, msg.Empty[:]) + return true + }() + + if !ok { return nil } - hash = mixHash(hash, msg.Empty[:]) // update handshake state + handshake.mutex.Lock() + handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender handshake.state = HandshakeResponseConsumed + handshake.mutex.Unlock() + return lookup.peer } |