summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-06-24 15:34:17 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-06-24 15:34:44 +0200
commit25190e43369a79dc77a740dc8cd28b8a9fcb235e (patch)
treeb7057627e0710fe9ef40c077a204904c78bed9cc /src
parent521e77fd54fba275405affd790ac91f7998e4559 (diff)
Restructuring of noise impl.
Diffstat (limited to 'src')
-rw-r--r--src/config.go12
-rw-r--r--src/device.go75
-rw-r--r--src/index.go82
-rw-r--r--src/keypair.go12
-rw-r--r--src/noise_helpers.go2
-rw-r--r--src/noise_protocol.go210
-rw-r--r--src/noise_test.go91
-rw-r--r--src/peer.go40
-rw-r--r--src/routing.go7
-rw-r--r--src/tai64.go5
10 files changed, 414 insertions, 122 deletions
diff --git a/src/config.go b/src/config.go
index a61b940..8865194 100644
--- a/src/config.go
+++ b/src/config.go
@@ -99,11 +99,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
if ok {
peer = found
} else {
- newPeer := &Peer{
- publicKey: pubKey,
- }
- peer = newPeer
- device.peers[pubKey] = newPeer
+ peer = device.NewPeer(pubKey)
}
case "replace_peers":
@@ -125,14 +121,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "remove":
peer.mutex.Lock()
- device.RemovePeer(peer.publicKey)
+ // device.RemovePeer(peer.publicKey)
peer = nil
case "preshared_key":
err := func() error {
peer.mutex.Lock()
defer peer.mutex.Unlock()
- return peer.presharedKey.FromHex(value)
+ return peer.handshake.presharedKey.FromHex(value)
}()
if err != nil {
return &IPCError{Code: ipcErrorInvalidPublicKey}
@@ -144,7 +140,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalidIPAddress}
}
peer.mutex.Lock()
- peer.endpoint = ip
+ // peer.endpoint = ip FIX
peer.mutex.Unlock()
case "persistent_keepalive_interval":
diff --git a/src/device.go b/src/device.go
index 9f1daa6..9969034 100644
--- a/src/device.go
+++ b/src/device.go
@@ -1,17 +1,13 @@
package main
import (
- "math/rand"
"sync"
)
-/* TODO: Locking may be a little broad here
- */
-
type Device struct {
mutex sync.RWMutex
peers map[NoisePublicKey]*Peer
- sessions map[uint32]*Handshake
+ indices IndexTable
privateKey NoisePrivateKey
publicKey NoisePublicKey
fwMark uint32
@@ -19,43 +15,66 @@ type Device struct {
routingTable RoutingTable
}
-func (dev *Device) NewID(h *Handshake) uint32 {
- dev.mutex.Lock()
- defer dev.mutex.Unlock()
- for {
- id := rand.Uint32()
- _, ok := dev.sessions[id]
- if !ok {
- dev.sessions[id] = h
- return id
- }
+func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
+ // update key material
+
+ device.privateKey = sk
+ device.publicKey = sk.publicKey()
+
+ // do precomputations
+
+ for _, peer := range device.peers {
+ h := &peer.handshake
+ h.mutex.Lock()
+ h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+ h.mutex.Unlock()
}
}
-func (dev *Device) RemovePeer(key NoisePublicKey) {
- dev.mutex.Lock()
- defer dev.mutex.Unlock()
- peer, ok := dev.peers[key]
+func (device *Device) Init() {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
+ device.peers = make(map[NoisePublicKey]*Peer)
+ device.indices.Init()
+ device.listenPort = 0
+ device.routingTable.Reset()
+}
+
+func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
+ device.mutex.RLock()
+ defer device.mutex.RUnlock()
+ return device.peers[pk]
+}
+
+func (device *Device) RemovePeer(key NoisePublicKey) {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
+ peer, ok := device.peers[key]
if !ok {
return
}
peer.mutex.Lock()
- dev.routingTable.RemovePeer(peer)
- delete(dev.peers, key)
+ device.routingTable.RemovePeer(peer)
+ delete(device.peers, key)
}
-func (dev *Device) RemoveAllAllowedIps(peer *Peer) {
+func (device *Device) RemoveAllAllowedIps(peer *Peer) {
}
-func (dev *Device) RemoveAllPeers() {
- dev.mutex.Lock()
- defer dev.mutex.Unlock()
+func (device *Device) RemoveAllPeers() {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
- for key, peer := range dev.peers {
+ for key, peer := range device.peers {
peer.mutex.Lock()
- dev.routingTable.RemovePeer(peer)
- delete(dev.peers, key)
+ device.routingTable.RemovePeer(peer)
+ delete(device.peers, key)
peer.mutex.Unlock()
}
}
diff --git a/src/index.go b/src/index.go
new file mode 100644
index 0000000..83a7e29
--- /dev/null
+++ b/src/index.go
@@ -0,0 +1,82 @@
+package main
+
+import (
+ "crypto/rand"
+ "sync"
+)
+
+/* Index=0 is reserved for unset indecies
+ *
+ */
+
+type IndexTable struct {
+ mutex sync.RWMutex
+ keypairs map[uint32]*KeyPair
+ handshakes map[uint32]*Handshake
+}
+
+func randUint32() (uint32, error) {
+ var buff [4]byte
+ _, err := rand.Read(buff[:])
+ id := uint32(buff[0])
+ id <<= 8
+ id |= uint32(buff[1])
+ id <<= 8
+ id |= uint32(buff[2])
+ id <<= 8
+ id |= uint32(buff[3])
+ return id, err
+}
+
+func (table *IndexTable) Init() {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+ table.keypairs = make(map[uint32]*KeyPair)
+ table.handshakes = make(map[uint32]*Handshake)
+}
+
+func (table *IndexTable) NewIndex(handshake *Handshake) (uint32, error) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+ for {
+ // generate random index
+
+ id, err := randUint32()
+ if err != nil {
+ return id, err
+ }
+ if id == 0 {
+ continue
+ }
+
+ // check if index used
+
+ _, ok := table.keypairs[id]
+ if ok {
+ continue
+ }
+ _, ok = table.handshakes[id]
+ if ok {
+ continue
+ }
+
+ // update the index
+
+ delete(table.handshakes, handshake.localIndex)
+ handshake.localIndex = id
+ table.handshakes[id] = handshake
+ return id, nil
+ }
+}
+
+func (table *IndexTable) LookupKeyPair(id uint32) *KeyPair {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.keypairs[id]
+}
+
+func (table *IndexTable) LookupHandshake(id uint32) *Handshake {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.handshakes[id]
+}
diff --git a/src/keypair.go b/src/keypair.go
new file mode 100644
index 0000000..22a8244
--- /dev/null
+++ b/src/keypair.go
@@ -0,0 +1,12 @@
+package main
+
+import (
+ "crypto/cipher"
+)
+
+type KeyPair struct {
+ recieveKey cipher.AEAD
+ recieveNonce NoiseNonce
+ sendKey cipher.AEAD
+ sendNonce NoiseNonce
+}
diff --git a/src/noise_helpers.go b/src/noise_helpers.go
index df25011..eadbc07 100644
--- a/src/noise_helpers.go
+++ b/src/noise_helpers.go
@@ -81,6 +81,6 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
- curve25519.ScalarMult(&ss, apk, ask)
+ curve25519.ScalarMult(&ss, ask, apk)
return ss
}
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index e7c8774..b9c8981 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -56,18 +56,22 @@ type MessageTransport struct {
}
type Handshake struct {
- lock sync.Mutex
- state int
- chainKey [blake2s.Size]byte // chain key
- hash [blake2s.Size]byte // hash value
- staticStatic NoisePublicKey // precomputed DH(S_i, S_r)
- ephemeral NoisePrivateKey // ephemeral secret key
- remoteIndex uint32 // index for sending
- device *Device
- peer *Peer
+ state int
+ mutex sync.Mutex
+ hash [blake2s.Size]byte // hash value
+ chainKey [blake2s.Size]byte // chain key
+ presharedKey NoiseSymmetricKey // psk
+ localEphemeral NoisePrivateKey // ephemeral secret key
+ localIndex uint32 // used to clear hash-table
+ remoteIndex uint32 // index for sending
+ remoteStatic NoisePublicKey // long term key
+ remoteEphemeral NoisePublicKey // ephemeral public key
+ precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
+ lastTimestamp TAI64N
}
var (
+ EmptyMessage []byte
ZeroNonce [chacha20poly1305.NonceSize]byte
InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte
@@ -78,102 +82,196 @@ func init() {
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
}
-func (h *Handshake) Precompute() {
- h.staticStatic = h.device.privateKey.sharedSecret(h.peer.publicKey)
-}
-
-func (h *Handshake) ConsumeMessageResponse(msg *MessageResponse) {
-
-}
-
-func (h *Handshake) addHash(data []byte) {
+func (h *Handshake) addToHash(data []byte) {
h.hash = addToHash(h.hash, data)
}
-func (h *Handshake) addChain(data []byte) {
+func (h *Handshake) addToChainKey(data []byte) {
h.chainKey = addToChainKey(h.chainKey, data)
}
-func (h *Handshake) CreateMessageInital() (*MessageInital, error) {
- h.lock.Lock()
- defer h.lock.Unlock()
+func (device *Device) Precompute(peer *Peer) {
+ h := &peer.handshake
+ h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+}
+
+func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
- // reset handshake
+ // create ephemeral key
var err error
- h.ephemeral, err = newPrivateKey()
+ handshake.chainKey = InitalChainKey
+ handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
+ handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
- h.chainKey = InitalChainKey
- h.hash = addToHash(InitalHash, h.device.publicKey[:])
- // create ephemeral key
+ // assign index
var msg MessageInital
+
msg.Type = MessageInitalType
- msg.Sender = h.device.NewID(h)
- msg.Ephemeral = h.ephemeral.publicKey()
- h.chainKey = addToChainKey(h.chainKey, msg.Ephemeral[:])
- h.hash = addToHash(h.hash, msg.Ephemeral[:])
+ msg.Ephemeral = handshake.localEphemeral.publicKey()
+ msg.Sender, err = device.indices.NewIndex(handshake)
+
+ if err != nil {
+ return nil, err
+ }
+
+ handshake.addToChainKey(msg.Ephemeral[:])
+ handshake.addToHash(msg.Ephemeral[:])
// encrypt long-term "identity key"
func() {
var key [chacha20poly1305.KeySize]byte
- ss := h.ephemeral.sharedSecret(h.peer.publicKey)
- h.chainKey, key = KDF2(h.chainKey[:], ss[:])
+ ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ handshake.chainKey, key = KDF2(handshake.chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Static[:0], ZeroNonce[:], h.device.publicKey[:], nil)
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
}()
- h.addHash(msg.Static[:])
+ handshake.addToHash(msg.Static[:])
// encrypt timestamp
timestamp := Timestamp()
func() {
var key [chacha20poly1305.KeySize]byte
- h.chainKey, key = KDF2(h.chainKey[:], h.staticStatic[:])
+ handshake.chainKey, key = KDF2(
+ handshake.chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], nil)
+ aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
- h.addHash(msg.Timestamp[:])
- h.state = HandshakeInitialCreated
+
+ handshake.addToHash(msg.Timestamp[:])
+ handshake.state = HandshakeInitialCreated
+
return &msg, nil
}
-func (h *Handshake) ConsumeMessageInitial(msg *MessageInital) error {
+func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
if msg.Type != MessageInitalType {
panic(errors.New("bug: invalid inital message type"))
}
- hash := addToHash(InitalHash, h.device.publicKey[:])
- chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
+ hash := addToHash(InitalHash, device.publicKey[:])
hash = addToHash(hash, msg.Ephemeral[:])
+ chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
- //
+ // decrypt identity key
- ephemeral, err := newPrivateKey()
+ var err error
+ var peerPK NoisePublicKey
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ ss := device.privateKey.sharedSecret(msg.Ephemeral)
+ chainKey, key = KDF2(chainKey[:], ss[:])
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
+ }()
if err != nil {
- return err
+ return nil
}
+ hash = addToHash(hash, msg.Static[:])
- // update handshake state
+ // find peer
+
+ peer := device.LookupPeer(peerPK)
+ if peer == nil {
+ return nil
+ }
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+
+ // decrypt timestamp
+
+ var timestamp TAI64N
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ chainKey, key = KDF2(
+ chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
+ }()
+ if err != nil {
+ return nil
+ }
+ hash = addToHash(hash, msg.Timestamp[:])
- h.lock.Lock()
- defer h.lock.Unlock()
+ // check for replay attack
- h.hash = hash
- h.chainKey = chainKey
- h.remoteIndex = msg.Sender
- h.ephemeral = ephemeral
- h.state = HandshakeInitialConsumed
+ if !timestamp.After(handshake.lastTimestamp) {
+ return nil
+ }
+
+ // check for flood attack
- return nil
+ // update handshake state
+ handshake.hash = hash
+ handshake.chainKey = chainKey
+ handshake.remoteIndex = msg.Sender
+ handshake.remoteEphemeral = msg.Ephemeral
+ handshake.state = HandshakeInitialConsumed
+ return peer
}
-func (h *Handshake) CreateMessageResponse() []byte {
+func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+
+ if handshake.state != HandshakeInitialConsumed {
+ panic(errors.New("bug: handshake initation must be consumed first"))
+ }
+
+ // assign index
+
+ var err error
+ var msg MessageResponse
+ msg.Type = MessageResponseType
+ msg.Sender, err = device.indices.NewIndex(handshake)
+ msg.Reciever = handshake.remoteIndex
+ if err != nil {
+ return nil, err
+ }
- return nil
+ // create ephemeral key
+
+ handshake.localEphemeral, err = newPrivateKey()
+ if err != nil {
+ return nil, err
+ }
+ msg.Ephemeral = handshake.localEphemeral.publicKey()
+
+ func() {
+ ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ handshake.addToChainKey(ss[:])
+ ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ handshake.addToChainKey(ss[:])
+ }()
+
+ // add preshared key (psk)
+
+ var tau [blake2s.Size]byte
+ var key [chacha20poly1305.KeySize]byte
+ handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
+ handshake.addToHash(tau[:])
+
+ func() {
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], EmptyMessage, handshake.hash[:])
+ handshake.addToHash(msg.Empty[:])
+ }()
+
+ return &msg, nil
}
diff --git a/src/noise_test.go b/src/noise_test.go
index b3ea54f..8d6a0fa 100644
--- a/src/noise_test.go
+++ b/src/noise_test.go
@@ -1,38 +1,93 @@
package main
import (
+ "bytes"
+ "encoding/binary"
"testing"
)
-func TestHandshake(t *testing.T) {
- var dev1 Device
- var dev2 Device
-
- var err error
-
- dev1.privateKey, err = newPrivateKey()
+func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
+}
+
+func assertEqual(t *testing.T, a []byte, b []byte) {
+ if bytes.Compare(a, b) != 0 {
+ t.Fatal(a, "!=", b)
+ }
+}
- dev2.privateKey, err = newPrivateKey()
+func TestCurveWrappers(t *testing.T) {
+ sk1, err := newPrivateKey()
+ assertNil(t, err)
+
+ sk2, err := newPrivateKey()
+ assertNil(t, err)
+
+ pk1 := sk1.publicKey()
+ pk2 := sk2.publicKey()
+
+ ss1 := sk1.sharedSecret(pk2)
+ ss2 := sk2.sharedSecret(pk1)
+
+ if ss1 != ss2 {
+ t.Fatal("Failed to compute shared secet")
+ }
+}
+
+func newDevice(t *testing.T) *Device {
+ var device Device
+ sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
+ device.Init()
+ device.SetPrivateKey(sk)
+ return &device
+}
+
+func TestNoiseHandshake(t *testing.T) {
+
+ dev1 := newDevice(t)
+ dev2 := newDevice(t)
- var peer1 Peer
- var peer2 Peer
+ peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
+ peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
- peer1.publicKey = dev1.privateKey.publicKey()
- peer2.publicKey = dev2.privateKey.publicKey()
+ assertEqual(
+ t,
+ peer1.handshake.precomputedStaticStatic[:],
+ peer2.handshake.precomputedStaticStatic[:],
+ )
+
+ /* simulate handshake */
+
+ // Initiation message
+
+ msg1, err := dev1.CreateMessageInitial(peer2)
+ assertNil(t, err)
+
+ packet := make([]byte, 0, 256)
+ writer := bytes.NewBuffer(packet)
+ err = binary.Write(writer, binary.LittleEndian, msg1)
+ peer := dev2.ConsumeMessageInitial(msg1)
+ if peer == nil {
+ t.Fatal("handshake failed at initiation message")
+ }
- var handshake1 Handshake
- var handshake2 Handshake
+ assertEqual(
+ t,
+ peer1.handshake.chainKey[:],
+ peer2.handshake.chainKey[:],
+ )
- handshake1.device = &dev1
- handshake2.device = &dev2
+ assertEqual(
+ t,
+ peer1.handshake.hash[:],
+ peer2.handshake.hash[:],
+ )
- handshake1.peer = &peer2
- handshake2.peer = &peer1
+ // Response message
}
diff --git a/src/peer.go b/src/peer.go
index db5e99f..f6eb555 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -6,17 +6,35 @@ import (
"time"
)
-type KeyPair struct {
- recieveKey NoiseSymmetricKey
- recieveNonce NoiseNonce
- sendKey NoiseSymmetricKey
- sendNonce NoiseNonce
-}
-
type Peer struct {
mutex sync.RWMutex
- publicKey NoisePublicKey
- presharedKey NoiseSymmetricKey
- endpoint net.IP
- persistentKeepaliveInterval time.Duration
+ endpointIP net.IP //
+ endpointPort uint16 //
+ persistentKeepaliveInterval time.Duration // 0 = disabled
+ handshake Handshake
+ device *Device
+}
+
+func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
+ var peer Peer
+
+ // map public key
+
+ device.mutex.Lock()
+ device.peers[pk] = &peer
+ device.mutex.Unlock()
+
+ // precompute
+
+ peer.mutex.Lock()
+ peer.device = device
+ func(h *Handshake) {
+ h.mutex.Lock()
+ h.remoteStatic = pk
+ h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+ h.mutex.Unlock()
+ }(&peer.handshake)
+ peer.mutex.Unlock()
+
+ return &peer
}
diff --git a/src/routing.go b/src/routing.go
index 0aa111c..553df11 100644
--- a/src/routing.go
+++ b/src/routing.go
@@ -13,6 +13,13 @@ type RoutingTable struct {
mutex sync.RWMutex
}
+func (table *RoutingTable) Reset() {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+ table.IPv4 = nil
+ table.IPv6 = nil
+}
+
func (table *RoutingTable) RemovePeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
diff --git a/src/tai64.go b/src/tai64.go
index d0d1432..2299a37 100644
--- a/src/tai64.go
+++ b/src/tai64.go
@@ -1,6 +1,7 @@
package main
import (
+ "bytes"
"encoding/binary"
"time"
)
@@ -21,3 +22,7 @@ func Timestamp() TAI64N {
binary.BigEndian.PutUint32(tai64n[8:], nano)
return tai64n
}
+
+func (t1 *TAI64N) After(t2 TAI64N) bool {
+ return bytes.Compare(t1[:], t2[:]) > 0
+}