diff options
-rw-r--r-- | src/daemon_linux.go | 34 | ||||
-rw-r--r-- | src/main.go | 73 | ||||
-rw-r--r-- | src/uapi_linux.go | 83 |
3 files changed, 165 insertions, 25 deletions
diff --git a/src/daemon_linux.go b/src/daemon_linux.go new file mode 100644 index 0000000..809c176 --- /dev/null +++ b/src/daemon_linux.go @@ -0,0 +1,34 @@ +package main + +import ( + "os" +) + +/* Daemonizes the process on linux + * + * This is done by spawning and releasing a copy with the --foreground flag + */ + +func Daemonize() error { + argv := []string{os.Args[0], "--foreground"} + argv = append(argv, os.Args[1:]...) + attr := &os.ProcAttr{ + Dir: ".", + Env: os.Environ(), + Files: []*os.File{ + os.Stdin, + nil, + nil, + }, + } + process, err := os.StartProcess( + argv[0], + argv, + attr, + ) + if err != nil { + return err + } + process.Release() + return nil +} diff --git a/src/main.go b/src/main.go index dc27472..74e7ec9 100644 --- a/src/main.go +++ b/src/main.go @@ -1,23 +1,45 @@ package main import ( - "fmt" "log" - "net" "os" "runtime" ) -/* TODO: Fix logging - * TODO: Fix daemon - */ - func main() { - if len(os.Args) != 2 { + // parse arguments + + var foreground bool + var interfaceName string + if len(os.Args) < 2 || len(os.Args) > 3 { + return + } + + switch os.Args[1] { + case "-f", "--foreground": + foreground = true + if len(os.Args) != 3 { + return + } + interfaceName = os.Args[2] + default: + foreground = false + if len(os.Args) != 2 { + return + } + interfaceName = os.Args[1] + } + + // daemonize the process + + if !foreground { + err := Daemonize() + if err != nil { + log.Println("Failed to daemonize:", err) + } return } - deviceName := os.Args[1] // increase number of go workers (for Go <1.5) @@ -25,32 +47,33 @@ func main() { // open TUN device - tun, err := CreateTUN(deviceName) + tun, err := CreateTUN(interfaceName) log.Println(tun, err) if err != nil { return } + // create wireguard device + device := NewDevice(tun, LogLevelDebug) - device.log.Info.Println("Starting device") + + logInfo := device.log.Info + logError := device.log.Error + logInfo.Println("Starting device") // start configuration lister - go func() { - socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName) - l, err := net.Listen("unix", socketPath) - if err != nil { - log.Fatal("listen error:", err) - } + uapi, err := NewUAPIListener(interfaceName) + if err != nil { + logError.Fatal("UAPI listen error:", err) + } + defer uapi.Close() - for { - conn, err := l.Accept() - if err != nil { - log.Fatal("accept error:", err) - } - go ipcHandle(device, conn) + for { + conn, err := uapi.Accept() + if err != nil { + logError.Fatal("accept error:", err) } - }() - - device.Wait() + go ipcHandle(device, conn) + } } diff --git a/src/uapi_linux.go b/src/uapi_linux.go new file mode 100644 index 0000000..ee6ee0b --- /dev/null +++ b/src/uapi_linux.go @@ -0,0 +1,83 @@ +package main + +import ( + "fmt" + "net" + "os" + "time" +) + +/* TODO: + * This code can be improved by using fsnotify once: + * https://github.com/fsnotify/fsnotify/pull/205 + * Is merged + */ + +type UAPIListener struct { + listener net.Listener // unix socket listener + connNew chan net.Conn + connErr chan error +} + +func (l *UAPIListener) Accept() (net.Conn, error) { + for { + select { + case conn := <-l.connNew: + return conn, nil + + case err := <-l.connErr: + return nil, err + } + } +} + +func (l *UAPIListener) Close() error { + return l.listener.Close() +} + +func (l *UAPIListener) Addr() net.Addr { + return nil +} + +func NewUAPIListener(name string) (net.Listener, error) { + + // open UNIX socket + + socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", name) + listener, err := net.Listen("unix", socketPath) + if err != nil { + return nil, err + } + + uapi := &UAPIListener{ + listener: listener, + connNew: make(chan net.Conn, 1), + connErr: make(chan error, 1), + } + + // watch for deletion of socket + + go func(l *UAPIListener) { + for ; ; time.Sleep(time.Second) { + if _, err := os.Stat(socketPath); os.IsNotExist(err) { + l.connErr <- err + return + } + } + }(uapi) + + // watch for new connections + + go func(l *UAPIListener) { + for { + conn, err := l.listener.Accept() + if err != nil { + l.connErr <- err + break + } + l.connNew <- conn + } + }(uapi) + + return uapi, nil +} |