diff options
Diffstat (limited to 'noise-protocol.go')
-rw-r--r-- | noise-protocol.go | 67 |
1 files changed, 33 insertions, 34 deletions
diff --git a/noise-protocol.go b/noise-protocol.go index f72dcc4..ffc2b50 100644 --- a/noise-protocol.go +++ b/noise-protocol.go @@ -107,6 +107,7 @@ type Handshake struct { precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp lastInitiationConsumption time.Time + lastSentHandshake time.Time } var ( @@ -153,8 +154,8 @@ func init() { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + defer device.staticIdentity.mutex.RUnlock() handshake := &peer.handshake handshake.mutex.Lock() @@ -206,7 +207,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e ss[:], ) aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) }() handshake.mixHash(msg.Static[:]) @@ -240,10 +241,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { return nil } - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + defer device.staticIdentity.mutex.RUnlock() - mixHash(&hash, &InitialHash, device.noise.publicKey[:]) + mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) mixHash(&hash, &hash, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) @@ -253,7 +254,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { var peerPK NoisePublicKey func() { var key [chacha20poly1305.KeySize]byte - ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) KDF2(&chainKey, &key, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) @@ -422,8 +423,8 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // lock private key for reading - device.noise.mutex.RLock() - defer device.noise.mutex.RUnlock() + device.staticIdentity.mutex.RLock() + defer device.staticIdentity.mutex.RUnlock() // finish 3-way DH @@ -437,7 +438,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { }() func() { - ss := device.noise.privateKey.sharedSecret(msg.Ephemeral) + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) }() @@ -490,7 +491,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { /* Derives a new keypair from the current handshake state * */ -func (peer *Peer) DeriveNewKeypair() error { +func (peer *Peer) BeginSymmetricSession() error { device := peer.device handshake := &peer.handshake handshake.mutex.Lock() @@ -552,50 +553,48 @@ func (peer *Peer) DeriveNewKeypair() error { // rotate key pairs - kp := &peer.keypairs - kp.mutex.Lock() + keypairs := &peer.keypairs + keypairs.mutex.Lock() + defer keypairs.mutex.Unlock() - peer.timersSessionDerived() - - previous := kp.previous - next := kp.next - current := kp.current + previous := keypairs.previous + next := keypairs.next + current := keypairs.current if isInitiator { if next != nil { - kp.next = nil - kp.previous = next + keypairs.next = nil + keypairs.previous = next device.DeleteKeypair(current) } else { - kp.previous = current + keypairs.previous = current } device.DeleteKeypair(previous) - kp.current = keypair + keypairs.current = keypair } else { - kp.next = keypair + keypairs.next = keypair device.DeleteKeypair(next) - kp.previous = nil + keypairs.previous = nil device.DeleteKeypair(previous) } - kp.mutex.Unlock() return nil } func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { - kp := &peer.keypairs - if kp.next != receivedKeypair { + keypairs := &peer.keypairs + if keypairs.next != receivedKeypair { return false } - kp.mutex.Lock() - defer kp.mutex.Unlock() - if kp.next != receivedKeypair { + keypairs.mutex.Lock() + defer keypairs.mutex.Unlock() + if keypairs.next != receivedKeypair { return false } - old := kp.previous - kp.previous = kp.current + old := keypairs.previous + keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - kp.current = kp.next - kp.next = nil + keypairs.current = keypairs.next + keypairs.next = nil return true } |