diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/config.go | 132 | ||||
-rw-r--r-- | src/device.go | 8 | ||||
-rw-r--r-- | src/main.go | 25 | ||||
-rw-r--r-- | src/routing.go | 6 | ||||
-rw-r--r-- | src/send.go | 8 | ||||
-rw-r--r-- | src/trie.go | 10 |
6 files changed, 109 insertions, 80 deletions
diff --git a/src/config.go b/src/config.go index 3b91d00..2f8dc76 100644 --- a/src/config.go +++ b/src/config.go @@ -5,24 +5,22 @@ import ( "errors" "fmt" "io" - "log" "net" "strconv" + "strings" "time" ) -/* TODO : use real error code - * Many of which will be the same +// #include <errno.h> +import "C" + +/* TODO: More fine grained? */ const ( - ipcErrorNoPeer = 0 - ipcErrorNoKeyValue = 1 - ipcErrorInvalidKey = 2 - ipcErrorInvalidValue = 2 - ipcErrorInvalidPrivateKey = 3 - ipcErrorInvalidPublicKey = 4 - ipcErrorInvalidPort = 5 - ipcErrorInvalidIPAddress = 6 + ipcErrorNoPeer = C.EPROTO + ipcErrorNoKeyValue = C.EPROTO + ipcErrorInvalidKey = C.EPROTO + ipcErrorInvalidValue = C.EPROTO ) type IPCError struct { @@ -78,7 +76,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { // send lines for _, line := range lines { - device.log.Debug.Println("config:", line) + device.log.Debug.Println("Response:", line) _, err := socket.WriteString(line + "\n") if err != nil { return err @@ -89,29 +87,26 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { } func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - + logger := device.log.Debug scanner := bufio.NewScanner(socket) - device.mutex.Lock() - defer device.mutex.Unlock() - + var peer *Peer for scanner.Scan() { - var key string - var value string - var peer *Peer // Parse line line := scanner.Text() - if line == "\n" { - break + if line == "" { + return nil } - fmt.Println(line) - n, err := fmt.Sscanf(line, "%s=%s\n", &key, &value) - if n != 2 || err != nil { - fmt.Println(err, n) + parts := strings.Split(line, "=") + if len(parts) != 2 { + device.log.Debug.Println(parts) return &IPCError{Code: ipcErrorNoKeyValue} } + key := parts[0] + value := parts[1] + logger.Println("Key-value pair: (", key, ",", value, ")") // TODO: Remove, leaks private key to log switch key { @@ -119,41 +114,60 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "private_key": if value == "" { + device.mutex.Lock() device.privateKey = NoisePrivateKey{} + device.mutex.Unlock() } else { + device.mutex.Lock() err := device.privateKey.FromHex(value) + device.mutex.Unlock() if err != nil { - return &IPCError{Code: ipcErrorInvalidPrivateKey} + logger.Println("Failed to set private_key:", err) + return &IPCError{Code: ipcErrorInvalidValue} } } case "listen_port": - _, err := fmt.Sscanf(value, "%ud", &device.address.Port) - if err != nil { - return &IPCError{Code: ipcErrorInvalidPort} + var port int + _, err := fmt.Sscanf(value, "%d", &port) + if err != nil || port > (1<<16) || port < 0 { + logger.Println("Failed to set listen_port:", err) + return &IPCError{Code: ipcErrorInvalidValue} } + device.mutex.Lock() + if device.address == nil { + device.address = &net.UDPAddr{} + } + device.address.Port = port + device.mutex.Unlock() case "fwmark": - panic(nil) // not handled yet + logger.Println("FWMark not handled yet") case "public_key": var pubKey NoisePublicKey err := pubKey.FromHex(value) if err != nil { - return &IPCError{Code: ipcErrorInvalidPublicKey} + logger.Println("Failed to get peer by public_key:", err) + return &IPCError{Code: ipcErrorInvalidValue} } + device.mutex.RLock() found, ok := device.peers[pubKey] + device.mutex.RUnlock() if ok { peer = found } else { peer = device.NewPeer(pubKey) } + if peer == nil { + panic(errors.New("bug: failed to find peer")) + } case "replace_peers": - if key == "true" { + if value == "true" { device.RemoveAllPeers() - } else if key == "false" { } else { + logger.Println("Failed to set replace_peers, invalid value:", value) return &IPCError{Code: ipcErrorInvalidValue} } @@ -161,6 +175,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { /* Peer configuration */ if peer == nil { + logger.Println("No peer referenced, before peer operation") return &IPCError{Code: ipcErrorNoPeer} } @@ -168,7 +183,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "remove": peer.mutex.Lock() - // device.RemovePeer(peer.publicKey) + device.RemovePeer(peer.handshake.remoteStatic) + peer.mutex.Unlock() + logger.Println("Remove peer") peer = nil case "preshared_key": @@ -178,13 +195,15 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return peer.handshake.presharedKey.FromHex(value) }() if err != nil { - return &IPCError{Code: ipcErrorInvalidPublicKey} + logger.Println("Failed to set preshared_key:", err) + return &IPCError{Code: ipcErrorInvalidValue} } case "endpoint": ip := net.ParseIP(value) if ip == nil { - return &IPCError{Code: ipcErrorInvalidIPAddress} + logger.Println("Failed to set endpoint:", value) + return &IPCError{Code: ipcErrorInvalidValue} } peer.mutex.Lock() // peer.endpoint = ip FIX @@ -193,6 +212,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "persistent_keepalive_interval": secs, err := strconv.ParseInt(value, 10, 64) if secs < 0 || err != nil { + logger.Println("Failed to set persistent_keepalive_interval:", err) return &IPCError{Code: ipcErrorInvalidValue} } peer.mutex.Lock() @@ -200,24 +220,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { peer.mutex.Unlock() case "replace_allowed_ips": - if key == "true" { + if value == "true" { device.routingTable.RemovePeer(peer) - } else if key == "false" { } else { + logger.Println("Failed to set replace_allowed_ips, invalid value:", value) return &IPCError{Code: ipcErrorInvalidValue} } case "allowed_ip": _, network, err := net.ParseCIDR(value) if err != nil { + logger.Println("Failed to set allowed_ip:", err) return &IPCError{Code: ipcErrorInvalidValue} } ones, _ := network.Mask.Size() + logger.Println(network, ones, network.IP) device.routingTable.Insert(network.IP, uint(ones), peer) /* Invalid key */ default: + logger.Println("Invalid key:", key) return &IPCError{Code: ipcErrorInvalidKey} } } @@ -226,49 +249,48 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return nil } -func ipcListen(device *Device, socket io.ReadWriter) error { +func ipcHandle(device *Device, socket net.Conn) { - buffered := func(s io.ReadWriter) *bufio.ReadWriter { - reader := bufio.NewReader(s) - writer := bufio.NewWriter(s) - return bufio.NewReadWriter(reader, writer) - }(socket) + func() { + buffered := func(s io.ReadWriter) *bufio.ReadWriter { + reader := bufio.NewReader(s) + writer := bufio.NewWriter(s) + return bufio.NewReadWriter(reader, writer) + }(socket) - defer buffered.Flush() + defer buffered.Flush() - for { op, err := buffered.ReadString('\n') if err != nil { - return err + return } - log.Println(op) switch op { case "set=1\n": + device.log.Debug.Println("Config, set operation") err := ipcSetOperation(device, buffered) if err != nil { fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) - return err } else { fmt.Fprintf(buffered, "errno=0\n\n") } - buffered.Flush() + break case "get=1\n": + device.log.Debug.Println("Config, get operation") err := ipcGetOperation(device, buffered) if err != nil { fmt.Fprintf(buffered, "errno=1\n\n") // fix - return err } else { fmt.Fprintf(buffered, "errno=0\n\n") } - buffered.Flush() + break - case "\n": default: - return errors.New("handle this please") + device.log.Info.Println("Invalid UAPI operation:", op) } - } + }() + socket.Close() } diff --git a/src/device.go b/src/device.go index a7a5c7b..52ac6a4 100644 --- a/src/device.go +++ b/src/device.go @@ -81,10 +81,7 @@ func (device *Device) RemovePeer(key NoisePublicKey) { peer.mutex.Lock() device.routingTable.RemovePeer(peer) delete(device.peers, key) -} - -func (device *Device) RemoveAllAllowedIps(peer *Peer) { - + peer.Close() } func (device *Device) RemoveAllPeers() { @@ -93,8 +90,7 @@ func (device *Device) RemoveAllPeers() { for key, peer := range device.peers { peer.mutex.Lock() - device.routingTable.RemovePeer(peer) delete(device.peers, key) - peer.mutex.Unlock() + peer.Close() } } diff --git a/src/main.go b/src/main.go index 7c58972..9c76ff4 100644 --- a/src/main.go +++ b/src/main.go @@ -1,21 +1,28 @@ package main import ( + "fmt" "log" "net" + "os" ) -/* - * - * TODO: Fix logging +/* TODO: Fix logging + * TODO: Fix daemon */ func main() { + + if len(os.Args) != 2 { + return + } + deviceName := os.Args[1] + // Open TUN device // TODO: Fix capabilities - tun, err := CreateTUN("test0") + tun, err := CreateTUN(deviceName) log.Println(tun, err) if err != nil { return @@ -25,19 +32,17 @@ func main() { // Start configuration lister - l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock") + socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName) + l, err := net.Listen("unix", socketPath) if err != nil { log.Fatal("listen error:", err) } for { - fd, err := l.Accept() + conn, err := l.Accept() if err != nil { log.Fatal("accept error:", err) } - go func(conn net.Conn) { - err := ipcListen(device, conn) - log.Println(err) - }(fd) + go ipcHandle(device, conn) } } diff --git a/src/routing.go b/src/routing.go index 6a5e1f3..2a2e237 100644 --- a/src/routing.go +++ b/src/routing.go @@ -16,9 +16,9 @@ func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet { table.mutex.RLock() defer table.mutex.RUnlock() - allowed := make([]net.IPNet, 10) - table.IPv4.AllowedIPs(peer, allowed) - table.IPv6.AllowedIPs(peer, allowed) + allowed := make([]net.IPNet, 0, 10) + allowed = table.IPv4.AllowedIPs(peer, allowed) + allowed = table.IPv6.AllowedIPs(peer, allowed) return allowed } diff --git a/src/send.go b/src/send.go index 4ff75db..ab75750 100644 --- a/src/send.go +++ b/src/send.go @@ -61,9 +61,11 @@ func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) { * Obs. Single instance per TUN device */ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { + device.log.Debug.Println("Routine, TUN Reader: started") for { // read packet + device.log.Debug.Println("Read") packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation size, err := tun.Read(packet) if err != nil { @@ -76,8 +78,6 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { continue } - device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove. - // lookup peer var peer *Peer @@ -85,10 +85,12 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { case IPv4version: dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] peer = device.routingTable.LookupIPv4(dst) + device.log.Debug.Println("New IPv4 packet:", packet, dst) case IPv6version: dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] peer = device.routingTable.LookupIPv6(dst) + device.log.Debug.Println("New IPv6 packet:", packet, dst) default: device.log.Debug.Println("Receieved packet with unknown IP version") @@ -97,7 +99,7 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { if peer == nil { device.log.Debug.Println("No peer configured for IP") - return + continue } // insert into nonce/pre-handshake queue diff --git a/src/trie.go b/src/trie.go index 4049167..c2304b2 100644 --- a/src/trie.go +++ b/src/trie.go @@ -195,7 +195,10 @@ func (node *Trie) Count() uint { return l + r } -func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) { +func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { + if node == nil { + return results + } if node.peer == p { var mask net.IPNet mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) @@ -213,6 +216,7 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) { } results = append(results, mask) } - node.child[0].AllowedIPs(p, results) - node.child[1].AllowedIPs(p, results) + results = node.child[0].AllowedIPs(p, results) + results = node.child[1].AllowedIPs(p, results) + return results } |