diff options
-rw-r--r-- | tun/tun_windows.go | 304 | ||||
-rw-r--r-- | tun/wintun/wintun_windows.go | 14 |
2 files changed, 100 insertions, 218 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 5da5a70..6a84394 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "os" - "sync" "sync/atomic" "time" "unsafe" @@ -21,14 +20,10 @@ import ( const ( packetAlignment uint32 = 4 // Number of bytes packets are aligned to in rings - packetSizeMax uint32 = 0xffff // Maximum packet size - packetCapacity uint32 = 0x800000 // Ring capacity, 8MiB - packetTrailingSize uint32 = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment - - ioctlRegisterRings uint32 = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) - - retryRate = 4 // Number of retries per second to reopen device pipe - retryTimeout = 30 // Number of seconds to tolerate adapter unavailable + packetSizeMax = 0xffff // Maximum packet size + packetCapacity = 0x800000 // Ring capacity, 8MiB + packetTrailingSize = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment + ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) ) type packetHeader struct { @@ -57,8 +52,7 @@ type ringDescriptor struct { type NativeTun struct { wt *wintun.Wintun - tunDev windows.Handle - tunLock sync.Mutex + handle windows.Handle close bool rings ringDescriptor events chan Event @@ -70,15 +64,6 @@ func packetAlign(size uint32) uint32 { return (size + (packetAlignment - 1)) &^ (packetAlignment - 1) } -var shouldRetryOpen = windows.RtlGetVersion().MajorVersion < 10 - -func maybeRetry(x int) int { - if shouldRetryOpen { - return x - } - return 0 -} - // // CreateTUN creates a Wintun adapter with the given name. Should a Wintun // adapter with the same name exist, it is reused. @@ -119,7 +104,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev tun := &NativeTun{ wt: wt, - tunDev: windows.InvalidHandle, + handle: windows.InvalidHandle, events: make(chan Event, 10), errors: make(chan error, 1), forcedMTU: 1500, @@ -129,7 +114,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev tun.rings.send.ring = &ring{} tun.rings.send.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil) if err != nil { - wt.DeleteInterface() + tun.Close() return nil, fmt.Errorf("Error creating event: %v", err) } @@ -137,98 +122,23 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev tun.rings.receive.ring = &ring{} tun.rings.receive.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil) if err != nil { - windows.CloseHandle(tun.rings.send.tailMoved) - wt.DeleteInterface() + tun.Close() return nil, fmt.Errorf("Error creating event: %v", err) } - _, err = tun.getTUN() + tun.handle, err = tun.wt.AdapterHandle() if err != nil { - windows.CloseHandle(tun.rings.send.tailMoved) - windows.CloseHandle(tun.rings.receive.tailMoved) - tun.closeTUN() - wt.DeleteInterface() + tun.Close() return nil, err } - return tun, nil -} - -func (tun *NativeTun) openTUN() error { - filename, err := tun.wt.NdisFileName() - if err != nil { - return err - } - - retries := maybeRetry(retryTimeout * retryRate) - if tun.close { - return os.ErrClosed - } - - name, err := windows.UTF16PtrFromString(filename) + var bytesReturned uint32 + err = windows.DeviceIoControl(tun.handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil) if err != nil { - return err + tun.Close() + return nil, fmt.Errorf("Error registering rings: %v", err) } - for tun.tunDev == windows.InvalidHandle { - tun.tunDev, err = windows.CreateFile(name, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, 0, 0) - if err != nil { - if retries > 0 && !tun.close { - time.Sleep(time.Second / retryRate) - retries-- - continue - } - return err - } - - atomic.StoreUint32(&tun.rings.send.ring.head, 0) - atomic.StoreUint32(&tun.rings.send.ring.tail, 0) - atomic.StoreInt32(&tun.rings.send.ring.alertable, 0) - atomic.StoreUint32(&tun.rings.receive.ring.head, 0) - atomic.StoreUint32(&tun.rings.receive.ring.tail, 0) - atomic.StoreInt32(&tun.rings.receive.ring.alertable, 0) - - var bytesReturned uint32 - err = windows.DeviceIoControl(tun.tunDev, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil) - if err != nil { - return fmt.Errorf("Error registering rings: %v", err) - } - } - return nil -} - -func (tun *NativeTun) closeTUN() (err error) { - for tun.tunDev != windows.InvalidHandle { - tun.tunLock.Lock() - if tun.tunDev == windows.InvalidHandle { - tun.tunLock.Unlock() - break - } - t := tun.tunDev - tun.tunDev = windows.InvalidHandle - err = windows.CloseHandle(t) - tun.tunLock.Unlock() - break - } - return -} - -func (tun *NativeTun) getTUN() (handle windows.Handle, err error) { - handle = tun.tunDev - if handle == windows.InvalidHandle { - tun.tunLock.Lock() - if tun.tunDev != windows.InvalidHandle { - handle = tun.tunDev - tun.tunLock.Unlock() - return - } - err = tun.openTUN() - if err == nil { - handle = tun.tunDev - } - tun.tunLock.Unlock() - return - } - return + return tun, nil } func (tun *NativeTun) Name() (string, error) { @@ -245,29 +155,22 @@ func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Close() error { tun.close = true - windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping - var err, err2 error - err = tun.closeTUN() - - if tun.events != nil { - close(tun.events) + if tun.rings.send.tailMoved != 0 { + windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping } - - err2 = windows.CloseHandle(tun.rings.receive.tailMoved) - if err == nil { - err = err2 + if tun.handle != windows.InvalidHandle { + windows.CloseHandle(tun.handle) } - - err2 = windows.CloseHandle(tun.rings.send.tailMoved) - if err == nil { - err = err2 + if tun.rings.send.tailMoved != 0 { + windows.CloseHandle(tun.rings.send.tailMoved) } - - _, err2 = tun.wt.DeleteInterface() - if err == nil { - err = err2 + if tun.rings.send.tailMoved != 0 { + windows.CloseHandle(tun.rings.receive.tailMoved) + } + var err error + if tun.wt != nil { + _, err = tun.wt.DeleteInterface() } - return err } @@ -286,74 +189,60 @@ func procyield(cycles uint32) // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +retry: select { case err := <-tun.errors: return 0, err default: } + if tun.close { + return 0, os.ErrClosed + } - retries := maybeRetry(1000) -top: - for !tun.close { - _, err := tun.getTUN() - if err != nil { - return 0, err - } - - buffHead := atomic.LoadUint32(&tun.rings.send.ring.head) - if buffHead >= packetCapacity { - return 0, errors.New("send ring head out of bounds") - } + buffHead := atomic.LoadUint32(&tun.rings.send.ring.head) + if buffHead >= packetCapacity { + return 0, os.ErrClosed + } - start := time.Now() - var buffTail uint32 - for { - buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail) - if buffHead != buffTail { - break - } - if tun.close { - return 0, os.ErrClosed - } - if time.Since(start) >= time.Millisecond*50 { - windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE) - continue top - } - procyield(1) + start := time.Now() + var buffTail uint32 + for { + buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail) + if buffHead != buffTail { + break } - if buffTail >= packetCapacity { - if retries > 0 { - tun.closeTUN() - time.Sleep(time.Millisecond * 2) - retries-- - continue - } - return 0, errors.New("send ring tail out of bounds") + if tun.close { + return 0, os.ErrClosed } - retries = maybeRetry(1000) - - buffContent := tun.rings.send.ring.wrap(buffTail - buffHead) - if buffContent < uint32(unsafe.Sizeof(packetHeader{})) { - return 0, errors.New("incomplete packet header in send ring") + if time.Since(start) >= time.Millisecond*50 { + windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE) + goto retry } + procyield(1) + } + if buffTail >= packetCapacity { + return 0, os.ErrClosed + } - packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead])) - if packet.size > packetSizeMax { - return 0, errors.New("packet too big in send ring") - } + buffContent := tun.rings.send.ring.wrap(buffTail - buffHead) + if buffContent < uint32(unsafe.Sizeof(packetHeader{})) { + return 0, errors.New("incomplete packet header in send ring") + } - alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size) - if alignedPacketSize > buffContent { - return 0, errors.New("incomplete packet in send ring") - } + packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead])) + if packet.size > packetSizeMax { + return 0, errors.New("packet too big in send ring") + } - copy(buff[offset:], packet.data[:packet.size]) - buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize) - atomic.StoreUint32(&tun.rings.send.ring.head, buffHead) - return int(packet.size), nil + alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size) + if alignedPacketSize > buffContent { + return 0, errors.New("incomplete packet in send ring") } - return 0, os.ErrClosed + copy(buff[offset:], packet.data[:packet.size]) + buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize) + atomic.StoreUint32(&tun.rings.send.ring.head, buffHead) + return int(packet.size), nil } func (tun *NativeTun) Flush() error { @@ -361,47 +250,36 @@ func (tun *NativeTun) Flush() error { } func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - retries := maybeRetry(1000) - for { - _, err := tun.getTUN() - if err != nil { - return 0, err - } + if tun.close { + return 0, os.ErrClosed + } - packetSize := uint32(len(buff) - offset) - alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize) - - buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head) - if buffHead >= packetCapacity { - if retries > 0 { - tun.closeTUN() - time.Sleep(time.Millisecond * 2) - retries-- - continue - } - return 0, errors.New("receive ring head out of bounds") - } - retries = maybeRetry(1000) + packetSize := uint32(len(buff) - offset) + alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize) - buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail) - if buffTail >= packetCapacity { - return 0, errors.New("receive ring tail out of bounds") - } + buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head) + if buffHead >= packetCapacity { + return 0, os.ErrClosed + } - buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment) - if alignedPacketSize > buffSpace { - return 0, nil // Dropping when ring is full. - } + buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail) + if buffTail >= packetCapacity { + return 0, os.ErrClosed + } - packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail])) - packet.size = packetSize - copy(packet.data[:packetSize], buff[offset:]) - atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize)) - if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 { - windows.SetEvent(tun.rings.receive.tailMoved) - } - return int(packetSize), nil + buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment) + if alignedPacketSize > buffSpace { + return 0, nil // Dropping when ring is full. + } + + packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail])) + packet.size = packetSize + copy(packet.data[:packetSize], buff[offset:]) + atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize)) + if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 { + windows.SetEvent(tun.rings.receive.tailMoved) } + return int(packetSize), nil } // LUID returns Windows adapter instance ID. diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go index e8eadf5..88d565d 100644 --- a/tun/wintun/wintun_windows.go +++ b/tun/wintun/wintun_windows.go @@ -612,21 +612,25 @@ func (wintun *Wintun) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, err return 0, nil, windows.ERROR_OBJECT_NOT_FOUND } -// NdisFileName returns the Wintun NDIS device object name. -func (wintun *Wintun) NdisFileName() (string, error) { +// AdapterHandle returns a handle to the adapter device object. +func (wintun *Wintun) AdapterHandle() (windows.Handle, error) { key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.netRegKeyName(), registry.QUERY_VALUE) if err != nil { - return "", fmt.Errorf("Network-specific registry key open failed: %v", err) + return windows.InvalidHandle, fmt.Errorf("Network-specific registry key open failed: %v", err) } defer key.Close() // Get the interface name. pnpInstanceID, err := registryEx.GetStringValue(key, "PnPInstanceId") if err != nil { - return "", fmt.Errorf("PnpInstanceId registry key read failed: %v", err) + return windows.InvalidHandle, fmt.Errorf("PnPInstanceId registry key read failed: %v", err) } mangledPnpNode := strings.ReplaceAll(fmt.Sprintf("%s\\{cac88484-7515-4c03-82e6-71a87abac361}", pnpInstanceID), "\\", "#") - return fmt.Sprintf("\\\\.\\Global\\%s", mangledPnpNode), nil + handle, err := windows.CreateFile(windows.StringToUTF16Ptr(fmt.Sprintf("\\\\.\\Global\\%s", mangledPnpNode)), windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, 0, 0) + if err != nil { + return windows.InvalidHandle, fmt.Errorf("CreateFile on mangled PnPInstanceId path failed: %v", err) + } + return handle, nil } // GUID returns the GUID of the interface. |