summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/daemon_linux.go34
-rw-r--r--src/main.go73
-rw-r--r--src/uapi_linux.go83
3 files changed, 165 insertions, 25 deletions
diff --git a/src/daemon_linux.go b/src/daemon_linux.go
new file mode 100644
index 0000000..809c176
--- /dev/null
+++ b/src/daemon_linux.go
@@ -0,0 +1,34 @@
+package main
+
+import (
+ "os"
+)
+
+/* Daemonizes the process on linux
+ *
+ * This is done by spawning and releasing a copy with the --foreground flag
+ */
+
+func Daemonize() error {
+ argv := []string{os.Args[0], "--foreground"}
+ argv = append(argv, os.Args[1:]...)
+ attr := &os.ProcAttr{
+ Dir: ".",
+ Env: os.Environ(),
+ Files: []*os.File{
+ os.Stdin,
+ nil,
+ nil,
+ },
+ }
+ process, err := os.StartProcess(
+ argv[0],
+ argv,
+ attr,
+ )
+ if err != nil {
+ return err
+ }
+ process.Release()
+ return nil
+}
diff --git a/src/main.go b/src/main.go
index dc27472..74e7ec9 100644
--- a/src/main.go
+++ b/src/main.go
@@ -1,23 +1,45 @@
package main
import (
- "fmt"
"log"
- "net"
"os"
"runtime"
)
-/* TODO: Fix logging
- * TODO: Fix daemon
- */
-
func main() {
- if len(os.Args) != 2 {
+ // parse arguments
+
+ var foreground bool
+ var interfaceName string
+ if len(os.Args) < 2 || len(os.Args) > 3 {
+ return
+ }
+
+ switch os.Args[1] {
+ case "-f", "--foreground":
+ foreground = true
+ if len(os.Args) != 3 {
+ return
+ }
+ interfaceName = os.Args[2]
+ default:
+ foreground = false
+ if len(os.Args) != 2 {
+ return
+ }
+ interfaceName = os.Args[1]
+ }
+
+ // daemonize the process
+
+ if !foreground {
+ err := Daemonize()
+ if err != nil {
+ log.Println("Failed to daemonize:", err)
+ }
return
}
- deviceName := os.Args[1]
// increase number of go workers (for Go <1.5)
@@ -25,32 +47,33 @@ func main() {
// open TUN device
- tun, err := CreateTUN(deviceName)
+ tun, err := CreateTUN(interfaceName)
log.Println(tun, err)
if err != nil {
return
}
+ // create wireguard device
+
device := NewDevice(tun, LogLevelDebug)
- device.log.Info.Println("Starting device")
+
+ logInfo := device.log.Info
+ logError := device.log.Error
+ logInfo.Println("Starting device")
// start configuration lister
- go func() {
- socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName)
- l, err := net.Listen("unix", socketPath)
- if err != nil {
- log.Fatal("listen error:", err)
- }
+ uapi, err := NewUAPIListener(interfaceName)
+ if err != nil {
+ logError.Fatal("UAPI listen error:", err)
+ }
+ defer uapi.Close()
- for {
- conn, err := l.Accept()
- if err != nil {
- log.Fatal("accept error:", err)
- }
- go ipcHandle(device, conn)
+ for {
+ conn, err := uapi.Accept()
+ if err != nil {
+ logError.Fatal("accept error:", err)
}
- }()
-
- device.Wait()
+ go ipcHandle(device, conn)
+ }
}
diff --git a/src/uapi_linux.go b/src/uapi_linux.go
new file mode 100644
index 0000000..ee6ee0b
--- /dev/null
+++ b/src/uapi_linux.go
@@ -0,0 +1,83 @@
+package main
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "time"
+)
+
+/* TODO:
+ * This code can be improved by using fsnotify once:
+ * https://github.com/fsnotify/fsnotify/pull/205
+ * Is merged
+ */
+
+type UAPIListener struct {
+ listener net.Listener // unix socket listener
+ connNew chan net.Conn
+ connErr chan error
+}
+
+func (l *UAPIListener) Accept() (net.Conn, error) {
+ for {
+ select {
+ case conn := <-l.connNew:
+ return conn, nil
+
+ case err := <-l.connErr:
+ return nil, err
+ }
+ }
+}
+
+func (l *UAPIListener) Close() error {
+ return l.listener.Close()
+}
+
+func (l *UAPIListener) Addr() net.Addr {
+ return nil
+}
+
+func NewUAPIListener(name string) (net.Listener, error) {
+
+ // open UNIX socket
+
+ socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", name)
+ listener, err := net.Listen("unix", socketPath)
+ if err != nil {
+ return nil, err
+ }
+
+ uapi := &UAPIListener{
+ listener: listener,
+ connNew: make(chan net.Conn, 1),
+ connErr: make(chan error, 1),
+ }
+
+ // watch for deletion of socket
+
+ go func(l *UAPIListener) {
+ for ; ; time.Sleep(time.Second) {
+ if _, err := os.Stat(socketPath); os.IsNotExist(err) {
+ l.connErr <- err
+ return
+ }
+ }
+ }(uapi)
+
+ // watch for new connections
+
+ go func(l *UAPIListener) {
+ for {
+ conn, err := l.listener.Accept()
+ if err != nil {
+ l.connErr <- err
+ break
+ }
+ l.connNew <- conn
+ }
+ }(uapi)
+
+ return uapi, nil
+}