diff options
Diffstat (limited to 'device/noise-protocol.go')
-rw-r--r-- | device/noise-protocol.go | 600 |
1 files changed, 600 insertions, 0 deletions
diff --git a/device/noise-protocol.go b/device/noise-protocol.go new file mode 100644 index 0000000..73826e1 --- /dev/null +++ b/device/noise-protocol.go @@ -0,0 +1,600 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" + "sync" + "time" +) + +const ( + HandshakeZeroed = iota + HandshakeInitiationCreated + HandshakeInitiationConsumed + HandshakeResponseCreated + HandshakeResponseConsumed +) + +const ( + NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" + WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" + WGLabelMAC1 = "mac1----" + WGLabelCookie = "cookie--" +) + +const ( + MessageInitiationType = 1 + MessageResponseType = 2 + MessageCookieReplyType = 3 + MessageTransportType = 4 +) + +const ( + MessageInitiationSize = 148 // size of handshake initation message + MessageResponseSize = 92 // size of response message + MessageCookieReplySize = 64 // size of cookie reply message + MessageTransportHeaderSize = 16 // size of data preceeding content in transport message + MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport + MessageKeepaliveSize = MessageTransportSize // size of keepalive + MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message +) + +const ( + MessageTransportOffsetReceiver = 4 + MessageTransportOffsetCounter = 8 + MessageTransportOffsetContent = 16 +) + +/* Type is an 8-bit field, followed by 3 nul bytes, + * by marshalling the messages in little-endian byteorder + * we can treat these as a 32-bit unsigned int (for now) + * + */ + +type MessageInitiation struct { + Type uint32 + Sender uint32 + Ephemeral NoisePublicKey + Static [NoisePublicKeySize + poly1305.TagSize]byte + Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte + MAC1 [blake2s.Size128]byte + MAC2 [blake2s.Size128]byte +} + +type MessageResponse struct { + Type uint32 + Sender uint32 + Receiver uint32 + Ephemeral NoisePublicKey + Empty [poly1305.TagSize]byte + MAC1 [blake2s.Size128]byte + MAC2 [blake2s.Size128]byte +} + +type MessageTransport struct { + Type uint32 + Receiver uint32 + Counter uint64 + Content []byte +} + +type MessageCookieReply struct { + Type uint32 + Receiver uint32 + Nonce [chacha20poly1305.NonceSizeX]byte + Cookie [blake2s.Size128 + poly1305.TagSize]byte +} + +type Handshake struct { + state int + mutex sync.RWMutex + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey NoiseSymmetricKey // psk + localEphemeral NoisePrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic NoisePublicKey // long term key + remoteEphemeral NoisePublicKey // ephemeral public key + precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret + lastTimestamp tai64n.Timestamp + lastInitiationConsumption time.Time + lastSentHandshake time.Time +} + +var ( + InitialChainKey [blake2s.Size]byte + InitialHash [blake2s.Size]byte + ZeroNonce [chacha20poly1305.NonceSize]byte +) + +func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { + KDF1(dst, c[:], data) +} + +func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { + hash, _ := blake2s.New256(nil) + hash.Write(h[:]) + hash.Write(data) + hash.Sum(dst[:0]) + hash.Reset() +} + +func (h *Handshake) Clear() { + setZero(h.localEphemeral[:]) + setZero(h.remoteEphemeral[:]) + setZero(h.chainKey[:]) + setZero(h.hash[:]) + h.localIndex = 0 + h.state = HandshakeZeroed +} + +func (h *Handshake) mixHash(data []byte) { + mixHash(&h.hash, &h.hash, data) +} + +func (h *Handshake) mixKey(data []byte) { + mixKey(&h.chainKey, &h.chainKey, data) +} + +/* Do basic precomputations + */ +func init() { + InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) + mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) +} + +func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if isZero(handshake.precomputedStaticStatic[:]) { + return nil, errors.New("static shared secret is zero") + } + + // create ephemeral key + + var err error + handshake.hash = InitialHash + handshake.chainKey = InitialChainKey + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + + // assign index + + device.indexTable.Delete(handshake.localIndex) + handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) + + if err != nil { + return nil, err + } + + handshake.mixHash(handshake.remoteStatic[:]) + + msg := MessageInitiation{ + Type: MessageInitiationType, + Ephemeral: handshake.localEphemeral.publicKey(), + Sender: handshake.localIndex, + } + + handshake.mixKey(msg.Ephemeral[:]) + handshake.mixHash(msg.Ephemeral[:]) + + // encrypt static key + + func() { + var key [chacha20poly1305.KeySize]byte + ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) + }() + handshake.mixHash(msg.Static[:]) + + // encrypt timestamp + + timestamp := tai64n.Now() + func() { + var key [chacha20poly1305.KeySize]byte + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) + }() + + handshake.mixHash(msg.Timestamp[:]) + handshake.state = HandshakeInitiationCreated + return &msg, nil +} + +func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + if msg.Type != MessageInitiationType { + return nil + } + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) + mixHash(&hash, &hash, msg.Ephemeral[:]) + mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) + + // decrypt static key + + var err error + var peerPK NoisePublicKey + func() { + var key [chacha20poly1305.KeySize]byte + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + KDF2(&chainKey, &key, chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) + }() + if err != nil { + return nil + } + mixHash(&hash, &hash, msg.Static[:]) + + // lookup peer + + peer := device.LookupPeer(peerPK) + if peer == nil { + return nil + } + + handshake := &peer.handshake + if isZero(handshake.precomputedStaticStatic[:]) { + return nil + } + + // verify identity + + var timestamp tai64n.Timestamp + var key [chacha20poly1305.KeySize]byte + + handshake.mutex.RLock() + KDF2( + &chainKey, + &key, + chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) + if err != nil { + handshake.mutex.RUnlock() + return nil + } + mixHash(&hash, &hash, msg.Timestamp[:]) + + // protect against replay & flood + + var ok bool + ok = timestamp.After(handshake.lastTimestamp) + ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate + handshake.mutex.RUnlock() + if !ok { + return nil + } + + // update handshake state + + handshake.mutex.Lock() + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.remoteEphemeral = msg.Ephemeral + handshake.lastTimestamp = timestamp + handshake.lastInitiationConsumption = time.Now() + handshake.state = HandshakeInitiationConsumed + + handshake.mutex.Unlock() + + setZero(hash[:]) + setZero(chainKey[:]) + + return peer +} + +func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if handshake.state != HandshakeInitiationConsumed { + return nil, errors.New("handshake initiation must be consumed first") + } + + // assign index + + var err error + device.indexTable.Delete(handshake.localIndex) + handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) + if err != nil { + return nil, err + } + + var msg MessageResponse + msg.Type = MessageResponseType + msg.Sender = handshake.localIndex + msg.Receiver = handshake.remoteIndex + + // create ephemeral key + + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + msg.Ephemeral = handshake.localEphemeral.publicKey() + handshake.mixHash(msg.Ephemeral[:]) + handshake.mixKey(msg.Ephemeral[:]) + + func() { + ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + handshake.mixKey(ss[:]) + ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + handshake.mixKey(ss[:]) + }() + + // add preshared key + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + + KDF3( + &handshake.chainKey, + &tau, + &key, + handshake.chainKey[:], + handshake.presharedKey[:], + ) + + handshake.mixHash(tau[:]) + + func() { + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) + }() + + handshake.state = HandshakeResponseCreated + + return &msg, nil +} + +func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { + if msg.Type != MessageResponseType { + return nil + } + + // lookup handshake by receiver + + lookup := device.indexTable.Lookup(msg.Receiver) + handshake := lookup.handshake + if handshake == nil { + return nil + } + + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + ok := func() bool { + + // lock handshake state + + handshake.mutex.RLock() + defer handshake.mutex.RUnlock() + + if handshake.state != HandshakeInitiationCreated { + return false + } + + // lock private key for reading + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + // finish 3-way DH + + mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) + mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) + + func() { + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + func() { + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + KDF3( + &chainKey, + &tau, + &key, + chainKey[:], + handshake.presharedKey[:], + ) + mixHash(&hash, &hash, tau[:]) + + // authenticate transcript + + aead, _ := chacha20poly1305.New(key[:]) + _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + if err != nil { + return false + } + mixHash(&hash, &hash, msg.Empty[:]) + return true + }() + + if !ok { + return nil + } + + // update handshake state + + handshake.mutex.Lock() + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.state = HandshakeResponseConsumed + + handshake.mutex.Unlock() + + setZero(hash[:]) + setZero(chainKey[:]) + + return lookup.peer +} + +/* Derives a new keypair from the current handshake state + * + */ +func (peer *Peer) BeginSymmetricSession() error { + device := peer.device + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + // derive keys + + var isInitiator bool + var sendKey [chacha20poly1305.KeySize]byte + var recvKey [chacha20poly1305.KeySize]byte + + if handshake.state == HandshakeResponseConsumed { + KDF2( + &sendKey, + &recvKey, + handshake.chainKey[:], + nil, + ) + isInitiator = true + } else if handshake.state == HandshakeResponseCreated { + KDF2( + &recvKey, + &sendKey, + handshake.chainKey[:], + nil, + ) + isInitiator = false + } else { + return errors.New("invalid state for keypair derivation") + } + + // zero handshake + + setZero(handshake.chainKey[:]) + setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. + setZero(handshake.localEphemeral[:]) + peer.handshake.state = HandshakeZeroed + + // create AEAD instances + + keypair := new(Keypair) + keypair.send, _ = chacha20poly1305.New(sendKey[:]) + keypair.receive, _ = chacha20poly1305.New(recvKey[:]) + + setZero(sendKey[:]) + setZero(recvKey[:]) + + keypair.created = time.Now() + keypair.sendNonce = 0 + keypair.replayFilter.Init() + keypair.isInitiator = isInitiator + keypair.localIndex = peer.handshake.localIndex + keypair.remoteIndex = peer.handshake.remoteIndex + + // remap index + + device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) + handshake.localIndex = 0 + + // rotate key pairs + + keypairs := &peer.keypairs + keypairs.Lock() + defer keypairs.Unlock() + + previous := keypairs.previous + next := keypairs.next + current := keypairs.current + + if isInitiator { + if next != nil { + keypairs.next = nil + keypairs.previous = next + device.DeleteKeypair(current) + } else { + keypairs.previous = current + } + device.DeleteKeypair(previous) + keypairs.current = keypair + } else { + keypairs.next = keypair + device.DeleteKeypair(next) + keypairs.previous = nil + device.DeleteKeypair(previous) + } + + return nil +} + +func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { + keypairs := &peer.keypairs + if keypairs.next != receivedKeypair { + return false + } + keypairs.Lock() + defer keypairs.Unlock() + if keypairs.next != receivedKeypair { + return false + } + old := keypairs.previous + keypairs.previous = keypairs.current + peer.device.DeleteKeypair(old) + keypairs.current = keypairs.next + keypairs.next = nil + return true +} |