diff options
-rw-r--r-- | src/noise.c | 46 | ||||
-rw-r--r-- | src/noise.h | 2 | ||||
-rw-r--r-- | src/peer.c | 2 |
3 files changed, 29 insertions, 21 deletions
diff --git a/src/noise.c b/src/noise.c index 531306b..cbe3f39 100644 --- a/src/noise.c +++ b/src/noise.c @@ -123,7 +123,7 @@ struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair) void noise_keypairs_clear(struct noise_keypairs *keypairs) { struct noise_keypair *old; - mutex_lock(&keypairs->keypair_update_lock); + spin_lock_bh(&keypairs->keypair_update_lock); old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); rcu_assign_pointer(keypairs->previous_keypair, NULL); noise_keypair_put(old); @@ -133,14 +133,14 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs) old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); rcu_assign_pointer(keypairs->current_keypair, NULL); noise_keypair_put(old); - mutex_unlock(&keypairs->keypair_update_lock); + spin_unlock_bh(&keypairs->keypair_update_lock); } static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypair *new_keypair) { struct noise_keypair *previous_keypair, *next_keypair, *current_keypair; - mutex_lock(&keypairs->keypair_update_lock); + spin_lock_bh(&keypairs->keypair_update_lock); previous_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); next_keypair = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); current_keypair = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); @@ -174,31 +174,39 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai rcu_assign_pointer(keypairs->previous_keypair, NULL); noise_keypair_put(previous_keypair); } - mutex_unlock(&keypairs->keypair_update_lock); + spin_unlock_bh(&keypairs->keypair_update_lock); } bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair) { - bool ret = false; + bool key_is_new; struct noise_keypair *old_keypair; - /* TODO: probably this needs the actual mutex, but we're in atomic context, - * so we can't take it here. Instead we just rely on RCU for the lookups. */ + /* We first check without taking the spinlock but just RCU. */ rcu_read_lock_bh(); - if (unlikely(received_keypair == rcu_dereference_bh(keypairs->next_keypair))) { - ret = true; - /* When we've finally received the confirmation, we slide the next - * into the current, the current into the previous, and get rid of - * the old previous. */ - old_keypair = rcu_dereference_bh(keypairs->previous_keypair); - rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_bh(keypairs->current_keypair)); - noise_keypair_put(old_keypair); - rcu_assign_pointer(keypairs->current_keypair, received_keypair); - rcu_assign_pointer(keypairs->next_keypair, NULL); - } + key_is_new = received_keypair == rcu_dereference_bh(keypairs->next_keypair); rcu_read_unlock_bh(); + if (likely(!key_is_new)) + return false; - return ret; + spin_lock_bh(&keypairs->keypair_update_lock); + /* After locking, we double check that things didn't change from beneath us. */ + if (unlikely(received_keypair != rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)))) { + spin_unlock_bh(&keypairs->keypair_update_lock); + return false; + } + + /* When we've finally received the confirmation, we slide the next + * into the current, the current into the previous, and get rid of + * the old previous. */ + old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock))); + noise_keypair_put(old_keypair); + rcu_assign_pointer(keypairs->current_keypair, received_keypair); + rcu_assign_pointer(keypairs->next_keypair, NULL); + + spin_unlock_bh(&keypairs->keypair_update_lock); + return true; } void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]) diff --git a/src/noise.h b/src/noise.h index 2024b80..f6c68ea 100644 --- a/src/noise.h +++ b/src/noise.h @@ -49,7 +49,7 @@ struct noise_keypairs { struct noise_keypair __rcu *current_keypair; struct noise_keypair __rcu *previous_keypair; struct noise_keypair __rcu *next_keypair; - struct mutex keypair_update_lock; + spinlock_t keypair_update_lock; }; struct noise_static_identity { @@ -41,7 +41,7 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ } timers_init(peer); cookie_checker_precompute_peer_keys(peer); - mutex_init(&peer->keypairs.keypair_update_lock); + spin_lock_init(&peer->keypairs.keypair_update_lock); INIT_WORK(&peer->transmit_handshake_work, packet_handshake_send_worker); rwlock_init(&peer->endpoint_lock); kref_init(&peer->refcount); |