summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/index.go17
-rw-r--r--src/keypair.go8
-rw-r--r--src/noise_helpers.go17
-rw-r--r--src/noise_protocol.go125
-rw-r--r--src/noise_test.go68
5 files changed, 191 insertions, 44 deletions
diff --git a/src/index.go b/src/index.go
index 83a7e29..81f71e9 100644
--- a/src/index.go
+++ b/src/index.go
@@ -7,12 +7,14 @@ import (
/* Index=0 is reserved for unset indecies
*
+ * TODO: Rethink map[id] -> peer VS map[id] -> handshake and handshake <ref> peer
+ *
*/
type IndexTable struct {
mutex sync.RWMutex
keypairs map[uint32]*KeyPair
- handshakes map[uint32]*Handshake
+ handshakes map[uint32]*Peer
}
func randUint32() (uint32, error) {
@@ -32,10 +34,10 @@ func (table *IndexTable) Init() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.keypairs = make(map[uint32]*KeyPair)
- table.handshakes = make(map[uint32]*Handshake)
+ table.handshakes = make(map[uint32]*Peer)
}
-func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
+func (table *IndexTable) NewIndex(peer *Peer) (uint32, error) {
table.mutex.Lock()
defer table.mutex.Unlock()
for {
@@ -60,11 +62,10 @@ func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
continue
}
- // update the index
+ // clean old index
- delete(table.handshakes, handshake.localIndex)
- handshake.localIndex = id
- table.handshakes[id] = handshake
+ delete(table.handshakes, peer.handshake.localIndex)
+ table.handshakes[id] = peer
return id, nil
}
}
@@ -75,7 +76,7 @@ func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
return table.keypairs[id]
}
-func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
+func (table *IndexTable) LookupHandshake(id uint32) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.handshakes[id]
diff --git a/src/keypair.go b/src/keypair.go
index 22a8244..e434c74 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -5,8 +5,8 @@ import (
)
type KeyPair struct {
- recieveKey cipher.AEAD
- recieveNonce NoiseNonce
- sendKey cipher.AEAD
- sendNonce NoiseNonce
+ recv cipher.AEAD
+ recvNonce NoiseNonce
+ send cipher.AEAD
+ sendNonce NoiseNonce
}
diff --git a/src/noise_helpers.go b/src/noise_helpers.go
index eadbc07..e163ace 100644
--- a/src/noise_helpers.go
+++ b/src/noise_helpers.go
@@ -45,22 +45,7 @@ func KDF3(key []byte, input []byte) (t0 [blake2s.Size]byte, t1 [blake2s.Size]byt
return
}
-/*
- *
- */
-
-func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
- return KDF1(c[:], data)
-}
-
-func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
- return blake2s.Sum256(append(h[:], data...))
-}
-
-/* Curve25519 wrappers
- *
- * TODO: Rethink this
- */
+/* curve25519 wrappers */
func newPrivateKey() (sk NoisePrivateKey, err error) {
// clamping: https://cr.yp.to/ecdh.html
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index b9c8981..7f26cf1 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -9,9 +9,11 @@ import (
)
const (
- HandshakeInitialCreated = iota
+ HandshakeReset = iota
+ HandshakeInitialCreated
HandshakeInitialConsumed
HandshakeResponseCreated
+ HandshakeResponseConsumed
)
const (
@@ -71,7 +73,6 @@ type Handshake struct {
}
var (
- EmptyMessage []byte
ZeroNonce [chacha20poly1305.NonceSize]byte
InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte
@@ -82,6 +83,14 @@ func init() {
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
}
+func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+ return KDF1(c[:], data)
+}
+
+func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+ return blake2s.Sum256(append(h[:], data...))
+}
+
func (h *Handshake) addToHash(data []byte) {
h.hash = addToHash(h.hash, data)
}
@@ -90,11 +99,6 @@ func (h *Handshake) addToChainKey(data []byte) {
h.chainKey = addToChainKey(h.chainKey, data)
}
-func (device *Device) Precompute(peer *Peer) {
- h := &peer.handshake
- h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
-}
-
func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
@@ -116,16 +120,17 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
msg.Type = MessageInitalType
msg.Ephemeral = handshake.localEphemeral.publicKey()
- msg.Sender, err = device.indices.NewIndex(handshake)
+ handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
}
+ msg.Sender = handshake.localIndex
handshake.addToChainKey(msg.Ephemeral[:])
handshake.addToHash(msg.Ephemeral[:])
- // encrypt long-term "identity key"
+ // encrypt identity key
func() {
var key [chacha20poly1305.KeySize]byte
@@ -221,6 +226,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
+ handshake.lastTimestamp = timestamp
handshake.state = HandshakeInitialConsumed
return peer
}
@@ -237,14 +243,16 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// assign index
var err error
- var msg MessageResponse
- msg.Type = MessageResponseType
- msg.Sender, err = device.indices.NewIndex(handshake)
- msg.Reciever = handshake.remoteIndex
+ handshake.localIndex, err = device.indices.NewIndex(peer)
if err != nil {
return nil, err
}
+ var msg MessageResponse
+ msg.Type = MessageResponseType
+ msg.Sender = handshake.localIndex
+ msg.Reciever = handshake.remoteIndex
+
// create ephemeral key
handshake.localEphemeral, err = newPrivateKey()
@@ -252,6 +260,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
return nil, err
}
msg.Ephemeral = handshake.localEphemeral.publicKey()
+ handshake.addToHash(msg.Ephemeral[:])
func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
@@ -269,9 +278,97 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
func() {
aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.addToHash(msg.Empty[:])
}()
+ handshake.state = HandshakeResponseCreated
return &msg, nil
}
+
+func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
+ if msg.Type != MessageResponseType {
+ panic(errors.New("bug: invalid message type"))
+ }
+
+ // lookup handshake by reciever
+
+ peer := device.indices.LookupHandshake(msg.Reciever)
+ if peer == nil {
+ return nil
+ }
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+ if handshake.state != HandshakeInitialCreated {
+ return nil
+ }
+
+ // finish 3-way DH
+
+ hash := addToHash(handshake.hash, msg.Ephemeral[:])
+ chainKey := handshake.chainKey
+
+ func() {
+ ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ chainKey = addToChainKey(chainKey, ss[:])
+ ss = device.privateKey.sharedSecret(msg.Ephemeral)
+ chainKey = addToChainKey(chainKey, ss[:])
+ }()
+
+ // add preshared key (psk)
+
+ var tau [blake2s.Size]byte
+ var key [chacha20poly1305.KeySize]byte
+ chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
+ hash = addToHash(hash, tau[:])
+
+ // authenticate
+
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ if err != nil {
+ return nil
+ }
+ hash = addToHash(hash, msg.Empty[:])
+
+ // update handshake state
+
+ handshake.hash = hash
+ handshake.chainKey = chainKey
+ handshake.remoteIndex = msg.Sender
+ handshake.state = HandshakeResponseConsumed
+
+ return peer
+}
+
+func (peer *Peer) NewKeyPair() *KeyPair {
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+
+ // derive keys
+
+ var sendKey [chacha20poly1305.KeySize]byte
+ var recvKey [chacha20poly1305.KeySize]byte
+
+ if handshake.state == HandshakeResponseConsumed {
+ sendKey, recvKey = KDF2(handshake.chainKey[:], nil)
+ } else if handshake.state == HandshakeResponseCreated {
+ recvKey, sendKey = KDF2(handshake.chainKey[:], nil)
+ } else {
+ return nil
+ }
+
+ // create AEAD instances
+
+ var keyPair KeyPair
+ keyPair.send, _ = chacha20poly1305.New(sendKey[:])
+ keyPair.recv, _ = chacha20poly1305.New(recvKey[:])
+ keyPair.sendNonce = 0
+ keyPair.recvNonce = 0
+
+ peer.handshake.state = HandshakeReset
+
+ return &keyPair
+}
diff --git a/src/noise_test.go b/src/noise_test.go
index 8d6a0fa..ddabf8e 100644
--- a/src/noise_test.go
+++ b/src/noise_test.go
@@ -63,7 +63,9 @@ func TestNoiseHandshake(t *testing.T) {
/* simulate handshake */
- // Initiation message
+ // initiation message
+
+ t.Log("exchange initiation message")
msg1, err := dev1.CreateMessageInitial(peer2)
assertNil(t, err)
@@ -88,6 +90,68 @@ func TestNoiseHandshake(t *testing.T) {
peer2.handshake.hash[:],
)
- // Response message
+ // response message
+
+ t.Log("exchange response message")
+
+ msg2, err := dev2.CreateMessageResponse(peer1)
+ assertNil(t, err)
+
+ peer = dev1.ConsumeMessageResponse(msg2)
+ if peer == nil {
+ t.Fatal("handshake failed at response message")
+ }
+
+ assertEqual(
+ t,
+ peer1.handshake.chainKey[:],
+ peer2.handshake.chainKey[:],
+ )
+
+ assertEqual(
+ t,
+ peer1.handshake.hash[:],
+ peer2.handshake.hash[:],
+ )
+
+ // key pairs
+
+ t.Log("deriving keys")
+
+ key1 := peer1.NewKeyPair()
+ key2 := peer2.NewKeyPair()
+
+ if key1 == nil {
+ t.Fatal("failed to dervice key-pair for peer 1")
+ }
+
+ if key2 == nil {
+ t.Fatal("failed to dervice key-pair for peer 2")
+ }
+ // encrypting / decryption test
+
+ t.Log("test key pairs")
+
+ func() {
+ testMsg := []byte("wireguard test message 1")
+ var err error
+ var out []byte
+ var nonce [12]byte
+ out = key1.send.Seal(out, nonce[:], testMsg, nil)
+ out, err = key2.recv.Open(out[:0], nonce[:], out, nil)
+ assertNil(t, err)
+ assertEqual(t, out, testMsg)
+ }()
+
+ func() {
+ testMsg := []byte("wireguard test message 2")
+ var err error
+ var out []byte
+ var nonce [12]byte
+ out = key2.send.Seal(out, nonce[:], testMsg, nil)
+ out, err = key1.recv.Open(out[:0], nonce[:], out, nil)
+ assertNil(t, err)
+ assertEqual(t, out, testMsg)
+ }()
}