diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/allowedips.c | 133 |
1 files changed, 70 insertions, 63 deletions
diff --git a/src/allowedips.c b/src/allowedips.c index bc43b71..b99078d 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -39,8 +39,7 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, node->bit_at_b = 7U - (cidr % 8U); memcpy(node->bits, src, bits / 8U); } - -#define choose_node(parent, key) \ +#define CHOOSE_NODE(parent, key) \ parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] static void node_free_rcu(struct rcu_head *rcu) @@ -48,23 +47,26 @@ static void node_free_rcu(struct rcu_head *rcu) kfree(container_of(rcu, struct allowedips_node, rcu)); } -#define push_rcu(stack, p, len) ({ \ - if (rcu_access_pointer(p)) { \ - WARN_ON(IS_ENABLED(DEBUG) && (len) >= 128); \ - stack[(len)++] = rcu_dereference_raw(p); \ - } \ - true; \ - }) +static void push_rcu(struct allowedips_node **stack, + struct allowedips_node __rcu *p, unsigned int *len) +{ + if (rcu_access_pointer(p)) { + WARN_ON(IS_ENABLED(DEBUG) && *len >= 128); + stack[(*len)++] = rcu_dereference_raw(p); + } +} + static void root_free_rcu(struct rcu_head *rcu) { struct allowedips_node *node, *stack[128] = { container_of(rcu, struct allowedips_node, rcu) }; unsigned int len = 1; - while (len > 0 && (node = stack[--len]) && - push_rcu(stack, node->bit[0], len) && - push_rcu(stack, node->bit[1], len)) + while (len > 0 && (node = stack[--len])) { + push_rcu(stack, node->bit[0], &len); + push_rcu(stack, node->bit[1], &len); kfree(node); + } } static int @@ -74,6 +76,7 @@ walk_by_peer(struct allowedips_node __rcu *top, u8 bits, void *ctx, struct mutex *lock) { const int address_family = bits == 32 ? AF_INET : AF_INET6; + /* Aligned so it can be treated as u64 */ u8 ip[16] __aligned(__alignof(u64)); struct allowedips_node *node; int ret; @@ -82,11 +85,11 @@ walk_by_peer(struct allowedips_node __rcu *top, u8 bits, return 0; if (!cursor->len) - push_rcu(cursor->stack, top, cursor->len); + push_rcu(cursor->stack, top, &cursor->len); for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); - --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len), - push_rcu(cursor->stack, node->bit[1], cursor->len)) { + --cursor->len, push_rcu(cursor->stack, node->bit[0], &cursor->len), + push_rcu(cursor->stack, node->bit[1], &cursor->len)) { const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); if (rcu_dereference_protected(node->peer, @@ -105,62 +108,60 @@ walk_by_peer(struct allowedips_node __rcu *top, u8 bits, return 0; } -#undef push_rcu - -#define ref(p) rcu_access_pointer(p) -#define deref(p) rcu_dereference_protected(*(p), lockdep_is_held(lock)) -#define push(p) ({ \ +static void walk_remove_by_peer(struct allowedips_node __rcu **top, + struct wg_peer *peer, struct mutex *lock) +{ +#define REF(p) rcu_access_pointer(p) +#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock)) +#define PUSH(p) ({ \ WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \ stack[len++] = p; \ }) -static void walk_remove_by_peer(struct allowedips_node __rcu **top, - struct wg_peer *peer, struct mutex *lock) -{ struct allowedips_node __rcu **stack[128], **nptr; struct allowedips_node *node, *prev; unsigned int len; - if (unlikely(!peer || !ref(*top))) + if (unlikely(!peer || !REF(*top))) return; - for (prev = NULL, len = 0, push(top); len > 0; prev = node) { + for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) { nptr = stack[len - 1]; - node = deref(nptr); + node = DEREF(nptr); if (!node) { --len; continue; } - if (!prev || ref(prev->bit[0]) == node || - ref(prev->bit[1]) == node) { - if (ref(node->bit[0])) - push(&node->bit[0]); - else if (ref(node->bit[1])) - push(&node->bit[1]); - } else if (ref(node->bit[0]) == prev) { - if (ref(node->bit[1])) - push(&node->bit[1]); + if (!prev || REF(prev->bit[0]) == node || + REF(prev->bit[1]) == node) { + if (REF(node->bit[0])) + PUSH(&node->bit[0]); + else if (REF(node->bit[1])) + PUSH(&node->bit[1]); + } else if (REF(node->bit[0]) == prev) { + if (REF(node->bit[1])) + PUSH(&node->bit[1]); } else { if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) { RCU_INIT_POINTER(node->peer, NULL); if (!node->bit[0] || !node->bit[1]) { - rcu_assign_pointer(*nptr, - deref(&node->bit[!ref(node->bit[0])])); + rcu_assign_pointer(*nptr, DEREF( + &node->bit[!REF(node->bit[0])])); call_rcu_bh(&node->rcu, node_free_rcu); - node = deref(nptr); + node = DEREF(nptr); } } --len; } } -} -#undef ref -#undef deref -#undef push +#undef REF +#undef DEREF +#undef PUSH +} -static __always_inline unsigned int fls128(u64 a, u64 b) +static unsigned int fls128(u64 a, u64 b) { return a ? fls64(a) + 64U : fls64(b); } @@ -177,14 +178,17 @@ static __always_inline u8 common_bits(const struct allowedips_node *node, return 0; } -/* This could be much faster if it actually just compared the common bits - * properly, by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but - * it turns out that common_bits is already super fast on modern processors, - * even taking into account the unfortunate bswap. So, we just inline it like - * this instead. - */ -#define prefix_matches(node, key, bits) \ - (common_bits(node, key, bits) >= (node)->cidr) +static __always_inline bool prefix_matches(const struct allowedips_node *node, + const u8 *key, u8 bits) +{ + /* This could be much faster if it actually just compared the common + * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and + * the rest, but it turns out that common_bits is already super fast on + * modern processors, even taking into account the unfortunate bswap. + * So, we just inline it like this instead. + */ + return common_bits(node, key, bits) >= node->cidr; +} static __always_inline struct allowedips_node * find_node(struct allowedips_node *trie, u8 bits, const u8 *key) @@ -196,7 +200,7 @@ find_node(struct allowedips_node *trie, u8 bits, const u8 *key) found = node; if (node->cidr == bits) break; - node = rcu_dereference_bh(choose_node(node, key)); + node = rcu_dereference_bh(CHOOSE_NODE(node, key)); } return found; } @@ -205,6 +209,7 @@ find_node(struct allowedips_node *trie, u8 bits, const u8 *key) static __always_inline struct wg_peer * lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip) { + /* Aligned so it can be passed to fls/fls64 */ u8 ip[16] __aligned(__alignof(u64)); struct allowedips_node *node; struct wg_peer *peer = NULL; @@ -223,9 +228,9 @@ retry: return peer; } -__attribute__((nonnull(1))) static bool -node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, - u8 bits, struct allowedips_node **rnode, struct mutex *lock) +static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, + u8 cidr, u8 bits, struct allowedips_node **rnode, + struct mutex *lock) { struct allowedips_node *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); @@ -238,7 +243,7 @@ node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, exact = true; break; } - node = rcu_dereference_protected(choose_node(parent, key), + node = rcu_dereference_protected(CHOOSE_NODE(parent, key), lockdep_is_held(lock)); } *rnode = parent; @@ -276,10 +281,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (!node) { down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); } else { - down = rcu_dereference_protected(choose_node(node, key), + down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock)); if (!down) { - rcu_assign_pointer(choose_node(node, key), newnode); + rcu_assign_pointer(CHOOSE_NODE(node, key), newnode); return 0; } } @@ -287,11 +292,11 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, parent = node; if (newnode->cidr == cidr) { - rcu_assign_pointer(choose_node(newnode, down->bits), down); + rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down); if (!parent) rcu_assign_pointer(*trie, newnode); else - rcu_assign_pointer(choose_node(parent, newnode->bits), + rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode); } else { node = kzalloc(sizeof(*node), GFP_KERNEL); @@ -301,12 +306,12 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, } copy_and_assign_cidr(node, newnode->bits, cidr, bits); - rcu_assign_pointer(choose_node(node, down->bits), down); - rcu_assign_pointer(choose_node(node, newnode->bits), newnode); + rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); + rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode); if (!parent) rcu_assign_pointer(*trie, node); else - rcu_assign_pointer(choose_node(parent, node->bits), + rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node); } return 0; @@ -336,6 +341,7 @@ int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock) { + /* Aligned so it can be passed to fls */ u8 key[4] __aligned(__alignof(u32)); ++table->seq; @@ -347,6 +353,7 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock) { + /* Aligned so it can be passed to fls64 */ u8 key[16] __aligned(__alignof(u64)); ++table->seq; |