diff options
-rw-r--r-- | src/allowedips.c | 100 | ||||
-rw-r--r-- | src/allowedips.h | 33 | ||||
-rw-r--r-- | src/netlink.c | 61 | ||||
-rw-r--r-- | src/peer.c | 1 | ||||
-rw-r--r-- | src/peer.h | 1 | ||||
-rw-r--r-- | src/selftest/allowedips.c | 93 |
6 files changed, 129 insertions, 160 deletions
diff --git a/src/allowedips.c b/src/allowedips.c index 30b66f4..bfb6020 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -6,18 +6,6 @@ #include "allowedips.h" #include "peer.h" -struct allowedips_node { - struct wg_peer __rcu *peer; - struct rcu_head rcu; - struct allowedips_node __rcu *bit[2]; - /* While it may seem scandalous that we waste space for v4, - * we're alloc'ing to the nearest power of 2 anyway, so this - * doesn't actually make a difference. - */ - u8 bits[16] __aligned(__alignof(u64)); - u8 cidr, bit_at_a, bit_at_b; -}; - static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits) { if (bits == 32) { @@ -37,6 +25,7 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, node->bit_at_a ^= (bits / 8U - 1U) % 8U; #endif node->bit_at_b = 7U - (cidr % 8U); + node->bitlen = bits; memcpy(node->bits, src, bits / 8U); } #define CHOOSE_NODE(parent, key) \ @@ -69,43 +58,17 @@ static void root_free_rcu(struct rcu_head *rcu) } } -static int -walk_by_peer(struct allowedips_node __rcu *top, u8 bits, - struct allowedips_cursor *cursor, struct wg_peer *peer, - int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), - void *ctx, struct mutex *lock) +static void root_remove_peer_lists(struct allowedips_node *root) { - const int address_family = bits == 32 ? AF_INET : AF_INET6; - /* Aligned so it can be treated as u64 */ - u8 ip[16] __aligned(__alignof(u64)); - struct allowedips_node *node; - int ret; - - if (!rcu_access_pointer(top)) - return 0; - - if (!cursor->len) - push_rcu(cursor->stack, top, &cursor->len); - - for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); - --cursor->len, push_rcu(cursor->stack, node->bit[0], &cursor->len), - push_rcu(cursor->stack, node->bit[1], &cursor->len)) { - const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); - - if (rcu_dereference_protected(node->peer, - lockdep_is_held(lock)) != peer) - continue; - - swap_endian(ip, node->bits, bits); - memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes); - if (node->cidr) - ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); + struct allowedips_node *node, *stack[128] = { root }; + unsigned int len = 1; - ret = func(ctx, ip, node->cidr, address_family); - if (ret) - return ret; + while (len > 0 && (node = stack[--len])) { + push_rcu(stack, node->bit[0], &len); + push_rcu(stack, node->bit[1], &len); + if (rcu_access_pointer(node->peer)) + list_del(&node->peer_list); } - return 0; } static void walk_remove_by_peer(struct allowedips_node __rcu **top, @@ -145,6 +108,7 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) { RCU_INIT_POINTER(node->peer, NULL); + list_del(&node->peer_list); if (!node->bit[0] || !node->bit[1]) { rcu_assign_pointer(*nptr, DEREF( &node->bit[!REF(node->bit[0])])); @@ -263,12 +227,14 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (unlikely(!node)) return -ENOMEM; RCU_INIT_POINTER(node->peer, peer); + list_add_tail(&node->peer_list, &peer->allowedips_list); copy_and_assign_cidr(node, key, cidr, bits); rcu_assign_pointer(*trie, node); return 0; } if (node_placement(*trie, key, cidr, bits, &node, lock)) { rcu_assign_pointer(node->peer, peer); + list_move_tail(&node->peer_list, &peer->allowedips_list); return 0; } @@ -276,6 +242,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, if (unlikely(!newnode)) return -ENOMEM; RCU_INIT_POINTER(newnode->peer, peer); + list_add_tail(&newnode->peer_list, &peer->allowedips_list); copy_and_assign_cidr(newnode, key, cidr, bits); if (!node) { @@ -304,6 +271,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key, kfree(newnode); return -ENOMEM; } + INIT_LIST_HEAD(&node->peer_list); copy_and_assign_cidr(node, newnode->bits, cidr, bits); rcu_assign_pointer(CHOOSE_NODE(node, down->bits), down); @@ -326,15 +294,20 @@ void wg_allowedips_init(struct allowedips *table) void wg_allowedips_free(struct allowedips *table, struct mutex *lock) { struct allowedips_node __rcu *old4 = table->root4, *old6 = table->root6; + ++table->seq; RCU_INIT_POINTER(table->root4, NULL); RCU_INIT_POINTER(table->root6, NULL); - if (rcu_access_pointer(old4)) + if (rcu_access_pointer(old4)) { + root_remove_peer_lists(old4); call_rcu_bh(&rcu_dereference_protected(old4, lockdep_is_held(lock))->rcu, root_free_rcu); - if (rcu_access_pointer(old6)) + } + if (rcu_access_pointer(old6)) { + root_remove_peer_lists(old6); call_rcu_bh(&rcu_dereference_protected(old6, lockdep_is_held(lock))->rcu, root_free_rcu); + } } int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, @@ -367,29 +340,16 @@ void wg_allowedips_remove_by_peer(struct allowedips *table, walk_remove_by_peer(&table->root6, peer, lock); } -int wg_allowedips_walk_by_peer(struct allowedips *table, - struct allowedips_cursor *cursor, - struct wg_peer *peer, - int (*func)(void *ctx, const u8 *ip, u8 cidr, - int family), - void *ctx, struct mutex *lock) +int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr) { - int ret; - - if (!cursor->seq) - cursor->seq = table->seq; - else if (cursor->seq != table->seq) - return 0; - - if (!cursor->second_half) { - ret = walk_by_peer(table->root4, 32, cursor, peer, func, ctx, - lock); - if (ret) - return ret; - cursor->len = 0; - cursor->second_half = true; - } - return walk_by_peer(table->root6, 128, cursor, peer, func, ctx, lock); + const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); + swap_endian(ip, node->bits, node->bitlen); + memset(ip + cidr_bytes, 0, node->bitlen / 8U - cidr_bytes); + if (node->cidr) + ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); + + *cidr = node->cidr; + return node->bitlen == 32 ? AF_INET : AF_INET6; } /* Returns a strong reference to a peer */ diff --git a/src/allowedips.h b/src/allowedips.h index 29e15a2..e5c83ca 100644 --- a/src/allowedips.h +++ b/src/allowedips.h @@ -11,7 +11,23 @@ #include <linux/ipv6.h> struct wg_peer; -struct allowedips_node; + +struct allowedips_node { + struct wg_peer __rcu *peer; + struct allowedips_node __rcu *bit[2]; + /* While it may seem scandalous that we waste space for v4, + * we're alloc'ing to the nearest power of 2 anyway, so this + * doesn't actually make a difference. + */ + u8 bits[16] __aligned(__alignof(u64)); + u8 cidr, bit_at_a, bit_at_b, bitlen; + + /* Keep rarely used list at bottom to be beyond cache line. */ + union { + struct list_head peer_list; + struct rcu_head rcu; + }; +}; struct allowedips { struct allowedips_node __rcu *root4; @@ -19,13 +35,6 @@ struct allowedips { u64 seq; }; -struct allowedips_cursor { - u64 seq; - struct allowedips_node *stack[128]; - unsigned int len; - bool second_half; -}; - void wg_allowedips_init(struct allowedips *table); void wg_allowedips_free(struct allowedips *table, struct mutex *mutex); int wg_allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, @@ -34,12 +43,8 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wg_peer *peer, struct mutex *lock); void wg_allowedips_remove_by_peer(struct allowedips *table, struct wg_peer *peer, struct mutex *lock); -int wg_allowedips_walk_by_peer(struct allowedips *table, - struct allowedips_cursor *cursor, - struct wg_peer *peer, - int (*func)(void *ctx, const u8 *ip, u8 cidr, - int family), - void *ctx, struct mutex *lock); +/* The ip input pointer should be __aligned(__alignof(u64))) */ +int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr); /* These return a strong reference to a peer: */ struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, diff --git a/src/netlink.c b/src/netlink.c index f44f211..b179b31 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -69,9 +69,9 @@ static struct wg_device *lookup_interface(struct nlattr **attrs, return netdev_priv(dev); } -static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) +static int get_allowedips(struct sk_buff *skb, const u8 *ip, u8 cidr, + int family) { - struct sk_buff *skb = ctx; struct nlattr *allowedip_nest; allowedip_nest = nla_nest_start(skb, 0); @@ -90,10 +90,12 @@ static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) return 0; } -static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, - struct sk_buff *skb) +static int +get_peer(struct wg_peer *peer, struct allowedips_node **next_allowedips_node, + u64 *allowedips_seq, struct sk_buff *skb) { struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, 0); + struct allowedips_node *allowedips_node = *next_allowedips_node; bool fail; if (!peer_nest) @@ -106,7 +108,7 @@ static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, if (fail) goto err; - if (!rt_cursor->seq) { + if (!allowedips_node) { const struct __kernel_timespec last_handshake = { .tv_sec = peer->walltime_last_handshake.tv_sec, .tv_nsec = peer->walltime_last_handshake.tv_nsec @@ -143,21 +145,39 @@ static int get_peer(struct wg_peer *peer, struct allowedips_cursor *rt_cursor, read_unlock_bh(&peer->endpoint_lock); if (fail) goto err; + allowedips_node = + list_first_entry_or_null(&peer->allowedips_list, + struct allowedips_node, peer_list); } + if (!allowedips_node) + goto no_allowedips; + if (!*allowedips_seq) + *allowedips_seq = peer->device->peer_allowedips.seq; + else if (*allowedips_seq != peer->device->peer_allowedips.seq) + goto no_allowedips; allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); if (!allowedips_nest) goto err; - if (wg_allowedips_walk_by_peer(&peer->device->peer_allowedips, - rt_cursor, peer, get_allowedips, skb, - &peer->device->device_update_lock)) { - nla_nest_end(skb, allowedips_nest); - nla_nest_end(skb, peer_nest); - return -EMSGSIZE; + + list_for_each_entry_from(allowedips_node, &peer->allowedips_list, + peer_list) { + u8 cidr, ip[16] __aligned(__alignof(u64)); + int family; + + family = wg_allowedips_read_node(allowedips_node, ip, &cidr); + if (get_allowedips(skb, ip, cidr, family)) { + nla_nest_end(skb, allowedips_nest); + nla_nest_end(skb, peer_nest); + *next_allowedips_node = allowedips_node; + return -EMSGSIZE; + } } - memset(rt_cursor, 0, sizeof(*rt_cursor)); nla_nest_end(skb, allowedips_nest); +no_allowedips: nla_nest_end(skb, peer_nest); + *next_allowedips_node = NULL; + *allowedips_seq = 0; return 0; err: nla_nest_cancel(skb, peer_nest); @@ -174,16 +194,9 @@ static int wg_get_device_start(struct netlink_callback *cb) genl_family.maxattr, device_policy, NULL); if (ret < 0) return ret; - cb->args[2] = (long)kzalloc(sizeof(struct allowedips_cursor), - GFP_KERNEL); - if (unlikely(!cb->args[2])) - return -ENOMEM; wg = lookup_interface(attrs, cb->skb); - if (IS_ERR(wg)) { - kfree((void *)cb->args[2]); - cb->args[2] = 0; + if (IS_ERR(wg)) return PTR_ERR(wg); - } cb->args[0] = (long)wg; return 0; } @@ -191,7 +204,6 @@ static int wg_get_device_start(struct netlink_callback *cb) static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) { struct wg_peer *peer, *next_peer_cursor, *last_peer_cursor; - struct allowedips_cursor *rt_cursor; struct nlattr *peers_nest; struct wg_device *wg; int ret = -EMSGSIZE; @@ -201,7 +213,6 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) wg = (struct wg_device *)cb->args[0]; next_peer_cursor = (struct wg_peer *)cb->args[1]; last_peer_cursor = (struct wg_peer *)cb->args[1]; - rt_cursor = (struct allowedips_cursor *)cb->args[2]; rtnl_lock(); mutex_lock(&wg->device_update_lock); @@ -253,7 +264,8 @@ static int wg_get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list); list_for_each_entry_continue(peer, &wg->peer_list, peer_list) { - if (get_peer(peer, rt_cursor, skb)) { + if (get_peer(peer, (struct allowedips_node **)&cb->args[2], + (u64 *)&cb->args[4] /* and args[5] */, skb)) { done = false; break; } @@ -290,12 +302,9 @@ static int wg_get_device_done(struct netlink_callback *cb) { struct wg_device *wg = (struct wg_device *)cb->args[0]; struct wg_peer *peer = (struct wg_peer *)cb->args[1]; - struct allowedips_cursor *rt_cursor = - (struct allowedips_cursor *)cb->args[2]; if (wg) dev_put(wg->dev); - kfree(rt_cursor); wg_peer_put(peer); return 0; } @@ -64,6 +64,7 @@ struct wg_peer *wg_peer_create(struct wg_device *wg, NAPI_POLL_WEIGHT); napi_enable(&peer->napi); list_add_tail(&peer->peer_list, &wg->peer_list); + INIT_LIST_HEAD(&peer->allowedips_list); wg_pubkey_hashtable_add(wg->peer_hashtable, peer); ++wg->num_peers; pr_debug("%s: Peer %llu created\n", wg->dev->name, peer->internal_id); @@ -60,6 +60,7 @@ struct wg_peer { struct kref refcount; struct rcu_head rcu; struct list_head peer_list; + struct list_head allowedips_list; u64 internal_id; struct napi_struct napi; bool is_dead; diff --git a/src/selftest/allowedips.c b/src/selftest/allowedips.c index 379ac31..6e244a9 100644 --- a/src/selftest/allowedips.c +++ b/src/selftest/allowedips.c @@ -452,47 +452,14 @@ static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d) return &ip; } -struct walk_ctx { - int count; - bool found_a, found_b, found_c, found_d, found_e; - bool found_other; -}; - -static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family) -{ - struct walk_ctx *wctx = ctx; - - wctx->count++; - - if (cidr == 27 && - !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) - wctx->found_a = true; - else if (cidr == 128 && - !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), - sizeof(struct in6_addr))) - wctx->found_b = true; - else if (cidr == 29 && - !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) - wctx->found_c = true; - else if (cidr == 83 && - !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), - sizeof(struct in6_addr))) - wctx->found_d = true; - else if (cidr == 21 && - !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr))) - wctx->found_e = true; - else - wctx->found_other = true; - - return 0; -} - static __init struct wg_peer *init_peer(void) { struct wg_peer *peer = kzalloc(sizeof(*peer), GFP_KERNEL); - if (peer) - kref_init(&peer->refcount); + if (!peer) + return NULL; + kref_init(&peer->refcount); + INIT_LIST_HEAD(&peer->allowedips_list); return peer; } @@ -527,23 +494,24 @@ static __init struct wg_peer *init_peer(void) bool __init wg_allowedips_selftest(void) { - struct allowedips_cursor *cursor = kzalloc(sizeof(*cursor), GFP_KERNEL); + bool found_a = false, found_b = false, found_c = false, found_d = false, + found_e = false, found_other = false; struct wg_peer *a = init_peer(), *b = init_peer(), *c = init_peer(), *d = init_peer(), *e = init_peer(), *f = init_peer(), *g = init_peer(), *h = init_peer(); - struct walk_ctx wctx = { 0 }; + struct allowedips_node *iter_node; bool success = false; struct allowedips t; DEFINE_MUTEX(mutex); struct in6_addr ip; - size_t i = 0; + size_t i = 0, count = 0; __be64 part; mutex_init(&mutex); mutex_lock(&mutex); wg_allowedips_init(&t); - if (!cursor || !a || !b || !c || !d || !e || !f || !g || !h) { + if (!a || !b || !c || !d || !e || !f || !g || !h) { pr_err("allowedips self-test malloc: FAIL\n"); goto free; } @@ -649,14 +617,40 @@ bool __init wg_allowedips_selftest(void) insert(4, a, 10, 1, 0, 20, 29); insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 83); insert(6, a, 0x26075300, 0x6d8a6bf8, 0xdab1f1df, 0xc05f1523, 21); - wg_allowedips_walk_by_peer(&t, cursor, a, walk_callback, &wctx, &mutex); - test_boolean(wctx.count == 5); - test_boolean(wctx.found_a); - test_boolean(wctx.found_b); - test_boolean(wctx.found_c); - test_boolean(wctx.found_d); - test_boolean(wctx.found_e); - test_boolean(!wctx.found_other); + list_for_each_entry(iter_node, &a->allowedips_list, peer_list) { + u8 cidr, ip[16] __aligned(__alignof(u64)); + int family = wg_allowedips_read_node(iter_node, ip, &cidr); + + count++; + + if (cidr == 27 && family == AF_INET && + !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) + found_a = true; + else if (cidr == 128 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), + sizeof(struct in6_addr))) + found_b = true; + else if (cidr == 29 && family == AF_INET && + !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) + found_c = true; + else if (cidr == 83 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), + sizeof(struct in6_addr))) + found_d = true; + else if (cidr == 21 && family == AF_INET6 && + !memcmp(ip, ip6(0x26075000, 0, 0, 0), + sizeof(struct in6_addr))) + found_e = true; + else + found_other = true; + } + test_boolean(count == 5); + test_boolean(found_a); + test_boolean(found_b); + test_boolean(found_c); + test_boolean(found_d); + test_boolean(found_e); + test_boolean(!found_other); if (IS_ENABLED(DEBUG_RANDOM_TRIE) && success) success = randomized_test(); @@ -675,7 +669,6 @@ free: kfree(g); kfree(h); mutex_unlock(&mutex); - kfree(cursor); return success; } |