diff options
Diffstat (limited to 'device/device_test.go')
-rw-r--r-- | device/device_test.go | 136 |
1 files changed, 104 insertions, 32 deletions
diff --git a/device/device_test.go b/device/device_test.go index a89dcc2..65942ec 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -7,9 +7,11 @@ package device import ( "bytes" + "errors" "fmt" "io" "net" + "sync" "testing" "time" @@ -79,18 +81,74 @@ func genConfigs(t *testing.T) (cfgs [2]io.Reader) { return } -// genChannelTUNs creates a usable pair of ChannelTUNs for use in a test. -func genChannelTUNs(t *testing.T) (tun [2]*tuntest.ChannelTUN) { +// A testPair is a pair of testPeers. +type testPair [2]testPeer + +// A testPeer is a peer used for testing. +type testPeer struct { + tun *tuntest.ChannelTUN + dev *Device + ip net.IP +} + +type SendDirection bool + +const ( + Ping SendDirection = true + Pong SendDirection = false +) + +func (pair *testPair) Send(t *testing.T, ping SendDirection, done chan struct{}) { + t.Helper() + p0, p1 := pair[0], pair[1] + if !ping { + // pong is the new ping + p0, p1 = p1, p0 + } + msg := tuntest.Ping(p0.ip, p1.ip) + p1.tun.Outbound <- msg + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + var err error + select { + case msgRecv := <-p0.tun.Inbound: + if !bytes.Equal(msg, msgRecv) { + err = errors.New("ping did not transit correctly") + } + case <-timer.C: + err = errors.New("ping did not transit") + case <-done: + } + if err != nil { + // The error may have occurred because the test is done. + select { + case <-done: + return + default: + } + // Real error. + t.Error(err) + } +} + +// genTestPair creates a testPair. +func genTestPair(t *testing.T) (pair testPair) { const maxAttempts = 10 NextAttempt: for i := 0; i < maxAttempts; i++ { cfg := genConfigs(t) // Bring up a ChannelTun for each config. - for i := range tun { - tun[i] = tuntest.NewChannelTUN() - dev := NewDevice(tun[i].TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i))) - dev.Up() - if err := dev.IpcSetOperation(cfg[i]); err != nil { + for i := range pair { + p := &pair[i] + p.tun = tuntest.NewChannelTUN() + if i == 0 { + p.ip = net.ParseIP("1.0.0.1") + } else { + p.ip = net.ParseIP("1.0.0.2") + } + p.dev = NewDevice(p.tun.TUN(), NewLogger(LogLevelDebug, fmt.Sprintf("dev%d: ", i))) + p.dev.Up() + if err := p.dev.IpcSetOperation(cfg[i]); err != nil { // genConfigs attempted to pick ports that were free. // There's a tiny window between genConfigs closing the port // and us opening it, during which another process could @@ -104,12 +162,12 @@ NextAttempt: // The device might still not be up, e.g. due to an error // in RoutineTUNEventReader's call to dev.Up that got swallowed. // Assume it's due to a transient error (port in use), and retry. - if !dev.isUp.Get() { - t.Logf("%v did not come up, trying again", dev) + if !p.dev.isUp.Get() { + t.Logf("device %d did not come up, trying again", i) continue NextAttempt } // The device is up. Close it when the test completes. - t.Cleanup(dev.Close) + t.Cleanup(p.dev.Close) } return // success } @@ -119,33 +177,47 @@ NextAttempt: } func TestTwoDevicePing(t *testing.T) { - tun := genChannelTUNs(t) - + pair := genTestPair(t) t.Run("ping 1.0.0.1", func(t *testing.T) { - msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) - tun[1].Outbound <- msg2to1 - select { - case msgRecv := <-tun[0].Inbound: - if !bytes.Equal(msg2to1, msgRecv) { - t.Error("ping did not transit correctly") - } - case <-time.After(5 * time.Second): - t.Error("ping did not transit") - } + pair.Send(t, Ping, nil) }) - t.Run("ping 1.0.0.2", func(t *testing.T) { - msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - tun[0].Outbound <- msg1to2 - select { - case msgRecv := <-tun[1].Inbound: - if !bytes.Equal(msg1to2, msgRecv) { - t.Error("return ping did not transit correctly") + pair.Send(t, Pong, nil) + }) +} + +// TestConcurrencySafety does other things concurrently with tunnel use. +// It is intended to be used with the race detector to catch data races. +func TestConcurrencySafety(t *testing.T) { + pair := genTestPair(t) + done := make(chan struct{}) + + const warmupIters = 10 + var warmup sync.WaitGroup + warmup.Add(warmupIters) + go func() { + // Send data continuously back and forth until we're done. + // Note that we may continue to attempt to send data + // even after done is closed. + i := warmupIters + for ping := Ping; ; ping = !ping { + pair.Send(t, ping, done) + select { + case <-done: + return + default: + } + if i > 0 { + warmup.Done() + i-- } - case <-time.After(5 * time.Second): - t.Error("return ping did not transit") } - }) + }() + warmup.Wait() + + // coming soon: more things here... + + close(done) } func assertNil(t *testing.T, err error) { |