diff options
-rw-r--r-- | tun/wintun/setupapi/setupapi_windows.go | 29 | ||||
-rw-r--r-- | tun/wintun/setupapi/setupapi_windows_test.go | 5 | ||||
-rw-r--r-- | tun/wintun/wintun_windows.go | 131 |
3 files changed, 106 insertions, 59 deletions
diff --git a/tun/wintun/setupapi/setupapi_windows.go b/tun/wintun/setupapi/setupapi_windows.go index 71732a4..5f9e05c 100644 --- a/tun/wintun/setupapi/setupapi_windows.go +++ b/tun/wintun/setupapi/setupapi_windows.go @@ -7,12 +7,14 @@ package setupapi import ( "encoding/binary" + "errors" "fmt" "syscall" "unsafe" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" + "golang.zx2c4.com/wireguard/tun/wintun/guid" ) //sys setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiCreateDeviceInfoListExW @@ -234,6 +236,33 @@ func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DI return SetupDiOpenDevRegKey(deviceInfoSet, DeviceInfoData, Scope, HwProfile, KeyType, samDesired) } +// GetInterfaceID method returns network interface ID. +func (deviceInfoSet DevInfo) GetInterfaceID(deviceInfoData *DevInfoData) (*windows.GUID, error) { + // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. + key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, DICS_FLAG_GLOBAL, 0, DIREG_DRV, registry.READ) + if err != nil { + return nil, errors.New("Device-specific registry key open failed: " + err.Error()) + } + defer key.Close() + + // Read the NetCfgInstanceId value. + value, valueType, err := key.GetStringValue("NetCfgInstanceId") + if err != nil { + return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error()) + } + if valueType != registry.SZ { + return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType) + } + + // Convert to windows.GUID. + ifid, err := guid.FromString(value) + if err != nil { + return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value) + } + + return ifid, nil +} + //sys setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceRegistryPropertyW // SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property. diff --git a/tun/wintun/setupapi/setupapi_windows_test.go b/tun/wintun/setupapi/setupapi_windows_test.go index 30f3692..c6f4a15 100644 --- a/tun/wintun/setupapi/setupapi_windows_test.go +++ b/tun/wintun/setupapi/setupapi_windows_test.go @@ -291,6 +291,11 @@ func TestSetupDiOpenDevRegKey(t *testing.T) { t.Errorf("Error calling SetupDiOpenDevRegKey: %s", err.Error()) } defer key.Close() + + _, err = devInfoList.GetInterfaceID(data) + if err != nil { + t.Errorf("Error calling GetInterfaceID: %s", err.Error()) + } } } diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go index 85d29f4..69fd30c 100644 --- a/tun/wintun/wintun_windows.go +++ b/tun/wintun/wintun_windows.go @@ -58,27 +58,24 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) { // Iterate. for index := 0; ; index++ { - // Get the device from the list. + // Get the device from the list. Should anything be wrong with this device, continue with next. deviceData, err := devInfoList.EnumDeviceInfo(index) if err != nil { if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { break } - // Something is wrong with this device. Skip it. continue } // Get interface ID. - ifid, err := getInterfaceID(devInfoList, deviceData, 1) + ifid, err := devInfoList.GetInterfaceID(deviceData) if err != nil { - // Something is wrong with this device. Skip it. continue } // Get interface name. ifname2, err := ((*Wintun)(ifid)).GetInterfaceName() if err != nil { - // Something is wrong with this device. Skip it. continue } @@ -243,8 +240,74 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err rebootRequired = true } - // Get network interface ID from registry. Retry for max 30sec. - ifid, err = getInterfaceID(devInfoList, deviceData, 30) + // Get network interface ID from registry. DIF_INSTALLDEVICE returns almost immediately, + // while the device installation continues in the background. It might take a while, before + // all registry keys and values are populated. + getInterfaceID := func() (*windows.GUID, error) { + // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. + keyDev, err := devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ) + if err != nil { + return nil, errors.New("Device-specific registry key open failed: " + err.Error()) + } + defer keyDev.Close() + + // Read the NetCfgInstanceId value. + value, err := getRegStringValue(keyDev, "NetCfgInstanceId") + if err != nil { + if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND { + return nil, err + } + + return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error()) + } + + // Convert to windows.GUID. + ifid, err := guid.FromString(value) + if err != nil { + return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value) + } + + keyNetName := fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), value) + keyNet, err := registry.OpenKey(registry.LOCAL_MACHINE, keyNetName, registry.QUERY_VALUE) + if err != nil { + if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND { + return nil, err + } + + return nil, errors.New(fmt.Sprintf("RegOpenKeyEx(\"%v\") failed: ", keyNetName) + err.Error()) + } + defer keyNet.Close() + + // Query the interface name. + _, valueType, err := keyNet.GetValue("Name", nil) + if err != nil { + if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND { + return nil, err + } + + return nil, errors.New("RegQueryValueEx(\"Name\") failed: " + err.Error()) + } + switch valueType { + case registry.SZ, registry.EXPAND_SZ: + default: + return nil, fmt.Errorf("Interface name registry value is not REG_SZ or REG_EXPAND_SZ (expected: %v or %v, provided: %v)", registry.SZ, registry.EXPAND_SZ, valueType) + } + + // TUN interface is ready. (As far as we need it.) + return ifid, nil + } + for numAttempts := 0; numAttempts < 30; numAttempts++ { + ifid, err = getInterfaceID() + if err != nil { + if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND { + // Wait and retry. TODO: Wait for a cancellable event instead. + time.Sleep(1000 * time.Millisecond) + continue + } + } + + break + } } if err == nil { @@ -294,20 +357,18 @@ func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) { // Iterate. for index := 0; ; index++ { - // Get the device from the list. + // Get the device from the list. Should anything be wrong with this device, continue with next. deviceData, err := devInfoList.EnumDeviceInfo(index) if err != nil { if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { break } - // Something is wrong with this device. Skip it. continue } // Get interface ID. - ifid2, err := getInterfaceID(devInfoList, deviceData, 1) + ifid2, err := devInfoList.GetInterfaceID(deviceData) if err != nil { - // Something is wrong with this device. Skip it. continue } @@ -367,54 +428,6 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf return false, nil } -// getInterfaceID returns network interface ID. -// -// After the device is created, it might take some time before the registry -// key is populated. numAttempts parameter specifies the number of attempts -// to read NetCfgInstanceId value from registry. A 1sec sleep is inserted -// between retry attempts. -// -// Function returns the network interface ID. -// -func getInterfaceID(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData, numAttempts int) (*windows.GUID, error) { - if numAttempts < 1 { - return nil, fmt.Errorf("Invalid numAttempts (expected: >=1, provided: %v)", numAttempts) - } - - // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. - key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ) - if err != nil { - return nil, errors.New("Device-specific registry key open failed: " + err.Error()) - } - defer key.Close() - - for { - // Read the NetCfgInstanceId value. - value, err := getRegStringValue(key, "NetCfgInstanceId") - if err != nil { - if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND { - numAttempts-- - if numAttempts > 0 { - // Wait and retry. - // TODO: Wait for a cancellable event instead. - time.Sleep(1000 * time.Millisecond) - continue - } - } - - return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error()) - } - - // Convert to windows.GUID. - ifid, err := guid.FromString(value) - if err != nil { - return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: \"%v\")", value) - } - - return ifid, err - } -} - // // GetInterfaceName returns network interface name. // |