summaryrefslogtreecommitdiffhomepage
path: root/device
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2019-03-03 04:04:41 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2019-03-03 05:00:40 +0100
commit69f0fe67b63d90e523a5a1241fb1b46c2e8dbe03 (patch)
tree1ef86da3242afde462dcadb7241bb09f499d5bd7 /device
parentd435be35cac49af9367b2005d831d55e570c4b1b (diff)
global: begin modularization
Diffstat (limited to 'device')
-rw-r--r--device/allowedips.go251
-rw-r--r--device/allowedips_rand_test.go131
-rw-r--r--device/allowedips_test.go260
-rw-r--r--device/bind_test.go55
-rw-r--r--device/conn.go180
-rw-r--r--device/conn_default.go170
-rw-r--r--device/conn_linux.go746
-rw-r--r--device/constants.go41
-rw-r--r--device/cookie.go250
-rw-r--r--device/cookie_test.go191
-rw-r--r--device/device.go396
-rw-r--r--device/device_test.go48
-rw-r--r--device/endpoint_test.go53
-rw-r--r--device/indextable.go97
-rw-r--r--device/ip.go22
-rw-r--r--device/kdf_test.go84
-rw-r--r--device/keypair.go50
-rw-r--r--device/logger.go59
-rw-r--r--device/mark_default.go12
-rw-r--r--device/mark_unix.go64
-rw-r--r--device/misc.go48
-rw-r--r--device/noise-helpers.go104
-rw-r--r--device/noise-protocol.go600
-rw-r--r--device/noise-types.go81
-rw-r--r--device/noise_test.go144
-rw-r--r--device/peer.go270
-rw-r--r--device/pools.go89
-rw-r--r--device/queueconstants.go16
-rw-r--r--device/receive.go641
-rw-r--r--device/send.go618
-rw-r--r--device/timers.go227
-rw-r--r--device/tun.go55
-rw-r--r--device/uapi.go426
-rw-r--r--device/version.go3
34 files changed, 6482 insertions, 0 deletions
diff --git a/device/allowedips.go b/device/allowedips.go
new file mode 100644
index 0000000..efc27c0
--- /dev/null
+++ b/device/allowedips.go
@@ -0,0 +1,251 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "errors"
+ "math/bits"
+ "net"
+ "sync"
+ "unsafe"
+)
+
+type trieEntry struct {
+ cidr uint
+ child [2]*trieEntry
+ bits net.IP
+ peer *Peer
+
+ // index of "branching" bit
+
+ bit_at_byte uint
+ bit_at_shift uint
+}
+
+func isLittleEndian() bool {
+ one := uint32(1)
+ return *(*byte)(unsafe.Pointer(&one)) != 0
+}
+
+func swapU32(i uint32) uint32 {
+ if !isLittleEndian() {
+ return i
+ }
+
+ return bits.ReverseBytes32(i)
+}
+
+func swapU64(i uint64) uint64 {
+ if !isLittleEndian() {
+ return i
+ }
+
+ return bits.ReverseBytes64(i)
+}
+
+func commonBits(ip1 net.IP, ip2 net.IP) uint {
+ size := len(ip1)
+ if size == net.IPv4len {
+ a := (*uint32)(unsafe.Pointer(&ip1[0]))
+ b := (*uint32)(unsafe.Pointer(&ip2[0]))
+ x := *a ^ *b
+ return uint(bits.LeadingZeros32(swapU32(x)))
+ } else if size == net.IPv6len {
+ a := (*uint64)(unsafe.Pointer(&ip1[0]))
+ b := (*uint64)(unsafe.Pointer(&ip2[0]))
+ x := *a ^ *b
+ if x != 0 {
+ return uint(bits.LeadingZeros64(swapU64(x)))
+ }
+ a = (*uint64)(unsafe.Pointer(&ip1[8]))
+ b = (*uint64)(unsafe.Pointer(&ip2[8]))
+ x = *a ^ *b
+ return 64 + uint(bits.LeadingZeros64(swapU64(x)))
+ } else {
+ panic("Wrong size bit string")
+ }
+}
+
+func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
+ if node == nil {
+ return node
+ }
+
+ // walk recursively
+
+ node.child[0] = node.child[0].removeByPeer(p)
+ node.child[1] = node.child[1].removeByPeer(p)
+
+ if node.peer != p {
+ return node
+ }
+
+ // remove peer & merge
+
+ node.peer = nil
+ if node.child[0] == nil {
+ return node.child[1]
+ }
+ return node.child[0]
+}
+
+func (node *trieEntry) choose(ip net.IP) byte {
+ return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
+}
+
+func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
+
+ // at leaf
+
+ if node == nil {
+ return &trieEntry{
+ bits: ip,
+ peer: peer,
+ cidr: cidr,
+ bit_at_byte: cidr / 8,
+ bit_at_shift: 7 - (cidr % 8),
+ }
+ }
+
+ // traverse deeper
+
+ common := commonBits(node.bits, ip)
+ if node.cidr <= cidr && common >= node.cidr {
+ if node.cidr == cidr {
+ node.peer = peer
+ return node
+ }
+ bit := node.choose(ip)
+ node.child[bit] = node.child[bit].insert(ip, cidr, peer)
+ return node
+ }
+
+ // split node
+
+ newNode := &trieEntry{
+ bits: ip,
+ peer: peer,
+ cidr: cidr,
+ bit_at_byte: cidr / 8,
+ bit_at_shift: 7 - (cidr % 8),
+ }
+
+ cidr = min(cidr, common)
+
+ // check for shorter prefix
+
+ if newNode.cidr == cidr {
+ bit := newNode.choose(node.bits)
+ newNode.child[bit] = node
+ return newNode
+ }
+
+ // create new parent for node & newNode
+
+ parent := &trieEntry{
+ bits: ip,
+ peer: nil,
+ cidr: cidr,
+ bit_at_byte: cidr / 8,
+ bit_at_shift: 7 - (cidr % 8),
+ }
+
+ bit := parent.choose(ip)
+ parent.child[bit] = newNode
+ parent.child[bit^1] = node
+
+ return parent
+}
+
+func (node *trieEntry) lookup(ip net.IP) *Peer {
+ var found *Peer
+ size := uint(len(ip))
+ for node != nil && commonBits(node.bits, ip) >= node.cidr {
+ if node.peer != nil {
+ found = node.peer
+ }
+ if node.bit_at_byte == size {
+ break
+ }
+ bit := node.choose(ip)
+ node = node.child[bit]
+ }
+ return found
+}
+
+func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
+ if node == nil {
+ return results
+ }
+ if node.peer == p {
+ mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
+ results = append(results, net.IPNet{
+ Mask: mask,
+ IP: node.bits.Mask(mask),
+ })
+ }
+ results = node.child[0].entriesForPeer(p, results)
+ results = node.child[1].entriesForPeer(p, results)
+ return results
+}
+
+type AllowedIPs struct {
+ IPv4 *trieEntry
+ IPv6 *trieEntry
+ mutex sync.RWMutex
+}
+
+func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+
+ allowed := make([]net.IPNet, 0, 10)
+ allowed = table.IPv4.entriesForPeer(peer, allowed)
+ allowed = table.IPv6.entriesForPeer(peer, allowed)
+ return allowed
+}
+
+func (table *AllowedIPs) Reset() {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ table.IPv4 = nil
+ table.IPv6 = nil
+}
+
+func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ table.IPv4 = table.IPv4.removeByPeer(peer)
+ table.IPv6 = table.IPv6.removeByPeer(peer)
+}
+
+func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
+ table.mutex.Lock()
+ defer table.mutex.Unlock()
+
+ switch len(ip) {
+ case net.IPv6len:
+ table.IPv6 = table.IPv6.insert(ip, cidr, peer)
+ case net.IPv4len:
+ table.IPv4 = table.IPv4.insert(ip, cidr, peer)
+ default:
+ panic(errors.New("inserting unknown address type"))
+ }
+}
+
+func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.IPv4.lookup(address)
+}
+
+func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+ return table.IPv6.lookup(address)
+}
diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go
new file mode 100644
index 0000000..59c10f7
--- /dev/null
+++ b/device/allowedips_rand_test.go
@@ -0,0 +1,131 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "math/rand"
+ "sort"
+ "testing"
+)
+
+const (
+ NumberOfPeers = 100
+ NumberOfAddresses = 250
+ NumberOfTests = 10000
+)
+
+type SlowNode struct {
+ peer *Peer
+ cidr uint
+ bits []byte
+}
+
+type SlowRouter []*SlowNode
+
+func (r SlowRouter) Len() int {
+ return len(r)
+}
+
+func (r SlowRouter) Less(i, j int) bool {
+ return r[i].cidr > r[j].cidr
+}
+
+func (r SlowRouter) Swap(i, j int) {
+ r[i], r[j] = r[j], r[i]
+}
+
+func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
+ for _, t := range r {
+ if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
+ t.peer = peer
+ t.bits = addr
+ return r
+ }
+ }
+ r = append(r, &SlowNode{
+ cidr: cidr,
+ bits: addr,
+ peer: peer,
+ })
+ sort.Sort(r)
+ return r
+}
+
+func (r SlowRouter) Lookup(addr []byte) *Peer {
+ for _, t := range r {
+ common := commonBits(t.bits, addr)
+ if common >= t.cidr {
+ return t.peer
+ }
+ }
+ return nil
+}
+
+func TestTrieRandomIPv4(t *testing.T) {
+ var trie *trieEntry
+ var slow SlowRouter
+ var peers []*Peer
+
+ rand.Seed(1)
+
+ const AddressLength = 4
+
+ for n := 0; n < NumberOfPeers; n += 1 {
+ peers = append(peers, &Peer{})
+ }
+
+ for n := 0; n < NumberOfAddresses; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ cidr := uint(rand.Uint32() % (AddressLength * 8))
+ index := rand.Int() % NumberOfPeers
+ trie = trie.insert(addr[:], cidr, peers[index])
+ slow = slow.Insert(addr[:], cidr, peers[index])
+ }
+
+ for n := 0; n < NumberOfTests; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ peer1 := slow.Lookup(addr[:])
+ peer2 := trie.lookup(addr[:])
+ if peer1 != peer2 {
+ t.Error("Trie did not match naive implementation, for:", addr)
+ }
+ }
+}
+
+func TestTrieRandomIPv6(t *testing.T) {
+ var trie *trieEntry
+ var slow SlowRouter
+ var peers []*Peer
+
+ rand.Seed(1)
+
+ const AddressLength = 16
+
+ for n := 0; n < NumberOfPeers; n += 1 {
+ peers = append(peers, &Peer{})
+ }
+
+ for n := 0; n < NumberOfAddresses; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ cidr := uint(rand.Uint32() % (AddressLength * 8))
+ index := rand.Int() % NumberOfPeers
+ trie = trie.insert(addr[:], cidr, peers[index])
+ slow = slow.Insert(addr[:], cidr, peers[index])
+ }
+
+ for n := 0; n < NumberOfTests; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ peer1 := slow.Lookup(addr[:])
+ peer2 := trie.lookup(addr[:])
+ if peer1 != peer2 {
+ t.Error("Trie did not match naive implementation, for:", addr)
+ }
+ }
+}
diff --git a/device/allowedips_test.go b/device/allowedips_test.go
new file mode 100644
index 0000000..075ff06
--- /dev/null
+++ b/device/allowedips_test.go
@@ -0,0 +1,260 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "math/rand"
+ "net"
+ "testing"
+)
+
+/* Todo: More comprehensive
+ */
+
+type testPairCommonBits struct {
+ s1 []byte
+ s2 []byte
+ match uint
+}
+
+type testPairTrieInsert struct {
+ key []byte
+ cidr uint
+ peer *Peer
+}
+
+type testPairTrieLookup struct {
+ key []byte
+ peer *Peer
+}
+
+func printTrie(t *testing.T, p *trieEntry) {
+ if p == nil {
+ return
+ }
+ t.Log(p)
+ printTrie(t, p.child[0])
+ printTrie(t, p.child[1])
+}
+
+func TestCommonBits(t *testing.T) {
+
+ tests := []testPairCommonBits{
+ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
+ {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
+ {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
+ {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
+ {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
+ }
+
+ for _, p := range tests {
+ v := commonBits(p.s1, p.s2)
+ if v != p.match {
+ t.Error(
+ "For slice", p.s1, p.s2,
+ "expected match", p.match,
+ ",but got", v,
+ )
+ }
+ }
+}
+
+func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
+ var trie *trieEntry
+ var peers []*Peer
+
+ rand.Seed(1)
+
+ const AddressLength = 4
+
+ for n := 0; n < peerNumber; n += 1 {
+ peers = append(peers, &Peer{})
+ }
+
+ for n := 0; n < addressNumber; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ cidr := uint(rand.Uint32() % (AddressLength * 8))
+ index := rand.Int() % peerNumber
+ trie = trie.insert(addr[:], cidr, peers[index])
+ }
+
+ for n := 0; n < b.N; n += 1 {
+ var addr [AddressLength]byte
+ rand.Read(addr[:])
+ trie.lookup(addr[:])
+ }
+}
+
+func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
+ benchmarkTrie(100, 1000, net.IPv4len, b)
+}
+
+func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
+ benchmarkTrie(10, 10, net.IPv4len, b)
+}
+
+func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
+ benchmarkTrie(100, 1000, net.IPv6len, b)
+}
+
+func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
+ benchmarkTrie(10, 10, net.IPv6len, b)
+}
+
+/* Test ported from kernel implementation:
+ * selftest/allowedips.h
+ */
+func TestTrieIPv4(t *testing.T) {
+ a := &Peer{}
+ b := &Peer{}
+ c := &Peer{}
+ d := &Peer{}
+ e := &Peer{}
+ g := &Peer{}
+ h := &Peer{}
+
+ var trie *trieEntry
+
+ insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
+ trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
+ }
+
+ assertEQ := func(peer *Peer, a, b, c, d byte) {
+ p := trie.lookup([]byte{a, b, c, d})
+ if p != peer {
+ t.Error("Assert EQ failed")
+ }
+ }
+
+ assertNEQ := func(peer *Peer, a, b, c, d byte) {
+ p := trie.lookup([]byte{a, b, c, d})
+ if p == peer {
+ t.Error("Assert NEQ failed")
+ }
+ }
+
+ insert(a, 192, 168, 4, 0, 24)
+ insert(b, 192, 168, 4, 4, 32)
+ insert(c, 192, 168, 0, 0, 16)
+ insert(d, 192, 95, 5, 64, 27)
+ insert(c, 192, 95, 5, 65, 27)
+ insert(e, 0, 0, 0, 0, 0)
+ insert(g, 64, 15, 112, 0, 20)
+ insert(h, 64, 15, 123, 211, 25)
+ insert(a, 10, 0, 0, 0, 25)
+ insert(b, 10, 0, 0, 128, 25)
+ insert(a, 10, 1, 0, 0, 30)
+ insert(b, 10, 1, 0, 4, 30)
+ insert(c, 10, 1, 0, 8, 29)
+ insert(d, 10, 1, 0, 16, 29)
+
+ assertEQ(a, 192, 168, 4, 20)
+ assertEQ(a, 192, 168, 4, 0)
+ assertEQ(b, 192, 168, 4, 4)
+ assertEQ(c, 192, 168, 200, 182)
+ assertEQ(c, 192, 95, 5, 68)
+ assertEQ(e, 192, 95, 5, 96)
+ assertEQ(g, 64, 15, 116, 26)
+ assertEQ(g, 64, 15, 127, 3)
+
+ insert(a, 1, 0, 0, 0, 32)
+ insert(a, 64, 0, 0, 0, 32)
+ insert(a, 128, 0, 0, 0, 32)
+ insert(a, 192, 0, 0, 0, 32)
+ insert(a, 255, 0, 0, 0, 32)
+
+ assertEQ(a, 1, 0, 0, 0)
+ assertEQ(a, 64, 0, 0, 0)
+ assertEQ(a, 128, 0, 0, 0)
+ assertEQ(a, 192, 0, 0, 0)
+ assertEQ(a, 255, 0, 0, 0)
+
+ trie = trie.removeByPeer(a)
+
+ assertNEQ(a, 1, 0, 0, 0)
+ assertNEQ(a, 64, 0, 0, 0)
+ assertNEQ(a, 128, 0, 0, 0)
+ assertNEQ(a, 192, 0, 0, 0)
+ assertNEQ(a, 255, 0, 0, 0)
+
+ trie = nil
+
+ insert(a, 192, 168, 0, 0, 16)
+ insert(a, 192, 168, 0, 0, 24)
+
+ trie = trie.removeByPeer(a)
+
+ assertNEQ(a, 192, 168, 0, 1)
+}
+
+/* Test ported from kernel implementation:
+ * selftest/allowedips.h
+ */
+func TestTrieIPv6(t *testing.T) {
+ a := &Peer{}
+ b := &Peer{}
+ c := &Peer{}
+ d := &Peer{}
+ e := &Peer{}
+ f := &Peer{}
+ g := &Peer{}
+ h := &Peer{}
+
+ var trie *trieEntry
+
+ expand := func(a uint32) []byte {
+ var out [4]byte
+ out[0] = byte(a >> 24 & 0xff)
+ out[1] = byte(a >> 16 & 0xff)
+ out[2] = byte(a >> 8 & 0xff)
+ out[3] = byte(a & 0xff)
+ return out[:]
+ }
+
+ insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
+ var addr []byte
+ addr = append(addr, expand(a)...)
+ addr = append(addr, expand(b)...)
+ addr = append(addr, expand(c)...)
+ addr = append(addr, expand(d)...)
+ trie = trie.insert(addr, cidr, peer)
+ }
+
+ assertEQ := func(peer *Peer, a, b, c, d uint32) {
+ var addr []byte
+ addr = append(addr, expand(a)...)
+ addr = append(addr, expand(b)...)
+ addr = append(addr, expand(c)...)
+ addr = append(addr, expand(d)...)
+ p := trie.lookup(addr)
+ if p != peer {
+ t.Error("Assert EQ failed")
+ }
+ }
+
+ insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
+ insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
+ insert(e, 0, 0, 0, 0, 0)
+ insert(f, 0, 0, 0, 0, 0)
+ insert(g, 0x24046800, 0, 0, 0, 32)
+ insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
+ insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
+ insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
+ insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
+
+ assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
+ assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
+ assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
+ assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
+ assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
+ assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
+ assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
+ assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
+ assertEQ(h, 0x24046800, 0x40040800, 0, 0)
+ assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
+ assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
+}
diff --git a/device/bind_test.go b/device/bind_test.go
new file mode 100644
index 0000000..0c2e2cf
--- /dev/null
+++ b/device/bind_test.go
@@ -0,0 +1,55 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import "errors"
+
+type DummyDatagram struct {
+ msg []byte
+ endpoint Endpoint
+ world bool // better type
+}
+
+type DummyBind struct {
+ in6 chan DummyDatagram
+ ou6 chan DummyDatagram
+ in4 chan DummyDatagram
+ ou4 chan DummyDatagram
+ closed bool
+}
+
+func (b *DummyBind) SetMark(v uint32) error {
+ return nil
+}
+
+func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ datagram, ok := <-b.in6
+ if !ok {
+ return 0, nil, errors.New("closed")
+ }
+ copy(buff, datagram.msg)
+ return len(datagram.msg), datagram.endpoint, nil
+}
+
+func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ datagram, ok := <-b.in4
+ if !ok {
+ return 0, nil, errors.New("closed")
+ }
+ copy(buff, datagram.msg)
+ return len(datagram.msg), datagram.endpoint, nil
+}
+
+func (b *DummyBind) Close() error {
+ close(b.in6)
+ close(b.in4)
+ b.closed = true
+ return nil
+}
+
+func (b *DummyBind) Send(buff []byte, end Endpoint) error {
+ return nil
+}
diff --git a/device/conn.go b/device/conn.go
new file mode 100644
index 0000000..2594680
--- /dev/null
+++ b/device/conn.go
@@ -0,0 +1,180 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "errors"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "net"
+)
+
+const (
+ ConnRoutineNumber = 2
+)
+
+/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
+ */
+type Bind interface {
+ SetMark(value uint32) error
+ ReceiveIPv6(buff []byte) (int, Endpoint, error)
+ ReceiveIPv4(buff []byte) (int, Endpoint, error)
+ Send(buff []byte, end Endpoint) error
+ Close() error
+}
+
+/* An Endpoint maintains the source/destination caching for a peer
+ *
+ * dst : the remote address of a peer ("endpoint" in uapi terminology)
+ * src : the local address from which datagrams originate going to the peer
+ */
+type Endpoint interface {
+ ClearSrc() // clears the source address
+ SrcToString() string // returns the local source address (ip:port)
+ DstToString() string // returns the destination address (ip:port)
+ DstToBytes() []byte // used for mac2 cookie calculations
+ DstIP() net.IP
+ SrcIP() net.IP
+}
+
+func parseEndpoint(s string) (*net.UDPAddr, error) {
+
+ // ensure that the host is an IP address
+
+ host, _, err := net.SplitHostPort(s)
+ if err != nil {
+ return nil, err
+ }
+ if ip := net.ParseIP(host); ip == nil {
+ return nil, errors.New("Failed to parse IP address: " + host)
+ }
+
+ // parse address and port
+
+ addr, err := net.ResolveUDPAddr("udp", s)
+ if err != nil {
+ return nil, err
+ }
+ ip4 := addr.IP.To4()
+ if ip4 != nil {
+ addr.IP = ip4
+ }
+ return addr, err
+}
+
+func unsafeCloseBind(device *Device) error {
+ var err error
+ netc := &device.net
+ if netc.bind != nil {
+ err = netc.bind.Close()
+ netc.bind = nil
+ }
+ netc.stopping.Wait()
+ return err
+}
+
+func (device *Device) BindSetMark(mark uint32) error {
+
+ device.net.Lock()
+ defer device.net.Unlock()
+
+ // check if modified
+
+ if device.net.fwmark == mark {
+ return nil
+ }
+
+ // update fwmark on existing bind
+
+ device.net.fwmark = mark
+ if device.isUp.Get() && device.net.bind != nil {
+ if err := device.net.bind.SetMark(mark); err != nil {
+ return err
+ }
+ }
+
+ // clear cached source addresses
+
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Lock()
+ defer peer.Unlock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ }
+ device.peers.RUnlock()
+
+ return nil
+}
+
+func (device *Device) BindUpdate() error {
+
+ device.net.Lock()
+ defer device.net.Unlock()
+
+ // close existing sockets
+
+ if err := unsafeCloseBind(device); err != nil {
+ return err
+ }
+
+ // open new sockets
+
+ if device.isUp.Get() {
+
+ // bind to new port
+
+ var err error
+ netc := &device.net
+ netc.bind, netc.port, err = CreateBind(netc.port, device)
+ if err != nil {
+ netc.bind = nil
+ netc.port = 0
+ return err
+ }
+
+ // set fwmark
+
+ if netc.fwmark != 0 {
+ err = netc.bind.SetMark(netc.fwmark)
+ if err != nil {
+ return err
+ }
+ }
+
+ // clear cached source addresses
+
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Lock()
+ defer peer.Unlock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ }
+ device.peers.RUnlock()
+
+ // start receiving routines
+
+ device.net.starting.Add(ConnRoutineNumber)
+ device.net.stopping.Add(ConnRoutineNumber)
+ go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
+ go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
+ device.net.starting.Wait()
+
+ device.log.Debug.Println("UDP bind has been updated")
+ }
+
+ return nil
+}
+
+func (device *Device) BindClose() error {
+ device.net.Lock()
+ err := unsafeCloseBind(device)
+ device.net.Unlock()
+ return err
+} \ No newline at end of file
diff --git a/device/conn_default.go b/device/conn_default.go
new file mode 100644
index 0000000..8a86719
--- /dev/null
+++ b/device/conn_default.go
@@ -0,0 +1,170 @@
+// +build !linux android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "net"
+ "os"
+ "syscall"
+)
+
+/* This code is meant to be a temporary solution
+ * on platforms for which the sticky socket / source caching behavior
+ * has not yet been implemented.
+ *
+ * See conn_linux.go for an implementation on the linux platform.
+ */
+
+type NativeBind struct {
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+}
+
+type NativeEndpoint net.UDPAddr
+
+var _ Bind = (*NativeBind)(nil)
+var _ Endpoint = (*NativeEndpoint)(nil)
+
+func CreateEndpoint(s string) (Endpoint, error) {
+ addr, err := parseEndpoint(s)
+ return (*NativeEndpoint)(addr), err
+}
+
+func (_ *NativeEndpoint) ClearSrc() {}
+
+func (e *NativeEndpoint) DstIP() net.IP {
+ return (*net.UDPAddr)(e).IP
+}
+
+func (e *NativeEndpoint) SrcIP() net.IP {
+ return nil // not supported
+}
+
+func (e *NativeEndpoint) DstToBytes() []byte {
+ addr := (*net.UDPAddr)(e)
+ out := addr.IP.To4()
+ if out == nil {
+ out = addr.IP
+ }
+ out = append(out, byte(addr.Port&0xff))
+ out = append(out, byte((addr.Port>>8)&0xff))
+ return out
+}
+
+func (e *NativeEndpoint) DstToString() string {
+ return (*net.UDPAddr)(e).String()
+}
+
+func (e *NativeEndpoint) SrcToString() string {
+ return ""
+}
+
+func listenNet(network string, port int) (*net.UDPConn, int, error) {
+
+ // listen
+
+ conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // retrieve port
+
+ laddr := conn.LocalAddr()
+ uaddr, err := net.ResolveUDPAddr(
+ laddr.Network(),
+ laddr.String(),
+ )
+ if err != nil {
+ return nil, 0, err
+ }
+ return conn, uaddr.Port, nil
+}
+
+func extractErrno(err error) error {
+ opErr, ok := err.(*net.OpError)
+ if !ok {
+ return nil
+ }
+ syscallErr, ok := opErr.Err.(*os.SyscallError)
+ if !ok {
+ return nil
+ }
+ return syscallErr.Err
+}
+
+func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
+ var err error
+ var bind NativeBind
+
+ port := int(uport)
+
+ bind.ipv4, port, err = listenNet("udp4", port)
+ if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
+ return nil, 0, err
+ }
+
+ bind.ipv6, port, err = listenNet("udp6", port)
+ if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
+ bind.ipv4.Close()
+ bind.ipv4 = nil
+ return nil, 0, err
+ }
+
+ return &bind, uint16(port), nil
+}
+
+func (bind *NativeBind) Close() error {
+ var err1, err2 error
+ if bind.ipv4 != nil {
+ err1 = bind.ipv4.Close()
+ }
+ if bind.ipv6 != nil {
+ err2 = bind.ipv6.Close()
+ }
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ if bind.ipv4 == nil {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
+ if endpoint != nil {
+ endpoint.IP = endpoint.IP.To4()
+ }
+ return n, (*NativeEndpoint)(endpoint), err
+}
+
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ if bind.ipv6 == nil {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
+ return n, (*NativeEndpoint)(endpoint), err
+}
+
+func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
+ var err error
+ nend := endpoint.(*NativeEndpoint)
+ if nend.IP.To4() != nil {
+ if bind.ipv4 == nil {
+ return syscall.EAFNOSUPPORT
+ }
+ _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
+ } else {
+ if bind.ipv6 == nil {
+ return syscall.EAFNOSUPPORT
+ }
+ _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
+ }
+ return err
+}
diff --git a/device/conn_linux.go b/device/conn_linux.go
new file mode 100644
index 0000000..49949d5
--- /dev/null
+++ b/device/conn_linux.go
@@ -0,0 +1,746 @@
+// +build !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ *
+ * This implements userspace semantics of "sticky sockets", modeled after
+ * WireGuard's kernelspace implementation. This is more or less a straight port
+ * of the sticky-sockets.c example code:
+ * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
+ *
+ * Currently there is no way to achieve this within the net package:
+ * See e.g. https://github.com/golang/go/issues/17930
+ * So this code is remains platform dependent.
+ */
+
+package device
+
+import (
+ "errors"
+ "golang.org/x/sys/unix"
+ "golang.zx2c4.com/wireguard/rwcancel"
+ "net"
+ "strconv"
+ "sync"
+ "syscall"
+ "unsafe"
+)
+
+const (
+ FD_ERR = -1
+)
+
+type IPv4Source struct {
+ src [4]byte
+ ifindex int32
+}
+
+type IPv6Source struct {
+ src [16]byte
+ //ifindex belongs in dst.ZoneId
+}
+
+type NativeEndpoint struct {
+ dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
+ src [unsafe.Sizeof(IPv6Source{})]byte
+ isV6 bool
+}
+
+func (endpoint *NativeEndpoint) src4() *IPv4Source {
+ return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
+}
+
+func (endpoint *NativeEndpoint) src6() *IPv6Source {
+ return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
+}
+
+func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
+ return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
+}
+
+func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
+ return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
+}
+
+type NativeBind struct {
+ sock4 int
+ sock6 int
+ netlinkSock int
+ netlinkCancel *rwcancel.RWCancel
+ lastMark uint32
+}
+
+var _ Endpoint = (*NativeEndpoint)(nil)
+var _ Bind = (*NativeBind)(nil)
+
+func CreateEndpoint(s string) (Endpoint, error) {
+ var end NativeEndpoint
+ addr, err := parseEndpoint(s)
+ if err != nil {
+ return nil, err
+ }
+
+ ipv4 := addr.IP.To4()
+ if ipv4 != nil {
+ dst := end.dst4()
+ end.isV6 = false
+ dst.Port = addr.Port
+ copy(dst.Addr[:], ipv4)
+ end.ClearSrc()
+ return &end, nil
+ }
+
+ ipv6 := addr.IP.To16()
+ if ipv6 != nil {
+ zone, err := zoneToUint32(addr.Zone)
+ if err != nil {
+ return nil, err
+ }
+ dst := end.dst6()
+ end.isV6 = true
+ dst.Port = addr.Port
+ dst.ZoneId = zone
+ copy(dst.Addr[:], ipv6[:])
+ end.ClearSrc()
+ return &end, nil
+ }
+
+ return nil, errors.New("Invalid IP address")
+}
+
+func createNetlinkRouteSocket() (int, error) {
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ if err != nil {
+ return -1, err
+ }
+ saddr := &unix.SockaddrNetlink{
+ Family: unix.AF_NETLINK,
+ Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
+ }
+ err = unix.Bind(sock, saddr)
+ if err != nil {
+ unix.Close(sock)
+ return -1, err
+ }
+ return sock, nil
+
+}
+
+func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
+ var err error
+ var bind NativeBind
+ var newPort uint16
+
+ bind.netlinkSock, err = createNetlinkRouteSocket()
+ if err != nil {
+ return nil, 0, err
+ }
+ bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
+ if err != nil {
+ unix.Close(bind.netlinkSock)
+ return nil, 0, err
+ }
+
+ go bind.routineRouteListener(device)
+
+ // attempt ipv6 bind, update port if succesful
+
+ bind.sock6, newPort, err = create6(port)
+ if err != nil {
+ if err != syscall.EAFNOSUPPORT {
+ bind.netlinkCancel.Cancel()
+ return nil, 0, err
+ }
+ } else {
+ port = newPort
+ }
+
+ // attempt ipv4 bind, update port if succesful
+
+ bind.sock4, newPort, err = create4(port)
+ if err != nil {
+ if err != syscall.EAFNOSUPPORT {
+ bind.netlinkCancel.Cancel()
+ unix.Close(bind.sock6)
+ return nil, 0, err
+ }
+ } else {
+ port = newPort
+ }
+
+ if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
+ return nil, 0, errors.New("ipv4 and ipv6 not supported")
+ }
+
+ return &bind, port, nil
+}
+
+func (bind *NativeBind) SetMark(value uint32) error {
+ if bind.sock6 != -1 {
+ err := unix.SetsockoptInt(
+ bind.sock6,
+ unix.SOL_SOCKET,
+ unix.SO_MARK,
+ int(value),
+ )
+
+ if err != nil {
+ return err
+ }
+ }
+
+ if bind.sock4 != -1 {
+ err := unix.SetsockoptInt(
+ bind.sock4,
+ unix.SOL_SOCKET,
+ unix.SO_MARK,
+ int(value),
+ )
+
+ if err != nil {
+ return err
+ }
+ }
+
+ bind.lastMark = value
+ return nil
+}
+
+func closeUnblock(fd int) error {
+ // shutdown to unblock readers and writers
+ unix.Shutdown(fd, unix.SHUT_RDWR)
+ return unix.Close(fd)
+}
+
+func (bind *NativeBind) Close() error {
+ var err1, err2, err3 error
+ if bind.sock6 != -1 {
+ err1 = closeUnblock(bind.sock6)
+ }
+ if bind.sock4 != -1 {
+ err2 = closeUnblock(bind.sock4)
+ }
+ err3 = bind.netlinkCancel.Cancel()
+
+ if err1 != nil {
+ return err1
+ }
+ if err2 != nil {
+ return err2
+ }
+ return err3
+}
+
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ if bind.sock6 == -1 {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, err := receive6(
+ bind.sock6,
+ buff,
+ &end,
+ )
+ return n, &end, err
+}
+
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ if bind.sock4 == -1 {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, err := receive4(
+ bind.sock4,
+ buff,
+ &end,
+ )
+ return n, &end, err
+}
+
+func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
+ nend := end.(*NativeEndpoint)
+ if !nend.isV6 {
+ if bind.sock4 == -1 {
+ return syscall.EAFNOSUPPORT
+ }
+ return send4(bind.sock4, nend, buff)
+ } else {
+ if bind.sock6 == -1 {
+ return syscall.EAFNOSUPPORT
+ }
+ return send6(bind.sock6, nend, buff)
+ }
+}
+
+func (end *NativeEndpoint) SrcIP() net.IP {
+ if !end.isV6 {
+ return net.IPv4(
+ end.src4().src[0],
+ end.src4().src[1],
+ end.src4().src[2],
+ end.src4().src[3],
+ )
+ } else {
+ return end.src6().src[:]
+ }
+}
+
+func (end *NativeEndpoint) DstIP() net.IP {
+ if !end.isV6 {
+ return net.IPv4(
+ end.dst4().Addr[0],
+ end.dst4().Addr[1],
+ end.dst4().Addr[2],
+ end.dst4().Addr[3],
+ )
+ } else {
+ return end.dst6().Addr[:]
+ }
+}
+
+func (end *NativeEndpoint) DstToBytes() []byte {
+ if !end.isV6 {
+ return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
+ } else {
+ return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
+ }
+}
+
+func (end *NativeEndpoint) SrcToString() string {
+ return end.SrcIP().String()
+}
+
+func (end *NativeEndpoint) DstToString() string {
+ var udpAddr net.UDPAddr
+ udpAddr.IP = end.DstIP()
+ if !end.isV6 {
+ udpAddr.Port = end.dst4().Port
+ } else {
+ udpAddr.Port = end.dst6().Port
+ }
+ return udpAddr.String()
+}
+
+func (end *NativeEndpoint) ClearDst() {
+ for i := range end.dst {
+ end.dst[i] = 0
+ }
+}
+
+func (end *NativeEndpoint) ClearSrc() {
+ for i := range end.src {
+ end.src[i] = 0
+ }
+}
+
+func zoneToUint32(zone string) (uint32, error) {
+ if zone == "" {
+ return 0, nil
+ }
+ if intr, err := net.InterfaceByName(zone); err == nil {
+ return uint32(intr.Index), nil
+ }
+ n, err := strconv.ParseUint(zone, 10, 32)
+ return uint32(n), err
+}
+
+func create4(port uint16) (int, uint16, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET,
+ unix.SOCK_DGRAM,
+ 0,
+ )
+
+ if err != nil {
+ return FD_ERR, 0, err
+ }
+
+ addr := unix.SockaddrInet4{
+ Port: int(port),
+ }
+
+ // set sockopts and bind
+
+ if err := func() error {
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IP,
+ unix.IP_PKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ return unix.Bind(fd, &addr)
+ }(); err != nil {
+ unix.Close(fd)
+ return FD_ERR, 0, err
+ }
+
+ return fd, uint16(addr.Port), err
+}
+
+func create6(port uint16) (int, uint16, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET6,
+ unix.SOCK_DGRAM,
+ 0,
+ )
+
+ if err != nil {
+ return FD_ERR, 0, err
+ }
+
+ // set sockopts and bind
+
+ addr := unix.SockaddrInet6{
+ Port: int(port),
+ }
+
+ if err := func() error {
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_RECVPKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_V6ONLY,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ return unix.Bind(fd, &addr)
+
+ }(); err != nil {
+ unix.Close(fd)
+ return FD_ERR, 0, err
+ }
+
+ return fd, uint16(addr.Port), err
+}
+
+func send4(sock int, end *NativeEndpoint, buff []byte) error {
+
+ // construct message header
+
+ cmsg := struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet4Pktinfo
+ }{
+ unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
+ },
+ unix.Inet4Pktinfo{
+ Spec_dst: end.src4().src,
+ Ifindex: end.src4().ifindex,
+ },
+ }
+
+ _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
+
+ if err == nil {
+ return nil
+ }
+
+ // clear src and retry
+
+ if err == unix.EINVAL {
+ end.ClearSrc()
+ cmsg.pktinfo = unix.Inet4Pktinfo{}
+ _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
+ }
+
+ return err
+}
+
+func send6(sock int, end *NativeEndpoint, buff []byte) error {
+
+ // construct message header
+
+ cmsg := struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet6Pktinfo
+ }{
+ unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
+ },
+ unix.Inet6Pktinfo{
+ Addr: end.src6().src,
+ Ifindex: end.dst6().ZoneId,
+ },
+ }
+
+ if cmsg.pktinfo.Addr == [16]byte{} {
+ cmsg.pktinfo.Ifindex = 0
+ }
+
+ _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
+
+ if err == nil {
+ return nil
+ }
+
+ // clear src and retry
+
+ if err == unix.EINVAL {
+ end.ClearSrc()
+ cmsg.pktinfo = unix.Inet6Pktinfo{}
+ _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
+ }
+
+ return err
+}
+
+func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
+
+ // contruct message header
+
+ var cmsg struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet4Pktinfo
+ }
+
+ size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
+
+ if err != nil {
+ return 0, err
+ }
+ end.isV6 = false
+
+ if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
+ *end.dst4() = *newDst4
+ }
+
+ // update source cache
+
+ if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
+ cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
+ cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
+ end.src4().src = cmsg.pktinfo.Spec_dst
+ end.src4().ifindex = cmsg.pktinfo.Ifindex
+ }
+
+ return size, nil
+}
+
+func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
+
+ // contruct message header
+
+ var cmsg struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet6Pktinfo
+ }
+
+ size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
+
+ if err != nil {
+ return 0, err
+ }
+ end.isV6 = true
+
+ if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
+ *end.dst6() = *newDst6
+ }
+
+ // update source cache
+
+ if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
+ cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
+ cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
+ end.src6().src = cmsg.pktinfo.Addr
+ end.dst6().ZoneId = cmsg.pktinfo.Ifindex
+ }
+
+ return size, nil
+}
+
+func (bind *NativeBind) routineRouteListener(device *Device) {
+ type peerEndpointPtr struct {
+ peer *Peer
+ endpoint *Endpoint
+ }
+ var reqPeer map[uint32]peerEndpointPtr
+ var reqPeerLock sync.Mutex
+
+ defer unix.Close(bind.netlinkSock)
+
+ for msg := make([]byte, 1<<16); ; {
+ var err error
+ var msgn int
+ for {
+ msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
+ if err == nil || !rwcancel.RetryAfterError(err) {
+ break
+ }
+ if !bind.netlinkCancel.ReadyRead() {
+ return
+ }
+ }
+ if err != nil {
+ return
+ }
+
+ for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+ hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+ if uint(hdr.Len) > uint(len(remain)) {
+ break
+ }
+
+ switch hdr.Type {
+ case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
+ if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
+ if uint(len(remain)) < uint(hdr.Len) {
+ break
+ }
+ if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
+ attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
+ for {
+ if uint(len(attr)) < uint(unix.SizeofRtAttr) {
+ break
+ }
+ attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
+ if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
+ break
+ }
+ if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
+ ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
+ reqPeerLock.Lock()
+ if reqPeer == nil {
+ reqPeerLock.Unlock()
+ break
+ }
+ pePtr, ok := reqPeer[hdr.Seq]
+ reqPeerLock.Unlock()
+ if !ok {
+ break
+ }
+ pePtr.peer.Lock()
+ if &pePtr.peer.endpoint != pePtr.endpoint {
+ pePtr.peer.Unlock()
+ break
+ }
+ if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
+ pePtr.peer.Unlock()
+ break
+ }
+ pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
+ pePtr.peer.Unlock()
+ }
+ attr = attr[attrhdr.Len:]
+ }
+ }
+ break
+ }
+ reqPeerLock.Lock()
+ reqPeer = make(map[uint32]peerEndpointPtr)
+ reqPeerLock.Unlock()
+ go func() {
+ device.peers.RLock()
+ i := uint32(1)
+ for _, peer := range device.peers.keyMap {
+ peer.RLock()
+ if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
+ peer.RUnlock()
+ continue
+ }
+ if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
+ peer.RUnlock()
+ break
+ }
+ nlmsg := struct {
+ hdr unix.NlMsghdr
+ msg unix.RtMsg
+ dsthdr unix.RtAttr
+ dst [4]byte
+ srchdr unix.RtAttr
+ src [4]byte
+ markhdr unix.RtAttr
+ mark uint32
+ }{
+ unix.NlMsghdr{
+ Type: uint16(unix.RTM_GETROUTE),
+ Flags: unix.NLM_F_REQUEST,
+ Seq: i,
+ },
+ unix.RtMsg{
+ Family: unix.AF_INET,
+ Dst_len: 32,
+ Src_len: 32,
+ },
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_DST,
+ },
+ peer.endpoint.(*NativeEndpoint).dst4().Addr,
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_SRC,
+ },
+ peer.endpoint.(*NativeEndpoint).src4().src,
+ unix.RtAttr{
+ Len: 8,
+ Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix
+ },
+ uint32(bind.lastMark),
+ }
+ nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
+ reqPeerLock.Lock()
+ reqPeer[i] = peerEndpointPtr{
+ peer: peer,
+ endpoint: &peer.endpoint,
+ }
+ reqPeerLock.Unlock()
+ peer.RUnlock()
+ i++
+ _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+ if err != nil {
+ break
+ }
+ }
+ device.peers.RUnlock()
+ }()
+ }
+ remain = remain[hdr.Len:]
+ }
+ }
+}
diff --git a/device/constants.go b/device/constants.go
new file mode 100644
index 0000000..27d910f
--- /dev/null
+++ b/device/constants.go
@@ -0,0 +1,41 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "time"
+)
+
+/* Specification constants */
+
+const (
+ RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
+ RejectAfterMessages = (1 << 64) - (1 << 4) - 1
+ RekeyAfterTime = time.Second * 120
+ RekeyAttemptTime = time.Second * 90
+ RekeyTimeout = time.Second * 5
+ MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
+ RekeyTimeoutJitterMaxMs = 334
+ RejectAfterTime = time.Second * 180
+ KeepaliveTimeout = time.Second * 10
+ CookieRefreshTime = time.Second * 120
+ HandshakeInitationRate = time.Second / 20
+ PaddingMultiple = 16
+)
+
+const (
+ MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
+ MaxMessageSize = MaxSegmentSize // maximum size of transport message
+ MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
+)
+
+/* Implementation constants */
+
+const (
+ UnderLoadQueueSize = QueueHandshakeSize / 8
+ UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
+ MaxPeers = 1 << 16 // maximum number of configured peers
+)
diff --git a/device/cookie.go b/device/cookie.go
new file mode 100644
index 0000000..2f21067
--- /dev/null
+++ b/device/cookie.go
@@ -0,0 +1,250 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "crypto/hmac"
+ "crypto/rand"
+ "golang.org/x/crypto/blake2s"
+ "golang.org/x/crypto/chacha20poly1305"
+ "sync"
+ "time"
+)
+
+type CookieChecker struct {
+ sync.RWMutex
+ mac1 struct {
+ key [blake2s.Size]byte
+ }
+ mac2 struct {
+ secret [blake2s.Size]byte
+ secretSet time.Time
+ encryptionKey [chacha20poly1305.KeySize]byte
+ }
+}
+
+type CookieGenerator struct {
+ sync.RWMutex
+ mac1 struct {
+ key [blake2s.Size]byte
+ }
+ mac2 struct {
+ cookie [blake2s.Size128]byte
+ cookieSet time.Time
+ hasLastMAC1 bool
+ lastMAC1 [blake2s.Size128]byte
+ encryptionKey [chacha20poly1305.KeySize]byte
+ }
+}
+
+func (st *CookieChecker) Init(pk NoisePublicKey) {
+ st.Lock()
+ defer st.Unlock()
+
+ // mac1 state
+
+ func() {
+ hash, _ := blake2s.New256(nil)
+ hash.Write([]byte(WGLabelMAC1))
+ hash.Write(pk[:])
+ hash.Sum(st.mac1.key[:0])
+ }()
+
+ // mac2 state
+
+ func() {
+ hash, _ := blake2s.New256(nil)
+ hash.Write([]byte(WGLabelCookie))
+ hash.Write(pk[:])
+ hash.Sum(st.mac2.encryptionKey[:0])
+ }()
+
+ st.mac2.secretSet = time.Time{}
+}
+
+func (st *CookieChecker) CheckMAC1(msg []byte) bool {
+ st.RLock()
+ defer st.RUnlock()
+
+ size := len(msg)
+ smac2 := size - blake2s.Size128
+ smac1 := smac2 - blake2s.Size128
+
+ var mac1 [blake2s.Size128]byte
+
+ mac, _ := blake2s.New128(st.mac1.key[:])
+ mac.Write(msg[:smac1])
+ mac.Sum(mac1[:0])
+
+ return hmac.Equal(mac1[:], msg[smac1:smac2])
+}
+
+func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
+ st.RLock()
+ defer st.RUnlock()
+
+ if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
+ return false
+ }
+
+ // derive cookie key
+
+ var cookie [blake2s.Size128]byte
+ func() {
+ mac, _ := blake2s.New128(st.mac2.secret[:])
+ mac.Write(src)
+ mac.Sum(cookie[:0])
+ }()
+
+ // calculate mac of packet (including mac1)
+
+ smac2 := len(msg) - blake2s.Size128
+
+ var mac2 [blake2s.Size128]byte
+ func() {
+ mac, _ := blake2s.New128(cookie[:])
+ mac.Write(msg[:smac2])
+ mac.Sum(mac2[:0])
+ }()
+
+ return hmac.Equal(mac2[:], msg[smac2:])
+}
+
+func (st *CookieChecker) CreateReply(
+ msg []byte,
+ recv uint32,
+ src []byte,
+) (*MessageCookieReply, error) {
+
+ st.RLock()
+
+ // refresh cookie secret
+
+ if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
+ st.RUnlock()
+ st.Lock()
+ _, err := rand.Read(st.mac2.secret[:])
+ if err != nil {
+ st.Unlock()
+ return nil, err
+ }
+ st.mac2.secretSet = time.Now()
+ st.Unlock()
+ st.RLock()
+ }
+
+ // derive cookie
+
+ var cookie [blake2s.Size128]byte
+ func() {
+ mac, _ := blake2s.New128(st.mac2.secret[:])
+ mac.Write(src)
+ mac.Sum(cookie[:0])
+ }()
+
+ // encrypt cookie
+
+ size := len(msg)
+
+ smac2 := size - blake2s.Size128
+ smac1 := smac2 - blake2s.Size128
+
+ reply := new(MessageCookieReply)
+ reply.Type = MessageCookieReplyType
+ reply.Receiver = recv
+
+ _, err := rand.Read(reply.Nonce[:])
+ if err != nil {
+ st.RUnlock()
+ return nil, err
+ }
+
+ xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
+ xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
+
+ st.RUnlock()
+
+ return reply, nil
+}
+
+func (st *CookieGenerator) Init(pk NoisePublicKey) {
+ st.Lock()
+ defer st.Unlock()
+
+ func() {
+ hash, _ := blake2s.New256(nil)
+ hash.Write([]byte(WGLabelMAC1))
+ hash.Write(pk[:])
+ hash.Sum(st.mac1.key[:0])
+ }()
+
+ func() {
+ hash, _ := blake2s.New256(nil)
+ hash.Write([]byte(WGLabelCookie))
+ hash.Write(pk[:])
+ hash.Sum(st.mac2.encryptionKey[:0])
+ }()
+
+ st.mac2.cookieSet = time.Time{}
+}
+
+func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
+ st.Lock()
+ defer st.Unlock()
+
+ if !st.mac2.hasLastMAC1 {
+ return false
+ }
+
+ var cookie [blake2s.Size128]byte
+
+ xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
+ _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
+
+ if err != nil {
+ return false
+ }
+
+ st.mac2.cookieSet = time.Now()
+ st.mac2.cookie = cookie
+ return true
+}
+
+func (st *CookieGenerator) AddMacs(msg []byte) {
+
+ size := len(msg)
+
+ smac2 := size - blake2s.Size128
+ smac1 := smac2 - blake2s.Size128
+
+ mac1 := msg[smac1:smac2]
+ mac2 := msg[smac2:]
+
+ st.Lock()
+ defer st.Unlock()
+
+ // set mac1
+
+ func() {
+ mac, _ := blake2s.New128(st.mac1.key[:])
+ mac.Write(msg[:smac1])
+ mac.Sum(mac1[:0])
+ }()
+ copy(st.mac2.lastMAC1[:], mac1)
+ st.mac2.hasLastMAC1 = true
+
+ // set mac2
+
+ if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime {
+ return
+ }
+
+ func() {
+ mac, _ := blake2s.New128(st.mac2.cookie[:])
+ mac.Write(msg[:smac2])
+ mac.Sum(mac2[:0])
+ }()
+}
diff --git a/device/cookie_test.go b/device/cookie_test.go
new file mode 100644
index 0000000..79a6a86
--- /dev/null
+++ b/device/cookie_test.go
@@ -0,0 +1,191 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "testing"
+)
+
+func TestCookieMAC1(t *testing.T) {
+
+ // setup generator / checker
+
+ var (
+ generator CookieGenerator
+ checker CookieChecker
+ )
+
+ sk, err := newPrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+ pk := sk.publicKey()
+
+ generator.Init(pk)
+ checker.Init(pk)
+
+ // check mac1
+
+ src := []byte{192, 168, 13, 37, 10, 10, 10}
+
+ checkMAC1 := func(msg []byte) {
+ generator.AddMacs(msg)
+ if !checker.CheckMAC1(msg) {
+ t.Fatal("MAC1 generation/verification failed")
+ }
+ if checker.CheckMAC2(msg, src) {
+ t.Fatal("MAC2 generation/verification failed")
+ }
+ }
+
+ checkMAC1([]byte{
+ 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd,
+ 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62,
+ 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64,
+ 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91,
+ 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4,
+ 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c,
+ 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b,
+ 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5,
+ 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8,
+ 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8,
+ 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e,
+ 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82,
+ 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53,
+ 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd,
+ 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad,
+ 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f,
+ })
+
+ checkMAC1([]byte{
+ 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c,
+ 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56,
+ 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a,
+ 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c,
+ 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b,
+ 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc,
+ 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b,
+ 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05,
+ })
+
+ checkMAC1([]byte{
+ 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b,
+ 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc,
+ 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b,
+ 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05,
+ })
+
+ // exchange cookie reply
+
+ func() {
+ msg := []byte{
+ 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf,
+ 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8,
+ 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5,
+ 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f,
+ 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2,
+ 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b,
+ 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e,
+ 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9,
+ 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22,
+ 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27,
+ 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f,
+ 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2,
+ 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b,
+ 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb,
+ 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61,
+ 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
+ }
+ generator.AddMacs(msg)
+ reply, err := checker.CreateReply(msg, 1377, src)
+ if err != nil {
+ t.Fatal("Failed to create cookie reply:", err)
+ }
+ if !generator.ConsumeReply(reply) {
+ t.Fatal("Failed to consume cookie reply")
+ }
+ }()
+
+ // check mac2
+
+ checkMAC2 := func(msg []byte) {
+ generator.AddMacs(msg)
+
+ if !checker.CheckMAC1(msg) {
+ t.Fatal("MAC1 generation/verification failed")
+ }
+ if !checker.CheckMAC2(msg, src) {
+ t.Fatal("MAC2 generation/verification failed")
+ }
+
+ msg[5] ^= 0x20
+
+ if checker.CheckMAC1(msg) {
+ t.Fatal("MAC1 generation/verification failed")
+ }
+ if checker.CheckMAC2(msg, src) {
+ t.Fatal("MAC2 generation/verification failed")
+ }
+
+ msg[5] ^= 0x20
+
+ srcBad1 := []byte{192, 168, 13, 37, 40, 01}
+ if checker.CheckMAC2(msg, srcBad1) {
+ t.Fatal("MAC2 generation/verification failed")
+ }
+
+ srcBad2 := []byte{192, 168, 13, 38, 40, 01}
+ if checker.CheckMAC2(msg, srcBad2) {
+ t.Fatal("MAC2 generation/verification failed")
+ }
+ }
+
+ checkMAC2([]byte{
+ 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3,
+ 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15,
+ 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f,
+ 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f,
+ 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69,
+ 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb,
+ 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c,
+ 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38,
+ })
+
+ checkMAC2([]byte{
+ 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3,
+ 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85,
+ 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84,
+ 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5,
+ 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85,
+ 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55,
+ 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a,
+ 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca,
+ 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59,
+ 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca,
+ 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b,
+ 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7,
+ 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55,
+ 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2,
+ 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f,
+ 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d,
+ 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d,
+ 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f,
+ 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2,
+ 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0,
+ 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35,
+ 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a,
+ 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee,
+ 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe,
+ 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e,
+ 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4,
+ 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06,
+ 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93,
+ 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa,
+ 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64,
+ 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b,
+ 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa,
+ })
+}
diff --git a/device/device.go b/device/device.go
new file mode 100644
index 0000000..d6c96d6
--- /dev/null
+++ b/device/device.go
@@ -0,0 +1,396 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "golang.zx2c4.com/wireguard/ratelimiter"
+ "golang.zx2c4.com/wireguard/tun"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+const (
+ DeviceRoutineNumberPerCPU = 3
+ DeviceRoutineNumberAdditional = 2
+)
+
+type Device struct {
+ isUp AtomicBool // device is (going) up
+ isClosed AtomicBool // device is closed? (acting as guard)
+ log *Logger
+
+ // synchronized resources (locks acquired in order)
+
+ state struct {
+ starting sync.WaitGroup
+ stopping sync.WaitGroup
+ sync.Mutex
+ changing AtomicBool
+ current bool
+ }
+
+ net struct {
+ starting sync.WaitGroup
+ stopping sync.WaitGroup
+ sync.RWMutex
+ bind Bind // bind interface
+ port uint16 // listening port
+ fwmark uint32 // mark value (0 = disabled)
+ }
+
+ staticIdentity struct {
+ sync.RWMutex
+ privateKey NoisePrivateKey
+ publicKey NoisePublicKey
+ }
+
+ peers struct {
+ sync.RWMutex
+ keyMap map[NoisePublicKey]*Peer
+ }
+
+ // unprotected / "self-synchronising resources"
+
+ allowedips AllowedIPs
+ indexTable IndexTable
+ cookieChecker CookieChecker
+
+ rate struct {
+ underLoadUntil atomic.Value
+ limiter ratelimiter.Ratelimiter
+ }
+
+ pool struct {
+ messageBufferPool *sync.Pool
+ messageBufferReuseChan chan *[MaxMessageSize]byte
+ inboundElementPool *sync.Pool
+ inboundElementReuseChan chan *QueueInboundElement
+ outboundElementPool *sync.Pool
+ outboundElementReuseChan chan *QueueOutboundElement
+ }
+
+ queue struct {
+ encryption chan *QueueOutboundElement
+ decryption chan *QueueInboundElement
+ handshake chan QueueHandshakeElement
+ }
+
+ signals struct {
+ stop chan struct{}
+ }
+
+ tun struct {
+ device tun.TUNDevice
+ mtu int32
+ }
+}
+
+/* Converts the peer into a "zombie", which remains in the peer map,
+ * but processes no packets and does not exists in the routing table.
+ *
+ * Must hold device.peers.Mutex
+ */
+func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
+
+ // stop routing and processing of packets
+
+ device.allowedips.RemoveByPeer(peer)
+ peer.Stop()
+
+ // remove from peer map
+
+ delete(device.peers.keyMap, key)
+}
+
+func deviceUpdateState(device *Device) {
+
+ // check if state already being updated (guard)
+
+ if device.state.changing.Swap(true) {
+ return
+ }
+
+ // compare to current state of device
+
+ device.state.Lock()
+
+ newIsUp := device.isUp.Get()
+
+ if newIsUp == device.state.current {
+ device.state.changing.Set(false)
+ device.state.Unlock()
+ return
+ }
+
+ // change state of device
+
+ switch newIsUp {
+ case true:
+ if err := device.BindUpdate(); err != nil {
+ device.isUp.Set(false)
+ break
+ }
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Start()
+ if peer.persistentKeepaliveInterval > 0 {
+ peer.SendKeepalive()
+ }
+ }
+ device.peers.RUnlock()
+
+ case false:
+ device.BindClose()
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Stop()
+ }
+ device.peers.RUnlock()
+ }
+
+ // update state variables
+
+ device.state.current = newIsUp
+ device.state.changing.Set(false)
+ device.state.Unlock()
+
+ // check for state change in the mean time
+
+ deviceUpdateState(device)
+}
+
+func (device *Device) Up() {
+
+ // closed device cannot be brought up
+
+ if device.isClosed.Get() {
+ return
+ }
+
+ device.isUp.Set(true)
+ deviceUpdateState(device)
+}
+
+func (device *Device) Down() {
+ device.isUp.Set(false)
+ deviceUpdateState(device)
+}
+
+func (device *Device) IsUnderLoad() bool {
+
+ // check if currently under load
+
+ now := time.Now()
+ underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
+ if underLoad {
+ device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
+ return true
+ }
+
+ // check if recently under load
+
+ until := device.rate.underLoadUntil.Load().(time.Time)
+ return until.After(now)
+}
+
+func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
+
+ // lock required resources
+
+ device.staticIdentity.Lock()
+ defer device.staticIdentity.Unlock()
+
+ device.peers.Lock()
+ defer device.peers.Unlock()
+
+ for _, peer := range device.peers.keyMap {
+ peer.handshake.mutex.RLock()
+ defer peer.handshake.mutex.RUnlock()
+ }
+
+ // remove peers with matching public keys
+
+ publicKey := sk.publicKey()
+ for key, peer := range device.peers.keyMap {
+ if peer.handshake.remoteStatic.Equals(publicKey) {
+ unsafeRemovePeer(device, peer, key)
+ }
+ }
+
+ // update key material
+
+ device.staticIdentity.privateKey = sk
+ device.staticIdentity.publicKey = publicKey
+ device.cookieChecker.Init(publicKey)
+
+ // do static-static DH pre-computations
+
+ rmKey := device.staticIdentity.privateKey.IsZero()
+
+ for key, peer := range device.peers.keyMap {
+
+ handshake := &peer.handshake
+
+ if rmKey {
+ handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
+ } else {
+ handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ }
+
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ unsafeRemovePeer(device, peer, key)
+ }
+ }
+
+ return nil
+}
+
+func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
+ device := new(Device)
+
+ device.isUp.Set(false)
+ device.isClosed.Set(false)
+
+ device.log = logger
+
+ device.tun.device = tunDevice
+ mtu, err := device.tun.device.MTU()
+ if err != nil {
+ logger.Error.Println("Trouble determining MTU, assuming default:", err)
+ mtu = DefaultMTU
+ }
+ device.tun.mtu = int32(mtu)
+
+ device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+
+ device.rate.limiter.Init()
+ device.rate.underLoadUntil.Store(time.Time{})
+
+ device.indexTable.Init()
+ device.allowedips.Reset()
+
+ device.PopulatePools()
+
+ // create queues
+
+ device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
+ device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
+ device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
+
+ // prepare signals
+
+ device.signals.stop = make(chan struct{})
+
+ // prepare net
+
+ device.net.port = 0
+ device.net.bind = nil
+
+ // start workers
+
+ cpus := runtime.NumCPU()
+ device.state.starting.Wait()
+ device.state.stopping.Wait()
+ device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
+ device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
+ for i := 0; i < cpus; i += 1 {
+ go device.RoutineEncryption()
+ go device.RoutineDecryption()
+ go device.RoutineHandshake()
+ }
+
+ go device.RoutineReadFromTUN()
+ go device.RoutineTUNEventReader()
+
+ device.state.starting.Wait()
+
+ return device
+}
+
+func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
+ device.peers.RLock()
+ defer device.peers.RUnlock()
+
+ return device.peers.keyMap[pk]
+}
+
+func (device *Device) RemovePeer(key NoisePublicKey) {
+ device.peers.Lock()
+ defer device.peers.Unlock()
+
+ // stop peer and remove from routing
+
+ peer, ok := device.peers.keyMap[key]
+ if ok {
+ unsafeRemovePeer(device, peer, key)
+ }
+}
+
+func (device *Device) RemoveAllPeers() {
+ device.peers.Lock()
+ defer device.peers.Unlock()
+
+ for key, peer := range device.peers.keyMap {
+ unsafeRemovePeer(device, peer, key)
+ }
+
+ device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+}
+
+func (device *Device) FlushPacketQueues() {
+ for {
+ select {
+ case elem, ok := <-device.queue.decryption:
+ if ok {
+ elem.Drop()
+ }
+ case elem, ok := <-device.queue.encryption:
+ if ok {
+ elem.Drop()
+ }
+ case <-device.queue.handshake:
+ default:
+ return
+ }
+ }
+
+}
+
+func (device *Device) Close() {
+ if device.isClosed.Swap(true) {
+ return
+ }
+
+ device.state.starting.Wait()
+
+ device.log.Info.Println("Device closing")
+ device.state.changing.Set(true)
+ device.state.Lock()
+ defer device.state.Unlock()
+
+ device.tun.device.Close()
+ device.BindClose()
+
+ device.isUp.Set(false)
+
+ close(device.signals.stop)
+
+ device.RemoveAllPeers()
+
+ device.state.stopping.Wait()
+ device.FlushPacketQueues()
+
+ device.rate.limiter.Close()
+
+ device.state.changing.Set(false)
+ device.log.Info.Println("Interface closed")
+}
+
+func (device *Device) Wait() chan struct{} {
+ return device.signals.stop
+}
diff --git a/device/device_test.go b/device/device_test.go
new file mode 100644
index 0000000..db5a3c0
--- /dev/null
+++ b/device/device_test.go
@@ -0,0 +1,48 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+/* Create two device instances and simulate full WireGuard interaction
+ * without network dependencies
+ */
+
+import "testing"
+
+func TestDevice(t *testing.T) {
+
+ // prepare tun devices for generating traffic
+
+ tun1, err := CreateDummyTUN("tun1")
+ if err != nil {
+ t.Error("failed to create tun:", err.Error())
+ }
+
+ tun2, err := CreateDummyTUN("tun2")
+ if err != nil {
+ t.Error("failed to create tun:", err.Error())
+ }
+
+ _ = tun1
+ _ = tun2
+
+ // prepare endpoints
+
+ end1, err := CreateDummyEndpoint()
+ if err != nil {
+ t.Error("failed to create endpoint:", err.Error())
+ }
+
+ end2, err := CreateDummyEndpoint()
+ if err != nil {
+ t.Error("failed to create endpoint:", err.Error())
+ }
+
+ _ = end1
+ _ = end2
+
+ // create binds
+
+}
diff --git a/device/endpoint_test.go b/device/endpoint_test.go
new file mode 100644
index 0000000..1896790
--- /dev/null
+++ b/device/endpoint_test.go
@@ -0,0 +1,53 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "math/rand"
+ "net"
+)
+
+type DummyEndpoint struct {
+ src [16]byte
+ dst [16]byte
+}
+
+func CreateDummyEndpoint() (*DummyEndpoint, error) {
+ var end DummyEndpoint
+ if _, err := rand.Read(end.src[:]); err != nil {
+ return nil, err
+ }
+ _, err := rand.Read(end.dst[:])
+ return &end, err
+}
+
+func (e *DummyEndpoint) ClearSrc() {}
+
+func (e *DummyEndpoint) SrcToString() string {
+ var addr net.UDPAddr
+ addr.IP = e.SrcIP()
+ addr.Port = 1000
+ return addr.String()
+}
+
+func (e *DummyEndpoint) DstToString() string {
+ var addr net.UDPAddr
+ addr.IP = e.DstIP()
+ addr.Port = 1000
+ return addr.String()
+}
+
+func (e *DummyEndpoint) SrcToBytes() []byte {
+ return e.src[:]
+}
+
+func (e *DummyEndpoint) DstIP() net.IP {
+ return e.dst[:]
+}
+
+func (e *DummyEndpoint) SrcIP() net.IP {
+ return e.src[:]
+}
diff --git a/device/indextable.go b/device/indextable.go
new file mode 100644
index 0000000..4cba970
--- /dev/null
+++ b/device/indextable.go
@@ -0,0 +1,97 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "crypto/rand"
+ "sync"
+ "unsafe"
+)
+
+type IndexTableEntry struct {
+ peer *Peer
+ handshake *Handshake
+ keypair *Keypair
+}
+
+type IndexTable struct {
+ sync.RWMutex
+ table map[uint32]IndexTableEntry
+}
+
+func randUint32() (uint32, error) {
+ var integer [4]byte
+ _, err := rand.Read(integer[:])
+ return *(*uint32)(unsafe.Pointer(&integer[0])), err
+}
+
+func (table *IndexTable) Init() {
+ table.Lock()
+ defer table.Unlock()
+ table.table = make(map[uint32]IndexTableEntry)
+}
+
+func (table *IndexTable) Delete(index uint32) {
+ table.Lock()
+ defer table.Unlock()
+ delete(table.table, index)
+}
+
+func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
+ table.Lock()
+ defer table.Unlock()
+ entry, ok := table.table[index]
+ if !ok {
+ return
+ }
+ table.table[index] = IndexTableEntry{
+ peer: entry.peer,
+ keypair: keypair,
+ handshake: nil,
+ }
+}
+
+func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) {
+ for {
+ // generate random index
+
+ index, err := randUint32()
+ if err != nil {
+ return index, err
+ }
+
+ // check if index used
+
+ table.RLock()
+ _, ok := table.table[index]
+ table.RUnlock()
+ if ok {
+ continue
+ }
+
+ // check again while locked
+
+ table.Lock()
+ _, found := table.table[index]
+ if found {
+ table.Unlock()
+ continue
+ }
+ table.table[index] = IndexTableEntry{
+ peer: peer,
+ handshake: handshake,
+ keypair: nil,
+ }
+ table.Unlock()
+ return index, nil
+ }
+}
+
+func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
+ table.RLock()
+ defer table.RUnlock()
+ return table.table[id]
+}
diff --git a/device/ip.go b/device/ip.go
new file mode 100644
index 0000000..9d4fb74
--- /dev/null
+++ b/device/ip.go
@@ -0,0 +1,22 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "net"
+)
+
+const (
+ IPv4offsetTotalLength = 2
+ IPv4offsetSrc = 12
+ IPv4offsetDst = IPv4offsetSrc + net.IPv4len
+)
+
+const (
+ IPv6offsetPayloadLength = 4
+ IPv6offsetSrc = 8
+ IPv6offsetDst = IPv6offsetSrc + net.IPv6len
+)
diff --git a/device/kdf_test.go b/device/kdf_test.go
new file mode 100644
index 0000000..11ea8d5
--- /dev/null
+++ b/device/kdf_test.go
@@ -0,0 +1,84 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "encoding/hex"
+ "golang.org/x/crypto/blake2s"
+ "testing"
+)
+
+type KDFTest struct {
+ key string
+ input string
+ t0 string
+ t1 string
+ t2 string
+}
+
+func assertEquals(t *testing.T, a string, b string) {
+ if a != b {
+ t.Fatal("expected", a, "=", b)
+ }
+}
+
+func TestKDF(t *testing.T) {
+ tests := []KDFTest{
+ {
+ key: "746573742d6b6579",
+ input: "746573742d696e707574",
+ t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633",
+ t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a",
+ t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24",
+ },
+ {
+ key: "776972656775617264",
+ input: "776972656775617264",
+ t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8",
+ t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f",
+ t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160",
+ },
+ {
+ key: "",
+ input: "",
+ t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0",
+ t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e",
+ t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e",
+ },
+ }
+
+ var t0, t1, t2 [blake2s.Size]byte
+
+ for _, test := range tests {
+ key, _ := hex.DecodeString(test.key)
+ input, _ := hex.DecodeString(test.input)
+ KDF3(&t0, &t1, &t2, key, input)
+ t0s := hex.EncodeToString(t0[:])
+ t1s := hex.EncodeToString(t1[:])
+ t2s := hex.EncodeToString(t2[:])
+ assertEquals(t, t0s, test.t0)
+ assertEquals(t, t1s, test.t1)
+ assertEquals(t, t2s, test.t2)
+ }
+
+ for _, test := range tests {
+ key, _ := hex.DecodeString(test.key)
+ input, _ := hex.DecodeString(test.input)
+ KDF2(&t0, &t1, key, input)
+ t0s := hex.EncodeToString(t0[:])
+ t1s := hex.EncodeToString(t1[:])
+ assertEquals(t, t0s, test.t0)
+ assertEquals(t, t1s, test.t1)
+ }
+
+ for _, test := range tests {
+ key, _ := hex.DecodeString(test.key)
+ input, _ := hex.DecodeString(test.input)
+ KDF1(&t0, key, input)
+ t0s := hex.EncodeToString(t0[:])
+ assertEquals(t, t0s, test.t0)
+ }
+}
diff --git a/device/keypair.go b/device/keypair.go
new file mode 100644
index 0000000..a9fbfce
--- /dev/null
+++ b/device/keypair.go
@@ -0,0 +1,50 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "crypto/cipher"
+ "golang.zx2c4.com/wireguard/replay"
+ "sync"
+ "time"
+)
+
+/* Due to limitations in Go and /x/crypto there is currently
+ * no way to ensure that key material is securely ereased in memory.
+ *
+ * Since this may harm the forward secrecy property,
+ * we plan to resolve this issue; whenever Go allows us to do so.
+ */
+
+type Keypair struct {
+ sendNonce uint64
+ send cipher.AEAD
+ receive cipher.AEAD
+ replayFilter replay.ReplayFilter
+ isInitiator bool
+ created time.Time
+ localIndex uint32
+ remoteIndex uint32
+}
+
+type Keypairs struct {
+ sync.RWMutex
+ current *Keypair
+ previous *Keypair
+ next *Keypair
+}
+
+func (kp *Keypairs) Current() *Keypair {
+ kp.RLock()
+ defer kp.RUnlock()
+ return kp.current
+}
+
+func (device *Device) DeleteKeypair(key *Keypair) {
+ if key != nil {
+ device.indexTable.Delete(key.localIndex)
+ }
+}
diff --git a/device/logger.go b/device/logger.go
new file mode 100644
index 0000000..7c8b704
--- /dev/null
+++ b/device/logger.go
@@ -0,0 +1,59 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "io"
+ "io/ioutil"
+ "log"
+ "os"
+)
+
+const (
+ LogLevelSilent = iota
+ LogLevelError
+ LogLevelInfo
+ LogLevelDebug
+)
+
+type Logger struct {
+ Debug *log.Logger
+ Info *log.Logger
+ Error *log.Logger
+}
+
+func NewLogger(level int, prepend string) *Logger {
+ output := os.Stdout
+ logger := new(Logger)
+
+ logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
+ if level >= LogLevelDebug {
+ return output, output, output
+ }
+ if level >= LogLevelInfo {
+ return output, output, ioutil.Discard
+ }
+ if level >= LogLevelError {
+ return output, ioutil.Discard, ioutil.Discard
+ }
+ return ioutil.Discard, ioutil.Discard, ioutil.Discard
+ }()
+
+ logger.Debug = log.New(logDebug,
+ "DEBUG: "+prepend,
+ log.Ldate|log.Ltime,
+ )
+
+ logger.Info = log.New(logInfo,
+ "INFO: "+prepend,
+ log.Ldate|log.Ltime,
+ )
+ logger.Error = log.New(logErr,
+ "ERROR: "+prepend,
+ log.Ldate|log.Ltime,
+ )
+ return logger
+}
diff --git a/device/mark_default.go b/device/mark_default.go
new file mode 100644
index 0000000..76b1015
--- /dev/null
+++ b/device/mark_default.go
@@ -0,0 +1,12 @@
+// +build !linux,!openbsd,!freebsd
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+func (bind *NativeBind) SetMark(mark uint32) error {
+ return nil
+}
diff --git a/device/mark_unix.go b/device/mark_unix.go
new file mode 100644
index 0000000..ee64cc9
--- /dev/null
+++ b/device/mark_unix.go
@@ -0,0 +1,64 @@
+// +build android openbsd freebsd
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "golang.org/x/sys/unix"
+ "runtime"
+)
+
+var fwmarkIoctl int
+
+func init() {
+ switch runtime.GOOS {
+ case "linux", "android":
+ fwmarkIoctl = 36 /* unix.SO_MARK */
+ case "freebsd":
+ fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
+ case "openbsd":
+ fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
+ }
+}
+
+func (bind *NativeBind) SetMark(mark uint32) error {
+ var operr error
+ if fwmarkIoctl == 0 {
+ return nil
+ }
+ if bind.ipv4 != nil {
+ fd, err := bind.ipv4.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err = fd.Control(func(fd uintptr) {
+ operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
+ })
+ if err == nil {
+ err = operr
+ }
+ if err != nil {
+ return err
+ }
+ }
+ if bind.ipv6 != nil {
+ fd, err := bind.ipv6.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err = fd.Control(func(fd uintptr) {
+ operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
+ })
+ if err == nil {
+ err = operr
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
diff --git a/device/misc.go b/device/misc.go
new file mode 100644
index 0000000..a38d1c1
--- /dev/null
+++ b/device/misc.go
@@ -0,0 +1,48 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "sync/atomic"
+)
+
+/* Atomic Boolean */
+
+const (
+ AtomicFalse = int32(iota)
+ AtomicTrue
+)
+
+type AtomicBool struct {
+ int32
+}
+
+func (a *AtomicBool) Get() bool {
+ return atomic.LoadInt32(&a.int32) == AtomicTrue
+}
+
+func (a *AtomicBool) Swap(val bool) bool {
+ flag := AtomicFalse
+ if val {
+ flag = AtomicTrue
+ }
+ return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
+}
+
+func (a *AtomicBool) Set(val bool) {
+ flag := AtomicFalse
+ if val {
+ flag = AtomicTrue
+ }
+ atomic.StoreInt32(&a.int32, flag)
+}
+
+func min(a, b uint) uint {
+ if a > b {
+ return b
+ }
+ return a
+}
diff --git a/device/noise-helpers.go b/device/noise-helpers.go
new file mode 100644
index 0000000..4b09bf3
--- /dev/null
+++ b/device/noise-helpers.go
@@ -0,0 +1,104 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "crypto/hmac"
+ "crypto/rand"
+ "crypto/subtle"
+ "golang.org/x/crypto/blake2s"
+ "golang.org/x/crypto/curve25519"
+ "hash"
+)
+
+/* KDF related functions.
+ * HMAC-based Key Derivation Function (HKDF)
+ * https://tools.ietf.org/html/rfc5869
+ */
+
+func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) {
+ mac := hmac.New(func() hash.Hash {
+ h, _ := blake2s.New256(nil)
+ return h
+ }, key)
+ mac.Write(in0)
+ mac.Sum(sum[:0])
+}
+
+func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
+ mac := hmac.New(func() hash.Hash {
+ h, _ := blake2s.New256(nil)
+ return h
+ }, key)
+ mac.Write(in0)
+ mac.Write(in1)
+ mac.Sum(sum[:0])
+}
+
+func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
+ HMAC1(t0, key, input)
+ HMAC1(t0, t0[:], []byte{0x1})
+ return
+}
+
+func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
+ var prk [blake2s.Size]byte
+ HMAC1(&prk, key, input)
+ HMAC1(t0, prk[:], []byte{0x1})
+ HMAC2(t1, prk[:], t0[:], []byte{0x2})
+ setZero(prk[:])
+ return
+}
+
+func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
+ var prk [blake2s.Size]byte
+ HMAC1(&prk, key, input)
+ HMAC1(t0, prk[:], []byte{0x1})
+ HMAC2(t1, prk[:], t0[:], []byte{0x2})
+ HMAC2(t2, prk[:], t1[:], []byte{0x3})
+ setZero(prk[:])
+ return
+}
+
+func isZero(val []byte) bool {
+ acc := 1
+ for _, b := range val {
+ acc &= subtle.ConstantTimeByteEq(b, 0)
+ }
+ return acc == 1
+}
+
+/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */
+func setZero(arr []byte) {
+ for i := range arr {
+ arr[i] = 0
+ }
+}
+
+func (sk *NoisePrivateKey) clamp() {
+ sk[0] &= 248
+ sk[31] = (sk[31] & 127) | 64
+}
+
+func newPrivateKey() (sk NoisePrivateKey, err error) {
+ _, err = rand.Read(sk[:])
+ sk.clamp()
+ return
+}
+
+func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
+ apk := (*[NoisePublicKeySize]byte)(&pk)
+ ask := (*[NoisePrivateKeySize]byte)(sk)
+ curve25519.ScalarBaseMult(apk, ask)
+ return
+}
+
+func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
+ apk := (*[NoisePublicKeySize]byte)(&pk)
+ ask := (*[NoisePrivateKeySize]byte)(sk)
+ curve25519.ScalarMult(&ss, ask, apk)
+ return ss
+}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
new file mode 100644
index 0000000..73826e1
--- /dev/null
+++ b/device/noise-protocol.go
@@ -0,0 +1,600 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "errors"
+ "golang.org/x/crypto/blake2s"
+ "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/crypto/poly1305"
+ "golang.zx2c4.com/wireguard/tai64n"
+ "sync"
+ "time"
+)
+
+const (
+ HandshakeZeroed = iota
+ HandshakeInitiationCreated
+ HandshakeInitiationConsumed
+ HandshakeResponseCreated
+ HandshakeResponseConsumed
+)
+
+const (
+ NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
+ WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
+ WGLabelMAC1 = "mac1----"
+ WGLabelCookie = "cookie--"
+)
+
+const (
+ MessageInitiationType = 1
+ MessageResponseType = 2
+ MessageCookieReplyType = 3
+ MessageTransportType = 4
+)
+
+const (
+ MessageInitiationSize = 148 // size of handshake initation message
+ MessageResponseSize = 92 // size of response message
+ MessageCookieReplySize = 64 // size of cookie reply message
+ MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
+ MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
+ MessageKeepaliveSize = MessageTransportSize // size of keepalive
+ MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
+)
+
+const (
+ MessageTransportOffsetReceiver = 4
+ MessageTransportOffsetCounter = 8
+ MessageTransportOffsetContent = 16
+)
+
+/* Type is an 8-bit field, followed by 3 nul bytes,
+ * by marshalling the messages in little-endian byteorder
+ * we can treat these as a 32-bit unsigned int (for now)
+ *
+ */
+
+type MessageInitiation struct {
+ Type uint32
+ Sender uint32
+ Ephemeral NoisePublicKey
+ Static [NoisePublicKeySize + poly1305.TagSize]byte
+ Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
+ MAC1 [blake2s.Size128]byte
+ MAC2 [blake2s.Size128]byte
+}
+
+type MessageResponse struct {
+ Type uint32
+ Sender uint32
+ Receiver uint32
+ Ephemeral NoisePublicKey
+ Empty [poly1305.TagSize]byte
+ MAC1 [blake2s.Size128]byte
+ MAC2 [blake2s.Size128]byte
+}
+
+type MessageTransport struct {
+ Type uint32
+ Receiver uint32
+ Counter uint64
+ Content []byte
+}
+
+type MessageCookieReply struct {
+ Type uint32
+ Receiver uint32
+ Nonce [chacha20poly1305.NonceSizeX]byte
+ Cookie [blake2s.Size128 + poly1305.TagSize]byte
+}
+
+type Handshake struct {
+ state int
+ mutex sync.RWMutex
+ hash [blake2s.Size]byte // hash value
+ chainKey [blake2s.Size]byte // chain key
+ presharedKey NoiseSymmetricKey // psk
+ localEphemeral NoisePrivateKey // ephemeral secret key
+ localIndex uint32 // used to clear hash-table
+ remoteIndex uint32 // index for sending
+ remoteStatic NoisePublicKey // long term key
+ remoteEphemeral NoisePublicKey // ephemeral public key
+ precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
+ lastTimestamp tai64n.Timestamp
+ lastInitiationConsumption time.Time
+ lastSentHandshake time.Time
+}
+
+var (
+ InitialChainKey [blake2s.Size]byte
+ InitialHash [blake2s.Size]byte
+ ZeroNonce [chacha20poly1305.NonceSize]byte
+)
+
+func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
+ KDF1(dst, c[:], data)
+}
+
+func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
+ hash, _ := blake2s.New256(nil)
+ hash.Write(h[:])
+ hash.Write(data)
+ hash.Sum(dst[:0])
+ hash.Reset()
+}
+
+func (h *Handshake) Clear() {
+ setZero(h.localEphemeral[:])
+ setZero(h.remoteEphemeral[:])
+ setZero(h.chainKey[:])
+ setZero(h.hash[:])
+ h.localIndex = 0
+ h.state = HandshakeZeroed
+}
+
+func (h *Handshake) mixHash(data []byte) {
+ mixHash(&h.hash, &h.hash, data)
+}
+
+func (h *Handshake) mixKey(data []byte) {
+ mixKey(&h.chainKey, &h.chainKey, data)
+}
+
+/* Do basic precomputations
+ */
+func init() {
+ InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
+ mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
+}
+
+func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
+
+ device.staticIdentity.RLock()
+ defer device.staticIdentity.RUnlock()
+
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ return nil, errors.New("static shared secret is zero")
+ }
+
+ // create ephemeral key
+
+ var err error
+ handshake.hash = InitialHash
+ handshake.chainKey = InitialChainKey
+ handshake.localEphemeral, err = newPrivateKey()
+ if err != nil {
+ return nil, err
+ }
+
+ // assign index
+
+ device.indexTable.Delete(handshake.localIndex)
+ handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
+
+ if err != nil {
+ return nil, err
+ }
+
+ handshake.mixHash(handshake.remoteStatic[:])
+
+ msg := MessageInitiation{
+ Type: MessageInitiationType,
+ Ephemeral: handshake.localEphemeral.publicKey(),
+ Sender: handshake.localIndex,
+ }
+
+ handshake.mixKey(msg.Ephemeral[:])
+ handshake.mixHash(msg.Ephemeral[:])
+
+ // encrypt static key
+
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ ss[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
+ }()
+ handshake.mixHash(msg.Static[:])
+
+ // encrypt timestamp
+
+ timestamp := tai64n.Now()
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ KDF2(
+ &handshake.chainKey,
+ &key,
+ handshake.chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
+ }()
+
+ handshake.mixHash(msg.Timestamp[:])
+ handshake.state = HandshakeInitiationCreated
+ return &msg, nil
+}
+
+func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
+ var (
+ hash [blake2s.Size]byte
+ chainKey [blake2s.Size]byte
+ )
+
+ if msg.Type != MessageInitiationType {
+ return nil
+ }
+
+ device.staticIdentity.RLock()
+ defer device.staticIdentity.RUnlock()
+
+ mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
+ mixHash(&hash, &hash, msg.Ephemeral[:])
+ mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
+
+ // decrypt static key
+
+ var err error
+ var peerPK NoisePublicKey
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ KDF2(&chainKey, &key, chainKey[:], ss[:])
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
+ }()
+ if err != nil {
+ return nil
+ }
+ mixHash(&hash, &hash, msg.Static[:])
+
+ // lookup peer
+
+ peer := device.LookupPeer(peerPK)
+ if peer == nil {
+ return nil
+ }
+
+ handshake := &peer.handshake
+ if isZero(handshake.precomputedStaticStatic[:]) {
+ return nil
+ }
+
+ // verify identity
+
+ var timestamp tai64n.Timestamp
+ var key [chacha20poly1305.KeySize]byte
+
+ handshake.mutex.RLock()
+ KDF2(
+ &chainKey,
+ &key,
+ chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
+ if err != nil {
+ handshake.mutex.RUnlock()
+ return nil
+ }
+ mixHash(&hash, &hash, msg.Timestamp[:])
+
+ // protect against replay & flood
+
+ var ok bool
+ ok = timestamp.After(handshake.lastTimestamp)
+ ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate
+ handshake.mutex.RUnlock()
+ if !ok {
+ return nil
+ }
+
+ // update handshake state
+
+ handshake.mutex.Lock()
+
+ handshake.hash = hash
+ handshake.chainKey = chainKey
+ handshake.remoteIndex = msg.Sender
+ handshake.remoteEphemeral = msg.Ephemeral
+ handshake.lastTimestamp = timestamp
+ handshake.lastInitiationConsumption = time.Now()
+ handshake.state = HandshakeInitiationConsumed
+
+ handshake.mutex.Unlock()
+
+ setZero(hash[:])
+ setZero(chainKey[:])
+
+ return peer
+}
+
+func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+
+ if handshake.state != HandshakeInitiationConsumed {
+ return nil, errors.New("handshake initiation must be consumed first")
+ }
+
+ // assign index
+
+ var err error
+ device.indexTable.Delete(handshake.localIndex)
+ handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
+ if err != nil {
+ return nil, err
+ }
+
+ var msg MessageResponse
+ msg.Type = MessageResponseType
+ msg.Sender = handshake.localIndex
+ msg.Receiver = handshake.remoteIndex
+
+ // create ephemeral key
+
+ handshake.localEphemeral, err = newPrivateKey()
+ if err != nil {
+ return nil, err
+ }
+ msg.Ephemeral = handshake.localEphemeral.publicKey()
+ handshake.mixHash(msg.Ephemeral[:])
+ handshake.mixKey(msg.Ephemeral[:])
+
+ func() {
+ ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ handshake.mixKey(ss[:])
+ ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ handshake.mixKey(ss[:])
+ }()
+
+ // add preshared key
+
+ var tau [blake2s.Size]byte
+ var key [chacha20poly1305.KeySize]byte
+
+ KDF3(
+ &handshake.chainKey,
+ &tau,
+ &key,
+ handshake.chainKey[:],
+ handshake.presharedKey[:],
+ )
+
+ handshake.mixHash(tau[:])
+
+ func() {
+ aead, _ := chacha20poly1305.New(key[:])
+ aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
+ handshake.mixHash(msg.Empty[:])
+ }()
+
+ handshake.state = HandshakeResponseCreated
+
+ return &msg, nil
+}
+
+func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
+ if msg.Type != MessageResponseType {
+ return nil
+ }
+
+ // lookup handshake by receiver
+
+ lookup := device.indexTable.Lookup(msg.Receiver)
+ handshake := lookup.handshake
+ if handshake == nil {
+ return nil
+ }
+
+ var (
+ hash [blake2s.Size]byte
+ chainKey [blake2s.Size]byte
+ )
+
+ ok := func() bool {
+
+ // lock handshake state
+
+ handshake.mutex.RLock()
+ defer handshake.mutex.RUnlock()
+
+ if handshake.state != HandshakeInitiationCreated {
+ return false
+ }
+
+ // lock private key for reading
+
+ device.staticIdentity.RLock()
+ defer device.staticIdentity.RUnlock()
+
+ // finish 3-way DH
+
+ mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
+ mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
+
+ func() {
+ ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
+ }()
+
+ func() {
+ ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ mixKey(&chainKey, &chainKey, ss[:])
+ setZero(ss[:])
+ }()
+
+ // add preshared key (psk)
+
+ var tau [blake2s.Size]byte
+ var key [chacha20poly1305.KeySize]byte
+ KDF3(
+ &chainKey,
+ &tau,
+ &key,
+ chainKey[:],
+ handshake.presharedKey[:],
+ )
+ mixHash(&hash, &hash, tau[:])
+
+ // authenticate transcript
+
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ if err != nil {
+ return false
+ }
+ mixHash(&hash, &hash, msg.Empty[:])
+ return true
+ }()
+
+ if !ok {
+ return nil
+ }
+
+ // update handshake state
+
+ handshake.mutex.Lock()
+
+ handshake.hash = hash
+ handshake.chainKey = chainKey
+ handshake.remoteIndex = msg.Sender
+ handshake.state = HandshakeResponseConsumed
+
+ handshake.mutex.Unlock()
+
+ setZero(hash[:])
+ setZero(chainKey[:])
+
+ return lookup.peer
+}
+
+/* Derives a new keypair from the current handshake state
+ *
+ */
+func (peer *Peer) BeginSymmetricSession() error {
+ device := peer.device
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ defer handshake.mutex.Unlock()
+
+ // derive keys
+
+ var isInitiator bool
+ var sendKey [chacha20poly1305.KeySize]byte
+ var recvKey [chacha20poly1305.KeySize]byte
+
+ if handshake.state == HandshakeResponseConsumed {
+ KDF2(
+ &sendKey,
+ &recvKey,
+ handshake.chainKey[:],
+ nil,
+ )
+ isInitiator = true
+ } else if handshake.state == HandshakeResponseCreated {
+ KDF2(
+ &recvKey,
+ &sendKey,
+ handshake.chainKey[:],
+ nil,
+ )
+ isInitiator = false
+ } else {
+ return errors.New("invalid state for keypair derivation")
+ }
+
+ // zero handshake
+
+ setZero(handshake.chainKey[:])
+ setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
+ setZero(handshake.localEphemeral[:])
+ peer.handshake.state = HandshakeZeroed
+
+ // create AEAD instances
+
+ keypair := new(Keypair)
+ keypair.send, _ = chacha20poly1305.New(sendKey[:])
+ keypair.receive, _ = chacha20poly1305.New(recvKey[:])
+
+ setZero(sendKey[:])
+ setZero(recvKey[:])
+
+ keypair.created = time.Now()
+ keypair.sendNonce = 0
+ keypair.replayFilter.Init()
+ keypair.isInitiator = isInitiator
+ keypair.localIndex = peer.handshake.localIndex
+ keypair.remoteIndex = peer.handshake.remoteIndex
+
+ // remap index
+
+ device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
+ handshake.localIndex = 0
+
+ // rotate key pairs
+
+ keypairs := &peer.keypairs
+ keypairs.Lock()
+ defer keypairs.Unlock()
+
+ previous := keypairs.previous
+ next := keypairs.next
+ current := keypairs.current
+
+ if isInitiator {
+ if next != nil {
+ keypairs.next = nil
+ keypairs.previous = next
+ device.DeleteKeypair(current)
+ } else {
+ keypairs.previous = current
+ }
+ device.DeleteKeypair(previous)
+ keypairs.current = keypair
+ } else {
+ keypairs.next = keypair
+ device.DeleteKeypair(next)
+ keypairs.previous = nil
+ device.DeleteKeypair(previous)
+ }
+
+ return nil
+}
+
+func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
+ keypairs := &peer.keypairs
+ if keypairs.next != receivedKeypair {
+ return false
+ }
+ keypairs.Lock()
+ defer keypairs.Unlock()
+ if keypairs.next != receivedKeypair {
+ return false
+ }
+ old := keypairs.previous
+ keypairs.previous = keypairs.current
+ peer.device.DeleteKeypair(old)
+ keypairs.current = keypairs.next
+ keypairs.next = nil
+ return true
+}
diff --git a/device/noise-types.go b/device/noise-types.go
new file mode 100644
index 0000000..82b12c1
--- /dev/null
+++ b/device/noise-types.go
@@ -0,0 +1,81 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "crypto/subtle"
+ "encoding/hex"
+ "errors"
+ "golang.org/x/crypto/chacha20poly1305"
+)
+
+const (
+ NoisePublicKeySize = 32
+ NoisePrivateKeySize = 32
+)
+
+type (
+ NoisePublicKey [NoisePublicKeySize]byte
+ NoisePrivateKey [NoisePrivateKeySize]byte
+ NoiseSymmetricKey [chacha20poly1305.KeySize]byte
+ NoiseNonce uint64 // padded to 12-bytes
+)
+
+func loadExactHex(dst []byte, src string) error {
+ slice, err := hex.DecodeString(src)
+ if err != nil {
+ return err
+ }
+ if len(slice) != len(dst) {
+ return errors.New("hex string does not fit the slice")
+ }
+ copy(dst, slice)
+ return nil
+}
+
+func (key NoisePrivateKey) IsZero() bool {
+ var zero NoisePrivateKey
+ return key.Equals(zero)
+}
+
+func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
+ return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
+}
+
+func (key *NoisePrivateKey) FromHex(src string) (err error) {
+ err = loadExactHex(key[:], src)
+ key.clamp()
+ return
+}
+
+func (key NoisePrivateKey) ToHex() string {
+ return hex.EncodeToString(key[:])
+}
+
+func (key *NoisePublicKey) FromHex(src string) error {
+ return loadExactHex(key[:], src)
+}
+
+func (key NoisePublicKey) ToHex() string {
+ return hex.EncodeToString(key[:])
+}
+
+func (key NoisePublicKey) IsZero() bool {
+ var zero NoisePublicKey
+ return key.Equals(zero)
+}
+
+func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
+ return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
+}
+
+func (key *NoiseSymmetricKey) FromHex(src string) error {
+ return loadExactHex(key[:], src)
+}
+
+func (key NoiseSymmetricKey) ToHex() string {
+ return hex.EncodeToString(key[:])
+}
diff --git a/device/noise_test.go b/device/noise_test.go
new file mode 100644
index 0000000..6ba3f2e
--- /dev/null
+++ b/device/noise_test.go
@@ -0,0 +1,144 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "bytes"
+ "encoding/binary"
+ "testing"
+)
+
+func TestCurveWrappers(t *testing.T) {
+ sk1, err := newPrivateKey()
+ assertNil(t, err)
+
+ sk2, err := newPrivateKey()
+ assertNil(t, err)
+
+ pk1 := sk1.publicKey()
+ pk2 := sk2.publicKey()
+
+ ss1 := sk1.sharedSecret(pk2)
+ ss2 := sk2.sharedSecret(pk1)
+
+ if ss1 != ss2 {
+ t.Fatal("Failed to compute shared secet")
+ }
+}
+
+func TestNoiseHandshake(t *testing.T) {
+ dev1 := randDevice(t)
+ dev2 := randDevice(t)
+
+ defer dev1.Close()
+ defer dev2.Close()
+
+ peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
+ peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
+
+ assertEqual(
+ t,
+ peer1.handshake.precomputedStaticStatic[:],
+ peer2.handshake.precomputedStaticStatic[:],
+ )
+
+ /* simulate handshake */
+
+ // initiation message
+
+ t.Log("exchange initiation message")
+
+ msg1, err := dev1.CreateMessageInitiation(peer2)
+ assertNil(t, err)
+
+ packet := make([]byte, 0, 256)
+ writer := bytes.NewBuffer(packet)
+ err = binary.Write(writer, binary.LittleEndian, msg1)
+ assertNil(t, err)
+ peer := dev2.ConsumeMessageInitiation(msg1)
+ if peer == nil {
+ t.Fatal("handshake failed at initiation message")
+ }
+
+ assertEqual(
+ t,
+ peer1.handshake.chainKey[:],
+ peer2.handshake.chainKey[:],
+ )
+
+ assertEqual(
+ t,
+ peer1.handshake.hash[:],
+ peer2.handshake.hash[:],
+ )
+
+ // response message
+
+ t.Log("exchange response message")
+
+ msg2, err := dev2.CreateMessageResponse(peer1)
+ assertNil(t, err)
+
+ peer = dev1.ConsumeMessageResponse(msg2)
+ if peer == nil {
+ t.Fatal("handshake failed at response message")
+ }
+
+ assertEqual(
+ t,
+ peer1.handshake.chainKey[:],
+ peer2.handshake.chainKey[:],
+ )
+
+ assertEqual(
+ t,
+ peer1.handshake.hash[:],
+ peer2.handshake.hash[:],
+ )
+
+ // key pairs
+
+ t.Log("deriving keys")
+
+ err = peer1.BeginSymmetricSession()
+ if err != nil {
+ t.Fatal("failed to derive keypair for peer 1", err)
+ }
+
+ err = peer2.BeginSymmetricSession()
+ if err != nil {
+ t.Fatal("failed to derive keypair for peer 2", err)
+ }
+
+ key1 := peer1.keypairs.next
+ key2 := peer2.keypairs.current
+
+ // encrypting / decryption test
+
+ t.Log("test key pairs")
+
+ func() {
+ testMsg := []byte("wireguard test message 1")
+ var err error
+ var out []byte
+ var nonce [12]byte
+ out = key1.send.Seal(out, nonce[:], testMsg, nil)
+ out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
+ assertNil(t, err)
+ assertEqual(t, out, testMsg)
+ }()
+
+ func() {
+ testMsg := []byte("wireguard test message 2")
+ var err error
+ var out []byte
+ var nonce [12]byte
+ out = key2.send.Seal(out, nonce[:], testMsg, nil)
+ out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
+ assertNil(t, err)
+ assertEqual(t, out, testMsg)
+ }()
+}
diff --git a/device/peer.go b/device/peer.go
new file mode 100644
index 0000000..af3ef9d
--- /dev/null
+++ b/device/peer.go
@@ -0,0 +1,270 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "encoding/base64"
+ "errors"
+ "fmt"
+ "sync"
+ "time"
+)
+
+const (
+ PeerRoutineNumber = 3
+)
+
+type Peer struct {
+ isRunning AtomicBool
+ sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
+ keypairs Keypairs
+ handshake Handshake
+ device *Device
+ endpoint Endpoint
+ persistentKeepaliveInterval uint16
+
+ // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
+ stats struct {
+ txBytes uint64 // bytes send to peer (endpoint)
+ rxBytes uint64 // bytes received from peer
+ lastHandshakeNano int64 // nano seconds since epoch
+ }
+
+ timers struct {
+ retransmitHandshake *Timer
+ sendKeepalive *Timer
+ newHandshake *Timer
+ zeroKeyMaterial *Timer
+ persistentKeepalive *Timer
+ handshakeAttempts uint32
+ needAnotherKeepalive AtomicBool
+ sentLastMinuteHandshake AtomicBool
+ }
+
+ signals struct {
+ newKeypairArrived chan struct{}
+ flushNonceQueue chan struct{}
+ }
+
+ queue struct {
+ nonce chan *QueueOutboundElement // nonce / pre-handshake queue
+ outbound chan *QueueOutboundElement // sequential ordering of work
+ inbound chan *QueueInboundElement // sequential ordering of work
+ packetInNonceQueueIsAwaitingKey AtomicBool
+ }
+
+ routines struct {
+ sync.Mutex // held when stopping / starting routines
+ starting sync.WaitGroup // routines pending start
+ stopping sync.WaitGroup // routines pending stop
+ stop chan struct{} // size 0, stop all go routines in peer
+ }
+
+ cookieGenerator CookieGenerator
+}
+
+func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
+
+ if device.isClosed.Get() {
+ return nil, errors.New("device closed")
+ }
+
+ // lock resources
+
+ device.staticIdentity.RLock()
+ defer device.staticIdentity.RUnlock()
+
+ device.peers.Lock()
+ defer device.peers.Unlock()
+
+ // check if over limit
+
+ if len(device.peers.keyMap) >= MaxPeers {
+ return nil, errors.New("too many peers")
+ }
+
+ // create peer
+
+ peer := new(Peer)
+ peer.Lock()
+ defer peer.Unlock()
+
+ peer.cookieGenerator.Init(pk)
+ peer.device = device
+ peer.isRunning.Set(false)
+
+ // map public key
+
+ _, ok := device.peers.keyMap[pk]
+ if ok {
+ return nil, errors.New("adding existing peer")
+ }
+ device.peers.keyMap[pk] = peer
+
+ // pre-compute DH
+
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ handshake.remoteStatic = pk
+ handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+ handshake.mutex.Unlock()
+
+ // reset endpoint
+
+ peer.endpoint = nil
+
+ // start peer
+
+ if peer.device.isUp.Get() {
+ peer.Start()
+ }
+
+ return peer, nil
+}
+
+func (peer *Peer) SendBuffer(buffer []byte) error {
+ peer.device.net.RLock()
+ defer peer.device.net.RUnlock()
+
+ if peer.device.net.bind == nil {
+ return errors.New("no bind")
+ }
+
+ peer.RLock()
+ defer peer.RUnlock()
+
+ if peer.endpoint == nil {
+ return errors.New("no known endpoint for peer")
+ }
+
+ return peer.device.net.bind.Send(buffer, peer.endpoint)
+}
+
+func (peer *Peer) String() string {
+ base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
+ abbreviatedKey := "invalid"
+ if len(base64Key) == 44 {
+ abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
+ }
+ return fmt.Sprintf("peer(%s)", abbreviatedKey)
+}
+
+func (peer *Peer) Start() {
+
+ // should never start a peer on a closed device
+
+ if peer.device.isClosed.Get() {
+ return
+ }
+
+ // prevent simultaneous start/stop operations
+
+ peer.routines.Lock()
+ defer peer.routines.Unlock()
+
+ if peer.isRunning.Get() {
+ return
+ }
+
+ device := peer.device
+ device.log.Debug.Println(peer, "- Starting...")
+
+ // reset routine state
+
+ peer.routines.starting.Wait()
+ peer.routines.stopping.Wait()
+ peer.routines.stop = make(chan struct{})
+ peer.routines.starting.Add(PeerRoutineNumber)
+ peer.routines.stopping.Add(PeerRoutineNumber)
+
+ // prepare queues
+
+ peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+
+ peer.timersInit()
+ peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+ peer.signals.newKeypairArrived = make(chan struct{}, 1)
+ peer.signals.flushNonceQueue = make(chan struct{}, 1)
+
+ // wait for routines to start
+
+ go peer.RoutineNonce()
+ go peer.RoutineSequentialSender()
+ go peer.RoutineSequentialReceiver()
+
+ peer.routines.starting.Wait()
+ peer.isRunning.Set(true)
+}
+
+func (peer *Peer) ZeroAndFlushAll() {
+ device := peer.device
+
+ // clear key pairs
+
+ keypairs := &peer.keypairs
+ keypairs.Lock()
+ device.DeleteKeypair(keypairs.previous)
+ device.DeleteKeypair(keypairs.current)
+ device.DeleteKeypair(keypairs.next)
+ keypairs.previous = nil
+ keypairs.current = nil
+ keypairs.next = nil
+ keypairs.Unlock()
+
+ // clear handshake state
+
+ handshake := &peer.handshake
+ handshake.mutex.Lock()
+ device.indexTable.Delete(handshake.localIndex)
+ handshake.Clear()
+ handshake.mutex.Unlock()
+
+ peer.FlushNonceQueue()
+}
+
+func (peer *Peer) Stop() {
+
+ // prevent simultaneous start/stop operations
+
+ if !peer.isRunning.Swap(false) {
+ return
+ }
+
+ peer.routines.starting.Wait()
+
+ peer.routines.Lock()
+ defer peer.routines.Unlock()
+
+ peer.device.log.Debug.Println(peer, "- Stopping...")
+
+ peer.timersStop()
+
+ // stop & wait for ongoing peer routines
+
+ close(peer.routines.stop)
+ peer.routines.stopping.Wait()
+
+ // close queues
+
+ close(peer.queue.nonce)
+ close(peer.queue.outbound)
+ close(peer.queue.inbound)
+
+ peer.ZeroAndFlushAll()
+}
+
+var roamingDisabled bool
+
+func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
+ if roamingDisabled {
+ return
+ }
+ peer.Lock()
+ peer.endpoint = endpoint
+ peer.Unlock()
+}
diff --git a/device/pools.go b/device/pools.go
new file mode 100644
index 0000000..98f4ef1
--- /dev/null
+++ b/device/pools.go
@@ -0,0 +1,89 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import "sync"
+
+func (device *Device) PopulatePools() {
+ if PreallocatedBuffersPerPool == 0 {
+ device.pool.messageBufferPool = &sync.Pool{
+ New: func() interface{} {
+ return new([MaxMessageSize]byte)
+ },
+ }
+ device.pool.inboundElementPool = &sync.Pool{
+ New: func() interface{} {
+ return new(QueueInboundElement)
+ },
+ }
+ device.pool.outboundElementPool = &sync.Pool{
+ New: func() interface{} {
+ return new(QueueOutboundElement)
+ },
+ }
+ } else {
+ device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
+ for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
+ device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
+ }
+ device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
+ for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
+ device.pool.inboundElementReuseChan <- new(QueueInboundElement)
+ }
+ device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
+ for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
+ device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
+ }
+ }
+}
+
+func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
+ if PreallocatedBuffersPerPool == 0 {
+ return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
+ } else {
+ return <-device.pool.messageBufferReuseChan
+ }
+}
+
+func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
+ if PreallocatedBuffersPerPool == 0 {
+ device.pool.messageBufferPool.Put(msg)
+ } else {
+ device.pool.messageBufferReuseChan <- msg
+ }
+}
+
+func (device *Device) GetInboundElement() *QueueInboundElement {
+ if PreallocatedBuffersPerPool == 0 {
+ return device.pool.inboundElementPool.Get().(*QueueInboundElement)
+ } else {
+ return <-device.pool.inboundElementReuseChan
+ }
+}
+
+func (device *Device) PutInboundElement(msg *QueueInboundElement) {
+ if PreallocatedBuffersPerPool == 0 {
+ device.pool.inboundElementPool.Put(msg)
+ } else {
+ device.pool.inboundElementReuseChan <- msg
+ }
+}
+
+func (device *Device) GetOutboundElement() *QueueOutboundElement {
+ if PreallocatedBuffersPerPool == 0 {
+ return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
+ } else {
+ return <-device.pool.outboundElementReuseChan
+ }
+}
+
+func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
+ if PreallocatedBuffersPerPool == 0 {
+ device.pool.outboundElementPool.Put(msg)
+ } else {
+ device.pool.outboundElementReuseChan <- msg
+ }
+}
diff --git a/device/queueconstants.go b/device/queueconstants.go
new file mode 100644
index 0000000..3e94b7f
--- /dev/null
+++ b/device/queueconstants.go
@@ -0,0 +1,16 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+/* Implementation specific constants */
+
+const (
+ QueueOutboundSize = 1024
+ QueueInboundSize = 1024
+ QueueHandshakeSize = 1024
+ MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
+ PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
+)
diff --git a/device/receive.go b/device/receive.go
new file mode 100644
index 0000000..5c837c1
--- /dev/null
+++ b/device/receive.go
@@ -0,0 +1,641 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "bytes"
+ "encoding/binary"
+ "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "net"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+type QueueHandshakeElement struct {
+ msgType uint32
+ packet []byte
+ endpoint Endpoint
+ buffer *[MaxMessageSize]byte
+}
+
+type QueueInboundElement struct {
+ dropped int32
+ sync.Mutex
+ buffer *[MaxMessageSize]byte
+ packet []byte
+ counter uint64
+ keypair *Keypair
+ endpoint Endpoint
+}
+
+func (elem *QueueInboundElement) Drop() {
+ atomic.StoreInt32(&elem.dropped, AtomicTrue)
+}
+
+func (elem *QueueInboundElement) IsDropped() bool {
+ return atomic.LoadInt32(&elem.dropped) == AtomicTrue
+}
+
+func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
+ select {
+ case inboundQueue <- element:
+ select {
+ case decryptionQueue <- element:
+ return true
+ default:
+ element.Drop()
+ element.Unlock()
+ return false
+ }
+ default:
+ device.PutInboundElement(element)
+ return false
+ }
+}
+
+func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
+ select {
+ case queue <- element:
+ return true
+ default:
+ return false
+ }
+}
+
+/* Called when a new authenticated message has been received
+ *
+ * NOTE: Not thread safe, but called by sequential receiver!
+ */
+func (peer *Peer) keepKeyFreshReceiving() {
+ if peer.timers.sentLastMinuteHandshake.Get() {
+ return
+ }
+ keypair := peer.keypairs.Current()
+ if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
+ peer.timers.sentLastMinuteHandshake.Set(true)
+ peer.SendHandshakeInitiation(false)
+ }
+}
+
+/* Receives incoming datagrams for the device
+ *
+ * Every time the bind is updated a new routine is started for
+ * IPv4 and IPv6 (separately)
+ */
+func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
+
+ logDebug := device.log.Debug
+ defer func() {
+ logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
+ device.net.stopping.Done()
+ }()
+
+ logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started")
+ device.net.starting.Done()
+
+ // receive datagrams until conn is closed
+
+ buffer := device.GetMessageBuffer()
+
+ var (
+ err error
+ size int
+ endpoint Endpoint
+ )
+
+ for {
+
+ // read next datagram
+
+ switch IP {
+ case ipv4.Version:
+ size, endpoint, err = bind.ReceiveIPv4(buffer[:])
+ case ipv6.Version:
+ size, endpoint, err = bind.ReceiveIPv6(buffer[:])
+ default:
+ panic("invalid IP version")
+ }
+
+ if err != nil {
+ device.PutMessageBuffer(buffer)
+ return
+ }
+
+ if size < MinMessageSize {
+ continue
+ }
+
+ // check size of packet
+
+ packet := buffer[:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
+
+ var okay bool
+
+ switch msgType {
+
+ // check if transport
+
+ case MessageTransportType:
+
+ // check size
+
+ if len(packet) < MessageTransportSize {
+ continue
+ }
+
+ // lookup key pair
+
+ receiver := binary.LittleEndian.Uint32(
+ packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
+ )
+ value := device.indexTable.Lookup(receiver)
+ keypair := value.keypair
+ if keypair == nil {
+ continue
+ }
+
+ // check keypair expiry
+
+ if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
+ continue
+ }
+
+ // create work element
+ peer := value.peer
+ elem := device.GetInboundElement()
+ elem.packet = packet
+ elem.buffer = buffer
+ elem.keypair = keypair
+ elem.dropped = AtomicFalse
+ elem.endpoint = endpoint
+ elem.counter = 0
+ elem.Mutex = sync.Mutex{}
+ elem.Lock()
+
+ // add to decryption queues
+
+ if peer.isRunning.Get() {
+ if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
+ buffer = device.GetMessageBuffer()
+ }
+ }
+
+ continue
+
+ // otherwise it is a fixed size & handshake related packet
+
+ case MessageInitiationType:
+ okay = len(packet) == MessageInitiationSize
+
+ case MessageResponseType:
+ okay = len(packet) == MessageResponseSize
+
+ case MessageCookieReplyType:
+ okay = len(packet) == MessageCookieReplySize
+
+ default:
+ logDebug.Println("Received message with unknown type")
+ }
+
+ if okay {
+ if (device.addToHandshakeQueue(
+ device.queue.handshake,
+ QueueHandshakeElement{
+ msgType: msgType,
+ buffer: buffer,
+ packet: packet,
+ endpoint: endpoint,
+ },
+ )) {
+ buffer = device.GetMessageBuffer()
+ }
+ }
+ }
+}
+
+func (device *Device) RoutineDecryption() {
+
+ var nonce [chacha20poly1305.NonceSize]byte
+
+ logDebug := device.log.Debug
+ defer func() {
+ logDebug.Println("Routine: decryption worker - stopped")
+ device.state.stopping.Done()
+ }()
+ logDebug.Println("Routine: decryption worker - started")
+ device.state.starting.Done()
+
+ for {
+ select {
+ case <-device.signals.stop:
+ return
+
+ case elem, ok := <-device.queue.decryption:
+
+ if !ok {
+ return
+ }
+
+ // check if dropped
+
+ if elem.IsDropped() {
+ continue
+ }
+
+ // split message into fields
+
+ counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
+ content := elem.packet[MessageTransportOffsetContent:]
+
+ // expand nonce
+
+ nonce[0x4] = counter[0x0]
+ nonce[0x5] = counter[0x1]
+ nonce[0x6] = counter[0x2]
+ nonce[0x7] = counter[0x3]
+
+ nonce[0x8] = counter[0x4]
+ nonce[0x9] = counter[0x5]
+ nonce[0xa] = counter[0x6]
+ nonce[0xb] = counter[0x7]
+
+ // decrypt and release to consumer
+
+ var err error
+ elem.counter = binary.LittleEndian.Uint64(counter)
+ elem.packet, err = elem.keypair.receive.Open(
+ content[:0],
+ nonce[:],
+ content,
+ nil,
+ )
+ if err != nil {
+ elem.Drop()
+ device.PutMessageBuffer(elem.buffer)
+ }
+ elem.Unlock()
+ }
+ }
+}
+
+/* Handles incoming packets related to handshake
+ */
+func (device *Device) RoutineHandshake() {
+
+ logInfo := device.log.Info
+ logError := device.log.Error
+ logDebug := device.log.Debug
+
+ var elem QueueHandshakeElement
+ var ok bool
+
+ defer func() {
+ logDebug.Println("Routine: handshake worker - stopped")
+ device.state.stopping.Done()
+ if elem.buffer != nil {
+ device.PutMessageBuffer(elem.buffer)
+ }
+ }()
+
+ logDebug.Println("Routine: handshake worker - started")
+ device.state.starting.Done()
+
+ for {
+ if elem.buffer != nil {
+ device.PutMessageBuffer(elem.buffer)
+ elem.buffer = nil
+ }
+
+ select {
+ case elem, ok = <-device.queue.handshake:
+ case <-device.signals.stop:
+ return
+ }
+
+ if !ok {
+ return
+ }
+
+ // handle cookie fields and ratelimiting
+
+ switch elem.msgType {
+
+ case MessageCookieReplyType:
+
+ // unmarshal packet
+
+ var reply MessageCookieReply
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &reply)
+ if err != nil {
+ logDebug.Println("Failed to decode cookie reply")
+ return
+ }
+
+ // lookup peer from index
+
+ entry := device.indexTable.Lookup(reply.Receiver)
+
+ if entry.peer == nil {
+ continue
+ }
+
+ // consume reply
+
+ if peer := entry.peer; peer.isRunning.Get() {
+ logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString())
+ if !peer.cookieGenerator.ConsumeReply(&reply) {
+ logDebug.Println("Could not decrypt invalid cookie response")
+ }
+ }
+
+ continue
+
+ case MessageInitiationType, MessageResponseType:
+
+ // check mac fields and maybe ratelimit
+
+ if !device.cookieChecker.CheckMAC1(elem.packet) {
+ logDebug.Println("Received packet with invalid mac1")
+ continue
+ }
+
+ // endpoints destination address is the source of the datagram
+
+ if device.IsUnderLoad() {
+
+ // verify MAC2 field
+
+ if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
+ device.SendHandshakeCookie(&elem)
+ continue
+ }
+
+ // check ratelimiter
+
+ if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
+ continue
+ }
+ }
+
+ default:
+ logError.Println("Invalid packet ended up in the handshake queue")
+ continue
+ }
+
+ // handle handshake initiation/response content
+
+ switch elem.msgType {
+ case MessageInitiationType:
+
+ // unmarshal
+
+ var msg MessageInitiation
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &msg)
+ if err != nil {
+ logError.Println("Failed to decode initiation message")
+ continue
+ }
+
+ // consume initiation
+
+ peer := device.ConsumeMessageInitiation(&msg)
+ if peer == nil {
+ logInfo.Println(
+ "Received invalid initiation message from",
+ elem.endpoint.DstToString(),
+ )
+ continue
+ }
+
+ // update timers
+
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+
+ // update endpoint
+ peer.SetEndpointFromPacket(elem.endpoint)
+
+ logDebug.Println(peer, "- Received handshake initiation")
+
+ peer.SendHandshakeResponse()
+
+ case MessageResponseType:
+
+ // unmarshal
+
+ var msg MessageResponse
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &msg)
+ if err != nil {
+ logError.Println("Failed to decode response message")
+ continue
+ }
+
+ // consume response
+
+ peer := device.ConsumeMessageResponse(&msg)
+ if peer == nil {
+ logInfo.Println(
+ "Received invalid response message from",
+ elem.endpoint.DstToString(),
+ )
+ continue
+ }
+
+ // update endpoint
+ peer.SetEndpointFromPacket(elem.endpoint)
+
+ logDebug.Println(peer, "- Received handshake response")
+
+ // update timers
+
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+
+ // derive keypair
+
+ err = peer.BeginSymmetricSession()
+
+ if err != nil {
+ logError.Println(peer, "- Failed to derive keypair:", err)
+ continue
+ }
+
+ peer.timersSessionDerived()
+ peer.timersHandshakeComplete()
+ peer.SendKeepalive()
+ select {
+ case peer.signals.newKeypairArrived <- struct{}{}:
+ default:
+ }
+ }
+ }
+}
+
+func (peer *Peer) RoutineSequentialReceiver() {
+
+ device := peer.device
+ logInfo := device.log.Info
+ logError := device.log.Error
+ logDebug := device.log.Debug
+
+ var elem *QueueInboundElement
+ var ok bool
+
+ defer func() {
+ logDebug.Println(peer, "- Routine: sequential receiver - stopped")
+ peer.routines.stopping.Done()
+ if elem != nil {
+ if !elem.IsDropped() {
+ device.PutMessageBuffer(elem.buffer)
+ }
+ device.PutInboundElement(elem)
+ }
+ }()
+
+ logDebug.Println(peer, "- Routine: sequential receiver - started")
+
+ peer.routines.starting.Done()
+
+ for {
+ if elem != nil {
+ if !elem.IsDropped() {
+ device.PutMessageBuffer(elem.buffer)
+ }
+ device.PutInboundElement(elem)
+ elem = nil
+ }
+
+ select {
+
+ case <-peer.routines.stop:
+ return
+
+ case elem, ok = <-peer.queue.inbound:
+
+ if !ok {
+ return
+ }
+
+ // wait for decryption
+
+ elem.Lock()
+
+ if elem.IsDropped() {
+ continue
+ }
+
+ // check for replay
+
+ if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
+ continue
+ }
+
+ // update endpoint
+ peer.SetEndpointFromPacket(elem.endpoint)
+
+ // check if using new keypair
+ if peer.ReceivedWithKeypair(elem.keypair) {
+ peer.timersHandshakeComplete()
+ select {
+ case peer.signals.newKeypairArrived <- struct{}{}:
+ default:
+ }
+ }
+
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
+
+ // check for keepalive
+
+ if len(elem.packet) == 0 {
+ logDebug.Println(peer, "- Receiving keepalive packet")
+ continue
+ }
+ peer.timersDataReceived()
+
+ // verify source and strip padding
+
+ switch elem.packet[0] >> 4 {
+ case ipv4.Version:
+
+ // strip padding
+
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
+
+ elem.packet = elem.packet[:length]
+
+ // verify IPv4 source
+
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.allowedips.LookupIPv4(src) != peer {
+ logInfo.Println(
+ "IPv4 packet with disallowed source address from",
+ peer,
+ )
+ continue
+ }
+
+ case ipv6.Version:
+
+ // strip padding
+
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
+ }
+
+ elem.packet = elem.packet[:length]
+
+ // verify IPv6 source
+
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.allowedips.LookupIPv6(src) != peer {
+ logInfo.Println(
+ peer,
+ "sent packet with disallowed IPv6 source",
+ )
+ continue
+ }
+
+ default:
+ logInfo.Println("Packet with invalid IP version from", peer)
+ continue
+ }
+
+ // write to tun device
+
+ offset := MessageTransportOffsetContent
+ atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
+ if err != nil {
+ logError.Println("Failed to write packet to TUN device:", err)
+ }
+ }
+ }
+}
diff --git a/device/send.go b/device/send.go
new file mode 100644
index 0000000..b4e23c7
--- /dev/null
+++ b/device/send.go
@@ -0,0 +1,618 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "bytes"
+ "encoding/binary"
+ "golang.org/x/crypto/chacha20poly1305"
+ "golang.org/x/net/ipv4"
+ "golang.org/x/net/ipv6"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+/* Outbound flow
+ *
+ * 1. TUN queue
+ * 2. Routing (sequential)
+ * 3. Nonce assignment (sequential)
+ * 4. Encryption (parallel)
+ * 5. Transmission (sequential)
+ *
+ * The functions in this file occur (roughly) in the order in
+ * which the packets are processed.
+ *
+ * Locking, Producers and Consumers
+ *
+ * The order of packets (per peer) must be maintained,
+ * but encryption of packets happen out-of-order:
+ *
+ * The sequential consumers will attempt to take the lock,
+ * workers release lock when they have completed work (encryption) on the packet.
+ *
+ * If the element is inserted into the "encryption queue",
+ * the content is preceded by enough "junk" to contain the transport header
+ * (to allow the construction of transport messages in-place)
+ */
+
+type QueueOutboundElement struct {
+ dropped int32
+ sync.Mutex
+ buffer *[MaxMessageSize]byte // slice holding the packet data
+ packet []byte // slice of "buffer" (always!)
+ nonce uint64 // nonce for encryption
+ keypair *Keypair // keypair for encryption
+ peer *Peer // related peer
+}
+
+func (device *Device) NewOutboundElement() *QueueOutboundElement {
+ elem := device.GetOutboundElement()
+ elem.dropped = AtomicFalse
+ elem.buffer = device.GetMessageBuffer()
+ elem.Mutex = sync.Mutex{}
+ elem.nonce = 0
+ elem.keypair = nil
+ elem.peer = nil
+ return elem
+}
+
+func (elem *QueueOutboundElement) Drop() {
+ atomic.StoreInt32(&elem.dropped, AtomicTrue)
+}
+
+func (elem *QueueOutboundElement) IsDropped() bool {
+ return atomic.LoadInt32(&elem.dropped) == AtomicTrue
+}
+
+func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
+ for {
+ select {
+ case queue <- element:
+ return
+ default:
+ select {
+ case old := <-queue:
+ device.PutMessageBuffer(old.buffer)
+ device.PutOutboundElement(old)
+ default:
+ }
+ }
+ }
+}
+
+func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
+ select {
+ case outboundQueue <- element:
+ select {
+ case encryptionQueue <- element:
+ return
+ default:
+ element.Drop()
+ element.peer.device.PutMessageBuffer(element.buffer)
+ element.Unlock()
+ }
+ default:
+ element.peer.device.PutMessageBuffer(element.buffer)
+ element.peer.device.PutOutboundElement(element)
+ }
+}
+
+/* Queues a keepalive if no packets are queued for peer
+ */
+func (peer *Peer) SendKeepalive() bool {
+ if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
+ return false
+ }
+ elem := peer.device.NewOutboundElement()
+ elem.packet = nil
+ select {
+ case peer.queue.nonce <- elem:
+ peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
+ return true
+ default:
+ peer.device.PutMessageBuffer(elem.buffer)
+ peer.device.PutOutboundElement(elem)
+ return false
+ }
+}
+
+func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
+ if !isRetry {
+ atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+ }
+
+ peer.handshake.mutex.RLock()
+ if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
+ peer.handshake.mutex.RUnlock()
+ return nil
+ }
+ peer.handshake.mutex.RUnlock()
+
+ peer.handshake.mutex.Lock()
+ if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
+ peer.handshake.mutex.Unlock()
+ return nil
+ }
+ peer.handshake.lastSentHandshake = time.Now()
+ peer.handshake.mutex.Unlock()
+
+ peer.device.log.Debug.Println(peer, "- Sending handshake initiation")
+
+ msg, err := peer.device.CreateMessageInitiation(peer)
+ if err != nil {
+ peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err)
+ return err
+ }
+
+ var buff [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ packet := writer.Bytes()
+ peer.cookieGenerator.AddMacs(packet)
+
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketSent()
+
+ err = peer.SendBuffer(packet)
+ if err != nil {
+ peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err)
+ }
+ peer.timersHandshakeInitiated()
+
+ return err
+}
+
+func (peer *Peer) SendHandshakeResponse() error {
+ peer.handshake.mutex.Lock()
+ peer.handshake.lastSentHandshake = time.Now()
+ peer.handshake.mutex.Unlock()
+
+ peer.device.log.Debug.Println(peer, "- Sending handshake response")
+
+ response, err := peer.device.CreateMessageResponse(peer)
+ if err != nil {
+ peer.device.log.Error.Println(peer, "- Failed to create response message:", err)
+ return err
+ }
+
+ var buff [MessageResponseSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, response)
+ packet := writer.Bytes()
+ peer.cookieGenerator.AddMacs(packet)
+
+ err = peer.BeginSymmetricSession()
+ if err != nil {
+ peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err)
+ return err
+ }
+
+ peer.timersSessionDerived()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketSent()
+
+ err = peer.SendBuffer(packet)
+ if err != nil {
+ peer.device.log.Error.Println(peer, "- Failed to send handshake response", err)
+ }
+ return err
+}
+
+func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
+
+ device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString())
+
+ sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
+ reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
+ if err != nil {
+ device.log.Error.Println("Failed to create cookie reply:", err)
+ return err
+ }
+
+ var buff [MessageCookieReplySize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, reply)
+ device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
+ if err != nil {
+ device.log.Error.Println("Failed to send cookie reply:", err)
+ }
+ return err
+}
+
+func (peer *Peer) keepKeyFreshSending() {
+ keypair := peer.keypairs.Current()
+ if keypair == nil {
+ return
+ }
+ nonce := atomic.LoadUint64(&keypair.sendNonce)
+ if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) {
+ peer.SendHandshakeInitiation(false)
+ }
+}
+
+/* Reads packets from the TUN and inserts
+ * into nonce queue for peer
+ *
+ * Obs. Single instance per TUN device
+ */
+func (device *Device) RoutineReadFromTUN() {
+
+ logDebug := device.log.Debug
+ logError := device.log.Error
+
+ defer func() {
+ logDebug.Println("Routine: TUN reader - stopped")
+ device.state.stopping.Done()
+ }()
+
+ logDebug.Println("Routine: TUN reader - started")
+ device.state.starting.Done()
+
+ var elem *QueueOutboundElement
+
+ for {
+ if elem != nil {
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ }
+ elem = device.NewOutboundElement()
+
+ // read packet
+
+ offset := MessageTransportHeaderSize
+ size, err := device.tun.device.Read(elem.buffer[:], offset)
+
+ if err != nil {
+ if !device.isClosed.Get() {
+ logError.Println("Failed to read packet from TUN device:", err)
+ device.Close()
+ }
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ return
+ }
+
+ if size == 0 || size > MaxContentSize {
+ continue
+ }
+
+ elem.packet = elem.buffer[offset : offset+size]
+
+ // lookup peer
+
+ var peer *Peer
+ switch elem.packet[0] >> 4 {
+ case ipv4.Version:
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.allowedips.LookupIPv4(dst)
+
+ case ipv6.Version:
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.allowedips.LookupIPv6(dst)
+
+ default:
+ logDebug.Println("Received packet with unknown IP version")
+ }
+
+ if peer == nil {
+ continue
+ }
+
+ // insert into nonce/pre-handshake queue
+
+ if peer.isRunning.Get() {
+ if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
+ peer.SendHandshakeInitiation(false)
+ }
+ addToNonceQueue(peer.queue.nonce, elem, device)
+ elem = nil
+ }
+ }
+}
+
+func (peer *Peer) FlushNonceQueue() {
+ select {
+ case peer.signals.flushNonceQueue <- struct{}{}:
+ default:
+ }
+}
+
+/* Queues packets when there is no handshake.
+ * Then assigns nonces to packets sequentially
+ * and creates "work" structs for workers
+ *
+ * Obs. A single instance per peer
+ */
+func (peer *Peer) RoutineNonce() {
+ var keypair *Keypair
+
+ device := peer.device
+ logDebug := device.log.Debug
+
+ flush := func() {
+ for {
+ select {
+ case elem := <-peer.queue.nonce:
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ default:
+ return
+ }
+ }
+ }
+
+ defer func() {
+ flush()
+ logDebug.Println(peer, "- Routine: nonce worker - stopped")
+ peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
+ peer.routines.stopping.Done()
+ }()
+
+ peer.routines.starting.Done()
+ logDebug.Println(peer, "- Routine: nonce worker - started")
+
+ for {
+ NextPacket:
+ peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
+
+ select {
+ case <-peer.routines.stop:
+ return
+
+ case <-peer.signals.flushNonceQueue:
+ flush()
+ goto NextPacket
+
+ case elem, ok := <-peer.queue.nonce:
+
+ if !ok {
+ return
+ }
+
+ // make sure to always pick the newest key
+
+ for {
+
+ // check validity of newest key pair
+
+ keypair = peer.keypairs.Current()
+ if keypair != nil && keypair.sendNonce < RejectAfterMessages {
+ if time.Now().Sub(keypair.created) < RejectAfterTime {
+ break
+ }
+ }
+ peer.queue.packetInNonceQueueIsAwaitingKey.Set(true)
+
+ // no suitable key pair, request for new handshake
+
+ select {
+ case <-peer.signals.newKeypairArrived:
+ default:
+ }
+
+ peer.SendHandshakeInitiation(false)
+
+ // wait for key to be established
+
+ logDebug.Println(peer, "- Awaiting keypair")
+
+ select {
+ case <-peer.signals.newKeypairArrived:
+ logDebug.Println(peer, "- Obtained awaited keypair")
+
+ case <-peer.signals.flushNonceQueue:
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ flush()
+ goto NextPacket
+
+ case <-peer.routines.stop:
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ return
+ }
+ }
+ peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
+
+ // populate work element
+
+ elem.peer = peer
+ elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
+
+ // double check in case of race condition added by future code
+
+ if elem.nonce >= RejectAfterMessages {
+ atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ goto NextPacket
+ }
+
+ elem.keypair = keypair
+ elem.dropped = AtomicFalse
+ elem.Lock()
+
+ // add to parallel and sequential queue
+ addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
+ }
+ }
+}
+
+/* Encrypts the elements in the queue
+ * and marks them for sequential consumption (by releasing the mutex)
+ *
+ * Obs. One instance per core
+ */
+func (device *Device) RoutineEncryption() {
+
+ var nonce [chacha20poly1305.NonceSize]byte
+
+ logDebug := device.log.Debug
+
+ defer func() {
+ for {
+ select {
+ case elem, ok := <-device.queue.encryption:
+ if ok && !elem.IsDropped() {
+ elem.Drop()
+ device.PutMessageBuffer(elem.buffer)
+ elem.Unlock()
+ }
+ default:
+ goto out
+ }
+ }
+ out:
+ logDebug.Println("Routine: encryption worker - stopped")
+ device.state.stopping.Done()
+ }()
+
+ logDebug.Println("Routine: encryption worker - started")
+ device.state.starting.Done()
+
+ for {
+
+ // fetch next element
+
+ select {
+ case <-device.signals.stop:
+ return
+
+ case elem, ok := <-device.queue.encryption:
+
+ if !ok {
+ return
+ }
+
+ // check if dropped
+
+ if elem.IsDropped() {
+ continue
+ }
+
+ // populate header fields
+
+ header := elem.buffer[:MessageTransportHeaderSize]
+
+ fieldType := header[0:4]
+ fieldReceiver := header[4:8]
+ fieldNonce := header[8:16]
+
+ binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
+ binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
+ binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
+
+ // pad content to multiple of 16
+
+ mtu := int(atomic.LoadInt32(&device.tun.mtu))
+ lastUnit := len(elem.packet) % mtu
+ paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
+ if paddedSize > mtu {
+ paddedSize = mtu
+ }
+ for i := len(elem.packet); i < paddedSize; i++ {
+ elem.packet = append(elem.packet, 0)
+ }
+
+ // encrypt content and release to consumer
+
+ binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
+ elem.packet = elem.keypair.send.Seal(
+ header,
+ nonce[:],
+ elem.packet,
+ nil,
+ )
+ elem.Unlock()
+ }
+ }
+}
+
+/* Sequentially reads packets from queue and sends to endpoint
+ *
+ * Obs. Single instance per peer.
+ * The routine terminates then the outbound queue is closed.
+ */
+func (peer *Peer) RoutineSequentialSender() {
+
+ device := peer.device
+
+ logDebug := device.log.Debug
+ logError := device.log.Error
+
+ defer func() {
+ for {
+ select {
+ case elem, ok := <-peer.queue.outbound:
+ if ok {
+ if !elem.IsDropped() {
+ device.PutMessageBuffer(elem.buffer)
+ elem.Drop()
+ }
+ device.PutOutboundElement(elem)
+ }
+ default:
+ goto out
+ }
+ }
+ out:
+ logDebug.Println(peer, "- Routine: sequential sender - stopped")
+ peer.routines.stopping.Done()
+ }()
+
+ logDebug.Println(peer, "- Routine: sequential sender - started")
+
+ peer.routines.starting.Done()
+
+ for {
+ select {
+
+ case <-peer.routines.stop:
+ return
+
+ case elem, ok := <-peer.queue.outbound:
+
+ if !ok {
+ return
+ }
+
+ elem.Lock()
+ if elem.IsDropped() {
+ device.PutOutboundElement(elem)
+ continue
+ }
+
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketSent()
+
+ // send message and return buffer to pool
+
+ length := uint64(len(elem.packet))
+ err := peer.SendBuffer(elem.packet)
+ device.PutMessageBuffer(elem.buffer)
+ device.PutOutboundElement(elem)
+ if err != nil {
+ logError.Println(peer, "- Failed to send data packet", err)
+ continue
+ }
+ atomic.AddUint64(&peer.stats.txBytes, length)
+
+ if len(elem.packet) != MessageKeepaliveSize {
+ peer.timersDataSent()
+ }
+ peer.keepKeyFreshSending()
+ }
+ }
+}
diff --git a/device/timers.go b/device/timers.go
new file mode 100644
index 0000000..5f28fcc
--- /dev/null
+++ b/device/timers.go
@@ -0,0 +1,227 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ *
+ * This is based heavily on timers.c from the kernel implementation.
+ */
+
+package device
+
+import (
+ "math/rand"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+/* This Timer structure and related functions should roughly copy the interface of
+ * the Linux kernel's struct timer_list.
+ */
+
+type Timer struct {
+ *time.Timer
+ modifyingLock sync.RWMutex
+ runningLock sync.Mutex
+ isPending bool
+}
+
+func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
+ timer := &Timer{}
+ timer.Timer = time.AfterFunc(time.Hour, func() {
+ timer.runningLock.Lock()
+
+ timer.modifyingLock.Lock()
+ if !timer.isPending {
+ timer.modifyingLock.Unlock()
+ timer.runningLock.Unlock()
+ return
+ }
+ timer.isPending = false
+ timer.modifyingLock.Unlock()
+
+ expirationFunction(peer)
+ timer.runningLock.Unlock()
+ })
+ timer.Stop()
+ return timer
+}
+
+func (timer *Timer) Mod(d time.Duration) {
+ timer.modifyingLock.Lock()
+ timer.isPending = true
+ timer.Reset(d)
+ timer.modifyingLock.Unlock()
+}
+
+func (timer *Timer) Del() {
+ timer.modifyingLock.Lock()
+ timer.isPending = false
+ timer.Stop()
+ timer.modifyingLock.Unlock()
+}
+
+func (timer *Timer) DelSync() {
+ timer.Del()
+ timer.runningLock.Lock()
+ timer.Del()
+ timer.runningLock.Unlock()
+}
+
+func (timer *Timer) IsPending() bool {
+ timer.modifyingLock.RLock()
+ defer timer.modifyingLock.RUnlock()
+ return timer.isPending
+}
+
+func (peer *Peer) timersActive() bool {
+ return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
+}
+
+func expiredRetransmitHandshake(peer *Peer) {
+ if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
+ peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
+
+ if peer.timersActive() {
+ peer.timers.sendKeepalive.Del()
+ }
+
+ /* We drop all packets without a keypair and don't try again,
+ * if we try unsuccessfully for too long to make a handshake.
+ */
+ peer.FlushNonceQueue()
+
+ /* We set a timer for destroying any residue that might be left
+ * of a partial exchange.
+ */
+ if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() {
+ peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
+ }
+ } else {
+ atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
+ peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
+
+ /* We clear the endpoint address src address, in case this is the cause of trouble. */
+ peer.Lock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ peer.Unlock()
+
+ peer.SendHandshakeInitiation(true)
+ }
+}
+
+func expiredSendKeepalive(peer *Peer) {
+ peer.SendKeepalive()
+ if peer.timers.needAnotherKeepalive.Get() {
+ peer.timers.needAnotherKeepalive.Set(false)
+ if peer.timersActive() {
+ peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
+ }
+ }
+}
+
+func expiredNewHandshake(peer *Peer) {
+ peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
+ /* We clear the endpoint address src address, in case this is the cause of trouble. */
+ peer.Lock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ peer.Unlock()
+ peer.SendHandshakeInitiation(false)
+
+}
+
+func expiredZeroKeyMaterial(peer *Peer) {
+ peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
+ peer.ZeroAndFlushAll()
+}
+
+func expiredPersistentKeepalive(peer *Peer) {
+ if peer.persistentKeepaliveInterval > 0 {
+ peer.SendKeepalive()
+ }
+}
+
+/* Should be called after an authenticated data packet is sent. */
+func (peer *Peer) timersDataSent() {
+ if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
+ peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout)
+ }
+}
+
+/* Should be called after an authenticated data packet is received. */
+func (peer *Peer) timersDataReceived() {
+ if peer.timersActive() {
+ if !peer.timers.sendKeepalive.IsPending() {
+ peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
+ } else {
+ peer.timers.needAnotherKeepalive.Set(true)
+ }
+ }
+}
+
+/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */
+func (peer *Peer) timersAnyAuthenticatedPacketSent() {
+ if peer.timersActive() {
+ peer.timers.sendKeepalive.Del()
+ }
+}
+
+/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */
+func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
+ if peer.timersActive() {
+ peer.timers.newHandshake.Del()
+ }
+}
+
+/* Should be called after a handshake initiation message is sent. */
+func (peer *Peer) timersHandshakeInitiated() {
+ if peer.timersActive() {
+ peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ }
+}
+
+/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
+func (peer *Peer) timersHandshakeComplete() {
+ if peer.timersActive() {
+ peer.timers.retransmitHandshake.Del()
+ }
+ atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+ peer.timers.sentLastMinuteHandshake.Set(false)
+ atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+}
+
+/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
+func (peer *Peer) timersSessionDerived() {
+ if peer.timersActive() {
+ peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
+ }
+}
+
+/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
+func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
+ if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
+ peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
+ }
+}
+
+func (peer *Peer) timersInit() {
+ peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
+ peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
+ peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
+ peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
+ peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
+ atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
+ peer.timers.sentLastMinuteHandshake.Set(false)
+ peer.timers.needAnotherKeepalive.Set(false)
+}
+
+func (peer *Peer) timersStop() {
+ peer.timers.retransmitHandshake.DelSync()
+ peer.timers.sendKeepalive.DelSync()
+ peer.timers.newHandshake.DelSync()
+ peer.timers.zeroKeyMaterial.DelSync()
+ peer.timers.persistentKeepalive.DelSync()
+}
diff --git a/device/tun.go b/device/tun.go
new file mode 100644
index 0000000..bc5f1f1
--- /dev/null
+++ b/device/tun.go
@@ -0,0 +1,55 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "golang.zx2c4.com/wireguard/tun"
+ "sync/atomic"
+)
+
+const DefaultMTU = 1420
+
+func (device *Device) RoutineTUNEventReader() {
+ setUp := false
+ logDebug := device.log.Debug
+ logInfo := device.log.Info
+ logError := device.log.Error
+
+ logDebug.Println("Routine: event worker - started")
+ device.state.starting.Done()
+
+ for event := range device.tun.device.Events() {
+ if event&tun.TUNEventMTUUpdate != 0 {
+ mtu, err := device.tun.device.MTU()
+ old := atomic.LoadInt32(&device.tun.mtu)
+ if err != nil {
+ logError.Println("Failed to load updated MTU of device:", err)
+ } else if int(old) != mtu {
+ if mtu+MessageTransportSize > MaxMessageSize {
+ logInfo.Println("MTU updated:", mtu, "(too large)")
+ } else {
+ logInfo.Println("MTU updated:", mtu)
+ }
+ atomic.StoreInt32(&device.tun.mtu, int32(mtu))
+ }
+ }
+
+ if event&tun.TUNEventUp != 0 && !setUp {
+ logInfo.Println("Interface set up")
+ setUp = true
+ device.Up()
+ }
+
+ if event&tun.TUNEventDown != 0 && setUp {
+ logInfo.Println("Interface set down")
+ setUp = false
+ device.Down()
+ }
+ }
+
+ logDebug.Println("Routine: event worker - stopped")
+ device.state.stopping.Done()
+}
diff --git a/device/uapi.go b/device/uapi.go
new file mode 100644
index 0000000..5c65917
--- /dev/null
+++ b/device/uapi.go
@@ -0,0 +1,426 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package device
+
+import (
+ "bufio"
+ "fmt"
+ "golang.zx2c4.com/wireguard/ipc"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+)
+
+type IPCError struct {
+ int64
+}
+
+func (s *IPCError) Error() string {
+ return fmt.Sprintf("IPC error: %d", s.int64)
+}
+
+func (s *IPCError) ErrorCode() int64 {
+ return s.int64
+}
+
+func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
+
+ device.log.Debug.Println("UAPI: Processing get operation")
+
+ // create lines
+
+ lines := make([]string, 0, 100)
+ send := func(line string) {
+ lines = append(lines, line)
+ }
+
+ func() {
+
+ // lock required resources
+
+ device.net.RLock()
+ defer device.net.RUnlock()
+
+ device.staticIdentity.RLock()
+ defer device.staticIdentity.RUnlock()
+
+ device.peers.RLock()
+ defer device.peers.RUnlock()
+
+ // serialize device related values
+
+ if !device.staticIdentity.privateKey.IsZero() {
+ send("private_key=" + device.staticIdentity.privateKey.ToHex())
+ }
+
+ if device.net.port != 0 {
+ send(fmt.Sprintf("listen_port=%d", device.net.port))
+ }
+
+ if device.net.fwmark != 0 {
+ send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
+ }
+
+ // serialize each peer state
+
+ for _, peer := range device.peers.keyMap {
+ peer.RLock()
+ defer peer.RUnlock()
+
+ send("public_key=" + peer.handshake.remoteStatic.ToHex())
+ send("preshared_key=" + peer.handshake.presharedKey.ToHex())
+ send("protocol_version=1")
+ if peer.endpoint != nil {
+ send("endpoint=" + peer.endpoint.DstToString())
+ }
+
+ nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
+ secs := nano / time.Second.Nanoseconds()
+ nano %= time.Second.Nanoseconds()
+
+ send(fmt.Sprintf("last_handshake_time_sec=%d", secs))
+ send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
+ send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
+ send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
+ send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
+
+ for _, ip := range device.allowedips.EntriesForPeer(peer) {
+ send("allowed_ip=" + ip.String())
+ }
+
+ }
+ }()
+
+ // send lines (does not require resource locks)
+
+ for _, line := range lines {
+ _, err := socket.WriteString(line + "\n")
+ if err != nil {
+ return &IPCError{ipc.IpcErrorIO}
+ }
+ }
+
+ return nil
+}
+
+func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
+ scanner := bufio.NewScanner(socket)
+ logError := device.log.Error
+ logDebug := device.log.Debug
+
+ var peer *Peer
+
+ dummy := false
+ deviceConfig := true
+
+ for scanner.Scan() {
+
+ // parse line
+
+ line := scanner.Text()
+ if line == "" {
+ return nil
+ }
+ parts := strings.Split(line, "=")
+ if len(parts) != 2 {
+ return &IPCError{ipc.IpcErrorProtocol}
+ }
+ key := parts[0]
+ value := parts[1]
+
+ /* device configuration */
+
+ if deviceConfig {
+
+ switch key {
+ case "private_key":
+ var sk NoisePrivateKey
+ err := sk.FromHex(value)
+ if err != nil {
+ logError.Println("Failed to set private_key:", err)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+ logDebug.Println("UAPI: Updating private key")
+ device.SetPrivateKey(sk)
+
+ case "listen_port":
+
+ // parse port number
+
+ port, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ logError.Println("Failed to parse listen_port:", err)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ // update port and rebind
+
+ logDebug.Println("UAPI: Updating listen port")
+
+ device.net.Lock()
+ device.net.port = uint16(port)
+ device.net.Unlock()
+
+ if err := device.BindUpdate(); err != nil {
+ logError.Println("Failed to set listen_port:", err)
+ return &IPCError{ipc.IpcErrorPortInUse}
+ }
+
+ case "fwmark":
+
+ // parse fwmark field
+
+ fwmark, err := func() (uint32, error) {
+ if value == "" {
+ return 0, nil
+ }
+ mark, err := strconv.ParseUint(value, 10, 32)
+ return uint32(mark), err
+ }()
+
+ if err != nil {
+ logError.Println("Invalid fwmark", err)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ logDebug.Println("UAPI: Updating fwmark")
+
+ if err := device.BindSetMark(uint32(fwmark)); err != nil {
+ logError.Println("Failed to update fwmark:", err)
+ return &IPCError{ipc.IpcErrorPortInUse}
+ }
+
+ case "public_key":
+ // switch to peer configuration
+ logDebug.Println("UAPI: Transition to peer configuration")
+ deviceConfig = false
+
+ case "replace_peers":
+ if value != "true" {
+ logError.Println("Failed to set replace_peers, invalid value:", value)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+ logDebug.Println("UAPI: Removing all peers")
+ device.RemoveAllPeers()
+
+ default:
+ logError.Println("Invalid UAPI device key:", key)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+ }
+
+ /* peer configuration */
+
+ if !deviceConfig {
+
+ switch key {
+
+ case "public_key":
+ var publicKey NoisePublicKey
+ err := publicKey.FromHex(value)
+ if err != nil {
+ logError.Println("Failed to get peer by public key:", err)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ // ignore peer with public key of device
+
+ device.staticIdentity.RLock()
+ dummy = device.staticIdentity.publicKey.Equals(publicKey)
+ device.staticIdentity.RUnlock()
+
+ if dummy {
+ peer = &Peer{}
+ } else {
+ peer = device.LookupPeer(publicKey)
+ }
+
+ if peer == nil {
+ peer, err = device.NewPeer(publicKey)
+ if err != nil {
+ logError.Println("Failed to create new peer:", err)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+ logDebug.Println(peer, "- UAPI: Created")
+ }
+
+ case "remove":
+
+ // remove currently selected peer from device
+
+ if value != "true" {
+ logError.Println("Failed to set remove, invalid value:", value)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+ if !dummy {
+ logDebug.Println(peer, "- UAPI: Removing")
+ device.RemovePeer(peer.handshake.remoteStatic)
+ }
+ peer = &Peer{}
+ dummy = true
+
+ case "preshared_key":
+
+ // update PSK
+
+ logDebug.Println(peer, "- UAPI: Updating preshared key")
+
+ peer.handshake.mutex.Lock()
+ err := peer.handshake.presharedKey.FromHex(value)
+ peer.handshake.mutex.Unlock()
+
+ if err != nil {
+ logError.Println("Failed to set preshared key:", err)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ case "endpoint":
+
+ // set endpoint destination
+
+ logDebug.Println(peer, "- UAPI: Updating endpoint")
+
+ err := func() error {
+ peer.Lock()
+ defer peer.Unlock()
+ endpoint, err := CreateEndpoint(value)
+ if err != nil {
+ return err
+ }
+ peer.endpoint = endpoint
+ return nil
+ }()
+
+ if err != nil {
+ logError.Println("Failed to set endpoint:", value)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ case "persistent_keepalive_interval":
+
+ // update persistent keepalive interval
+
+ logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
+
+ secs, err := strconv.ParseUint(value, 10, 16)
+ if err != nil {
+ logError.Println("Failed to set persistent keepalive interval:", err)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ old := peer.persistentKeepaliveInterval
+ peer.persistentKeepaliveInterval = uint16(secs)
+
+ // send immediate keepalive if we're turning it on and before it wasn't on
+
+ if old == 0 && secs != 0 {
+ if err != nil {
+ logError.Println("Failed to get tun device status:", err)
+ return &IPCError{ipc.IpcErrorIO}
+ }
+ if device.isUp.Get() && !dummy {
+ peer.SendKeepalive()
+ }
+ }
+
+ case "replace_allowed_ips":
+
+ logDebug.Println(peer, "- UAPI: Removing all allowedips")
+
+ if value != "true" {
+ logError.Println("Failed to replace allowedips, invalid value:", value)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ if dummy {
+ continue
+ }
+
+ device.allowedips.RemoveByPeer(peer)
+
+ case "allowed_ip":
+
+ logDebug.Println(peer, "- UAPI: Adding allowedip")
+
+ _, network, err := net.ParseCIDR(value)
+ if err != nil {
+ logError.Println("Failed to set allowed ip:", err)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ if dummy {
+ continue
+ }
+
+ ones, _ := network.Mask.Size()
+ device.allowedips.Insert(network.IP, uint(ones), peer)
+
+ case "protocol_version":
+
+ if value != "1" {
+ logError.Println("Invalid protocol version:", value)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+
+ default:
+ logError.Println("Invalid UAPI peer key:", key)
+ return &IPCError{ipc.IpcErrorInvalid}
+ }
+ }
+ }
+
+ return nil
+}
+
+func (device *Device) IpcHandle(socket net.Conn) {
+
+ // create buffered read/writer
+
+ defer socket.Close()
+
+ buffered := func(s io.ReadWriter) *bufio.ReadWriter {
+ reader := bufio.NewReader(s)
+ writer := bufio.NewWriter(s)
+ return bufio.NewReadWriter(reader, writer)
+ }(socket)
+
+ defer buffered.Flush()
+
+ op, err := buffered.ReadString('\n')
+ if err != nil {
+ return
+ }
+
+ // handle operation
+
+ var status *IPCError
+
+ switch op {
+ case "set=1\n":
+ device.log.Debug.Println("UAPI: Set operation")
+ status = device.IpcSetOperation(buffered.Reader)
+
+ case "get=1\n":
+ device.log.Debug.Println("UAPI: Get operation")
+ status = device.IpcGetOperation(buffered.Writer)
+
+ default:
+ device.log.Error.Println("Invalid UAPI operation:", op)
+ return
+ }
+
+ // write status
+
+ if status != nil {
+ device.log.Error.Println(status)
+ fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
+ }
+}
diff --git a/device/version.go b/device/version.go
new file mode 100644
index 0000000..9077cdc
--- /dev/null
+++ b/device/version.go
@@ -0,0 +1,3 @@
+package device
+
+const WireGuardGoVersion = "0.0.20181222"