diff options
Diffstat (limited to 'device/device_test.go')
-rw-r--r-- | device/device_test.go | 228 |
1 files changed, 199 insertions, 29 deletions
diff --git a/device/device_test.go b/device/device_test.go index cdbd458..14cc605 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -5,54 +5,212 @@ package device -/* Create two device instances and simulate full WireGuard interaction - * without network dependencies - */ - import ( + "bufio" "bytes" + "encoding/binary" + "io" + "net" + "os" + "strings" "testing" + "time" + + "golang.zx2c4.com/wireguard/tun" ) -func TestDevice(t *testing.T) { +func TestTwoDevicePing(t *testing.T) { + // TODO(crawshaw): pick unused ports on localhost + cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58 +listen_port=53511 +replace_peers=true +public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725 +protocol_version=1 +replace_allowed_ips=true +allowed_ip=1.0.0.2/32 +endpoint=127.0.0.1:53512` + tun1 := NewChannelTUN() + dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: ")) + dev1.Up() + defer dev1.Close() + if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil { + t.Fatal(err) + } - // prepare tun devices for generating traffic + cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768 +listen_port=53512 +replace_peers=true +public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427 +protocol_version=1 +replace_allowed_ips=true +allowed_ip=1.0.0.1/32 +endpoint=127.0.0.1:53511` + tun2 := NewChannelTUN() + dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: ")) + dev2.Up() + defer dev2.Close() + if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil { + t.Fatal(err) + } - tun1 := newDummyTUN("tun1") - tun2 := newDummyTUN("tun2") + t.Run("ping 1.0.0.1", func(t *testing.T) { + msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) + tun2.Outbound <- msg2to1 + select { + case msgRecv := <-tun1.Inbound: + if !bytes.Equal(msg2to1, msgRecv) { + t.Error("ping did not transit correctly") + } + case <-time.After(300 * time.Millisecond): + t.Error("ping did not transit") + } + }) + + t.Run("ping 1.0.0.2", func(t *testing.T) { + msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) + tun1.Outbound <- msg1to2 + select { + case msgRecv := <-tun2.Inbound: + if !bytes.Equal(msg1to2, msgRecv) { + t.Error("return ping did not transit correctly") + } + case <-time.After(300 * time.Millisecond): + t.Error("return ping did not transit") + } + }) +} - _ = tun1 - _ = tun2 +func ping(dst, src net.IP) []byte { + localPort := uint16(1337) + seq := uint16(0) - // prepare endpoints + payload := make([]byte, 4) + binary.BigEndian.PutUint16(payload[0:], localPort) + binary.BigEndian.PutUint16(payload[2:], seq) - end1, err := CreateDummyEndpoint() - if err != nil { - t.Error("failed to create endpoint:", err.Error()) + return genICMPv4(payload, dst, src) +} + +// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071. +func checksum(buf []byte, initial uint16) uint16 { + v := uint32(initial) + for i := 0; i < len(buf)-1; i += 2 { + v += uint32(binary.BigEndian.Uint16(buf[i:])) + } + if len(buf)%2 == 1 { + v += uint32(buf[len(buf)-1]) << 8 + } + for v > 0xffff { + v = (v >> 16) + (v & 0xffff) } + return ^uint16(v) +} - end2, err := CreateDummyEndpoint() - if err != nil { - t.Error("failed to create endpoint:", err.Error()) +func genICMPv4(payload []byte, dst, src net.IP) []byte { + const ( + icmpv4ProtocolNumber = 1 + icmpv4Echo = 8 + icmpv4ChecksumOffset = 2 + icmpv4Size = 8 + ipv4Size = 20 + ipv4TotalLenOffset = 2 + ipv4ChecksumOffset = 10 + ttl = 65 + ) + + hdr := make([]byte, ipv4Size+icmpv4Size) + + ip := hdr[0:ipv4Size] + icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size] + + // https://tools.ietf.org/html/rfc792 + icmpv4[0] = icmpv4Echo // type + icmpv4[1] = 0 // code + chksum := ^checksum(icmpv4, checksum(payload, 0)) + binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) + + // https://tools.ietf.org/html/rfc760 section 3.1 + length := uint16(len(hdr) + len(payload)) + ip[0] = (4 << 4) | (ipv4Size / 4) + binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) + ip[8] = ttl + ip[9] = icmpv4ProtocolNumber + copy(ip[12:], src.To4()) + copy(ip[16:], dst.To4()) + chksum = ^checksum(ip[:], 0) + binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) + + var v []byte + v = append(v, hdr...) + v = append(v, payload...) + return []byte(v) +} + +// TODO(crawshaw): find a reusable home for this. package devicetest? +type ChannelTUN struct { + Inbound chan []byte // incoming packets, closed on TUN close + Outbound chan []byte // outbound packets, blocks forever on TUN close + + closed chan struct{} + events chan tun.Event + tun chTun +} + +func NewChannelTUN() *ChannelTUN { + c := &ChannelTUN{ + Inbound: make(chan []byte), + Outbound: make(chan []byte), + closed: make(chan struct{}), + events: make(chan tun.Event, 1), } + c.tun.c = c + c.events <- tun.EventUp + return c +} + +func (c *ChannelTUN) TUN() tun.Device { + return &c.tun +} - _ = end1 - _ = end2 +type chTun struct { + c *ChannelTUN +} - // create binds +func (t *chTun) File() *os.File { return nil } +func (t *chTun) Read(data []byte, offset int) (int, error) { + select { + case <-t.c.closed: + return 0, io.EOF // TODO(crawshaw): what is the correct error value? + case msg := <-t.c.Outbound: + return copy(data[offset:], msg), nil + } } -func randDevice(t *testing.T) *Device { - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) +// Write is called by the wireguard device to deliver a packet for routing. +func (t *chTun) Write(data []byte, offset int) (int, error) { + if offset == -1 { + close(t.c.closed) + close(t.c.events) + return 0, io.EOF } - tun := newDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device + msg := make([]byte, len(data)-offset) + copy(msg, data[offset:]) + select { + case <-t.c.closed: + return 0, io.EOF // TODO(crawshaw): what is the correct error value? + case t.c.Inbound <- msg: + return len(data) - offset, nil + } +} + +func (t *chTun) Flush() error { return nil } +func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } +func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } +func (t *chTun) Events() chan tun.Event { return t.c.events } +func (t *chTun) Close() error { + t.Write(nil, -1) + return nil } func assertNil(t *testing.T, err error) { @@ -66,3 +224,15 @@ func assertEqual(t *testing.T, a, b []byte) { t.Fatal(a, "!=", b) } } + +func randDevice(t *testing.T) *Device { + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + tun := newDummyTUN("dummy") + logger := NewLogger(LogLevelError, "") + device := NewDevice(tun, logger) + device.SetPrivateKey(sk) + return device +} |