diff options
-rw-r--r-- | src/routingtable.c | 25 |
1 files changed, 9 insertions, 16 deletions
diff --git a/src/routingtable.c b/src/routingtable.c index dee5e89..1df8561 100644 --- a/src/routingtable.c +++ b/src/routingtable.c @@ -17,8 +17,10 @@ static inline u8 bit_at(const u8 *key, u8 a, u8 b) { return (key[a] >> b) & 1; } -static inline void assign_cidr(struct routing_table_node *node, u8 cidr) +static inline void copy_and_assign_cidr(struct routing_table_node *node, const u8 *src, u8 cidr) { + memcpy(node->bits, src, (cidr + 7) / 8); + node->bits[(cidr + 7) / 8 - 1] &= 0xff << ((8 - (cidr % 8)) % 8); node->cidr = cidr; node->bit_at_a = cidr / 8; node->bit_at_b = 7 - (cidr % 8); @@ -208,17 +210,13 @@ static inline bool node_placement(struct routing_table_node __rcu *trie, const u static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) { struct routing_table_node *node, *parent, *down, *newnode; - int bits_in_common; if (!rcu_access_pointer(*trie)) { node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL); if (!node) return -ENOMEM; node->peer = peer; - memcpy(node->bits, key, (cidr + 7) / 8); - /* Not strictly neccessary for the data structure, but helps keep the data cleaner: */ - node->bits[(cidr + 7) / 8 - 1] &= 0xff << ((8 - (cidr % 8)) % 8); - assign_cidr(node, cidr); + copy_and_assign_cidr(node, key, cidr); rcu_assign_pointer(*trie, node); return 0; } @@ -233,10 +231,7 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u if (!newnode) return -ENOMEM; newnode->peer = peer; - memcpy(newnode->bits, key, (cidr + 7) / 8); - /* Not strictly neccessary for the data structure, but helps keep the data cleaner: */ - newnode->bits[(cidr + 7) / 8 - 1] &= 0xff << ((8 - (cidr % 8)) % 8); - assign_cidr(newnode, cidr); + copy_and_assign_cidr(newnode, key, cidr); if (!node) down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); @@ -247,15 +242,13 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u return 0; } /* here we must be inserting between node and down */ - bits_in_common = common_bits(down, key, cidr); + cidr = min(cidr, common_bits(down, key, cidr)); parent = node; - if (bits_in_common > cidr) - bits_in_common = cidr; /* we either need to make a new branch above down and newnode * or newnode can be the branch. newnode can be the branch if * its cidr == bits_in_common */ - if (newnode->cidr == bits_in_common) { + if (newnode->cidr == cidr) { /* newnode can be the branch */ rcu_assign_pointer(newnode->bit[bit_at(down->bits, newnode->bit_at_a, newnode->bit_at_b)], down); if (!parent) @@ -269,9 +262,9 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u kfree(newnode); return -ENOMEM; } - assign_cidr(node, bits_in_common); node->incidental = true; - memcpy(node->bits, newnode->bits, (bits + 7) / 8); + copy_and_assign_cidr(node, newnode->bits, cidr); + rcu_assign_pointer(node->bit[bit_at(down->bits, node->bit_at_a, node->bit_at_b)], down); rcu_assign_pointer(node->bit[bit_at(newnode->bits, node->bit_at_a, node->bit_at_b)], newnode); if (!parent) |