summaryrefslogtreecommitdiffhomepage
path: root/conn/winrio/rio_windows.go
diff options
context:
space:
mode:
Diffstat (limited to 'conn/winrio/rio_windows.go')
-rw-r--r--conn/winrio/rio_windows.go243
1 files changed, 243 insertions, 0 deletions
diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go
new file mode 100644
index 0000000..1785a02
--- /dev/null
+++ b/conn/winrio/rio_windows.go
@@ -0,0 +1,243 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package winrio
+
+import (
+ "log"
+ "sync"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+const (
+ MsgDontNotify = 1
+ MsgDefer = 2
+ MsgWaitAll = 4
+ MsgCommitOnly = 8
+
+ MaxCqSize = 0x8000000
+
+ invalidBufferId = 0xFFFFFFFF
+ invalidCq = 0
+ invalidRq = 0
+ corruptCq = 0xFFFFFFFF
+)
+
+var extensionFunctionTable struct {
+ cbSize uint32
+ rioReceive uintptr
+ rioReceiveEx uintptr
+ rioSend uintptr
+ rioSendEx uintptr
+ rioCloseCompletionQueue uintptr
+ rioCreateCompletionQueue uintptr
+ rioCreateRequestQueue uintptr
+ rioDequeueCompletion uintptr
+ rioDeregisterBuffer uintptr
+ rioNotify uintptr
+ rioRegisterBuffer uintptr
+ rioResizeCompletionQueue uintptr
+ rioResizeRequestQueue uintptr
+}
+
+type Cq uintptr
+
+type Rq uintptr
+
+type BufferId uintptr
+
+type Buffer struct {
+ Id BufferId
+ Offset uint32
+ Length uint32
+}
+
+type Result struct {
+ Status int32
+ BytesTransferred uint32
+ SocketContext uint64
+ RequestContext uint64
+}
+
+type notificationCompletionType uint32
+
+const (
+ eventCompletion notificationCompletionType = 1
+ iocpCompletion notificationCompletionType = 2
+)
+
+type eventNotificationCompletion struct {
+ completionType notificationCompletionType
+ event windows.Handle
+ notifyReset uint32
+}
+
+type iocpNotificationCompletion struct {
+ completionType notificationCompletionType
+ iocp windows.Handle
+ key uintptr
+ overlapped *windows.Overlapped
+}
+
+var initialized sync.Once
+var available bool
+
+func Initialize() bool {
+ initialized.Do(func() {
+ var (
+ err error
+ socket windows.Handle
+ cq Cq
+ )
+ defer func() {
+ if err == nil {
+ return
+ }
+ if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
+ return
+ }
+ log.Printf("Registered I/O is unavailable: %v", err)
+ }()
+ socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
+ if err != nil {
+ return
+ }
+ defer windows.CloseHandle(socket)
+ var WSAID_MULTIPLE_RIO = &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
+ const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
+ ob := uint32(0)
+ err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
+ (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
+ (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
+ &ob, nil, 0)
+ if err != nil {
+ return
+ }
+ // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
+ // failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
+ cq, err = CreatePolledCompletionQueue(2)
+ if err != nil {
+ return
+ }
+ defer CloseCompletionQueue(cq)
+ _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
+ if err != nil {
+ return
+ }
+ available = true
+ })
+ return available
+}
+
+func Socket(af, typ, proto int32) (windows.Handle, error) {
+ return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
+}
+
+func CloseCompletionQueue(cq Cq) {
+ _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
+}
+
+func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
+ notificationCompletion := &eventNotificationCompletion{
+ completionType: eventCompletion,
+ event: event,
+ }
+ if notifyReset {
+ notificationCompletion.notifyReset = 1
+ }
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
+ notificationCompletion := &iocpNotificationCompletion{
+ completionType: iocpCompletion,
+ iocp: iocp,
+ overlapped: overlapped,
+ }
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
+ if ret == invalidCq {
+ return 0, err
+ }
+ return Cq(ret), nil
+}
+
+func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
+ if ret == invalidRq {
+ return 0, err
+ }
+ return Rq(ret), nil
+}
+
+func DequeueCompletion(cq Cq, results []Result) uint32 {
+ var array uintptr
+ if len(results) > 0 {
+ array = uintptr(unsafe.Pointer(&results[0]))
+ }
+ ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
+ if ret == corruptCq {
+ panic("cq is corrupt")
+ }
+ return uint32(ret)
+}
+
+func DeregisterBuffer(id BufferId) {
+ _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
+}
+
+func RegisterBuffer(buffer []byte) (BufferId, error) {
+ var buf unsafe.Pointer
+ if len(buffer) > 0 {
+ buf = unsafe.Pointer(&buffer[0])
+ }
+ return RegisterPointer(buf, uint32(len(buffer)))
+}
+
+func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
+ ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
+ if ret == invalidBufferId {
+ return 0, err
+ }
+ return BufferId(ret), nil
+}
+
+func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
+ ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
+ if ret == 0 {
+ return err
+ }
+ return nil
+}
+
+func Notify(cq Cq) error {
+ ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
+ if ret != 0 {
+ return windows.Errno(ret)
+ }
+ return nil
+}