summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorDavid Crawshaw <crawshaw@tailscale.com>2020-01-07 07:43:17 -0800
committerJason A. Donenfeld <Jason@zx2c4.com>2020-05-02 01:46:42 -0600
commit1a1c3d096888816c94cf1eb7c2747e83f008549f (patch)
treee5e7229c8198cf3fea04d173d6e41a99f2ce82c8
parent85a45a9651d2b38e3616a7899800601656cf143e (diff)
tuntest: split out testing package
This code is useful to other packages writing tests. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
-rw-r--r--device/device_test.go146
-rw-r--r--tun/tuntest/tuntest.go150
2 files changed, 155 insertions, 141 deletions
diff --git a/device/device_test.go b/device/device_test.go
index 14cc605..87ecfc8 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -8,15 +8,12 @@ package device
import (
"bufio"
"bytes"
- "encoding/binary"
- "io"
"net"
- "os"
"strings"
"testing"
"time"
- "golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/tun/tuntest"
)
func TestTwoDevicePing(t *testing.T) {
@@ -29,7 +26,7 @@ protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
- tun1 := NewChannelTUN()
+ tun1 := tuntest.NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
dev1.Up()
defer dev1.Close()
@@ -45,7 +42,7 @@ protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
- tun2 := NewChannelTUN()
+ tun2 := tuntest.NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
@@ -54,7 +51,7 @@ endpoint=127.0.0.1:53511`
}
t.Run("ping 1.0.0.1", func(t *testing.T) {
- msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
+ msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
@@ -67,7 +64,7 @@ endpoint=127.0.0.1:53511`
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
- msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
+ msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
@@ -80,139 +77,6 @@ endpoint=127.0.0.1:53511`
})
}
-func ping(dst, src net.IP) []byte {
- localPort := uint16(1337)
- seq := uint16(0)
-
- payload := make([]byte, 4)
- binary.BigEndian.PutUint16(payload[0:], localPort)
- binary.BigEndian.PutUint16(payload[2:], seq)
-
- return genICMPv4(payload, dst, src)
-}
-
-// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
-func checksum(buf []byte, initial uint16) uint16 {
- v := uint32(initial)
- for i := 0; i < len(buf)-1; i += 2 {
- v += uint32(binary.BigEndian.Uint16(buf[i:]))
- }
- if len(buf)%2 == 1 {
- v += uint32(buf[len(buf)-1]) << 8
- }
- for v > 0xffff {
- v = (v >> 16) + (v & 0xffff)
- }
- return ^uint16(v)
-}
-
-func genICMPv4(payload []byte, dst, src net.IP) []byte {
- const (
- icmpv4ProtocolNumber = 1
- icmpv4Echo = 8
- icmpv4ChecksumOffset = 2
- icmpv4Size = 8
- ipv4Size = 20
- ipv4TotalLenOffset = 2
- ipv4ChecksumOffset = 10
- ttl = 65
- )
-
- hdr := make([]byte, ipv4Size+icmpv4Size)
-
- ip := hdr[0:ipv4Size]
- icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
-
- // https://tools.ietf.org/html/rfc792
- icmpv4[0] = icmpv4Echo // type
- icmpv4[1] = 0 // code
- chksum := ^checksum(icmpv4, checksum(payload, 0))
- binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
-
- // https://tools.ietf.org/html/rfc760 section 3.1
- length := uint16(len(hdr) + len(payload))
- ip[0] = (4 << 4) | (ipv4Size / 4)
- binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
- ip[8] = ttl
- ip[9] = icmpv4ProtocolNumber
- copy(ip[12:], src.To4())
- copy(ip[16:], dst.To4())
- chksum = ^checksum(ip[:], 0)
- binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
-
- var v []byte
- v = append(v, hdr...)
- v = append(v, payload...)
- return []byte(v)
-}
-
-// TODO(crawshaw): find a reusable home for this. package devicetest?
-type ChannelTUN struct {
- Inbound chan []byte // incoming packets, closed on TUN close
- Outbound chan []byte // outbound packets, blocks forever on TUN close
-
- closed chan struct{}
- events chan tun.Event
- tun chTun
-}
-
-func NewChannelTUN() *ChannelTUN {
- c := &ChannelTUN{
- Inbound: make(chan []byte),
- Outbound: make(chan []byte),
- closed: make(chan struct{}),
- events: make(chan tun.Event, 1),
- }
- c.tun.c = c
- c.events <- tun.EventUp
- return c
-}
-
-func (c *ChannelTUN) TUN() tun.Device {
- return &c.tun
-}
-
-type chTun struct {
- c *ChannelTUN
-}
-
-func (t *chTun) File() *os.File { return nil }
-
-func (t *chTun) Read(data []byte, offset int) (int, error) {
- select {
- case <-t.c.closed:
- return 0, io.EOF // TODO(crawshaw): what is the correct error value?
- case msg := <-t.c.Outbound:
- return copy(data[offset:], msg), nil
- }
-}
-
-// Write is called by the wireguard device to deliver a packet for routing.
-func (t *chTun) Write(data []byte, offset int) (int, error) {
- if offset == -1 {
- close(t.c.closed)
- close(t.c.events)
- return 0, io.EOF
- }
- msg := make([]byte, len(data)-offset)
- copy(msg, data[offset:])
- select {
- case <-t.c.closed:
- return 0, io.EOF // TODO(crawshaw): what is the correct error value?
- case t.c.Inbound <- msg:
- return len(data) - offset, nil
- }
-}
-
-func (t *chTun) Flush() error { return nil }
-func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
-func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
-func (t *chTun) Events() chan tun.Event { return t.c.events }
-func (t *chTun) Close() error {
- t.Write(nil, -1)
- return nil
-}
-
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go
new file mode 100644
index 0000000..bdd96ac
--- /dev/null
+++ b/tun/tuntest/tuntest.go
@@ -0,0 +1,150 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package tuntest
+
+import (
+ "encoding/binary"
+ "io"
+ "net"
+ "os"
+
+ "golang.zx2c4.com/wireguard/tun"
+)
+
+func Ping(dst, src net.IP) []byte {
+ localPort := uint16(1337)
+ seq := uint16(0)
+
+ payload := make([]byte, 4)
+ binary.BigEndian.PutUint16(payload[0:], localPort)
+ binary.BigEndian.PutUint16(payload[2:], seq)
+
+ return genICMPv4(payload, dst, src)
+}
+
+// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
+func checksum(buf []byte, initial uint16) uint16 {
+ v := uint32(initial)
+ for i := 0; i < len(buf)-1; i += 2 {
+ v += uint32(binary.BigEndian.Uint16(buf[i:]))
+ }
+ if len(buf)%2 == 1 {
+ v += uint32(buf[len(buf)-1]) << 8
+ }
+ for v > 0xffff {
+ v = (v >> 16) + (v & 0xffff)
+ }
+ return ^uint16(v)
+}
+
+func genICMPv4(payload []byte, dst, src net.IP) []byte {
+ const (
+ icmpv4ProtocolNumber = 1
+ icmpv4Echo = 8
+ icmpv4ChecksumOffset = 2
+ icmpv4Size = 8
+ ipv4Size = 20
+ ipv4TotalLenOffset = 2
+ ipv4ChecksumOffset = 10
+ ttl = 65
+ )
+
+ hdr := make([]byte, ipv4Size+icmpv4Size)
+
+ ip := hdr[0:ipv4Size]
+ icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
+
+ // https://tools.ietf.org/html/rfc792
+ icmpv4[0] = icmpv4Echo // type
+ icmpv4[1] = 0 // code
+ chksum := ^checksum(icmpv4, checksum(payload, 0))
+ binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
+
+ // https://tools.ietf.org/html/rfc760 section 3.1
+ length := uint16(len(hdr) + len(payload))
+ ip[0] = (4 << 4) | (ipv4Size / 4)
+ binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
+ ip[8] = ttl
+ ip[9] = icmpv4ProtocolNumber
+ copy(ip[12:], src.To4())
+ copy(ip[16:], dst.To4())
+ chksum = ^checksum(ip[:], 0)
+ binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
+
+ var v []byte
+ v = append(v, hdr...)
+ v = append(v, payload...)
+ return []byte(v)
+}
+
+// TODO(crawshaw): find a reusable home for this. package devicetest?
+type ChannelTUN struct {
+ Inbound chan []byte // incoming packets, closed on TUN close
+ Outbound chan []byte // outbound packets, blocks forever on TUN close
+
+ closed chan struct{}
+ events chan tun.Event
+ tun chTun
+}
+
+func NewChannelTUN() *ChannelTUN {
+ c := &ChannelTUN{
+ Inbound: make(chan []byte),
+ Outbound: make(chan []byte),
+ closed: make(chan struct{}),
+ events: make(chan tun.Event, 1),
+ }
+ c.tun.c = c
+ c.events <- tun.EventUp
+ return c
+}
+
+func (c *ChannelTUN) TUN() tun.Device {
+ return &c.tun
+}
+
+type chTun struct {
+ c *ChannelTUN
+}
+
+func (t *chTun) File() *os.File { return nil }
+
+func (t *chTun) Read(data []byte, offset int) (int, error) {
+ select {
+ case <-t.c.closed:
+ return 0, io.EOF // TODO(crawshaw): what is the correct error value?
+ case msg := <-t.c.Outbound:
+ return copy(data[offset:], msg), nil
+ }
+}
+
+// Write is called by the wireguard device to deliver a packet for routing.
+func (t *chTun) Write(data []byte, offset int) (int, error) {
+ if offset == -1 {
+ close(t.c.closed)
+ close(t.c.events)
+ return 0, io.EOF
+ }
+ msg := make([]byte, len(data)-offset)
+ copy(msg, data[offset:])
+ select {
+ case <-t.c.closed:
+ return 0, io.EOF // TODO(crawshaw): what is the correct error value?
+ case t.c.Inbound <- msg:
+ return len(data) - offset, nil
+ }
+}
+
+const DefaultMTU = 1420
+
+func (t *chTun) Flush() error { return nil }
+func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
+func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
+func (t *chTun) Events() chan tun.Event { return t.c.events }
+func (t *chTun) Close() error {
+ t.Write(nil, -1)
+ return nil
+}