diff options
Diffstat (limited to 'src/noise_protocol.go')
-rw-r--r-- | src/noise_protocol.go | 125 |
1 files changed, 111 insertions, 14 deletions
diff --git a/src/noise_protocol.go b/src/noise_protocol.go index b9c8981..7f26cf1 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -9,9 +9,11 @@ import ( ) const ( - HandshakeInitialCreated = iota + HandshakeReset = iota + HandshakeInitialCreated HandshakeInitialConsumed HandshakeResponseCreated + HandshakeResponseConsumed ) const ( @@ -71,7 +73,6 @@ type Handshake struct { } var ( - EmptyMessage []byte ZeroNonce [chacha20poly1305.NonceSize]byte InitalChainKey [blake2s.Size]byte InitalHash [blake2s.Size]byte @@ -82,6 +83,14 @@ func init() { InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) } +func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte { + return KDF1(c[:], data) +} + +func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte { + return blake2s.Sum256(append(h[:], data...)) +} + func (h *Handshake) addToHash(data []byte) { h.hash = addToHash(h.hash, data) } @@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) { h.chainKey = addToChainKey(h.chainKey, data) } -func (device *Device) Precompute(peer *Peer) { - h := &peer.handshake - h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) -} - func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { handshake := &peer.handshake handshake.mutex.Lock() @@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { msg.Type = MessageInitalType msg.Ephemeral = handshake.localEphemeral.publicKey() - msg.Sender, err = device.indices.NewIndex(handshake) + handshake.localIndex, err = device.indices.NewIndex(peer) if err != nil { return nil, err } + msg.Sender = handshake.localIndex handshake.addToChainKey(msg.Ephemeral[:]) handshake.addToHash(msg.Ephemeral[:]) - // encrypt long-term "identity key" + // encrypt identity key func() { var key [chacha20poly1305.KeySize]byte @@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender handshake.remoteEphemeral = msg.Ephemeral + handshake.lastTimestamp = timestamp handshake.state = HandshakeInitialConsumed return peer } @@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error // assign index var err error - var msg MessageResponse - msg.Type = MessageResponseType - msg.Sender, err = device.indices.NewIndex(handshake) - msg.Reciever = handshake.remoteIndex + handshake.localIndex, err = device.indices.NewIndex(peer) if err != nil { return nil, err } + var msg MessageResponse + msg.Type = MessageResponseType + msg.Sender = handshake.localIndex + msg.Reciever = handshake.remoteIndex + // create ephemeral key handshake.localEphemeral, err = newPrivateKey() @@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error return nil, err } msg.Ephemeral = handshake.localEphemeral.publicKey() + handshake.addToHash(msg.Ephemeral[:]) func() { ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) @@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error func() { aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) handshake.addToHash(msg.Empty[:]) }() + handshake.state = HandshakeResponseCreated return &msg, nil } + +func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { + if msg.Type != MessageResponseType { + panic(errors.New("bug: invalid message type")) + } + + // lookup handshake by reciever + + peer := device.indices.LookupHandshake(msg.Reciever) + if peer == nil { + return nil + } + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + if handshake.state != HandshakeInitialCreated { + return nil + } + + // finish 3-way DH + + hash := addToHash(handshake.hash, msg.Ephemeral[:]) + chainKey := handshake.chainKey + + func() { + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + chainKey = addToChainKey(chainKey, ss[:]) + ss = device.privateKey.sharedSecret(msg.Ephemeral) + chainKey = addToChainKey(chainKey, ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) + hash = addToHash(hash, tau[:]) + + // authenticate + + aead, _ := chacha20poly1305.New(key[:]) + _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + if err != nil { + return nil + } + hash = addToHash(hash, msg.Empty[:]) + + // update handshake state + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.state = HandshakeResponseConsumed + + return peer +} + +func (peer *Peer) NewKeyPair() *KeyPair { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + // derive keys + + var sendKey [chacha20poly1305.KeySize]byte + var recvKey [chacha20poly1305.KeySize]byte + + if handshake.state == HandshakeResponseConsumed { + sendKey, recvKey = KDF2(handshake.chainKey[:], nil) + } else if handshake.state == HandshakeResponseCreated { + recvKey, sendKey = KDF2(handshake.chainKey[:], nil) + } else { + return nil + } + + // create AEAD instances + + var keyPair KeyPair + keyPair.send, _ = chacha20poly1305.New(sendKey[:]) + keyPair.recv, _ = chacha20poly1305.New(recvKey[:]) + keyPair.sendNonce = 0 + keyPair.recvNonce = 0 + + peer.handshake.state = HandshakeReset + + return &keyPair +} |