summaryrefslogtreecommitdiffhomepage
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/alignment_test.go1
-rw-r--r--device/device.go9
-rw-r--r--device/pools.go113
-rw-r--r--device/pools_test.go60
4 files changed, 116 insertions, 67 deletions
diff --git a/device/alignment_test.go b/device/alignment_test.go
index 5587cbe..46baeb1 100644
--- a/device/alignment_test.go
+++ b/device/alignment_test.go
@@ -42,7 +42,6 @@ func TestPeerAlignment(t *testing.T) {
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
}
-
// TestDeviceAlignment checks that atomically-accessed fields are
// aligned to 64-bit boundaries, as required by the atomic package.
//
diff --git a/device/device.go b/device/device.go
index bac361e..5f36036 100644
--- a/device/device.go
+++ b/device/device.go
@@ -67,12 +67,9 @@ type Device struct {
}
pool struct {
- messageBufferPool *sync.Pool
- messageBufferReuseChan chan *[MaxMessageSize]byte
- inboundElementPool *sync.Pool
- inboundElementReuseChan chan *QueueInboundElement
- outboundElementPool *sync.Pool
- outboundElementReuseChan chan *QueueOutboundElement
+ messageBuffers *WaitPool
+ inboundElements *WaitPool
+ outboundElements *WaitPool
}
queue struct {
diff --git a/device/pools.go b/device/pools.go
index eb6d6be..f1d1fa0 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -5,87 +5,80 @@
package device
-import "sync"
+import (
+ "sync"
+ "sync/atomic"
+)
-func (device *Device) PopulatePools() {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.messageBufferPool = &sync.Pool{
- New: func() interface{} {
- return new([MaxMessageSize]byte)
- },
- }
- device.pool.inboundElementPool = &sync.Pool{
- New: func() interface{} {
- return new(QueueInboundElement)
- },
- }
- device.pool.outboundElementPool = &sync.Pool{
- New: func() interface{} {
- return new(QueueOutboundElement)
- },
- }
- } else {
- device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i++ {
- device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
- }
- device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i++ {
- device.pool.inboundElementReuseChan <- new(QueueInboundElement)
- }
- device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
- for i := 0; i < PreallocatedBuffersPerPool; i++ {
- device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
+type WaitPool struct {
+ pool sync.Pool
+ cond sync.Cond
+ lock sync.Mutex
+ count uint32
+ max uint32
+}
+
+func NewWaitPool(max uint32, new func() interface{}) *WaitPool {
+ p := &WaitPool{pool: sync.Pool{New: new}, max: max}
+ p.cond = sync.Cond{L: &p.lock}
+ return p
+}
+
+func (p *WaitPool) Get() interface{} {
+ if p.max != 0 {
+ p.lock.Lock()
+ for atomic.LoadUint32(&p.count) >= p.max {
+ p.cond.Wait()
}
+ atomic.AddUint32(&p.count, 1)
+ p.lock.Unlock()
}
+ return p.pool.Get()
}
-func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
- } else {
- return <-device.pool.messageBufferReuseChan
+func (p *WaitPool) Put(x interface{}) {
+ p.pool.Put(x)
+ if p.max == 0 {
+ return
}
+ atomic.AddUint32(&p.count, ^uint32(0))
+ p.cond.Signal()
+}
+
+func (device *Device) PopulatePools() {
+ device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ return new([MaxMessageSize]byte)
+ })
+ device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ return new(QueueInboundElement)
+ })
+ device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() interface{} {
+ return new(QueueOutboundElement)
+ })
+}
+
+func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
+ return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
- if PreallocatedBuffersPerPool == 0 {
- device.pool.messageBufferPool.Put(msg)
- } else {
- device.pool.messageBufferReuseChan <- msg
- }
+ device.pool.messageBuffers.Put(msg)
}
func (device *Device) GetInboundElement() *QueueInboundElement {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.inboundElementPool.Get().(*QueueInboundElement)
- } else {
- return <-device.pool.inboundElementReuseChan
- }
+ return device.pool.inboundElements.Get().(*QueueInboundElement)
}
func (device *Device) PutInboundElement(elem *QueueInboundElement) {
elem.clearPointers()
- if PreallocatedBuffersPerPool == 0 {
- device.pool.inboundElementPool.Put(elem)
- } else {
- device.pool.inboundElementReuseChan <- elem
- }
+ device.pool.inboundElements.Put(elem)
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
- if PreallocatedBuffersPerPool == 0 {
- return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
- } else {
- return <-device.pool.outboundElementReuseChan
- }
+ return device.pool.outboundElements.Get().(*QueueOutboundElement)
}
func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
elem.clearPointers()
- if PreallocatedBuffersPerPool == 0 {
- device.pool.outboundElementPool.Put(elem)
- } else {
- device.pool.outboundElementReuseChan <- elem
- }
+ device.pool.outboundElements.Put(elem)
}
diff --git a/device/pools_test.go b/device/pools_test.go
new file mode 100644
index 0000000..e6cbac5
--- /dev/null
+++ b/device/pools_test.go
@@ -0,0 +1,60 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "math/rand"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestWaitPool(t *testing.T) {
+ var wg sync.WaitGroup
+ trials := int32(100000)
+ workers := runtime.NumCPU() + 2
+ if workers-4 <= 0 {
+ t.Skip("Not enough cores")
+ }
+ p := NewWaitPool(uint32(workers-4), func() interface{} { return make([]byte, 16) })
+ wg.Add(workers)
+ max := uint32(0)
+ updateMax := func() {
+ count := atomic.LoadUint32(&p.count)
+ if count > p.max {
+ t.Errorf("count (%d) > max (%d)", count, p.max)
+ }
+ for {
+ old := atomic.LoadUint32(&max)
+ if count <= old {
+ break
+ }
+ if atomic.CompareAndSwapUint32(&max, old, count) {
+ break
+ }
+ }
+ }
+ for i := 0; i < workers; i++ {
+ go func() {
+ defer wg.Done()
+ for atomic.AddInt32(&trials, -1) > 0 {
+ updateMax()
+ x := p.Get()
+ updateMax()
+ time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
+ updateMax()
+ p.Put(x)
+ updateMax()
+ }
+ }()
+ }
+ wg.Wait()
+ if max != p.max {
+ t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
+ }
+}