summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2017-10-09 02:48:33 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2017-10-09 04:40:22 +0200
commite000d5c96d2fbe8864fa313b203e2210f24ed18c (patch)
treed96dc5123b4b49d811da3ad3d6614db3ee5fc641
parent89db52f3fba403124701fdda0503a3443ca6016e (diff)
routingtable: only use device's mutex, not a special rt one
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--src/device.c2
-rw-r--r--src/netlink.c8
-rw-r--r--src/peer.c2
-rw-r--r--src/routingtable.c96
-rw-r--r--src/routingtable.h14
-rw-r--r--src/selftest/routingtable.h16
6 files changed, 38 insertions, 100 deletions
diff --git a/src/device.c b/src/device.c
index 5102acc..0fb5dcd 100644
--- a/src/device.c
+++ b/src/device.c
@@ -212,7 +212,7 @@ static void destruct(struct net_device *dev)
packet_queue_free(&wg->decrypt_queue, true);
packet_queue_free(&wg->encrypt_queue, true);
destroy_workqueue(wg->packet_crypt_wq);
- routing_table_free(&wg->peer_routing_table);
+ routing_table_free(&wg->peer_routing_table, &wg->device_update_lock);
ratelimiter_uninit();
memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity));
skb_queue_purge(&wg->incoming_handshakes);
diff --git a/src/netlink.c b/src/netlink.c
index b813508..fc27b7f 100644
--- a/src/netlink.c
+++ b/src/netlink.c
@@ -126,7 +126,7 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, unsigned in
allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
if (!allowedips_nest)
goto err;
- if (routing_table_walk_ips_by_peer_sleepable(&peer->device->peer_routing_table, &ctx, peer, get_allowedips)) {
+ if (routing_table_walk_ips_by_peer(&peer->device->peer_routing_table, &ctx, peer, get_allowedips, &peer->device->device_update_lock)) {
*allowedips_idx_cursor = ctx.idx;
nla_nest_end(skb, allowedips_nest);
nla_nest_end(skb, peer_nest);
@@ -274,9 +274,9 @@ static int set_allowedip(struct wireguard_peer *peer, struct nlattr **attrs)
cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
if (family == AF_INET && cidr <= 32 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
- ret = routing_table_insert_v4(&peer->device->peer_routing_table, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer);
+ ret = routing_table_insert_v4(&peer->device->peer_routing_table, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock);
else if (family == AF_INET6 && cidr <= 128 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
- ret = routing_table_insert_v6(&peer->device->peer_routing_table, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer);
+ ret = routing_table_insert_v6(&peer->device->peer_routing_table, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock);
return ret;
}
@@ -343,7 +343,7 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
}
if (flags & WGPEER_F_REPLACE_ALLOWEDIPS)
- routing_table_remove_by_peer(&wg->peer_routing_table, peer);
+ routing_table_remove_by_peer(&wg->peer_routing_table, peer, &wg->device_update_lock);
if (attrs[WGPEER_A_ALLOWEDIPS]) {
int rem;
diff --git a/src/peer.c b/src/peer.c
index 4408201..8cef1f9 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -79,7 +79,7 @@ void peer_remove(struct wireguard_peer *peer)
if (unlikely(!peer))
return;
lockdep_assert_held(&peer->device->device_update_lock);
- routing_table_remove_by_peer(&peer->device->peer_routing_table, peer);
+ routing_table_remove_by_peer(&peer->device->peer_routing_table, peer, &peer->device->device_update_lock);
pubkey_hashtable_remove(&peer->device->peer_hashtable, peer);
skb_queue_purge(&peer->staged_packet_queue);
noise_handshake_clear(&peer->handshake);
diff --git a/src/routingtable.c b/src/routingtable.c
index 781c758..e884383 100644
--- a/src/routingtable.c
+++ b/src/routingtable.c
@@ -45,18 +45,6 @@ static void free_root_node(struct routing_table_node __rcu *top, struct mutex *l
call_rcu_bh(&node->rcu, node_free_rcu);
}
-static size_t count_nodes(struct routing_table_node __rcu *top)
-{
- size_t ret = 0;
- walk_prep;
-
- walk (top, NULL) {
- if (node->peer)
- ++ret;
- }
- return ret;
-}
-
static int walk_ips_by_peer(struct routing_table_node __rcu *top, int family, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *maybe_lock)
{
int ret;
@@ -185,6 +173,9 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
{
struct routing_table_node *node, *parent, *down, *newnode;
+ if (unlikely(cidr > bits || !peer))
+ return -EINVAL;
+
if (!rcu_access_pointer(*trie)) {
node = kzalloc(sizeof(*node) + (bits + 7) / 8, GFP_KERNEL);
if (!node)
@@ -244,91 +235,36 @@ static int add(struct routing_table_node __rcu **trie, u8 bits, const u8 *key, u
void routing_table_init(struct routing_table *table)
{
memset(table, 0, sizeof(struct routing_table));
- mutex_init(&table->table_update_lock);
}
-void routing_table_free(struct routing_table *table)
+void routing_table_free(struct routing_table *table, struct mutex *mutex)
{
- mutex_lock(&table->table_update_lock);
- free_root_node(table->root4, &table->table_update_lock);
+ free_root_node(table->root4, mutex);
rcu_assign_pointer(table->root4, NULL);
- free_root_node(table->root6, &table->table_update_lock);
+ free_root_node(table->root6, mutex);
rcu_assign_pointer(table->root6, NULL);
- mutex_unlock(&table->table_update_lock);
-}
-
-int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer)
-{
- int ret;
-
- if (unlikely(cidr > 32 || !peer))
- return -EINVAL;
- mutex_lock(&table->table_update_lock);
- ret = add(&table->root4, 32, (const u8 *)ip, cidr, peer, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
- return ret;
-}
-
-int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer)
-{
- int ret;
-
- if (unlikely(cidr > 128 || !peer))
- return -EINVAL;
- mutex_lock(&table->table_update_lock);
- ret = add(&table->root6, 128, (const u8 *)ip, cidr, peer, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
- return ret;
}
-void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer)
+int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex)
{
- mutex_lock(&table->table_update_lock);
- walk_remove_by_peer(&table->root4, peer, &table->table_update_lock);
- walk_remove_by_peer(&table->root6, peer, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
+ return add(&table->root4, 32, (const u8 *)ip, cidr, peer, mutex);
}
-size_t routing_table_count_nodes(struct routing_table *table)
+int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex)
{
- size_t ret;
-
- rcu_read_lock_bh();
- ret = count_nodes(table->root4) + count_nodes(table->root6);
- rcu_read_unlock_bh();
- return ret;
+ return add(&table->root6, 128, (const u8 *)ip, cidr, peer, mutex);
}
-int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family))
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex)
{
- int ret;
-
- rcu_read_lock_bh();
- ret = walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, NULL);
- rcu_read_unlock_bh();
- if (ret)
- return ret;
-
- rcu_read_lock_bh();
- ret = walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, NULL);
- rcu_read_unlock_bh();
- return ret;
+ walk_remove_by_peer(&table->root4, peer, mutex);
+ walk_remove_by_peer(&table->root6, peer, mutex);
}
-int routing_table_walk_ips_by_peer_sleepable(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family))
+int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex)
{
- int ret;
-
- mutex_lock(&table->table_update_lock);
- ret = walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
- if (ret)
- return ret;
-
- mutex_lock(&table->table_update_lock);
- ret = walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, &table->table_update_lock);
- mutex_unlock(&table->table_update_lock);
- return ret;
+ return walk_ips_by_peer(table->root4, AF_INET, ctx, peer, func, mutex) ?:
+ walk_ips_by_peer(table->root6, AF_INET6, ctx, peer, func, mutex);
}
/* Returns a strong reference to a peer */
diff --git a/src/routingtable.h b/src/routingtable.h
index c251354..815118c 100644
--- a/src/routingtable.h
+++ b/src/routingtable.h
@@ -13,17 +13,15 @@ struct routing_table_node;
struct routing_table {
struct routing_table_node __rcu *root4;
struct routing_table_node __rcu *root6;
- struct mutex table_update_lock;
};
void routing_table_init(struct routing_table *table);
-void routing_table_free(struct routing_table *table);
-int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer);
-int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer);
-void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer);
-size_t routing_table_count_nodes(struct routing_table *table);
-int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family));
-int routing_table_walk_ips_by_peer_sleepable(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family));
+void routing_table_free(struct routing_table *table, struct mutex *mutex);
+int routing_table_insert_v4(struct routing_table *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex);
+int routing_table_insert_v6(struct routing_table *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *mutex);
+void routing_table_remove_by_peer(struct routing_table *table, struct wireguard_peer *peer, struct mutex *mutex);
+
+int routing_table_walk_ips_by_peer(struct routing_table *table, void *ctx, struct wireguard_peer *peer, int (*func)(void *ctx, union nf_inet_addr ip, u8 cidr, int family), struct mutex *mutex);
/* These return a strong reference to a peer: */
struct wireguard_peer *routing_table_lookup_dst(struct routing_table *table, struct sk_buff *skb);
diff --git a/src/selftest/routingtable.h b/src/selftest/routingtable.h
index 951eb59..4e30b98 100644
--- a/src/selftest/routingtable.h
+++ b/src/selftest/routingtable.h
@@ -349,7 +349,7 @@ static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
} while (0)
#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \
- routing_table_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem)
+ routing_table_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem, &mutex)
#define maybe_fail \
++i; \
@@ -370,6 +370,7 @@ static __init inline struct in6_addr *ip6(u32 a, u32 b, u32 c, u32 d)
bool __init routing_table_selftest(void)
{
+ DEFINE_MUTEX(mutex);
struct routing_table t;
struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL;
size_t i = 0;
@@ -377,6 +378,8 @@ bool __init routing_table_selftest(void)
struct in6_addr ip;
__be64 part;
+ mutex_lock(&mutex);
+
routing_table_init(&t);
init_peer(a);
init_peer(b);
@@ -452,18 +455,18 @@ bool __init routing_table_selftest(void)
insert(4, a, 128, 0, 0, 0, 32);
insert(4, a, 192, 0, 0, 0, 32);
insert(4, a, 255, 0, 0, 0, 32);
- routing_table_remove_by_peer(&t, a);
+ routing_table_remove_by_peer(&t, a, &mutex);
test_negative(4, a, 1, 0, 0, 0);
test_negative(4, a, 64, 0, 0, 0);
test_negative(4, a, 128, 0, 0, 0);
test_negative(4, a, 192, 0, 0, 0);
test_negative(4, a, 255, 0, 0, 0);
- routing_table_free(&t);
+ routing_table_free(&t, &mutex);
routing_table_init(&t);
insert(4, a, 192, 168, 0, 0, 16);
insert(4, a, 192, 168, 0, 0, 24);
- routing_table_remove_by_peer(&t, a);
+ routing_table_remove_by_peer(&t, a, &mutex);
test_negative(4, a, 192, 168, 0, 1);
/* These will hit the BUG_ON(len >= 128) in free_node if something goes wrong. */
@@ -471,7 +474,7 @@ bool __init routing_table_selftest(void)
part = cpu_to_be64(~(1LLU << (i % 64)));
memset(&ip, 0xff, 16);
memcpy((u8 *)&ip + (i < 64) * 8, &part, 8);
- routing_table_insert_v6(&t, &ip, 128, a);
+ routing_table_insert_v6(&t, &ip, 128, a, &mutex);
}
#ifdef DEBUG_RANDOM_TRIE
@@ -483,7 +486,7 @@ bool __init routing_table_selftest(void)
pr_info("routing table self-tests: pass\n");
free:
- routing_table_free(&t);
+ routing_table_free(&t, &mutex);
kfree(a);
kfree(b);
kfree(c);
@@ -492,6 +495,7 @@ free:
kfree(f);
kfree(g);
kfree(h);
+ mutex_unlock(&mutex);
return success;
}