summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-11-17 14:36:08 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-11-17 14:36:08 +0100
commite1227d3af480eae72639cde842b4d538c58936dc (patch)
tree0263f8f1ecee9da28a0d5e951ad520972d8db504 /src
parent88801529fd4097993f7c448b1c3eee0abc8cb51c (diff)
Allows passing UAPI fd to service
Diffstat (limited to 'src')
-rw-r--r--src/main.go59
-rw-r--r--src/tun_linux.go2
-rw-r--r--src/uapi_linux.go117
3 files changed, 111 insertions, 67 deletions
diff --git a/src/main.go b/src/main.go
index 3808c9c..7d86716 100644
--- a/src/main.go
+++ b/src/main.go
@@ -9,7 +9,8 @@ import (
)
const (
- EnvWGTunFD = "WG_TUN_FD"
+ ENV_WG_TUN_FD = "WG_TUN_FD"
+ ENV_WG_UAPI_FD = "WG_UAPI_FD"
)
func printUsage() {
@@ -65,46 +66,69 @@ func main() {
logLevel,
fmt.Sprintf("(%s) ", interfaceName),
)
+
logger.Debug.Println("Debug log enabled")
- // open TUN device
+ // open TUN device (or use supplied fd)
tun, err := func() (TUNDevice, error) {
- tunFdStr := os.Getenv(EnvWGTunFD)
+ tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" {
return CreateTUN(interfaceName)
}
- // construct tun device from supplied FD
+ // construct tun device from supplied fd
fd, err := strconv.ParseUint(tunFdStr, 10, 32)
if err != nil {
return nil, err
}
- file := os.NewFile(uintptr(fd), "/dev/net/tun")
+ file := os.NewFile(uintptr(fd), "")
return CreateTUNFromFile(interfaceName, file)
}()
if err != nil {
logger.Error.Println("Failed to create TUN device:", err)
+ os.Exit(ExitSetupFailed)
}
+ // open UAPI file (or use supplied fd)
+
+ fileUAPI, err := func() (*os.File, error) {
+ uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
+ if uapiFdStr == "" {
+ return UAPIOpen(interfaceName)
+ }
+
+ // use supplied fd
+
+ fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
+ if err != nil {
+ return nil, err
+ }
+
+ return os.NewFile(uintptr(fd), ""), nil
+ }()
+
+ if err != nil {
+ logger.Error.Println("UAPI listen error:", err)
+ os.Exit(ExitSetupFailed)
+ return
+ }
// daemonize the process
if !foreground {
env := os.Environ()
- _, ok := os.LookupEnv(EnvWGTunFD)
- if !ok {
- kvp := fmt.Sprintf("%s=3", EnvWGTunFD)
- env = append(env, kvp)
- }
+ env = append(env, fmt.Sprintf("%s=3", ENV_WG_TUN_FD))
+ env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
attr := &os.ProcAttr{
Files: []*os.File{
nil, // stdin
nil, // stdout
nil, // stderr
tun.File(),
+ fileUAPI,
},
Dir: ".",
Env: env,
@@ -112,6 +136,7 @@ func main() {
err = Daemonize(attr)
if err != nil {
logger.Error.Println("Failed to daemonize:", err)
+ os.Exit(ExitSetupFailed)
}
return
}
@@ -123,20 +148,17 @@ func main() {
// create wireguard device
device := NewDevice(tun, logger)
+
logger.Info.Println("Device started")
- // start configuration lister
-
- uapi, err := NewUAPIListener(interfaceName)
- if err != nil {
- logger.Error.Println("UAPI listen error:", err)
- return
- }
+ // start uapi listener
errs := make(chan error)
term := make(chan os.Signal)
wait := device.WaitChannel()
+ uapi, err := UAPIListen(interfaceName, fileUAPI)
+
go func() {
for {
conn, err := uapi.Accept()
@@ -161,9 +183,10 @@ func main() {
case <-errs:
}
- // clean up UAPI bind
+ // clean up
uapi.Close()
+ device.Close()
logger.Info.Println("Shutting down")
}
diff --git a/src/tun_linux.go b/src/tun_linux.go
index ce6304c..2a5b276 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -227,7 +227,7 @@ func (tun *NativeTun) MTU() (int, error) {
val := binary.LittleEndian.Uint32(ifr[16:20])
if val >= (1 << 31) {
- return int(val-(1<<31)) - (1 << 31), nil
+ return int(toInt32(val)), nil
}
return int(val), nil
}
diff --git a/src/uapi_linux.go b/src/uapi_linux.go
index cb9d858..f97a18a 100644
--- a/src/uapi_linux.go
+++ b/src/uapi_linux.go
@@ -10,12 +10,12 @@ import (
)
const (
- ipcErrorIO = -int64(unix.EIO)
- ipcErrorProtocol = -int64(unix.EPROTO)
- ipcErrorInvalid = -int64(unix.EINVAL)
- ipcErrorPortInUse = -int64(unix.EADDRINUSE)
- socketDirectory = "/var/run/wireguard"
- socketName = "%s.sock"
+ ipcErrorIO = -int64(unix.EIO)
+ ipcErrorProtocol = -int64(unix.EPROTO)
+ ipcErrorInvalid = -int64(unix.EINVAL)
+ ipcErrorPortInUse = -int64(unix.EADDRINUSE)
+ socketDirectory = "/var/run/wireguard"
+ socketName = "%s.sock"
)
type UAPIListener struct {
@@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr {
return nil
}
-func connectUnixSocket(path string) (net.Listener, error) {
+func UAPIListen(name string, file *os.File) (net.Listener, error) {
- // attempt inital connection
+ // wrap file in listener
- listener, err := net.Listen("unix", path)
- if err == nil {
- return listener, nil
- }
-
- // check if active
-
- _, err = net.Dial("unix", path)
- if err == nil {
- return nil, errors.New("Unix socket in use")
- }
-
- // attempt cleanup
-
- err = os.Remove(path)
- if err != nil {
- return nil, err
- }
-
- return net.Listen("unix", path)
-}
-
-func NewUAPIListener(name string) (net.Listener, error) {
-
- // check if path exist
-
- err := os.MkdirAll(socketDirectory, 077)
- if err != nil && !os.IsExist(err) {
- return nil, err
- }
-
- // open UNIX socket
-
- socketPath := path.Join(
- socketDirectory,
- fmt.Sprintf(socketName, name),
- )
-
- listener, err := connectUnixSocket(socketPath)
+ listener, err := net.FileListener(file)
if err != nil {
return nil, err
}
@@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) {
// watch for deletion of socket
+ socketPath := path.Join(
+ socketDirectory,
+ fmt.Sprintf(socketName, name),
+ )
+
uapi.inotifyFd, err = unix.InotifyInit()
if err != nil {
return nil, err
@@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) {
go func(l *UAPIListener) {
var buff [4096]byte
for {
- unix.Read(uapi.inotifyFd, buff[:])
+ // start with lstat to avoid race condition
if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err
return
}
+ unix.Read(uapi.inotifyFd, buff[:])
}
}(uapi)
@@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) {
return uapi, nil
}
+
+func UAPIOpen(name string) (*os.File, error) {
+
+ // check if path exist
+
+ err := os.MkdirAll(socketDirectory, 0600)
+ if err != nil && !os.IsExist(err) {
+ return nil, err
+ }
+
+ // open UNIX socket
+
+ socketPath := path.Join(
+ socketDirectory,
+ fmt.Sprintf(socketName, name),
+ )
+
+ addr, err := net.ResolveUnixAddr("unix", socketPath)
+ if err != nil {
+ return nil, err
+ }
+
+ listener, err := func() (*net.UnixListener, error) {
+
+ // initial connection attempt
+
+ listener, err := net.ListenUnix("unix", addr)
+ if err == nil {
+ return listener, nil
+ }
+
+ // check if socket already active
+
+ _, err = net.Dial("unix", socketPath)
+ if err == nil {
+ return nil, errors.New("unix socket in use")
+ }
+
+ // cleanup & attempt again
+
+ err = os.Remove(socketPath)
+ if err != nil {
+ return nil, err
+ }
+ return net.ListenUnix("unix", addr)
+ }()
+
+ if err != nil {
+ return nil, err
+ }
+
+ return listener.File()
+}