diff options
Diffstat (limited to 'src/allowedips.c')
-rw-r--r-- | src/allowedips.c | 100 |
1 files changed, 30 insertions, 70 deletions
diff --git a/src/allowedips.c b/src/allowedips.c index 30b66f4..bfb6020 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -6,18 +6,6 @@ #include "allowedips.h" #include "peer.h" -struct allowedips_node { - struct wg_peer __rcu *peer; - struct rcu_head rcu; - struct allowedips_node __rcu *bit[2]; - /* While it may seem scandalous that we waste space for v4, - * we're alloc'ing to the nearest power of 2 anyway, so this - * doesn't actually make a difference. - */ - u8 bits[16] __aligned(__alignof(u64)); - u8 cidr, bit_at_a, bit_at_b; -}; - static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits) { if (bits == 32) { @@ -37,6 +25,7 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, node->bit_at_a ^= (bits / 8U - 1U) % 8U; #endif node->bit_at_b = 7U - (cidr % 8U); + node->bitlen = bits; memcpy(node->bits, src, bits / 8U); } #define CHOOSE_NODE(parent, key) \ @@ -69,43 +58,17 @@ static void root_free_rcu(struct rcu_head *rcu) } } -static int -walk_by_peer(struct allowedips_node __rcu *top, u8 bits, - struct allowedips_cursor *cursor, struct wg_peer *peer, - int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), - void *ctx, struct mutex *lock) +static void root_remove_peer_lists(struct allowedips_node *root) { - const int address_family = bits == 32 ? AF_INET : AF_INET6; - /* Aligned so it can be treated as u64 */ - u8 ip[16] __aligned(__alignof(u64)); - struct allowedips_node *node; - int ret; - - if (!rcu_access_pointer(top)) - return 0; - - if (!cursor->len) - push_rcu(cursor->stack, top, &cursor->len); - - for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); - --cursor->len, push_rcu(cursor->stack, node->bit[0], &cursor->len), - push_rcu(cursor->stack, node->bit[1], &cursor->len)) { - const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); - - if (rcu_dereference_protected(node->peer, - lockdep_is_held(lock)) != peer) - continue; - - swap_endian(ip, node->bits, bits); - memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes); - if (node->cidr) - ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); + struct allowedips_node *node, *stack[128] = { root }; + unsigned int len = 1; - ret = func(ctx, ip, node->cidr, address_family); - if (ret) - return ret; + while (len > 0 && (node = stack[--len])) { + push_rcu(stack, node->bit[0], &len); + push_rcu(stack, node->bit[1], &len); + if (rcu_access_pointer(node->peer)) + list_del(&node->peer_list); } - return 0; } static void walk_remove_by_peer(struct allowedips_node __rcu **top, @@ -145,6 +108,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) { RCU_INIT_POINTER(node->peer, NULL); + list_del(&node->peer_list); if (!node->bit[0] || !node->bit[1]) { rcu_assign_pointer(*nptr, DEREF( &node->bit[!REF(node->bit[0])])); @@ -263,12 +227,14 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (unlikely(!node)) return -ENOMEM; RCU_INIT_POINTER(node->peer, peer); + list_add_tail(&node->peer_list, &peer->allowedips_list); copy_and_assign_cidr(node, key, cidr, bits); rcu_assign_pointer(*trie, node); return 0; } if (node_placement(*trie, key, cidr, bits, &node, lock)) { rcu_assign_pointer(node->peer, peer); + list_move_tail(&node->peer_list, &peer->allowedips_list); return 0; } @@ -276,6 +242,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (unlikely(!newnode)) return -ENOMEM; RCU_INIT_POINTER(newnode->peer, peer); + list_add_tail(&newnode->peer_list, &peer->allowedips_list); copy_and_assign_cidr(newnode, key, cidr, bits); if (!node) { @@ -304,6 +271,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, kfree(newnode); return -ENOMEM; } + INIT_LIST_HEAD(&node->peer_list); copy_and_assign_cidr(node, newnode->bits, cidr, bits); rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); @@ -326,15 +294,20 @@ void wg_allowedips_init(struct allowedips *table) void wg_allowedips_free(struct allowedips *table, struct mutex *lock) { struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; + ++table->seq; RCU_INIT_POINTER(table->root4, NULL); RCU_INIT_POINTER(table->root6, NULL); - if (rcu_access_pointer(old4)) + if (rcu_access_pointer(old4)) { + root_remove_peer_lists(old4); call_rcu_bh(&rcu_dereference_protected(old4, lockdep_is_held(lock))->rcu, root_free_rcu); - if (rcu_access_pointer(old6)) + } + if (rcu_access_pointer(old6)) { + root_remove_peer_lists(old6); call_rcu_bh(&rcu_dereference_protected(old6, lockdep_is_held(lock))->rcu, root_free_rcu); + } } int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, @@ -367,29 +340,16 @@ void wg_allowedips_remove_by_peer(struct allowedips *table, walk_remove_by_peer(&table->root6, peer, lock); } -int wg_allowedips_walk_by_peer(struct allowedips *table, - struct allowedips_cursor *cursor, - struct wg_peer *peer, - int (*func)(void *ctx, const u8 *ip, u8 cidr, - int family), - void *ctx, struct mutex *lock) +int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) { - int ret; - - if (!cursor->seq) - cursor->seq = table->seq; - else if (cursor->seq != table->seq) - return 0; - - if (!cursor->second_half) { - ret = walk_by_peer(table->root4, 32, cursor, peer, func, ctx, - lock); - if (ret) - return ret; - cursor->len = 0; - cursor->second_half = true; - } - return walk_by_peer(table->root6, 128, cursor, peer, func, ctx, lock); + const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); + swap_endian(ip, node->bits, node->bitlen); + memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); + if (node->cidr) + ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); + + *cidr = node->cidr; + return node->bitlen == 32 ? AF_INET : AF_INET6; } /* Returns a strong reference to a peer */ |