diff options
Diffstat (limited to 'tun/wintun/registry')
-rw-r--r-- | tun/wintun/registry/registry_windows.go | 147 |
1 files changed, 81 insertions, 66 deletions
diff --git a/tun/wintun/registry/registry_windows.go b/tun/wintun/registry/registry_windows.go index 415aa00..b996c23 100644 --- a/tun/wintun/registry/registry_windows.go +++ b/tun/wintun/registry/registry_windows.go @@ -10,7 +10,9 @@ import ( "fmt" "runtime" "strings" + "syscall" "time" + "unsafe" "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" @@ -102,18 +104,44 @@ func WaitForKey(k registry.Key, path string, timeout time.Duration) error { } // -// getStringValueRetry function reads a string value from registry. It waits for +// getValue is the same as windows/registry's getValue, which is unfortunately +// private. +// +func getValue(k registry.Key, name string, buf []byte) ([]byte, uint32, error) { + p, err := syscall.UTF16PtrFromString(name) + if err != nil { + return nil, 0, err + } + var t uint32 + n := uint32(len(buf)) + for { + err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n) + if err == nil { + return buf[:n], t, nil + } + if err != syscall.ERROR_MORE_DATA { + return nil, 0, err + } + if n <= uint32(len(buf)) { + return nil, 0, err + } + buf = make([]byte, n) + } +} + +// +// getValueRetry function reads any value from registry. It waits for // the registry value to become available or returns error on timeout. // // Key must be opened with at least QUERY_VALUE|NOTIFY access. // -func getStringValueRetry(key registry.Key, name string, timeout time.Duration, useFirstFromMulti bool) (string, uint32, error) { +func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) { runtime.LockOSThread() defer runtime.UnlockOSThread() event, err := windows.CreateEvent(nil, 0, 0, nil) if err != nil { - return "", 0, fmt.Errorf("Error creating event: %v", err) + return nil, 0, fmt.Errorf("Error creating event: %v", err) } defer windows.CloseHandle(event) @@ -121,46 +149,47 @@ func getStringValueRetry(key registry.Key, name string, timeout time.Duration, u for { err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true) if err != nil { - return "", 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err) + return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err) } - var value string - var values []string - var valueType uint32 - if !useFirstFromMulti { - value, valueType, err = key.GetStringValue(name) - } else { - values, valueType, err = key.GetStringsValue(name) - } - if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND || (useFirstFromMulti && len(values) == 0) { + buf, valueType, err := getValue(key, name, buf) + if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { timeout := time.Until(deadline) / time.Millisecond if timeout < 0 { timeout = 0 } s, err := windows.WaitForSingleObject(event, uint32(timeout)) if err != nil { - return "", 0, fmt.Errorf("Unable to wait on registry value: %v", err) + return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err) } if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows - return "", 0, errors.New("Timeout waiting for registry value") + return nil, 0, errors.New("Timeout waiting for registry value") } } else if err != nil { - return "", 0, fmt.Errorf("Error reading registry value %v: %v", name, err) + return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err) } else { - if !useFirstFromMulti { - return value, valueType, nil - } else { - return values[0], registry.SZ, nil - } + return buf, valueType, nil } } } -func expandString(value string, valueType uint32, err error) (string, error) { +func toString(buf []byte, valueType uint32, err error) (string, error) { if err != nil { return "", err } + var value string + switch valueType { + case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ: + if len(buf) == 0 { + return "", nil + } + value = syscall.UTF16ToString((*[1 << 29]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2]) + + default: + return "", registry.ErrUnexpectedType + } + if valueType != registry.EXPAND_SZ { // Value does not require expansion. return value, nil @@ -176,6 +205,29 @@ func expandString(value string, valueType uint32, err error) (string, error) { return valueExp, nil } +func toInteger(buf []byte, valueType uint32, err error) (uint64, error) { + if err != nil { + return 0, err + } + + switch valueType { + case registry.DWORD: + if len(buf) != 4 { + return 0, errors.New("DWORD value is not 4 bytes long") + } + return uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), nil + + case registry.QWORD: + if len(buf) != 8 { + return 0, errors.New("QWORD value is not 8 bytes long") + } + return uint64(*(*uint64)(unsafe.Pointer(&buf[0]))), nil + + default: + return 0, registry.ErrUnexpectedType + } +} + // // GetStringValueWait function reads a string value from registry. It waits // for the registry value to become available or returns error on timeout. @@ -185,15 +237,10 @@ func expandString(value string, valueType uint32, err error) (string, error) { // If the value type is REG_EXPAND_SZ the environment variables are expanded. // Should expanding fail, original string value and nil error are returned. // -func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) { - return expandString(getStringValueRetry(key, name, timeout, false)) -} - -// -// Same as GetStringValueWait, but returns the first from a MULTI_SZ. +// If the value type is REG_MULTI_SZ only the first string is returned. // -func GetFirstStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) { - return expandString(getStringValueRetry(key, name, timeout, true)) +func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) { + return toString(getValueRetry(key, name, make([]byte, 64), timeout)) } // @@ -204,8 +251,10 @@ func GetFirstStringValueWait(key registry.Key, name string, timeout time.Duratio // If the value type is REG_EXPAND_SZ the environment variables are expanded. // Should expanding fail, original string value and nil error are returned. // +// If the value type is REG_MULTI_SZ only the first string is returned. +// func GetStringValue(key registry.Key, name string) (string, error) { - return expandString(key.GetStringValue(name)) + return toString(getValue(key, name, make([]byte, 64))) } // @@ -216,39 +265,5 @@ func GetStringValue(key registry.Key, name string) (string, error) { // Key must be opened with at least QUERY_VALUE|NOTIFY access. // func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - event, err := windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return 0, fmt.Errorf("Error creating event: %v", err) - } - defer windows.CloseHandle(event) - - deadline := time.Now().Add(timeout) - for { - err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true) - if err != nil { - return 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err) - } - - value, _, err := key.GetIntegerValue(name) - if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { - timeout := time.Until(deadline) / time.Millisecond - if timeout < 0 { - timeout = 0 - } - s, err := windows.WaitForSingleObject(event, uint32(timeout)) - if err != nil { - return 0, fmt.Errorf("Unable to wait on registry value: %v", err) - } - if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows - return 0, errors.New("Timeout waiting for registry value") - } - } else if err != nil { - return 0, fmt.Errorf("Error reading registry value %v: %v", name, err) - } else { - return value, nil - } - } + return toInteger(getValueRetry(key, name, make([]byte, 8), timeout)) } |