diff options
-rw-r--r-- | tun/mksyscall.go | 8 | ||||
-rw-r--r-- | tun/tun_windows.go | 176 | ||||
-rw-r--r-- | tun/ztun_windows.go | 61 |
3 files changed, 83 insertions, 162 deletions
diff --git a/tun/mksyscall.go b/tun/mksyscall.go deleted file mode 100644 index 06bb41e..0000000 --- a/tun/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package tun - -//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output ztun_windows.go tun_windows.go diff --git a/tun/tun_windows.go b/tun/tun_windows.go index fffd802..ea244f1 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -9,7 +9,6 @@ import ( "errors" "os" "sync" - "syscall" "time" "unsafe" @@ -39,22 +38,18 @@ type exchgBufWrite struct { } type NativeTun struct { - wt *wintun.Wintun - tunName *uint16 - tunFile windows.Handle - tunLock sync.Mutex - close bool - rdBuff *exchgBufRead - wrBuff *exchgBufWrite - rdEvent windows.Handle - wrEvent windows.Handle - events chan TUNEvent - errors chan error - forcedMtu int + wt *wintun.Wintun + tunFileRead *os.File + tunFileWrite *os.File + tunLock sync.Mutex + close bool + rdBuff *exchgBufRead + wrBuff *exchgBufWrite + events chan TUNEvent + errors chan error + forcedMtu int } -//sys getOverlappedResult(handle windows.Handle, overlapped *windows.Overlapped, done *uint32, wait bool) (err error) = kernel32.GetOverlappedResult - func packetAlign(size uint32) uint32 { return (size + (packetExchangeAlignment - 1)) &^ (packetExchangeAlignment - 1) } @@ -92,32 +87,10 @@ func CreateTUN(ifname string) (TUNDevice, error) { return nil, errors.New("Flushing interface failed: " + err.Error()) } - tunNameUTF16, err := windows.UTF16PtrFromString(wt.DataFileName()) - if err != nil { - wt.DeleteInterface(0) - return nil, err - } - - rde, err := windows.CreateEvent(nil, 1 /*TRUE*/, 0 /*FALSE*/, nil) - if err != nil { - wt.DeleteInterface(0) - return nil, err - } - wre, err := windows.CreateEvent(nil, 1 /*TRUE*/, 0 /*FALSE*/, nil) - if err != nil { - windows.CloseHandle(rde) - wt.DeleteInterface(0) - return nil, err - } - return &NativeTun{ wt: wt, - tunName: tunNameUTF16, - tunFile: windows.InvalidHandle, rdBuff: &exchgBufRead{}, wrBuff: &exchgBufWrite{}, - rdEvent: rde, - wrEvent: wre, events: make(chan TUNEvent, 10), errors: make(chan error, 1), forcedMtu: 1500, @@ -126,12 +99,25 @@ func CreateTUN(ifname string) (TUNDevice, error) { func (tun *NativeTun) openTUN() error { retries := retryTimeout * retryRate - for { - if tun.close { - return errors.New("Cancelled") - } + if tun.close { + return os.ErrClosed + } - file, err := windows.CreateFile(tun.tunName, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_ATTRIBUTE_NORMAL|windows.FILE_FLAG_OVERLAPPED|windows.FILE_FLAG_NO_BUFFERING, 0) + var err error + name := tun.wt.DataFileName() + for tun.tunFileRead == nil { + tun.tunFileRead, err = os.OpenFile(name, os.O_RDONLY, 0) + if err != nil { + if retries > 0 && !tun.close { + time.Sleep(time.Second / retryRate) + retries-- + continue + } + return err + } + } + for tun.tunFileWrite == nil { + tun.tunFileWrite, err = os.OpenFile(name, os.O_WRONLY, 0) if err != nil { if retries > 0 { time.Sleep(time.Second / retryRate) @@ -140,47 +126,69 @@ func (tun *NativeTun) openTUN() error { } return err } - - tun.tunFile = file - return nil } + return nil } func (tun *NativeTun) closeTUN() (err error) { - if tun.tunFile != windows.InvalidHandle { + for tun.tunFileRead != nil { tun.tunLock.Lock() - defer tun.tunLock.Unlock() - if tun.tunFile == windows.InvalidHandle { - return + if tun.tunFileRead == nil { + tun.tunLock.Unlock() + break } - t := tun.tunFile - tun.tunFile = windows.InvalidHandle - err = windows.CloseHandle(t) + t := tun.tunFileRead + tun.tunFileRead = nil + err = t.Close() + tun.tunLock.Unlock() + break + } + for tun.tunFileWrite != nil { + tun.tunLock.Lock() + if tun.tunFileWrite == nil { + tun.tunLock.Unlock() + break + } + t := tun.tunFileWrite + tun.tunFileWrite = nil + err2 := t.Close() + tun.tunLock.Unlock() + if err == nil { + err = err2 + } + break } return } -func (tun *NativeTun) getTUN() (windows.Handle, error) { - if tun.tunFile == windows.InvalidHandle { +func (tun *NativeTun) getTUN() (read *os.File, write *os.File, err error) { + read, write = tun.tunFileRead, tun.tunFileWrite + if read == nil || write == nil { + read, write = nil, nil tun.tunLock.Lock() - defer tun.tunLock.Unlock() - if tun.tunFile != windows.InvalidHandle { - return tun.tunFile, nil + if tun.tunFileRead != nil && tun.tunFileWrite != nil { + read, write = tun.tunFileRead, tun.tunFileWrite + tun.tunLock.Unlock() + return } - err := tun.openTUN() + err = tun.closeTUN() if err != nil { - return windows.InvalidHandle, err + tun.tunLock.Unlock() + return } + err = tun.openTUN() + if err == nil { + read, write = tun.tunFileRead, tun.tunFileWrite + } + tun.tunLock.Unlock() + return } - return tun.tunFile, nil + return } -func (tun *NativeTun) isIOCancelled(err error) bool { - // Read&WriteFile() return the same ERROR_OPERATION_ABORTED if we close the handle - // or the TUN device is put down. We need a "close" flag to distinguish. - en, ok := err.(syscall.Errno) - if tun.close && ok && en == windows.ERROR_OPERATION_ABORTED { - return true +func (tun *NativeTun) shouldReopenHandle(err error) bool { + if pe, ok := err.(*os.PathError); ok && pe.Err == os.ErrClosed { + return !tun.close } return false } @@ -210,9 +218,6 @@ func (tun *NativeTun) Close() error { err1 = err2 } - windows.CloseHandle(tun.rdEvent) - windows.CloseHandle(tun.wrEvent) - return err1 } @@ -252,27 +257,20 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { } // Get TUN data pipe. - file, err := tun.getTUN() + file, _, err := tun.getTUN() if err != nil { return 0, err } // Fill queue. - var n uint32 - overlapped := &windows.Overlapped{HEvent: tun.rdEvent} - err = windows.ReadFile(file, tun.rdBuff.data[:], &n, overlapped) + n, err := file.Read(tun.rdBuff.data[:]) if err != nil { - if en, ok := err.(syscall.Errno); ok && en == windows.ERROR_IO_PENDING { - err = getOverlappedResult(file, overlapped, &n, true) - } - if err != nil { - tun.rdBuff.avail = 0 - if tun.isIOCancelled(err) { - return 0, err - } + tun.rdBuff.avail = 0 + if tun.shouldReopenHandle(err) { tun.closeTUN() continue } + return 0, err } tun.rdBuff.offset = 0 tun.rdBuff.avail = uint32(n) @@ -287,30 +285,22 @@ func (tun *NativeTun) Flush() error { } // Get TUN data pipe. - file, err := tun.getTUN() + _, file, err := tun.getTUN() if err != nil { return err } // Flush write buffer. - var n uint32 - overlapped := &windows.Overlapped{HEvent: tun.wrEvent} - err = windows.WriteFile(file, tun.wrBuff.data[:tun.wrBuff.offset], &n, overlapped) + _, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset]) tun.wrBuff.packetNum = 0 tun.wrBuff.offset = 0 if err != nil { - if en, ok := err.(syscall.Errno); ok && en == windows.ERROR_IO_PENDING { - err = getOverlappedResult(file, overlapped, &n, true) - } - if err != nil { - if tun.isIOCancelled(err) { - return err - } + if tun.shouldReopenHandle(err) { tun.closeTUN() return nil } + return err } - return nil } diff --git a/tun/ztun_windows.go b/tun/ztun_windows.go deleted file mode 100644 index ed779c1..0000000 --- a/tun/ztun_windows.go +++ /dev/null @@ -1,61 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package tun - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - - procGetOverlappedResult = modkernel32.NewProc("GetOverlappedResult") -) - -func getOverlappedResult(handle windows.Handle, overlapped *windows.Overlapped, done *uint32, wait bool) (err error) { - var _p0 uint32 - if wait { - _p0 = 1 - } else { - _p0 = 0 - } - r1, _, e1 := syscall.Syscall6(procGetOverlappedResult.Addr(), 4, uintptr(handle), uintptr(unsafe.Pointer(overlapped)), uintptr(unsafe.Pointer(done)), uintptr(_p0), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} |