diff options
Diffstat (limited to 'tun')
-rw-r--r-- | tun/tun_windows.go | 6 | ||||
-rw-r--r-- | tun/wintun/ring_windows.go | 28 |
2 files changed, 27 insertions, 7 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go index daad4aa..8fc5174 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -35,11 +35,11 @@ type NativeTun struct { wt *wintun.Interface handle windows.Handle close bool - rings wintun.RingDescriptor events chan Event errors chan error forcedMTU int rate rateJuggler + rings *wintun.RingDescriptor } const WintunPool = wintun.Pool("WireGuard") @@ -93,13 +93,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu forcedMTU: forcedMTU, } - err = tun.rings.Init() + tun.rings, err = wintun.NewRingDescriptor() if err != nil { tun.Close() return nil, fmt.Errorf("Error creating events: %v", err) } - tun.handle, err = tun.wt.Register(&tun.rings) + tun.handle, err = tun.wt.Register(tun.rings) if err != nil { tun.Close() return nil, fmt.Errorf("Error registering rings: %v", err) diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go index 8f46bc9..8e6b375 100644 --- a/tun/wintun/ring_windows.go +++ b/tun/wintun/ring_windows.go @@ -6,6 +6,7 @@ package wintun import ( + "runtime" "unsafe" "golang.org/x/sys/windows" @@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 { return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1) } -func (descriptor *RingDescriptor) Init() (err error) { +func NewRingDescriptor() (descriptor *RingDescriptor, err error) { + descriptor = new(RingDescriptor) + allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + return + } + defer func() { + if err != nil { + descriptor.free() + descriptor = nil + } + }() descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{})) - descriptor.Send.Ring = &Ring{} + descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion)) descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) if err != nil { return } descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{})) - descriptor.Receive.Ring = &Ring{} + descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{}))) descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) if err != nil { windows.CloseHandle(descriptor.Send.TailMoved) return } - + runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() }) return } +func (descriptor *RingDescriptor) free() { + if descriptor.Send.Ring != nil { + windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE) + descriptor.Send.Ring = nil + descriptor.Receive.Ring = nil + } +} + func (descriptor *RingDescriptor) Close() { if descriptor.Send.TailMoved != 0 { windows.CloseHandle(descriptor.Send.TailMoved) |