diff options
-rw-r--r-- | src/data.c | 6 | ||||
-rw-r--r-- | src/noise.c | 10 | ||||
-rw-r--r-- | src/noise.h | 1 |
3 files changed, 13 insertions, 4 deletions
@@ -191,10 +191,9 @@ int packet_create_data(struct sk_buff *skb, struct wireguard_peer *peer, void(*c unsigned int num_frags; rcu_read_lock(); - keypair = rcu_dereference(peer->keypairs.current_keypair); + keypair = noise_keypair_get(rcu_dereference(peer->keypairs.current_keypair)); if (unlikely(!keypair)) goto err_rcu; - kref_get(&keypair->refcount); rcu_read_unlock(); if (unlikely(!get_encryption_nonce(&nonce, &keypair->sending))) @@ -367,12 +366,11 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de goto err; ret = -EINVAL; rcu_read_lock(); - keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx); + keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); if (unlikely(!keypair)) { rcu_read_unlock(); goto err; } - kref_get(&keypair->refcount); rcu_read_unlock(); #ifdef CONFIG_WIREGUARD_PARALLEL if (cpumask_weight(cpu_online_mask) > 1) { diff --git a/src/noise.c b/src/noise.c index b24b483..cefedee 100644 --- a/src/noise.c +++ b/src/noise.c @@ -95,6 +95,16 @@ void noise_keypair_put(struct noise_keypair *keypair) kref_put(&keypair->refcount, keypair_free_kref); } +struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair) +{ + RCU_LOCKDEP_WARN(!rcu_read_lock_held(), "Calling noise_keypair_get without holding the RCU read lock."); + if (unlikely(!keypair)) + return NULL; + if (unlikely(!kref_get_unless_zero(&keypair->refcount))) + return NULL; + return keypair; +} + void noise_keypairs_clear(struct noise_keypairs *keypairs) { struct noise_keypair *old; diff --git a/src/noise.h b/src/noise.h index ca865f8..a849dc9 100644 --- a/src/noise.h +++ b/src/noise.h @@ -105,6 +105,7 @@ void noise_init(void); void noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[static NOISE_PUBLIC_KEY_LEN], struct wireguard_peer *peer); void noise_handshake_clear(struct noise_handshake *handshake); void noise_keypair_put(struct noise_keypair *keypair); +struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair); void noise_keypairs_clear(struct noise_keypairs *keypairs); bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair); |