diff options
-rw-r--r-- | src/hashtables.c | 5 | ||||
-rw-r--r-- | src/hashtables.h | 2 | ||||
-rw-r--r-- | src/noise.c | 28 |
3 files changed, 24 insertions, 11 deletions
diff --git a/src/hashtables.c b/src/hashtables.c index db97f7e..a01a899 100644 --- a/src/hashtables.c +++ b/src/hashtables.c @@ -97,13 +97,16 @@ search_unused_slot: return entry->index; } -void index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new) +bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new) { + if (unlikely(hlist_unhashed(&old->index_hash))) + return false; spin_lock_bh(&table->lock); new->index = old->index; hlist_replace_rcu(&old->index_hash, &new->index_hash); INIT_HLIST_NODE(&old->index_hash); spin_unlock_bh(&table->lock); + return true; } void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry) diff --git a/src/hashtables.h b/src/hashtables.h index 9fa47d5..08a2a5d 100644 --- a/src/hashtables.h +++ b/src/hashtables.h @@ -40,7 +40,7 @@ struct index_hashtable_entry { }; void index_hashtable_init(struct index_hashtable *table); __le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry); -void index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new); +bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new); void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry); struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index); diff --git a/src/noise.c b/src/noise.c index 7ca2a67..9583ab1 100644 --- a/src/noise.c +++ b/src/noise.c @@ -59,16 +59,21 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static return noise_precompute_static_static(peer); } -void noise_handshake_clear(struct noise_handshake *handshake) +static void handshake_zero(struct noise_handshake *handshake) { - index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); - down_write(&handshake->lock); memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN); memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN); memset(&handshake->hash, 0, NOISE_HASH_LEN); memset(&handshake->chaining_key, 0, NOISE_HASH_LEN); handshake->remote_index = 0; handshake->state = HANDSHAKE_ZEROED; +} + +void noise_handshake_clear(struct noise_handshake *handshake) +{ + index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + down_write(&handshake->lock); + handshake_zero(handshake); up_write(&handshake->lock); index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); } @@ -371,8 +376,8 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); - ret = true; handshake->state = HANDSHAKE_CREATED_INITIATION; + ret = true; out: up_write(&handshake->lock); @@ -548,6 +553,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake /* Success! Copy everything to peer */ down_write(&handshake->lock); + /* It's important to check that the state is still the same, while we have an exclusive lock */ + if (handshake->state != state) { + up_write(&handshake->lock); + goto fail; + } memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN); memcpy(handshake->hash, hash, NOISE_HASH_LEN); memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN); @@ -573,7 +583,7 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi { struct noise_keypair *new_keypair; - down_read(&handshake->lock); + down_write(&handshake->lock); if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE) goto fail; @@ -587,16 +597,16 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi derive_keys(&new_keypair->sending, &new_keypair->receiving, handshake->chaining_key); else derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key); - up_read(&handshake->lock); + handshake_zero(handshake); add_new_keypair(keypairs, new_keypair); - index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry); - noise_handshake_clear(handshake); net_dbg_ratelimited("%s: Keypair %Lu created for peer %Lu\n", netdev_pub(new_keypair->entry.peer->device)->name, new_keypair->internal_id, new_keypair->entry.peer->internal_id); + WARN_ON(!index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry)); + up_write(&handshake->lock); return true; fail: - up_read(&handshake->lock); + up_write(&handshake->lock); return false; } |