diff options
Diffstat (limited to 'src/noise.c')
-rw-r--r-- | src/noise.c | 28 |
1 files changed, 19 insertions, 9 deletions
diff --git a/src/noise.c b/src/noise.c index 7ca2a67..9583ab1 100644 --- a/src/noise.c +++ b/src/noise.c @@ -59,16 +59,21 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static return noise_precompute_static_static(peer); } -void noise_handshake_clear(struct noise_handshake *handshake) +static void handshake_zero(struct noise_handshake *handshake) { - index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); - down_write(&handshake->lock); memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN); memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN); memset(&handshake->hash, 0, NOISE_HASH_LEN); memset(&handshake->chaining_key, 0, NOISE_HASH_LEN); handshake->remote_index = 0; handshake->state = HANDSHAKE_ZEROED; +} + +void noise_handshake_clear(struct noise_handshake *handshake) +{ + index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + down_write(&handshake->lock); + handshake_zero(handshake); up_write(&handshake->lock); index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); } @@ -371,8 +376,8 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); - ret = true; handshake->state = HANDSHAKE_CREATED_INITIATION; + ret = true; out: up_write(&handshake->lock); @@ -548,6 +553,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake /* Success! Copy everything to peer */ down_write(&handshake->lock); + /* It's important to check that the state is still the same, while we have an exclusive lock */ + if (handshake->state != state) { + up_write(&handshake->lock); + goto fail; + } memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN); memcpy(handshake->hash, hash, NOISE_HASH_LEN); memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN); @@ -573,7 +583,7 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi { struct noise_keypair *new_keypair; - down_read(&handshake->lock); + down_write(&handshake->lock); if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE) goto fail; @@ -587,16 +597,16 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi derive_keys(&new_keypair->sending, &new_keypair->receiving, handshake->chaining_key); else derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key); - up_read(&handshake->lock); + handshake_zero(handshake); add_new_keypair(keypairs, new_keypair); - index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry); - noise_handshake_clear(handshake); net_dbg_ratelimited("%s: Keypair %Lu created for peer %Lu\n", netdev_pub(new_keypair->entry.peer->device)->name, new_keypair->internal_id, new_keypair->entry.peer->internal_id); + WARN_ON(!index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry)); + up_write(&handshake->lock); return true; fail: - up_read(&handshake->lock); + up_write(&handshake->lock); return false; } |