summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/allowedips.c61
1 files changed, 40 insertions, 21 deletions
diff --git a/src/allowedips.c b/src/allowedips.c
index b9cd0aa..c1d9bb8 100644
--- a/src/allowedips.c
+++ b/src/allowedips.c
@@ -14,23 +14,29 @@ struct allowedips_node {
* we're alloc'ing to the nearest power of 2 anyway, so this
* doesn't actually make a difference.
*/
- union {
- __be64 v6[2];
- __be32 v4;
- u8 bits[16];
- };
+ u8 bits[16] __aligned(__alignof(u64));
u8 cidr, bit_at_a, bit_at_b;
};
-static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 cidr)
+static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits)
+{
+ if (bits == 32)
+ *(u32 *)dst = be32_to_cpu(*(const __be32 *)src);
+ else if (bits == 128) {
+ ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]);
+ ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]);
+ }
+}
+
+static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 cidr, u8 bits)
{
node->cidr = cidr;
node->bit_at_a = cidr / 8;
+#ifdef __LITTLE_ENDIAN
+ node->bit_at_a ^= (bits / 8 - 1) % 8;
+#endif
node->bit_at_b = 7 - (cidr % 8);
- if (cidr) {
- memcpy(node->bits, src, (cidr + 7) / 8);
- node->bits[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8);
- }
+ memcpy(node->bits, src, bits / 8);
}
#define choose_node(parent, key) parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
@@ -56,10 +62,11 @@ static void free_root_node(struct allowedips_node __rcu *top, struct mutex *lock
call_rcu_bh(&node->rcu, node_free_rcu);
}
-static int walk_by_peer(struct allowedips_node __rcu *top, int family, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
+static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
{
struct allowedips_node *node;
int ret;
+ u8 ip[16] __aligned(__alignof(u64));
if (!rcu_access_pointer(top))
return 0;
@@ -70,7 +77,13 @@ static int walk_by_peer(struct allowedips_node __rcu *top, int family, struct al
for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push(cursor->stack, node->bit[0], cursor->len), push(cursor->stack, node->bit[1], cursor->len)) {
if (node->peer != peer)
continue;
- ret = func(ctx, node->bits, node->cidr, family);
+
+ swap_endian(ip, node->bits, bits);
+ memset(ip + (node->cidr + 7) / 8, 0, bits / 8 - (node->cidr + 7) / 8);
+ if (node->cidr)
+ ip[(node->cidr + 7) / 8 - 1] &= ~0U << ((8 - (node->cidr % 8)) % 8);
+
+ ret = func(ctx, ip, node->cidr, bits == 32 ? AF_INET : AF_INET6);
if (ret)
return ret;
}
@@ -130,9 +143,9 @@ static __always_inline unsigned int fls128(u64 a, u64 b)
static __always_inline u8 common_bits(const struct allowedips_node *node, const u8 *key, u8 bits)
{
if (bits == 32)
- return 32 - fls(be32_to_cpu(*(const __be32 *)node->bits ^ *(const __be32 *)key));
+ return 32 - fls(*(const u32 *)node->bits ^ *(const u32 *)key);
else if (bits == 128)
- return 128 - fls128(be64_to_cpu(*(const __be64 *)&node->bits[0] ^ *(const __be64 *)&key[0]), be64_to_cpu(*(const __be64 *)&node->bits[8] ^ *(const __be64 *)&key[8]));
+ return 128 - fls128(*(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]);
return 0;
}
@@ -158,10 +171,13 @@ static __always_inline struct allowedips_node *find_node(struct allowedips_node
}
/* Returns a strong reference to a peer */
-static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rcu *root, u8 bits, const void *ip)
+static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip)
{
struct wireguard_peer *peer = NULL;
struct allowedips_node *node;
+ u8 ip[16] __aligned(__alignof(u64));
+
+ swap_endian(ip, be_ip, bits);
rcu_read_lock_bh();
node = find_node(rcu_dereference_bh(root), bits, ip);
@@ -189,19 +205,22 @@ static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 *
return exact;
}
-static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
+static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
{
struct allowedips_node *node, *parent, *down, *newnode;
+ u8 key[16] __aligned(__alignof(u64));
if (unlikely(cidr > bits || !peer))
return -EINVAL;
+ swap_endian(key, be_key, bits);
+
if (!rcu_access_pointer(*trie)) {
node = kzalloc(sizeof(*node), GFP_KERNEL);
if (!node)
return -ENOMEM;
node->peer = peer;
- copy_and_assign_cidr(node, key, cidr);
+ copy_and_assign_cidr(node, key, cidr, bits);
rcu_assign_pointer(*trie, node);
return 0;
}
@@ -214,7 +233,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, u8 c
if (!newnode)
return -ENOMEM;
newnode->peer = peer;
- copy_and_assign_cidr(newnode, key, cidr);
+ copy_and_assign_cidr(newnode, key, cidr, bits);
if (!node)
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
@@ -240,7 +259,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, u8 c
kfree(newnode);
return -ENOMEM;
}
- copy_and_assign_cidr(node, newnode->bits, cidr);
+ copy_and_assign_cidr(node, newnode->bits, cidr, bits);
rcu_assign_pointer(choose_node(node, down->bits), down);
rcu_assign_pointer(choose_node(node, newnode->bits), newnode);
@@ -296,13 +315,13 @@ int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *
return 0;
if (!cursor->second_half) {
- ret = walk_by_peer(table->root4, AF_INET, cursor, peer, func, ctx, lock);
+ ret = walk_by_peer(table->root4, 32, cursor, peer, func, ctx, lock);
if (ret)
return ret;
cursor->len = 0;
cursor->second_half = true;
}
- return walk_by_peer(table->root6, AF_INET6, cursor, peer, func, ctx, lock);
+ return walk_by_peer(table->root6, 128, cursor, peer, func, ctx, lock);
}
/* Returns a strong reference to a peer */