summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/config.go190
-rw-r--r--src/device.go14
-rw-r--r--src/main.go28
-rw-r--r--src/misc.go8
-rw-r--r--src/noise.go51
-rw-r--r--src/peer.go18
-rw-r--r--src/trie.go154
-rw-r--r--src/trie_test.go66
8 files changed, 529 insertions, 0 deletions
diff --git a/src/config.go b/src/config.go
new file mode 100644
index 0000000..f6f1378
--- /dev/null
+++ b/src/config.go
@@ -0,0 +1,190 @@
+package main
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+ "log"
+)
+
+/* todo : use real error code
+ * Many of which will be the same
+ */
+const (
+ ipcErrorNoPeer = 0
+ ipcErrorNoKeyValue = 1
+ ipcErrorInvalidKey = 2
+ ipcErrorInvalidPrivateKey = 3
+ ipcErrorInvalidPublicKey = 4
+ ipcErrorInvalidPort = 5
+)
+
+type IPCError struct {
+ Code int
+}
+
+func (s *IPCError) Error() string {
+ return fmt.Sprintf("IPC error: %d", s.Code)
+}
+
+func (s *IPCError) ErrorCode() int {
+ return s.Code
+}
+
+// Writes the configuration to the socket
+func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
+
+}
+
+// Creates new config, from old and socket message
+func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError {
+
+ scanner := bufio.NewScanner(socket)
+
+ dev.mutex.Lock()
+ defer dev.mutex.Unlock()
+
+ for scanner.Scan() {
+ var key string
+ var value string
+ var peer *Peer
+
+ // Parse line
+
+ line := scanner.Text()
+ if line == "\n" {
+ break
+ }
+ fmt.Println(line)
+ n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value)
+ if n != 2 || err != nil {
+ fmt.Println(err, n)
+ return &IPCError{Code: ipcErrorNoKeyValue}
+ }
+
+ switch key {
+
+ /* Interface configuration */
+
+ case "private_key":
+ if value == "" {
+ dev.privateKey = NoisePrivateKey{}
+ } else {
+ err := dev.privateKey.FromHex(value)
+ if err != nil {
+ return &IPCError{Code: ipcErrorInvalidPrivateKey}
+ }
+ }
+
+ case "listen_port":
+ _, err := fmt.Sscanf(value, "%ud", &dev.listenPort)
+ if err != nil {
+ return &IPCError{Code: ipcErrorInvalidPort}
+ }
+
+ case "fwmark":
+ panic(nil) // not handled yet
+
+ case "public_key":
+ var pubKey NoisePublicKey
+ err := pubKey.FromHex(value)
+ if err != nil {
+ return &IPCError{Code: ipcErrorInvalidPublicKey}
+ }
+ found, ok := dev.peers[pubKey]
+ if ok {
+ peer = found
+ } else {
+ newPeer := &Peer{
+ publicKey: pubKey,
+ }
+ peer = newPeer
+ dev.peers[pubKey] = newPeer
+ }
+
+ case "replace_peers":
+
+ default:
+ /* Peer configuration */
+
+ if peer == nil {
+ return &IPCError{Code: ipcErrorNoPeer}
+ }
+
+ switch key {
+
+ case "remove":
+ peer.mutex.Lock()
+
+ peer = nil
+
+ case "preshared_key":
+ func() {
+ peer.mutex.Lock()
+ defer peer.mutex.Unlock()
+ }()
+
+ case "endpoint":
+ func() {
+ peer.mutex.Lock()
+ defer peer.mutex.Unlock()
+ }()
+
+ case "persistent_keepalive_interval":
+ func() {
+ peer.mutex.Lock()
+ defer peer.mutex.Unlock()
+ }()
+
+ case "replace_allowed_ips":
+ // remove peer from trie
+
+ case "allowed_ip":
+
+ /* Invalid key */
+
+ default:
+ return &IPCError{Code: ipcErrorInvalidKey}
+ }
+ }
+ }
+
+ return nil
+}
+
+func ipcListen(dev *Device, socket io.ReadWriter) error {
+
+ buffered := func(s io.ReadWriter) *bufio.ReadWriter {
+ reader := bufio.NewReader(s)
+ writer := bufio.NewWriter(s)
+ return bufio.NewReadWriter(reader, writer)
+ }(socket)
+
+ for {
+ op, err := buffered.ReadString('\n')
+ if err != nil {
+ return err
+ }
+ log.Println(op)
+
+ switch op {
+
+ case "set=1\n":
+ err := ipcSetOperation(dev, buffered)
+ if err != nil {
+ fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
+ return err
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n")
+ }
+ buffered.Flush()
+
+ case "get=1\n":
+
+ default:
+ return errors.New("handle this please")
+ }
+ }
+
+}
diff --git a/src/device.go b/src/device.go
new file mode 100644
index 0000000..cd0835c
--- /dev/null
+++ b/src/device.go
@@ -0,0 +1,14 @@
+package main
+
+import (
+ "sync"
+)
+
+type Device struct {
+ mutex sync.RWMutex
+ peers map[NoisePublicKey]*Peer
+ privateKey NoisePrivateKey
+ publicKey NoisePublicKey
+ fwMark uint32
+ listenPort uint16
+}
diff --git a/src/main.go b/src/main.go
new file mode 100644
index 0000000..0f5016d
--- /dev/null
+++ b/src/main.go
@@ -0,0 +1,28 @@
+package main
+
+import (
+ "fmt"
+ "log"
+ "net"
+)
+
+func main() {
+ l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
+ if err != nil {
+ log.Fatal("listen error:", err)
+ }
+
+ for {
+ fd, err := l.Accept()
+ if err != nil {
+ log.Fatal("accept error:", err)
+ }
+
+ var dev Device
+ go func(conn net.Conn) {
+ err := ipcListen(&dev, conn)
+ fmt.Println(err)
+ }(fd)
+ }
+
+}
diff --git a/src/misc.go b/src/misc.go
new file mode 100644
index 0000000..e1244d6
--- /dev/null
+++ b/src/misc.go
@@ -0,0 +1,8 @@
+package main
+
+func min(a uint, b uint) uint {
+ if a > b {
+ return b
+ }
+ return a
+}
diff --git a/src/noise.go b/src/noise.go
new file mode 100644
index 0000000..d13bdd6
--- /dev/null
+++ b/src/noise.go
@@ -0,0 +1,51 @@
+package main
+
+import (
+ "encoding/hex"
+ "errors"
+)
+
+const (
+ NoisePublicKeySize = 32
+ NoisePrivateKeySize = 32
+ NoiseSymmetricKeySize = 32
+)
+
+type (
+ NoisePublicKey [NoisePublicKeySize]byte
+ NoisePrivateKey [NoisePrivateKeySize]byte
+ NoiseSymmetricKey [NoiseSymmetricKeySize]byte
+ NoiseNonce uint64 // padded to 12-bytes
+)
+
+func (key *NoisePrivateKey) FromHex(s string) error {
+ slice, err := hex.DecodeString(s)
+ if err != nil {
+ return err
+ }
+ if len(slice) != NoisePrivateKeySize {
+ return errors.New("Invalid length of hex string for curve25519 point")
+ }
+ copy(key[:], slice)
+ return nil
+}
+
+func (key *NoisePrivateKey) ToHex() string {
+ return hex.EncodeToString(key[:])
+}
+
+func (key *NoisePublicKey) FromHex(s string) error {
+ slice, err := hex.DecodeString(s)
+ if err != nil {
+ return err
+ }
+ if len(slice) != NoisePublicKeySize {
+ return errors.New("Invalid length of hex string for curve25519 scalar")
+ }
+ copy(key[:], slice)
+ return nil
+}
+
+func (key *NoisePublicKey) ToHex() string {
+ return hex.EncodeToString(key[:])
+}
diff --git a/src/peer.go b/src/peer.go
new file mode 100644
index 0000000..7c000da
--- /dev/null
+++ b/src/peer.go
@@ -0,0 +1,18 @@
+package main
+
+import (
+ "sync"
+)
+
+type KeyPair struct {
+ recieveKey NoiseSymmetricKey
+ recieveNonce NoiseNonce
+ sendKey NoiseSymmetricKey
+ sendNonce NoiseNonce
+}
+
+type Peer struct {
+ mutex sync.RWMutex
+ publicKey NoisePublicKey
+ presharedKey NoiseSymmetricKey
+}
diff --git a/src/trie.go b/src/trie.go
new file mode 100644
index 0000000..7fd7c5f
--- /dev/null
+++ b/src/trie.go
@@ -0,0 +1,154 @@
+package main
+
+import "fmt"
+
+/* Syncronization must be done seperatly
+ *
+ */
+
+type Trie struct {
+ cidr uint
+ child [2]*Trie
+ bits []byte
+ peer *Peer
+
+ // Index of "branching" bit
+ // bit_at_shift
+ bit_at_byte uint
+ bit_at_shift uint
+}
+
+/* Finds length of matching prefix
+ * Maybe there is a faster way
+ *
+ * Assumption: len(s1) == len(s2)
+ */
+func commonBits(s1 []byte, s2 []byte) uint {
+ var i uint
+ size := uint(len(s1))
+ for i = 0; i < size; i += 1 {
+ v := s1[i] ^ s2[i]
+ if v != 0 {
+ v >>= 1
+ if v == 0 {
+ return i*8 + 7
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 6
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 5
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 4
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 3
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 2
+ }
+
+ v >>= 1
+ if v == 0 {
+ return i*8 + 1
+ }
+ return i * 8
+ }
+ }
+ return i * 8
+}
+
+func (node *Trie) RemovePeer(p *Peer) *Trie {
+ if node == nil {
+ return node
+ }
+
+ // Walk recursivly
+
+ node.child[0] = node.child[0].RemovePeer(p)
+ node.child[1] = node.child[1].RemovePeer(p)
+
+ if node.peer != p {
+ return node
+ }
+
+ // Remove peer & merge
+
+ node.peer = nil
+ if node.child[0] == nil {
+ return node.child[1]
+ }
+ return node.child[0]
+}
+
+func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie {
+ if node == nil {
+ return &Trie{
+ bits: key,
+ peer: peer,
+ cidr: cidr,
+ bit_at_byte: cidr / 8,
+ bit_at_shift: 7 - (cidr % 8),
+ }
+ }
+
+ // Traverse deeper
+
+ common := commonBits(node.bits, key)
+ if node.cidr <= cidr && common >= node.cidr {
+ // Check if match the t.bits[:t.cidr] exactly
+ if node.cidr == cidr {
+ node.peer = peer
+ return node
+ }
+
+ // Go to child
+ bit := (key[node.bit_at_byte] >> node.bit_at_shift) & 1
+ node.child[bit] = node.child[bit].Insert(key, cidr, peer)
+ return node
+ }
+
+ // Split node
+
+ fmt.Println("new", common)
+
+ newNode := &Trie{
+ bits: key,
+ peer: peer,
+ cidr: cidr,
+ bit_at_byte: cidr / 8,
+ bit_at_shift: 7 - (cidr % 8),
+ }
+
+ cidr = min(cidr, common)
+ node.cidr = cidr
+ node.bit_at_byte = cidr / 8
+ node.bit_at_shift = 7 - (cidr % 8)
+
+ // bval := node.bits[node.bit_at_byte] >> node.bit_at_shift // todo : remember index
+ // Work in progress
+ node.child[0] = newNode
+ node.child[1] = newNode
+
+ return node
+}
+
+func (t *Trie) Lookup(key []byte) *Peer {
+ if t == nil {
+ return nil
+ }
+
+ return nil
+
+}
diff --git a/src/trie_test.go b/src/trie_test.go
new file mode 100644
index 0000000..ec4cde3
--- /dev/null
+++ b/src/trie_test.go
@@ -0,0 +1,66 @@
+package main
+
+import (
+ "testing"
+)
+
+type testPairCommonBits struct {
+ s1 []byte
+ s2 []byte
+ match uint
+}
+
+type testPairTrieInsert struct {
+ key []byte
+ cidr uint
+ peer *Peer
+}
+
+func printTrie(t *testing.T, p *Trie) {
+ if p == nil {
+ return
+ }
+ t.Log(p)
+ printTrie(t, p.child[0])
+ printTrie(t, p.child[1])
+}
+
+func TestCommonBits(t *testing.T) {
+
+ tests := []testPairCommonBits{
+ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
+ {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
+ {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
+ {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
+ {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
+ }
+
+ for _, p := range tests {
+ v := commonBits(p.s1, p.s2)
+ if v != p.match {
+ t.Error(
+ "For slice", p.s1, p.s2,
+ "expected match", p.match,
+ "got", v,
+ )
+ }
+ }
+}
+
+func TestTrieInsertV4(t *testing.T) {
+ var trie *Trie
+
+ peer1 := Peer{}
+ peer2 := Peer{}
+
+ tests := []testPairTrieInsert{
+ {key: []byte{192, 168, 1, 1}, cidr: 24, peer: &peer1},
+ {key: []byte{192, 169, 1, 1}, cidr: 24, peer: &peer2},
+ }
+
+ for _, p := range tests {
+ trie = trie.Insert(p.key, p.cidr, p.peer)
+ printTrie(t, trie)
+ }
+
+}