diff options
-rw-r--r-- | src/main.go | 59 | ||||
-rw-r--r-- | src/tun_linux.go | 2 | ||||
-rw-r--r-- | src/uapi_linux.go | 117 |
3 files changed, 111 insertions, 67 deletions
diff --git a/src/main.go b/src/main.go index 3808c9c..7d86716 100644 --- a/src/main.go +++ b/src/main.go @@ -9,7 +9,8 @@ import ( ) const ( - EnvWGTunFD = "WG_TUN_FD" + ENV_WG_TUN_FD = "WG_TUN_FD" + ENV_WG_UAPI_FD = "WG_UAPI_FD" ) func printUsage() { @@ -65,46 +66,69 @@ func main() { logLevel, fmt.Sprintf("(%s) ", interfaceName), ) + logger.Debug.Println("Debug log enabled") - // open TUN device + // open TUN device (or use supplied fd) tun, err := func() (TUNDevice, error) { - tunFdStr := os.Getenv(EnvWGTunFD) + tunFdStr := os.Getenv(ENV_WG_TUN_FD) if tunFdStr == "" { return CreateTUN(interfaceName) } - // construct tun device from supplied FD + // construct tun device from supplied fd fd, err := strconv.ParseUint(tunFdStr, 10, 32) if err != nil { return nil, err } - file := os.NewFile(uintptr(fd), "/dev/net/tun") + file := os.NewFile(uintptr(fd), "") return CreateTUNFromFile(interfaceName, file) }() if err != nil { logger.Error.Println("Failed to create TUN device:", err) + os.Exit(ExitSetupFailed) } + // open UAPI file (or use supplied fd) + + fileUAPI, err := func() (*os.File, error) { + uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) + if uapiFdStr == "" { + return UAPIOpen(interfaceName) + } + + // use supplied fd + + fd, err := strconv.ParseUint(uapiFdStr, 10, 32) + if err != nil { + return nil, err + } + + return os.NewFile(uintptr(fd), ""), nil + }() + + if err != nil { + logger.Error.Println("UAPI listen error:", err) + os.Exit(ExitSetupFailed) + return + } // daemonize the process if !foreground { env := os.Environ() - _, ok := os.LookupEnv(EnvWGTunFD) - if !ok { - kvp := fmt.Sprintf("%s=3", EnvWGTunFD) - env = append(env, kvp) - } + env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD)) + env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD)) attr := &os.ProcAttr{ Files: []*os.File{ nil, // stdin nil, // stdout nil, // stderr tun.File(), + fileUAPI, }, Dir: ".", Env: env, @@ -112,6 +136,7 @@ func main() { err = Daemonize(attr) if err != nil { logger.Error.Println("Failed to daemonize:", err) + os.Exit(ExitSetupFailed) } return } @@ -123,20 +148,17 @@ func main() { // create wireguard device device := NewDevice(tun, logger) + logger.Info.Println("Device started") - // start configuration lister - - uapi, err := NewUAPIListener(interfaceName) - if err != nil { - logger.Error.Println("UAPI listen error:", err) - return - } + // start uapi listener errs := make(chan error) term := make(chan os.Signal) wait := device.WaitChannel() + uapi, err := UAPIListen(interfaceName, fileUAPI) + go func() { for { conn, err := uapi.Accept() @@ -161,9 +183,10 @@ func main() { case <-errs: } - // clean up UAPI bind + // clean up uapi.Close() + device.Close() logger.Info.Println("Shutting down") } diff --git a/src/tun_linux.go b/src/tun_linux.go index ce6304c..2a5b276 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -227,7 +227,7 @@ func (tun *NativeTun) MTU() (int, error) { val := binary.LittleEndian.Uint32(ifr[16:20]) if val >= (1 << 31) { - return int(val-(1<<31)) - (1 << 31), nil + return int(toInt32(val)), nil } return int(val), nil } diff --git a/src/uapi_linux.go b/src/uapi_linux.go index cb9d858..f97a18a 100644 --- a/src/uapi_linux.go +++ b/src/uapi_linux.go @@ -10,12 +10,12 @@ import ( ) const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketDirectory = "/var/run/wireguard" - socketName = "%s.sock" + ipcErrorIO = -int64(unix.EIO) + ipcErrorProtocol = -int64(unix.EPROTO) + ipcErrorInvalid = -int64(unix.EINVAL) + ipcErrorPortInUse = -int64(unix.EADDRINUSE) + socketDirectory = "/var/run/wireguard" + socketName = "%s.sock" ) type UAPIListener struct { @@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr { return nil } -func connectUnixSocket(path string) (net.Listener, error) { +func UAPIListen(name string, file *os.File) (net.Listener, error) { - // attempt inital connection + // wrap file in listener - listener, err := net.Listen("unix", path) - if err == nil { - return listener, nil - } - - // check if active - - _, err = net.Dial("unix", path) - if err == nil { - return nil, errors.New("Unix socket in use") - } - - // attempt cleanup - - err = os.Remove(path) - if err != nil { - return nil, err - } - - return net.Listen("unix", path) -} - -func NewUAPIListener(name string) (net.Listener, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 077) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - listener, err := connectUnixSocket(socketPath) + listener, err := net.FileListener(file) if err != nil { return nil, err } @@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) { // watch for deletion of socket + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + uapi.inotifyFd, err = unix.InotifyInit() if err != nil { return nil, err @@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) { go func(l *UAPIListener) { var buff [4096]byte for { - unix.Read(uapi.inotifyFd, buff[:]) + // start with lstat to avoid race condition if _, err := os.Lstat(socketPath); os.IsNotExist(err) { l.connErr <- err return } + unix.Read(uapi.inotifyFd, buff[:]) } }(uapi) @@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) { return uapi, nil } + +func UAPIOpen(name string) (*os.File, error) { + + // check if path exist + + err := os.MkdirAll(socketDirectory, 0600) + if err != nil && !os.IsExist(err) { + return nil, err + } + + // open UNIX socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + addr, err := net.ResolveUnixAddr("unix", socketPath) + if err != nil { + return nil, err + } + + listener, err := func() (*net.UnixListener, error) { + + // initial connection attempt + + listener, err := net.ListenUnix("unix", addr) + if err == nil { + return listener, nil + } + + // check if socket already active + + _, err = net.Dial("unix", socketPath) + if err == nil { + return nil, errors.New("unix socket in use") + } + + // cleanup & attempt again + + err = os.Remove(socketPath) + if err != nil { + return nil, err + } + return net.ListenUnix("unix", addr) + }() + + if err != nil { + return nil, err + } + + return listener.File() +} |