diff options
-rw-r--r-- | tun/tun_linux.go | 100 |
1 files changed, 46 insertions, 54 deletions
diff --git a/tun/tun_linux.go b/tun/tun_linux.go index c352c1a..b7c429c 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -17,8 +17,8 @@ import ( "golang.zx2c4.com/wireguard/rwcancel" "net" "os" - "strconv" "sync" + "syscall" "time" "unsafe" ) @@ -30,8 +30,6 @@ const ( type NativeTun struct { tunFile *os.File - fd uintptr - fdCancel *rwcancel.RWCancel index int32 // if index name string // name of interface errors chan error // async error handling @@ -52,9 +50,17 @@ func (tun *NativeTun) routineHackListener() { /* This is needed for the detection to work across network namespaces * If you are reading this and know a better method, please get in touch. */ - fd := int(tun.fd) for { - _, err := unix.Write(fd, nil) + sysconn, err := tun.tunFile.SyscallConn() + if err != nil { + return + } + err2 := sysconn.Control(func(fd uintptr) { + _, err = unix.Write(int(fd), nil) + }) + if err2 != nil { + return + } switch err { case unix.EINVAL: tun.events <- TUNEventUp @@ -248,22 +254,32 @@ func (tun *NativeTun) MTU() (int, error) { uintptr(unsafe.Pointer(&ifr[0])), ) if errno != 0 { - return 0, errors.New("failed to get MTU of TUN device: " + strconv.FormatInt(int64(errno), 10)) + return 0, errors.New("failed to get MTU of TUN device: " + errno.Error()) } return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil } func (tun *NativeTun) Name() (string, error) { + sysconn, err := tun.tunFile.SyscallConn() + if err != nil { + return "", err + } var ifr [ifReqSize]byte - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - tun.fd, - uintptr(unix.TUNGETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) + var errno syscall.Errno + err = sysconn.Control(func(fd uintptr) { + _, _, errno = unix.Syscall( + unix.SYS_IOCTL, + fd, + uintptr(unix.TUNGETIFF), + uintptr(unsafe.Pointer(&ifr[0])), + ) + }) + if err != nil { + return "", errors.New("failed to get name of TUN device: " + err.Error()) + } if errno != 0 { - return "", errors.New("failed to get name of TUN device: " + strconv.FormatInt(int64(errno), 10)) + return "", errors.New("failed to get name of TUN device: " + errno.Error()) } nullStr := ifr[:] i := bytes.IndexByte(nullStr, 0) @@ -302,7 +318,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { return tun.tunFile.Write(buff) } -func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { select { case err := <-tun.errors: return 0, err @@ -320,18 +336,6 @@ func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) { } } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { - for { - n, err := tun.doRead(buff, offset) - if err == nil || !rwcancel.RetryAfterError(err) { - return n, err - } - if !tun.fdCancel.ReadyRead() { - return 0, errors.New("tun device closed") - } - } -} - func (tun *NativeTun) Events() chan TUNEvent { return tun.events } @@ -347,15 +351,11 @@ func (tun *NativeTun) Close() error { close(tun.events) } err2 := tun.tunFile.Close() - err3 := tun.fdCancel.Cancel() if err1 != nil { return err1 } - if err2 != nil { - return err2 - } - return err3 + return err2 } func CreateTUN(name string, mtu int) (TUNDevice, error) { @@ -364,13 +364,6 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) { return nil, err } - fd := os.NewFile(uintptr(nfd), cloneDevicePath) - if err != nil { - return nil, err - } - - // create new device - var ifr [ifReqSize]byte var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) nameBytes := []byte(name) @@ -382,13 +375,21 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) { _, _, errno := unix.Syscall( unix.SYS_IOCTL, - fd.Fd(), + uintptr(nfd), uintptr(unix.TUNSETIFF), uintptr(unsafe.Pointer(&ifr[0])), ) if errno != 0 { return nil, errno } + err = unix.SetNonblock(nfd, true) + + // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. + + fd := os.NewFile(uintptr(nfd), cloneDevicePath) + if err != nil { + return nil, err + } return CreateTUNFromFile(fd, mtu) } @@ -396,7 +397,6 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) { func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { tun := &NativeTun{ tunFile: file, - fd: file.Fd(), events: make(chan TUNEvent, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), @@ -404,11 +404,6 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { } var err error - tun.fdCancel, err = rwcancel.NewRWCancel(int(tun.fd)) - if err != nil { - return nil, err - } - _, err = tun.Name() if err != nil { return nil, err @@ -444,23 +439,20 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { return tun, nil } -func CreateUnmonitoredTUNFromFD(tunFd int) (TUNDevice, string, error) { - file := os.NewFile(uintptr(tunFd), "/dev/tun") +func CreateUnmonitoredTUNFromFD(fd int) (TUNDevice, string, error) { + err := unix.SetNonblock(fd, true) + if err != nil { + return nil, "", err + } + file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ tunFile: file, - fd: file.Fd(), events: make(chan TUNEvent, 5), errors: make(chan error, 5), nopi: true, } - var err error - tun.fdCancel, err = rwcancel.NewRWCancel(int(tun.fd)) - if err != nil { - return nil, "", err - } name, err := tun.Name() if err != nil { - tun.fdCancel.Cancel() return nil, "", err } return tun, name, nil |