summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/routingtable.c25
1 files changed, 9 insertions, 16 deletions
diff --git a/src/routingtable.c b/src/routingtable.c
index dee5e89..1df8561 100644
--- a/src/routingtable.c
+++ b/src/routingtable.c
@@ -17,8 +17,10 @@ static inline u8 bit_at(const u8 *key, u8 a, u8 b)
{
return (key[a] >> b) & 1;
}
-static inline void assign_cidr(struct routing_table_node *node, u8 cidr)
+static inline void copy_and_assign_cidr(struct routing_table_node *node, const u8 *src, u8 cidr)
{
+ memcpy(node->bits, src, (cidr + 7) / 8);
+ node->bits[(cidr + 7) / 8 - 1] &= 0xff << ((8 - (cidr % 8)) % 8);
node->cidr = cidr;
node->bit_at_a = cidr / 8;
node->bit_at_b = 7 - (cidr % 8);
@@ -208,17 +210,13 @@ static inline bool node_placement(struct routing_table_node __rcu *trie, const u
static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
{
struct routing_table_node *node, *parent, *down, *newnode;
- int bits_in_common;
if (!rcu_access_pointer(*trie)) {
node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL);
if (!node)
return -ENOMEM;
node->peer = peer;
- memcpy(node->bits, key, (cidr + 7) / 8);
- /* Not strictly neccessary for the data structure, but helps keep the data cleaner: */
- node->bits[(cidr + 7) / 8 - 1] &= 0xff << ((8 - (cidr % 8)) % 8);
- assign_cidr(node, cidr);
+ copy_and_assign_cidr(node, key, cidr);
rcu_assign_pointer(*trie, node);
return 0;
}
@@ -233,10 +231,7 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
if (!newnode)
return -ENOMEM;
newnode->peer = peer;
- memcpy(newnode->bits, key, (cidr + 7) / 8);
- /* Not strictly neccessary for the data structure, but helps keep the data cleaner: */
- newnode->bits[(cidr + 7) / 8 - 1] &= 0xff << ((8 - (cidr % 8)) % 8);
- assign_cidr(newnode, cidr);
+ copy_and_assign_cidr(newnode, key, cidr);
if (!node)
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
@@ -247,15 +242,13 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
return 0;
}
/* here we must be inserting between node and down */
- bits_in_common = common_bits(down, key, cidr);
+ cidr = min(cidr, common_bits(down, key, cidr));
parent = node;
- if (bits_in_common > cidr)
- bits_in_common = cidr;
/* we either need to make a new branch above down and newnode
* or newnode can be the branch. newnode can be the branch if
* its cidr == bits_in_common */
- if (newnode->cidr == bits_in_common) {
+ if (newnode->cidr == cidr) {
/* newnode can be the branch */
rcu_assign_pointer(newnode->bit[bit_at(down->bits, newnode->bit_at_a, newnode->bit_at_b)], down);
if (!parent)
@@ -269,9 +262,9 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
kfree(newnode);
return -ENOMEM;
}
- assign_cidr(node, bits_in_common);
node->incidental = true;
- memcpy(node->bits, newnode->bits, (bits + 7) / 8);
+ copy_and_assign_cidr(node, newnode->bits, cidr);
+
rcu_assign_pointer(node->bit[bit_at(down->bits, node->bit_at_a, node->bit_at_b)], down);
rcu_assign_pointer(node->bit[bit_at(newnode->bits, node->bit_at_a, node->bit_at_b)], newnode);
if (!parent)