summaryrefslogtreecommitdiffhomepage
path: root/src/allowedips.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/allowedips.c')
-rw-r--r--src/allowedips.c132
1 files changed, 54 insertions, 78 deletions
diff --git a/src/allowedips.c b/src/allowedips.c
index 23c2285..e407c69 100644
--- a/src/allowedips.c
+++ b/src/allowedips.c
@@ -71,60 +71,6 @@ static void root_remove_peer_lists(struct allowedips_node *root)
}
}
-static void walk_remove_by_peer(struct allowedips_node __rcu **top,
- struct wg_peer *peer, struct mutex *lock)
-{
-#define REF(p) rcu_access_pointer(p)
-#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
-#define PUSH(p) ({ \
- WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
- stack[len++] = p; \
- })
-
- struct allowedips_node __rcu **stack[128], **nptr;
- struct allowedips_node *node, *prev;
- unsigned int len;
-
- if (unlikely(!peer || !REF(*top)))
- return;
-
- for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
- nptr = stack[len - 1];
- node = DEREF(nptr);
- if (!node) {
- --len;
- continue;
- }
- if (!prev || REF(prev->bit[0]) == node ||
- REF(prev->bit[1]) == node) {
- if (REF(node->bit[0]))
- PUSH(&node->bit[0]);
- else if (REF(node->bit[1]))
- PUSH(&node->bit[1]);
- } else if (REF(node->bit[0]) == prev) {
- if (REF(node->bit[1]))
- PUSH(&node->bit[1]);
- } else {
- if (rcu_dereference_protected(node->peer,
- lockdep_is_held(lock)) == peer) {
- RCU_INIT_POINTER(node->peer, NULL);
- list_del_init(&node->peer_list);
- if (!node->bit[0] || !node->bit[1]) {
- rcu_assign_pointer(*nptr, DEREF(
- &node->bit[!REF(node->bit[0])]));
- call_rcu(&node->rcu, node_free_rcu);
- node = DEREF(nptr);
- }
- }
- --len;
- }
- }
-
-#undef REF
-#undef DEREF
-#undef PUSH
-}
-
static unsigned int fls128(u64 a, u64 b)
{
return a ? fls64(a) + 64U : fls64(b);
@@ -229,6 +175,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
RCU_INIT_POINTER(node->peer, peer);
list_add_tail(&node->peer_list, &peer->allowedips_list);
copy_and_assign_cidr(node, key, cidr, bits);
+ rcu_assign_pointer(node->parent_bit, trie);
rcu_assign_pointer(*trie, node);
return 0;
}
@@ -248,9 +195,9 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (!node) {
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
} else {
- down = rcu_dereference_protected(CHOOSE_NODE(node, key),
- lockdep_is_held(lock));
+ down = rcu_dereference_protected(CHOOSE_NODE(node, key), lockdep_is_held(lock));
if (!down) {
+ rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, key));
rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
return 0;
}
@@ -259,29 +206,37 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
parent = node;
if (newnode->cidr == cidr) {
+ rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(newnode, down->bits));
rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
- if (!parent)
+ if (!parent) {
+ rcu_assign_pointer(newnode->parent_bit, trie);
rcu_assign_pointer(*trie, newnode);
- else
- rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
- newnode);
- } else {
- node = kzalloc(sizeof(*node), GFP_KERNEL);
- if (unlikely(!node)) {
- list_del(&newnode->peer_list);
- kfree(newnode);
- return -ENOMEM;
+ } else {
+ rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(parent, newnode->bits));
+ rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits), newnode);
}
- INIT_LIST_HEAD(&node->peer_list);
- copy_and_assign_cidr(node, newnode->bits, cidr, bits);
-
- rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
- rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
- if (!parent)
- rcu_assign_pointer(*trie, node);
- else
- rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
- node);
+ return 0;
+ }
+
+ node = kzalloc(sizeof(*node), GFP_KERNEL);
+ if (unlikely(!node)) {
+ list_del(&newnode->peer_list);
+ kfree(newnode);
+ return -ENOMEM;
+ }
+ INIT_LIST_HEAD(&node->peer_list);
+ copy_and_assign_cidr(node, newnode->bits, cidr, bits);
+
+ rcu_assign_pointer(down->parent_bit, &CHOOSE_NODE(node, down->bits));
+ rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
+ rcu_assign_pointer(newnode->parent_bit, &CHOOSE_NODE(node, newnode->bits));
+ rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
+ if (!parent) {
+ rcu_assign_pointer(node->parent_bit, trie);
+ rcu_assign_pointer(*trie, node);
+ } else {
+ rcu_assign_pointer(node->parent_bit, &CHOOSE_NODE(parent, node->bits));
+ rcu_assign_pointer(CHOOSE_NODE(parent, node->bits), node);
}
return 0;
}
@@ -340,9 +295,30 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
void wg_allowedips_remove_by_peer(struct allowedips *table,
struct wg_peer *peer, struct mutex *lock)
{
+ struct allowedips_node *node, *child, *tmp;
+
+ if (list_empty(&peer->allowedips_list))
+ return;
++table->seq;
- walk_remove_by_peer(&table->root4, peer, lock);
- walk_remove_by_peer(&table->root6, peer, lock);
+ list_for_each_entry_safe(node, tmp, &peer->allowedips_list, peer_list) {
+ list_del_init(&node->peer_list);
+ RCU_INIT_POINTER(node->peer, NULL);
+ if (node->bit[0] && node->bit[1])
+ continue;
+ child = rcu_dereference_protected(
+ node->bit[!rcu_access_pointer(node->bit[0])],
+ lockdep_is_held(lock));
+ if (child)
+ child->parent_bit = node->parent_bit;
+ *rcu_dereference_protected(node->parent_bit, lockdep_is_held(lock)) = child;
+ kfree_rcu(node, rcu);
+
+ /* TODO: Note that we currently don't walk up and down in order to
+ * free any potential filler nodes. This means that this function
+ * doesn't free up as much as it could, which could be revisited
+ * at some point.
+ */
+ }
}
int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr)