summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-11-21 14:48:21 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2019-11-22 13:13:55 +0100
commit2b242f93932e1c4ab8b45dd0f628dd4fe063699b (patch)
tree379f4a2daca264c09f12b6f7fd592243aa05202c
parent4cdf805b29b1aaca1fab317ca4fce54c7fd69bf6 (diff)
wintun: manage ring memory manually
It's large and Go's garbage collector doesn't deal with it especially well.
-rw-r--r--tun/tun_windows.go6
-rw-r--r--tun/wintun/ring_windows.go28
2 files changed, 27 insertions, 7 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index daad4aa..8fc5174 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -35,11 +35,11 @@ type NativeTun struct {
wt *wintun.Interface
handle windows.Handle
close bool
- rings wintun.RingDescriptor
events chan Event
errors chan error
forcedMTU int
rate rateJuggler
+ rings *wintun.RingDescriptor
}
const WintunPool = wintun.Pool("WireGuard")
@@ -93,13 +93,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
forcedMTU: forcedMTU,
}
- err = tun.rings.Init()
+ tun.rings, err = wintun.NewRingDescriptor()
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating events: %v", err)
}
- tun.handle, err = tun.wt.Register(&tun.rings)
+ tun.handle, err = tun.wt.Register(tun.rings)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err)
diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go
index 8f46bc9..8e6b375 100644
--- a/tun/wintun/ring_windows.go
+++ b/tun/wintun/ring_windows.go
@@ -6,6 +6,7 @@
package wintun
import (
+ "runtime"
"unsafe"
"golang.org/x/sys/windows"
@@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 {
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
}
-func (descriptor *RingDescriptor) Init() (err error) {
+func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
+ descriptor = new(RingDescriptor)
+ allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
+ if err != nil {
+ return
+ }
+ defer func() {
+ if err != nil {
+ descriptor.free()
+ descriptor = nil
+ }
+ }()
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
- descriptor.Send.Ring = &Ring{}
+ descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return
}
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
- descriptor.Receive.Ring = &Ring{}
+ descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
windows.CloseHandle(descriptor.Send.TailMoved)
return
}
-
+ runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
return
}
+func (descriptor *RingDescriptor) free() {
+ if descriptor.Send.Ring != nil {
+ windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
+ descriptor.Send.Ring = nil
+ descriptor.Receive.Ring = nil
+ }
+}
+
func (descriptor *RingDescriptor) Close() {
if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Send.TailMoved)