diff options
Diffstat (limited to 'tun')
-rw-r--r-- | tun/tun_windows.go | 128 | ||||
-rw-r--r-- | tun/wintun/ring.go | 97 | ||||
-rw-r--r-- | tun/wintun/wintun_windows.go | 4 |
3 files changed, 133 insertions, 96 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 1891d21..9c635b5 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -19,40 +19,11 @@ import ( ) const ( - packetAlignment = 4 // Number of bytes packets are aligned to in rings - packetSizeMax = 0xffff // Maximum packet size - packetCapacity = 0x800000 // Ring capacity, 8MiB - packetTrailingSize = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment - ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond) spinloopRateThreshold = 800000000 / 8 // 800mbps spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s ) -type packetHeader struct { - size uint32 -} - -type packet struct { - packetHeader - data [packetSizeMax]byte -} - -type ring struct { - head uint32 - tail uint32 - alertable int32 - data [packetCapacity + packetTrailingSize]byte -} - -type ringDescriptor struct { - send, receive struct { - size uint32 - ring *ring - tailMoved windows.Handle - } -} - type rateJuggler struct { current uint64 nextByteCount uint64 @@ -64,7 +35,7 @@ type NativeTun struct { wt *wintun.Interface handle windows.Handle close bool - rings ringDescriptor + rings wintun.RingDescriptor events chan Event errors chan error forcedMTU int @@ -79,10 +50,6 @@ func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 -func packetAlign(size uint32) uint32 { - return (size + (packetAlignment - 1)) &^ (packetAlignment - 1) -} - // // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. @@ -127,30 +94,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev forcedMTU: 1500, } - tun.rings.send.size = uint32(unsafe.Sizeof(ring{})) - tun.rings.send.ring = &ring{} - tun.rings.send.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil) + err = tun.rings.Init() if err != nil { tun.Close() - return nil, fmt.Errorf("Error creating event: %v", err) + return nil, fmt.Errorf("Error creating events: %v", err) } - tun.rings.receive.size = uint32(unsafe.Sizeof(ring{})) - tun.rings.receive.ring = &ring{} - tun.rings.receive.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - tun.Close() - return nil, fmt.Errorf("Error creating event: %v", err) - } - - tun.handle, err = tun.wt.Handle() - if err != nil { - tun.Close() - return nil, err - } - - var bytesReturned uint32 - err = windows.DeviceIoControl(tun.handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil) + tun.handle, err = tun.wt.Register(&tun.rings) if err != nil { tun.Close() return nil, fmt.Errorf("Error registering rings: %v", err) @@ -172,18 +122,13 @@ func (tun *NativeTun) Events() chan Event { func (tun *NativeTun) Close() error { tun.close = true - if tun.rings.send.tailMoved != 0 { - windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping + if tun.rings.Send.TailMoved != 0 { + windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping } if tun.handle != windows.InvalidHandle { windows.CloseHandle(tun.handle) } - if tun.rings.send.tailMoved != 0 { - windows.CloseHandle(tun.rings.send.tailMoved) - } - if tun.rings.send.tailMoved != 0 { - windows.CloseHandle(tun.rings.receive.tailMoved) - } + tun.rings.Close() var err error if tun.wt != nil { _, err = tun.wt.DeleteInterface() @@ -214,8 +159,8 @@ retry: return 0, os.ErrClosed } - buffHead := atomic.LoadUint32(&tun.rings.send.ring.head) - if buffHead >= packetCapacity { + buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head) + if buffHead >= wintun.PacketCapacity { return 0, os.ErrClosed } @@ -223,7 +168,7 @@ retry: shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 var buffTail uint32 for { - buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail) + buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail) if buffHead != buffTail { break } @@ -231,35 +176,35 @@ retry: return 0, os.ErrClosed } if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { - windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE) + windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE) goto retry } procyield(1) } - if buffTail >= packetCapacity { + if buffTail >= wintun.PacketCapacity { return 0, os.ErrClosed } - buffContent := tun.rings.send.ring.wrap(buffTail - buffHead) - if buffContent < uint32(unsafe.Sizeof(packetHeader{})) { + buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead) + if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) { return 0, errors.New("incomplete packet header in send ring") } - packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead])) - if packet.size > packetSizeMax { + packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead])) + if packet.Size > wintun.PacketSizeMax { return 0, errors.New("packet too big in send ring") } - alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size) + alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size) if alignedPacketSize > buffContent { return 0, errors.New("incomplete packet in send ring") } - copy(buff[offset:], packet.data[:packet.size]) - buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize) - atomic.StoreUint32(&tun.rings.send.ring.head, buffHead) - tun.rate.update(uint64(packet.size)) - return int(packet.size), nil + copy(buff[offset:], packet.Data[:packet.Size]) + buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize) + atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead) + tun.rate.update(uint64(packet.Size)) + return int(packet.Size), nil } func (tun *NativeTun) Flush() error { @@ -273,29 +218,29 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { packetSize := uint32(len(buff) - offset) tun.rate.update(uint64(packetSize)) - alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize) + alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize) - buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head) - if buffHead >= packetCapacity { + buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head) + if buffHead >= wintun.PacketCapacity { return 0, os.ErrClosed } - buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail) - if buffTail >= packetCapacity { + buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail) + if buffTail >= wintun.PacketCapacity { return 0, os.ErrClosed } - buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment) + buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment) if alignedPacketSize > buffSpace { return 0, nil // Dropping when ring is full. } - packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail])) - packet.size = packetSize - copy(packet.data[:packetSize], buff[offset:]) - atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize)) - if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 { - windows.SetEvent(tun.rings.receive.tailMoved) + packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail])) + packet.Size = packetSize + copy(packet.Data[:packetSize], buff[offset:]) + atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize)) + if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 { + windows.SetEvent(tun.rings.Receive.TailMoved) } return int(packetSize), nil } @@ -305,11 +250,6 @@ func (tun *NativeTun) LUID() uint64 { return tun.wt.LUID() } -// wrap returns value modulo ring capacity -func (rb *ring) wrap(value uint32) uint32 { - return value & (packetCapacity - 1) -} - func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() total := atomic.AddUint64(&rate.nextByteCount, packetLen) diff --git a/tun/wintun/ring.go b/tun/wintun/ring.go new file mode 100644 index 0000000..8f46bc9 --- /dev/null +++ b/tun/wintun/ring.go @@ -0,0 +1,97 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wintun + +import ( + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + PacketAlignment = 4 // Number of bytes packets are aligned to in rings + PacketSizeMax = 0xffff // Maximum packet size + PacketCapacity = 0x800000 // Ring capacity, 8MiB + PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment + ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) +) + +type PacketHeader struct { + Size uint32 +} + +type Packet struct { + PacketHeader + Data [PacketSizeMax]byte +} + +type Ring struct { + Head uint32 + Tail uint32 + Alertable int32 + Data [PacketCapacity + PacketTrailingSize]byte +} + +type RingDescriptor struct { + Send, Receive struct { + Size uint32 + Ring *Ring + TailMoved windows.Handle + } +} + +// Wrap returns value modulo ring capacity +func (rb *Ring) Wrap(value uint32) uint32 { + return value & (PacketCapacity - 1) +} + +// Aligns a packet size to PacketAlignment +func PacketAlign(size uint32) uint32 { + return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1) +} + +func (descriptor *RingDescriptor) Init() (err error) { + descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{})) + descriptor.Send.Ring = &Ring{} + descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return + } + + descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{})) + descriptor.Receive.Ring = &Ring{} + descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + windows.CloseHandle(descriptor.Send.TailMoved) + return + } + + return +} + +func (descriptor *RingDescriptor) Close() { + if descriptor.Send.TailMoved != 0 { + windows.CloseHandle(descriptor.Send.TailMoved) + descriptor.Send.TailMoved = 0 + } + if descriptor.Send.TailMoved != 0 { + windows.CloseHandle(descriptor.Receive.TailMoved) + descriptor.Receive.TailMoved = 0 + } +} + +func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) { + handle, err := wintun.handle() + if err != nil { + return 0, err + } + var bytesReturned uint32 + err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil) + if err != nil { + return 0, err + } + return handle, nil +} diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go index fb8b908..e726748 100644 --- a/tun/wintun/wintun_windows.go +++ b/tun/wintun/wintun_windows.go @@ -698,8 +698,8 @@ func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, return 0, nil, windows.ERROR_OBJECT_NOT_FOUND } -// Handle returns a handle to the interface device object. -func (wintun *Interface) Handle() (windows.Handle, error) { +// handle returns a handle to the interface device object. +func (wintun *Interface) handle() (windows.Handle, error) { interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT) if err != nil { return windows.InvalidHandle, fmt.Errorf("Error listing NDIS interfaces: %v", err) |