summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--device/device.go39
-rw-r--r--device/device_test.go136
-rw-r--r--device/send.go93
3 files changed, 170 insertions, 98 deletions
diff --git a/device/device.go b/device/device.go
index 9e2d001..d9367e5 100644
--- a/device/device.go
+++ b/device/device.go
@@ -74,7 +74,7 @@ type Device struct {
}
queue struct {
- encryption chan *QueueOutboundElement
+ encryption *encryptionQueue
decryption chan *QueueInboundElement
handshake chan QueueHandshakeElement
}
@@ -89,6 +89,31 @@ type Device struct {
}
}
+// An encryptionQueue is a channel of QueueOutboundElements awaiting encryption.
+// An encryptionQueue is ref-counted using its wg field.
+// An encryptionQueue created with newEncryptionQueue has one reference.
+// Every additional writer must call wg.Add(1).
+// Every completed writer must call wg.Done().
+// When no further writers will be added,
+// call wg.Done to remove the initial reference.
+// When the refcount hits 0, the queue's channel is closed.
+type encryptionQueue struct {
+ c chan *QueueOutboundElement
+ wg sync.WaitGroup
+}
+
+func newEncryptionQueue() *encryptionQueue {
+ q := &encryptionQueue{
+ c: make(chan *QueueOutboundElement, QueueOutboundSize),
+ }
+ q.wg.Add(1)
+ go func() {
+ q.wg.Wait()
+ close(q.c)
+ }()
+ return q
+}
+
/* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table.
*
@@ -280,7 +305,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
// create queues
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
- device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
+ device.queue.encryption = newEncryptionQueue()
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals
@@ -297,7 +322,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
cpus := runtime.NumCPU()
device.state.stopping.Wait()
for i := 0; i < cpus; i += 1 {
- device.state.stopping.Add(3)
+ device.state.stopping.Add(2) // decryption and handshake
go device.RoutineEncryption()
go device.RoutineDecryption()
go device.RoutineHandshake()
@@ -346,10 +371,6 @@ func (device *Device) FlushPacketQueues() {
if ok {
elem.Drop()
}
- case elem, ok := <-device.queue.encryption:
- if ok {
- elem.Drop()
- }
case <-device.queue.handshake:
default:
return
@@ -373,6 +394,10 @@ func (device *Device) Close() {
device.isUp.Set(false)
+ // We kept a reference to the encryption queue,
+ // in case we started any new peers that might write to it.
+ // No new peers are coming; we are done with the encryption queue.
+ device.queue.encryption.wg.Done()
close(device.signals.stop)
device.state.stopping.Wait()
diff --git a/device/device_test.go b/device/device_test.go
index a89dcc2..65942ec 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -7,9 +7,11 @@ package device
import (
"bytes"
+ "errors"
"fmt"
"io"
"net"
+ "sync"
"testing"
"time"
@@ -79,18 +81,74 @@ func genConfigs(t *testing.T) (cfgs [2]io.Reader) {
return
}
-// genChannelTUNs creates a usable pair of ChannelTUNs for use in a test.
-func genChannelTUNs(t *testing.T) (tun [2]*tuntest.ChannelTUN) {
+// A testPair is a pair of testPeers.
+type testPair [2]testPeer
+
+// A testPeer is a peer used for testing.
+type testPeer struct {
+ tun *tuntest.ChannelTUN
+ dev *Device
+ ip net.IP
+}
+
+type SendDirection bool
+
+const (
+ Ping SendDirection = true
+ Pong SendDirection = false
+)
+
+func (pair *testPair) Send(t *testing.T, ping SendDirection, done chan struct{}) {
+ t.Helper()
+ p0, p1 := pair[0], pair[1]
+ if !ping {
+ // pong is the new ping
+ p0, p1 = p1, p0
+ }
+ msg := tuntest.Ping(p0.ip, p1.ip)
+ p1.tun.Outbound <- msg
+ timer := time.NewTimer(5 * time.Second)
+ defer timer.Stop()
+ var err error
+ select {
+ case msgRecv := <-p0.tun.Inbound:
+ if !bytes.Equal(msg, msgRecv) {
+ err = errors.New("ping did not transit correctly")
+ }
+ case <-timer.C:
+ err = errors.New("ping did not transit")
+ case <-done:
+ }
+ if err != nil {
+ // The error may have occurred because the test is done.
+ select {
+ case <-done:
+ return
+ default:
+ }
+ // Real error.
+ t.Error(err)
+ }
+}
+
+// genTestPair creates a testPair.
+func genTestPair(t *testing.T) (pair testPair) {
const maxAttempts = 10
NextAttempt:
for i := 0; i < maxAttempts; i++ {
cfg := genConfigs(t)
// Bring up a ChannelTun for each config.
- for i := range tun {
- tun[i] = tuntest.NewChannelTUN()
- dev := NewDevice(tun[i].TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i)))
- dev.Up()
- if err := dev.IpcSetOperation(cfg[i]); err != nil {
+ for i := range pair {
+ p := &pair[i]
+ p.tun = tuntest.NewChannelTUN()
+ if i == 0 {
+ p.ip = net.ParseIP("1.0.0.1")
+ } else {
+ p.ip = net.ParseIP("1.0.0.2")
+ }
+ p.dev = NewDevice(p.tun.TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i)))
+ p.dev.Up()
+ if err := p.dev.IpcSetOperation(cfg[i]); err != nil {
// genConfigs attempted to pick ports that were free.
// There's a tiny window between genConfigs closing the port
// and us opening it, during which another process could
@@ -104,12 +162,12 @@ NextAttempt:
// The device might still not be up, e.g. due to an error
// in RoutineTUNEventReader's call to dev.Up that got swallowed.
// Assume it's due to a transient error (port in use), and retry.
- if !dev.isUp.Get() {
- t.Logf("%v did not come up, trying again", dev)
+ if !p.dev.isUp.Get() {
+ t.Logf("device %d did not come up, trying again", i)
continue NextAttempt
}
// The device is up. Close it when the test completes.
- t.Cleanup(dev.Close)
+ t.Cleanup(p.dev.Close)
}
return // success
}
@@ -119,33 +177,47 @@ NextAttempt:
}
func TestTwoDevicePing(t *testing.T) {
- tun := genChannelTUNs(t)
-
+ pair := genTestPair(t)
t.Run("ping 1.0.0.1", func(t *testing.T) {
- msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
- tun[1].Outbound <- msg2to1
- select {
- case msgRecv := <-tun[0].Inbound:
- if !bytes.Equal(msg2to1, msgRecv) {
- t.Error("ping did not transit correctly")
- }
- case <-time.After(5 * time.Second):
- t.Error("ping did not transit")
- }
+ pair.Send(t, Ping, nil)
})
-
t.Run("ping 1.0.0.2", func(t *testing.T) {
- msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
- tun[0].Outbound <- msg1to2
- select {
- case msgRecv := <-tun[1].Inbound:
- if !bytes.Equal(msg1to2, msgRecv) {
- t.Error("return ping did not transit correctly")
+ pair.Send(t, Pong, nil)
+ })
+}
+
+// TestConcurrencySafety does other things concurrently with tunnel use.
+// It is intended to be used with the race detector to catch data races.
+func TestConcurrencySafety(t *testing.T) {
+ pair := genTestPair(t)
+ done := make(chan struct{})
+
+ const warmupIters = 10
+ var warmup sync.WaitGroup
+ warmup.Add(warmupIters)
+ go func() {
+ // Send data continuously back and forth until we're done.
+ // Note that we may continue to attempt to send data
+ // even after done is closed.
+ i := warmupIters
+ for ping := Ping; ; ping = !ping {
+ pair.Send(t, ping, done)
+ select {
+ case <-done:
+ return
+ default:
+ }
+ if i > 0 {
+ warmup.Done()
+ i--
}
- case <-time.After(5 * time.Second):
- t.Error("return ping did not transit")
}
- })
+ }()
+ warmup.Wait()
+
+ // coming soon: more things here...
+
+ close(done)
}
func assertNil(t *testing.T, err error) {
diff --git a/device/send.go b/device/send.go
index 0801b71..1b16edd 100644
--- a/device/send.go
+++ b/device/send.go
@@ -352,6 +352,9 @@ func (peer *Peer) RoutineNonce() {
device := peer.device
logDebug := device.log.Debug
+ // We write to the encryption queue; keep it alive until we are done.
+ device.queue.encryption.wg.Add(1)
+
flush := func() {
for {
select {
@@ -368,6 +371,7 @@ func (peer *Peer) RoutineNonce() {
flush()
logDebug.Println(peer, "- Routine: nonce worker - stopped")
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
+ device.queue.encryption.wg.Done() // no more writes from us
peer.routines.stopping.Done()
}()
@@ -455,7 +459,7 @@ NextPacket:
elem.Lock()
// add to parallel and sequential queue
- addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
+ addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption.c, elem)
}
}
}
@@ -486,76 +490,46 @@ func (device *Device) RoutineEncryption() {
logDebug := device.log.Debug
- defer func() {
- for {
- select {
- case elem, ok := <-device.queue.encryption:
- if ok && !elem.IsDropped() {
- elem.Drop()
- device.PutMessageBuffer(elem.buffer)
- elem.Unlock()
- }
- default:
- goto out
- }
- }
- out:
- logDebug.Println("Routine: encryption worker - stopped")
- device.state.stopping.Done()
- }()
-
+ defer logDebug.Println("Routine: encryption worker - stopped")
logDebug.Println("Routine: encryption worker - started")
- for {
-
- // fetch next element
+ for elem := range device.queue.encryption.c {
- select {
- case <-device.signals.stop:
- return
-
- case elem, ok := <-device.queue.encryption:
-
- if !ok {
- return
- }
-
- // check if dropped
+ // check if dropped
- if elem.IsDropped() {
- continue
- }
+ if elem.IsDropped() {
+ continue
+ }
- // populate header fields
+ // populate header fields
- header := elem.buffer[:MessageTransportHeaderSize]
+ header := elem.buffer[:MessageTransportHeaderSize]
- fieldType := header[0:4]
- fieldReceiver := header[4:8]
- fieldNonce := header[8:16]
+ fieldType := header[0:4]
+ fieldReceiver := header[4:8]
+ fieldNonce := header[8:16]
- binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
- binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
- binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
+ binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
+ binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
- // pad content to multiple of 16
+ // pad content to multiple of 16
- paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
- for i := 0; i < paddingSize; i++ {
- elem.packet = append(elem.packet, 0)
- }
+ paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
+ for i := 0; i < paddingSize; i++ {
+ elem.packet = append(elem.packet, 0)
+ }
- // encrypt content and release to consumer
+ // encrypt content and release to consumer
- binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
- elem.packet = elem.keypair.send.Seal(
- header,
- nonce[:],
- elem.packet,
- nil,
- )
- elem.Unlock()
- }
+ binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
+ elem.packet = elem.keypair.send.Seal(
+ header,
+ nonce[:],
+ elem.packet,
+ nil,
+ )
+ elem.Unlock()
}
}
@@ -576,6 +550,7 @@ func (peer *Peer) RoutineSequentialSender() {
select {
case elem, ok := <-peer.queue.outbound:
if ok {
+ elem.Lock()
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
elem.Drop()