summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2017-04-14 18:51:15 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2017-04-21 04:31:26 +0200
commitf2c9cd05c8a9f96e6f4073af13f45960a2b29ab1 (patch)
tree801b5c680f531eb5e9be1965d8b32371e534dda9 /src
parent08c09a63fd766e3557c8e63acf3dc90c86549a1e (diff)
routingtable: rewrite core functions
When removing by peer, prev needs to be set to *nptr in order to traverse that part of the trie. The other remove by IP function can simply be removed, as it's not in use. The root freeing function can use pre-order traversal instead of post-order. The pre-order traversal code in general is now a nice iterator macro. The common bits function can use the fast fls instructions and the match function can be rewritten to simply compare common bits. While we're at it, let's add tons of new tests, randomized checking against a dumb implementation, and graphviz output. And in general, it's nice to clean things up. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'src')
-rw-r--r--src/config.c17
-rw-r--r--src/routingtable.c411
-rw-r--r--src/routingtable.h8
-rw-r--r--src/selftest/routing-table.h133
-rw-r--r--src/selftest/routingtable.h504
5 files changed, 634 insertions, 439 deletions
diff --git a/src/config.c b/src/config.c
index 2736377..a5a25c9 100644
--- a/src/config.c
+++ b/src/config.c
@@ -209,20 +209,6 @@ static inline int use_data(struct data_remaining *data, size_t size)
return 0;
}
-static int calculate_ipmasks_size(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family)
-{
- size_t *count = ctx;
- *count += sizeof(struct wgipmask);
- return 0;
-}
-
-static size_t calculate_peers_size(struct wireguard_device *wg)
-{
- size_t len = peer_total_count(wg) * sizeof(struct wgpeer);
- routing_table_walk_ips(&wg->peer_routing_table, &len, calculate_ipmasks_size);
- return len;
-}
-
static int populate_ipmask(void *ctx, union nf_inet_addr ip, u8 cidr, int family)
{
int ret;
@@ -305,7 +291,8 @@ int config_get_device(struct wireguard_device *wg, void __user *user_device)
mutex_lock(&wg->device_update_lock);
if (!user_device) {
- ret = calculate_peers_size(wg);
+ ret = peer_total_count(wg) * sizeof(struct wgpeer)
+ + routing_table_count_nodes(&wg->peer_routing_table) * sizeof(struct wgipmask);
goto out;
}
diff --git a/src/routingtable.c b/src/routingtable.c
index 1de7727..f9c3eff 100644
--- a/src/routingtable.c
+++ b/src/routingtable.c
@@ -7,16 +7,10 @@ struct routing_table_node {
struct routing_table_node __rcu *bit[2];
struct rcu_head rcu;
struct wireguard_peer *peer;
- u8 cidr;
- u8 bit_at_a, bit_at_b;
- bool incidental;
- u8 bits[];
+ u8 cidr, bit_at_a, bit_at_b;
+ u8 bits[] __aligned(__alignof__(u64));
};
-static inline u8 bit_at(const u8 *key, u8 a, u8 b)
-{
- return (key[a] >> b) & 1;
-}
static inline void copy_and_assign_cidr(struct routing_table_node *node, const u8 *src, u8 cidr)
{
memcpy(node->bits, src, (cidr + 7) / 8);
@@ -25,67 +19,77 @@ static inline void copy_and_assign_cidr(struct routing_table_node *node, const u
node->bit_at_a = cidr / 8;
node->bit_at_b = 7 - (cidr % 8);
}
+#define choose_node(parent, key) parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
-/* Non-recursive RCU expansion of:
- *
- * free_node(node)
- * {
- * if (!node)
- * return;
- * free_node(node->bit[0]);
- * free_node(node->bit[1]);
- * kfree_rcu_bh(node);
- * }
- */
static void node_free_rcu(struct rcu_head *rcu)
{
kfree(container_of(rcu, struct routing_table_node, rcu));
}
-#define ref(p) rcu_access_pointer(p)
-#define push(p) do { BUG_ON(len >= 128); stack[len++] = rcu_dereference_protected(p, lockdep_is_held(lock)); } while (0)
-static void free_node(struct routing_table_node *top, struct mutex *lock)
+#define push(p, lock) ({ \
+ if (rcu_access_pointer(p)) { \
+ BUG_ON(len >= 128); \
+ stack[len++] = lock ? rcu_dereference_protected(p, lockdep_is_held((struct mutex *)lock)) : rcu_dereference_bh(p); \
+ } \
+ true; \
+})
+#define walk_prep \
+ struct routing_table_node *stack[128], *node; \
+ unsigned int len;
+#define walk(top, lock) for (len = 0, push(top, lock); len > 0 && (node = stack[--len]) && push(node->bit[0], lock) && push(node->bit[1], lock);)
+
+static void free_root_node(struct routing_table_node __rcu *top, struct mutex *lock)
{
- struct routing_table_node *stack[128];
- struct routing_table_node *node = NULL;
- struct routing_table_node *prev = NULL;
- unsigned int len = 0;
+ walk_prep;
+ walk (top, lock)
+ call_rcu_bh(&node->rcu, node_free_rcu);
+}
- if (!top)
- return;
+static size_t count_nodes(struct routing_table_node __rcu *top)
+{
+ size_t ret = 0;
+ walk_prep;
+ walk (top, NULL) {
+ if (node->peer)
+ ++ret;
+ }
+ return ret;
+}
- stack[len++] = top;
- while (len > 0) {
- node = stack[len - 1];
- if (!prev || ref(prev->bit[0]) == node || ref(prev->bit[1]) == node) {
- if (ref(node->bit[0]))
- push(node->bit[0]);
- else if (ref(node->bit[1]))
- push(node->bit[1]);
- } else if (ref(node->bit[0]) == prev) {
- if (ref(node->bit[1]))
- push(node->bit[1]);
- } else {
- call_rcu_bh(&node->rcu, node_free_rcu);
- --len;
- }
- prev = node;
+static int walk_ips_by_peer(struct routing_table_node __rcu *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
+{
+ int ret;
+ union nf_inet_addr ip = { .all = { 0 } };
+ walk_prep;
+
+ if (unlikely(!peer))
+ return 0;
+
+ walk (top, maybe_lock) {
+ if (node->peer != peer)
+ continue;
+ memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4);
+ ret = func(ctx, ip, node->cidr, family);
+ if (ret)
+ return ret;
}
+ return 0;
}
#undef push
-#define push(p) do { BUG_ON(len >= 128); stack[len++] = p; } while (0)
-static bool walk_remove_by_peer(struct routing_table_node __rcu **top, struct wireguard_peer *peer, struct mutex *lock)
+
+#define ref(p) rcu_access_pointer(p)
+#define deref(p) rcu_dereference_protected(*p, lockdep_is_held(lock))
+#define push(p) ({ BUG_ON(len >= 128); stack[len++] = p; })
+static void walk_remove_by_peer(struct routing_table_node __rcu **top, struct wireguard_peer *peer, struct mutex *lock)
{
- struct routing_table_node __rcu **stack[128];
- struct routing_table_node __rcu **nptr;
- struct routing_table_node *node = NULL;
- struct routing_table_node *prev = NULL;
- unsigned int len = 0;
- bool ret = false;
-
- stack[len++] = top;
- while (len > 0) {
+ struct routing_table_node __rcu **stack[128], **nptr, *node, *prev;
+ unsigned int len;
+
+ if (unlikely(!peer || !ref(*top)))
+ return;
+
+ for (prev = NULL, len = 0, push(top); len > 0; prev = node) {
nptr = stack[len - 1];
- node = rcu_dereference_protected(*nptr, lockdep_is_held(lock));
+ node = deref(nptr);
if (!node) {
--len;
continue;
@@ -100,111 +104,76 @@ static bool walk_remove_by_peer(struct routing_table_node __rcu **top, struct wi
push(&node->bit[1]);
} else {
if (node->peer == peer) {
- ret = true;
node->peer = NULL;
- node->incidental = true;
if (!node->bit[0] || !node->bit[1]) {
- /* collapse (even if both are null) */
- rcu_assign_pointer(*nptr, rcu_dereference_protected(node->bit[!node->bit[0]], lockdep_is_held(lock)));
- rcu_assign_pointer(node->bit[0], NULL);
- rcu_assign_pointer(node->bit[1], NULL);
- free_node(node, lock);
+ rcu_assign_pointer(*nptr, deref(&node->bit[!ref(node->bit[0])]));
+ call_rcu_bh(&node->rcu, node_free_rcu);
+ node = deref(nptr);
}
}
--len;
}
- prev = node;
}
-
- return ret;
}
#undef ref
+#undef deref
#undef push
-static inline bool match(const struct routing_table_node *node, const u8 *key, u8 match_len)
+static inline unsigned int fls128(u64 a, u64 b)
{
- u8 full_blocks_to_match = match_len / 8;
- u8 bits_leftover = match_len % 8;
- u8 mask;
- const u8 *a = node->bits, *b = key;
- if (memcmp(a, b, full_blocks_to_match))
- return false;
- if (!bits_leftover)
- return true;
- mask = ~(0xff >> bits_leftover);
- return (a[full_blocks_to_match] & mask) == (b[full_blocks_to_match] & mask);
+ return a ? fls64(a) + 64 : fls64(b);
}
-static inline u8 common_bits(const struct routing_table_node *node, const u8 *key, u8 match_len)
+static inline u8 common_bits(const struct routing_table_node *node, const u8 *key, u8 bits)
{
- u8 max = (((match_len > node->cidr) ? match_len : node->cidr) + 7) / 8;
- u8 bits = 0;
- u8 i, mask;
- const u8 *a = node->bits, *b = key;
- for (i = 0; i < max; ++i, bits += 8) {
- if (a[i] != b[i])
- break;
- }
- if (i == max)
- return bits;
- for (mask = 128; mask > 0; mask /= 2, ++bits) {
- if ((a[i] & mask) != (b[i] & mask))
- return bits;
- }
+ if (bits == 32)
+ return 32 - fls(be32_to_cpu(*(const __be32 *)node->bits ^ *(const __be32 *)key));
+ else if (bits == 128)
+ return 128 - fls128(be64_to_cpu(*(const __be64 *)&node->bits[0] ^ *(const __be64 *)&key[0]), be64_to_cpu(*(const __be64 *)&node->bits[8] ^ *(const __be64 *)&key[8]));
BUG();
- return bits;
-}
-
-static int remove(struct routing_table_node __rcu **trie, const u8 *key, u8 cidr, struct mutex *lock)
-{
- struct routing_table_node *parent = NULL, *node;
- node = rcu_dereference_protected(*trie, lockdep_is_held(lock));
- while (node && node->cidr <= cidr && match(node, key, node->cidr)) {
- if (node->cidr == cidr) {
- /* exact match */
- node->incidental = true;
- node->peer = NULL;
- if (!node->bit[0] || !node->bit[1]) {
- /* collapse (even if both are null) */
- if (parent)
- rcu_assign_pointer(parent->bit[bit_at(key, parent->bit_at_a, parent->bit_at_b)],
- rcu_dereference_protected(node->bit[(!node->bit[0]) ? 1 : 0], lockdep_is_held(lock)));
- rcu_assign_pointer(node->bit[0], NULL);
- rcu_assign_pointer(node->bit[1], NULL);
- free_node(node, lock);
- }
- return 0;
- }
- parent = node;
- node = rcu_dereference_protected(parent->bit[bit_at(key, parent->bit_at_a, parent->bit_at_b)], lockdep_is_held(lock));
- }
- return -ENOENT;
+ return 0;
}
static inline struct routing_table_node *find_node(struct routing_table_node *trie, u8 bits, const u8 *key)
{
struct routing_table_node *node = trie, *found = NULL;
- while (node && match(node, key, node->cidr)) {
- if (!node->incidental)
+
+ while (node && common_bits(node, key, bits) >= node->cidr) {
+ if (node->peer)
found = node;
if (node->cidr == bits)
break;
- node = rcu_dereference_bh(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)]);
+ node = rcu_dereference_bh(choose_node(node, key));
}
return found;
}
-static inline bool node_placement(struct routing_table_node __rcu *trie, const u8 *key, u8 cidr, struct routing_table_node **rnode, struct mutex *lock)
+/* Returns a strong reference to a peer */
+static inline struct wireguard_peer *lookup(struct routing_table_node __rcu *root, u8 bits, const void *ip)
+{
+ struct wireguard_peer *peer = NULL;
+ struct routing_table_node *node;
+
+ rcu_read_lock_bh();
+ node = find_node(rcu_dereference_bh(root), bits, ip);
+ if (node)
+ peer = peer_get(node->peer);
+ rcu_read_unlock_bh();
+ return peer;
+}
+
+static inline bool node_placement(struct routing_table_node __rcu *trie, const u8 *key, u8 cidr, u8 bits, struct routing_table_node **rnode, struct mutex *lock)
{
bool exact = false;
struct routing_table_node *parent = NULL, *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
- while (node && node->cidr <= cidr && match(node, key, node->cidr)) {
+
+ while (node && node->cidr <= cidr && common_bits(node, key, bits) >= node->cidr) {
parent = node;
if (parent->cidr == cidr) {
exact = true;
break;
}
- node = rcu_dereference_protected(parent->bit[bit_at(key, parent->bit_at_a, parent->bit_at_b)], lockdep_is_held(lock));
+ node = rcu_dereference_protected(choose_node(parent, key), lockdep_is_held(lock));
}
if (rnode)
*rnode = parent;
@@ -224,9 +193,7 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
rcu_assign_pointer(*trie, node);
return 0;
}
- if (node_placement(*trie, key, cidr, &node, lock)) {
- /* exact match */
- node->incidental = false;
+ if (node_placement(*trie, key, cidr, bits, &node, lock)) {
node->peer = peer;
return 0;
}
@@ -239,112 +206,40 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
if (!node)
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
- else
- down = rcu_dereference_protected(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)], lockdep_is_held(lock));
- if (!down) {
- rcu_assign_pointer(node->bit[bit_at(key, node->bit_at_a, node->bit_at_b)], newnode);
- return 0;
+ else {
+ down = rcu_dereference_protected(choose_node(node, key), lockdep_is_held(lock));
+ if (!down) {
+ rcu_assign_pointer(choose_node(node, key), newnode);
+ return 0;
+ }
}
- /* here we must be inserting between node and down */
- cidr = min(cidr, common_bits(down, key, cidr));
+ cidr = min(cidr, common_bits(down, key, bits));
parent = node;
- /* we either need to make a new branch above down and newnode
- * or newnode can be the branch. newnode can be the branch if
- * its cidr == bits_in_common */
if (newnode->cidr == cidr) {
- /* newnode can be the branch */
- rcu_assign_pointer(newnode->bit[bit_at(down->bits, newnode->bit_at_a, newnode->bit_at_b)], down);
+ rcu_assign_pointer(choose_node(newnode, down->bits), down);
if (!parent)
rcu_assign_pointer(*trie, newnode);
else
- rcu_assign_pointer(parent->bit[bit_at(newnode->bits, parent->bit_at_a, parent->bit_at_b)], newnode);
+ rcu_assign_pointer(choose_node(parent, newnode->bits), newnode);
} else {
- /* reparent */
node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL);
if (!node) {
kfree(newnode);
return -ENOMEM;
}
- node->incidental = true;
copy_and_assign_cidr(node, newnode->bits, cidr);
- rcu_assign_pointer(node->bit[bit_at(down->bits, node->bit_at_a, node->bit_at_b)], down);
- rcu_assign_pointer(node->bit[bit_at(newnode->bits, node->bit_at_a, node->bit_at_b)], newnode);
+ rcu_assign_pointer(choose_node(node, down->bits), down);
+ rcu_assign_pointer(choose_node(node, newnode->bits), newnode);
if (!parent)
rcu_assign_pointer(*trie, node);
else
- rcu_assign_pointer(parent->bit[bit_at(node->bits, parent->bit_at_a, parent->bit_at_b)], node);
+ rcu_assign_pointer(choose_node(parent, node->bits), node);
}
return 0;
}
-#define push(p) do { \
- struct routing_table_node *next = (maybe_lock ? rcu_dereference_protected(p, lockdep_is_held(maybe_lock)) : rcu_dereference_bh(p)); \
- if (next) { \
- BUG_ON(len >= 128); \
- stack[len++] = next; \
- } \
-} while (0)
-static int walk_ips(struct routing_table_node *top, int family, void *ctx, int (*func)(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
-{
- int ret;
- union nf_inet_addr ip = { .all = { 0 } };
- struct routing_table_node *stack[128];
- struct routing_table_node *node;
- unsigned int len = 0;
- struct wireguard_peer *peer;
-
- if (!top)
- return 0;
-
- stack[len++] = top;
- while (len > 0) {
- node = stack[--len];
-
- peer = peer_get(node->peer);
- if (peer) {
- memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4);
- ret = func(ctx, peer, ip, node->cidr, family);
- peer_put(peer);
- if (ret)
- return ret;
- }
-
- push(node->bit[0]);
- push(node->bit[1]);
- }
- return 0;
-}
-static int walk_ips_by_peer(struct routing_table_node *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
-{
- int ret;
- union nf_inet_addr ip = { .all = { 0 } };
- struct routing_table_node *stack[128];
- struct routing_table_node *node;
- unsigned int len = 0;
-
- if (!top)
- return 0;
-
- stack[len++] = top;
- while (len > 0) {
- node = stack[--len];
-
- if (node->peer == peer) {
- memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4);
- ret = func(ctx, ip, node->cidr, family);
- if (ret)
- return ret;
- }
-
- push(node->bit[0]);
- push(node->bit[1]);
- }
- return 0;
-}
-#undef push
-
void routing_table_init(struct routing_table *table)
{
memset(table, 0, sizeof(struct routing_table));
@@ -354,9 +249,9 @@ void routing_table_init(struct routing_table *table)
void routing_table_free(struct routing_table *table)
{
mutex_lock(&table->table_update_lock);
- free_node(rcu_dereference_protected(table->root4, lockdep_is_held(&table->table_update_lock)), &table->table_update_lock);
+ free_root_node(table->root4, &table->table_update_lock);
rcu_assign_pointer(table->root4, NULL);
- free_node(rcu_dereference_protected(table->root6, lockdep_is_held(&table->table_update_lock)), &table->table_update_lock);
+ free_root_node(table->root6, &table->table_update_lock);
rcu_assign_pointer(table->root6, NULL);
mutex_unlock(&table->table_update_lock);
}
@@ -364,7 +259,7 @@ void routing_table_free(struct routing_table *table)
int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer)
{
int ret;
- if (cidr > 32)
+ if (unlikely(cidr > 32 || !peer))
return -EINVAL;
mutex_lock(&table->table_update_lock);
ret = add(&table->root4, 32, (const u8 *)ip, cidr, peer, &table->table_update_lock);
@@ -375,7 +270,7 @@ int routing_table_insert_v4(struct routing_table *table, const struct in_addr *i
int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer)
{
int ret;
- if (cidr > 128)
+ if (unlikely(cidr > 128 || !peer))
return -EINVAL;
mutex_lock(&table->table_update_lock);
ret = add(&table->root6, 128, (const u8 *)ip, cidr, peer, &table->table_update_lock);
@@ -383,73 +278,19 @@ int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *
return ret;
}
-/* Returns a strong reference to a peer */
-inline struct wireguard_peer *routing_table_lookup_v4(struct routing_table *table, const struct in_addr *ip)
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer)
{
- struct wireguard_peer *peer = NULL;
- struct routing_table_node *node;
-
- rcu_read_lock_bh();
- node = find_node(rcu_dereference_bh(table->root4), 32, (const u8 *)ip);
- if (node)
- peer = peer_get(node->peer);
- rcu_read_unlock_bh();
- return peer;
-}
-
-/* Returns a strong reference to a peer */
-inline struct wireguard_peer *routing_table_lookup_v6(struct routing_table *table, const struct in6_addr *ip)
-{
- struct wireguard_peer *peer = NULL;
- struct routing_table_node *node;
-
- rcu_read_lock_bh();
- node = find_node(rcu_dereference_bh(table->root6), 128, (const u8 *)ip);
- if (node)
- peer = peer_get(node->peer);
- rcu_read_unlock_bh();
- return peer;
-}
-
-int routing_table_remove_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr)
-{
- int ret;
- mutex_lock(&table->table_update_lock);
- ret = remove(&table->root4, (const u8 *)ip, cidr, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
- return ret;
-}
-
-int routing_table_remove_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr)
-{
- int ret;
mutex_lock(&table->table_update_lock);
- ret = remove(&table->root6, (const u8 *)ip, cidr, &table->table_update_lock);
+ walk_remove_by_peer(&table->root4, peer, &table->table_update_lock);
+ walk_remove_by_peer(&table->root6, peer, &table->table_update_lock);
mutex_unlock(&table->table_update_lock);
- return ret;
}
-int routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer)
+size_t routing_table_count_nodes(struct routing_table *table)
{
- bool found;
- mutex_lock(&table->table_update_lock);
- found = walk_remove_by_peer(&table->root4, peer, &table->table_update_lock) | walk_remove_by_peer(&table->root6, peer, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
- return found ? 0 : -EINVAL;
-}
-
-/* Calls func with a strong reference to each peer, before putting it when the function has completed.
- * It's thus up to the caller to call peer_put on it if it's going to be used elsewhere after or stored. */
-int routing_table_walk_ips(struct routing_table *table, void *ctx, int (*func)(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family))
-{
- int ret;
- rcu_read_lock_bh();
- ret = walk_ips(rcu_dereference_bh(table->root4), AF_INET, ctx, func, NULL);
- rcu_read_unlock_bh();
- if (ret)
- return ret;
+ size_t ret;
rcu_read_lock_bh();
- ret = walk_ips(rcu_dereference_bh(table->root6), AF_INET6, ctx, func, NULL);
+ ret = count_nodes(table->root4) + count_nodes(table->root6);
rcu_read_unlock_bh();
return ret;
}
@@ -458,12 +299,12 @@ int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struc
{
int ret;
rcu_read_lock_bh();
- ret = walk_ips_by_peer(rcu_dereference_bh(table->root4), AF_INET, ctx, peer, func, NULL);
+ ret = walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, NULL);
rcu_read_unlock_bh();
if (ret)
return ret;
rcu_read_lock_bh();
- ret = walk_ips_by_peer(rcu_dereference_bh(table->root6), AF_INET6, ctx, peer, func, NULL);
+ ret = walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, NULL);
rcu_read_unlock_bh();
return ret;
}
@@ -472,12 +313,12 @@ int routing_table_walk_ips_by_peer_sleepable(struct routing_table *table, void *
{
int ret;
mutex_lock(&table->table_update_lock);
- ret = walk_ips_by_peer(rcu_dereference_protected(table->root4, lockdep_is_held(&table->table_update_lock)), AF_INET, ctx, peer, func, &table->table_update_lock);
+ ret = walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, &table->table_update_lock);
mutex_unlock(&table->table_update_lock);
if (ret)
return ret;
mutex_lock(&table->table_update_lock);
- ret = walk_ips_by_peer(rcu_dereference_protected(table->root6, lockdep_is_held(&table->table_update_lock)), AF_INET6, ctx, peer, func, &table->table_update_lock);
+ ret = walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, &table->table_update_lock);
mutex_unlock(&table->table_update_lock);
return ret;
}
@@ -499,9 +340,9 @@ struct wireguard_peer *routing_table_lookup_dst(struct routing_table *table, str
if (unlikely(!has_valid_ip_header(skb)))
return NULL;
if (ip_hdr(skb)->version == 4)
- return routing_table_lookup_v4(table, (struct in_addr *)&ip_hdr(skb)->daddr);
+ return lookup(table->root4, 32, &ip_hdr(skb)->daddr);
else if (ip_hdr(skb)->version == 6)
- return routing_table_lookup_v6(table, &ipv6_hdr(skb)->daddr);
+ return lookup(table->root6, 128, &ipv6_hdr(skb)->daddr);
return NULL;
}
@@ -511,10 +352,10 @@ struct wireguard_peer *routing_table_lookup_src(struct routing_table *table, str
if (unlikely(!has_valid_ip_header(skb)))
return NULL;
if (ip_hdr(skb)->version == 4)
- return routing_table_lookup_v4(table, (struct in_addr *)&ip_hdr(skb)->saddr);
+ return lookup(table->root4, 32, &ip_hdr(skb)->saddr);
else if (ip_hdr(skb)->version == 6)
- return routing_table_lookup_v6(table, &ipv6_hdr(skb)->saddr);
+ return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr);
return NULL;
}
-#include "selftest/routing-table.h"
+#include "selftest/routingtable.h"
diff --git a/src/routingtable.h b/src/routingtable.h
index adcc632..4fdf410 100644
--- a/src/routingtable.h
+++ b/src/routingtable.h
@@ -20,16 +20,12 @@ void routing_table_init(struct routing_table *table);
void routing_table_free(struct routing_table *table);
int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer);
int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer);
-int routing_table_remove_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr);
-int routing_table_remove_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr);
-int routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer);
-int routing_table_walk_ips(struct routing_table *table, void *ctx, int (*func)(void *ctx, struct wireguard_peer *peer, union nf_inet_addr ip, u8 cidr, int family));
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer);
+size_t routing_table_count_nodes(struct routing_table *table);
int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family));
int routing_table_walk_ips_by_peer_sleepable(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family));
/* These return a strong reference to a peer: */
-struct wireguard_peer *routing_table_lookup_v4(struct routing_table *table, const struct in_addr *ip);
-struct wireguard_peer *routing_table_lookup_v6(struct routing_table *table, const struct in6_addr *ip);
struct wireguard_peer *routing_table_lookup_dst(struct routing_table *table, struct sk_buff *skb);
struct wireguard_peer *routing_table_lookup_src(struct routing_table *table, struct sk_buff *skb);
diff --git a/src/selftest/routing-table.h b/src/selftest/routing-table.h
deleted file mode 100644
index a603401..0000000
--- a/src/selftest/routing-table.h
+++ /dev/null
@@ -1,133 +0,0 @@
-/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
-
-#ifdef DEBUG
-static inline struct in_addr *ip4(u8 a, u8 b, u8 c, u8 d)
-{
- static struct in_addr ip;
- u8 *split = (u8 *)&ip;
- split[0] = a;
- split[1] = b;
- split[2] = c;
- split[3] = d;
- return &ip;
-}
-static inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
-{
- static struct in6_addr ip;
- __be32 *split = (__be32 *)&ip;
- split[0] = cpu_to_be32(a);
- split[1] = cpu_to_be32(b);
- split[2] = cpu_to_be32(c);
- split[3] = cpu_to_be32(d);
- return &ip;
-}
-
-bool routing_table_selftest(void)
-{
- struct routing_table t;
- struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL;
- size_t i = 0;
- bool success = false;
- struct in6_addr ip;
- __be64 part;
-
- routing_table_init(&t);
-#define init_peer(name) do { name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); if (!name) goto free; kref_init(&name->refcount); } while (0)
- init_peer(a);
- init_peer(b);
- init_peer(c);
- init_peer(d);
- init_peer(e);
- init_peer(f);
- init_peer(g);
- init_peer(h);
-#undef init_peer
-
-#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) routing_table_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem)
- insert(4, a, 192, 168, 4, 0, 24);
- insert(4, b, 192, 168, 4, 4, 32);
- insert(4, c, 192, 168, 0, 0, 16);
- insert(4, d, 192, 95, 5, 64, 27);
- insert(4, c, 192, 95, 5, 65, 27); /* replaces previous entry, and maskself is required */
- insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
- insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
- insert(4, e, 0, 0, 0, 0, 0);
- insert(6, e, 0, 0, 0, 0, 0);
- insert(6, f, 0, 0, 0, 0, 0); /* replaces previous entry */
- insert(6, g, 0x24046800, 0, 0, 0, 32);
- insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); /* maskself is required */
- insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
- insert(4, g, 64, 15, 112, 0, 20);
- insert(4, h, 64, 15, 123, 211, 25); /* maskself is required */
- insert(4, a, 10, 0, 0, 0, 25);
- insert(4, b, 10, 0, 0, 128, 25);
- insert(4, a, 10, 1, 0, 0, 30);
- insert(4, b, 10, 1, 0, 4, 30);
- insert(4, c, 10, 1, 0, 8, 29);
- insert(4, d, 10, 1, 0, 16, 29);
-#undef insert
-
- success = true;
-#define test(version, mem, ipa, ipb, ipc, ipd) do { \
- bool _s = routing_table_lookup_v##version(&t, ip##version(ipa, ipb, ipc, ipd)) == mem; \
- ++i; \
- if (!_s) { \
- pr_info("routing table self-test %zu: FAIL\n", i); \
- success = false; \
- } \
-} while (0)
- test(4, a, 192, 168, 4, 20);
- test(4, a, 192, 168, 4, 0);
- test(4, b, 192, 168, 4, 4);
- test(4, c, 192, 168, 200, 182);
- test(4, c, 192, 95, 5, 68);
- test(4, e, 192, 95, 5, 96);
- test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
- test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
- test(6, f, 0x26075300, 0x60006b01, 0, 0);
- test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
- test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
- test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
- test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
- test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
- test(6, h, 0x24046800, 0x40040800, 0, 0);
- test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
- test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
- test(4, g, 64, 15, 116, 26);
- test(4, g, 64, 15, 127, 3);
- test(4, g, 64, 15, 123, 1);
- test(4, h, 64, 15, 123, 128);
- test(4, h, 64, 15, 123, 129);
- test(4, a, 10, 0, 0, 52);
- test(4, b, 10, 0, 0, 220);
- test(4, a, 10, 1, 0, 2);
- test(4, b, 10, 1, 0, 6);
- test(4, c, 10, 1, 0, 10);
- test(4, d, 10, 1, 0, 20);
-#undef test
-
- /* These will hit the BUG_ON(len >= 128) in free_node if something goes wrong. */
- for (i = 0; i < 128; ++i) {
- part = cpu_to_be64(~(1LLU << (i % 64)));
- memset(&ip, 0xff, 16);
- memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
- routing_table_insert_v6(&t, &ip, 128, a);
- }
-
- if (success)
- pr_info("routing table self-tests: pass\n");
-
-free:
- routing_table_free(&t);
- kfree(a);
- kfree(b);
- kfree(c);
- kfree(d);
- kfree(e);
- kfree(f);
- kfree(g);
- kfree(h);
-
- return success;
-}
-#endif
diff --git a/src/selftest/routingtable.h b/src/selftest/routingtable.h
new file mode 100644
index 0000000..0915e65
--- /dev/null
+++ b/src/selftest/routingtable.h
@@ -0,0 +1,504 @@
+/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
+
+#ifdef DEBUG
+
+#ifdef DEBUG_PRINT_TRIE_GRAPHVIZ
+#include <linux/siphash.h>
+static void print_node(struct routing_table_node *node, u8 bits)
+{
+ u32 color = 0;
+ char *style = "dotted";
+ char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
+ char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
+ if (bits == 32) {
+ fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
+ fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
+ } else if (bits == 128) {
+ fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
+ fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
+ }
+ if (node->peer) {
+ hsiphash_key_t key = { 0 };
+ memcpy(&key, &node->peer, sizeof(node->peer));
+ color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 | hsiphash_1u32(0xbabecafe, &key) % 200 << 8 | hsiphash_1u32(0xabad1dea, &key) % 200;
+ style = "bold";
+ }
+ printk(fmt_declaration, node->bits, node->cidr, style, color);
+ if (node->bit[0]) {
+ printk(fmt_connection, node->bits, node->cidr, node->bit[0]->bits, node->bit[0]->cidr);
+ print_node(node->bit[0], bits);
+ }
+ if (node->bit[1]) {
+ printk(fmt_connection, node->bits, node->cidr, node->bit[1]->bits, node->bit[1]->cidr);
+ print_node(node->bit[1], bits);
+ }
+}
+static void print_tree(struct routing_table_node *top, u8 bits)
+{
+ printk(KERN_DEBUG "digraph trie {\n");
+ print_node(top, bits);
+ printk(KERN_DEBUG "}\n");
+}
+#endif
+
+#ifdef DEBUG_RANDOM_TRIE
+#define NUM_PEERS 2000
+#define NUM_RAND_ROUTES 400
+#define NUM_MUTATED_ROUTES 100
+#define NUM_QUERIES (NUM_RAND_ROUTES * NUM_MUTATED_ROUTES * 30)
+#include <linux/random.h>
+struct horrible_routing_table {
+ struct hlist_head head;
+};
+struct horrible_routing_table_node {
+ struct hlist_node table;
+ union nf_inet_addr ip;
+ union nf_inet_addr mask;
+ uint8_t ip_version;
+ void *value;
+};
+static void horrible_routing_table_init(struct horrible_routing_table *table)
+{
+ INIT_HLIST_HEAD(&table->head);
+}
+static void horrible_routing_table_free(struct horrible_routing_table *table)
+{
+ struct hlist_node *h;
+ struct horrible_routing_table_node *node;
+ hlist_for_each_entry_safe(node, h, &table->head, table) {
+ hlist_del(&node->table);
+ kfree(node);
+ };
+}
+static inline union nf_inet_addr horrible_cidr_to_mask(uint8_t cidr)
+{
+ union nf_inet_addr mask;
+ memset(&mask, 0x00, 128 / 8);
+ memset(&mask, 0xff, cidr / 8);
+ if (cidr % 32)
+ mask.all[cidr / 32] = htonl((0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
+ return mask;
+}
+static inline uint8_t horrible_mask_to_cidr(union nf_inet_addr subnet)
+{
+ return hweight32(subnet.all[0])
+ + hweight32(subnet.all[1])
+ + hweight32(subnet.all[2])
+ + hweight32(subnet.all[3]);
+}
+static inline void horrible_mask_self(struct horrible_routing_table_node *node)
+{
+ if (node->ip_version == 4)
+ node->ip.ip &= node->mask.ip;
+ else if (node->ip_version == 6) {
+ node->ip.ip6[0] &= node->mask.ip6[0];
+ node->ip.ip6[1] &= node->mask.ip6[1];
+ node->ip.ip6[2] &= node->mask.ip6[2];
+ node->ip.ip6[3] &= node->mask.ip6[3];
+ }
+}
+static inline bool horrible_match_v4(const struct horrible_routing_table_node *node, struct in_addr *ip)
+{
+ return (ip->s_addr & node->mask.ip) == node->ip.ip;
+}
+static inline bool horrible_match_v6(const struct horrible_routing_table_node *node, struct in6_addr *ip)
+{
+ return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] &&
+ (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] &&
+ (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] &&
+ (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
+}
+static void horrible_insert_ordered(struct horrible_routing_table *table, struct horrible_routing_table_node *node)
+{
+ struct horrible_routing_table_node *other = NULL, *where = NULL;
+ uint8_t my_cidr = horrible_mask_to_cidr(node->mask);
+ hlist_for_each_entry(other, &table->head, table) {
+ if (!memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) &&
+ !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr)) &&
+ other->ip_version == node->ip_version) {
+ other->value = node->value;
+ kfree(node);
+ return;
+ }
+ where = other;
+ if (horrible_mask_to_cidr(other->mask) <= my_cidr)
+ break;
+ }
+ if (!other && !where)
+ hlist_add_head(&node->table, &table->head);
+ else if (!other)
+ hlist_add_behind(&node->table, &where->table);
+ else
+ hlist_add_before(&node->table, &where->table);
+}
+static int horrible_routing_table_insert_v4(struct horrible_routing_table *table, struct in_addr *ip, uint8_t cidr, void *value)
+{
+ struct horrible_routing_table_node *node = kzalloc(sizeof(struct horrible_routing_table_node), GFP_KERNEL);
+ if (!node)
+ return -ENOMEM;
+ node->ip.in = *ip;
+ node->mask = horrible_cidr_to_mask(cidr);
+ node->ip_version = 4;
+ node->value = value;
+ horrible_mask_self(node);
+ horrible_insert_ordered(table, node);
+ return 0;
+}
+static int horrible_routing_table_insert_v6(struct horrible_routing_table *table, struct in6_addr *ip, uint8_t cidr, void *value)
+{
+ struct horrible_routing_table_node *node = kzalloc(sizeof(struct horrible_routing_table_node), GFP_KERNEL);
+ if (!node)
+ return -ENOMEM;
+ node->ip.in6 = *ip;
+ node->mask = horrible_cidr_to_mask(cidr);
+ node->ip_version = 6;
+ node->value = value;
+ horrible_mask_self(node);
+ horrible_insert_ordered(table, node);
+ return 0;
+}
+static void *horrible_routing_table_lookup_v4(struct horrible_routing_table *table, struct in_addr *ip)
+{
+ struct horrible_routing_table_node *node;
+ void *ret = NULL;
+ hlist_for_each_entry(node, &table->head, table) {
+ if (node->ip_version != 4)
+ continue;
+ if (horrible_match_v4(node, ip)) {
+ ret = node->value;
+ break;
+ }
+ };
+ return ret;
+}
+static void *horrible_routing_table_lookup_v6(struct horrible_routing_table *table, struct in6_addr *ip)
+{
+ struct horrible_routing_table_node *node;
+ void *ret = NULL;
+ hlist_for_each_entry(node, &table->head, table) {
+ if (node->ip_version != 6)
+ continue;
+ if (horrible_match_v6(node, ip)) {
+ ret = node->value;
+ break;
+ }
+ };
+ return ret;
+}
+
+static bool randomized_test(void)
+{
+ bool ret = false;
+ unsigned int i, j, k, mutate_amount, cidr;
+ struct wireguard_peer **peers, *peer;
+ struct routing_table t;
+ struct horrible_routing_table h;
+ u8 ip[16], mutate_mask[16], mutated[16];
+
+ routing_table_init(&t);
+ horrible_routing_table_init(&h);
+
+ peers = kcalloc(NUM_PEERS, sizeof(struct wireguard_peer *), GFP_KERNEL);
+ if (!peers) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ for (i = 0; i < NUM_PEERS; ++i) {
+ peers[i] = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL);
+ if (!peers[i]) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ kref_init(&peers[i]->refcount);
+ }
+
+ for (i = 0; i < NUM_RAND_ROUTES; ++i) {
+ prandom_bytes(ip, 4);
+ cidr = prandom_u32_max(32) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (routing_table_insert_v4(&t, (struct in_addr *)ip, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ if (horrible_routing_table_insert_v4(&h, (struct in_addr *)ip, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
+ memcpy(mutated, ip, 4);
+ prandom_bytes(mutate_mask, 4);
+ mutate_amount = prandom_u32_max(32);
+ for (k = 0; k < mutate_amount / 8; ++k)
+ mutate_mask[k] = 0xff;
+ mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8);
+ for (; k < 4; ++k)
+ mutate_mask[k] = 0;
+ for (k = 0; k < 4; ++k)
+ mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256));
+ cidr = prandom_u32_max(32) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (routing_table_insert_v4(&t, (struct in_addr *)mutated, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ if (horrible_routing_table_insert_v4(&h, (struct in_addr *)mutated, cidr, peer)) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ }
+ }
+
+ for (i = 0; i < NUM_RAND_ROUTES; ++i) {
+ prandom_bytes(ip, 16);
+ cidr = prandom_u32_max(128) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (routing_table_insert_v6(&t, (struct in6_addr *)ip, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ if (horrible_routing_table_insert_v6(&h, (struct in6_addr *)ip, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ for (j = 0; j < NUM_MUTATED_ROUTES; ++j) {
+ memcpy(mutated, ip, 16);
+ prandom_bytes(mutate_mask, 16);
+ mutate_amount = prandom_u32_max(128);
+ for (k = 0; k < mutate_amount / 8; ++k)
+ mutate_mask[k] = 0xff;
+ mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8);
+ for (; k < 4; ++k)
+ mutate_mask[k] = 0;
+ for (k = 0; k < 4; ++k)
+ mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256));
+ cidr = prandom_u32_max(128) + 1;
+ peer = peers[prandom_u32_max(NUM_PEERS)];
+ if (routing_table_insert_v6(&t, (struct in6_addr *)mutated, cidr, peer) < 0) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ if (horrible_routing_table_insert_v6(&h, (struct in6_addr *)mutated, cidr, peer)) {
+ pr_info("routing table random self-test: out of memory\n");
+ goto free;
+ }
+ }
+ }
+
+#ifdef DEBUG_PRINT_TRIE_GRAPHVIZ
+ print_tree(t.root4, 32);
+ print_tree(t.root6, 128);
+#endif
+
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ prandom_bytes(ip, 4);
+ if (lookup(t.root4, 32, ip) != horrible_routing_table_lookup_v4(&h, (struct in_addr *)ip)) {
+ pr_info("routing table random self-test: FAIL\n");
+ goto free;
+ }
+ }
+
+ for (i = 0; i < NUM_QUERIES; ++i) {
+ prandom_bytes(ip, 16);
+ if (lookup(t.root6, 128, ip) != horrible_routing_table_lookup_v6(&h, (struct in6_addr *)ip)) {
+ pr_info("routing table random self-test: FAIL\n");
+ goto free;
+ }
+ }
+ ret = true;
+
+free:
+ routing_table_free(&t);
+ horrible_routing_table_free(&h);
+ if (peers) {
+ for (i = 0; i < NUM_PEERS; ++i)
+ kfree(peers[i]);
+ }
+ kfree(peers);
+ return ret;
+}
+#endif
+
+static inline struct in_addr *ip4(u8 a, u8 b, u8 c, u8 d)
+{
+ static struct in_addr ip;
+ u8 *split = (u8 *)&ip;
+ split[0] = a;
+ split[1] = b;
+ split[2] = c;
+ split[3] = d;
+ return &ip;
+}
+static inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
+{
+ static struct in6_addr ip;
+ __be32 *split = (__be32 *)&ip;
+ split[0] = cpu_to_be32(a);
+ split[1] = cpu_to_be32(b);
+ split[2] = cpu_to_be32(c);
+ split[3] = cpu_to_be32(d);
+ return &ip;
+}
+
+#define init_peer(name) do { \
+ name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); \
+ if (!name) { \
+ pr_info("routing table self-test: out of memory\n"); \
+ goto free; \
+ } \
+ kref_init(&name->refcount); \
+} while (0)
+
+#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \
+ routing_table_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem)
+
+#define maybe_fail \
+ ++i; \
+ if (!_s) { \
+ pr_info("routing table self-test %zu: FAIL\n", i); \
+ success = false; \
+ }
+
+#define test(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) == mem; \
+ maybe_fail \
+} while (0)
+
+#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) != mem; \
+ maybe_fail \
+} while (0)
+
+bool routing_table_selftest(void)
+{
+ struct routing_table t;
+ struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL;
+ size_t i = 0;
+ bool success = false;
+ struct in6_addr ip;
+ __be64 part;
+
+ routing_table_init(&t);
+ init_peer(a);
+ init_peer(b);
+ init_peer(c);
+ init_peer(d);
+ init_peer(e);
+ init_peer(f);
+ init_peer(g);
+ init_peer(h);
+
+ insert(4, a, 192, 168, 4, 0, 24);
+ insert(4, b, 192, 168, 4, 4, 32);
+ insert(4, c, 192, 168, 0, 0, 16);
+ insert(4, d, 192, 95, 5, 64, 27);
+ insert(4, c, 192, 95, 5, 65, 27); /* replaces previous entry, and maskself is required */
+ insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
+ insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
+ insert(4, e, 0, 0, 0, 0, 0);
+ insert(6, e, 0, 0, 0, 0, 0);
+ insert(6, f, 0, 0, 0, 0, 0); /* replaces previous entry */
+ insert(6, g, 0x24046800, 0, 0, 0, 32);
+ insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); /* maskself is required */
+ insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
+ insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
+ insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
+ insert(4, g, 64, 15, 112, 0, 20);
+ insert(4, h, 64, 15, 123, 211, 25); /* maskself is required */
+ insert(4, a, 10, 0, 0, 0, 25);
+ insert(4, b, 10, 0, 0, 128, 25);
+ insert(4, a, 10, 1, 0, 0, 30);
+ insert(4, b, 10, 1, 0, 4, 30);
+ insert(4, c, 10, 1, 0, 8, 29);
+ insert(4, d, 10, 1, 0, 16, 29);
+
+#ifdef DEBUG_PRINT_TRIE_GRAPHVIZ
+ print_tree(t.root4, 32);
+ print_tree(t.root6, 128);
+#endif
+
+ success = true;
+
+ test(4, a, 192, 168, 4, 20);
+ test(4, a, 192, 168, 4, 0);
+ test(4, b, 192, 168, 4, 4);
+ test(4, c, 192, 168, 200, 182);
+ test(4, c, 192, 95, 5, 68);
+ test(4, e, 192, 95, 5, 96);
+ test(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543);
+ test(6, c, 0x26075300, 0x60006b00, 0, 0xc02e01ee);
+ test(6, f, 0x26075300, 0x60006b01, 0, 0);
+ test(6, g, 0x24046800, 0x40040806, 0, 0x1006);
+ test(6, g, 0x24046800, 0x40040806, 0x1234, 0x5678);
+ test(6, f, 0x240467ff, 0x40040806, 0x1234, 0x5678);
+ test(6, f, 0x24046801, 0x40040806, 0x1234, 0x5678);
+ test(6, h, 0x24046800, 0x40040800, 0x1234, 0x5678);
+ test(6, h, 0x24046800, 0x40040800, 0, 0);
+ test(6, h, 0x24046800, 0x40040800, 0x10101010, 0x10101010);
+ test(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef);
+ test(4, g, 64, 15, 116, 26);
+ test(4, g, 64, 15, 127, 3);
+ test(4, g, 64, 15, 123, 1);
+ test(4, h, 64, 15, 123, 128);
+ test(4, h, 64, 15, 123, 129);
+ test(4, a, 10, 0, 0, 52);
+ test(4, b, 10, 0, 0, 220);
+ test(4, a, 10, 1, 0, 2);
+ test(4, b, 10, 1, 0, 6);
+ test(4, c, 10, 1, 0, 10);
+ test(4, d, 10, 1, 0, 20);
+
+ insert(4, a, 1, 0, 0, 0, 32);
+ insert(4, a, 64, 0, 0, 0, 32);
+ insert(4, a, 128, 0, 0, 0, 32);
+ insert(4, a, 192, 0, 0, 0, 32);
+ insert(4, a, 255, 0, 0, 0, 32);
+ routing_table_remove_by_peer(&t, a);
+ test_negative(4, a, 1, 0, 0, 0);
+ test_negative(4, a, 64, 0, 0, 0);
+ test_negative(4, a, 128, 0, 0, 0);
+ test_negative(4, a, 192, 0, 0, 0);
+ test_negative(4, a, 255, 0, 0, 0);
+
+ routing_table_free(&t);
+ routing_table_init(&t);
+ insert(4, a, 192, 168, 0, 0, 16);
+ insert(4, a, 192, 168, 0, 0, 24);
+ routing_table_remove_by_peer(&t, a);
+ test_negative(4, a, 192, 168, 0, 1);
+
+ /* These will hit the BUG_ON(len >= 128) in free_node if something goes wrong. */
+ for (i = 0; i < 128; ++i) {
+ part = cpu_to_be64(~(1LLU << (i % 64)));
+ memset(&ip, 0xff, 16);
+ memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
+ routing_table_insert_v6(&t, &ip, 128, a);
+ }
+
+#ifdef DEBUG_RANDOM_TRIE
+ if (success)
+ success = randomized_test();
+#endif
+
+ if (success)
+ pr_info("routing table self-tests: pass\n");
+
+free:
+ routing_table_free(&t);
+ kfree(a);
+ kfree(b);
+ kfree(c);
+ kfree(d);
+ kfree(e);
+ kfree(f);
+ kfree(g);
+ kfree(h);
+
+ return success;
+}
+#undef test_negative
+#undef test
+#undef remove
+#undef insert
+#undef init_peer
+
+#endif