summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--conn_linux.go36
-rw-r--r--main.go4
-rw-r--r--tun_darwin.go2
-rw-r--r--tun_linux.go52
4 files changed, 69 insertions, 25 deletions
diff --git a/conn_linux.go b/conn_linux.go
index 8d076ac..e30631f 100644
--- a/conn_linux.go
+++ b/conn_linux.go
@@ -15,6 +15,7 @@
package main
import (
+ "./rwcancel"
"errors"
"golang.org/x/sys/unix"
"net"
@@ -55,10 +56,11 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
}
type NativeBind struct {
- sock4 int
- sock6 int
- netlinkSock int
- lastMark uint32
+ sock4 int
+ sock6 int
+ netlinkSock int
+ netlinkCancel *rwcancel.RWCancel
+ lastMark uint32
}
var _ Endpoint = (*NativeEndpoint)(nil)
@@ -125,18 +127,23 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
if err != nil {
return nil, 0, err
}
+ bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
+ if err != nil {
+ unix.Close(bind.netlinkSock)
+ return nil, 0, err
+ }
go bind.routineRouteListener(device)
bind.sock6, port, err = create6(port)
if err != nil {
- unix.Close(bind.netlinkSock)
+ bind.netlinkCancel.Cancel()
return nil, port, err
}
bind.sock4, port, err = create4(port)
if err != nil {
- unix.Close(bind.netlinkSock)
+ bind.netlinkCancel.Cancel()
unix.Close(bind.sock6)
}
return &bind, port, err
@@ -178,7 +185,8 @@ func closeUnblock(fd int) error {
func (bind *NativeBind) Close() error {
err1 := closeUnblock(bind.sock6)
err2 := closeUnblock(bind.sock4)
- err3 := closeUnblock(bind.netlinkSock)
+ err3 := bind.netlinkCancel.Cancel()
+
if err1 != nil {
return err1
}
@@ -539,8 +547,20 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
func (bind *NativeBind) routineRouteListener(device *Device) {
var reqPeer map[uint32]*Peer
+ defer unix.Close(bind.netlinkSock)
+
for msg := make([]byte, 1<<16); ; {
- msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
+ var err error
+ var msgn int
+ for {
+ msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
+ if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
+ break
+ }
+ if !bind.netlinkCancel.ReadyRead() {
+ return
+ }
+ }
if err != nil {
return
}
diff --git a/main.go b/main.go
index c9ef343..6e876df 100644
--- a/main.go
+++ b/main.go
@@ -221,14 +221,10 @@ func main() {
return
}
- // create wireguard device
-
device := NewDevice(tun, logger)
logger.Info.Println("Device started")
- // start uapi listener
-
errs := make(chan error)
term := make(chan os.Signal)
diff --git a/tun_darwin.go b/tun_darwin.go
index ac8bffd..8f9a5d5 100644
--- a/tun_darwin.go
+++ b/tun_darwin.go
@@ -122,11 +122,13 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) {
_, err := tun.Name()
if err != nil {
+ tun.fd.Close()
return nil, err
}
tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd()))
if err != nil {
+ tun.fd.Close()
return nil, err
}
diff --git a/tun_linux.go b/tun_linux.go
index 8e42d44..32bd95d 100644
--- a/tun_linux.go
+++ b/tun_linux.go
@@ -31,14 +31,16 @@ const (
)
type NativeTun struct {
- fd *os.File
- index int32 // if index
- name string // name of interface
- errors chan error // async error handling
- events chan TUNEvent // device related events
- nopi bool // the device was pased IFF_NO_PI
- rwcancel *rwcancel.RWCancel
- netlinkSock int
+ fd *os.File
+ fdCancel *rwcancel.RWCancel
+ index int32 // if index
+ name string // name of interface
+ errors chan error // async error handling
+ events chan TUNEvent // device related events
+ nopi bool // the device was pased IFF_NO_PI
+ netlinkSock int
+ netlinkCancel *rwcancel.RWCancel
+
statusListenersShutdown chan struct{}
}
@@ -86,9 +88,22 @@ func createNetlinkSocket() (int, error) {
}
func (tun *NativeTun) RoutineNetlinkListener() {
+ defer unix.Close(tun.netlinkSock)
+
for msg := make([]byte, 1<<16); ; {
- msgn, _, _, _, err := unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
+ var err error
+ var msgn int
+ for {
+ msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
+ if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
+ break
+ }
+ if !tun.netlinkCancel.ReadyRead() {
+ tun.errors <- fmt.Errorf("netlink socket closed: %s", err.Error())
+ return
+ }
+ }
if err != nil {
tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error())
return
@@ -323,7 +338,7 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
return n, err
}
- if !tun.rwcancel.ReadyRead() {
+ if !tun.fdCancel.ReadyRead() {
return 0, errors.New("tun device closed")
}
}
@@ -334,10 +349,13 @@ func (tun *NativeTun) Events() chan TUNEvent {
}
func (tun *NativeTun) Close() error {
+ var err1 error
close(tun.statusListenersShutdown)
- err1 := closeUnblock(tun.netlinkSock)
+ if tun.netlinkCancel != nil {
+ err1 = tun.netlinkCancel.Cancel()
+ }
err2 := tun.fd.Close()
- err3 := tun.rwcancel.Cancel()
+ err3 := tun.fdCancel.Cancel()
close(tun.events)
if err1 != nil {
@@ -404,13 +422,15 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
}
var err error
- tun.rwcancel, err = rwcancel.NewRWCancel(int(fd.Fd()))
+ tun.fdCancel, err = rwcancel.NewRWCancel(int(fd.Fd()))
if err != nil {
+ tun.fd.Close()
return nil, err
}
_, err = tun.Name()
if err != nil {
+ tun.fd.Close()
return nil, err
}
@@ -423,6 +443,12 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
tun.netlinkSock, err = createNetlinkSocket()
if err != nil {
+ tun.fd.Close()
+ return nil, err
+ }
+ tun.netlinkCancel, err = rwcancel.NewRWCancel(tun.netlinkSock)
+ if err != nil {
+ tun.fd.Close()
return nil, err
}