diff options
Diffstat (limited to 'src/config.go')
-rw-r--r-- | src/config.go | 160 |
1 files changed, 83 insertions, 77 deletions
diff --git a/src/config.go b/src/config.go index 4edaa2e..d92e8d7 100644 --- a/src/config.go +++ b/src/config.go @@ -8,39 +8,36 @@ import ( "net" "strconv" "strings" + "sync/atomic" + "syscall" ) -// #include <errno.h> -import "C" - -/* TODO: More fine grained? - */ const ( - ipcErrorNoPeer = C.EPROTO - ipcErrorNoKeyValue = C.EPROTO - ipcErrorInvalidKey = C.EPROTO - ipcErrorInvalidValue = C.EPROTO + ipcErrorIO = syscall.EIO + ipcErrorNoPeer = syscall.EPROTO + ipcErrorNoKeyValue = syscall.EPROTO + ipcErrorInvalidKey = syscall.EPROTO + ipcErrorInvalidValue = syscall.EPROTO ) type IPCError struct { - Code int + Code syscall.Errno } func (s *IPCError) Error() string { return fmt.Sprintf("IPC error: %d", s.Code) } -func (s *IPCError) ErrorCode() int { - return s.Code +func (s *IPCError) ErrorCode() uintptr { + return uintptr(s.Code) } -func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { - - device.mutex.RLock() - defer device.mutex.RUnlock() +func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { // create lines + device.mutex.RLock() + lines := make([]string, 0, 100) send := func(line string) { lines = append(lines, line) @@ -63,19 +60,25 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { } send(fmt.Sprintf("tx_bytes=%d", peer.txBytes)) send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes)) - send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) + send(fmt.Sprintf("persistent_keepalive_interval=%d", + atomic.LoadUint64(&peer.persistentKeepaliveInterval), + )) for _, ip := range device.routingTable.AllowedIPs(peer) { send("allowed_ip=" + ip.String()) } }() } + device.mutex.RUnlock() + // send lines for _, line := range lines { _, err := socket.WriteString(line + "\n") if err != nil { - return err + return &IPCError{ + Code: ipcErrorIO, + } } } @@ -83,13 +86,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { } func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { - logger := device.log.Debug scanner := bufio.NewScanner(socket) + logError := device.log.Error + logDebug := device.log.Debug var peer *Peer for scanner.Scan() { - // Parse line + // parse line line := scanner.Text() if line == "" { @@ -97,7 +101,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } parts := strings.Split(line, "=") if len(parts) != 2 { - device.log.Debug.Println(parts) return &IPCError{Code: ipcErrorNoKeyValue} } key := parts[0] @@ -105,7 +108,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { switch key { - /* Interface configuration */ + /* interface configuration */ case "private_key": if value == "" { @@ -116,7 +119,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { var sk NoisePrivateKey err := sk.FromHex(value) if err != nil { - logger.Println("Failed to set private_key:", err) + logError.Println("Failed to set private_key:", err) return &IPCError{Code: ipcErrorInvalidValue} } device.SetPrivateKey(sk) @@ -126,22 +129,26 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { var port int _, err := fmt.Sscanf(value, "%d", &port) if err != nil || port > (1<<16) || port < 0 { - logger.Println("Failed to set listen_port:", err) + logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorInvalidValue} } device.net.mutex.Lock() device.net.addr.Port = port device.net.conn, err = net.ListenUDP("udp", device.net.addr) device.net.mutex.Unlock() + if err != nil { + logError.Println("Failed to create UDP listener:", err) + return &IPCError{Code: ipcErrorInvalidValue} + } case "fwmark": - logger.Println("FWMark not handled yet") + logError.Println("FWMark not handled yet") case "public_key": var pubKey NoisePublicKey err := pubKey.FromHex(value) if err != nil { - logger.Println("Failed to get peer by public_key:", err) + logError.Println("Failed to get peer by public_key:", err) return &IPCError{Code: ipcErrorInvalidValue} } device.mutex.RLock() @@ -153,22 +160,23 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { peer = device.NewPeer(pubKey) } if peer == nil { - panic(errors.New("bug: failed to find peer")) + panic(errors.New("bug: failed to find / create peer")) } case "replace_peers": if value == "true" { device.RemoveAllPeers() } else { - logger.Println("Failed to set replace_peers, invalid value:", value) + logError.Println("Failed to set replace_peers, invalid value:", value) return &IPCError{Code: ipcErrorInvalidValue} } default: - /* Peer configuration */ + + /* peer configuration */ if peer == nil { - logger.Println("No peer referenced, before peer operation") + logError.Println("No peer referenced, before peer operation") return &IPCError{Code: ipcErrorNoPeer} } @@ -178,7 +186,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { peer.mutex.Lock() device.RemovePeer(peer.handshake.remoteStatic) peer.mutex.Unlock() - logger.Println("Remove peer") + logDebug.Println("Removing", peer.String()) peer = nil case "preshared_key": @@ -188,14 +196,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return peer.handshake.presharedKey.FromHex(value) }() if err != nil { - logger.Println("Failed to set preshared_key:", err) + logError.Println("Failed to set preshared_key:", err) return &IPCError{Code: ipcErrorInvalidValue} } case "endpoint": addr, err := net.ResolveUDPAddr("udp", value) if err != nil { - logger.Println("Failed to set endpoint:", value) + logError.Println("Failed to set endpoint:", value) return &IPCError{Code: ipcErrorInvalidValue} } peer.mutex.Lock() @@ -205,35 +213,34 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "persistent_keepalive_interval": secs, err := strconv.ParseInt(value, 10, 64) if secs < 0 || err != nil { - logger.Println("Failed to set persistent_keepalive_interval:", err) + logError.Println("Failed to set persistent_keepalive_interval:", err) return &IPCError{Code: ipcErrorInvalidValue} } - peer.mutex.Lock() - peer.persistentKeepaliveInterval = uint64(secs) - peer.mutex.Unlock() + atomic.StoreUint64( + &peer.persistentKeepaliveInterval, + uint64(secs), + ) case "replace_allowed_ips": if value == "true" { device.routingTable.RemovePeer(peer) } else { - logger.Println("Failed to set replace_allowed_ips, invalid value:", value) + logError.Println("Failed to set replace_allowed_ips, invalid value:", value) return &IPCError{Code: ipcErrorInvalidValue} } case "allowed_ip": _, network, err := net.ParseCIDR(value) if err != nil { - logger.Println("Failed to set allowed_ip:", err) + logError.Println("Failed to set allowed_ip:", err) return &IPCError{Code: ipcErrorInvalidValue} } ones, _ := network.Mask.Size() - logger.Println(network, ones, network.IP) + logError.Println(network, ones, network.IP) device.routingTable.Insert(network.IP, uint(ones), peer) - /* Invalid key */ - default: - logger.Println("Invalid key:", key) + logError.Println("Invalid UAPI key:", key) return &IPCError{Code: ipcErrorInvalidKey} } } @@ -244,46 +251,45 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { func ipcHandle(device *Device, socket net.Conn) { - func() { - buffered := func(s io.ReadWriter) *bufio.ReadWriter { - reader := bufio.NewReader(s) - writer := bufio.NewWriter(s) - return bufio.NewReadWriter(reader, writer) - }(socket) + defer socket.Close() - defer buffered.Flush() + buffered := func(s io.ReadWriter) *bufio.ReadWriter { + reader := bufio.NewReader(s) + writer := bufio.NewWriter(s) + return bufio.NewReadWriter(reader, writer) + }(socket) - op, err := buffered.ReadString('\n') - if err != nil { - return - } + defer buffered.Flush() - switch op { + op, err := buffered.ReadString('\n') + if err != nil { + return + } - case "set=1\n": - device.log.Debug.Println("Config, set operation") - err := ipcSetOperation(device, buffered) - if err != nil { - fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } - break + switch op { - case "get=1\n": - device.log.Debug.Println("Config, get operation") - err := ipcGetOperation(device, buffered) - if err != nil { - fmt.Fprintf(buffered, "errno=1\n\n") // fix - } else { - fmt.Fprintf(buffered, "errno=0\n\n") - } - break + case "set=1\n": + device.log.Debug.Println("Config, set operation") + err := ipcSetOperation(device, buffered) + if err != nil { + fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } + return - default: - device.log.Info.Println("Invalid UAPI operation:", op) + case "get=1\n": + device.log.Debug.Println("Config, get operation") + err := ipcGetOperation(device, buffered) + if err != nil { + fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") } - }() + return + + default: + device.log.Error.Println("Invalid UAPI operation:", op) - socket.Close() + } } |