diff options
-rw-r--r-- | tun/tun_linux.go | 75 |
1 files changed, 48 insertions, 27 deletions
diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 0ee9f6c..17075d9 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -19,7 +19,6 @@ import ( "os" "strconv" "sync" - "syscall" "time" "unsafe" ) @@ -31,6 +30,8 @@ const ( type nativeTun struct { tunFile *os.File + fd uintptr + fdCancel *rwcancel.RWCancel index int32 // if index name string // name of interface errors chan error // async error handling @@ -51,11 +52,9 @@ func (tun *nativeTun) routineHackListener() { /* This is needed for the detection to work across network namespaces * If you are reading this and know a better method, please get in touch. */ + fd := int(tun.fd) for { - var err error - tun.operateOnFd(func(fd uintptr) { - _, err = unix.Write(int(fd), nil) - }) + _, err := unix.Write(fd, nil) switch err { case unix.EINVAL: tun.events <- TUNEventUp @@ -257,15 +256,12 @@ func (tun *nativeTun) MTU() (int, error) { func (tun *nativeTun) Name() (string, error) { var ifr [ifReqSize]byte - var errno syscall.Errno - tun.operateOnFd(func(fd uintptr) { - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - fd, - uintptr(unix.TUNGETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - }) + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + tun.fd, + uintptr(unix.TUNGETIFF), + uintptr(unsafe.Pointer(&ifr[0])), + ) if errno != 0 { return "", errors.New("failed to get name of TUN device: " + strconv.FormatInt(int64(errno), 10)) } @@ -306,7 +302,7 @@ func (tun *nativeTun) Write(buff []byte, offset int) (int, error) { return tun.tunFile.Write(buff) } -func (tun *nativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *nativeTun) doRead(buff []byte, offset int) (int, error) { select { case err := <-tun.errors: return 0, err @@ -324,6 +320,18 @@ func (tun *nativeTun) Read(buff []byte, offset int) (int, error) { } } +func (tun *nativeTun) Read(buff []byte, offset int) (int, error) { + for { + n, err := tun.doRead(buff, offset) + if err == nil || !rwcancel.RetryAfterError(err) { + return n, err + } + if !tun.fdCancel.ReadyRead() { + return 0, errors.New("tun device closed") + } + } +} + func (tun *nativeTun) Events() chan TUNEvent { return tun.events } @@ -339,20 +347,30 @@ func (tun *nativeTun) Close() error { close(tun.events) } err2 := tun.tunFile.Close() + err3 := tun.fdCancel.Cancel() if err1 != nil { return err1 } - return err2 + if err2 != nil { + return err2 + } + return err3 } func CreateTUN(name string, mtu int) (TUNDevice, error) { - tunFile, err := os.OpenFile(cloneDevicePath, os.O_RDWR, 0) + nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0) + if err != nil { + return nil, err + } + + fd := os.NewFile(uintptr(nfd), cloneDevicePath) if err != nil { return nil, err } // create new device + var ifr [ifReqSize]byte var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) nameBytes := []byte(name) @@ -362,25 +380,23 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) { copy(ifr[:], nameBytes) *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - var errno syscall.Errno - (&nativeTun{tunFile: tunFile}).operateOnFd(func(fd uintptr) { - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - fd, - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - }) + _, _, errno := unix.Syscall( + unix.SYS_IOCTL, + fd.Fd(), + uintptr(unix.TUNSETIFF), + uintptr(unsafe.Pointer(&ifr[0])), + ) if errno != 0 { return nil, errno } - return CreateTUNFromFile(tunFile, mtu) + return CreateTUNFromFile(fd, mtu) } func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { tun := &nativeTun{ tunFile: file, + fd: file.Fd(), events: make(chan TUNEvent, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), @@ -388,6 +404,11 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { } var err error + tun.fdCancel, err = rwcancel.NewRWCancel(int(tun.fd)) + if err != nil { + return nil, err + } + _, err = tun.Name() if err != nil { return nil, err |