diff options
-rw-r--r-- | src/allowedips.c | 61 |
1 files changed, 40 insertions, 21 deletions
diff --git a/src/allowedips.c b/src/allowedips.c index b9cd0aa..c1d9bb8 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -14,23 +14,29 @@ struct allowedips_node { * we're alloc'ing to the nearest power of 2 anyway, so this * doesn't actually make a difference. */ - union { - __be64 v6[2]; - __be32 v4; - u8 bits[16]; - }; + u8 bits[16] __aligned(__alignof(u64)); u8 cidr, bit_at_a, bit_at_b; }; -static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 cidr) +static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits) +{ + if (bits == 32) + *(u32 *)dst = be32_to_cpu(*(const __be32 *)src); + else if (bits == 128) { + ((u64 *)dst)[0] = be64_to_cpu(((const __be64 *)src)[0]); + ((u64 *)dst)[1] = be64_to_cpu(((const __be64 *)src)[1]); + } +} + +static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 cidr, u8 bits) { node->cidr = cidr; node->bit_at_a = cidr / 8; +#ifdef __LITTLE_ENDIAN + node->bit_at_a ^= (bits / 8 - 1) % 8; +#endif node->bit_at_b = 7 - (cidr % 8); - if (cidr) { - memcpy(node->bits, src, (cidr + 7) / 8); - node->bits[(cidr + 7) / 8 - 1] &= ~0U << ((8 - (cidr % 8)) % 8); - } + memcpy(node->bits, src, bits / 8); } #define choose_node(parent, key) parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] @@ -56,10 +62,11 @@ static void free_root_node(struct allowedips_node __rcu *top, struct mutex *lock call_rcu_bh(&node->rcu, node_free_rcu); } -static int walk_by_peer(struct allowedips_node __rcu *top, int family, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock) +static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock) { struct allowedips_node *node; int ret; + u8 ip[16] __aligned(__alignof(u64)); if (!rcu_access_pointer(top)) return 0; @@ -70,7 +77,13 @@ static int walk_by_peer(struct allowedips_node __rcu *top, int family, struct al for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push(cursor->stack, node->bit[0], cursor->len), push(cursor->stack, node->bit[1], cursor->len)) { if (node->peer != peer) continue; - ret = func(ctx, node->bits, node->cidr, family); + + swap_endian(ip, node->bits, bits); + memset(ip + (node->cidr + 7) / 8, 0, bits / 8 - (node->cidr + 7) / 8); + if (node->cidr) + ip[(node->cidr + 7) / 8 - 1] &= ~0U << ((8 - (node->cidr % 8)) % 8); + + ret = func(ctx, ip, node->cidr, bits == 32 ? AF_INET : AF_INET6); if (ret) return ret; } @@ -130,9 +143,9 @@ static __always_inline unsigned int fls128(u64 a, u64 b) static __always_inline u8 common_bits(const struct allowedips_node *node, const u8 *key, u8 bits) { if (bits == 32) - return 32 - fls(be32_to_cpu(*(const __be32 *)node->bits ^ *(const __be32 *)key)); + return 32 - fls(*(const u32 *)node->bits ^ *(const u32 *)key); else if (bits == 128) - return 128 - fls128(be64_to_cpu(*(const __be64 *)&node->bits[0] ^ *(const __be64 *)&key[0]), be64_to_cpu(*(const __be64 *)&node->bits[8] ^ *(const __be64 *)&key[8])); + return 128 - fls128(*(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); return 0; } @@ -158,10 +171,13 @@ static __always_inline struct allowedips_node *find_node(struct allowedips_node } /* Returns a strong reference to a peer */ -static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rcu *root, u8 bits, const void *ip) +static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip) { struct wireguard_peer *peer = NULL; struct allowedips_node *node; + u8 ip[16] __aligned(__alignof(u64)); + + swap_endian(ip, be_ip, bits); rcu_read_lock_bh(); node = find_node(rcu_dereference_bh(root), bits, ip); @@ -189,19 +205,22 @@ static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 * return exact; } -static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) +static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) { struct allowedips_node *node, *parent, *down, *newnode; + u8 key[16] __aligned(__alignof(u64)); if (unlikely(cidr > bits || !peer)) return -EINVAL; + swap_endian(key, be_key, bits); + if (!rcu_access_pointer(*trie)) { node = kzalloc(sizeof(*node), GFP_KERNEL); if (!node) return -ENOMEM; node->peer = peer; - copy_and_assign_cidr(node, key, cidr); + copy_and_assign_cidr(node, key, cidr, bits); rcu_assign_pointer(*trie, node); return 0; } @@ -214,7 +233,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, u8 c if (!newnode) return -ENOMEM; newnode->peer = peer; - copy_and_assign_cidr(newnode, key, cidr); + copy_and_assign_cidr(newnode, key, cidr, bits); if (!node) down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); @@ -240,7 +259,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, u8 c kfree(newnode); return -ENOMEM; } - copy_and_assign_cidr(node, newnode->bits, cidr); + copy_and_assign_cidr(node, newnode->bits, cidr, bits); rcu_assign_pointer(choose_node(node, down->bits), down); rcu_assign_pointer(choose_node(node, newnode->bits), newnode); @@ -296,13 +315,13 @@ int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor * return 0; if (!cursor->second_half) { - ret = walk_by_peer(table->root4, AF_INET, cursor, peer, func, ctx, lock); + ret = walk_by_peer(table->root4, 32, cursor, peer, func, ctx, lock); if (ret) return ret; cursor->len = 0; cursor->second_half = true; } - return walk_by_peer(table->root6, AF_INET6, cursor, peer, func, ctx, lock); + return walk_by_peer(table->root6, 128, cursor, peer, func, ctx, lock); } /* Returns a strong reference to a peer */ |