diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/config.go | 53 | ||||
-rw-r--r-- | src/ip.go | 17 | ||||
-rw-r--r-- | src/main.go | 26 | ||||
-rw-r--r-- | src/peer.go | 10 | ||||
-rw-r--r-- | src/routing.go | 55 | ||||
-rw-r--r-- | src/trie.go | 40 | ||||
-rw-r--r-- | src/trie_test.go | 63 | ||||
-rw-r--r-- | src/tun.go | 8 | ||||
-rw-r--r-- | src/tun_linux.go | 80 |
9 files changed, 290 insertions, 62 deletions
diff --git a/src/config.go b/src/config.go index 62af67a..a61b940 100644 --- a/src/config.go +++ b/src/config.go @@ -7,6 +7,8 @@ import ( "io" "log" "net" + "strconv" + "time" ) /* todo : use real error code @@ -16,6 +18,7 @@ const ( ipcErrorNoPeer = 0 ipcErrorNoKeyValue = 1 ipcErrorInvalidKey = 2 + ipcErrorInvalidValue = 2 ipcErrorInvalidPrivateKey = 3 ipcErrorInvalidPublicKey = 4 ipcErrorInvalidPort = 5 @@ -34,18 +37,16 @@ func (s *IPCError) ErrorCode() int { return s.Code } -// Writes the configuration to the socket func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) { } -// Creates new config, from old and socket message -func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { +func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { scanner := bufio.NewScanner(socket) - dev.mutex.Lock() - defer dev.mutex.Unlock() + device.mutex.Lock() + defer device.mutex.Unlock() for scanner.Scan() { var key string @@ -71,16 +72,16 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { case "private_key": if value == "" { - dev.privateKey = NoisePrivateKey{} + device.privateKey = NoisePrivateKey{} } else { - err := dev.privateKey.FromHex(value) + err := device.privateKey.FromHex(value) if err != nil { return &IPCError{Code: ipcErrorInvalidPrivateKey} } } case "listen_port": - _, err := fmt.Sscanf(value, "%ud", &dev.listenPort) + _, err := fmt.Sscanf(value, "%ud", &device.listenPort) if err != nil { return &IPCError{Code: ipcErrorInvalidPort} } @@ -94,7 +95,7 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { if err != nil { return &IPCError{Code: ipcErrorInvalidPublicKey} } - found, ok := dev.peers[pubKey] + found, ok := device.peers[pubKey] if ok { peer = found } else { @@ -102,14 +103,16 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { publicKey: pubKey, } peer = newPeer - dev.peers[pubKey] = newPeer + device.peers[pubKey] = newPeer } case "replace_peers": if key == "true" { - dev.RemoveAllPeers() + device.RemoveAllPeers() + } else if key == "false" { + } else { + return &IPCError{Code: ipcErrorInvalidValue} } - // todo: else fail default: /* Peer configuration */ @@ -122,7 +125,7 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { case "remove": peer.mutex.Lock() - dev.RemovePeer(peer.publicKey) + device.RemovePeer(peer.publicKey) peer = nil case "preshared_key": @@ -145,15 +148,29 @@ func ipcSetOperation(dev *Device, socket *bufio.ReadWriter) *IPCError { peer.mutex.Unlock() case "persistent_keepalive_interval": - func() { - peer.mutex.Lock() - defer peer.mutex.Unlock() - }() + secs, err := strconv.ParseInt(value, 10, 64) + if secs < 0 || err != nil { + return &IPCError{Code: ipcErrorInvalidValue} + } + peer.mutex.Lock() + peer.persistentKeepaliveInterval = time.Duration(secs) * time.Second + peer.mutex.Unlock() case "replace_allowed_ips": - // remove peer from trie + if key == "true" { + device.routingTable.RemovePeer(peer) + } else if key == "false" { + } else { + return &IPCError{Code: ipcErrorInvalidValue} + } case "allowed_ip": + _, network, err := net.ParseCIDR(value) + if err != nil { + return &IPCError{Code: ipcErrorInvalidValue} + } + ones, _ := network.Mask.Size() + device.routingTable.Insert(network.IP, uint(ones), peer) /* Invalid key */ diff --git a/src/ip.go b/src/ip.go new file mode 100644 index 0000000..3137891 --- /dev/null +++ b/src/ip.go @@ -0,0 +1,17 @@ +package main + +import ( + "net" +) + +const ( + IPv4version = 4 + IPv4offsetSrc = 12 + IPv4offsetDst = IPv4offsetSrc + net.IPv4len +) + +const ( + IPv6version = 6 + IPv6offsetSrc = 8 + IPv6offsetDst = IPv6offsetSrc + net.IPv6len +) diff --git a/src/main.go b/src/main.go index 0f5016d..af336f0 100644 --- a/src/main.go +++ b/src/main.go @@ -1,11 +1,33 @@ package main +import "fmt" + +func main() { + fd, err := CreateTUN("test0") + fmt.Println(fd, err) + + queue := make(chan []byte, 1000) + + var device Device + + go OutgoingRoutingWorker(&device, queue) + + for { + tmp := make([]byte, 1<<16) + n, err := fd.Read(tmp) + if err != nil { + break + } + queue <- tmp[:n] + } +} + +/* import ( "fmt" "log" "net" ) - func main() { l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock") if err != nil { @@ -24,5 +46,5 @@ func main() { fmt.Println(err) }(fd) } - } +*/ diff --git a/src/peer.go b/src/peer.go index 7b2b2a6..db5e99f 100644 --- a/src/peer.go +++ b/src/peer.go @@ -3,6 +3,7 @@ package main import ( "net" "sync" + "time" ) type KeyPair struct { @@ -13,8 +14,9 @@ type KeyPair struct { } type Peer struct { - mutex sync.RWMutex - publicKey NoisePublicKey - presharedKey NoiseSymmetricKey - endpoint net.IP + mutex sync.RWMutex + publicKey NoisePublicKey + presharedKey NoiseSymmetricKey + endpoint net.IP + persistentKeepaliveInterval time.Duration } diff --git a/src/routing.go b/src/routing.go index 99b180c..0aa111c 100644 --- a/src/routing.go +++ b/src/routing.go @@ -1,13 +1,12 @@ package main import ( + "errors" + "fmt" + "net" "sync" ) -/* Thread-safe high level functions for cryptkey routing. - * - */ - type RoutingTable struct { IPv4 *Trie IPv6 *Trie @@ -20,3 +19,51 @@ func (table *RoutingTable) RemovePeer(peer *Peer) { table.IPv4 = table.IPv4.RemovePeer(peer) table.IPv6 = table.IPv6.RemovePeer(peer) } + +func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + switch len(ip) { + case net.IPv6len: + table.IPv6 = table.IPv6.Insert(ip, cidr, peer) + case net.IPv4len: + table.IPv4 = table.IPv4.Insert(ip, cidr, peer) + default: + panic(errors.New("Inserting unknown address type")) + } +} + +func (table *RoutingTable) LookupIPv4(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv4.Lookup(address) +} + +func (table *RoutingTable) LookupIPv6(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv6.Lookup(address) +} + +func OutgoingRoutingWorker(device *Device, queue chan []byte) { + for { + packet := <-queue + switch packet[0] >> 4 { + + case IPv4version: + dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer := device.routingTable.LookupIPv4(dst) + fmt.Println("IPv4", peer) + + case IPv6version: + dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer := device.routingTable.LookupIPv6(dst) + fmt.Println("IPv6", peer) + + default: + // todo: log + fmt.Println("Unknown IP version") + } + } +} diff --git a/src/trie.go b/src/trie.go index 31a4d92..746c1b4 100644 --- a/src/trie.go +++ b/src/trie.go @@ -1,5 +1,9 @@ package main +import ( + "net" +) + /* Binary trie * * Syncronization done seperatly @@ -22,13 +26,13 @@ type Trie struct { /* Finds length of matching prefix * Maybe there is a faster way * - * Assumption: len(s1) == len(s2) + * Assumption: len(ip1) == len(ip2) */ -func commonBits(s1 []byte, s2 []byte) uint { +func commonBits(ip1 net.IP, ip2 net.IP) uint { var i uint - size := uint(len(s1)) + size := uint(len(ip1)) for i = 0; i < size; i += 1 { - v := s1[i] ^ s2[i] + v := ip1[i] ^ ip2[i] if v != 0 { v >>= 1 if v == 0 { @@ -93,17 +97,17 @@ func (node *Trie) RemovePeer(p *Peer) *Trie { return node.child[0] } -func (node *Trie) choose(key []byte) byte { - return (key[node.bit_at_byte] >> node.bit_at_shift) & 1 +func (node *Trie) choose(ip net.IP) byte { + return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 } -func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie { +func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { // At leaf if node == nil { return &Trie{ - bits: key, + bits: ip, peer: peer, cidr: cidr, bit_at_byte: cidr / 8, @@ -113,21 +117,21 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie { // Traverse deeper - common := commonBits(node.bits, key) + common := commonBits(node.bits, ip) if node.cidr <= cidr && common >= node.cidr { if node.cidr == cidr { node.peer = peer return node } - bit := node.choose(key) - node.child[bit] = node.child[bit].Insert(key, cidr, peer) + bit := node.choose(ip) + node.child[bit] = node.child[bit].Insert(ip, cidr, peer) return node } // Split node newNode := &Trie{ - bits: key, + bits: ip, peer: peer, cidr: cidr, bit_at_byte: cidr / 8, @@ -147,31 +151,31 @@ func (node *Trie) Insert(key []byte, cidr uint, peer *Peer) *Trie { // Create new parent for node & newNode parent := &Trie{ - bits: key, + bits: ip, peer: nil, cidr: cidr, bit_at_byte: cidr / 8, bit_at_shift: 7 - (cidr % 8), } - bit := parent.choose(key) + bit := parent.choose(ip) parent.child[bit] = newNode parent.child[bit^1] = node return parent } -func (node *Trie) Lookup(key []byte) *Peer { +func (node *Trie) Lookup(ip net.IP) *Peer { var found *Peer - size := uint(len(key)) - for node != nil && commonBits(node.bits, key) >= node.cidr { + size := uint(len(ip)) + for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } if node.bit_at_byte == size { break } - bit := node.choose(key) + bit := node.choose(ip) node = node.child[bit] } return found diff --git a/src/trie_test.go b/src/trie_test.go index 35af0aa..9d53df3 100644 --- a/src/trie_test.go +++ b/src/trie_test.go @@ -1,6 +1,8 @@ package main import ( + "math/rand" + "net" "testing" ) @@ -55,6 +57,49 @@ func TestCommonBits(t *testing.T) { } } +func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { + var trie *Trie + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < peerNumber; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < addressNumber; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % peerNumber + trie = trie.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < b.N; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + trie.Lookup(addr[:]) + } +} + +func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv4len, b) +} + +func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv4len, b) +} + +func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv6len, b) +} + +func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv6len, b) +} + /* Test ported from kernel implementation: * selftest/routingtable.h */ @@ -91,10 +136,10 @@ func TestTrieIPv4(t *testing.T) { insert(b, 192, 168, 4, 4, 32) insert(c, 192, 168, 0, 0, 16) insert(d, 192, 95, 5, 64, 27) - insert(c, 192, 95, 5, 65, 27) /* replaces previous entry, and maskself is required */ + insert(c, 192, 95, 5, 65, 27) insert(e, 0, 0, 0, 0, 0) insert(g, 64, 15, 112, 0, 20) - insert(h, 64, 15, 123, 211, 25) /* maskself is required */ + insert(h, 64, 15, 123, 211, 25) insert(a, 10, 0, 0, 0, 25) insert(b, 10, 0, 0, 128, 25) insert(a, 10, 1, 0, 0, 30) @@ -186,20 +231,6 @@ func TestTrieIPv6(t *testing.T) { } } - /* - assertNEQ := func(peer *Peer, a, b, c, d uint32) { - var addr []byte - addr = append(addr, expand(a)...) - addr = append(addr, expand(b)...) - addr = append(addr, expand(c)...) - addr = append(addr, expand(d)...) - p := trie.Lookup(addr) - if p == peer { - t.Error("Assert NEQ failed") - } - } - */ - insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) insert(c, 0x26075300, 0x60006b00, 0, 0, 64) insert(e, 0, 0, 0, 0, 0) diff --git a/src/tun.go b/src/tun.go new file mode 100644 index 0000000..1a8bb82 --- /dev/null +++ b/src/tun.go @@ -0,0 +1,8 @@ +package main + +type TUN interface { + Read([]byte) (int, error) + Write([]byte) (int, error) + Name() string + MTU() uint +} diff --git a/src/tun_linux.go b/src/tun_linux.go new file mode 100644 index 0000000..d545dfa --- /dev/null +++ b/src/tun_linux.go @@ -0,0 +1,80 @@ +package main + +import ( + "encoding/binary" + "errors" + "os" + "strings" + "syscall" + "unsafe" +) + +/* Platform dependent functions for interacting with + * TUN devices on linux systems + * + */ + +const CloneDevicePath = "/dev/net/tun" + +const ( + IFF_NO_PI = 0x1000 + IFF_TUN = 0x1 + IFNAMSIZ = 0x10 + TUNSETIFF = 0x400454CA +) + +type NativeTun struct { + fd *os.File + name string + mtu uint +} + +func (tun *NativeTun) Name() string { + return tun.name +} + +func (tun *NativeTun) MTU() uint { + return tun.mtu +} + +func (tun *NativeTun) Write(d []byte) (int, error) { + return tun.fd.Write(d) +} + +func (tun *NativeTun) Read(d []byte) (int, error) { + return tun.fd.Read(d) +} + +func CreateTUN(name string) (TUN, error) { + // Open clone device + fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0) + if err != nil { + return nil, err + } + + // Prepare ifreq struct + var ifr [18]byte + var flags uint16 = IFF_TUN | IFF_NO_PI + nameBytes := []byte(name) + if len(nameBytes) >= IFNAMSIZ { + return nil, errors.New("Name size too long") + } + copy(ifr[:], nameBytes) + binary.LittleEndian.PutUint16(ifr[16:], flags) + + // Create new device + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, + uintptr(fd.Fd()), uintptr(TUNSETIFF), + uintptr(unsafe.Pointer(&ifr[0]))) + if errno != 0 { + return nil, errors.New("Failed to create tun, ioctl call failed") + } + + // Read name of interface + newName := string(ifr[:]) + newName = newName[:strings.Index(newName, "\000")] + return &NativeTun{ + fd: fd, + name: newName, + }, nil +} |