summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/data.c6
-rw-r--r--src/noise.c10
-rw-r--r--src/noise.h1
3 files changed, 13 insertions, 4 deletions
diff --git a/src/data.c b/src/data.c
index 2da5ddd..8738e67 100644
--- a/src/data.c
+++ b/src/data.c
@@ -191,10 +191,9 @@ int packet_create_data(struct sk_buff *skb, struct wireguard_peer *peer, void(*c
unsigned int num_frags;
rcu_read_lock();
- keypair = rcu_dereference(peer->keypairs.current_keypair);
+ keypair = noise_keypair_get(rcu_dereference(peer->keypairs.current_keypair));
if (unlikely(!keypair))
goto err_rcu;
- kref_get(&keypair->refcount);
rcu_read_unlock();
if (unlikely(!get_encryption_nonce(&nonce, &keypair->sending)))
@@ -367,12 +366,11 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de
goto err;
ret = -EINVAL;
rcu_read_lock();
- keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx);
+ keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
if (unlikely(!keypair)) {
rcu_read_unlock();
goto err;
}
- kref_get(&keypair->refcount);
rcu_read_unlock();
#ifdef CONFIG_WIREGUARD_PARALLEL
if (cpumask_weight(cpu_online_mask) > 1) {
diff --git a/src/noise.c b/src/noise.c
index b24b483..cefedee 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -95,6 +95,16 @@ void noise_keypair_put(struct noise_keypair *keypair)
kref_put(&keypair->refcount, keypair_free_kref);
}
+struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair)
+{
+ RCU_LOCKDEP_WARN(!rcu_read_lock_held(), "Calling noise_keypair_get without holding the RCU read lock.");
+ if (unlikely(!keypair))
+ return NULL;
+ if (unlikely(!kref_get_unless_zero(&keypair->refcount)))
+ return NULL;
+ return keypair;
+}
+
void noise_keypairs_clear(struct noise_keypairs *keypairs)
{
struct noise_keypair *old;
diff --git a/src/noise.h b/src/noise.h
index ca865f8..a849dc9 100644
--- a/src/noise.h
+++ b/src/noise.h
@@ -105,6 +105,7 @@ void noise_init(void);
void noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[static NOISE_PUBLIC_KEY_LEN], struct wireguard_peer *peer);
void noise_handshake_clear(struct noise_handshake *handshake);
void noise_keypair_put(struct noise_keypair *keypair);
+struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair);
void noise_keypairs_clear(struct noise_keypairs *keypairs);
bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair);