summaryrefslogtreecommitdiffhomepage
path: root/src/noise_protocol.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/noise_protocol.go')
-rw-r--r--src/noise_protocol.go210
1 files changed, 154 insertions, 56 deletions
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index e7c8774..b9c8981 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -56,18 +56,22 @@ type MessageTransport struct {
}
type Handshake struct {
- lock sync.Mutex
- state int
- chainKey [blake2s.Size]byte // chain key
- hash [blake2s.Size]byte // hash value
- staticStatic NoisePublicKey // precomputed DH(S_i, S_r)
- ephemeral NoisePrivateKey // ephemeral secret key
- remoteIndex uint32 // index for sending
- device *Device
- peer *Peer
+ state int
+ mutex sync.Mutex
+ hash [blake2s.Size]byte // hash value
+ chainKey [blake2s.Size]byte // chain key
+ presharedKey NoiseSymmetricKey // psk
+ localEphemeral NoisePrivateKey // ephemeral secret key
+ localIndex uint32 // used to clear hash-table
+ remoteIndex uint32 // index for sending
+ remoteStatic NoisePublicKey // long term key
+ remoteEphemeral NoisePublicKey // ephemeral public key
+ precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
+ lastTimestamp TAI64N
}
var (
+ EmptyMessage []byte
ZeroNonce [chacha20poly1305.NonceSize]byte
InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte
@@ -78,102 +82,196 @@ func init() {
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
}
-func (h *Handshake) Precompute() {
- h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey)
-}
-
-func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) {
-
-}
-
-func (h *Handshake) addHash(data []byte) {
+func (h *Handshake) addToHash(data []byte) {
h.hash = addToHash(h.hash, data)
}
-func (h *Handshake) addChain(data []byte) {
+func (h *Handshake) addToChainKey(data []byte) {
h.chainKey = addToChainKey(h.chainKey, data)
}
-func (h *Handshake) CreateMessageInital() (*MessageInital, error) {
- h.lock.Lock()
- defer h.lock.Unlock()
+func (device *Device) Precompute(peer *Peer) {
+ h := &peer.handshake
+ h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+}
+
+func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
- // reset handshake
+ // create ephemeral key
var err error
- h.ephemeral, err = newPrivateKey()
+ handshake.chainKey = InitalChainKey
+ handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
+ handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
- h.chainKey = InitalChainKey
- h.hash = addToHash(InitalHash, h.device.publicKey[:])
- // create ephemeral key
+ // assign index
var msg MessageInital
+
msg.Type = MessageInitalType
- msg.Sender = h.device.NewID(h)
- msg.Ephemeral = h.ephemeral.publicKey()
- h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:])
- h.hash = addToHash(h.hash, msg.Ephemeral[:])
+ msg.Ephemeral = handshake.localEphemeral.publicKey()
+ msg.Sender, err = device.indices.NewIndex(handshake)
+
+ if err != nil {
+ return nil, err
+ }
+
+ handshake.addToChainKey(msg.Ephemeral[:])
+ handshake.addToHash(msg.Ephemeral[:])
// encrypt long-term "identity key"
func() {
var key [chacha20poly1305.KeySize]byte
- ss := h.ephemeral.sharedSecret(h.peer.publicKey)
- h.chainKey, key = KDF2(h.chainKey[:], ss[:])
+ ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil)
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
}()
- h.addHash(msg.Static[:])
+ handshake.addToHash(msg.Static[:])
// encrypt timestamp
timestamp := Timestamp()
func() {
var key [chacha20poly1305.KeySize]byte
- h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:])
+ handshake.chainKey, key = KDF2(
+ handshake.chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil)
+ aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
- h.addHash(msg.Timestamp[:])
- h.state = HandshakeInitialCreated
+
+ handshake.addToHash(msg.Timestamp[:])
+ handshake.state = HandshakeInitialCreated
+
return &msg, nil
}
-func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error {
+func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
if msg.Type != MessageInitalType {
panic(errors.New("bug: invalid inital message type"))
}
- hash := addToHash(InitalHash, h.device.publicKey[:])
- chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
+ hash := addToHash(InitalHash, device.publicKey[:])
hash = addToHash(hash, msg.Ephemeral[:])
+ chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
- //
+ // decrypt identity key
- ephemeral, err := newPrivateKey()
+ var err error
+ var peerPK NoisePublicKey
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ ss := device.privateKey.sharedSecret(msg.Ephemeral)
+ chainKey, key = KDF2(chainKey[:], ss[:])
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
+ }()
if err != nil {
- return err
+ return nil
}
+ hash = addToHash(hash, msg.Static[:])
- // update handshake state
+ // find peer
+
+ peer := device.LookupPeer(peerPK)
+ if peer == nil {
+ return nil
+ }
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+
+ // decrypt timestamp
+
+ var timestamp TAI64N
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ chainKey, key = KDF2(
+ chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
+ }()
+ if err != nil {
+ return nil
+ }
+ hash = addToHash(hash, msg.Timestamp[:])
- h.lock.Lock()
- defer h.lock.Unlock()
+ // check for replay attack
- h.hash = hash
- h.chainKey = chainKey
- h.remoteIndex = msg.Sender
- h.ephemeral = ephemeral
- h.state = HandshakeInitialConsumed
+ if !timestamp.After(handshake.lastTimestamp) {
+ return nil
+ }
+
+ // check for flood attack
- return nil
+ // update handshake state
+ handshake.hash = hash
+ handshake.chainKey = chainKey
+ handshake.remoteIndex = msg.Sender
+ handshake.remoteEphemeral = msg.Ephemeral
+ handshake.state = HandshakeInitialConsumed
+ return peer
}
-func (h *Handshake) CreateMessageResponse() []byte {
+func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+
+ if handshake.state != HandshakeInitialConsumed {
+ panic(errors.New("bug: handshake initation must be consumed first"))
+ }
+
+ // assign index
+
+ var err error
+ var msg MessageResponse
+ msg.Type = MessageResponseType
+ msg.Sender, err = device.indices.NewIndex(handshake)
+ msg.Reciever = handshake.remoteIndex
+ if err != nil {
+ return nil, err
+ }
- return nil
+ // create ephemeral key
+
+ handshake.localEphemeral, err = newPrivateKey()
+ if err != nil {
+ return nil, err
+ }
+ msg.Ephemeral = handshake.localEphemeral.publicKey()
+
+ func() {
+ ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ handshake.addToChainKey(ss[:])
+ ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ handshake.addToChainKey(ss[:])
+ }()
+
+ // add preshared key (psk)
+
+ var tau [blake2s.Size]byte
+ var key [chacha20poly1305.KeySize]byte
+ handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
+ handshake.addToHash(tau[:])
+
+ func() {
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
+ handshake.addToHash(msg.Empty[:])
+ }()
+
+ return &msg, nil
}