diff options
-rw-r--r-- | Makefile | 3 | ||||
-rw-r--r-- | allowedips.go (renamed from trie.go) | 110 | ||||
-rw-r--r-- | allowedips_rand_test.go (renamed from trie_rand_test.go) | 16 | ||||
-rw-r--r-- | allowedips_test.go (renamed from trie_test.go) | 26 | ||||
-rw-r--r-- | device.go | 4 | ||||
-rw-r--r-- | keypair.go | 2 | ||||
-rw-r--r-- | logger.go | 2 | ||||
-rw-r--r-- | noise-helpers.go | 3 | ||||
-rw-r--r-- | noise-types.go | 2 | ||||
-rw-r--r-- | peer.go | 24 | ||||
-rw-r--r-- | receive.go | 26 | ||||
-rw-r--r-- | routing.go | 70 | ||||
-rw-r--r-- | tun_darwin.go | 4 | ||||
-rw-r--r-- | tun_linux.go | 1 | ||||
-rw-r--r-- | tun_windows.go | 4 | ||||
-rw-r--r-- | uapi.go | 6 |
16 files changed, 139 insertions, 164 deletions
@@ -6,7 +6,4 @@ wireguard-go: $(wildcard *.go) clean: rm -f wireguard-go -cloc: - cloc $(filter-out xchacha20.go $(wildcard *_test.go), $(wildcard *.go)) - .PHONY: clean cloc @@ -8,21 +8,12 @@ package main import ( "errors" "net" + "sync" ) -/* Binary trie - * - * The net.IPs used here are not formatted the - * same way as those created by the "net" functions. - * Here the IPs are slices of either 4 or 16 byte (not always 16) - * - * Synchronization done separately - * See: routing.go - */ - -type Trie struct { +type trieEntry struct { cidr uint - child [2]*Trie + child [2]*trieEntry bits []byte peer *Peer @@ -90,15 +81,15 @@ func commonBits(ip1 []byte, ip2 []byte) uint { return i * 8 } -func (node *Trie) RemovePeer(p *Peer) *Trie { +func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { if node == nil { return node } // walk recursively - node.child[0] = node.child[0].RemovePeer(p) - node.child[1] = node.child[1].RemovePeer(p) + node.child[0] = node.child[0].removeByPeer(p) + node.child[1] = node.child[1].removeByPeer(p) if node.peer != p { return node @@ -113,16 +104,16 @@ func (node *Trie) RemovePeer(p *Peer) *Trie { return node.child[0] } -func (node *Trie) choose(ip net.IP) byte { +func (node *trieEntry) choose(ip net.IP) byte { return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 } -func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { +func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { // at leaf if node == nil { - return &Trie{ + return &trieEntry{ bits: ip, peer: peer, cidr: cidr, @@ -140,13 +131,13 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { return node } bit := node.choose(ip) - node.child[bit] = node.child[bit].Insert(ip, cidr, peer) + node.child[bit] = node.child[bit].insert(ip, cidr, peer) return node } // split node - newNode := &Trie{ + newNode := &trieEntry{ bits: ip, peer: peer, cidr: cidr, @@ -166,7 +157,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { // create new parent for node & newNode - parent := &Trie{ + parent := &trieEntry{ bits: ip, peer: nil, cidr: cidr, @@ -181,7 +172,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { return parent } -func (node *Trie) Lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip net.IP) *Peer { var found *Peer size := uint(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { @@ -197,16 +188,7 @@ func (node *Trie) Lookup(ip net.IP) *Peer { return found } -func (node *Trie) Count() uint { - if node == nil { - return 0 - } - l := node.child[0].Count() - r := node.child[1].Count() - return l + r -} - -func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { +func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { if node == nil { return results } @@ -223,11 +205,69 @@ func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) []net.IPNet { } else if len(node.bits) == net.IPv6len { mask.IP = node.bits } else { - panic(errors.New("bug: unexpected address length")) + panic(errors.New("unexpected address length")) } results = append(results, mask) } - results = node.child[0].AllowedIPs(p, results) - results = node.child[1].AllowedIPs(p, results) + results = node.child[0].entriesForPeer(p, results) + results = node.child[1].entriesForPeer(p, results) return results } + +type AllowedIPs struct { + IPv4 *trieEntry + IPv6 *trieEntry + mutex sync.RWMutex +} + +func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { + table.mutex.RLock() + defer table.mutex.RUnlock() + + allowed := make([]net.IPNet, 0, 10) + allowed = table.IPv4.entriesForPeer(peer, allowed) + allowed = table.IPv6.entriesForPeer(peer, allowed) + return allowed +} + +func (table *AllowedIPs) Reset() { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = nil + table.IPv6 = nil +} + +func (table *AllowedIPs) RemoveByPeer(peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + table.IPv4 = table.IPv4.removeByPeer(peer) + table.IPv6 = table.IPv6.removeByPeer(peer) +} + +func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { + table.mutex.Lock() + defer table.mutex.Unlock() + + switch len(ip) { + case net.IPv6len: + table.IPv6 = table.IPv6.insert(ip, cidr, peer) + case net.IPv4len: + table.IPv4 = table.IPv4.insert(ip, cidr, peer) + default: + panic(errors.New("inserting unknown address type")) + } +} + +func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv4.lookup(address) +} + +func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { + table.mutex.RLock() + defer table.mutex.RUnlock() + return table.IPv6.lookup(address) +} diff --git a/trie_rand_test.go b/allowedips_rand_test.go index 157c270..6ec039d 100644 --- a/trie_rand_test.go +++ b/allowedips_rand_test.go @@ -65,7 +65,7 @@ func (r SlowRouter) Lookup(addr []byte) *Peer { } func TestTrieRandomIPv4(t *testing.T) { - var trie *Trie + var trie *trieEntry var slow SlowRouter var peers []*Peer @@ -82,7 +82,7 @@ func TestTrieRandomIPv4(t *testing.T) { rand.Read(addr[:]) cidr := uint(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % NumberOfPeers - trie = trie.Insert(addr[:], cidr, peers[index]) + trie = trie.insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index]) } @@ -90,15 +90,15 @@ func TestTrieRandomIPv4(t *testing.T) { var addr [AddressLength]byte rand.Read(addr[:]) peer1 := slow.Lookup(addr[:]) - peer2 := trie.Lookup(addr[:]) + peer2 := trie.lookup(addr[:]) if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) + t.Error("trieEntry did not match naive implementation, for:", addr) } } } func TestTrieRandomIPv6(t *testing.T) { - var trie *Trie + var trie *trieEntry var slow SlowRouter var peers []*Peer @@ -115,7 +115,7 @@ func TestTrieRandomIPv6(t *testing.T) { rand.Read(addr[:]) cidr := uint(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % NumberOfPeers - trie = trie.Insert(addr[:], cidr, peers[index]) + trie = trie.insert(addr[:], cidr, peers[index]) slow = slow.Insert(addr[:], cidr, peers[index]) } @@ -123,9 +123,9 @@ func TestTrieRandomIPv6(t *testing.T) { var addr [AddressLength]byte rand.Read(addr[:]) peer1 := slow.Lookup(addr[:]) - peer2 := trie.Lookup(addr[:]) + peer2 := trie.lookup(addr[:]) if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) + t.Error("trieEntry did not match naive implementation, for:", addr) } } } diff --git a/trie_test.go b/allowedips_test.go index 3c3b5ba..7b73af3 100644 --- a/trie_test.go +++ b/allowedips_test.go @@ -31,7 +31,7 @@ type testPairTrieLookup struct { peer *Peer } -func printTrie(t *testing.T, p *Trie) { +func printTrie(t *testing.T, p *trieEntry) { if p == nil { return } @@ -63,7 +63,7 @@ func TestCommonBits(t *testing.T) { } func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { - var trie *Trie + var trie *trieEntry var peers []*Peer rand.Seed(1) @@ -79,13 +79,13 @@ func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *test rand.Read(addr[:]) cidr := uint(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % peerNumber - trie = trie.Insert(addr[:], cidr, peers[index]) + trie = trie.insert(addr[:], cidr, peers[index]) } for n := 0; n < b.N; n += 1 { var addr [AddressLength]byte rand.Read(addr[:]) - trie.Lookup(addr[:]) + trie.lookup(addr[:]) } } @@ -117,21 +117,21 @@ func TestTrieIPv4(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *Trie + var trie *trieEntry insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.Insert([]byte{a, b, c, d}, cidr, peer) + trie = trie.insert([]byte{a, b, c, d}, cidr, peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.Lookup([]byte{a, b, c, d}) + p := trie.lookup([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.Lookup([]byte{a, b, c, d}) + p := trie.lookup([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } @@ -173,7 +173,7 @@ func TestTrieIPv4(t *testing.T) { assertEQ(a, 192, 0, 0, 0) assertEQ(a, 255, 0, 0, 0) - trie = trie.RemovePeer(a) + trie = trie.removeByPeer(a) assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0) @@ -186,7 +186,7 @@ func TestTrieIPv4(t *testing.T) { insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) - trie = trie.RemovePeer(a) + trie = trie.removeByPeer(a) assertNEQ(a, 192, 168, 0, 1) } @@ -204,7 +204,7 @@ func TestTrieIPv6(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *Trie + var trie *trieEntry expand := func(a uint32) []byte { var out [4]byte @@ -221,7 +221,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - trie = trie.Insert(addr, cidr, peer) + trie = trie.insert(addr, cidr, peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { @@ -230,7 +230,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - p := trie.Lookup(addr) + p := trie.lookup(addr) if p != peer { t.Error("Assert EQ failed") } @@ -46,7 +46,7 @@ type Device struct { routing struct { mutex sync.RWMutex - table RoutingTable + table AllowedIPs } peers struct { @@ -95,7 +95,7 @@ func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { // stop routing and processing of packets - device.routing.table.RemovePeer(peer) + device.routing.table.RemoveByPeer(peer) peer.Stop() // remove from peer map @@ -33,7 +33,7 @@ type Keypairs struct { mutex sync.RWMutex current *Keypair previous *Keypair - next *Keypair // not yet "confirmed by transport" + next *Keypair } func (kp *Keypairs) Current() *Keypair { @@ -40,7 +40,7 @@ func NewLogger(level int, prepend string) *Logger { logger.Debug = log.New(logDebug, "DEBUG: "+prepend, - log.Ldate|log.Ltime|log.Lshortfile, + log.Ldate|log.Ltime, ) logger.Info = log.New(logInfo, diff --git a/noise-helpers.go b/noise-helpers.go index 6e23d83..63e45b3 100644 --- a/noise-helpers.go +++ b/noise-helpers.go @@ -71,14 +71,13 @@ func isZero(val []byte) bool { return acc == 1 } +/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */ func setZero(arr []byte) { for i := range arr { arr[i] = 0 } } -/* curve25519 wrappers */ - func newPrivateKey() (sk NoisePrivateKey, err error) { // clamping: https://cr.yp.to/ecdh.html _, err = rand.Read(sk[:]) diff --git a/noise-types.go b/noise-types.go index 58aa0c2..2635e01 100644 --- a/noise-types.go +++ b/noise-types.go @@ -30,7 +30,7 @@ func loadExactHex(dst []byte, src string) error { return err } if len(slice) != len(dst) { - return errors.New("Hex string does not fit the slice") + return errors.New("hex string does not fit the slice") } copy(dst, slice) return nil @@ -61,7 +61,7 @@ type Peer struct { mutex sync.Mutex // held when stopping / starting routines starting sync.WaitGroup // routines pending start stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go-routines in peer + stop chan struct{} // size 0, stop all go routines in peer } mac CookieGenerator @@ -70,7 +70,7 @@ type Peer struct { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { if device.isClosed.Get() { - return nil, errors.New("Device closed") + return nil, errors.New("device closed") } // lock resources @@ -87,7 +87,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // check if over limit if len(device.peers.keyMap) >= MaxPeers { - return nil, errors.New("Too many peers") + return nil, errors.New("too many peers") } // create peer @@ -104,7 +104,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { _, ok := device.peers.keyMap[pk] if ok { - return nil, errors.New("Adding existing peer") + return nil, errors.New("adding existing peer") } device.peers.keyMap[pk] = peer @@ -134,26 +134,26 @@ func (peer *Peer) SendBuffer(buffer []byte) error { defer peer.device.net.mutex.RUnlock() if peer.device.net.bind == nil { - return errors.New("No bind") + return errors.New("no bind") } peer.mutex.RLock() defer peer.mutex.RUnlock() if peer.endpoint == nil { - return errors.New("No known endpoint for peer") + return errors.New("no known endpoint for peer") } return peer.device.net.bind.Send(buffer, peer.endpoint) } -/* Returns a short string identifier for logging - */ func (peer *Peer) String() string { - return fmt.Sprintf( - "peer(%s)", - base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), - ) + base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + abbreviatedKey := "invalid" + if len(base64Key) == 44 { + abbreviatedKey = base64Key[0:4] + "..." + base64Key[40:44] + } + return fmt.Sprintf("peer(%s)", abbreviatedKey) } func (peer *Peer) Start() { @@ -600,20 +600,24 @@ func (peer *Peer) RoutineSequentialReceiver() { // check if using new key-pair kp := &peer.keypairs - kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true if kp.next == elem.keypair { - old := kp.previous - kp.previous = kp.current - device.DeleteKeypair(old) - kp.current = kp.next - kp.next = nil - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: + kp.mutex.Lock() + if kp.next != elem.keypair { + kp.mutex.Unlock() + } else { + old := kp.previous + kp.previous = kp.current + device.DeleteKeypair(old) + kp.current = kp.next + kp.next = nil + kp.mutex.Unlock() + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } } } - kp.mutex.Unlock() peer.keepKeyFreshReceiving() peer.timersAnyAuthenticatedPacketTraversal() diff --git a/routing.go b/routing.go deleted file mode 100644 index 77c9b1e..0000000 --- a/routing.go +++ /dev/null @@ -1,70 +0,0 @@ -/* SPDX-License-Identifier: GPL-2.0 - * - * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. - */ - -package main - -import ( - "errors" - "net" - "sync" -) - -type RoutingTable struct { - IPv4 *Trie - IPv6 *Trie - mutex sync.RWMutex -} - -func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet { - table.mutex.RLock() - defer table.mutex.RUnlock() - - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.AllowedIPs(peer, allowed) - allowed = table.IPv6.AllowedIPs(peer, allowed) - return allowed -} - -func (table *RoutingTable) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil -} - -func (table *RoutingTable) RemovePeer(peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = table.IPv4.RemovePeer(peer) - table.IPv6 = table.IPv6.RemovePeer(peer) -} - -func (table *RoutingTable) Insert(ip net.IP, cidr uint, peer *Peer) { - table.mutex.Lock() - defer table.mutex.Unlock() - - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.Insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.Insert(ip, cidr, peer) - default: - panic(errors.New("Inserting unknown address type")) - } -} - -func (table *RoutingTable) LookupIPv4(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv4.Lookup(address) -} - -func (table *RoutingTable) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.Lookup(address) -} diff --git a/tun_darwin.go b/tun_darwin.go index 1d66c66..fa8efe0 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -224,7 +224,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { } func (tun *NativeTun) Close() error { - return tun.fd.Close() + err := tun.fd.Close() + close(tun.events) + return err } func (tun *NativeTun) setMTU(n int) error { diff --git a/tun_linux.go b/tun_linux.go index 18994cc..9f60d2b 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -392,6 +392,7 @@ func (tun *NativeTun) Close() error { return err } tun.closingWriter.Write([]byte{0}) + close(tun.events) return nil } diff --git a/tun_windows.go b/tun_windows.go index c0c9ff8..6eea5a3 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -125,7 +125,9 @@ func (f *NativeTUN) Events() chan TUNEvent { } func (f *NativeTUN) Close() error { - return windows.Close(f.fd) + close(f.events) + err := windows.Close(f.fd) + return err } func (f *NativeTUN) Write(b []byte) (int, error) { @@ -91,7 +91,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes)) send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) - for _, ip := range device.routing.table.AllowedIPs(peer) { + for _, ip := range device.routing.table.EntriesForPeer(peer) { send("allowed_ip=" + ip.String()) } @@ -337,7 +337,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { case "replace_allowed_ips": - logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer) + logDebug.Println("UAPI: Removing all allowed EntriesForPeer for peer:", peer) if value != "true" { logError.Println("Failed to set replace_allowed_ips, invalid value:", value) @@ -349,7 +349,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { } device.routing.mutex.Lock() - device.routing.table.RemovePeer(peer) + device.routing.table.RemoveByPeer(peer) device.routing.mutex.Unlock() case "allowed_ip": |