summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/netlink.c37
-rw-r--r--src/routingtable.c76
-rw-r--r--src/routingtable.h17
-rwxr-xr-xsrc/tests/netns.sh4
4 files changed, 81 insertions, 53 deletions
diff --git a/src/netlink.c b/src/netlink.c
index fc27b7f..e727669 100644
--- a/src/netlink.c
+++ b/src/netlink.c
@@ -63,22 +63,20 @@ static struct wireguard_device *lookup_interface(struct nlattr **attrs, struct s
struct allowedips_ctx {
struct sk_buff *skb;
- unsigned int idx_cursor, idx;
+ unsigned int i;
};
-static int get_allowedips(void *ctx, union nf_inet_addr ip, u8 cidr, int family)
+static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family)
{
struct nlattr *allowedip_nest;
struct allowedips_ctx *actx = ctx;
- if (++actx->idx < actx->idx_cursor)
- return 0;
- allowedip_nest = nla_nest_start(actx->skb, actx->idx - 1);
+ allowedip_nest = nla_nest_start(actx->skb, actx->i++);
if (!allowedip_nest)
return -EMSGSIZE;
if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) || nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) ||
- nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), &ip)) {
+ nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) {
nla_nest_cancel(actx->skb, allowedip_nest);
return -EMSGSIZE;
}
@@ -87,9 +85,9 @@ static int get_allowedips(void *ctx, union nf_inet_addr ip, u8 cidr, int family)
return 0;
}
-static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned int *allowedips_idx_cursor, struct sk_buff *skb)
+static int get_peer(struct wireguard_peer *peer, unsigned int index, struct routing_table_cursor *rt_cursor, struct sk_buff *skb)
{
- struct allowedips_ctx ctx = { .skb = skb, .idx_cursor = *allowedips_idx_cursor };
+ struct allowedips_ctx ctx = { .skb = skb };
struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, index);
bool fail;
@@ -102,7 +100,7 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned in
if (fail)
goto err;
- if (!ctx.idx_cursor) {
+ if (!rt_cursor->seq) {
down_read(&peer->handshake.lock);
fail = nla_put(skb, WGPEER_A_PRESHARED_KEY, NOISE_SYMMETRIC_KEY_LEN, peer->handshake.preshared_key);
up_read(&peer->handshake.lock);
@@ -126,13 +124,12 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned in
allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
if (!allowedips_nest)
goto err;
- if (routing_table_walk_ips_by_peer(&peer->device->peer_routing_table, &ctx, peer, get_allowedips, &peer->device->device_update_lock)) {
- *allowedips_idx_cursor = ctx.idx;
+ if (routing_table_walk_by_peer(&peer->device->peer_routing_table, rt_cursor, peer, get_allowedips, &ctx, &peer->device->device_update_lock)) {
nla_nest_end(skb, allowedips_nest);
nla_nest_end(skb, peer_nest);
return -EMSGSIZE;
}
- *allowedips_idx_cursor = 0;
+ memset(rt_cursor, 0, sizeof(*rt_cursor));
nla_nest_end(skb, allowedips_nest);
nla_nest_end(skb, peer_nest);
return 0;
@@ -149,9 +146,15 @@ static int get_device_start(struct netlink_callback *cb)
if (ret < 0)
return ret;
+ cb->args[2] = (long)kzalloc(sizeof(struct routing_table_cursor), GFP_KERNEL);
+ if (!cb->args[2])
+ return -ENOMEM;
wg = lookup_interface(attrs, cb->skb);
- if (IS_ERR(wg))
+ if (IS_ERR(wg)) {
+ kfree((void *)cb->args[2]);
+ cb->args[2] = 0;
return PTR_ERR(wg);
+ }
cb->args[0] = (long)wg;
return 0;
}
@@ -160,7 +163,8 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
{
struct wireguard_device *wg = (struct wireguard_device *)cb->args[0];
struct wireguard_peer *peer, *next_peer_cursor = NULL, *last_peer_cursor = (struct wireguard_peer *)cb->args[1];
- unsigned int peer_idx = 0, allowedips_idx_cursor = (unsigned int)cb->args[2];
+ struct routing_table_cursor *rt_cursor = (struct routing_table_cursor *)cb->args[2];
+ unsigned int peer_idx = 0;
struct nlattr *peers_nest;
bool done = true;
void *hdr;
@@ -203,7 +207,7 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
lockdep_assert_held(&wg->device_update_lock);
peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list);
list_for_each_entry_continue (peer, &wg->peer_list, peer_list) {
- if (get_peer(peer, peer_idx++, &allowedips_idx_cursor, skb)) {
+ if (get_peer(peer, peer_idx++, rt_cursor, skb)) {
done = false;
break;
}
@@ -228,7 +232,6 @@ out:
return 0;
}
cb->args[1] = (long)next_peer_cursor;
- cb->args[2] = (long)allowedips_idx_cursor;
return skb->len;
/* At this point, we can't really deal ourselves with safely zeroing out
@@ -240,9 +243,11 @@ static int get_device_done(struct netlink_callback *cb)
{
struct wireguard_device *wg = (struct wireguard_device *)cb->args[0];
struct wireguard_peer *peer = (struct wireguard_peer *)cb->args[1];
+ struct routing_table_cursor *rt_cursor = (struct routing_table_cursor *)cb->args[2];
if (wg)
dev_put(wg->dev);
+ kfree(rt_cursor);
peer_put(peer);
return 0;
}
diff --git a/src/routingtable.c b/src/routingtable.c
index e884383..6cbea3d 100644
--- a/src/routingtable.c
+++ b/src/routingtable.c
@@ -25,40 +25,38 @@ static void node_free_rcu(struct rcu_head *rcu)
{
kfree(container_of(rcu, struct routing_table_node, rcu));
}
-#define push(p, lock) ({ \
+
+#define push(stack, p, len) ({ \
if (rcu_access_pointer(p)) { \
BUG_ON(len >= 128); \
- stack[len++] = lock ? rcu_dereference_protected(p, lockdep_is_held((struct mutex *)lock)) : rcu_dereference_bh(p); \
+ stack[len++] = rcu_dereference_protected(p, lockdep_is_held(lock)); \
} \
true; \
})
-#define walk_prep \
- struct routing_table_node *stack[128], *node; \
- unsigned int len;
-#define walk(top, lock) for (len = 0, push(top, lock); len > 0 && (node = stack[--len]) && push(node->bit[0], lock) && push(node->bit[1], lock);)
-
static void free_root_node(struct routing_table_node __rcu *top, struct mutex *lock)
{
- walk_prep;
+ struct routing_table_node *stack[128], *node;
+ unsigned int len;
- walk (top, lock)
+ for (len = 0, push(stack, top, len); len > 0 && (node = stack[--len]) && push(stack, node->bit[0], len) && push(stack, node->bit[1], len);)
call_rcu_bh(&node->rcu, node_free_rcu);
}
-static int walk_ips_by_peer(struct routing_table_node __rcu *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
+static int walk_by_peer(struct routing_table_node __rcu *top, int family, struct routing_table_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
{
+ struct routing_table_node *node;
int ret;
- union nf_inet_addr ip = { .all = { 0 } };
- walk_prep;
- if (unlikely(!peer))
+ if (!rcu_access_pointer(top))
return 0;
- walk (top, maybe_lock) {
+ if (!cursor->len)
+ push(cursor->stack, top, cursor->len);
+
+ for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push(cursor->stack, node->bit[0], cursor->len), push(cursor->stack, node->bit[1], cursor->len)) {
if (node->peer != peer)
continue;
- memcpy(ip.all, node->bits, family == AF_INET6 ? 16 : 4);
- ret = func(ctx, ip, node->cidr, family);
+ ret = func(ctx, node->bits, node->cidr, family);
if (ret)
return ret;
}
@@ -234,37 +232,55 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
void routing_table_init(struct routing_table *table)
{
- memset(table, 0, sizeof(struct routing_table));
+ table->root4 = table->root6 = NULL;
+ table->seq = 1;
}
-void routing_table_free(struct routing_table *table, struct mutex *mutex)
+void routing_table_free(struct routing_table *table, struct mutex *lock)
{
- free_root_node(table->root4, mutex);
+ ++table->seq;
+ free_root_node(table->root4, lock);
rcu_assign_pointer(table->root4, NULL);
- free_root_node(table->root6, mutex);
+ free_root_node(table->root6, lock);
rcu_assign_pointer(table->root6, NULL);
}
-int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex)
+int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
{
- return add(&table->root4, 32, (const u8 *)ip, cidr, peer, mutex);
+ ++table->seq;
+ return add(&table->root4, 32, (const u8 *)ip, cidr, peer, lock);
}
-int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex)
+int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
{
- return add(&table->root6, 128, (const u8 *)ip, cidr, peer, mutex);
+ ++table->seq;
+ return add(&table->root6, 128, (const u8 *)ip, cidr, peer, lock);
}
-void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex)
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *lock)
{
- walk_remove_by_peer(&table->root4, peer, mutex);
- walk_remove_by_peer(&table->root6, peer, mutex);
+ ++table->seq;
+ walk_remove_by_peer(&table->root4, peer, lock);
+ walk_remove_by_peer(&table->root6, peer, lock);
}
-int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex)
+int routing_table_walk_by_peer(struct routing_table *table, struct routing_table_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
{
- return walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, mutex) ?:
- walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, mutex);
+ int ret;
+
+ if (!cursor->seq)
+ cursor->seq = table->seq;
+ else if (cursor->seq != table->seq)
+ return 0;
+
+ if (!cursor->second_half) {
+ ret = walk_by_peer(table->root4, AF_INET, cursor, peer, func, ctx, lock);
+ if (ret)
+ return ret;
+ cursor->len = 0;
+ cursor->second_half = true;
+ }
+ return walk_by_peer(table->root6, AF_INET6, cursor, peer, func, ctx, lock);
}
/* Returns a strong reference to a peer */
diff --git a/src/routingtable.h b/src/routingtable.h
index 815118c..d8666f8 100644
--- a/src/routingtable.h
+++ b/src/routingtable.h
@@ -13,15 +13,22 @@ struct routing_table_node;
struct routing_table {
struct routing_table_node __rcu *root4;
struct routing_table_node __rcu *root6;
+ u64 seq;
+};
+
+struct routing_table_cursor {
+ u64 seq;
+ struct routing_table_node *stack[128];
+ unsigned int len;
+ bool second_half;
};
void routing_table_init(struct routing_table *table);
void routing_table_free(struct routing_table *table, struct mutex *mutex);
-int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex);
-int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex);
-void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex);
-
-int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex);
+int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock);
+int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock);
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *lock);
+int routing_table_walk_by_peer(struct routing_table *table, struct routing_table_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock);
/* These return a strong reference to a peer: */
struct wireguard_peer *routing_table_lookup_dst(struct routing_table *table, struct sk_buff *skb);
diff --git a/src/tests/netns.sh b/src/tests/netns.sh
index a3145b8..375868e 100755
--- a/src/tests/netns.sh
+++ b/src/tests/netns.sh
@@ -376,7 +376,7 @@ ip0 link add dev wg0 type wireguard
config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
for a in {1..255}; do
for b in {0..255}; do
- config+=( "AllowedIPs=$a.$b.0.0/16" )
+ config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
done
done
n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
@@ -384,7 +384,7 @@ i=0
for ip in $(n0 wg show wg0 allowed-ips); do
((++i))
done
-((i == 65281))
+((i == 255*256*2+1))
ip0 link del wg0
ip0 link add dev wg0 type wireguard
config=( "[Interface]" "PrivateKey=$(wg genkey)" )