From 0294a5c0dd753786996e62236b7d8d524201ace4 Mon Sep 17 00:00:00 2001 From: Mathias Hall-Andersen Date: Fri, 1 Sep 2017 14:21:53 +0200 Subject: Improved handling of key-material --- src/noise_protocol.go | 136 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 97 insertions(+), 39 deletions(-) (limited to 'src/noise_protocol.go') diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 1f1301e..a50e3dc 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -109,27 +109,31 @@ var ( ZeroNonce [chacha20poly1305.NonceSize]byte ) -func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { - return KDF1(c[:], data) +func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { + KDF1(dst, c[:], data) } -func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { - return blake2s.Sum256(append(h[:], data...)) +func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { + hsh, _ := blake2s.New256(nil) + hsh.Write(h[:]) + hsh.Write(data) + hsh.Sum(dst[:0]) + hsh.Reset() } func (h *Handshake) mixHash(data []byte) { - h.hash = mixHash(h.hash, data) + mixHash(&h.hash, &h.hash, data) } func (h *Handshake) mixKey(data []byte) { - h.chainKey = mixKey(h.chainKey, data) + mixKey(&h.chainKey, &h.chainKey, data) } /* Do basic precomputations */ func init() { InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) - InitialHash = mixHash(InitialChainKey, []byte(WGIdentifier)) + mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { @@ -176,7 +180,12 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e func() { var key [chacha20poly1305.KeySize]byte ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) - handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:]) + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) aead, _ := chacha20poly1305.New(key[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) }() @@ -187,7 +196,9 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e timestamp := Timestamp() func() { var key [chacha20poly1305.KeySize]byte - handshake.chainKey, key = KDF2( + KDF2( + &handshake.chainKey, + &key, handshake.chainKey[:], handshake.precomputedStaticStatic[:], ) @@ -197,7 +208,6 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e handshake.mixHash(msg.Timestamp[:]) handshake.state = HandshakeInitiationCreated - return &msg, nil } @@ -206,9 +216,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { return nil } - hash := mixHash(InitialHash, device.publicKey[:]) - hash = mixHash(hash, msg.Ephemeral[:]) - chainKey := mixKey(InitialChainKey, msg.Ephemeral[:]) + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + mixHash(&hash, &InitialHash, device.publicKey[:]) + mixHash(&hash, &hash, msg.Ephemeral[:]) + mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key @@ -217,14 +232,14 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { func() { var key [chacha20poly1305.KeySize]byte ss := device.privateKey.sharedSecret(msg.Ephemeral) - chainKey, key = KDF2(chainKey[:], ss[:]) + KDF2(&chainKey, &key, chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) }() if err != nil { return nil } - hash = mixHash(hash, msg.Static[:]) + mixHash(&hash, &hash, msg.Static[:]) // lookup peer @@ -244,7 +259,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { var key [chacha20poly1305.KeySize]byte handshake.mutex.RLock() - chainKey, key = KDF2( + KDF2( + &chainKey, + &key, chainKey[:], handshake.precomputedStaticStatic[:], ) @@ -254,7 +271,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { handshake.mutex.RUnlock() return nil } - hash = mixHash(hash, msg.Timestamp[:]) + mixHash(&hash, &hash, msg.Timestamp[:]) // protect against replay & flood @@ -327,7 +344,15 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte - handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:]) + + KDF3( + &handshake.chainKey, + &tau, + &key, + handshake.chainKey[:], + handshake.presharedKey[:], + ) + handshake.mixHash(tau[:]) func() { @@ -337,6 +362,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error }() handshake.state = HandshakeResponseCreated + return &msg, nil } @@ -371,22 +397,33 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // finish 3-way DH - hash = mixHash(handshake.hash, msg.Ephemeral[:]) - chainKey = mixKey(handshake.chainKey, msg.Ephemeral[:]) + mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) + mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) func() { ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - chainKey = mixKey(chainKey, ss[:]) - ss = device.privateKey.sharedSecret(msg.Ephemeral) - chainKey = mixKey(chainKey, ss[:]) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + func() { + ss := device.privateKey.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) }() // add preshared key (psk) var tau [blake2s.Size]byte var key [chacha20poly1305.KeySize]byte - chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) - hash = mixHash(hash, tau[:]) + KDF3( + &chainKey, + &tau, + &key, + chainKey[:], + handshake.presharedKey[:], + ) + mixHash(&hash, &hash, tau[:]) // authenticate @@ -396,7 +433,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { device.log.Debug.Println("failed to open") return false } - hash = mixHash(hash, msg.Empty[:]) + mixHash(&hash, &hash, msg.Empty[:]) return true }() @@ -415,6 +452,9 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { handshake.mutex.Unlock() + setZero(hash[:]) + setZero(chainKey[:]) + return lookup.peer } @@ -422,6 +462,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { * */ func (peer *Peer) NewKeyPair() *KeyPair { + device := peer.device handshake := &peer.handshake handshake.mutex.Lock() defer handshake.mutex.Unlock() @@ -433,10 +474,20 @@ func (peer *Peer) NewKeyPair() *KeyPair { var recvKey [chacha20poly1305.KeySize]byte if handshake.state == HandshakeResponseConsumed { - sendKey, recvKey = KDF2(handshake.chainKey[:], nil) + KDF2( + &sendKey, + &recvKey, + handshake.chainKey[:], + nil, + ) isInitiator = true } else if handshake.state == HandshakeResponseCreated { - recvKey, sendKey = KDF2(handshake.chainKey[:], nil) + KDF2( + &recvKey, + &sendKey, + handshake.chainKey[:], + nil, + ) isInitiator = false } else { return nil @@ -444,16 +495,20 @@ func (peer *Peer) NewKeyPair() *KeyPair { // zero handshake - handshake.chainKey = [blake2s.Size]byte{} - handshake.localEphemeral = NoisePrivateKey{} + setZero(handshake.chainKey[:]) + setZero(handshake.localEphemeral[:]) peer.handshake.state = HandshakeZeroed // create AEAD instances keyPair := new(KeyPair) + keyPair.send.setKey(&sendKey) + keyPair.receive.setKey(&recvKey) + + setZero(sendKey[:]) + setZero(recvKey[:]) + keyPair.created = time.Now() - keyPair.send, _ = chacha20poly1305.New(sendKey[:]) - keyPair.receive, _ = chacha20poly1305.New(recvKey[:]) keyPair.sendNonce = 0 keyPair.replayFilter.Init() keyPair.isInitiator = isInitiator @@ -462,12 +517,14 @@ func (peer *Peer) NewKeyPair() *KeyPair { // remap index - indices := &peer.device.indices - indices.Insert(handshake.localIndex, IndexTableEntry{ - peer: peer, - keyPair: keyPair, - handshake: nil, - }) + device.indices.Insert( + handshake.localIndex, + IndexTableEntry{ + peer: peer, + keyPair: keyPair, + handshake: nil, + }, + ) handshake.localIndex = 0 // rotate key pairs @@ -479,7 +536,8 @@ func (peer *Peer) NewKeyPair() *KeyPair { // TODO: Adapt kernel behavior noise.c:161 if isInitiator { if kp.previous != nil { - indices.Delete(kp.previous.localIndex) + device.DeleteKeyPair(kp.previous) + kp.previous = nil } if kp.next != nil { -- cgit v1.2.3