diff options
Diffstat (limited to 'src/uapi_linux.go')
-rw-r--r-- | src/uapi_linux.go | 117 |
1 files changed, 69 insertions, 48 deletions
diff --git a/src/uapi_linux.go b/src/uapi_linux.go index cb9d858..f97a18a 100644 --- a/src/uapi_linux.go +++ b/src/uapi_linux.go @@ -10,12 +10,12 @@ import ( ) const ( - ipcErrorIO = -int64(unix.EIO) - ipcErrorProtocol = -int64(unix.EPROTO) - ipcErrorInvalid = -int64(unix.EINVAL) - ipcErrorPortInUse = -int64(unix.EADDRINUSE) - socketDirectory = "/var/run/wireguard" - socketName = "%s.sock" + ipcErrorIO = -int64(unix.EIO) + ipcErrorProtocol = -int64(unix.EPROTO) + ipcErrorInvalid = -int64(unix.EINVAL) + ipcErrorPortInUse = -int64(unix.EADDRINUSE) + socketDirectory = "/var/run/wireguard" + socketName = "%s.sock" ) type UAPIListener struct { @@ -50,49 +50,11 @@ func (l *UAPIListener) Addr() net.Addr { return nil } -func connectUnixSocket(path string) (net.Listener, error) { +func UAPIListen(name string, file *os.File) (net.Listener, error) { - // attempt inital connection + // wrap file in listener - listener, err := net.Listen("unix", path) - if err == nil { - return listener, nil - } - - // check if active - - _, err = net.Dial("unix", path) - if err == nil { - return nil, errors.New("Unix socket in use") - } - - // attempt cleanup - - err = os.Remove(path) - if err != nil { - return nil, err - } - - return net.Listen("unix", path) -} - -func NewUAPIListener(name string) (net.Listener, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 077) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - listener, err := connectUnixSocket(socketPath) + listener, err := net.FileListener(file) if err != nil { return nil, err } @@ -105,6 +67,11 @@ func NewUAPIListener(name string) (net.Listener, error) { // watch for deletion of socket + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + uapi.inotifyFd, err = unix.InotifyInit() if err != nil { return nil, err @@ -125,11 +92,12 @@ func NewUAPIListener(name string) (net.Listener, error) { go func(l *UAPIListener) { var buff [4096]byte for { - unix.Read(uapi.inotifyFd, buff[:]) + // start with lstat to avoid race condition if _, err := os.Lstat(socketPath); os.IsNotExist(err) { l.connErr <- err return } + unix.Read(uapi.inotifyFd, buff[:]) } }(uapi) @@ -148,3 +116,56 @@ func NewUAPIListener(name string) (net.Listener, error) { return uapi, nil } + +func UAPIOpen(name string) (*os.File, error) { + + // check if path exist + + err := os.MkdirAll(socketDirectory, 0600) + if err != nil && !os.IsExist(err) { + return nil, err + } + + // open UNIX socket + + socketPath := path.Join( + socketDirectory, + fmt.Sprintf(socketName, name), + ) + + addr, err := net.ResolveUnixAddr("unix", socketPath) + if err != nil { + return nil, err + } + + listener, err := func() (*net.UnixListener, error) { + + // initial connection attempt + + listener, err := net.ListenUnix("unix", addr) + if err == nil { + return listener, nil + } + + // check if socket already active + + _, err = net.Dial("unix", socketPath) + if err == nil { + return nil, errors.New("unix socket in use") + } + + // cleanup & attempt again + + err = os.Remove(socketPath) + if err != nil { + return nil, err + } + return net.ListenUnix("unix", addr) + }() + + if err != nil { + return nil, err + } + + return listener.File() +} |