diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-02-26 03:38:24 +0100 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-02-26 23:01:12 +0100 |
commit | cc58f5878702332bd5b619f39b2c0ab9dbdef292 (patch) | |
tree | 92945e984ef48d79bf928042540fcd3a2335e075 | |
parent | 4c25956de55aefc96f16ddc5324c8a304b570aff (diff) |
allowedips: maintain per-peer list of allowedips
This makes `wg show` and `wg showconf` and the like significantly
faster, since we don't have to iterate through every node of the trie
for every single peer. It also makes netlink cursor resumption much less
problematic, since we're just iterating through a list, rather than
having to save a traversal stack.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | src/allowedips.c | 100 | ||||
-rw-r--r-- | src/allowedips.h | 33 | ||||
-rw-r--r-- | src/netlink.c | 61 | ||||
-rw-r--r-- | src/peer.c | 1 | ||||
-rw-r--r-- | src/peer.h | 1 | ||||
-rw-r--r-- | src/selftest/allowedips.c | 93 |
6 files changed, 129 insertions, 160 deletions
diff --git a/src/allowedips.c b/src/allowedips.c index 30b66f4..bfb6020 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -6,18 +6,6 @@ #include "allowedips.h" #include "peer.h" -struct allowedips_node { - struct wg_peer __rcu *peer; - struct rcu_head rcu; - struct allowedips_node __rcu *bit[2]; - /* While it may seem scandalous that we waste space for v4, - * we're alloc'ing to the nearest power of 2 anyway, so this - * doesn't actually make a difference. - */ - u8 bits[16] __aligned(__alignof(u64)); - u8 cidr, bit_at_a, bit_at_b; -}; - static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits) { if (bits == 32) { @@ -37,6 +25,7 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, node->bit_at_a ^= (bits / 8U - 1U) % 8U; #endif node->bit_at_b = 7U - (cidr % 8U); + node->bitlen = bits; memcpy(node->bits, src, bits / 8U); } #define CHOOSE_NODE(parent, key) \ @@ -69,43 +58,17 @@ static void root_free_rcu(struct rcu_head *rcu) } } -static int -walk_by_peer(struct allowedips_node __rcu *top, u8 bits, - struct allowedips_cursor *cursor, struct wg_peer *peer, - int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), - void *ctx, struct mutex *lock) +static void root_remove_peer_lists(struct allowedips_node *root) { - const int address_family = bits == 32 ? AF_INET : AF_INET6; - /* Aligned so it can be treated as u64 */ - u8 ip[16] __aligned(__alignof(u64)); - struct allowedips_node *node; - int ret; - - if (!rcu_access_pointer(top)) - return 0; - - if (!cursor->len) - push_rcu(cursor->stack, top, &cursor->len); - - for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); - --cursor->len, push_rcu(cursor->stack, node->bit[0], &cursor->len), - push_rcu(cursor->stack, node->bit[1], &cursor->len)) { - const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); - - if (rcu_dereference_protected(node->peer, - lockdep_is_held(lock)) != peer) - continue; - - swap_endian(ip, node->bits, bits); - memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes); - if (node->cidr) - ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); + struct allowedips_node *node, *stack[128] = { root }; + unsigned int len = 1; - ret = func(ctx, ip, node->cidr, address_family); - if (ret) - return ret; + while (len > 0 && (node = stack[--len])) { + push_rcu(stack, node->bit[0], &len); + push_rcu(stack, node->bit[1], &len); + if (rcu_access_pointer(node->peer)) + list_del(&node->peer_list); } - return 0; } static void walk_remove_by_peer(struct allowedips_node __rcu **top, @@ -145,6 +108,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) { RCU_INIT_POINTER(node->peer, NULL); + list_del(&node->peer_list); if (!node->bit[0] || !node->bit[1]) { rcu_assign_pointer(*nptr, DEREF( &node->bit[!REF(node->bit[0])])); @@ -263,12 +227,14 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (unlikely(!node)) return -ENOMEM; RCU_INIT_POINTER(node->peer, peer); + list_add_tail(&node->peer_list, &peer->allowedips_list); copy_and_assign_cidr(node, key, cidr, bits); rcu_assign_pointer(*trie, node); return 0; } if (node_placement(*trie, key, cidr, bits, &node, lock)) { rcu_assign_pointer(node->peer, peer); + list_move_tail(&node->peer_list, &peer->allowedips_list); return 0; } @@ -276,6 +242,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (unlikely(!newnode)) return -ENOMEM; RCU_INIT_POINTER(newnode->peer, peer); + list_add_tail(&newnode->peer_list, &peer->allowedips_list); copy_and_assign_cidr(newnode, key, cidr, bits); if (!node) { @@ -304,6 +271,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, kfree(newnode); return -ENOMEM; } + INIT_LIST_HEAD(&node->peer_list); copy_and_assign_cidr(node, newnode->bits, cidr, bits); rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); @@ -326,15 +294,20 @@ void wg_allowedips_init(struct allowedips *table) void wg_allowedips_free(struct allowedips *table, struct mutex *lock) { struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; + ++table->seq; RCU_INIT_POINTER(table->root4, NULL); RCU_INIT_POINTER(table->root6, NULL); - if (rcu_access_pointer(old4)) + if (rcu_access_pointer(old4)) { + root_remove_peer_lists(old4); call_rcu_bh(&rcu_dereference_protected(old4, lockdep_is_held(lock))->rcu, root_free_rcu); - if (rcu_access_pointer(old6)) + } + if (rcu_access_pointer(old6)) { + root_remove_peer_lists(old6); call_rcu_bh(&rcu_dereference_protected(old6, lockdep_is_held(lock))->rcu, root_free_rcu); + } } int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, @@ -367,29 +340,16 @@ void wg_allowedips_remove_by_peer(struct allowedips *table, walk_remove_by_peer(&table->root6, peer, lock); } -int wg_allowedips_walk_by_peer(struct allowedips *table, - struct allowedips_cursor *cursor, - struct wg_peer *peer, - int (*func)(void *ctx, const u8 *ip, u8 cidr, - int family), - void *ctx, struct mutex *lock) +int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) { - int ret; - - if (!cursor->seq) - cursor->seq = table->seq; - else if (cursor->seq != table->seq) - return 0; - - if (!cursor->second_half) { - ret = walk_by_peer(table->root4, 32, cursor, peer, func, ctx, - lock); - if (ret) - return ret; - cursor->len = 0; - cursor->second_half = true; - } - return walk_by_peer(table->root6, 128, cursor, peer, func, ctx, lock); + const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); + swap_endian(ip, node->bits, node->bitlen); + memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); + if (node->cidr) + ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); + + *cidr = node->cidr; + return node->bitlen == 32 ? AF_INET : AF_INET6; } /* Returns a strong reference to a peer */ diff --git a/src/allowedips.h b/src/allowedips.h index 29e15a2..e5c83ca 100644 --- a/src/allowedips.h +++ b/src/allowedips.h @@ -11,7 +11,23 @@ #include <linux/ipv6.h> struct wg_peer; -struct allowedips_node; + +struct allowedips_node { + struct wg_peer __rcu *peer; + struct allowedips_node __rcu *bit[2]; + /* While it may seem scandalous that we waste space for v4, + * we're alloc'ing to the nearest power of 2 anyway, so this + * doesn't actually make a difference. + */ + u8 bits[16] __aligned(__alignof(u64)); + u8 cidr, bit_at_a, bit_at_b, bitlen; + + /* Keep rarely used list at bottom to be beyond cache line. */ + union { + struct list_head peer_list; + struct rcu_head rcu; + }; +}; struct allowedips { struct allowedips_node __rcu *root4; @@ -19,13 +35,6 @@ struct allowedips { u64 seq; }; -struct allowedips_cursor { - u64 seq; - struct allowedips_node *stack[128]; - unsigned int len; - bool second_half; -}; - void wg_allowedips_init(struct allowedips *table); void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, @@ -34,12 +43,8 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock); void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock); -int wg_allowedips_walk_by_peer(struct allowedips *table, - struct allowedips_cursor *cursor, - struct wg_peer *peer, - int (*func)(void *ctx, const u8 *ip, u8 cidr, - int family), - void *ctx, struct mutex *lock); +/* The ip input pointer should be __aligned(__alignof(u64))) */ +int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr); /* These return a strong reference to a peer: */ struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, diff --git a/src/netlink.c b/src/netlink.c index f44f211..b179b31 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -69,9 +69,9 @@ static struct wg_device *lookup_interface(struct nlattr **attrs, return netdev_priv(dev); } -static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) +static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr, + int family) { - struct sk_buff *skb = ctx; struct nlattr *allowedip_nest; allowedip_nest = nla_nest_start(skb, 0); @@ -90,10 +90,12 @@ static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) return 0; } -static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, - struct sk_buff *skb) +static int +get_peer(struct wg_peer *peer, struct allowedips_node **next_allowedips_node, + u64 *allowedips_seq, struct sk_buff *skb) { struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0); + struct allowedips_node *allowedips_node = *next_allowedips_node; bool fail; if (!peer_nest) @@ -106,7 +108,7 @@ static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, if (fail) goto err; - if (!rt_cursor->seq) { + if (!allowedips_node) { const struct __kernel_timespec last_handshake = { .tv_sec = peer->walltime_last_handshake.tv_sec, .tv_nsec = peer->walltime_last_handshake.tv_nsec @@ -143,21 +145,39 @@ static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, read_unlock_bh(&peer->endpoint_lock); if (fail) goto err; + allowedips_node = + list_first_entry_or_null(&peer->allowedips_list, + struct allowedips_node, peer_list); } + if (!allowedips_node) + goto no_allowedips; + if (!*allowedips_seq) + *allowedips_seq = peer->device->peer_allowedips.seq; + else if (*allowedips_seq != peer->device->peer_allowedips.seq) + goto no_allowedips; allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); if (!allowedips_nest) goto err; - if (wg_allowedips_walk_by_peer(&peer->device->peer_allowedips, - rt_cursor, peer, get_allowedips, skb, - &peer->device->device_update_lock)) { - nla_nest_end(skb, allowedips_nest); - nla_nest_end(skb, peer_nest); - return -EMSGSIZE; + + list_for_each_entry_from(allowedips_node, &peer->allowedips_list, + peer_list) { + u8 cidr, ip[16] __aligned(__alignof(u64)); + int family; + + family = wg_allowedips_read_node(allowedips_node, ip, &cidr); + if (get_allowedips(skb, ip, cidr, family)) { + nla_nest_end(skb, allowedips_nest); + nla_nest_end(skb, peer_nest); + *next_allowedips_node = allowedips_node; + return -EMSGSIZE; + } } - memset(rt_cursor, 0, sizeof(*rt_cursor)); nla_nest_end(skb, allowedips_nest); +no_allowedips: nla_nest_end(skb, peer_nest); + *next_allowedips_node = NULL; + *allowedips_seq = 0; return 0; err: nla_nest_cancel(skb, peer_nest); @@ -174,16 +194,9 @@ static int wg_get_device_start(struct netlink_callback *cb) genl_family.maxattr, device_policy, NULL); if (ret < 0) return ret; - cb->args[2] = (long)kzalloc(sizeof(struct allowedips_cursor), - GFP_KERNEL); - if (unlikely(!cb->args[2])) - return -ENOMEM; wg = lookup_interface(attrs, cb->skb); - if (IS_ERR(wg)) { - kfree((void *)cb->args[2]); - cb->args[2] = 0; + if (IS_ERR(wg)) return PTR_ERR(wg); - } cb->args[0] = (long)wg; return 0; } @@ -191,7 +204,6 @@ static int wg_get_device_start(struct netlink_callback *cb) static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) { struct wg_peer *peer, *next_peer_cursor, *last_peer_cursor; - struct allowedips_cursor *rt_cursor; struct nlattr *peers_nest; struct wg_device *wg; int ret = -EMSGSIZE; @@ -201,7 +213,6 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) wg = (struct wg_device *)cb->args[0]; next_peer_cursor = (struct wg_peer *)cb->args[1]; last_peer_cursor = (struct wg_peer *)cb->args[1]; - rt_cursor = (struct allowedips_cursor *)cb->args[2]; rtnl_lock(); mutex_lock(&wg->device_update_lock); @@ -253,7 +264,8 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list); list_for_each_entry_continue(peer, &wg->peer_list, peer_list) { - if (get_peer(peer, rt_cursor, skb)) { + if (get_peer(peer, (struct allowedips_node **)&cb->args[2], + (u64 *)&cb->args[4] /* and args[5] */, skb)) { done = false; break; } @@ -290,12 +302,9 @@ static int wg_get_device_done(struct netlink_callback *cb) { struct wg_device *wg = (struct wg_device *)cb->args[0]; struct wg_peer *peer = (struct wg_peer *)cb->args[1]; - struct allowedips_cursor *rt_cursor = - (struct allowedips_cursor *)cb->args[2]; if (wg) dev_put(wg->dev); - kfree(rt_cursor); wg_peer_put(peer); return 0; } @@ -64,6 +64,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, NAPI_POLL_WEIGHT); napi_enable(&peer->napi); list_add_tail(&peer->peer_list, &wg->peer_list); + INIT_LIST_HEAD(&peer->allowedips_list); wg_pubkey_hashtable_add(wg->peer_hashtable, peer); ++wg->num_peers; pr_debug("%s: Peer %llu created\n", wg->dev->name, peer->internal_id); @@ -60,6 +60,7 @@ struct wg_peer { struct kref refcount; struct rcu_head rcu; struct list_head peer_list; + struct list_head allowedips_list; u64 internal_id; struct napi_struct napi; bool is_dead; diff --git a/src/selftest/allowedips.c b/src/selftest/allowedips.c index 379ac31..6e244a9 100644 --- a/src/selftest/allowedips.c +++ b/src/selftest/allowedips.c @@ -452,47 +452,14 @@ static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d) return &ip; } -struct walk_ctx { - int count; - bool found_a, found_b, found_c, found_d, found_e; - bool found_other; -}; - -static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family) -{ - struct walk_ctx *wctx = ctx; - - wctx->count++; - - if (cidr == 27 && - !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) - wctx->found_a = true; - else if (cidr == 128 && - !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), - sizeof(struct in6_addr))) - wctx->found_b = true; - else if (cidr == 29 && - !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) - wctx->found_c = true; - else if (cidr == 83 && - !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), - sizeof(struct in6_addr))) - wctx->found_d = true; - else if (cidr == 21 && - !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr))) - wctx->found_e = true; - else - wctx->found_other = true; - - return 0; -} - static __init struct wg_peer *init_peer(void) { struct wg_peer *peer = kzalloc(sizeof(*peer), GFP_KERNEL); - if (peer) - kref_init(&peer->refcount); + if (!peer) + return NULL; + kref_init(&peer->refcount); + INIT_LIST_HEAD(&peer->allowedips_list); return peer; } @@ -527,23 +494,24 @@ static __init struct wg_peer *init_peer(void) bool __init wg_allowedips_selftest(void) { - struct allowedips_cursor *cursor = kzalloc(sizeof(*cursor), GFP_KERNEL); + bool found_a = false, found_b = false, found_c = false, found_d = false, + found_e = false, found_other = false; struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(), *d = init_peer(), *e = init_peer(), *f = init_peer(), *g = init_peer(), *h = init_peer(); - struct walk_ctx wctx = { 0 }; + struct allowedips_node *iter_node; bool success = false; struct allowedips t; DEFINE_MUTEX(mutex); struct in6_addr ip; - size_t i = 0; + size_t i = 0, count = 0; __be64 part; mutex_init(&mutex); mutex_lock(&mutex); wg_allowedips_init(&t); - if (!cursor || !a || !b || !c || !d || !e || !f || !g || !h) { + if (!a || !b || !c || !d || !e || !f || !g || !h) { pr_err("allowedips self-test malloc: FAIL\n"); goto free; } @@ -649,14 +617,40 @@ bool __init wg_allowedips_selftest(void) insert(4, a, 10, 1, 0, 20, 29); insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83); insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21); - wg_allowedips_walk_by_peer(&t, cursor, a, walk_callback, &wctx, &mutex); - test_boolean(wctx.count == 5); - test_boolean(wctx.found_a); - test_boolean(wctx.found_b); - test_boolean(wctx.found_c); - test_boolean(wctx.found_d); - test_boolean(wctx.found_e); - test_boolean(!wctx.found_other); + list_for_each_entry(iter_node, &a->allowedips_list, peer_list) { + u8 cidr, ip[16] __aligned(__alignof(u64)); + int family = wg_allowedips_read_node(iter_node, ip, &cidr); + + count++; + + if (cidr == 27 && family == AF_INET && + !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) + found_a = true; + else if (cidr == 128 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), + sizeof(struct in6_addr))) + found_b = true; + else if (cidr == 29 && family == AF_INET && + !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) + found_c = true; + else if (cidr == 83 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), + sizeof(struct in6_addr))) + found_d = true; + else if (cidr == 21 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075000, 0, 0, 0), + sizeof(struct in6_addr))) + found_e = true; + else + found_other = true; + } + test_boolean(count == 5); + test_boolean(found_a); + test_boolean(found_b); + test_boolean(found_c); + test_boolean(found_d); + test_boolean(found_e); + test_boolean(!found_other); if (IS_ENABLED(DEBUG_RANDOM_TRIE) && success) success = randomized_test(); @@ -675,7 +669,6 @@ free: kfree(g); kfree(h); mutex_unlock(&mutex); - kfree(cursor); return success; } |