diff options
Diffstat (limited to 'src/noise_protocol.go')
-rw-r--r-- | src/noise_protocol.go | 210 |
1 files changed, 154 insertions, 56 deletions
diff --git a/src/noise_protocol.go b/src/noise_protocol.go index e7c8774..b9c8981 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -56,18 +56,22 @@ type MessageTransport struct { } type Handshake struct { - lock sync.Mutex - state int - chainKey [blake2s.Size]byte // chain key - hash [blake2s.Size]byte // hash value - staticStatic NoisePublicKey // precomputed DH(S_i, S_r) - ephemeral NoisePrivateKey // ephemeral secret key - remoteIndex uint32 // index for sending - device *Device - peer *Peer + state int + mutex sync.Mutex + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey NoiseSymmetricKey // psk + localEphemeral NoisePrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic NoisePublicKey // long term key + remoteEphemeral NoisePublicKey // ephemeral public key + precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret + lastTimestamp TAI64N } var ( + EmptyMessage []byte ZeroNonce [chacha20poly1305.NonceSize]byte InitalChainKey [blake2s.Size]byte InitalHash [blake2s.Size]byte @@ -78,102 +82,196 @@ func init() { InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...)) } -func (h *Handshake) Precompute() { - h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey) -} - -func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) { - -} - -func (h *Handshake) addHash(data []byte) { +func (h *Handshake) addToHash(data []byte) { h.hash = addToHash(h.hash, data) } -func (h *Handshake) addChain(data []byte) { +func (h *Handshake) addToChainKey(data []byte) { h.chainKey = addToChainKey(h.chainKey, data) } -func (h *Handshake) CreateMessageInital() (*MessageInital, error) { - h.lock.Lock() - defer h.lock.Unlock() +func (device *Device) Precompute(peer *Peer) { + h := &peer.handshake + h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) +} + +func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() - // reset handshake + // create ephemeral key var err error - h.ephemeral, err = newPrivateKey() + handshake.chainKey = InitalChainKey + handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:]) + handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err } - h.chainKey = InitalChainKey - h.hash = addToHash(InitalHash, h.device.publicKey[:]) - // create ephemeral key + // assign index var msg MessageInital + msg.Type = MessageInitalType - msg.Sender = h.device.NewID(h) - msg.Ephemeral = h.ephemeral.publicKey() - h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:]) - h.hash = addToHash(h.hash, msg.Ephemeral[:]) + msg.Ephemeral = handshake.localEphemeral.publicKey() + msg.Sender, err = device.indices.NewIndex(handshake) + + if err != nil { + return nil, err + } + + handshake.addToChainKey(msg.Ephemeral[:]) + handshake.addToHash(msg.Ephemeral[:]) // encrypt long-term "identity key" func() { var key [chacha20poly1305.KeySize]byte - ss := h.ephemeral.sharedSecret(h.peer.publicKey) - h.chainKey, key = KDF2(h.chainKey[:], ss[:]) + ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:]) aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:]) }() - h.addHash(msg.Static[:]) + handshake.addToHash(msg.Static[:]) // encrypt timestamp timestamp := Timestamp() func() { var key [chacha20poly1305.KeySize]byte - h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:]) + handshake.chainKey, key = KDF2( + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) }() - h.addHash(msg.Timestamp[:]) - h.state = HandshakeInitialCreated + + handshake.addToHash(msg.Timestamp[:]) + handshake.state = HandshakeInitialCreated + return &msg, nil } -func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error { +func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer { if msg.Type != MessageInitalType { panic(errors.New("bug: invalid inital message type")) } - hash := addToHash(InitalHash, h.device.publicKey[:]) - chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:]) + hash := addToHash(InitalHash, device.publicKey[:]) hash = addToHash(hash, msg.Ephemeral[:]) + chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:]) - // + // decrypt identity key - ephemeral, err := newPrivateKey() + var err error + var peerPK NoisePublicKey + func() { + var key [chacha20poly1305.KeySize]byte + ss := device.privateKey.sharedSecret(msg.Ephemeral) + chainKey, key = KDF2(chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) + }() if err != nil { - return err + return nil } + hash = addToHash(hash, msg.Static[:]) - // update handshake state + // find peer + + peer := device.LookupPeer(peerPK) + if peer == nil { + return nil + } + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + // decrypt timestamp + + var timestamp TAI64N + func() { + var key [chacha20poly1305.KeySize]byte + chainKey, key = KDF2( + chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) + }() + if err != nil { + return nil + } + hash = addToHash(hash, msg.Timestamp[:]) - h.lock.Lock() - defer h.lock.Unlock() + // check for replay attack - h.hash = hash - h.chainKey = chainKey - h.remoteIndex = msg.Sender - h.ephemeral = ephemeral - h.state = HandshakeInitialConsumed + if !timestamp.After(handshake.lastTimestamp) { + return nil + } + + // check for flood attack - return nil + // update handshake state + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.remoteEphemeral = msg.Ephemeral + handshake.state = HandshakeInitialConsumed + return peer } -func (h *Handshake) CreateMessageResponse() []byte { +func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if handshake.state != HandshakeInitialConsumed { + panic(errors.New("bug: handshake initation must be consumed first")) + } + + // assign index + + var err error + var msg MessageResponse + msg.Type = MessageResponseType + msg.Sender, err = device.indices.NewIndex(handshake) + msg.Reciever = handshake.remoteIndex + if err != nil { + return nil, err + } - return nil + // create ephemeral key + + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + msg.Ephemeral = handshake.localEphemeral.publicKey() + + func() { + ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + handshake.addToChainKey(ss[:]) + ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + handshake.addToChainKey(ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:]) + handshake.addToHash(tau[:]) + + func() { + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:]) + handshake.addToHash(msg.Empty[:]) + }() + + return &msg, nil } |