diff options
-rw-r--r-- | src/config.c | 32 | ||||
-rw-r--r-- | src/device.c | 50 | ||||
-rw-r--r-- | src/noise.c | 8 | ||||
-rw-r--r-- | src/noise.h | 2 | ||||
-rw-r--r-- | src/peer.c | 27 | ||||
-rw-r--r-- | src/peer.h | 31 |
6 files changed, 73 insertions, 77 deletions
diff --git a/src/config.c b/src/config.c index c3fe154..48afae1 100644 --- a/src/config.c +++ b/src/config.c @@ -8,20 +8,15 @@ #include "hashtables.h" #include "peer.h" #include "uapi.h" - -static int clear_peer_endpoint_src(struct wireguard_peer *peer, void *data) -{ - socket_clear_peer_endpoint_src(peer); - return 0; -} - static int set_device_port(struct wireguard_device *wg, u16 port) { + struct wireguard_peer *peer, *temp; socket_uninit(wg); wg->incoming_port = port; if (!(netdev_pub(wg)->flags & IFF_UP)) return 0; - peer_for_each_unlocked(wg, clear_peer_endpoint_src, NULL); + peer_for_each (wg, peer, temp, false) + socket_clear_peer_endpoint_src(peer); return socket_init(wg); } @@ -133,6 +128,7 @@ int config_set_device(struct wireguard_device *wg, void __user *user_device) { int ret; size_t i, offset; + struct wireguard_peer *peer, *temp; struct wgdevice in_device; void __user *user_peer; bool modified_static_identity = false; @@ -152,7 +148,8 @@ int config_set_device(struct wireguard_device *wg, void __user *user_device) if (in_device.fwmark || (!in_device.fwmark && (in_device.flags & WGDEVICE_REMOVE_FWMARK))) { wg->fwmark = in_device.fwmark; - peer_for_each_unlocked(wg, clear_peer_endpoint_src, NULL); + peer_for_each (wg, peer, temp, false) + socket_clear_peer_endpoint_src(peer); } if (in_device.port) { @@ -183,8 +180,10 @@ int config_set_device(struct wireguard_device *wg, void __user *user_device) } if (modified_static_identity) { - if (peer_for_each_unlocked(wg, noise_precompute_static_static, NULL) < 0) - noise_set_static_identity_private_key(&wg->static_identity, NULL); + peer_for_each (wg, peer, temp, false) { + if (!noise_precompute_static_static(peer)) + peer_remove(peer); + } cookie_checker_precompute_device_keys(&wg->cookie_checker); } @@ -242,10 +241,9 @@ static int populate_ipmask(void *ctx, union nf_inet_addr ip, u8 cidr, int family return ret; } -static int populate_peer(struct wireguard_peer *peer, void *ctx) +static int populate_peer(struct wireguard_peer *peer, struct data_remaining *data) { int ret = 0; - struct data_remaining *data = ctx; void __user *upeer = data->data; struct wgpeer out_peer; struct data_remaining ipmasks_data = { NULL }; @@ -289,6 +287,7 @@ static int populate_peer(struct wireguard_peer *peer, void *ctx) int config_get_device(struct wireguard_device *wg, void __user *user_device) { int ret; + struct wireguard_peer *peer, *temp; struct net_device *dev = netdev_pub(wg); struct data_remaining peer_data = { NULL }; struct wgdevice out_device; @@ -330,7 +329,12 @@ int config_get_device(struct wireguard_device *wg, void __user *user_device) peer_data.out_len = in_device.peers_size; peer_data.data = user_device + sizeof(struct wgdevice); - ret = peer_for_each_unlocked(wg, populate_peer, &peer_data); + + peer_for_each (wg, peer, temp, false) { + ret = populate_peer(peer, &peer_data); + if (ret) + break; + } if (ret) goto out; out_device.num_peers = peer_data.count; diff --git a/src/device.c b/src/device.c index e10aeed..a06750a 100644 --- a/src/device.c +++ b/src/device.c @@ -26,18 +26,10 @@ #include <net/netfilter/nf_nat_core.h> #endif -static int open_peer(struct wireguard_peer *peer, void *data) -{ - timers_init_peer(peer); - packet_send_queue(peer); - if (peer->persistent_keepalive_interval) - packet_send_keepalive(peer); - return 0; -} - static int open(struct net_device *dev) { int ret; + struct wireguard_peer *peer, *temp; struct wireguard_device *wg = netdev_priv(dev); #if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0) struct inet6_dev *dev_v6 = __in6_dev_get(dev); @@ -64,16 +56,12 @@ static int open(struct net_device *dev) ret = socket_init(wg); if (ret < 0) return ret; - peer_for_each(wg, open_peer, NULL); - return 0; -} - -static int clear_noise_peer(struct wireguard_peer *peer, void *data) -{ - noise_handshake_clear(&peer->handshake); - noise_keypairs_clear(&peer->keypairs); - if (peer->timers_enabled) - del_timer(&peer->timer_kill_ephemerals); + peer_for_each (wg, peer, temp, true) { + timers_init_peer(peer); + packet_send_queue(peer); + if (peer->persistent_keepalive_interval) + packet_send_keepalive(peer); + } return 0; } @@ -81,25 +69,31 @@ static int clear_noise_peer(struct wireguard_peer *peer, void *data) static int suspending_clear_noise_peers(struct notifier_block *nb, unsigned long action, void *data) { struct wireguard_device *wg = container_of(nb, struct wireguard_device, clear_peers_on_suspend); + struct wireguard_peer *peer, *temp; if (action == PM_HIBERNATION_PREPARE || action == PM_SUSPEND_PREPARE) { - peer_for_each(wg, clear_noise_peer, NULL); + peer_for_each (wg, peer, temp, true) { + noise_handshake_clear(&peer->handshake); + noise_keypairs_clear(&peer->keypairs); + if (peer->timers_enabled) + del_timer(&peer->timer_kill_ephemerals); + } rcu_barrier_bh(); } return 0; } #endif -static int stop_peer(struct wireguard_peer *peer, void *data) -{ - timers_uninit_peer(peer); - clear_noise_peer(peer, data); - return 0; -} - static int stop(struct net_device *dev) { struct wireguard_device *wg = netdev_priv(dev); - peer_for_each(wg, stop_peer, NULL); + struct wireguard_peer *peer, *temp; + peer_for_each (wg, peer, temp, true) { + timers_uninit_peer(peer); + noise_handshake_clear(&peer->handshake); + noise_keypairs_clear(&peer->keypairs); + if (peer->timers_enabled) + del_timer(&peer->timer_kill_ephemerals); + } skb_queue_purge(&wg->incoming_handshakes); socket_uninit(wg); return 0; diff --git a/src/noise.c b/src/noise.c index 9e7fab0..c9d8148 100644 --- a/src/noise.c +++ b/src/noise.c @@ -38,12 +38,12 @@ void noise_init(void) blake2s_final(&blake, handshake_init_hash, NOISE_HASH_LEN); } -int noise_precompute_static_static(struct wireguard_peer *peer, void *ctx) +bool noise_precompute_static_static(struct wireguard_peer *peer) { if (peer->handshake.static_identity->has_identity) - return curve25519(peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static) ? 0 : -EINVAL; + return curve25519(peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static); memset(peer->handshake.precomputed_static_static, 0, NOISE_PUBLIC_KEY_LEN); - return 0; + return true; } bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer) @@ -56,7 +56,7 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static memcpy(handshake->preshared_key, peer_preshared_key, NOISE_SYMMETRIC_KEY_LEN); handshake->static_identity = static_identity; handshake->state = HANDSHAKE_ZEROED; - return !noise_precompute_static_static(peer, static_identity); + return noise_precompute_static_static(peer); } void noise_handshake_clear(struct noise_handshake *handshake) diff --git a/src/noise.h b/src/noise.h index 5e4d9af..c2d7e63 100644 --- a/src/noise.h +++ b/src/noise.h @@ -109,7 +109,7 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs); bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair); void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]); -int noise_precompute_static_static(struct wireguard_peer *peer, void *ctx); +bool noise_precompute_static_static(struct wireguard_peer *peer); bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake); struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg); @@ -108,33 +108,6 @@ void peer_put(struct wireguard_peer *peer) kref_put(&peer->refcount, kref_release); } -int peer_for_each_unlocked(struct wireguard_device *wg, int (*fn)(struct wireguard_peer *peer, void *ctx), void *data) -{ - struct wireguard_peer *peer, *temp; - int ret = 0; - - lockdep_assert_held(&wg->device_update_lock); - list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) { - peer = peer_rcu_get(peer); - if (unlikely(!peer)) - continue; - ret = fn(peer, data); - peer_put(peer); - if (ret < 0) - break; - } - return ret; -} - -int peer_for_each(struct wireguard_device *wg, int (*fn)(struct wireguard_peer *peer, void *ctx), void *data) -{ - int ret; - mutex_lock(&wg->device_update_lock); - ret = peer_for_each_unlocked(wg, fn, data); - mutex_unlock(&wg->device_update_lock); - return ret; -} - void peer_remove_all(struct wireguard_device *wg) { struct wireguard_peer *peer, *temp; @@ -67,9 +67,34 @@ void peer_remove_all(struct wireguard_device *wg); struct wireguard_peer *peer_lookup_by_index(struct wireguard_device *wg, u32 index); -int peer_for_each_unlocked(struct wireguard_device *wg, int (*fn)(struct wireguard_peer *peer, void *ctx), void *data); -int peer_for_each(struct wireguard_device *wg, int (*fn)(struct wireguard_peer *peer, void *ctx), void *data); - unsigned int peer_total_count(struct wireguard_device *wg); +/* This is a macro iterator of essentially this: + * + * if (__should_lock) + * mutex_lock(&(__wg)->device_update_lock); + * else + * lockdep_assert_held(&(__wg)->device_update_lock) + * list_for_each_entry_safe (__peer, __temp, &(__wg)->peer_list, peer_list) { + * __peer = peer_rcu_get(__peer); + * if (!__peer) + * continue; + * ITERATOR_BODY + * peer_put(__peer); + * } + * if (__should_lock) + * mutex_unlock(&(__wg)->device_update_lock); + * + * While it's really ugly to look at, the code gcc produces from it is actually perfect. + */ +#define pfe_label(n) __PASTE(__PASTE(pfe_label_, n ## _), __LINE__) +#define peer_for_each(__wg, __peer, __temp, __should_lock) \ + if (1) { if (__should_lock) mutex_lock(&(__wg)->device_update_lock); else lockdep_assert_held(&(__wg)->device_update_lock); goto pfe_label(1); } else pfe_label(1): \ + if (1) goto pfe_label(2); else while (1) if (1) { if (__should_lock) mutex_unlock(&(__wg)->device_update_lock); break; } else pfe_label(2): \ + list_for_each_entry_safe (__peer, __temp, &(__wg)->peer_list, peer_list) \ + if (0) pfe_label(3): break; else \ + if (0); else for (__peer = peer_rcu_get(peer); __peer;) if (1) { goto pfe_label(4); pfe_label(5): break; } else while (1) if (1) goto pfe_label(5); else pfe_label(4): \ + if (1) { goto pfe_label(6); pfe_label(7):; } else while (1) if (1) goto pfe_label(3); else while (1) if (1) goto pfe_label(7); else pfe_label(6): \ + if (1) { goto pfe_label(8); pfe_label(9): peer_put(__peer); break; pfe_label(10): peer_put(__peer); } else while (1) if (1) goto pfe_label(9); else while (1) if (1) goto pfe_label(10); else pfe_label(8): + #endif |