From 69f0fe67b63d90e523a5a1241fb1b46c2e8dbe03 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sun, 3 Mar 2019 04:04:41 +0100 Subject: global: begin modularization --- device/allowedips.go | 251 ++++++++++++++ device/allowedips_rand_test.go | 131 ++++++++ device/allowedips_test.go | 260 ++++++++++++++ device/bind_test.go | 55 +++ device/conn.go | 180 ++++++++++ device/conn_default.go | 170 ++++++++++ device/conn_linux.go | 746 +++++++++++++++++++++++++++++++++++++++++ device/constants.go | 41 +++ device/cookie.go | 250 ++++++++++++++ device/cookie_test.go | 191 +++++++++++ device/device.go | 396 ++++++++++++++++++++++ device/device_test.go | 48 +++ device/endpoint_test.go | 53 +++ device/indextable.go | 97 ++++++ device/ip.go | 22 ++ device/kdf_test.go | 84 +++++ device/keypair.go | 50 +++ device/logger.go | 59 ++++ device/mark_default.go | 12 + device/mark_unix.go | 64 ++++ device/misc.go | 48 +++ device/noise-helpers.go | 104 ++++++ device/noise-protocol.go | 600 +++++++++++++++++++++++++++++++++ device/noise-types.go | 81 +++++ device/noise_test.go | 144 ++++++++ device/peer.go | 270 +++++++++++++++ device/pools.go | 89 +++++ device/queueconstants.go | 16 + device/receive.go | 641 +++++++++++++++++++++++++++++++++++ device/send.go | 618 ++++++++++++++++++++++++++++++++++ device/timers.go | 227 +++++++++++++ device/tun.go | 55 +++ device/uapi.go | 426 +++++++++++++++++++++++ device/version.go | 3 + 34 files changed, 6482 insertions(+) create mode 100644 device/allowedips.go create mode 100644 device/allowedips_rand_test.go create mode 100644 device/allowedips_test.go create mode 100644 device/bind_test.go create mode 100644 device/conn.go create mode 100644 device/conn_default.go create mode 100644 device/conn_linux.go create mode 100644 device/constants.go create mode 100644 device/cookie.go create mode 100644 device/cookie_test.go create mode 100644 device/device.go create mode 100644 device/device_test.go create mode 100644 device/endpoint_test.go create mode 100644 device/indextable.go create mode 100644 device/ip.go create mode 100644 device/kdf_test.go create mode 100644 device/keypair.go create mode 100644 device/logger.go create mode 100644 device/mark_default.go create mode 100644 device/mark_unix.go create mode 100644 device/misc.go create mode 100644 device/noise-helpers.go create mode 100644 device/noise-protocol.go create mode 100644 device/noise-types.go create mode 100644 device/noise_test.go create mode 100644 device/peer.go create mode 100644 device/pools.go create mode 100644 device/queueconstants.go create mode 100644 device/receive.go create mode 100644 device/send.go create mode 100644 device/timers.go create mode 100644 device/tun.go create mode 100644 device/uapi.go create mode 100644 device/version.go (limited to 'device') diff --git a/device/allowedips.go b/device/allowedips.go new file mode 100644 index 0000000..efc27c0 --- /dev/null +++ b/device/allowedips.go @@ -0,0 +1,251 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + "math/bits" + "net" + "sync" + "unsafe" +) + +type trieEntry struct { + cidr uint + child [2]*trieEntry + bits net.IP + peer *Peer + + // index of "branching" bit + + bit_at_byte uint + bit_at_shift uint +} + +func isLittleEndian() bool { + one := uint32(1) + return *(*byte)(unsafe.Pointer(&one)) != 0 +} + +func swapU32(i uint32) uint32 { + if !isLittleEndian() { + return i + } + + return bits.ReverseBytes32(i) +} + +func swapU64(i uint64) uint64 { + if !isLittleEndian() { + return i + } + + return bits.ReverseBytes64(i) +} + +func commonBits(ip1 net.IP, ip2 net.IP) uint { + size := len(ip1) + if size == net.IPv4len { + a := (*uint32)(unsafe.Pointer(&ip1[0])) + b := (*uint32)(unsafe.Pointer(&ip2[0])) + x := *a ^ *b + return uint(bits.LeadingZeros32(swapU32(x))) + } else if size == net.IPv6len { + a := (*uint64)(unsafe.Pointer(&ip1[0])) + b := (*uint64)(unsafe.Pointer(&ip2[0])) + x := *a ^ *b + if x != 0 { + return uint(bits.LeadingZeros64(swapU64(x))) + } + a = (*uint64)(unsafe.Pointer(&ip1[8])) + b = (*uint64)(unsafe.Pointer(&ip2[8])) + x = *a ^ *b + return 64 + uint(bits.LeadingZeros64(swapU64(x))) + } else { + panic("Wrong size bit string") + } +} + +func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { + if node == nil { + return node + } + + // walk recursively + + node.child[0] = node.child[0].removeByPeer(p) + node.child[1] = node.child[1].removeByPeer(p) + + if node.peer != p { + return node + } + + // remove peer & merge + + node.peer = nil + if node.child[0] == nil { + return node.child[1] + } + return node.child[0] +} + +func (node *trieEntry) choose(ip net.IP) byte { + return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +} + +func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { + + // at leaf + + if node == nil { + return &trieEntry{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + } + + // traverse deeper + + common := commonBits(node.bits, ip) + if node.cidr <= cidr && common >= node.cidr { + if node.cidr == cidr { + node.peer = peer + return node + } + bit := node.choose(ip) + node.child[bit] = node.child[bit].insert(ip, cidr, peer) + return node + } + + // split node + + newNode := &trieEntry{ + bits: ip, + peer: peer, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + cidr = min(cidr, common) + + // check for shorter prefix + + if newNode.cidr == cidr { + bit := newNode.choose(node.bits) + newNode.child[bit] = node + return newNode + } + + // create new parent for node & newNode + + parent := &trieEntry{ + bits: ip, + peer: nil, + cidr: cidr, + bit_at_byte: cidr / 8, + bit_at_shift: 7 - (cidr % 8), + } + + bit := parent.choose(ip) + parent.child[bit] = newNode + parent.child[bit^1] = node + + return parent +} + +func (node *trieEntry) lookup(ip net.IP) *Peer { + var found *Peer + size := uint(len(ip)) + for node != nil && commonBits(node.bits, ip) >= node.cidr { + if node.peer != nil { + found = node.peer + } + if node.bit_at_byte == size { + break + } + bit := node.choose(ip) + node = node.child[bit] + } + return found +} + +func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { + if node == nil { + return results + } + if node.peer == p { + mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) + results = append(results, net.IPNet{ + Mask: mask, + IP: node.bits.Mask(mask), + }) + } + results = node.child[0].entriesForPeer(p, results) + results = node.child[1].entriesForPeer(p, results) + return results +} + +type AllowedIPs struct { + IPv4 *trieEntry + IPv6 *trieEntry + mutex sync.RWMutex +} + +func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { + table.mutex.RLock() + defer table.mutex.RUnlock() + + allowed := make([]net.IPNet, 0, 10) + allowed = table.IPv4.entriesForPeer(peer, allowed) + allowed = table.IPv6.entriesForPeer(peer, allowed) + return allowed +} + +func (table *AllowedIPs) Reset() { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = nil + table.IPv6 = nil +} + +func (table *AllowedIPs) RemoveByPeer(peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = table.IPv4.removeByPeer(peer) + table.IPv6 = table.IPv6.removeByPeer(peer) +} + +func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + switch len(ip) { + case net.IPv6len: + table.IPv6 = table.IPv6.insert(ip, cidr, peer) + case net.IPv4len: + table.IPv4 = table.IPv4.insert(ip, cidr, peer) + default: + panic(errors.New("inserting unknown address type")) + } +} + +func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv4.lookup(address) +} + +func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv6.lookup(address) +} diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go new file mode 100644 index 0000000..59c10f7 --- /dev/null +++ b/device/allowedips_rand_test.go @@ -0,0 +1,131 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "sort" + "testing" +) + +const ( + NumberOfPeers = 100 + NumberOfAddresses = 250 + NumberOfTests = 10000 +) + +type SlowNode struct { + peer *Peer + cidr uint + bits []byte +} + +type SlowRouter []*SlowNode + +func (r SlowRouter) Len() int { + return len(r) +} + +func (r SlowRouter) Less(i, j int) bool { + return r[i].cidr > r[j].cidr +} + +func (r SlowRouter) Swap(i, j int) { + r[i], r[j] = r[j], r[i] +} + +func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { + for _, t := range r { + if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { + t.peer = peer + t.bits = addr + return r + } + } + r = append(r, &SlowNode{ + cidr: cidr, + bits: addr, + peer: peer, + }) + sort.Sort(r) + return r +} + +func (r SlowRouter) Lookup(addr []byte) *Peer { + for _, t := range r { + common := commonBits(t.bits, addr) + if common >= t.cidr { + return t.peer + } + } + return nil +} + +func TestTrieRandomIPv4(t *testing.T) { + var trie *trieEntry + var slow SlowRouter + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < NumberOfPeers; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < NumberOfAddresses; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % NumberOfPeers + trie = trie.insert(addr[:], cidr, peers[index]) + slow = slow.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < NumberOfTests; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + peer1 := slow.Lookup(addr[:]) + peer2 := trie.lookup(addr[:]) + if peer1 != peer2 { + t.Error("Trie did not match naive implementation, for:", addr) + } + } +} + +func TestTrieRandomIPv6(t *testing.T) { + var trie *trieEntry + var slow SlowRouter + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 16 + + for n := 0; n < NumberOfPeers; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < NumberOfAddresses; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % NumberOfPeers + trie = trie.insert(addr[:], cidr, peers[index]) + slow = slow.Insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < NumberOfTests; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + peer1 := slow.Lookup(addr[:]) + peer2 := trie.lookup(addr[:]) + if peer1 != peer2 { + t.Error("Trie did not match naive implementation, for:", addr) + } + } +} diff --git a/device/allowedips_test.go b/device/allowedips_test.go new file mode 100644 index 0000000..075ff06 --- /dev/null +++ b/device/allowedips_test.go @@ -0,0 +1,260 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "net" + "testing" +) + +/* Todo: More comprehensive + */ + +type testPairCommonBits struct { + s1 []byte + s2 []byte + match uint +} + +type testPairTrieInsert struct { + key []byte + cidr uint + peer *Peer +} + +type testPairTrieLookup struct { + key []byte + peer *Peer +} + +func printTrie(t *testing.T, p *trieEntry) { + if p == nil { + return + } + t.Log(p) + printTrie(t, p.child[0]) + printTrie(t, p.child[1]) +} + +func TestCommonBits(t *testing.T) { + + tests := []testPairCommonBits{ + {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, + {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, + {s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31}, + {s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15}, + {s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0}, + } + + for _, p := range tests { + v := commonBits(p.s1, p.s2) + if v != p.match { + t.Error( + "For slice", p.s1, p.s2, + "expected match", p.match, + ",but got", v, + ) + } + } +} + +func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { + var trie *trieEntry + var peers []*Peer + + rand.Seed(1) + + const AddressLength = 4 + + for n := 0; n < peerNumber; n += 1 { + peers = append(peers, &Peer{}) + } + + for n := 0; n < addressNumber; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + cidr := uint(rand.Uint32() % (AddressLength * 8)) + index := rand.Int() % peerNumber + trie = trie.insert(addr[:], cidr, peers[index]) + } + + for n := 0; n < b.N; n += 1 { + var addr [AddressLength]byte + rand.Read(addr[:]) + trie.lookup(addr[:]) + } +} + +func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv4len, b) +} + +func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv4len, b) +} + +func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) { + benchmarkTrie(100, 1000, net.IPv6len, b) +} + +func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) { + benchmarkTrie(10, 10, net.IPv6len, b) +} + +/* Test ported from kernel implementation: + * selftest/allowedips.h + */ +func TestTrieIPv4(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *trieEntry + + insert := func(peer *Peer, a, b, c, d byte, cidr uint) { + trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.lookup([]byte{a, b, c, d}) + if p != peer { + t.Error("Assert EQ failed") + } + } + + assertNEQ := func(peer *Peer, a, b, c, d byte) { + p := trie.lookup([]byte{a, b, c, d}) + if p == peer { + t.Error("Assert NEQ failed") + } + } + + insert(a, 192, 168, 4, 0, 24) + insert(b, 192, 168, 4, 4, 32) + insert(c, 192, 168, 0, 0, 16) + insert(d, 192, 95, 5, 64, 27) + insert(c, 192, 95, 5, 65, 27) + insert(e, 0, 0, 0, 0, 0) + insert(g, 64, 15, 112, 0, 20) + insert(h, 64, 15, 123, 211, 25) + insert(a, 10, 0, 0, 0, 25) + insert(b, 10, 0, 0, 128, 25) + insert(a, 10, 1, 0, 0, 30) + insert(b, 10, 1, 0, 4, 30) + insert(c, 10, 1, 0, 8, 29) + insert(d, 10, 1, 0, 16, 29) + + assertEQ(a, 192, 168, 4, 20) + assertEQ(a, 192, 168, 4, 0) + assertEQ(b, 192, 168, 4, 4) + assertEQ(c, 192, 168, 200, 182) + assertEQ(c, 192, 95, 5, 68) + assertEQ(e, 192, 95, 5, 96) + assertEQ(g, 64, 15, 116, 26) + assertEQ(g, 64, 15, 127, 3) + + insert(a, 1, 0, 0, 0, 32) + insert(a, 64, 0, 0, 0, 32) + insert(a, 128, 0, 0, 0, 32) + insert(a, 192, 0, 0, 0, 32) + insert(a, 255, 0, 0, 0, 32) + + assertEQ(a, 1, 0, 0, 0) + assertEQ(a, 64, 0, 0, 0) + assertEQ(a, 128, 0, 0, 0) + assertEQ(a, 192, 0, 0, 0) + assertEQ(a, 255, 0, 0, 0) + + trie = trie.removeByPeer(a) + + assertNEQ(a, 1, 0, 0, 0) + assertNEQ(a, 64, 0, 0, 0) + assertNEQ(a, 128, 0, 0, 0) + assertNEQ(a, 192, 0, 0, 0) + assertNEQ(a, 255, 0, 0, 0) + + trie = nil + + insert(a, 192, 168, 0, 0, 16) + insert(a, 192, 168, 0, 0, 24) + + trie = trie.removeByPeer(a) + + assertNEQ(a, 192, 168, 0, 1) +} + +/* Test ported from kernel implementation: + * selftest/allowedips.h + */ +func TestTrieIPv6(t *testing.T) { + a := &Peer{} + b := &Peer{} + c := &Peer{} + d := &Peer{} + e := &Peer{} + f := &Peer{} + g := &Peer{} + h := &Peer{} + + var trie *trieEntry + + expand := func(a uint32) []byte { + var out [4]byte + out[0] = byte(a >> 24 & 0xff) + out[1] = byte(a >> 16 & 0xff) + out[2] = byte(a >> 8 & 0xff) + out[3] = byte(a & 0xff) + return out[:] + } + + insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + trie = trie.insert(addr, cidr, peer) + } + + assertEQ := func(peer *Peer, a, b, c, d uint32) { + var addr []byte + addr = append(addr, expand(a)...) + addr = append(addr, expand(b)...) + addr = append(addr, expand(c)...) + addr = append(addr, expand(d)...) + p := trie.lookup(addr) + if p != peer { + t.Error("Assert EQ failed") + } + } + + insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128) + insert(c, 0x26075300, 0x60006b00, 0, 0, 64) + insert(e, 0, 0, 0, 0, 0) + insert(f, 0, 0, 0, 0, 0) + insert(g, 0x24046800, 0, 0, 0, 32) + insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64) + insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128) + insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128) + insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98) + + assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543) + assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee) + assertEQ(f, 0x26075300, 0x60006b01, 0, 0) + assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006) + assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678) + assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678) + assertEQ(h, 0x24046800, 0x40040800, 0, 0) + assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010) + assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef) +} diff --git a/device/bind_test.go b/device/bind_test.go new file mode 100644 index 0000000..0c2e2cf --- /dev/null +++ b/device/bind_test.go @@ -0,0 +1,55 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import "errors" + +type DummyDatagram struct { + msg []byte + endpoint Endpoint + world bool // better type +} + +type DummyBind struct { + in6 chan DummyDatagram + ou6 chan DummyDatagram + in4 chan DummyDatagram + ou4 chan DummyDatagram + closed bool +} + +func (b *DummyBind) SetMark(v uint32) error { + return nil +} + +func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + datagram, ok := <-b.in6 + if !ok { + return 0, nil, errors.New("closed") + } + copy(buff, datagram.msg) + return len(datagram.msg), datagram.endpoint, nil +} + +func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + datagram, ok := <-b.in4 + if !ok { + return 0, nil, errors.New("closed") + } + copy(buff, datagram.msg) + return len(datagram.msg), datagram.endpoint, nil +} + +func (b *DummyBind) Close() error { + close(b.in6) + close(b.in4) + b.closed = true + return nil +} + +func (b *DummyBind) Send(buff []byte, end Endpoint) error { + return nil +} diff --git a/device/conn.go b/device/conn.go new file mode 100644 index 0000000..2594680 --- /dev/null +++ b/device/conn.go @@ -0,0 +1,180 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" +) + +const ( + ConnRoutineNumber = 2 +) + +/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic + */ +type Bind interface { + SetMark(value uint32) error + ReceiveIPv6(buff []byte) (int, Endpoint, error) + ReceiveIPv4(buff []byte) (int, Endpoint, error) + Send(buff []byte, end Endpoint) error + Close() error +} + +/* An Endpoint maintains the source/destination caching for a peer + * + * dst : the remote address of a peer ("endpoint" in uapi terminology) + * src : the local address from which datagrams originate going to the peer + */ +type Endpoint interface { + ClearSrc() // clears the source address + SrcToString() string // returns the local source address (ip:port) + DstToString() string // returns the destination address (ip:port) + DstToBytes() []byte // used for mac2 cookie calculations + DstIP() net.IP + SrcIP() net.IP +} + +func parseEndpoint(s string) (*net.UDPAddr, error) { + + // ensure that the host is an IP address + + host, _, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + if ip := net.ParseIP(host); ip == nil { + return nil, errors.New("Failed to parse IP address: " + host) + } + + // parse address and port + + addr, err := net.ResolveUDPAddr("udp", s) + if err != nil { + return nil, err + } + ip4 := addr.IP.To4() + if ip4 != nil { + addr.IP = ip4 + } + return addr, err +} + +func unsafeCloseBind(device *Device) error { + var err error + netc := &device.net + if netc.bind != nil { + err = netc.bind.Close() + netc.bind = nil + } + netc.stopping.Wait() + return err +} + +func (device *Device) BindSetMark(mark uint32) error { + + device.net.Lock() + defer device.net.Unlock() + + // check if modified + + if device.net.fwmark == mark { + return nil + } + + // update fwmark on existing bind + + device.net.fwmark = mark + if device.isUp.Get() && device.net.bind != nil { + if err := device.net.bind.SetMark(mark); err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + return nil +} + +func (device *Device) BindUpdate() error { + + device.net.Lock() + defer device.net.Unlock() + + // close existing sockets + + if err := unsafeCloseBind(device); err != nil { + return err + } + + // open new sockets + + if device.isUp.Get() { + + // bind to new port + + var err error + netc := &device.net + netc.bind, netc.port, err = CreateBind(netc.port, device) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + + // set fwmark + + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + // start receiving routines + + device.net.starting.Add(ConnRoutineNumber) + device.net.stopping.Add(ConnRoutineNumber) + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.starting.Wait() + + device.log.Debug.Println("UDP bind has been updated") + } + + return nil +} + +func (device *Device) BindClose() error { + device.net.Lock() + err := unsafeCloseBind(device) + device.net.Unlock() + return err +} \ No newline at end of file diff --git a/device/conn_default.go b/device/conn_default.go new file mode 100644 index 0000000..8a86719 --- /dev/null +++ b/device/conn_default.go @@ -0,0 +1,170 @@ +// +build !linux android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "net" + "os" + "syscall" +) + +/* This code is meant to be a temporary solution + * on platforms for which the sticky socket / source caching behavior + * has not yet been implemented. + * + * See conn_linux.go for an implementation on the linux platform. + */ + +type NativeBind struct { + ipv4 *net.UDPConn + ipv6 *net.UDPConn +} + +type NativeEndpoint net.UDPAddr + +var _ Bind = (*NativeBind)(nil) +var _ Endpoint = (*NativeEndpoint)(nil) + +func CreateEndpoint(s string) (Endpoint, error) { + addr, err := parseEndpoint(s) + return (*NativeEndpoint)(addr), err +} + +func (_ *NativeEndpoint) ClearSrc() {} + +func (e *NativeEndpoint) DstIP() net.IP { + return (*net.UDPAddr)(e).IP +} + +func (e *NativeEndpoint) SrcIP() net.IP { + return nil // not supported +} + +func (e *NativeEndpoint) DstToBytes() []byte { + addr := (*net.UDPAddr)(e) + out := addr.IP.To4() + if out == nil { + out = addr.IP + } + out = append(out, byte(addr.Port&0xff)) + out = append(out, byte((addr.Port>>8)&0xff)) + return out +} + +func (e *NativeEndpoint) DstToString() string { + return (*net.UDPAddr)(e).String() +} + +func (e *NativeEndpoint) SrcToString() string { + return "" +} + +func listenNet(network string, port int) (*net.UDPConn, int, error) { + + // listen + + conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + if err != nil { + return nil, 0, err + } + + // retrieve port + + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + return nil, 0, err + } + return conn, uaddr.Port, nil +} + +func extractErrno(err error) error { + opErr, ok := err.(*net.OpError) + if !ok { + return nil + } + syscallErr, ok := opErr.Err.(*os.SyscallError) + if !ok { + return nil + } + return syscallErr.Err +} + +func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { + var err error + var bind NativeBind + + port := int(uport) + + bind.ipv4, port, err = listenNet("udp4", port) + if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { + return nil, 0, err + } + + bind.ipv6, port, err = listenNet("udp6", port) + if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { + bind.ipv4.Close() + bind.ipv4 = nil + return nil, 0, err + } + + return &bind, uint16(port), nil +} + +func (bind *NativeBind) Close() error { + var err1, err2 error + if bind.ipv4 != nil { + err1 = bind.ipv4.Close() + } + if bind.ipv6 != nil { + err2 = bind.ipv6.Close() + } + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + if bind.ipv4 == nil { + return 0, nil, syscall.EAFNOSUPPORT + } + n, endpoint, err := bind.ipv4.ReadFromUDP(buff) + if endpoint != nil { + endpoint.IP = endpoint.IP.To4() + } + return n, (*NativeEndpoint)(endpoint), err +} + +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + if bind.ipv6 == nil { + return 0, nil, syscall.EAFNOSUPPORT + } + n, endpoint, err := bind.ipv6.ReadFromUDP(buff) + return n, (*NativeEndpoint)(endpoint), err +} + +func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error { + var err error + nend := endpoint.(*NativeEndpoint) + if nend.IP.To4() != nil { + if bind.ipv4 == nil { + return syscall.EAFNOSUPPORT + } + _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } else { + if bind.ipv6 == nil { + return syscall.EAFNOSUPPORT + } + _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } + return err +} diff --git a/device/conn_linux.go b/device/conn_linux.go new file mode 100644 index 0000000..49949d5 --- /dev/null +++ b/device/conn_linux.go @@ -0,0 +1,746 @@ +// +build !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This implements userspace semantics of "sticky sockets", modeled after + * WireGuard's kernelspace implementation. This is more or less a straight port + * of the sticky-sockets.c example code: + * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c + * + * Currently there is no way to achieve this within the net package: + * See e.g. https://github.com/golang/go/issues/17930 + * So this code is remains platform dependent. + */ + +package device + +import ( + "errors" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/rwcancel" + "net" + "strconv" + "sync" + "syscall" + "unsafe" +) + +const ( + FD_ERR = -1 +) + +type IPv4Source struct { + src [4]byte + ifindex int32 +} + +type IPv6Source struct { + src [16]byte + //ifindex belongs in dst.ZoneId +} + +type NativeEndpoint struct { + dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte + src [unsafe.Sizeof(IPv6Source{})]byte + isV6 bool +} + +func (endpoint *NativeEndpoint) src4() *IPv4Source { + return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) +} + +func (endpoint *NativeEndpoint) src6() *IPv6Source { + return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) +} + +func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { + return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) +} + +func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { + return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) +} + +type NativeBind struct { + sock4 int + sock6 int + netlinkSock int + netlinkCancel *rwcancel.RWCancel + lastMark uint32 +} + +var _ Endpoint = (*NativeEndpoint)(nil) +var _ Bind = (*NativeBind)(nil) + +func CreateEndpoint(s string) (Endpoint, error) { + var end NativeEndpoint + addr, err := parseEndpoint(s) + if err != nil { + return nil, err + } + + ipv4 := addr.IP.To4() + if ipv4 != nil { + dst := end.dst4() + end.isV6 = false + dst.Port = addr.Port + copy(dst.Addr[:], ipv4) + end.ClearSrc() + return &end, nil + } + + ipv6 := addr.IP.To16() + if ipv6 != nil { + zone, err := zoneToUint32(addr.Zone) + if err != nil { + return nil, err + } + dst := end.dst6() + end.isV6 = true + dst.Port = addr.Port + dst.ZoneId = zone + copy(dst.Addr[:], ipv6[:]) + end.ClearSrc() + return &end, nil + } + + return nil, errors.New("Invalid IP address") +} + +func createNetlinkRouteSocket() (int, error) { + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + if err != nil { + return -1, err + } + saddr := &unix.SockaddrNetlink{ + Family: unix.AF_NETLINK, + Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + } + err = unix.Bind(sock, saddr) + if err != nil { + unix.Close(sock) + return -1, err + } + return sock, nil + +} + +func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) { + var err error + var bind NativeBind + var newPort uint16 + + bind.netlinkSock, err = createNetlinkRouteSocket() + if err != nil { + return nil, 0, err + } + bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock) + if err != nil { + unix.Close(bind.netlinkSock) + return nil, 0, err + } + + go bind.routineRouteListener(device) + + // attempt ipv6 bind, update port if succesful + + bind.sock6, newPort, err = create6(port) + if err != nil { + if err != syscall.EAFNOSUPPORT { + bind.netlinkCancel.Cancel() + return nil, 0, err + } + } else { + port = newPort + } + + // attempt ipv4 bind, update port if succesful + + bind.sock4, newPort, err = create4(port) + if err != nil { + if err != syscall.EAFNOSUPPORT { + bind.netlinkCancel.Cancel() + unix.Close(bind.sock6) + return nil, 0, err + } + } else { + port = newPort + } + + if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { + return nil, 0, errors.New("ipv4 and ipv6 not supported") + } + + return &bind, port, nil +} + +func (bind *NativeBind) SetMark(value uint32) error { + if bind.sock6 != -1 { + err := unix.SetsockoptInt( + bind.sock6, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + } + + if bind.sock4 != -1 { + err := unix.SetsockoptInt( + bind.sock4, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + } + + bind.lastMark = value + return nil +} + +func closeUnblock(fd int) error { + // shutdown to unblock readers and writers + unix.Shutdown(fd, unix.SHUT_RDWR) + return unix.Close(fd) +} + +func (bind *NativeBind) Close() error { + var err1, err2, err3 error + if bind.sock6 != -1 { + err1 = closeUnblock(bind.sock6) + } + if bind.sock4 != -1 { + err2 = closeUnblock(bind.sock4) + } + err3 = bind.netlinkCancel.Cancel() + + if err1 != nil { + return err1 + } + if err2 != nil { + return err2 + } + return err3 +} + +func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + if bind.sock6 == -1 { + return 0, nil, syscall.EAFNOSUPPORT + } + n, err := receive6( + bind.sock6, + buff, + &end, + ) + return n, &end, err +} + +func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + var end NativeEndpoint + if bind.sock4 == -1 { + return 0, nil, syscall.EAFNOSUPPORT + } + n, err := receive4( + bind.sock4, + buff, + &end, + ) + return n, &end, err +} + +func (bind *NativeBind) Send(buff []byte, end Endpoint) error { + nend := end.(*NativeEndpoint) + if !nend.isV6 { + if bind.sock4 == -1 { + return syscall.EAFNOSUPPORT + } + return send4(bind.sock4, nend, buff) + } else { + if bind.sock6 == -1 { + return syscall.EAFNOSUPPORT + } + return send6(bind.sock6, nend, buff) + } +} + +func (end *NativeEndpoint) SrcIP() net.IP { + if !end.isV6 { + return net.IPv4( + end.src4().src[0], + end.src4().src[1], + end.src4().src[2], + end.src4().src[3], + ) + } else { + return end.src6().src[:] + } +} + +func (end *NativeEndpoint) DstIP() net.IP { + if !end.isV6 { + return net.IPv4( + end.dst4().Addr[0], + end.dst4().Addr[1], + end.dst4().Addr[2], + end.dst4().Addr[3], + ) + } else { + return end.dst6().Addr[:] + } +} + +func (end *NativeEndpoint) DstToBytes() []byte { + if !end.isV6 { + return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] + } else { + return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] + } +} + +func (end *NativeEndpoint) SrcToString() string { + return end.SrcIP().String() +} + +func (end *NativeEndpoint) DstToString() string { + var udpAddr net.UDPAddr + udpAddr.IP = end.DstIP() + if !end.isV6 { + udpAddr.Port = end.dst4().Port + } else { + udpAddr.Port = end.dst6().Port + } + return udpAddr.String() +} + +func (end *NativeEndpoint) ClearDst() { + for i := range end.dst { + end.dst[i] = 0 + } +} + +func (end *NativeEndpoint) ClearSrc() { + for i := range end.src { + end.src[i] = 0 + } +} + +func zoneToUint32(zone string) (uint32, error) { + if zone == "" { + return 0, nil + } + if intr, err := net.InterfaceByName(zone); err == nil { + return uint32(intr.Index), nil + } + n, err := strconv.ParseUint(zone, 10, 32) + return uint32(n), err +} + +func create4(port uint16) (int, uint16, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return FD_ERR, 0, err + } + + addr := unix.SockaddrInet4{ + Port: int(port), + } + + // set sockopts and bind + + if err := func() error { + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IP, + unix.IP_PKTINFO, + 1, + ); err != nil { + return err + } + + return unix.Bind(fd, &addr) + }(); err != nil { + unix.Close(fd) + return FD_ERR, 0, err + } + + return fd, uint16(addr.Port), err +} + +func create6(port uint16) (int, uint16, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return FD_ERR, 0, err + } + + // set sockopts and bind + + addr := unix.SockaddrInet6{ + Port: int(port), + } + + if err := func() error { + + if err := unix.SetsockoptInt( + fd, + unix.SOL_SOCKET, + unix.SO_REUSEADDR, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_RECVPKTINFO, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_V6ONLY, + 1, + ); err != nil { + return err + } + + return unix.Bind(fd, &addr) + + }(); err != nil { + unix.Close(fd) + return FD_ERR, 0, err + } + + return fd, uint16(addr.Port), err +} + +func send4(sock int, end *NativeEndpoint, buff []byte) error { + + // construct message header + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, + }, + unix.Inet4Pktinfo{ + Spec_dst: end.src4().src, + Ifindex: end.src4().ifindex, + }, + } + + _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) + + if err == nil { + return nil + } + + // clear src and retry + + if err == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet4Pktinfo{} + _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) + } + + return err +} + +func send6(sock int, end *NativeEndpoint, buff []byte) error { + + // construct message header + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, + }, + unix.Inet6Pktinfo{ + Addr: end.src6().src, + Ifindex: end.dst6().ZoneId, + }, + } + + if cmsg.pktinfo.Addr == [16]byte{} { + cmsg.pktinfo.Ifindex = 0 + } + + _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) + + if err == nil { + return nil + } + + // clear src and retry + + if err == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet6Pktinfo{} + _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) + } + + return err +} + +func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { + + // contruct message header + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + } + + size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) + + if err != nil { + return 0, err + } + end.isV6 = false + + if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { + *end.dst4() = *newDst4 + } + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IP && + cmsg.cmsghdr.Type == unix.IP_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { + end.src4().src = cmsg.pktinfo.Spec_dst + end.src4().ifindex = cmsg.pktinfo.Ifindex + } + + return size, nil +} + +func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { + + // contruct message header + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + } + + size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) + + if err != nil { + return 0, err + } + end.isV6 = true + + if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { + *end.dst6() = *newDst6 + } + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && + cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { + end.src6().src = cmsg.pktinfo.Addr + end.dst6().ZoneId = cmsg.pktinfo.Ifindex + } + + return size, nil +} + +func (bind *NativeBind) routineRouteListener(device *Device) { + type peerEndpointPtr struct { + peer *Peer + endpoint *Endpoint + } + var reqPeer map[uint32]peerEndpointPtr + var reqPeerLock sync.Mutex + + defer unix.Close(bind.netlinkSock) + + for msg := make([]byte, 1<<16); ; { + var err error + var msgn int + for { + msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) + if err == nil || !rwcancel.RetryAfterError(err) { + break + } + if !bind.netlinkCancel.ReadyRead() { + return + } + } + if err != nil { + return + } + + for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; { + + hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0])) + + if uint(hdr.Len) > uint(len(remain)) { + break + } + + switch hdr.Type { + case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: + if hdr.Seq <= MaxPeers && hdr.Seq > 0 { + if uint(len(remain)) < uint(hdr.Len) { + break + } + if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg { + attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:] + for { + if uint(len(attr)) < uint(unix.SizeofRtAttr) { + break + } + attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0])) + if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) { + break + } + if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { + ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) + reqPeerLock.Lock() + if reqPeer == nil { + reqPeerLock.Unlock() + break + } + pePtr, ok := reqPeer[hdr.Seq] + reqPeerLock.Unlock() + if !ok { + break + } + pePtr.peer.Lock() + if &pePtr.peer.endpoint != pePtr.endpoint { + pePtr.peer.Unlock() + break + } + if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx { + pePtr.peer.Unlock() + break + } + pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc() + pePtr.peer.Unlock() + } + attr = attr[attrhdr.Len:] + } + } + break + } + reqPeerLock.Lock() + reqPeer = make(map[uint32]peerEndpointPtr) + reqPeerLock.Unlock() + go func() { + device.peers.RLock() + i := uint32(1) + for _, peer := range device.peers.keyMap { + peer.RLock() + if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil { + peer.RUnlock() + continue + } + if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 { + peer.RUnlock() + break + } + nlmsg := struct { + hdr unix.NlMsghdr + msg unix.RtMsg + dsthdr unix.RtAttr + dst [4]byte + srchdr unix.RtAttr + src [4]byte + markhdr unix.RtAttr + mark uint32 + }{ + unix.NlMsghdr{ + Type: uint16(unix.RTM_GETROUTE), + Flags: unix.NLM_F_REQUEST, + Seq: i, + }, + unix.RtMsg{ + Family: unix.AF_INET, + Dst_len: 32, + Src_len: 32, + }, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_DST, + }, + peer.endpoint.(*NativeEndpoint).dst4().Addr, + unix.RtAttr{ + Len: 8, + Type: unix.RTA_SRC, + }, + peer.endpoint.(*NativeEndpoint).src4().src, + unix.RtAttr{ + Len: 8, + Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix + }, + uint32(bind.lastMark), + } + nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) + reqPeerLock.Lock() + reqPeer[i] = peerEndpointPtr{ + peer: peer, + endpoint: &peer.endpoint, + } + reqPeerLock.Unlock() + peer.RUnlock() + i++ + _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) + if err != nil { + break + } + } + device.peers.RUnlock() + }() + } + remain = remain[hdr.Len:] + } + } +} diff --git a/device/constants.go b/device/constants.go new file mode 100644 index 0000000..27d910f --- /dev/null +++ b/device/constants.go @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "time" +) + +/* Specification constants */ + +const ( + RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 + RejectAfterMessages = (1 << 64) - (1 << 4) - 1 + RekeyAfterTime = time.Second * 120 + RekeyAttemptTime = time.Second * 90 + RekeyTimeout = time.Second * 5 + MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */ + RekeyTimeoutJitterMaxMs = 334 + RejectAfterTime = time.Second * 180 + KeepaliveTimeout = time.Second * 10 + CookieRefreshTime = time.Second * 120 + HandshakeInitationRate = time.Second / 20 + PaddingMultiple = 16 +) + +const ( + MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) + MaxMessageSize = MaxSegmentSize // maximum size of transport message + MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content +) + +/* Implementation constants */ + +const ( + UnderLoadQueueSize = QueueHandshakeSize / 8 + UnderLoadAfterTime = time.Second // how long does the device remain under load after detected + MaxPeers = 1 << 16 // maximum number of configured peers +) diff --git a/device/cookie.go b/device/cookie.go new file mode 100644 index 0000000..2f21067 --- /dev/null +++ b/device/cookie.go @@ -0,0 +1,250 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/hmac" + "crypto/rand" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "sync" + "time" +) + +type CookieChecker struct { + sync.RWMutex + mac1 struct { + key [blake2s.Size]byte + } + mac2 struct { + secret [blake2s.Size]byte + secretSet time.Time + encryptionKey [chacha20poly1305.KeySize]byte + } +} + +type CookieGenerator struct { + sync.RWMutex + mac1 struct { + key [blake2s.Size]byte + } + mac2 struct { + cookie [blake2s.Size128]byte + cookieSet time.Time + hasLastMAC1 bool + lastMAC1 [blake2s.Size128]byte + encryptionKey [chacha20poly1305.KeySize]byte + } +} + +func (st *CookieChecker) Init(pk NoisePublicKey) { + st.Lock() + defer st.Unlock() + + // mac1 state + + func() { + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelMAC1)) + hash.Write(pk[:]) + hash.Sum(st.mac1.key[:0]) + }() + + // mac2 state + + func() { + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelCookie)) + hash.Write(pk[:]) + hash.Sum(st.mac2.encryptionKey[:0]) + }() + + st.mac2.secretSet = time.Time{} +} + +func (st *CookieChecker) CheckMAC1(msg []byte) bool { + st.RLock() + defer st.RUnlock() + + size := len(msg) + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + var mac1 [blake2s.Size128]byte + + mac, _ := blake2s.New128(st.mac1.key[:]) + mac.Write(msg[:smac1]) + mac.Sum(mac1[:0]) + + return hmac.Equal(mac1[:], msg[smac1:smac2]) +} + +func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { + st.RLock() + defer st.RUnlock() + + if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { + return false + } + + // derive cookie key + + var cookie [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(st.mac2.secret[:]) + mac.Write(src) + mac.Sum(cookie[:0]) + }() + + // calculate mac of packet (including mac1) + + smac2 := len(msg) - blake2s.Size128 + + var mac2 [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(cookie[:]) + mac.Write(msg[:smac2]) + mac.Sum(mac2[:0]) + }() + + return hmac.Equal(mac2[:], msg[smac2:]) +} + +func (st *CookieChecker) CreateReply( + msg []byte, + recv uint32, + src []byte, +) (*MessageCookieReply, error) { + + st.RLock() + + // refresh cookie secret + + if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { + st.RUnlock() + st.Lock() + _, err := rand.Read(st.mac2.secret[:]) + if err != nil { + st.Unlock() + return nil, err + } + st.mac2.secretSet = time.Now() + st.Unlock() + st.RLock() + } + + // derive cookie + + var cookie [blake2s.Size128]byte + func() { + mac, _ := blake2s.New128(st.mac2.secret[:]) + mac.Write(src) + mac.Sum(cookie[:0]) + }() + + // encrypt cookie + + size := len(msg) + + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + reply := new(MessageCookieReply) + reply.Type = MessageCookieReplyType + reply.Receiver = recv + + _, err := rand.Read(reply.Nonce[:]) + if err != nil { + st.RUnlock() + return nil, err + } + + xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) + xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2]) + + st.RUnlock() + + return reply, nil +} + +func (st *CookieGenerator) Init(pk NoisePublicKey) { + st.Lock() + defer st.Unlock() + + func() { + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelMAC1)) + hash.Write(pk[:]) + hash.Sum(st.mac1.key[:0]) + }() + + func() { + hash, _ := blake2s.New256(nil) + hash.Write([]byte(WGLabelCookie)) + hash.Write(pk[:]) + hash.Sum(st.mac2.encryptionKey[:0]) + }() + + st.mac2.cookieSet = time.Time{} +} + +func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { + st.Lock() + defer st.Unlock() + + if !st.mac2.hasLastMAC1 { + return false + } + + var cookie [blake2s.Size128]byte + + xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) + _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) + + if err != nil { + return false + } + + st.mac2.cookieSet = time.Now() + st.mac2.cookie = cookie + return true +} + +func (st *CookieGenerator) AddMacs(msg []byte) { + + size := len(msg) + + smac2 := size - blake2s.Size128 + smac1 := smac2 - blake2s.Size128 + + mac1 := msg[smac1:smac2] + mac2 := msg[smac2:] + + st.Lock() + defer st.Unlock() + + // set mac1 + + func() { + mac, _ := blake2s.New128(st.mac1.key[:]) + mac.Write(msg[:smac1]) + mac.Sum(mac1[:0]) + }() + copy(st.mac2.lastMAC1[:], mac1) + st.mac2.hasLastMAC1 = true + + // set mac2 + + if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime { + return + } + + func() { + mac, _ := blake2s.New128(st.mac2.cookie[:]) + mac.Write(msg[:smac2]) + mac.Sum(mac2[:0]) + }() +} diff --git a/device/cookie_test.go b/device/cookie_test.go new file mode 100644 index 0000000..79a6a86 --- /dev/null +++ b/device/cookie_test.go @@ -0,0 +1,191 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "testing" +) + +func TestCookieMAC1(t *testing.T) { + + // setup generator / checker + + var ( + generator CookieGenerator + checker CookieChecker + ) + + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + pk := sk.publicKey() + + generator.Init(pk) + checker.Init(pk) + + // check mac1 + + src := []byte{192, 168, 13, 37, 10, 10, 10} + + checkMAC1 := func(msg []byte) { + generator.AddMacs(msg) + if !checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + } + + checkMAC1([]byte{ + 0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd, + 0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62, + 0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64, + 0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91, + 0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4, + 0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c, + 0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b, + 0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5, + 0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8, + 0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8, + 0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e, + 0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82, + 0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53, + 0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd, + 0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad, + 0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f, + }) + + checkMAC1([]byte{ + 0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c, + 0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56, + 0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a, + 0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c, + 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, + 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, + 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, + 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, + }) + + checkMAC1([]byte{ + 0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b, + 0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc, + 0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b, + 0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05, + }) + + // exchange cookie reply + + func() { + msg := []byte{ + 0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf, + 0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8, + 0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5, + 0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f, + 0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2, + 0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b, + 0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e, + 0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9, + 0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22, + 0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27, + 0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f, + 0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2, + 0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b, + 0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb, + 0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61, + 0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d, + } + generator.AddMacs(msg) + reply, err := checker.CreateReply(msg, 1377, src) + if err != nil { + t.Fatal("Failed to create cookie reply:", err) + } + if !generator.ConsumeReply(reply) { + t.Fatal("Failed to consume cookie reply") + } + }() + + // check mac2 + + checkMAC2 := func(msg []byte) { + generator.AddMacs(msg) + + if !checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if !checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + + msg[5] ^= 0x20 + + if checker.CheckMAC1(msg) { + t.Fatal("MAC1 generation/verification failed") + } + if checker.CheckMAC2(msg, src) { + t.Fatal("MAC2 generation/verification failed") + } + + msg[5] ^= 0x20 + + srcBad1 := []byte{192, 168, 13, 37, 40, 01} + if checker.CheckMAC2(msg, srcBad1) { + t.Fatal("MAC2 generation/verification failed") + } + + srcBad2 := []byte{192, 168, 13, 38, 40, 01} + if checker.CheckMAC2(msg, srcBad2) { + t.Fatal("MAC2 generation/verification failed") + } + } + + checkMAC2([]byte{ + 0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3, + 0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15, + 0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f, + 0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f, + 0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69, + 0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb, + 0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c, + 0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38, + }) + + checkMAC2([]byte{ + 0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3, + 0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85, + 0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84, + 0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5, + 0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85, + 0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55, + 0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a, + 0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca, + 0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59, + 0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca, + 0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b, + 0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7, + 0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55, + 0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2, + 0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f, + 0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d, + 0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d, + 0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f, + 0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2, + 0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0, + 0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35, + 0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a, + 0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee, + 0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe, + 0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e, + 0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4, + 0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06, + 0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93, + 0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa, + 0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64, + 0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b, + 0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa, + }) +} diff --git a/device/device.go b/device/device.go new file mode 100644 index 0000000..d6c96d6 --- /dev/null +++ b/device/device.go @@ -0,0 +1,396 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/tun" + "runtime" + "sync" + "sync/atomic" + "time" +) + +const ( + DeviceRoutineNumberPerCPU = 3 + DeviceRoutineNumberAdditional = 2 +) + +type Device struct { + isUp AtomicBool // device is (going) up + isClosed AtomicBool // device is closed? (acting as guard) + log *Logger + + // synchronized resources (locks acquired in order) + + state struct { + starting sync.WaitGroup + stopping sync.WaitGroup + sync.Mutex + changing AtomicBool + current bool + } + + net struct { + starting sync.WaitGroup + stopping sync.WaitGroup + sync.RWMutex + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + } + + staticIdentity struct { + sync.RWMutex + privateKey NoisePrivateKey + publicKey NoisePublicKey + } + + peers struct { + sync.RWMutex + keyMap map[NoisePublicKey]*Peer + } + + // unprotected / "self-synchronising resources" + + allowedips AllowedIPs + indexTable IndexTable + cookieChecker CookieChecker + + rate struct { + underLoadUntil atomic.Value + limiter ratelimiter.Ratelimiter + } + + pool struct { + messageBufferPool *sync.Pool + messageBufferReuseChan chan *[MaxMessageSize]byte + inboundElementPool *sync.Pool + inboundElementReuseChan chan *QueueInboundElement + outboundElementPool *sync.Pool + outboundElementReuseChan chan *QueueOutboundElement + } + + queue struct { + encryption chan *QueueOutboundElement + decryption chan *QueueInboundElement + handshake chan QueueHandshakeElement + } + + signals struct { + stop chan struct{} + } + + tun struct { + device tun.TUNDevice + mtu int32 + } +} + +/* Converts the peer into a "zombie", which remains in the peer map, + * but processes no packets and does not exists in the routing table. + * + * Must hold device.peers.Mutex + */ +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { + + // stop routing and processing of packets + + device.allowedips.RemoveByPeer(peer) + peer.Stop() + + // remove from peer map + + delete(device.peers.keyMap, key) +} + +func deviceUpdateState(device *Device) { + + // check if state already being updated (guard) + + if device.state.changing.Swap(true) { + return + } + + // compare to current state of device + + device.state.Lock() + + newIsUp := device.isUp.Get() + + if newIsUp == device.state.current { + device.state.changing.Set(false) + device.state.Unlock() + return + } + + // change state of device + + switch newIsUp { + case true: + if err := device.BindUpdate(); err != nil { + device.isUp.Set(false) + break + } + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Start() + if peer.persistentKeepaliveInterval > 0 { + peer.SendKeepalive() + } + } + device.peers.RUnlock() + + case false: + device.BindClose() + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Stop() + } + device.peers.RUnlock() + } + + // update state variables + + device.state.current = newIsUp + device.state.changing.Set(false) + device.state.Unlock() + + // check for state change in the mean time + + deviceUpdateState(device) +} + +func (device *Device) Up() { + + // closed device cannot be brought up + + if device.isClosed.Get() { + return + } + + device.isUp.Set(true) + deviceUpdateState(device) +} + +func (device *Device) Down() { + device.isUp.Set(false) + deviceUpdateState(device) +} + +func (device *Device) IsUnderLoad() bool { + + // check if currently under load + + now := time.Now() + underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + if underLoad { + device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) + return true + } + + // check if recently under load + + until := device.rate.underLoadUntil.Load().(time.Time) + return until.After(now) +} + +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { + + // lock required resources + + device.staticIdentity.Lock() + defer device.staticIdentity.Unlock() + + device.peers.Lock() + defer device.peers.Unlock() + + for _, peer := range device.peers.keyMap { + peer.handshake.mutex.RLock() + defer peer.handshake.mutex.RUnlock() + } + + // remove peers with matching public keys + + publicKey := sk.publicKey() + for key, peer := range device.peers.keyMap { + if peer.handshake.remoteStatic.Equals(publicKey) { + unsafeRemovePeer(device, peer, key) + } + } + + // update key material + + device.staticIdentity.privateKey = sk + device.staticIdentity.publicKey = publicKey + device.cookieChecker.Init(publicKey) + + // do static-static DH pre-computations + + rmKey := device.staticIdentity.privateKey.IsZero() + + for key, peer := range device.peers.keyMap { + + handshake := &peer.handshake + + if rmKey { + handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{} + } else { + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + } + + if isZero(handshake.precomputedStaticStatic[:]) { + unsafeRemovePeer(device, peer, key) + } + } + + return nil +} + +func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device { + device := new(Device) + + device.isUp.Set(false) + device.isClosed.Set(false) + + device.log = logger + + device.tun.device = tunDevice + mtu, err := device.tun.device.MTU() + if err != nil { + logger.Error.Println("Trouble determining MTU, assuming default:", err) + mtu = DefaultMTU + } + device.tun.mtu = int32(mtu) + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) + + device.rate.limiter.Init() + device.rate.underLoadUntil.Store(time.Time{}) + + device.indexTable.Init() + device.allowedips.Reset() + + device.PopulatePools() + + // create queues + + device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) + device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + + // prepare signals + + device.signals.stop = make(chan struct{}) + + // prepare net + + device.net.port = 0 + device.net.bind = nil + + // start workers + + cpus := runtime.NumCPU() + device.state.starting.Wait() + device.state.stopping.Wait() + device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) + device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) + for i := 0; i < cpus; i += 1 { + go device.RoutineEncryption() + go device.RoutineDecryption() + go device.RoutineHandshake() + } + + go device.RoutineReadFromTUN() + go device.RoutineTUNEventReader() + + device.state.starting.Wait() + + return device +} + +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { + device.peers.RLock() + defer device.peers.RUnlock() + + return device.peers.keyMap[pk] +} + +func (device *Device) RemovePeer(key NoisePublicKey) { + device.peers.Lock() + defer device.peers.Unlock() + + // stop peer and remove from routing + + peer, ok := device.peers.keyMap[key] + if ok { + unsafeRemovePeer(device, peer, key) + } +} + +func (device *Device) RemoveAllPeers() { + device.peers.Lock() + defer device.peers.Unlock() + + for key, peer := range device.peers.keyMap { + unsafeRemovePeer(device, peer, key) + } + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) +} + +func (device *Device) FlushPacketQueues() { + for { + select { + case elem, ok := <-device.queue.decryption: + if ok { + elem.Drop() + } + case elem, ok := <-device.queue.encryption: + if ok { + elem.Drop() + } + case <-device.queue.handshake: + default: + return + } + } + +} + +func (device *Device) Close() { + if device.isClosed.Swap(true) { + return + } + + device.state.starting.Wait() + + device.log.Info.Println("Device closing") + device.state.changing.Set(true) + device.state.Lock() + defer device.state.Unlock() + + device.tun.device.Close() + device.BindClose() + + device.isUp.Set(false) + + close(device.signals.stop) + + device.RemoveAllPeers() + + device.state.stopping.Wait() + device.FlushPacketQueues() + + device.rate.limiter.Close() + + device.state.changing.Set(false) + device.log.Info.Println("Interface closed") +} + +func (device *Device) Wait() chan struct{} { + return device.signals.stop +} diff --git a/device/device_test.go b/device/device_test.go new file mode 100644 index 0000000..db5a3c0 --- /dev/null +++ b/device/device_test.go @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +/* Create two device instances and simulate full WireGuard interaction + * without network dependencies + */ + +import "testing" + +func TestDevice(t *testing.T) { + + // prepare tun devices for generating traffic + + tun1, err := CreateDummyTUN("tun1") + if err != nil { + t.Error("failed to create tun:", err.Error()) + } + + tun2, err := CreateDummyTUN("tun2") + if err != nil { + t.Error("failed to create tun:", err.Error()) + } + + _ = tun1 + _ = tun2 + + // prepare endpoints + + end1, err := CreateDummyEndpoint() + if err != nil { + t.Error("failed to create endpoint:", err.Error()) + } + + end2, err := CreateDummyEndpoint() + if err != nil { + t.Error("failed to create endpoint:", err.Error()) + } + + _ = end1 + _ = end2 + + // create binds + +} diff --git a/device/endpoint_test.go b/device/endpoint_test.go new file mode 100644 index 0000000..1896790 --- /dev/null +++ b/device/endpoint_test.go @@ -0,0 +1,53 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "net" +) + +type DummyEndpoint struct { + src [16]byte + dst [16]byte +} + +func CreateDummyEndpoint() (*DummyEndpoint, error) { + var end DummyEndpoint + if _, err := rand.Read(end.src[:]); err != nil { + return nil, err + } + _, err := rand.Read(end.dst[:]) + return &end, err +} + +func (e *DummyEndpoint) ClearSrc() {} + +func (e *DummyEndpoint) SrcToString() string { + var addr net.UDPAddr + addr.IP = e.SrcIP() + addr.Port = 1000 + return addr.String() +} + +func (e *DummyEndpoint) DstToString() string { + var addr net.UDPAddr + addr.IP = e.DstIP() + addr.Port = 1000 + return addr.String() +} + +func (e *DummyEndpoint) SrcToBytes() []byte { + return e.src[:] +} + +func (e *DummyEndpoint) DstIP() net.IP { + return e.dst[:] +} + +func (e *DummyEndpoint) SrcIP() net.IP { + return e.src[:] +} diff --git a/device/indextable.go b/device/indextable.go new file mode 100644 index 0000000..4cba970 --- /dev/null +++ b/device/indextable.go @@ -0,0 +1,97 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/rand" + "sync" + "unsafe" +) + +type IndexTableEntry struct { + peer *Peer + handshake *Handshake + keypair *Keypair +} + +type IndexTable struct { + sync.RWMutex + table map[uint32]IndexTableEntry +} + +func randUint32() (uint32, error) { + var integer [4]byte + _, err := rand.Read(integer[:]) + return *(*uint32)(unsafe.Pointer(&integer[0])), err +} + +func (table *IndexTable) Init() { + table.Lock() + defer table.Unlock() + table.table = make(map[uint32]IndexTableEntry) +} + +func (table *IndexTable) Delete(index uint32) { + table.Lock() + defer table.Unlock() + delete(table.table, index) +} + +func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) { + table.Lock() + defer table.Unlock() + entry, ok := table.table[index] + if !ok { + return + } + table.table[index] = IndexTableEntry{ + peer: entry.peer, + keypair: keypair, + handshake: nil, + } +} + +func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) { + for { + // generate random index + + index, err := randUint32() + if err != nil { + return index, err + } + + // check if index used + + table.RLock() + _, ok := table.table[index] + table.RUnlock() + if ok { + continue + } + + // check again while locked + + table.Lock() + _, found := table.table[index] + if found { + table.Unlock() + continue + } + table.table[index] = IndexTableEntry{ + peer: peer, + handshake: handshake, + keypair: nil, + } + table.Unlock() + return index, nil + } +} + +func (table *IndexTable) Lookup(id uint32) IndexTableEntry { + table.RLock() + defer table.RUnlock() + return table.table[id] +} diff --git a/device/ip.go b/device/ip.go new file mode 100644 index 0000000..9d4fb74 --- /dev/null +++ b/device/ip.go @@ -0,0 +1,22 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "net" +) + +const ( + IPv4offsetTotalLength = 2 + IPv4offsetSrc = 12 + IPv4offsetDst = IPv4offsetSrc + net.IPv4len +) + +const ( + IPv6offsetPayloadLength = 4 + IPv6offsetSrc = 8 + IPv6offsetDst = IPv6offsetSrc + net.IPv6len +) diff --git a/device/kdf_test.go b/device/kdf_test.go new file mode 100644 index 0000000..11ea8d5 --- /dev/null +++ b/device/kdf_test.go @@ -0,0 +1,84 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "encoding/hex" + "golang.org/x/crypto/blake2s" + "testing" +) + +type KDFTest struct { + key string + input string + t0 string + t1 string + t2 string +} + +func assertEquals(t *testing.T, a string, b string) { + if a != b { + t.Fatal("expected", a, "=", b) + } +} + +func TestKDF(t *testing.T) { + tests := []KDFTest{ + { + key: "746573742d6b6579", + input: "746573742d696e707574", + t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633", + t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a", + t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24", + }, + { + key: "776972656775617264", + input: "776972656775617264", + t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8", + t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f", + t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160", + }, + { + key: "", + input: "", + t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0", + t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e", + t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e", + }, + } + + var t0, t1, t2 [blake2s.Size]byte + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF3(&t0, &t1, &t2, key, input) + t0s := hex.EncodeToString(t0[:]) + t1s := hex.EncodeToString(t1[:]) + t2s := hex.EncodeToString(t2[:]) + assertEquals(t, t0s, test.t0) + assertEquals(t, t1s, test.t1) + assertEquals(t, t2s, test.t2) + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF2(&t0, &t1, key, input) + t0s := hex.EncodeToString(t0[:]) + t1s := hex.EncodeToString(t1[:]) + assertEquals(t, t0s, test.t0) + assertEquals(t, t1s, test.t1) + } + + for _, test := range tests { + key, _ := hex.DecodeString(test.key) + input, _ := hex.DecodeString(test.input) + KDF1(&t0, key, input) + t0s := hex.EncodeToString(t0[:]) + assertEquals(t, t0s, test.t0) + } +} diff --git a/device/keypair.go b/device/keypair.go new file mode 100644 index 0000000..a9fbfce --- /dev/null +++ b/device/keypair.go @@ -0,0 +1,50 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/cipher" + "golang.zx2c4.com/wireguard/replay" + "sync" + "time" +) + +/* Due to limitations in Go and /x/crypto there is currently + * no way to ensure that key material is securely ereased in memory. + * + * Since this may harm the forward secrecy property, + * we plan to resolve this issue; whenever Go allows us to do so. + */ + +type Keypair struct { + sendNonce uint64 + send cipher.AEAD + receive cipher.AEAD + replayFilter replay.ReplayFilter + isInitiator bool + created time.Time + localIndex uint32 + remoteIndex uint32 +} + +type Keypairs struct { + sync.RWMutex + current *Keypair + previous *Keypair + next *Keypair +} + +func (kp *Keypairs) Current() *Keypair { + kp.RLock() + defer kp.RUnlock() + return kp.current +} + +func (device *Device) DeleteKeypair(key *Keypair) { + if key != nil { + device.indexTable.Delete(key.localIndex) + } +} diff --git a/device/logger.go b/device/logger.go new file mode 100644 index 0000000..7c8b704 --- /dev/null +++ b/device/logger.go @@ -0,0 +1,59 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "io" + "io/ioutil" + "log" + "os" +) + +const ( + LogLevelSilent = iota + LogLevelError + LogLevelInfo + LogLevelDebug +) + +type Logger struct { + Debug *log.Logger + Info *log.Logger + Error *log.Logger +} + +func NewLogger(level int, prepend string) *Logger { + output := os.Stdout + logger := new(Logger) + + logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { + if level >= LogLevelDebug { + return output, output, output + } + if level >= LogLevelInfo { + return output, output, ioutil.Discard + } + if level >= LogLevelError { + return output, ioutil.Discard, ioutil.Discard + } + return ioutil.Discard, ioutil.Discard, ioutil.Discard + }() + + logger.Debug = log.New(logDebug, + "DEBUG: "+prepend, + log.Ldate|log.Ltime, + ) + + logger.Info = log.New(logInfo, + "INFO: "+prepend, + log.Ldate|log.Ltime, + ) + logger.Error = log.New(logErr, + "ERROR: "+prepend, + log.Ldate|log.Ltime, + ) + return logger +} diff --git a/device/mark_default.go b/device/mark_default.go new file mode 100644 index 0000000..76b1015 --- /dev/null +++ b/device/mark_default.go @@ -0,0 +1,12 @@ +// +build !linux,!openbsd,!freebsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +func (bind *NativeBind) SetMark(mark uint32) error { + return nil +} diff --git a/device/mark_unix.go b/device/mark_unix.go new file mode 100644 index 0000000..ee64cc9 --- /dev/null +++ b/device/mark_unix.go @@ -0,0 +1,64 @@ +// +build android openbsd freebsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "golang.org/x/sys/unix" + "runtime" +) + +var fwmarkIoctl int + +func init() { + switch runtime.GOOS { + case "linux", "android": + fwmarkIoctl = 36 /* unix.SO_MARK */ + case "freebsd": + fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */ + case "openbsd": + fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */ + } +} + +func (bind *NativeBind) SetMark(mark uint32) error { + var operr error + if fwmarkIoctl == 0 { + return nil + } + if bind.ipv4 != nil { + fd, err := bind.ipv4.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err == nil { + err = operr + } + if err != nil { + return err + } + } + if bind.ipv6 != nil { + fd, err := bind.ipv6.SyscallConn() + if err != nil { + return err + } + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + if err == nil { + err = operr + } + if err != nil { + return err + } + } + return nil +} diff --git a/device/misc.go b/device/misc.go new file mode 100644 index 0000000..a38d1c1 --- /dev/null +++ b/device/misc.go @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "sync/atomic" +) + +/* Atomic Boolean */ + +const ( + AtomicFalse = int32(iota) + AtomicTrue +) + +type AtomicBool struct { + int32 +} + +func (a *AtomicBool) Get() bool { + return atomic.LoadInt32(&a.int32) == AtomicTrue +} + +func (a *AtomicBool) Swap(val bool) bool { + flag := AtomicFalse + if val { + flag = AtomicTrue + } + return atomic.SwapInt32(&a.int32, flag) == AtomicTrue +} + +func (a *AtomicBool) Set(val bool) { + flag := AtomicFalse + if val { + flag = AtomicTrue + } + atomic.StoreInt32(&a.int32, flag) +} + +func min(a, b uint) uint { + if a > b { + return b + } + return a +} diff --git a/device/noise-helpers.go b/device/noise-helpers.go new file mode 100644 index 0000000..4b09bf3 --- /dev/null +++ b/device/noise-helpers.go @@ -0,0 +1,104 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/subtle" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/curve25519" + "hash" +) + +/* KDF related functions. + * HMAC-based Key Derivation Function (HKDF) + * https://tools.ietf.org/html/rfc5869 + */ + +func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) { + mac := hmac.New(func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + }, key) + mac.Write(in0) + mac.Sum(sum[:0]) +} + +func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) { + mac := hmac.New(func() hash.Hash { + h, _ := blake2s.New256(nil) + return h + }, key) + mac.Write(in0) + mac.Write(in1) + mac.Sum(sum[:0]) +} + +func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { + HMAC1(t0, key, input) + HMAC1(t0, t0[:], []byte{0x1}) + return +} + +func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { + var prk [blake2s.Size]byte + HMAC1(&prk, key, input) + HMAC1(t0, prk[:], []byte{0x1}) + HMAC2(t1, prk[:], t0[:], []byte{0x2}) + setZero(prk[:]) + return +} + +func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { + var prk [blake2s.Size]byte + HMAC1(&prk, key, input) + HMAC1(t0, prk[:], []byte{0x1}) + HMAC2(t1, prk[:], t0[:], []byte{0x2}) + HMAC2(t2, prk[:], t1[:], []byte{0x3}) + setZero(prk[:]) + return +} + +func isZero(val []byte) bool { + acc := 1 + for _, b := range val { + acc &= subtle.ConstantTimeByteEq(b, 0) + } + return acc == 1 +} + +/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ +func setZero(arr []byte) { + for i := range arr { + arr[i] = 0 + } +} + +func (sk *NoisePrivateKey) clamp() { + sk[0] &= 248 + sk[31] = (sk[31] & 127) | 64 +} + +func newPrivateKey() (sk NoisePrivateKey, err error) { + _, err = rand.Read(sk[:]) + sk.clamp() + return +} + +func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarBaseMult(apk, ask) + return +} + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarMult(&ss, ask, apk) + return ss +} diff --git a/device/noise-protocol.go b/device/noise-protocol.go new file mode 100644 index 0000000..73826e1 --- /dev/null +++ b/device/noise-protocol.go @@ -0,0 +1,600 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "errors" + "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" + "sync" + "time" +) + +const ( + HandshakeZeroed = iota + HandshakeInitiationCreated + HandshakeInitiationConsumed + HandshakeResponseCreated + HandshakeResponseConsumed +) + +const ( + NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s" + WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com" + WGLabelMAC1 = "mac1----" + WGLabelCookie = "cookie--" +) + +const ( + MessageInitiationType = 1 + MessageResponseType = 2 + MessageCookieReplyType = 3 + MessageTransportType = 4 +) + +const ( + MessageInitiationSize = 148 // size of handshake initation message + MessageResponseSize = 92 // size of response message + MessageCookieReplySize = 64 // size of cookie reply message + MessageTransportHeaderSize = 16 // size of data preceeding content in transport message + MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport + MessageKeepaliveSize = MessageTransportSize // size of keepalive + MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message +) + +const ( + MessageTransportOffsetReceiver = 4 + MessageTransportOffsetCounter = 8 + MessageTransportOffsetContent = 16 +) + +/* Type is an 8-bit field, followed by 3 nul bytes, + * by marshalling the messages in little-endian byteorder + * we can treat these as a 32-bit unsigned int (for now) + * + */ + +type MessageInitiation struct { + Type uint32 + Sender uint32 + Ephemeral NoisePublicKey + Static [NoisePublicKeySize + poly1305.TagSize]byte + Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte + MAC1 [blake2s.Size128]byte + MAC2 [blake2s.Size128]byte +} + +type MessageResponse struct { + Type uint32 + Sender uint32 + Receiver uint32 + Ephemeral NoisePublicKey + Empty [poly1305.TagSize]byte + MAC1 [blake2s.Size128]byte + MAC2 [blake2s.Size128]byte +} + +type MessageTransport struct { + Type uint32 + Receiver uint32 + Counter uint64 + Content []byte +} + +type MessageCookieReply struct { + Type uint32 + Receiver uint32 + Nonce [chacha20poly1305.NonceSizeX]byte + Cookie [blake2s.Size128 + poly1305.TagSize]byte +} + +type Handshake struct { + state int + mutex sync.RWMutex + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey NoiseSymmetricKey // psk + localEphemeral NoisePrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic NoisePublicKey // long term key + remoteEphemeral NoisePublicKey // ephemeral public key + precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret + lastTimestamp tai64n.Timestamp + lastInitiationConsumption time.Time + lastSentHandshake time.Time +} + +var ( + InitialChainKey [blake2s.Size]byte + InitialHash [blake2s.Size]byte + ZeroNonce [chacha20poly1305.NonceSize]byte +) + +func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { + KDF1(dst, c[:], data) +} + +func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { + hash, _ := blake2s.New256(nil) + hash.Write(h[:]) + hash.Write(data) + hash.Sum(dst[:0]) + hash.Reset() +} + +func (h *Handshake) Clear() { + setZero(h.localEphemeral[:]) + setZero(h.remoteEphemeral[:]) + setZero(h.chainKey[:]) + setZero(h.hash[:]) + h.localIndex = 0 + h.state = HandshakeZeroed +} + +func (h *Handshake) mixHash(data []byte) { + mixHash(&h.hash, &h.hash, data) +} + +func (h *Handshake) mixKey(data []byte) { + mixKey(&h.chainKey, &h.chainKey, data) +} + +/* Do basic precomputations + */ +func init() { + InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction)) + mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier)) +} + +func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if isZero(handshake.precomputedStaticStatic[:]) { + return nil, errors.New("static shared secret is zero") + } + + // create ephemeral key + + var err error + handshake.hash = InitialHash + handshake.chainKey = InitialChainKey + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + + // assign index + + device.indexTable.Delete(handshake.localIndex) + handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) + + if err != nil { + return nil, err + } + + handshake.mixHash(handshake.remoteStatic[:]) + + msg := MessageInitiation{ + Type: MessageInitiationType, + Ephemeral: handshake.localEphemeral.publicKey(), + Sender: handshake.localIndex, + } + + handshake.mixKey(msg.Ephemeral[:]) + handshake.mixHash(msg.Ephemeral[:]) + + // encrypt static key + + func() { + var key [chacha20poly1305.KeySize]byte + ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + ss[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) + }() + handshake.mixHash(msg.Static[:]) + + // encrypt timestamp + + timestamp := tai64n.Now() + func() { + var key [chacha20poly1305.KeySize]byte + KDF2( + &handshake.chainKey, + &key, + handshake.chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) + }() + + handshake.mixHash(msg.Timestamp[:]) + handshake.state = HandshakeInitiationCreated + return &msg, nil +} + +func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + if msg.Type != MessageInitiationType { + return nil + } + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) + mixHash(&hash, &hash, msg.Ephemeral[:]) + mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) + + // decrypt static key + + var err error + var peerPK NoisePublicKey + func() { + var key [chacha20poly1305.KeySize]byte + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + KDF2(&chainKey, &key, chainKey[:], ss[:]) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) + }() + if err != nil { + return nil + } + mixHash(&hash, &hash, msg.Static[:]) + + // lookup peer + + peer := device.LookupPeer(peerPK) + if peer == nil { + return nil + } + + handshake := &peer.handshake + if isZero(handshake.precomputedStaticStatic[:]) { + return nil + } + + // verify identity + + var timestamp tai64n.Timestamp + var key [chacha20poly1305.KeySize]byte + + handshake.mutex.RLock() + KDF2( + &chainKey, + &key, + chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) + if err != nil { + handshake.mutex.RUnlock() + return nil + } + mixHash(&hash, &hash, msg.Timestamp[:]) + + // protect against replay & flood + + var ok bool + ok = timestamp.After(handshake.lastTimestamp) + ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate + handshake.mutex.RUnlock() + if !ok { + return nil + } + + // update handshake state + + handshake.mutex.Lock() + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.remoteEphemeral = msg.Ephemeral + handshake.lastTimestamp = timestamp + handshake.lastInitiationConsumption = time.Now() + handshake.state = HandshakeInitiationConsumed + + handshake.mutex.Unlock() + + setZero(hash[:]) + setZero(chainKey[:]) + + return peer +} + +func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) { + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + if handshake.state != HandshakeInitiationConsumed { + return nil, errors.New("handshake initiation must be consumed first") + } + + // assign index + + var err error + device.indexTable.Delete(handshake.localIndex) + handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake) + if err != nil { + return nil, err + } + + var msg MessageResponse + msg.Type = MessageResponseType + msg.Sender = handshake.localIndex + msg.Receiver = handshake.remoteIndex + + // create ephemeral key + + handshake.localEphemeral, err = newPrivateKey() + if err != nil { + return nil, err + } + msg.Ephemeral = handshake.localEphemeral.publicKey() + handshake.mixHash(msg.Ephemeral[:]) + handshake.mixKey(msg.Ephemeral[:]) + + func() { + ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + handshake.mixKey(ss[:]) + ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + handshake.mixKey(ss[:]) + }() + + // add preshared key + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + + KDF3( + &handshake.chainKey, + &tau, + &key, + handshake.chainKey[:], + handshake.presharedKey[:], + ) + + handshake.mixHash(tau[:]) + + func() { + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) + }() + + handshake.state = HandshakeResponseCreated + + return &msg, nil +} + +func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { + if msg.Type != MessageResponseType { + return nil + } + + // lookup handshake by receiver + + lookup := device.indexTable.Lookup(msg.Receiver) + handshake := lookup.handshake + if handshake == nil { + return nil + } + + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) + + ok := func() bool { + + // lock handshake state + + handshake.mutex.RLock() + defer handshake.mutex.RUnlock() + + if handshake.state != HandshakeInitiationCreated { + return false + } + + // lock private key for reading + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + // finish 3-way DH + + mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) + mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) + + func() { + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + func() { + ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + KDF3( + &chainKey, + &tau, + &key, + chainKey[:], + handshake.presharedKey[:], + ) + mixHash(&hash, &hash, tau[:]) + + // authenticate transcript + + aead, _ := chacha20poly1305.New(key[:]) + _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + if err != nil { + return false + } + mixHash(&hash, &hash, msg.Empty[:]) + return true + }() + + if !ok { + return nil + } + + // update handshake state + + handshake.mutex.Lock() + + handshake.hash = hash + handshake.chainKey = chainKey + handshake.remoteIndex = msg.Sender + handshake.state = HandshakeResponseConsumed + + handshake.mutex.Unlock() + + setZero(hash[:]) + setZero(chainKey[:]) + + return lookup.peer +} + +/* Derives a new keypair from the current handshake state + * + */ +func (peer *Peer) BeginSymmetricSession() error { + device := peer.device + handshake := &peer.handshake + handshake.mutex.Lock() + defer handshake.mutex.Unlock() + + // derive keys + + var isInitiator bool + var sendKey [chacha20poly1305.KeySize]byte + var recvKey [chacha20poly1305.KeySize]byte + + if handshake.state == HandshakeResponseConsumed { + KDF2( + &sendKey, + &recvKey, + handshake.chainKey[:], + nil, + ) + isInitiator = true + } else if handshake.state == HandshakeResponseCreated { + KDF2( + &recvKey, + &sendKey, + handshake.chainKey[:], + nil, + ) + isInitiator = false + } else { + return errors.New("invalid state for keypair derivation") + } + + // zero handshake + + setZero(handshake.chainKey[:]) + setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line. + setZero(handshake.localEphemeral[:]) + peer.handshake.state = HandshakeZeroed + + // create AEAD instances + + keypair := new(Keypair) + keypair.send, _ = chacha20poly1305.New(sendKey[:]) + keypair.receive, _ = chacha20poly1305.New(recvKey[:]) + + setZero(sendKey[:]) + setZero(recvKey[:]) + + keypair.created = time.Now() + keypair.sendNonce = 0 + keypair.replayFilter.Init() + keypair.isInitiator = isInitiator + keypair.localIndex = peer.handshake.localIndex + keypair.remoteIndex = peer.handshake.remoteIndex + + // remap index + + device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair) + handshake.localIndex = 0 + + // rotate key pairs + + keypairs := &peer.keypairs + keypairs.Lock() + defer keypairs.Unlock() + + previous := keypairs.previous + next := keypairs.next + current := keypairs.current + + if isInitiator { + if next != nil { + keypairs.next = nil + keypairs.previous = next + device.DeleteKeypair(current) + } else { + keypairs.previous = current + } + device.DeleteKeypair(previous) + keypairs.current = keypair + } else { + keypairs.next = keypair + device.DeleteKeypair(next) + keypairs.previous = nil + device.DeleteKeypair(previous) + } + + return nil +} + +func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { + keypairs := &peer.keypairs + if keypairs.next != receivedKeypair { + return false + } + keypairs.Lock() + defer keypairs.Unlock() + if keypairs.next != receivedKeypair { + return false + } + old := keypairs.previous + keypairs.previous = keypairs.current + peer.device.DeleteKeypair(old) + keypairs.current = keypairs.next + keypairs.next = nil + return true +} diff --git a/device/noise-types.go b/device/noise-types.go new file mode 100644 index 0000000..82b12c1 --- /dev/null +++ b/device/noise-types.go @@ -0,0 +1,81 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/subtle" + "encoding/hex" + "errors" + "golang.org/x/crypto/chacha20poly1305" +) + +const ( + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 +) + +type ( + NoisePublicKey [NoisePublicKeySize]byte + NoisePrivateKey [NoisePrivateKeySize]byte + NoiseSymmetricKey [chacha20poly1305.KeySize]byte + NoiseNonce uint64 // padded to 12-bytes +) + +func loadExactHex(dst []byte, src string) error { + slice, err := hex.DecodeString(src) + if err != nil { + return err + } + if len(slice) != len(dst) { + return errors.New("hex string does not fit the slice") + } + copy(dst, slice) + return nil +} + +func (key NoisePrivateKey) IsZero() bool { + var zero NoisePrivateKey + return key.Equals(zero) +} + +func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoisePrivateKey) FromHex(src string) (err error) { + err = loadExactHex(key[:], src) + key.clamp() + return +} + +func (key NoisePrivateKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key *NoisePublicKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoisePublicKey) ToHex() string { + return hex.EncodeToString(key[:]) +} + +func (key NoisePublicKey) IsZero() bool { + var zero NoisePublicKey + return key.Equals(zero) +} + +func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoiseSymmetricKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoiseSymmetricKey) ToHex() string { + return hex.EncodeToString(key[:]) +} diff --git a/device/noise_test.go b/device/noise_test.go new file mode 100644 index 0000000..6ba3f2e --- /dev/null +++ b/device/noise_test.go @@ -0,0 +1,144 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func TestCurveWrappers(t *testing.T) { + sk1, err := newPrivateKey() + assertNil(t, err) + + sk2, err := newPrivateKey() + assertNil(t, err) + + pk1 := sk1.publicKey() + pk2 := sk2.publicKey() + + ss1 := sk1.sharedSecret(pk2) + ss2 := sk2.sharedSecret(pk1) + + if ss1 != ss2 { + t.Fatal("Failed to compute shared secet") + } +} + +func TestNoiseHandshake(t *testing.T) { + dev1 := randDevice(t) + dev2 := randDevice(t) + + defer dev1.Close() + defer dev2.Close() + + peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) + peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) + + assertEqual( + t, + peer1.handshake.precomputedStaticStatic[:], + peer2.handshake.precomputedStaticStatic[:], + ) + + /* simulate handshake */ + + // initiation message + + t.Log("exchange initiation message") + + msg1, err := dev1.CreateMessageInitiation(peer2) + assertNil(t, err) + + packet := make([]byte, 0, 256) + writer := bytes.NewBuffer(packet) + err = binary.Write(writer, binary.LittleEndian, msg1) + assertNil(t, err) + peer := dev2.ConsumeMessageInitiation(msg1) + if peer == nil { + t.Fatal("handshake failed at initiation message") + } + + assertEqual( + t, + peer1.handshake.chainKey[:], + peer2.handshake.chainKey[:], + ) + + assertEqual( + t, + peer1.handshake.hash[:], + peer2.handshake.hash[:], + ) + + // response message + + t.Log("exchange response message") + + msg2, err := dev2.CreateMessageResponse(peer1) + assertNil(t, err) + + peer = dev1.ConsumeMessageResponse(msg2) + if peer == nil { + t.Fatal("handshake failed at response message") + } + + assertEqual( + t, + peer1.handshake.chainKey[:], + peer2.handshake.chainKey[:], + ) + + assertEqual( + t, + peer1.handshake.hash[:], + peer2.handshake.hash[:], + ) + + // key pairs + + t.Log("deriving keys") + + err = peer1.BeginSymmetricSession() + if err != nil { + t.Fatal("failed to derive keypair for peer 1", err) + } + + err = peer2.BeginSymmetricSession() + if err != nil { + t.Fatal("failed to derive keypair for peer 2", err) + } + + key1 := peer1.keypairs.next + key2 := peer2.keypairs.current + + // encrypting / decryption test + + t.Log("test key pairs") + + func() { + testMsg := []byte("wireguard test message 1") + var err error + var out []byte + var nonce [12]byte + out = key1.send.Seal(out, nonce[:], testMsg, nil) + out, err = key2.receive.Open(out[:0], nonce[:], out, nil) + assertNil(t, err) + assertEqual(t, out, testMsg) + }() + + func() { + testMsg := []byte("wireguard test message 2") + var err error + var out []byte + var nonce [12]byte + out = key2.send.Seal(out, nonce[:], testMsg, nil) + out, err = key1.receive.Open(out[:0], nonce[:], out, nil) + assertNil(t, err) + assertEqual(t, out, testMsg) + }() +} diff --git a/device/peer.go b/device/peer.go new file mode 100644 index 0000000..af3ef9d --- /dev/null +++ b/device/peer.go @@ -0,0 +1,270 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "encoding/base64" + "errors" + "fmt" + "sync" + "time" +) + +const ( + PeerRoutineNumber = 3 +) + +type Peer struct { + isRunning AtomicBool + sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer + keypairs Keypairs + handshake Handshake + device *Device + endpoint Endpoint + persistentKeepaliveInterval uint16 + + // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly + stats struct { + txBytes uint64 // bytes send to peer (endpoint) + rxBytes uint64 // bytes received from peer + lastHandshakeNano int64 // nano seconds since epoch + } + + timers struct { + retransmitHandshake *Timer + sendKeepalive *Timer + newHandshake *Timer + zeroKeyMaterial *Timer + persistentKeepalive *Timer + handshakeAttempts uint32 + needAnotherKeepalive AtomicBool + sentLastMinuteHandshake AtomicBool + } + + signals struct { + newKeypairArrived chan struct{} + flushNonceQueue chan struct{} + } + + queue struct { + nonce chan *QueueOutboundElement // nonce / pre-handshake queue + outbound chan *QueueOutboundElement // sequential ordering of work + inbound chan *QueueInboundElement // sequential ordering of work + packetInNonceQueueIsAwaitingKey AtomicBool + } + + routines struct { + sync.Mutex // held when stopping / starting routines + starting sync.WaitGroup // routines pending start + stopping sync.WaitGroup // routines pending stop + stop chan struct{} // size 0, stop all go routines in peer + } + + cookieGenerator CookieGenerator +} + +func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { + + if device.isClosed.Get() { + return nil, errors.New("device closed") + } + + // lock resources + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + device.peers.Lock() + defer device.peers.Unlock() + + // check if over limit + + if len(device.peers.keyMap) >= MaxPeers { + return nil, errors.New("too many peers") + } + + // create peer + + peer := new(Peer) + peer.Lock() + defer peer.Unlock() + + peer.cookieGenerator.Init(pk) + peer.device = device + peer.isRunning.Set(false) + + // map public key + + _, ok := device.peers.keyMap[pk] + if ok { + return nil, errors.New("adding existing peer") + } + device.peers.keyMap[pk] = peer + + // pre-compute DH + + handshake := &peer.handshake + handshake.mutex.Lock() + handshake.remoteStatic = pk + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) + handshake.mutex.Unlock() + + // reset endpoint + + peer.endpoint = nil + + // start peer + + if peer.device.isUp.Get() { + peer.Start() + } + + return peer, nil +} + +func (peer *Peer) SendBuffer(buffer []byte) error { + peer.device.net.RLock() + defer peer.device.net.RUnlock() + + if peer.device.net.bind == nil { + return errors.New("no bind") + } + + peer.RLock() + defer peer.RUnlock() + + if peer.endpoint == nil { + return errors.New("no known endpoint for peer") + } + + return peer.device.net.bind.Send(buffer, peer.endpoint) +} + +func (peer *Peer) String() string { + base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + abbreviatedKey := "invalid" + if len(base64Key) == 44 { + abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] + } + return fmt.Sprintf("peer(%s)", abbreviatedKey) +} + +func (peer *Peer) Start() { + + // should never start a peer on a closed device + + if peer.device.isClosed.Get() { + return + } + + // prevent simultaneous start/stop operations + + peer.routines.Lock() + defer peer.routines.Unlock() + + if peer.isRunning.Get() { + return + } + + device := peer.device + device.log.Debug.Println(peer, "- Starting...") + + // reset routine state + + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + peer.routines.stop = make(chan struct{}) + peer.routines.starting.Add(PeerRoutineNumber) + peer.routines.stopping.Add(PeerRoutineNumber) + + // prepare queues + + peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + + peer.timersInit() + peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + peer.signals.newKeypairArrived = make(chan struct{}, 1) + peer.signals.flushNonceQueue = make(chan struct{}, 1) + + // wait for routines to start + + go peer.RoutineNonce() + go peer.RoutineSequentialSender() + go peer.RoutineSequentialReceiver() + + peer.routines.starting.Wait() + peer.isRunning.Set(true) +} + +func (peer *Peer) ZeroAndFlushAll() { + device := peer.device + + // clear key pairs + + keypairs := &peer.keypairs + keypairs.Lock() + device.DeleteKeypair(keypairs.previous) + device.DeleteKeypair(keypairs.current) + device.DeleteKeypair(keypairs.next) + keypairs.previous = nil + keypairs.current = nil + keypairs.next = nil + keypairs.Unlock() + + // clear handshake state + + handshake := &peer.handshake + handshake.mutex.Lock() + device.indexTable.Delete(handshake.localIndex) + handshake.Clear() + handshake.mutex.Unlock() + + peer.FlushNonceQueue() +} + +func (peer *Peer) Stop() { + + // prevent simultaneous start/stop operations + + if !peer.isRunning.Swap(false) { + return + } + + peer.routines.starting.Wait() + + peer.routines.Lock() + defer peer.routines.Unlock() + + peer.device.log.Debug.Println(peer, "- Stopping...") + + peer.timersStop() + + // stop & wait for ongoing peer routines + + close(peer.routines.stop) + peer.routines.stopping.Wait() + + // close queues + + close(peer.queue.nonce) + close(peer.queue.outbound) + close(peer.queue.inbound) + + peer.ZeroAndFlushAll() +} + +var roamingDisabled bool + +func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { + if roamingDisabled { + return + } + peer.Lock() + peer.endpoint = endpoint + peer.Unlock() +} diff --git a/device/pools.go b/device/pools.go new file mode 100644 index 0000000..98f4ef1 --- /dev/null +++ b/device/pools.go @@ -0,0 +1,89 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import "sync" + +func (device *Device) PopulatePools() { + if PreallocatedBuffersPerPool == 0 { + device.pool.messageBufferPool = &sync.Pool{ + New: func() interface{} { + return new([MaxMessageSize]byte) + }, + } + device.pool.inboundElementPool = &sync.Pool{ + New: func() interface{} { + return new(QueueInboundElement) + }, + } + device.pool.outboundElementPool = &sync.Pool{ + New: func() interface{} { + return new(QueueOutboundElement) + }, + } + } else { + device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool) + for i := 0; i < PreallocatedBuffersPerPool; i += 1 { + device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) + } + device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool) + for i := 0; i < PreallocatedBuffersPerPool; i += 1 { + device.pool.inboundElementReuseChan <- new(QueueInboundElement) + } + device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool) + for i := 0; i < PreallocatedBuffersPerPool; i += 1 { + device.pool.outboundElementReuseChan <- new(QueueOutboundElement) + } + } +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + if PreallocatedBuffersPerPool == 0 { + return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) + } else { + return <-device.pool.messageBufferReuseChan + } +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + if PreallocatedBuffersPerPool == 0 { + device.pool.messageBufferPool.Put(msg) + } else { + device.pool.messageBufferReuseChan <- msg + } +} + +func (device *Device) GetInboundElement() *QueueInboundElement { + if PreallocatedBuffersPerPool == 0 { + return device.pool.inboundElementPool.Get().(*QueueInboundElement) + } else { + return <-device.pool.inboundElementReuseChan + } +} + +func (device *Device) PutInboundElement(msg *QueueInboundElement) { + if PreallocatedBuffersPerPool == 0 { + device.pool.inboundElementPool.Put(msg) + } else { + device.pool.inboundElementReuseChan <- msg + } +} + +func (device *Device) GetOutboundElement() *QueueOutboundElement { + if PreallocatedBuffersPerPool == 0 { + return device.pool.outboundElementPool.Get().(*QueueOutboundElement) + } else { + return <-device.pool.outboundElementReuseChan + } +} + +func (device *Device) PutOutboundElement(msg *QueueOutboundElement) { + if PreallocatedBuffersPerPool == 0 { + device.pool.outboundElementPool.Put(msg) + } else { + device.pool.outboundElementReuseChan <- msg + } +} diff --git a/device/queueconstants.go b/device/queueconstants.go new file mode 100644 index 0000000..3e94b7f --- /dev/null +++ b/device/queueconstants.go @@ -0,0 +1,16 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +/* Implementation specific constants */ + +const ( + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram + PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth +) diff --git a/device/receive.go b/device/receive.go new file mode 100644 index 0000000..5c837c1 --- /dev/null +++ b/device/receive.go @@ -0,0 +1,641 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "strconv" + "sync" + "sync/atomic" + "time" +) + +type QueueHandshakeElement struct { + msgType uint32 + packet []byte + endpoint Endpoint + buffer *[MaxMessageSize]byte +} + +type QueueInboundElement struct { + dropped int32 + sync.Mutex + buffer *[MaxMessageSize]byte + packet []byte + counter uint64 + keypair *Keypair + endpoint Endpoint +} + +func (elem *QueueInboundElement) Drop() { + atomic.StoreInt32(&elem.dropped, AtomicTrue) +} + +func (elem *QueueInboundElement) IsDropped() bool { + return atomic.LoadInt32(&elem.dropped) == AtomicTrue +} + +func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { + select { + case inboundQueue <- element: + select { + case decryptionQueue <- element: + return true + default: + element.Drop() + element.Unlock() + return false + } + default: + device.PutInboundElement(element) + return false + } +} + +func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { + select { + case queue <- element: + return true + default: + return false + } +} + +/* Called when a new authenticated message has been received + * + * NOTE: Not thread safe, but called by sequential receiver! + */ +func (peer *Peer) keepKeyFreshReceiving() { + if peer.timers.sentLastMinuteHandshake.Get() { + return + } + keypair := peer.keypairs.Current() + if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { + peer.timers.sentLastMinuteHandshake.Set(true) + peer.SendHandshakeInitiation(false) + } +} + +/* Receives incoming datagrams for the device + * + * Every time the bind is updated a new routine is started for + * IPv4 and IPv6 (separately) + */ +func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { + + logDebug := device.log.Debug + defer func() { + logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.net.stopping.Done() + }() + + logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started") + device.net.starting.Done() + + // receive datagrams until conn is closed + + buffer := device.GetMessageBuffer() + + var ( + err error + size int + endpoint Endpoint + ) + + for { + + // read next datagram + + switch IP { + case ipv4.Version: + size, endpoint, err = bind.ReceiveIPv4(buffer[:]) + case ipv6.Version: + size, endpoint, err = bind.ReceiveIPv6(buffer[:]) + default: + panic("invalid IP version") + } + + if err != nil { + device.PutMessageBuffer(buffer) + return + } + + if size < MinMessageSize { + continue + } + + // check size of packet + + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) + + var okay bool + + switch msgType { + + // check if transport + + case MessageTransportType: + + // check size + + if len(packet) < MessageTransportSize { + continue + } + + // lookup key pair + + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } + + // check keypair expiry + + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } + + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = buffer + elem.keypair = keypair + elem.dropped = AtomicFalse + elem.endpoint = endpoint + elem.counter = 0 + elem.Mutex = sync.Mutex{} + elem.Lock() + + // add to decryption queues + + if peer.isRunning.Get() { + if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { + buffer = device.GetMessageBuffer() + } + } + + continue + + // otherwise it is a fixed size & handshake related packet + + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize + + case MessageResponseType: + okay = len(packet) == MessageResponseSize + + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize + + default: + logDebug.Println("Received message with unknown type") + } + + if okay { + if (device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, + }, + )) { + buffer = device.GetMessageBuffer() + } + } + } +} + +func (device *Device) RoutineDecryption() { + + var nonce [chacha20poly1305.NonceSize]byte + + logDebug := device.log.Debug + defer func() { + logDebug.Println("Routine: decryption worker - stopped") + device.state.stopping.Done() + }() + logDebug.Println("Routine: decryption worker - started") + device.state.starting.Done() + + for { + select { + case <-device.signals.stop: + return + + case elem, ok := <-device.queue.decryption: + + if !ok { + return + } + + // check if dropped + + if elem.IsDropped() { + continue + } + + // split message into fields + + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // expand nonce + + nonce[0x4] = counter[0x0] + nonce[0x5] = counter[0x1] + nonce[0x6] = counter[0x2] + nonce[0x7] = counter[0x3] + + nonce[0x8] = counter[0x4] + nonce[0x9] = counter[0x5] + nonce[0xa] = counter[0x6] + nonce[0xb] = counter[0x7] + + // decrypt and release to consumer + + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.Drop() + device.PutMessageBuffer(elem.buffer) + } + elem.Unlock() + } + } +} + +/* Handles incoming packets related to handshake + */ +func (device *Device) RoutineHandshake() { + + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + + var elem QueueHandshakeElement + var ok bool + + defer func() { + logDebug.Println("Routine: handshake worker - stopped") + device.state.stopping.Done() + if elem.buffer != nil { + device.PutMessageBuffer(elem.buffer) + } + }() + + logDebug.Println("Routine: handshake worker - started") + device.state.starting.Done() + + for { + if elem.buffer != nil { + device.PutMessageBuffer(elem.buffer) + elem.buffer = nil + } + + select { + case elem, ok = <-device.queue.handshake: + case <-device.signals.stop: + return + } + + if !ok { + return + } + + // handle cookie fields and ratelimiting + + switch elem.msgType { + + case MessageCookieReplyType: + + // unmarshal packet + + var reply MessageCookieReply + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &reply) + if err != nil { + logDebug.Println("Failed to decode cookie reply") + return + } + + // lookup peer from index + + entry := device.indexTable.Lookup(reply.Receiver) + + if entry.peer == nil { + continue + } + + // consume reply + + if peer := entry.peer; peer.isRunning.Get() { + logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString()) + if !peer.cookieGenerator.ConsumeReply(&reply) { + logDebug.Println("Could not decrypt invalid cookie response") + } + } + + continue + + case MessageInitiationType, MessageResponseType: + + // check mac fields and maybe ratelimit + + if !device.cookieChecker.CheckMAC1(elem.packet) { + logDebug.Println("Received packet with invalid mac1") + continue + } + + // endpoints destination address is the source of the datagram + + if device.IsUnderLoad() { + + // verify MAC2 field + + if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { + device.SendHandshakeCookie(&elem) + continue + } + + // check ratelimiter + + if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { + continue + } + } + + default: + logError.Println("Invalid packet ended up in the handshake queue") + continue + } + + // handle handshake initiation/response content + + switch elem.msgType { + case MessageInitiationType: + + // unmarshal + + var msg MessageInitiation + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode initiation message") + continue + } + + // consume initiation + + peer := device.ConsumeMessageInitiation(&msg) + if peer == nil { + logInfo.Println( + "Received invalid initiation message from", + elem.endpoint.DstToString(), + ) + continue + } + + // update timers + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + logDebug.Println(peer, "- Received handshake initiation") + + peer.SendHandshakeResponse() + + case MessageResponseType: + + // unmarshal + + var msg MessageResponse + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode response message") + continue + } + + // consume response + + peer := device.ConsumeMessageResponse(&msg) + if peer == nil { + logInfo.Println( + "Received invalid response message from", + elem.endpoint.DstToString(), + ) + continue + } + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + logDebug.Println(peer, "- Received handshake response") + + // update timers + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // derive keypair + + err = peer.BeginSymmetricSession() + + if err != nil { + logError.Println(peer, "- Failed to derive keypair:", err) + continue + } + + peer.timersSessionDerived() + peer.timersHandshakeComplete() + peer.SendKeepalive() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } + } + } +} + +func (peer *Peer) RoutineSequentialReceiver() { + + device := peer.device + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + + var elem *QueueInboundElement + var ok bool + + defer func() { + logDebug.Println(peer, "- Routine: sequential receiver - stopped") + peer.routines.stopping.Done() + if elem != nil { + if !elem.IsDropped() { + device.PutMessageBuffer(elem.buffer) + } + device.PutInboundElement(elem) + } + }() + + logDebug.Println(peer, "- Routine: sequential receiver - started") + + peer.routines.starting.Done() + + for { + if elem != nil { + if !elem.IsDropped() { + device.PutMessageBuffer(elem.buffer) + } + device.PutInboundElement(elem) + elem = nil + } + + select { + + case <-peer.routines.stop: + return + + case elem, ok = <-peer.queue.inbound: + + if !ok { + return + } + + // wait for decryption + + elem.Lock() + + if elem.IsDropped() { + continue + } + + // check for replay + + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + // check if using new keypair + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } + } + + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // check for keepalive + + if len(elem.packet) == 0 { + logDebug.Println(peer, "- Receiving keepalive packet") + continue + } + peer.timersDataReceived() + + // verify source and strip padding + + switch elem.packet[0] >> 4 { + case ipv4.Version: + + // strip padding + + if len(elem.packet) < ipv4.HeaderLen { + continue + } + + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv4 source + + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.LookupIPv4(src) != peer { + logInfo.Println( + "IPv4 packet with disallowed source address from", + peer, + ) + continue + } + + case ipv6.Version: + + // strip padding + + if len(elem.packet) < ipv6.HeaderLen { + continue + } + + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv6 source + + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.LookupIPv6(src) != peer { + logInfo.Println( + peer, + "sent packet with disallowed IPv6 source", + ) + continue + } + + default: + logInfo.Println("Packet with invalid IP version from", peer) + continue + } + + // write to tun device + + offset := MessageTransportOffsetContent + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) + } + } + } +} diff --git a/device/send.go b/device/send.go new file mode 100644 index 0000000..b4e23c7 --- /dev/null +++ b/device/send.go @@ -0,0 +1,618 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "sync" + "sync/atomic" + "time" +) + +/* Outbound flow + * + * 1. TUN queue + * 2. Routing (sequential) + * 3. Nonce assignment (sequential) + * 4. Encryption (parallel) + * 5. Transmission (sequential) + * + * The functions in this file occur (roughly) in the order in + * which the packets are processed. + * + * Locking, Producers and Consumers + * + * The order of packets (per peer) must be maintained, + * but encryption of packets happen out-of-order: + * + * The sequential consumers will attempt to take the lock, + * workers release lock when they have completed work (encryption) on the packet. + * + * If the element is inserted into the "encryption queue", + * the content is preceded by enough "junk" to contain the transport header + * (to allow the construction of transport messages in-place) + */ + +type QueueOutboundElement struct { + dropped int32 + sync.Mutex + buffer *[MaxMessageSize]byte // slice holding the packet data + packet []byte // slice of "buffer" (always!) + nonce uint64 // nonce for encryption + keypair *Keypair // keypair for encryption + peer *Peer // related peer +} + +func (device *Device) NewOutboundElement() *QueueOutboundElement { + elem := device.GetOutboundElement() + elem.dropped = AtomicFalse + elem.buffer = device.GetMessageBuffer() + elem.Mutex = sync.Mutex{} + elem.nonce = 0 + elem.keypair = nil + elem.peer = nil + return elem +} + +func (elem *QueueOutboundElement) Drop() { + atomic.StoreInt32(&elem.dropped, AtomicTrue) +} + +func (elem *QueueOutboundElement) IsDropped() bool { + return atomic.LoadInt32(&elem.dropped) == AtomicTrue +} + +func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + device.PutMessageBuffer(old.buffer) + device.PutOutboundElement(old) + default: + } + } + } +} + +func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) { + select { + case outboundQueue <- element: + select { + case encryptionQueue <- element: + return + default: + element.Drop() + element.peer.device.PutMessageBuffer(element.buffer) + element.Unlock() + } + default: + element.peer.device.PutMessageBuffer(element.buffer) + element.peer.device.PutOutboundElement(element) + } +} + +/* Queues a keepalive if no packets are queued for peer + */ +func (peer *Peer) SendKeepalive() bool { + if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() { + return false + } + elem := peer.device.NewOutboundElement() + elem.packet = nil + select { + case peer.queue.nonce <- elem: + peer.device.log.Debug.Println(peer, "- Sending keepalive packet") + return true + default: + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + return false + } +} + +func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { + if !isRetry { + atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + } + + peer.handshake.mutex.RLock() + if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { + peer.handshake.mutex.RUnlock() + return nil + } + peer.handshake.mutex.RUnlock() + + peer.handshake.mutex.Lock() + if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { + peer.handshake.mutex.Unlock() + return nil + } + peer.handshake.lastSentHandshake = time.Now() + peer.handshake.mutex.Unlock() + + peer.device.log.Debug.Println(peer, "- Sending handshake initiation") + + msg, err := peer.device.CreateMessageInitiation(peer) + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err) + return err + } + + var buff [MessageInitiationSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, msg) + packet := writer.Bytes() + peer.cookieGenerator.AddMacs(packet) + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + err = peer.SendBuffer(packet) + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err) + } + peer.timersHandshakeInitiated() + + return err +} + +func (peer *Peer) SendHandshakeResponse() error { + peer.handshake.mutex.Lock() + peer.handshake.lastSentHandshake = time.Now() + peer.handshake.mutex.Unlock() + + peer.device.log.Debug.Println(peer, "- Sending handshake response") + + response, err := peer.device.CreateMessageResponse(peer) + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to create response message:", err) + return err + } + + var buff [MessageResponseSize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, response) + packet := writer.Bytes() + peer.cookieGenerator.AddMacs(packet) + + err = peer.BeginSymmetricSession() + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err) + return err + } + + peer.timersSessionDerived() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + err = peer.SendBuffer(packet) + if err != nil { + peer.device.log.Error.Println(peer, "- Failed to send handshake response", err) + } + return err +} + +func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { + + device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString()) + + sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) + reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) + if err != nil { + device.log.Error.Println("Failed to create cookie reply:", err) + return err + } + + var buff [MessageCookieReplySize]byte + writer := bytes.NewBuffer(buff[:0]) + binary.Write(writer, binary.LittleEndian, reply) + device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + if err != nil { + device.log.Error.Println("Failed to send cookie reply:", err) + } + return err +} + +func (peer *Peer) keepKeyFreshSending() { + keypair := peer.keypairs.Current() + if keypair == nil { + return + } + nonce := atomic.LoadUint64(&keypair.sendNonce) + if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) { + peer.SendHandshakeInitiation(false) + } +} + +/* Reads packets from the TUN and inserts + * into nonce queue for peer + * + * Obs. Single instance per TUN device + */ +func (device *Device) RoutineReadFromTUN() { + + logDebug := device.log.Debug + logError := device.log.Error + + defer func() { + logDebug.Println("Routine: TUN reader - stopped") + device.state.stopping.Done() + }() + + logDebug.Println("Routine: TUN reader - started") + device.state.starting.Done() + + var elem *QueueOutboundElement + + for { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + elem = device.NewOutboundElement() + + // read packet + + offset := MessageTransportHeaderSize + size, err := device.tun.device.Read(elem.buffer[:], offset) + + if err != nil { + if !device.isClosed.Get() { + logError.Println("Failed to read packet from TUN device:", err) + device.Close() + } + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + return + } + + if size == 0 || size > MaxContentSize { + continue + } + + elem.packet = elem.buffer[offset : offset+size] + + // lookup peer + + var peer *Peer + switch elem.packet[0] >> 4 { + case ipv4.Version: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.allowedips.LookupIPv4(dst) + + case ipv6.Version: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.allowedips.LookupIPv6(dst) + + default: + logDebug.Println("Received packet with unknown IP version") + } + + if peer == nil { + continue + } + + // insert into nonce/pre-handshake queue + + if peer.isRunning.Get() { + if peer.queue.packetInNonceQueueIsAwaitingKey.Get() { + peer.SendHandshakeInitiation(false) + } + addToNonceQueue(peer.queue.nonce, elem, device) + elem = nil + } + } +} + +func (peer *Peer) FlushNonceQueue() { + select { + case peer.signals.flushNonceQueue <- struct{}{}: + default: + } +} + +/* Queues packets when there is no handshake. + * Then assigns nonces to packets sequentially + * and creates "work" structs for workers + * + * Obs. A single instance per peer + */ +func (peer *Peer) RoutineNonce() { + var keypair *Keypair + + device := peer.device + logDebug := device.log.Debug + + flush := func() { + for { + select { + case elem := <-peer.queue.nonce: + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + default: + return + } + } + } + + defer func() { + flush() + logDebug.Println(peer, "- Routine: nonce worker - stopped") + peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) + peer.routines.stopping.Done() + }() + + peer.routines.starting.Done() + logDebug.Println(peer, "- Routine: nonce worker - started") + + for { + NextPacket: + peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) + + select { + case <-peer.routines.stop: + return + + case <-peer.signals.flushNonceQueue: + flush() + goto NextPacket + + case elem, ok := <-peer.queue.nonce: + + if !ok { + return + } + + // make sure to always pick the newest key + + for { + + // check validity of newest key pair + + keypair = peer.keypairs.Current() + if keypair != nil && keypair.sendNonce < RejectAfterMessages { + if time.Now().Sub(keypair.created) < RejectAfterTime { + break + } + } + peer.queue.packetInNonceQueueIsAwaitingKey.Set(true) + + // no suitable key pair, request for new handshake + + select { + case <-peer.signals.newKeypairArrived: + default: + } + + peer.SendHandshakeInitiation(false) + + // wait for key to be established + + logDebug.Println(peer, "- Awaiting keypair") + + select { + case <-peer.signals.newKeypairArrived: + logDebug.Println(peer, "- Obtained awaited keypair") + + case <-peer.signals.flushNonceQueue: + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + flush() + goto NextPacket + + case <-peer.routines.stop: + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + return + } + } + peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) + + // populate work element + + elem.peer = peer + elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 + + // double check in case of race condition added by future code + + if elem.nonce >= RejectAfterMessages { + atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + goto NextPacket + } + + elem.keypair = keypair + elem.dropped = AtomicFalse + elem.Lock() + + // add to parallel and sequential queue + addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) + } + } +} + +/* Encrypts the elements in the queue + * and marks them for sequential consumption (by releasing the mutex) + * + * Obs. One instance per core + */ +func (device *Device) RoutineEncryption() { + + var nonce [chacha20poly1305.NonceSize]byte + + logDebug := device.log.Debug + + defer func() { + for { + select { + case elem, ok := <-device.queue.encryption: + if ok && !elem.IsDropped() { + elem.Drop() + device.PutMessageBuffer(elem.buffer) + elem.Unlock() + } + default: + goto out + } + } + out: + logDebug.Println("Routine: encryption worker - stopped") + device.state.stopping.Done() + }() + + logDebug.Println("Routine: encryption worker - started") + device.state.starting.Done() + + for { + + // fetch next element + + select { + case <-device.signals.stop: + return + + case elem, ok := <-device.queue.encryption: + + if !ok { + return + } + + // check if dropped + + if elem.IsDropped() { + continue + } + + // populate header fields + + header := elem.buffer[:MessageTransportHeaderSize] + + fieldType := header[0:4] + fieldReceiver := header[4:8] + fieldNonce := header[8:16] + + binary.LittleEndian.PutUint32(fieldType, MessageTransportType) + binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex) + binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) + + // pad content to multiple of 16 + + mtu := int(atomic.LoadInt32(&device.tun.mtu)) + lastUnit := len(elem.packet) % mtu + paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1) + if paddedSize > mtu { + paddedSize = mtu + } + for i := len(elem.packet); i < paddedSize; i++ { + elem.packet = append(elem.packet, 0) + } + + // encrypt content and release to consumer + + binary.LittleEndian.PutUint64(nonce[4:], elem.nonce) + elem.packet = elem.keypair.send.Seal( + header, + nonce[:], + elem.packet, + nil, + ) + elem.Unlock() + } + } +} + +/* Sequentially reads packets from queue and sends to endpoint + * + * Obs. Single instance per peer. + * The routine terminates then the outbound queue is closed. + */ +func (peer *Peer) RoutineSequentialSender() { + + device := peer.device + + logDebug := device.log.Debug + logError := device.log.Error + + defer func() { + for { + select { + case elem, ok := <-peer.queue.outbound: + if ok { + if !elem.IsDropped() { + device.PutMessageBuffer(elem.buffer) + elem.Drop() + } + device.PutOutboundElement(elem) + } + default: + goto out + } + } + out: + logDebug.Println(peer, "- Routine: sequential sender - stopped") + peer.routines.stopping.Done() + }() + + logDebug.Println(peer, "- Routine: sequential sender - started") + + peer.routines.starting.Done() + + for { + select { + + case <-peer.routines.stop: + return + + case elem, ok := <-peer.queue.outbound: + + if !ok { + return + } + + elem.Lock() + if elem.IsDropped() { + device.PutOutboundElement(elem) + continue + } + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + // send message and return buffer to pool + + length := uint64(len(elem.packet)) + err := peer.SendBuffer(elem.packet) + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + if err != nil { + logError.Println(peer, "- Failed to send data packet", err) + continue + } + atomic.AddUint64(&peer.stats.txBytes, length) + + if len(elem.packet) != MessageKeepaliveSize { + peer.timersDataSent() + } + peer.keepKeyFreshSending() + } + } +} diff --git a/device/timers.go b/device/timers.go new file mode 100644 index 0000000..5f28fcc --- /dev/null +++ b/device/timers.go @@ -0,0 +1,227 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This is based heavily on timers.c from the kernel implementation. + */ + +package device + +import ( + "math/rand" + "sync" + "sync/atomic" + "time" +) + +/* This Timer structure and related functions should roughly copy the interface of + * the Linux kernel's struct timer_list. + */ + +type Timer struct { + *time.Timer + modifyingLock sync.RWMutex + runningLock sync.Mutex + isPending bool +} + +func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { + timer := &Timer{} + timer.Timer = time.AfterFunc(time.Hour, func() { + timer.runningLock.Lock() + + timer.modifyingLock.Lock() + if !timer.isPending { + timer.modifyingLock.Unlock() + timer.runningLock.Unlock() + return + } + timer.isPending = false + timer.modifyingLock.Unlock() + + expirationFunction(peer) + timer.runningLock.Unlock() + }) + timer.Stop() + return timer +} + +func (timer *Timer) Mod(d time.Duration) { + timer.modifyingLock.Lock() + timer.isPending = true + timer.Reset(d) + timer.modifyingLock.Unlock() +} + +func (timer *Timer) Del() { + timer.modifyingLock.Lock() + timer.isPending = false + timer.Stop() + timer.modifyingLock.Unlock() +} + +func (timer *Timer) DelSync() { + timer.Del() + timer.runningLock.Lock() + timer.Del() + timer.runningLock.Unlock() +} + +func (timer *Timer) IsPending() bool { + timer.modifyingLock.RLock() + defer timer.modifyingLock.RUnlock() + return timer.isPending +} + +func (peer *Peer) timersActive() bool { + return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 +} + +func expiredRetransmitHandshake(peer *Peer) { + if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { + peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) + + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } + + /* We drop all packets without a keypair and don't try again, + * if we try unsuccessfully for too long to make a handshake. + */ + peer.FlushNonceQueue() + + /* We set a timer for destroying any residue that might be left + * of a partial exchange. + */ + if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { + peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) + } + } else { + atomic.AddUint32(&peer.timers.handshakeAttempts, 1) + peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) + + /* We clear the endpoint address src address, in case this is the cause of trouble. */ + peer.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.Unlock() + + peer.SendHandshakeInitiation(true) + } +} + +func expiredSendKeepalive(peer *Peer) { + peer.SendKeepalive() + if peer.timers.needAnotherKeepalive.Get() { + peer.timers.needAnotherKeepalive.Set(false) + if peer.timersActive() { + peer.timers.sendKeepalive.Mod(KeepaliveTimeout) + } + } +} + +func expiredNewHandshake(peer *Peer) { + peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) + /* We clear the endpoint address src address, in case this is the cause of trouble. */ + peer.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.Unlock() + peer.SendHandshakeInitiation(false) + +} + +func expiredZeroKeyMaterial(peer *Peer) { + peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) + peer.ZeroAndFlushAll() +} + +func expiredPersistentKeepalive(peer *Peer) { + if peer.persistentKeepaliveInterval > 0 { + peer.SendKeepalive() + } +} + +/* Should be called after an authenticated data packet is sent. */ +func (peer *Peer) timersDataSent() { + if peer.timersActive() && !peer.timers.newHandshake.IsPending() { + peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout) + } +} + +/* Should be called after an authenticated data packet is received. */ +func (peer *Peer) timersDataReceived() { + if peer.timersActive() { + if !peer.timers.sendKeepalive.IsPending() { + peer.timers.sendKeepalive.Mod(KeepaliveTimeout) + } else { + peer.timers.needAnotherKeepalive.Set(true) + } + } +} + +/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ +func (peer *Peer) timersAnyAuthenticatedPacketSent() { + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } +} + +/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ +func (peer *Peer) timersAnyAuthenticatedPacketReceived() { + if peer.timersActive() { + peer.timers.newHandshake.Del() + } +} + +/* Should be called after a handshake initiation message is sent. */ +func (peer *Peer) timersHandshakeInitiated() { + if peer.timersActive() { + peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + } +} + +/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ +func (peer *Peer) timersHandshakeComplete() { + if peer.timersActive() { + peer.timers.retransmitHandshake.Del() + } + atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.sentLastMinuteHandshake.Set(false) + atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) +} + +/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ +func (peer *Peer) timersSessionDerived() { + if peer.timersActive() { + peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) + } +} + +/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ +func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { + if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { + peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) + } +} + +func (peer *Peer) timersInit() { + peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) + peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) + peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) + peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) + peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) + atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.sentLastMinuteHandshake.Set(false) + peer.timers.needAnotherKeepalive.Set(false) +} + +func (peer *Peer) timersStop() { + peer.timers.retransmitHandshake.DelSync() + peer.timers.sendKeepalive.DelSync() + peer.timers.newHandshake.DelSync() + peer.timers.zeroKeyMaterial.DelSync() + peer.timers.persistentKeepalive.DelSync() +} diff --git a/device/tun.go b/device/tun.go new file mode 100644 index 0000000..bc5f1f1 --- /dev/null +++ b/device/tun.go @@ -0,0 +1,55 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "golang.zx2c4.com/wireguard/tun" + "sync/atomic" +) + +const DefaultMTU = 1420 + +func (device *Device) RoutineTUNEventReader() { + setUp := false + logDebug := device.log.Debug + logInfo := device.log.Info + logError := device.log.Error + + logDebug.Println("Routine: event worker - started") + device.state.starting.Done() + + for event := range device.tun.device.Events() { + if event&tun.TUNEventMTUUpdate != 0 { + mtu, err := device.tun.device.MTU() + old := atomic.LoadInt32(&device.tun.mtu) + if err != nil { + logError.Println("Failed to load updated MTU of device:", err) + } else if int(old) != mtu { + if mtu+MessageTransportSize > MaxMessageSize { + logInfo.Println("MTU updated:", mtu, "(too large)") + } else { + logInfo.Println("MTU updated:", mtu) + } + atomic.StoreInt32(&device.tun.mtu, int32(mtu)) + } + } + + if event&tun.TUNEventUp != 0 && !setUp { + logInfo.Println("Interface set up") + setUp = true + device.Up() + } + + if event&tun.TUNEventDown != 0 && setUp { + logInfo.Println("Interface set down") + setUp = false + device.Down() + } + } + + logDebug.Println("Routine: event worker - stopped") + device.state.stopping.Done() +} diff --git a/device/uapi.go b/device/uapi.go new file mode 100644 index 0000000..5c65917 --- /dev/null +++ b/device/uapi.go @@ -0,0 +1,426 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bufio" + "fmt" + "golang.zx2c4.com/wireguard/ipc" + "io" + "net" + "strconv" + "strings" + "sync/atomic" + "time" +) + +type IPCError struct { + int64 +} + +func (s *IPCError) Error() string { + return fmt.Sprintf("IPC error: %d", s.int64) +} + +func (s *IPCError) ErrorCode() int64 { + return s.int64 +} + +func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError { + + device.log.Debug.Println("UAPI: Processing get operation") + + // create lines + + lines := make([]string, 0, 100) + send := func(line string) { + lines = append(lines, line) + } + + func() { + + // lock required resources + + device.net.RLock() + defer device.net.RUnlock() + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + device.peers.RLock() + defer device.peers.RUnlock() + + // serialize device related values + + if !device.staticIdentity.privateKey.IsZero() { + send("private_key=" + device.staticIdentity.privateKey.ToHex()) + } + + if device.net.port != 0 { + send(fmt.Sprintf("listen_port=%d", device.net.port)) + } + + if device.net.fwmark != 0 { + send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + } + + // serialize each peer state + + for _, peer := range device.peers.keyMap { + peer.RLock() + defer peer.RUnlock() + + send("public_key=" + peer.handshake.remoteStatic.ToHex()) + send("preshared_key=" + peer.handshake.presharedKey.ToHex()) + send("protocol_version=1") + if peer.endpoint != nil { + send("endpoint=" + peer.endpoint.DstToString()) + } + + nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + secs := nano / time.Second.Nanoseconds() + nano %= time.Second.Nanoseconds() + + send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) + send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) + send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))) + send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))) + send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) + + for _, ip := range device.allowedips.EntriesForPeer(peer) { + send("allowed_ip=" + ip.String()) + } + + } + }() + + // send lines (does not require resource locks) + + for _, line := range lines { + _, err := socket.WriteString(line + "\n") + if err != nil { + return &IPCError{ipc.IpcErrorIO} + } + } + + return nil +} + +func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError { + scanner := bufio.NewScanner(socket) + logError := device.log.Error + logDebug := device.log.Debug + + var peer *Peer + + dummy := false + deviceConfig := true + + for scanner.Scan() { + + // parse line + + line := scanner.Text() + if line == "" { + return nil + } + parts := strings.Split(line, "=") + if len(parts) != 2 { + return &IPCError{ipc.IpcErrorProtocol} + } + key := parts[0] + value := parts[1] + + /* device configuration */ + + if deviceConfig { + + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromHex(value) + if err != nil { + logError.Println("Failed to set private_key:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + logDebug.Println("UAPI: Updating private key") + device.SetPrivateKey(sk) + + case "listen_port": + + // parse port number + + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + logError.Println("Failed to parse listen_port:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + // update port and rebind + + logDebug.Println("UAPI: Updating listen port") + + device.net.Lock() + device.net.port = uint16(port) + device.net.Unlock() + + if err := device.BindUpdate(); err != nil { + logError.Println("Failed to set listen_port:", err) + return &IPCError{ipc.IpcErrorPortInUse} + } + + case "fwmark": + + // parse fwmark field + + fwmark, err := func() (uint32, error) { + if value == "" { + return 0, nil + } + mark, err := strconv.ParseUint(value, 10, 32) + return uint32(mark), err + }() + + if err != nil { + logError.Println("Invalid fwmark", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + logDebug.Println("UAPI: Updating fwmark") + + if err := device.BindSetMark(uint32(fwmark)); err != nil { + logError.Println("Failed to update fwmark:", err) + return &IPCError{ipc.IpcErrorPortInUse} + } + + case "public_key": + // switch to peer configuration + logDebug.Println("UAPI: Transition to peer configuration") + deviceConfig = false + + case "replace_peers": + if value != "true" { + logError.Println("Failed to set replace_peers, invalid value:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + logDebug.Println("UAPI: Removing all peers") + device.RemoveAllPeers() + + default: + logError.Println("Invalid UAPI device key:", key) + return &IPCError{ipc.IpcErrorInvalid} + } + } + + /* peer configuration */ + + if !deviceConfig { + + switch key { + + case "public_key": + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + logError.Println("Failed to get peer by public key:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + // ignore peer with public key of device + + device.staticIdentity.RLock() + dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.RUnlock() + + if dummy { + peer = &Peer{} + } else { + peer = device.LookupPeer(publicKey) + } + + if peer == nil { + peer, err = device.NewPeer(publicKey) + if err != nil { + logError.Println("Failed to create new peer:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + logDebug.Println(peer, "- UAPI: Created") + } + + case "remove": + + // remove currently selected peer from device + + if value != "true" { + logError.Println("Failed to set remove, invalid value:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + if !dummy { + logDebug.Println(peer, "- UAPI: Removing") + device.RemovePeer(peer.handshake.remoteStatic) + } + peer = &Peer{} + dummy = true + + case "preshared_key": + + // update PSK + + logDebug.Println(peer, "- UAPI: Updating preshared key") + + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() + + if err != nil { + logError.Println("Failed to set preshared key:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + case "endpoint": + + // set endpoint destination + + logDebug.Println(peer, "- UAPI: Updating endpoint") + + err := func() error { + peer.Lock() + defer peer.Unlock() + endpoint, err := CreateEndpoint(value) + if err != nil { + return err + } + peer.endpoint = endpoint + return nil + }() + + if err != nil { + logError.Println("Failed to set endpoint:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + + case "persistent_keepalive_interval": + + // update persistent keepalive interval + + logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") + + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + logError.Println("Failed to set persistent keepalive interval:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + old := peer.persistentKeepaliveInterval + peer.persistentKeepaliveInterval = uint16(secs) + + // send immediate keepalive if we're turning it on and before it wasn't on + + if old == 0 && secs != 0 { + if err != nil { + logError.Println("Failed to get tun device status:", err) + return &IPCError{ipc.IpcErrorIO} + } + if device.isUp.Get() && !dummy { + peer.SendKeepalive() + } + } + + case "replace_allowed_ips": + + logDebug.Println(peer, "- UAPI: Removing all allowedips") + + if value != "true" { + logError.Println("Failed to replace allowedips, invalid value:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + + if dummy { + continue + } + + device.allowedips.RemoveByPeer(peer) + + case "allowed_ip": + + logDebug.Println(peer, "- UAPI: Adding allowedip") + + _, network, err := net.ParseCIDR(value) + if err != nil { + logError.Println("Failed to set allowed ip:", err) + return &IPCError{ipc.IpcErrorInvalid} + } + + if dummy { + continue + } + + ones, _ := network.Mask.Size() + device.allowedips.Insert(network.IP, uint(ones), peer) + + case "protocol_version": + + if value != "1" { + logError.Println("Invalid protocol version:", value) + return &IPCError{ipc.IpcErrorInvalid} + } + + default: + logError.Println("Invalid UAPI peer key:", key) + return &IPCError{ipc.IpcErrorInvalid} + } + } + } + + return nil +} + +func (device *Device) IpcHandle(socket net.Conn) { + + // create buffered read/writer + + defer socket.Close() + + buffered := func(s io.ReadWriter) *bufio.ReadWriter { + reader := bufio.NewReader(s) + writer := bufio.NewWriter(s) + return bufio.NewReadWriter(reader, writer) + }(socket) + + defer buffered.Flush() + + op, err := buffered.ReadString('\n') + if err != nil { + return + } + + // handle operation + + var status *IPCError + + switch op { + case "set=1\n": + device.log.Debug.Println("UAPI: Set operation") + status = device.IpcSetOperation(buffered.Reader) + + case "get=1\n": + device.log.Debug.Println("UAPI: Get operation") + status = device.IpcGetOperation(buffered.Writer) + + default: + device.log.Error.Println("Invalid UAPI operation:", op) + return + } + + // write status + + if status != nil { + device.log.Error.Println(status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } +} diff --git a/device/version.go b/device/version.go new file mode 100644 index 0000000..9077cdc --- /dev/null +++ b/device/version.go @@ -0,0 +1,3 @@ +package device + +const WireGuardGoVersion = "0.0.20181222" -- cgit v1.2.3