diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/netlink.c | 37 | ||||
-rw-r--r-- | src/routingtable.c | 76 | ||||
-rw-r--r-- | src/routingtable.h | 17 | ||||
-rwxr-xr-x | src/tests/netns.sh | 4 |
4 files changed, 81 insertions, 53 deletions
diff --git a/src/netlink.c b/src/netlink.c index fc27b7f..e727669 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -63,22 +63,20 @@ static struct wireguard_device *lookup_interface(struct nlattr **attrs, struct s struct allowedips_ctx { struct sk_buff *skb; - unsigned int idx_cursor, idx; + unsigned int i; }; -static int get_allowedips(void *ctx, union nf_inet_addr ip, u8 cidr, int family) +static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) { struct nlattr *allowedip_nest; struct allowedips_ctx *actx = ctx; - if (++actx->idx < actx->idx_cursor) - return 0; - allowedip_nest = nla_nest_start(actx->skb, actx->idx - 1); + allowedip_nest = nla_nest_start(actx->skb, actx->i++); if (!allowedip_nest) return -EMSGSIZE; if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) || nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) || - nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), &ip)) { + nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) { nla_nest_cancel(actx->skb, allowedip_nest); return -EMSGSIZE; } @@ -87,9 +85,9 @@ static int get_allowedips(void *ctx, union nf_inet_addr ip, u8 cidr, int family) return 0; } -static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned int *allowedips_idx_cursor, struct sk_buff *skb) +static int get_peer(struct wireguard_peer *peer, unsigned int index, struct routing_table_cursor *rt_cursor, struct sk_buff *skb) { - struct allowedips_ctx ctx = { .skb = skb, .idx_cursor = *allowedips_idx_cursor }; + struct allowedips_ctx ctx = { .skb = skb }; struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, index); bool fail; @@ -102,7 +100,7 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned in if (fail) goto err; - if (!ctx.idx_cursor) { + if (!rt_cursor->seq) { down_read(&peer->handshake.lock); fail = nla_put(skb, WGPEER_A_PRESHARED_KEY, NOISE_SYMMETRIC_KEY_LEN, peer->handshake.preshared_key); up_read(&peer->handshake.lock); @@ -126,13 +124,12 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned in allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); if (!allowedips_nest) goto err; - if (routing_table_walk_ips_by_peer(&peer->device->peer_routing_table, &ctx, peer, get_allowedips, &peer->device->device_update_lock)) { - *allowedips_idx_cursor = ctx.idx; + if (routing_table_walk_by_peer(&peer->device->peer_routing_table, rt_cursor, peer, get_allowedips, &ctx, &peer->device->device_update_lock)) { nla_nest_end(skb, allowedips_nest); nla_nest_end(skb, peer_nest); return -EMSGSIZE; } - *allowedips_idx_cursor = 0; + memset(rt_cursor, 0, sizeof(*rt_cursor)); nla_nest_end(skb, allowedips_nest); nla_nest_end(skb, peer_nest); return 0; @@ -149,9 +146,15 @@ static int get_device_start(struct netlink_callback *cb) if (ret < 0) return ret; + cb->args[2] = (long)kzalloc(sizeof(struct routing_table_cursor), GFP_KERNEL); + if (!cb->args[2]) + return -ENOMEM; wg = lookup_interface(attrs, cb->skb); - if (IS_ERR(wg)) + if (IS_ERR(wg)) { + kfree((void *)cb->args[2]); + cb->args[2] = 0; return PTR_ERR(wg); + } cb->args[0] = (long)wg; return 0; } @@ -160,7 +163,8 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) { struct wireguard_device *wg = (struct wireguard_device *)cb->args[0]; struct wireguard_peer *peer, *next_peer_cursor = NULL, *last_peer_cursor = (struct wireguard_peer *)cb->args[1]; - unsigned int peer_idx = 0, allowedips_idx_cursor = (unsigned int)cb->args[2]; + struct routing_table_cursor *rt_cursor = (struct routing_table_cursor *)cb->args[2]; + unsigned int peer_idx = 0; struct nlattr *peers_nest; bool done = true; void *hdr; @@ -203,7 +207,7 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list); list_for_each_entry_continue (peer, &wg->peer_list, peer_list) { - if (get_peer(peer, peer_idx++, &allowedips_idx_cursor, skb)) { + if (get_peer(peer, peer_idx++, rt_cursor, skb)) { done = false; break; } @@ -228,7 +232,6 @@ out: return 0; } cb->args[1] = (long)next_peer_cursor; - cb->args[2] = (long)allowedips_idx_cursor; return skb->len; /* At this point, we can't really deal ourselves with safely zeroing out @@ -240,9 +243,11 @@ static int get_device_done(struct netlink_callback *cb) { struct wireguard_device *wg = (struct wireguard_device *)cb->args[0]; struct wireguard_peer *peer = (struct wireguard_peer *)cb->args[1]; + struct routing_table_cursor *rt_cursor = (struct routing_table_cursor *)cb->args[2]; if (wg) dev_put(wg->dev); + kfree(rt_cursor); peer_put(peer); return 0; } diff --git a/src/routingtable.c b/src/routingtable.c index e884383..6cbea3d 100644 --- a/src/routingtable.c +++ b/src/routingtable.c @@ -25,40 +25,38 @@ static void node_free_rcu(struct rcu_head *rcu) { kfree(container_of(rcu, struct routing_table_node, rcu)); } -#define push(p, lock) ({ \ + +#define push(stack, p, len) ({ \ if (rcu_access_pointer(p)) { \ BUG_ON(len >= 128); \ - stack[len++] = lock ? rcu_dereference_protected(p, lockdep_is_held((struct mutex *)lock)) : rcu_dereference_bh(p); \ + stack[len++] = rcu_dereference_protected(p, lockdep_is_held(lock)); \ } \ true; \ }) -#define walk_prep \ - struct routing_table_node *stack[128], *node; \ - unsigned int len; -#define walk(top, lock) for (len = 0, push(top, lock); len > 0 && (node = stack[--len]) && push(node->bit[0], lock) && push(node->bit[1], lock);) - static void free_root_node(struct routing_table_node __rcu *top, struct mutex *lock) { - walk_prep; + struct routing_table_node *stack[128], *node; + unsigned int len; - walk (top, lock) + for (len = 0, push(stack, top, len); len > 0 && (node = stack[--len]) && push(stack, node->bit[0], len) && push(stack, node->bit[1], len);) call_rcu_bh(&node->rcu, node_free_rcu); } -static int walk_ips_by_peer(struct routing_table_node __rcu *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock) +static int walk_by_peer(struct routing_table_node __rcu *top, int family, struct routing_table_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock) { + struct routing_table_node *node; int ret; - union nf_inet_addr ip = { .all = { 0 } }; - walk_prep; - if (unlikely(!peer)) + if (!rcu_access_pointer(top)) return 0; - walk (top, maybe_lock) { + if (!cursor->len) + push(cursor->stack, top, cursor->len); + + for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push(cursor->stack, node->bit[0], cursor->len), push(cursor->stack, node->bit[1], cursor->len)) { if (node->peer != peer) continue; - memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4); - ret = func(ctx, ip, node->cidr, family); + ret = func(ctx, node->bits, node->cidr, family); if (ret) return ret; } @@ -234,37 +232,55 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u void routing_table_init(struct routing_table *table) { - memset(table, 0, sizeof(struct routing_table)); + table->root4 = table->root6 = NULL; + table->seq = 1; } -void routing_table_free(struct routing_table *table, struct mutex *mutex) +void routing_table_free(struct routing_table *table, struct mutex *lock) { - free_root_node(table->root4, mutex); + ++table->seq; + free_root_node(table->root4, lock); rcu_assign_pointer(table->root4, NULL); - free_root_node(table->root6, mutex); + free_root_node(table->root6, lock); rcu_assign_pointer(table->root6, NULL); } -int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex) +int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) { - return add(&table->root4, 32, (const u8 *)ip, cidr, peer, mutex); + ++table->seq; + return add(&table->root4, 32, (const u8 *)ip, cidr, peer, lock); } -int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex) +int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) { - return add(&table->root6, 128, (const u8 *)ip, cidr, peer, mutex); + ++table->seq; + return add(&table->root6, 128, (const u8 *)ip, cidr, peer, lock); } -void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex) +void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *lock) { - walk_remove_by_peer(&table->root4, peer, mutex); - walk_remove_by_peer(&table->root6, peer, mutex); + ++table->seq; + walk_remove_by_peer(&table->root4, peer, lock); + walk_remove_by_peer(&table->root6, peer, lock); } -int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex) +int routing_table_walk_by_peer(struct routing_table *table, struct routing_table_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock) { - return walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, mutex) ?: - walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, mutex); + int ret; + + if (!cursor->seq) + cursor->seq = table->seq; + else if (cursor->seq != table->seq) + return 0; + + if (!cursor->second_half) { + ret = walk_by_peer(table->root4, AF_INET, cursor, peer, func, ctx, lock); + if (ret) + return ret; + cursor->len = 0; + cursor->second_half = true; + } + return walk_by_peer(table->root6, AF_INET6, cursor, peer, func, ctx, lock); } /* Returns a strong reference to a peer */ diff --git a/src/routingtable.h b/src/routingtable.h index 815118c..d8666f8 100644 --- a/src/routingtable.h +++ b/src/routingtable.h @@ -13,15 +13,22 @@ struct routing_table_node; struct routing_table { struct routing_table_node __rcu *root4; struct routing_table_node __rcu *root6; + u64 seq; +}; + +struct routing_table_cursor { + u64 seq; + struct routing_table_node *stack[128]; + unsigned int len; + bool second_half; }; void routing_table_init(struct routing_table *table); void routing_table_free(struct routing_table *table, struct mutex *mutex); -int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex); -int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex); -void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex); - -int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex); +int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock); +int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock); +void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *lock); +int routing_table_walk_by_peer(struct routing_table *table, struct routing_table_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock); /* These return a strong reference to a peer: */ struct wireguard_peer *routing_table_lookup_dst(struct routing_table *table, struct sk_buff *skb); diff --git a/src/tests/netns.sh b/src/tests/netns.sh index a3145b8..375868e 100755 --- a/src/tests/netns.sh +++ b/src/tests/netns.sh @@ -376,7 +376,7 @@ ip0 link add dev wg0 type wireguard config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" ) for a in {1..255}; do for b in {0..255}; do - config+=( "AllowedIPs=$a.$b.0.0/16" ) + config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" ) done done n0 wg setconf wg0 <(printf '%s\n' "${config[@]}") @@ -384,7 +384,7 @@ i=0 for ip in $(n0 wg show wg0 allowed-ips); do ((++i)) done -((i == 65281)) +((i == 255*256*2+1)) ip0 link del wg0 ip0 link add dev wg0 type wireguard config=( "[Interface]" "PrivateKey=$(wg genkey)" ) |