summaryrefslogtreecommitdiffhomepage
path: root/tun/wintun
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-03-31 10:17:11 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2019-04-01 09:07:43 +0200
commit92f847483200a63193d55418381e685621b24e5c (patch)
treed03dec21526ee523e5364a615321ede58e3a2ec7 /tun/wintun
parent2e0ed4614addc5e1842cf652c5d23779581ca7f2 (diff)
wintun: add more retry loops
Diffstat (limited to 'tun/wintun')
-rw-r--r--tun/wintun/registryhacks_windows.go42
-rw-r--r--tun/wintun/wintun_windows.go27
2 files changed, 51 insertions, 18 deletions
diff --git a/tun/wintun/registryhacks_windows.go b/tun/wintun/registryhacks_windows.go
new file mode 100644
index 0000000..62a629a
--- /dev/null
+++ b/tun/wintun/registryhacks_windows.go
@@ -0,0 +1,42 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package wintun
+
+import (
+ "golang.org/x/sys/windows/registry"
+ "time"
+)
+
+const (
+ numRetries = 25
+ retryTimeout = 100 * time.Millisecond
+)
+
+func registryOpenKeyRetry(k registry.Key, path string, access uint32) (key registry.Key, err error) {
+ for i := 0; i < numRetries; i++ {
+ key, err = registry.OpenKey(k, path, access)
+ if err == nil {
+ break
+ }
+ if i != numRetries - 1 {
+ time.Sleep(retryTimeout)
+ }
+ }
+ return
+}
+
+func keyGetStringValueRetry(k registry.Key, name string) (val string, valtype uint32, err error) {
+ for i := 0; i < numRetries; i++ {
+ val, valtype, err = k.GetStringValue(name)
+ if err == nil {
+ break
+ }
+ if i != numRetries - 1 {
+ time.Sleep(retryTimeout)
+ }
+ }
+ return
+}
diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go
index ba94b11..77e83a0 100644
--- a/tun/wintun/wintun_windows.go
+++ b/tun/wintun/wintun_windows.go
@@ -48,22 +48,14 @@ func MakeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
var valueStr string
var valueType uint32
- //TODO: Figure out a way to not need to loop like this.
- for i := 0; i < 30; i++ {
- // Read the NetCfgInstanceId value.
- valueStr, valueType, err = key.GetStringValue("NetCfgInstanceId")
- if err != nil {
- time.Sleep(time.Millisecond * 100)
- continue
- }
- if valueType != registry.SZ {
- return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
- }
- break
- }
+ // Read the NetCfgInstanceId value.
+ valueStr, valueType, err = keyGetStringValueRetry(key, "NetCfgInstanceId")
if err != nil {
return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error())
}
+ if valueType != registry.SZ {
+ return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
+ }
// Convert to windows.GUID.
ifid, err := guid.FromString(valueStr)
@@ -117,7 +109,6 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
// "foobar" would cause conflict with "FooBar".
ifname = strings.ToLower(ifname)
- // Iterate.
for index := 0; ; index++ {
// Get the device from the list. Should anything be wrong with this device, continue with next.
deviceData, err := devInfoList.EnumDeviceInfo(index)
@@ -174,7 +165,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
}
// This interface is not using Wintun driver.
- return wintun, errors.New("Foreign network interface with the same name exists")
+ return nil, errors.New("Foreign network interface with the same name exists")
}
}
@@ -444,7 +435,7 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
// GetInterfaceName returns network interface name.
//
func (wintun *Wintun) GetInterfaceName() (string, error) {
- key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
+ key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
if err != nil {
return "", errors.New("Network-specific registry key open failed: " + err.Error())
}
@@ -458,7 +449,7 @@ func (wintun *Wintun) GetInterfaceName() (string, error) {
// SetInterfaceName sets network interface name.
//
func (wintun *Wintun) SetInterfaceName(ifname string) error {
- key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
+ key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
if err != nil {
return errors.New("Network-specific registry key open failed: " + err.Error())
}
@@ -483,7 +474,7 @@ func (wintun *Wintun) GetNetRegKeyName() string {
//
func getRegStringValue(key registry.Key, name string) (string, error) {
// Read string value.
- value, valueType, err := key.GetStringValue(name)
+ value, valueType, err := keyGetStringValueRetry(key, name)
if err != nil {
return "", err
}