summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/cookie.go39
-rw-r--r--src/device.go20
-rw-r--r--src/keypair.go13
-rw-r--r--src/main.go8
-rw-r--r--src/noise_protocol.go106
-rw-r--r--src/noise_test.go4
-rw-r--r--src/peer.go43
-rw-r--r--src/routing.go23
-rw-r--r--src/send.go154
9 files changed, 315 insertions, 95 deletions
diff --git a/src/cookie.go b/src/cookie.go
new file mode 100644
index 0000000..a6987a2
--- /dev/null
+++ b/src/cookie.go
@@ -0,0 +1,39 @@
+package main
+
+import (
+ "errors"
+ "golang.org/x/crypto/blake2s"
+)
+
+func CalculateCookie(peer *Peer, msg []byte) {
+ size := len(msg)
+
+ if size < blake2s.Size128*2 {
+ panic(errors.New("bug: message too short"))
+ }
+
+ startMac1 := size - (blake2s.Size128 * 2)
+ startMac2 := size - blake2s.Size128
+
+ mac1 := msg[startMac1 : startMac1+blake2s.Size128]
+ mac2 := msg[startMac2 : startMac2+blake2s.Size128]
+
+ peer.mutex.RLock()
+ defer peer.mutex.RUnlock()
+
+ // set mac1
+
+ func() {
+ mac, _ := blake2s.New128(peer.macKey[:])
+ mac.Write(msg[:startMac1])
+ mac.Sum(mac1[:0])
+ }()
+
+ // set mac2
+
+ if peer.cookie != nil {
+ mac, _ := blake2s.New128(peer.cookie)
+ mac.Write(msg[:startMac2])
+ mac.Sum(mac2[:0])
+ }
+}
diff --git a/src/device.go b/src/device.go
index 9969034..ce10a63 100644
--- a/src/device.go
+++ b/src/device.go
@@ -1,18 +1,22 @@
package main
import (
+ "log"
"sync"
)
type Device struct {
- mutex sync.RWMutex
- peers map[NoisePublicKey]*Peer
- indices IndexTable
- privateKey NoisePrivateKey
- publicKey NoisePublicKey
- fwMark uint32
- listenPort uint16
- routingTable RoutingTable
+ mtu int
+ mutex sync.RWMutex
+ peers map[NoisePublicKey]*Peer
+ indices IndexTable
+ privateKey NoisePrivateKey
+ publicKey NoisePublicKey
+ fwMark uint32
+ listenPort uint16
+ routingTable RoutingTable
+ logger log.Logger
+ queueWorkOutbound chan *OutboundWorkQueueElement
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
diff --git a/src/keypair.go b/src/keypair.go
index e434c74..e7961a8 100644
--- a/src/keypair.go
+++ b/src/keypair.go
@@ -2,11 +2,20 @@ package main
import (
"crypto/cipher"
+ "sync"
)
type KeyPair struct {
recv cipher.AEAD
- recvNonce NoiseNonce
+ recvNonce uint64
send cipher.AEAD
- sendNonce NoiseNonce
+ sendNonce uint64
+}
+
+type KeyPairs struct {
+ mutex sync.RWMutex
+ current *KeyPair
+ previous *KeyPair
+ next *KeyPair
+ newKeyPair chan bool
}
diff --git a/src/main.go b/src/main.go
index af336f0..b6f6deb 100644
--- a/src/main.go
+++ b/src/main.go
@@ -1,6 +1,8 @@
package main
-import "fmt"
+import (
+ "fmt"
+)
func main() {
fd, err := CreateTUN("test0")
@@ -8,9 +10,9 @@ func main() {
queue := make(chan []byte, 1000)
- var device Device
+ // var device Device
- go OutgoingRoutingWorker(&device, queue)
+ // go OutgoingRoutingWorker(&device, queue)
for {
tmp := make([]byte, 1<<16)
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index 7f26cf1..a16908a 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -9,9 +9,9 @@ import (
)
const (
- HandshakeReset = iota
- HandshakeInitialCreated
- HandshakeInitialConsumed
+ HandshakeZeroed = iota
+ HandshakeInitiationCreated
+ HandshakeInitiationConsumed
HandshakeResponseCreated
HandshakeResponseConsumed
)
@@ -24,13 +24,19 @@ const (
)
const (
- MessageInitalType = 1
+ MessageInitiationType = 1
MessageResponseType = 2
MessageCookieResponseType = 3
MessageTransportType = 4
)
-type MessageInital struct {
+/* Type is an 8-bit field, followed by 3 nul bytes,
+ * by marshalling the messages in little-endian byteorder
+ * we can treat these as a 32-bit int
+ *
+ */
+
+type MessageInitiation struct {
Type uint32
Sender uint32
Ephemeral NoisePublicKey
@@ -73,9 +79,9 @@ type Handshake struct {
}
var (
- ZeroNonce [chacha20poly1305.NonceSize]byte
InitalChainKey [blake2s.Size]byte
InitalHash [blake2s.Size]byte
+ ZeroNonce [chacha20poly1305.NonceSize]byte
)
func init() {
@@ -83,23 +89,23 @@ func init() {
InitalHash = blake2s.Sum256(append(InitalChainKey[:], []byte(WGIdentifier)...))
}
-func addToChainKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+func mixKey(c [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return KDF1(c[:], data)
}
-func addToHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
+func mixHash(h [blake2s.Size]byte, data []byte) [blake2s.Size]byte {
return blake2s.Sum256(append(h[:], data...))
}
-func (h *Handshake) addToHash(data []byte) {
- h.hash = addToHash(h.hash, data)
+func (h *Handshake) mixHash(data []byte) {
+ h.hash = mixHash(h.hash, data)
}
-func (h *Handshake) addToChainKey(data []byte) {
- h.chainKey = addToChainKey(h.chainKey, data)
+func (h *Handshake) mixKey(data []byte) {
+ h.chainKey = mixKey(h.chainKey, data)
}
-func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
+func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
@@ -108,7 +114,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
var err error
handshake.chainKey = InitalChainKey
- handshake.hash = addToHash(InitalHash, handshake.remoteStatic[:])
+ handshake.hash = mixHash(InitalHash, handshake.remoteStatic[:])
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
@@ -116,9 +122,9 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
// assign index
- var msg MessageInital
+ var msg MessageInitiation
- msg.Type = MessageInitalType
+ msg.Type = MessageInitiationType
msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.localIndex, err = device.indices.NewIndex(peer)
@@ -127,10 +133,10 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
}
msg.Sender = handshake.localIndex
- handshake.addToChainKey(msg.Ephemeral[:])
- handshake.addToHash(msg.Ephemeral[:])
+ handshake.mixKey(msg.Ephemeral[:])
+ handshake.mixHash(msg.Ephemeral[:])
- // encrypt identity key
+ // encrypt static key
func() {
var key [chacha20poly1305.KeySize]byte
@@ -139,7 +145,7 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
}()
- handshake.addToHash(msg.Static[:])
+ handshake.mixHash(msg.Static[:])
// encrypt timestamp
@@ -154,22 +160,22 @@ func (device *Device) CreateMessageInitial(peer *Peer) (*MessageInital, error) {
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
- handshake.addToHash(msg.Timestamp[:])
- handshake.state = HandshakeInitialCreated
+ handshake.mixHash(msg.Timestamp[:])
+ handshake.state = HandshakeInitiationCreated
return &msg, nil
}
-func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
- if msg.Type != MessageInitalType {
- panic(errors.New("bug: invalid inital message type"))
+func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
+ if msg.Type != MessageInitiationType {
+ return nil
}
- hash := addToHash(InitalHash, device.publicKey[:])
- hash = addToHash(hash, msg.Ephemeral[:])
- chainKey := addToChainKey(InitalChainKey, msg.Ephemeral[:])
+ hash := mixHash(InitalHash, device.publicKey[:])
+ hash = mixHash(hash, msg.Ephemeral[:])
+ chainKey := mixKey(InitalChainKey, msg.Ephemeral[:])
- // decrypt identity key
+ // decrypt static key
var err error
var peerPK NoisePublicKey
@@ -183,7 +189,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
if err != nil {
return nil
}
- hash = addToHash(hash, msg.Static[:])
+ hash = mixHash(hash, msg.Static[:])
// find peer
@@ -210,7 +216,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
if err != nil {
return nil
}
- hash = addToHash(hash, msg.Timestamp[:])
+ hash = mixHash(hash, msg.Timestamp[:])
// check for replay attack
@@ -218,7 +224,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
return nil
}
- // check for flood attack
+ // TODO: check for flood attack
// update handshake state
@@ -227,7 +233,7 @@ func (device *Device) ConsumeMessageInitial(msg *MessageInital) *Peer {
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp
- handshake.state = HandshakeInitialConsumed
+ handshake.state = HandshakeInitiationConsumed
return peer
}
@@ -236,8 +242,8 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitialConsumed {
- panic(errors.New("bug: handshake initation must be consumed first"))
+ if handshake.state != HandshakeInitiationConsumed {
+ return nil, errors.New("handshake initation must be consumed first")
}
// assign index
@@ -260,13 +266,13 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
return nil, err
}
msg.Ephemeral = handshake.localEphemeral.publicKey()
- handshake.addToHash(msg.Ephemeral[:])
+ handshake.mixHash(msg.Ephemeral[:])
func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
- handshake.addToChainKey(ss[:])
+ handshake.mixKey(ss[:])
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
- handshake.addToChainKey(ss[:])
+ handshake.mixKey(ss[:])
}()
// add preshared key (psk)
@@ -274,12 +280,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
handshake.chainKey, tau, key = KDF3(handshake.chainKey[:], handshake.presharedKey[:])
- handshake.addToHash(tau[:])
+ handshake.mixHash(tau[:])
func() {
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
- handshake.addToHash(msg.Empty[:])
+ handshake.mixHash(msg.Empty[:])
}()
handshake.state = HandshakeResponseCreated
@@ -288,7 +294,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
if msg.Type != MessageResponseType {
- panic(errors.New("bug: invalid message type"))
+ return nil
}
// lookup handshake by reciever
@@ -300,20 +306,20 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitialCreated {
+ if handshake.state != HandshakeInitiationCreated {
return nil
}
// finish 3-way DH
- hash := addToHash(handshake.hash, msg.Ephemeral[:])
+ hash := mixHash(handshake.hash, msg.Ephemeral[:])
chainKey := handshake.chainKey
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- chainKey = addToChainKey(chainKey, ss[:])
+ chainKey = mixKey(chainKey, ss[:])
ss = device.privateKey.sharedSecret(msg.Ephemeral)
- chainKey = addToChainKey(chainKey, ss[:])
+ chainKey = mixKey(chainKey, ss[:])
}()
// add preshared key (psk)
@@ -321,7 +327,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
- hash = addToHash(hash, tau[:])
+ hash = mixHash(hash, tau[:])
// authenticate
@@ -330,7 +336,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
if err != nil {
return nil
}
- hash = addToHash(hash, msg.Empty[:])
+ hash = mixHash(hash, msg.Empty[:])
// update handshake state
@@ -368,7 +374,11 @@ func (peer *Peer) NewKeyPair() *KeyPair {
keyPair.sendNonce = 0
keyPair.recvNonce = 0
- peer.handshake.state = HandshakeReset
+ // zero handshake
+
+ handshake.chainKey = [blake2s.Size]byte{}
+ handshake.localEphemeral = NoisePrivateKey{}
+ peer.handshake.state = HandshakeZeroed
return &keyPair
}
diff --git a/src/noise_test.go b/src/noise_test.go
index ddabf8e..8450c1c 100644
--- a/src/noise_test.go
+++ b/src/noise_test.go
@@ -67,13 +67,13 @@ func TestNoiseHandshake(t *testing.T) {
t.Log("exchange initiation message")
- msg1, err := dev1.CreateMessageInitial(peer2)
+ msg1, err := dev1.CreateMessageInitiation(peer2)
assertNil(t, err)
packet := make([]byte, 0, 256)
writer := bytes.NewBuffer(packet)
err = binary.Write(writer, binary.LittleEndian, msg1)
- peer := dev2.ConsumeMessageInitial(msg1)
+ peer := dev2.ConsumeMessageInitiation(msg1)
if peer == nil {
t.Fatal("handshake failed at initiation message")
}
diff --git a/src/peer.go b/src/peer.go
index f6eb555..42b9e8d 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -1,39 +1,64 @@
package main
import (
+ "errors"
+ "golang.org/x/crypto/blake2s"
"net"
"sync"
"time"
)
+const (
+ OutboundQueueSize = 64
+)
+
type Peer struct {
mutex sync.RWMutex
endpointIP net.IP //
endpointPort uint16 //
persistentKeepaliveInterval time.Duration // 0 = disabled
+ keyPairs KeyPairs
handshake Handshake
device *Device
+ macKey [blake2s.Size]byte // Hash(Label-Mac1 || publicKey)
+ cookie []byte // cookie
+ cookieExpire time.Time
+ queueInbound chan []byte
+ queueOutbound chan *OutboundWorkQueueElement
+ queueOutboundRouting chan []byte
}
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
var peer Peer
+ // create peer
+
+ peer.mutex.Lock()
+ peer.device = device
+ peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
+
// map public key
device.mutex.Lock()
+ _, ok := device.peers[pk]
+ if ok {
+ panic(errors.New("bug: adding existing peer"))
+ }
device.peers[pk] = &peer
device.mutex.Unlock()
- // precompute
+ // precompute DH
- peer.mutex.Lock()
- peer.device = device
- func(h *Handshake) {
- h.mutex.Lock()
- h.remoteStatic = pk
- h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
- h.mutex.Unlock()
- }(&peer.handshake)
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ handshake.remoteStatic = pk
+ handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic)
+
+ // compute mac key
+
+ peer.macKey = blake2s.Sum256(append([]byte(WGLabelMAC1[:]), handshake.remoteStatic[:]...))
+
+ handshake.mutex.Unlock()
peer.mutex.Unlock()
return &peer
diff --git a/src/routing.go b/src/routing.go
index 553df11..4189c25 100644
--- a/src/routing.go
+++ b/src/routing.go
@@ -2,7 +2,6 @@ package main
import (
"errors"
- "fmt"
"net"
"sync"
)
@@ -52,25 +51,3 @@ func (table *RoutingTable) LookupIPv6(address []byte) *Peer {
defer table.mutex.RUnlock()
return table.IPv6.Lookup(address)
}
-
-func OutgoingRoutingWorker(device *Device, queue chan []byte) {
- for {
- packet := <-queue
- switch packet[0] >> 4 {
-
- case IPv4version:
- dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer := device.routingTable.LookupIPv4(dst)
- fmt.Println("IPv4", peer)
-
- case IPv6version:
- dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer := device.routingTable.LookupIPv6(dst)
- fmt.Println("IPv6", peer)
-
- default:
- // todo: log
- fmt.Println("Unknown IP version")
- }
- }
-}
diff --git a/src/send.go b/src/send.go
new file mode 100644
index 0000000..9790320
--- /dev/null
+++ b/src/send.go
@@ -0,0 +1,154 @@
+package main
+
+import (
+ "net"
+ "sync"
+ "sync/atomic"
+)
+
+/* Handles outbound flow
+ *
+ * 1. TUN queue
+ * 2. Routing
+ * 3. Per peer queuing
+ * 4. (work queuing)
+ *
+ */
+
+type OutboundWorkQueueElement struct {
+ wg sync.WaitGroup
+ packet []byte
+ nonce uint64
+ keyPair *KeyPair
+}
+
+func (device *Device) SendPacket(packet []byte) {
+
+ // lookup peer
+
+ var peer *Peer
+ switch packet[0] >> 4 {
+ case IPv4version:
+ dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.routingTable.LookupIPv4(dst)
+
+ case IPv6version:
+ dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.routingTable.LookupIPv6(dst)
+
+ default:
+ device.logger.Println("unknown IP version")
+ return
+ }
+
+ if peer == nil {
+ return
+ }
+
+ // insert into peer queue
+
+ for {
+ select {
+ case peer.queueOutboundRouting <- packet:
+ default:
+ select {
+ case <-peer.queueOutboundRouting:
+ default:
+ }
+ continue
+ }
+ break
+ }
+}
+
+/* Go routine
+ *
+ *
+ * 1. waits for handshake.
+ * 2. assigns key pair & nonce
+ * 3. inserts to working queue
+ *
+ * TODO: avoid dynamic allocation of work queue elements
+ */
+func (peer *Peer) ConsumeOutboundPackets() {
+ for {
+ // wait for key pair
+ keyPair := func() *KeyPair {
+ peer.keyPairs.mutex.RLock()
+ defer peer.keyPairs.mutex.RUnlock()
+ return peer.keyPairs.current
+ }()
+ if keyPair == nil {
+ if len(peer.queueOutboundRouting) > 0 {
+ // TODO: start handshake
+ <-peer.keyPairs.newKeyPair
+ }
+ continue
+ }
+
+ // assign packets key pair
+ for {
+ select {
+ case <-peer.keyPairs.newKeyPair:
+ default:
+ case <-peer.keyPairs.newKeyPair:
+ case packet := <-peer.queueOutboundRouting:
+
+ // create new work element
+
+ work := new(OutboundWorkQueueElement)
+ work.wg.Add(1)
+ work.keyPair = keyPair
+ work.packet = packet
+ work.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+
+ peer.queueOutbound <- work
+
+ // drop packets until there is room
+
+ for {
+ select {
+ case peer.device.queueWorkOutbound <- work:
+ break
+ default:
+ drop := <-peer.device.queueWorkOutbound
+ drop.packet = nil
+ drop.wg.Done()
+ }
+ }
+ }
+ }
+ }
+}
+
+func (peer *Peer) RoutineSequential() {
+ for work := range peer.queueOutbound {
+ work.wg.Wait()
+ if work.packet == nil {
+ continue
+ }
+ }
+}
+
+func (device *Device) EncryptionWorker() {
+ for {
+ work := <-device.queueWorkOutbound
+
+ func() {
+ defer work.wg.Done()
+
+ // pad packet
+ padding := device.mtu - len(work.packet)
+ if padding < 0 {
+ work.packet = nil
+ return
+ }
+ for n := 0; n < padding; n += 1 {
+ work.packet = append(work.packet, 0) // TODO: gotta be a faster way
+ }
+
+ //
+
+ }()
+ }
+}