summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/allowedips.c133
1 files changed, 70 insertions, 63 deletions
diff --git a/src/allowedips.c b/src/allowedips.c
index bc43b71..b99078d 100644
--- a/src/allowedips.c
+++ b/src/allowedips.c
@@ -39,8 +39,7 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
node->bit_at_b = 7U - (cidr % 8U);
memcpy(node->bits, src, bits / 8U);
}
-
-#define choose_node(parent, key) \
+#define CHOOSE_NODE(parent, key) \
parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
static void node_free_rcu(struct rcu_head *rcu)
@@ -48,23 +47,26 @@ static void node_free_rcu(struct rcu_head *rcu)
kfree(container_of(rcu, struct allowedips_node, rcu));
}
-#define push_rcu(stack, p, len) ({ \
- if (rcu_access_pointer(p)) { \
- WARN_ON(IS_ENABLED(DEBUG) && (len) >= 128); \
- stack[(len)++] = rcu_dereference_raw(p); \
- } \
- true; \
- })
+static void push_rcu(struct allowedips_node **stack,
+ struct allowedips_node __rcu *p, unsigned int *len)
+{
+ if (rcu_access_pointer(p)) {
+ WARN_ON(IS_ENABLED(DEBUG) && *len >= 128);
+ stack[(*len)++] = rcu_dereference_raw(p);
+ }
+}
+
static void root_free_rcu(struct rcu_head *rcu)
{
struct allowedips_node *node, *stack[128] = {
container_of(rcu, struct allowedips_node, rcu) };
unsigned int len = 1;
- while (len > 0 && (node = stack[--len]) &&
- push_rcu(stack, node->bit[0], len) &&
- push_rcu(stack, node->bit[1], len))
+ while (len > 0 && (node = stack[--len])) {
+ push_rcu(stack, node->bit[0], &len);
+ push_rcu(stack, node->bit[1], &len);
kfree(node);
+ }
}
static int
@@ -74,6 +76,7 @@ walk_by_peer(struct allowedips_node __rcu *top, u8 bits,
void *ctx, struct mutex *lock)
{
const int address_family = bits == 32 ? AF_INET : AF_INET6;
+ /* Aligned so it can be treated as u64 */
u8 ip[16] __aligned(__alignof(u64));
struct allowedips_node *node;
int ret;
@@ -82,11 +85,11 @@ walk_by_peer(struct allowedips_node __rcu *top, u8 bits,
return 0;
if (!cursor->len)
- push_rcu(cursor->stack, top, cursor->len);
+ push_rcu(cursor->stack, top, &cursor->len);
for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]);
- --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len),
- push_rcu(cursor->stack, node->bit[1], cursor->len)) {
+ --cursor->len, push_rcu(cursor->stack, node->bit[0], &cursor->len),
+ push_rcu(cursor->stack, node->bit[1], &cursor->len)) {
const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
if (rcu_dereference_protected(node->peer,
@@ -105,62 +108,60 @@ walk_by_peer(struct allowedips_node __rcu *top, u8 bits,
return 0;
}
-#undef push_rcu
-
-#define ref(p) rcu_access_pointer(p)
-#define deref(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
-#define push(p) ({ \
+static void walk_remove_by_peer(struct allowedips_node __rcu **top,
+ struct wg_peer *peer, struct mutex *lock)
+{
+#define REF(p) rcu_access_pointer(p)
+#define DEREF(p) rcu_dereference_protected(*(p), lockdep_is_held(lock))
+#define PUSH(p) ({ \
WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
stack[len++] = p; \
})
-static void walk_remove_by_peer(struct allowedips_node __rcu **top,
- struct wg_peer *peer, struct mutex *lock)
-{
struct allowedips_node __rcu **stack[128], **nptr;
struct allowedips_node *node, *prev;
unsigned int len;
- if (unlikely(!peer || !ref(*top)))
+ if (unlikely(!peer || !REF(*top)))
return;
- for (prev = NULL, len = 0, push(top); len > 0; prev = node) {
+ for (prev = NULL, len = 0, PUSH(top); len > 0; prev = node) {
nptr = stack[len - 1];
- node = deref(nptr);
+ node = DEREF(nptr);
if (!node) {
--len;
continue;
}
- if (!prev || ref(prev->bit[0]) == node ||
- ref(prev->bit[1]) == node) {
- if (ref(node->bit[0]))
- push(&node->bit[0]);
- else if (ref(node->bit[1]))
- push(&node->bit[1]);
- } else if (ref(node->bit[0]) == prev) {
- if (ref(node->bit[1]))
- push(&node->bit[1]);
+ if (!prev || REF(prev->bit[0]) == node ||
+ REF(prev->bit[1]) == node) {
+ if (REF(node->bit[0]))
+ PUSH(&node->bit[0]);
+ else if (REF(node->bit[1]))
+ PUSH(&node->bit[1]);
+ } else if (REF(node->bit[0]) == prev) {
+ if (REF(node->bit[1]))
+ PUSH(&node->bit[1]);
} else {
if (rcu_dereference_protected(node->peer,
lockdep_is_held(lock)) == peer) {
RCU_INIT_POINTER(node->peer, NULL);
if (!node->bit[0] || !node->bit[1]) {
- rcu_assign_pointer(*nptr,
- deref(&node->bit[!ref(node->bit[0])]));
+ rcu_assign_pointer(*nptr, DEREF(
+ &node->bit[!REF(node->bit[0])]));
call_rcu_bh(&node->rcu, node_free_rcu);
- node = deref(nptr);
+ node = DEREF(nptr);
}
}
--len;
}
}
-}
-#undef ref
-#undef deref
-#undef push
+#undef REF
+#undef DEREF
+#undef PUSH
+}
-static __always_inline unsigned int fls128(u64 a, u64 b)
+static unsigned int fls128(u64 a, u64 b)
{
return a ? fls64(a) + 64U : fls64(b);
}
@@ -177,14 +178,17 @@ static __always_inline u8 common_bits(const struct allowedips_node *node,
return 0;
}
-/* This could be much faster if it actually just compared the common bits
- * properly, by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but
- * it turns out that common_bits is already super fast on modern processors,
- * even taking into account the unfortunate bswap. So, we just inline it like
- * this instead.
- */
-#define prefix_matches(node, key, bits) \
- (common_bits(node, key, bits) >= (node)->cidr)
+static __always_inline bool prefix_matches(const struct allowedips_node *node,
+ const u8 *key, u8 bits)
+{
+ /* This could be much faster if it actually just compared the common
+ * bits properly, by precomputing a mask bswap(~0 << (32 - cidr)), and
+ * the rest, but it turns out that common_bits is already super fast on
+ * modern processors, even taking into account the unfortunate bswap.
+ * So, we just inline it like this instead.
+ */
+ return common_bits(node, key, bits) >= node->cidr;
+}
static __always_inline struct allowedips_node *
find_node(struct allowedips_node *trie, u8 bits, const u8 *key)
@@ -196,7 +200,7 @@ find_node(struct allowedips_node *trie, u8 bits, const u8 *key)
found = node;
if (node->cidr == bits)
break;
- node = rcu_dereference_bh(choose_node(node, key));
+ node = rcu_dereference_bh(CHOOSE_NODE(node, key));
}
return found;
}
@@ -205,6 +209,7 @@ find_node(struct allowedips_node *trie, u8 bits, const u8 *key)
static __always_inline struct wg_peer *
lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip)
{
+ /* Aligned so it can be passed to fls/fls64 */
u8 ip[16] __aligned(__alignof(u64));
struct allowedips_node *node;
struct wg_peer *peer = NULL;
@@ -223,9 +228,9 @@ retry:
return peer;
}
-__attribute__((nonnull(1))) static bool
-node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr,
- u8 bits, struct allowedips_node **rnode, struct mutex *lock)
+static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
+ u8 cidr, u8 bits, struct allowedips_node **rnode,
+ struct mutex *lock)
{
struct allowedips_node *node = rcu_dereference_protected(trie,
lockdep_is_held(lock));
@@ -238,7 +243,7 @@ node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr,
exact = true;
break;
}
- node = rcu_dereference_protected(choose_node(parent, key),
+ node = rcu_dereference_protected(CHOOSE_NODE(parent, key),
lockdep_is_held(lock));
}
*rnode = parent;
@@ -276,10 +281,10 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
if (!node) {
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
} else {
- down = rcu_dereference_protected(choose_node(node, key),
+ down = rcu_dereference_protected(CHOOSE_NODE(node, key),
lockdep_is_held(lock));
if (!down) {
- rcu_assign_pointer(choose_node(node, key), newnode);
+ rcu_assign_pointer(CHOOSE_NODE(node, key), newnode);
return 0;
}
}
@@ -287,11 +292,11 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
parent = node;
if (newnode->cidr == cidr) {
- rcu_assign_pointer(choose_node(newnode, down->bits), down);
+ rcu_assign_pointer(CHOOSE_NODE(newnode, down->bits), down);
if (!parent)
rcu_assign_pointer(*trie, newnode);
else
- rcu_assign_pointer(choose_node(parent, newnode->bits),
+ rcu_assign_pointer(CHOOSE_NODE(parent, newnode->bits),
newnode);
} else {
node = kzalloc(sizeof(*node), GFP_KERNEL);
@@ -301,12 +306,12 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
}
copy_and_assign_cidr(node, newnode->bits, cidr, bits);
- rcu_assign_pointer(choose_node(node, down->bits), down);
- rcu_assign_pointer(choose_node(node, newnode->bits), newnode);
+ rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down);
+ rcu_assign_pointer(CHOOSE_NODE(node, newnode->bits), newnode);
if (!parent)
rcu_assign_pointer(*trie, node);
else
- rcu_assign_pointer(choose_node(parent, node->bits),
+ rcu_assign_pointer(CHOOSE_NODE(parent, node->bits),
node);
}
return 0;
@@ -336,6 +341,7 @@ int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
u8 cidr, struct wg_peer *peer,
struct mutex *lock)
{
+ /* Aligned so it can be passed to fls */
u8 key[4] __aligned(__alignof(u32));
++table->seq;
@@ -347,6 +353,7 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
u8 cidr, struct wg_peer *peer,
struct mutex *lock)
{
+ /* Aligned so it can be passed to fls64 */
u8 key[16] __aligned(__alignof(u64));
++table->seq;