diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-08-01 15:59:37 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-08-03 00:14:18 +0200 |
commit | 9f4cc375561f34a6512f3789edbbd3d7a900145f (patch) | |
tree | 06c28c1a36d2b03c79fa6bbffb1a39cd49858dac | |
parent | 7dbb9eec2638b41a0b95f3ca196d3c3a8550173c (diff) |
peer: ensure destruction doesn't race
Completely rework peer removal to ensure peers don't jump between
contexts and create races.
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | src/compat/compat.h | 3 | ||||
-rw-r--r-- | src/cookie.c | 8 | ||||
-rw-r--r-- | src/hashtables.c | 6 | ||||
-rw-r--r-- | src/hashtables.h | 2 | ||||
-rw-r--r-- | src/noise.c | 58 | ||||
-rw-r--r-- | src/noise.h | 2 | ||||
-rw-r--r-- | src/peer.c | 43 | ||||
-rw-r--r-- | src/peer.h | 1 | ||||
-rw-r--r-- | src/receive.c | 41 | ||||
-rw-r--r-- | src/send.c | 34 | ||||
-rw-r--r-- | src/timers.c | 60 |
11 files changed, 148 insertions, 110 deletions
diff --git a/src/compat/compat.h b/src/compat/compat.h index 5b3075b..86df5f3 100644 --- a/src/compat/compat.h +++ b/src/compat/compat.h @@ -51,6 +51,9 @@ #ifndef READ_ONCE #define READ_ONCE ACCESS_ONCE #endif +#ifndef WRITE_ONCE +#define WRITE_ONCE(p, v) (ACCESS_ONCE(p) = (v)) +#endif #if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 17, 0) #include "udp_tunnel/udp_tunnel_partial_compat.h" diff --git a/src/cookie.c b/src/cookie.c index bc6d8be..9268630 100644 --- a/src/cookie.c +++ b/src/cookie.c @@ -165,15 +165,9 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua { u8 cookie[COOKIE_LEN]; struct wireguard_peer *peer = NULL; - struct index_hashtable_entry *entry; bool ret; - rcu_read_lock_bh(); - entry = index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index); - if (likely(entry)) - peer = entry->peer; - rcu_read_unlock_bh(); - if (unlikely(!peer)) + if (unlikely(!index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index, &peer))) return; down_read(&peer->latest_cookie.lock); diff --git a/src/hashtables.c b/src/hashtables.c index 03b9e21..ac6df59 100644 --- a/src/hashtables.c +++ b/src/hashtables.c @@ -152,7 +152,7 @@ void index_hashtable_remove(struct index_hashtable *table, struct index_hashtabl } /* Returns a strong reference to a entry->peer */ -struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index) +struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer) { struct index_hashtable_entry *iter_entry, *entry = NULL; @@ -166,7 +166,9 @@ struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *tab } if (likely(entry)) { entry->peer = peer_get_maybe_zero(entry->peer); - if (unlikely(!entry->peer)) + if (likely(entry->peer)) + *peer = entry->peer; + else entry = NULL; } rcu_read_unlock_bh(); diff --git a/src/hashtables.h b/src/hashtables.h index a2ef6f0..f64cd24 100644 --- a/src/hashtables.h +++ b/src/hashtables.h @@ -47,6 +47,6 @@ void index_hashtable_init(struct index_hashtable *table); __le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry); bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new); void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry); -struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index); +struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer); #endif /* _WG_HASHTABLES_H */ diff --git a/src/noise.c b/src/noise.c index a1e094b..0f6e51b 100644 --- a/src/noise.c +++ b/src/noise.c @@ -103,24 +103,23 @@ static struct noise_keypair *keypair_create(struct wireguard_peer *peer) static void keypair_free_rcu(struct rcu_head *rcu) { - struct noise_keypair *keypair = container_of(rcu, struct noise_keypair, rcu); - - net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id); - kzfree(keypair); + kzfree(container_of(rcu, struct noise_keypair, rcu)); } static void keypair_free_kref(struct kref *kref) { struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount); - + net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id); index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); call_rcu_bh(&keypair->rcu, keypair_free_rcu); } -void noise_keypair_put(struct noise_keypair *keypair) +void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now) { if (unlikely(!keypair)) return; + if (unlikely(unreference_now)) + index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); kref_put(&keypair->refcount, keypair_free_kref); } @@ -139,13 +138,13 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs) spin_lock_bh(&keypairs->keypair_update_lock); old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->previous_keypair, NULL); - noise_keypair_put(old); + noise_keypair_put(old, true); old = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->next_keypair, NULL); - noise_keypair_put(old); + noise_keypair_put(old, true); old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->current_keypair, NULL); - noise_keypair_put(old); + noise_keypair_put(old, true); spin_unlock_bh(&keypairs->keypair_update_lock); } @@ -171,7 +170,7 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai */ RCU_INIT_POINTER(keypairs->next_keypair, NULL); rcu_assign_pointer(keypairs->previous_keypair, next_keypair); - noise_keypair_put(current_keypair); + noise_keypair_put(current_keypair, true); } else /* If there wasn't an existing next keypair, we replace the * previous with the current one. */ @@ -179,7 +178,7 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai /* At this point we can get rid of the old previous keypair, and set up * the new keypair. */ - noise_keypair_put(previous_keypair); + noise_keypair_put(previous_keypair, true); rcu_assign_pointer(keypairs->current_keypair, new_keypair); } else { /* If we're the responder, it means we can't use the new keypair until @@ -188,9 +187,9 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai * in the new next one. */ rcu_assign_pointer(keypairs->next_keypair, new_keypair); - noise_keypair_put(next_keypair); + noise_keypair_put(next_keypair, true); RCU_INIT_POINTER(keypairs->previous_keypair, NULL); - noise_keypair_put(previous_keypair); + noise_keypair_put(previous_keypair, true); } spin_unlock_bh(&keypairs->keypair_update_lock); } @@ -218,7 +217,7 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k */ old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock))); - noise_keypair_put(old_keypair); + noise_keypair_put(old_keypair, true); rcu_assign_pointer(keypairs->current_keypair, received_keypair); RCU_INIT_POINTER(keypairs->next_keypair, NULL); @@ -542,7 +541,7 @@ out: struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg) { struct noise_handshake *handshake; - struct wireguard_peer *ret_peer = NULL; + struct wireguard_peer *peer = NULL, *ret_peer = NULL; u8 key[NOISE_SYMMETRIC_KEY_LEN]; u8 hash[NOISE_HASH_LEN]; u8 chaining_key[NOISE_HASH_LEN]; @@ -556,7 +555,7 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake if (unlikely(!wg->static_identity.has_identity)) goto out; - handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index); + handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer); if (unlikely(!handshake)) goto out; @@ -601,11 +600,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake handshake->remote_index = src->sender_index; handshake->state = HANDSHAKE_CONSUMED_RESPONSE; up_write(&handshake->lock); - ret_peer = handshake->entry.peer; + ret_peer = peer; goto out; fail: - peer_put(handshake->entry.peer); + peer_put(peer); out: memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN); memzero_explicit(hash, NOISE_HASH_LEN); @@ -619,14 +618,15 @@ out: bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs) { struct noise_keypair *new_keypair; + bool ret = false; down_write(&handshake->lock); if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE) - goto fail; + goto out; new_keypair = keypair_create(handshake->entry.peer); if (!new_keypair) - goto fail; + goto out; new_keypair->i_am_the_initiator = handshake->state == HANDSHAKE_CONSUMED_RESPONSE; new_keypair->remote_index = handshake->remote_index; @@ -636,14 +636,16 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key); handshake_zero(handshake); - add_new_keypair(keypairs, new_keypair); - net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id); - WARN_ON(!index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry)); - up_write(&handshake->lock); - - return true; + rcu_read_lock_bh(); + if (likely(!container_of(handshake, struct wireguard_peer, handshake)->is_dead)) { + add_new_keypair(keypairs, new_keypair); + net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id); + ret = index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry); + } else + kzfree(new_keypair); + rcu_read_unlock_bh(); -fail: +out: up_write(&handshake->lock); - return false; + return ret; } diff --git a/src/noise.h b/src/noise.h index 5804acf..be59587 100644 --- a/src/noise.h +++ b/src/noise.h @@ -95,7 +95,7 @@ struct wireguard_device; void noise_init(void); bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer); void noise_handshake_clear(struct noise_handshake *handshake); -void noise_keypair_put(struct noise_keypair *keypair); +void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now); struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair); void noise_keypairs_clear(struct noise_keypairs *keypairs); bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair); @@ -86,18 +86,40 @@ void peer_remove(struct wireguard_peer *peer) if (unlikely(!peer)) return; lockdep_assert_held(&peer->device->device_update_lock); + + /* Remove from configuration-time lookup structures so new packets can't enter. */ + list_del_init(&peer->peer_list); allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, &peer->device->device_update_lock); pubkey_hashtable_remove(&peer->device->peer_hashtable, peer); - skb_queue_purge(&peer->staged_packet_queue); - noise_handshake_clear(&peer->handshake); + + /* Mark as dead, so that we don't allow jumping contexts after. */ + WRITE_ONCE(peer->is_dead, true); + synchronize_rcu_bh(); + + /* Now that no more keypairs can be created for this peer, we destroy existing ones. */ noise_keypairs_clear(&peer->keypairs); - list_del_init(&peer->peer_list); + + /* Destroy all ongoing timers that were in-flight at the beginning of this function. */ timers_stop(peer); - flush_workqueue(peer->device->packet_crypt_wq); /* The first flush is for encrypt/decrypt. */ - flush_workqueue(peer->device->packet_crypt_wq); /* The second.1 flush is for send (but not receive, since that's napi). */ - napi_disable(&peer->napi); /* The second.2 flush is for receive (but not send, since that's wq). */ - flush_workqueue(peer->device->handshake_send_wq); + + /* The transition between packet encryption/decryption queues isn't guarded + * by is_dead, but each reference's life is strictly bounded by two + * generations: once for parallel crypto and once for serial ingestion, + * so we can simply flush twice, and be sure that we no longer have references + * inside these queues. + * + * a) For encrypt/decrypt. */ + flush_workqueue(peer->device->packet_crypt_wq); + /* b.1) For send (but not receive, since that's napi). */ + flush_workqueue(peer->device->packet_crypt_wq); + /* b.2.1) For receive (but not send, since that's wq). */ + napi_disable(&peer->napi); + /* b.2.1) It's now safe to remove the napi struct, which must be done here from process context. */ netif_napi_del(&peer->napi); + + /* Ensure any workstructs we own (like transmit_handshake_work or clear_peer_work) no longer are in use. */ + flush_workqueue(peer->device->handshake_send_wq); + --peer->device->num_peers; peer_put(peer); } @@ -105,8 +127,6 @@ void peer_remove(struct wireguard_peer *peer) static void rcu_release(struct rcu_head *rcu) { struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu); - - pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); dst_cache_destroy(&peer->endpoint_cache); packet_queue_free(&peer->rx_queue, false); packet_queue_free(&peer->tx_queue, false); @@ -116,9 +136,12 @@ static void rcu_release(struct rcu_head *rcu) static void kref_release(struct kref *refcount) { struct wireguard_peer *peer = container_of(refcount, struct wireguard_peer, refcount); - + pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); + /* Remove ourself from dynamic runtime lookup structures, now that the last reference is gone. */ index_hashtable_remove(&peer->device->index_hashtable, &peer->handshake.entry); + /* Remove any lingering packets that didn't have a chance to be transmitted. */ skb_queue_purge(&peer->staged_packet_queue); + /* Free the memory used. */ call_rcu_bh(&peer->rcu, rcu_release); } @@ -58,6 +58,7 @@ struct wireguard_peer { struct list_head peer_list; u64 internal_id; struct napi_struct napi; + bool is_dead; }; struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); diff --git a/src/receive.c b/src/receive.c index 12af8ed..d3a698a 100644 --- a/src/receive.c +++ b/src/receive.c @@ -282,9 +282,9 @@ out: } #include "selftest/counter.h" -static void packet_consume_data_done(struct sk_buff *skb, struct endpoint *endpoint) +static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff *skb, struct endpoint *endpoint) { - struct wireguard_peer *peer = PACKET_PEER(skb), *routed_peer; + struct wireguard_peer *routed_peer; struct net_device *dev = peer->device->dev; unsigned int len, len_before_trim; @@ -400,11 +400,11 @@ int packet_rx_poll(struct napi_struct *napi, int budget) goto next; skb_reset(skb); - packet_consume_data_done(skb, &endpoint); + packet_consume_data_done(peer, skb, &endpoint); free = false; next: - noise_keypair_put(keypair); + noise_keypair_put(keypair, false); peer_put(peer); if (unlikely(free)) dev_kfree_skb(skb); @@ -436,32 +436,31 @@ void packet_decrypt_worker(struct work_struct *work) static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb) { - struct wireguard_peer *peer; + struct wireguard_peer *peer = NULL; __le32 idx = ((struct message_data *)skb->data)->key_idx; int ret; rcu_read_lock_bh(); - PACKET_CB(skb)->keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); - rcu_read_unlock_bh(); - if (unlikely(!PACKET_CB(skb)->keypair)) { - dev_kfree_skb(skb); - return; - } + PACKET_CB(skb)->keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx, &peer); + if (unlikely(!noise_keypair_get(PACKET_CB(skb)->keypair))) + goto err_keypair; - /* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_get(). */ - peer = PACKET_PEER(skb); + if (unlikely(peer->is_dead)) + goto err; ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); - if (likely(!ret)) - return; /* Successful. No need to drop references below. */ - - if (ret == -EPIPE) + if (unlikely(ret == -EPIPE)) queue_enqueue_per_peer(&peer->rx_queue, skb, PACKET_STATE_DEAD); - else { - peer_put(peer); - noise_keypair_put(PACKET_CB(skb)->keypair); - dev_kfree_skb(skb); + if (likely(!ret || ret == -EPIPE)) { + rcu_read_unlock_bh(); + return; } +err: + noise_keypair_put(PACKET_CB(skb)->keypair, false); +err_keypair: + rcu_read_unlock_bh(); + peer_put(peer); + dev_kfree_skb(skb); } void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) @@ -58,13 +58,16 @@ void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool i /* First checking the timestamp here is just an optimization; it will * be caught while properly locked inside the actual work queue. */ - if (!has_expired(peer->last_sent_handshake, REKEY_TIMEOUT)) - return; + rcu_read_lock_bh(); + if (!has_expired(peer->last_sent_handshake, REKEY_TIMEOUT) || unlikely(peer->is_dead)) + goto out; peer_get(peer); /* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */ if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work)) peer_put(peer); /* If the work was already queued, we want to drop the extra reference */ +out: + rcu_read_unlock_bh(); } void packet_send_handshake_response(struct wireguard_peer *peer) @@ -233,7 +236,7 @@ void packet_tx_worker(struct work_struct *work) else skb_free_null_queue(first); - noise_keypair_put(keypair); + noise_keypair_put(keypair, false); peer_put(peer); } } @@ -266,19 +269,22 @@ static void packet_create_data(struct sk_buff *first) { struct wireguard_peer *peer = PACKET_PEER(first); struct wireguard_device *wg = peer->device; - int ret; + int ret = -EINVAL; - ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); - if (likely(!ret)) - return; /* Successful. No need to fall through to drop references below. */ + rcu_read_lock_bh(); + if (unlikely(peer->is_dead)) + goto err; - if (ret == -EPIPE) + ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); + if (unlikely(ret == -EPIPE)) queue_enqueue_per_peer(&peer->tx_queue, first, PACKET_STATE_DEAD); - else { - peer_put(peer); - noise_keypair_put(PACKET_CB(first)->keypair); - skb_free_null_queue(first); - } +err: + rcu_read_unlock_bh(); + if (likely(!ret || ret == -EPIPE)) + return; + noise_keypair_put(PACKET_CB(first)->keypair, false); + peer_put(peer); + skb_free_null_queue(first); } void packet_send_staged_packets(struct wireguard_peer *peer) @@ -328,7 +334,7 @@ void packet_send_staged_packets(struct wireguard_peer *peer) out_invalid: key->is_valid = false; out_nokey: - noise_keypair_put(keypair); + noise_keypair_put(keypair, false); /* We orphan the packets if we're waiting on a handshake, so that they * don't block a socket's pool. diff --git a/src/timers.c b/src/timers.c index e8bb101..762152a 100644 --- a/src/timers.c +++ b/src/timers.c @@ -27,9 +27,20 @@ if (unlikely(!peer)) \ return; -static inline bool timers_active(struct wireguard_peer *peer) +static inline void mod_peer_timer(struct wireguard_peer *peer, struct timer_list *timer, unsigned long expires) { - return netif_running(peer->device->dev) && !list_empty(&peer->peer_list); + rcu_read_lock_bh(); + if (likely(netif_running(peer->device->dev) && !peer->is_dead)) + mod_timer(timer, expires); + rcu_read_unlock_bh(); +} + +static inline void del_peer_timer(struct wireguard_peer *peer, struct timer_list *timer) +{ + rcu_read_lock_bh(); + if (likely(netif_running(peer->device->dev) && !peer->is_dead)) + del_timer(timer); + rcu_read_unlock_bh(); } static void expired_retransmit_handshake(struct timer_list *timer) @@ -39,8 +50,7 @@ static void expired_retransmit_handshake(struct timer_list *timer) if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) { pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2); - if (likely(timers_active(peer))) - del_timer(&peer->timer_send_keepalive); + del_peer_timer(peer, &peer->timer_send_keepalive); /* We drop all packets without a keypair and don't try again, * if we try unsuccessfully for too long to make a handshake. */ @@ -49,8 +59,8 @@ static void expired_retransmit_handshake(struct timer_list *timer) /* We set a timer for destroying any residue that might be left * of a partial exchange. */ - if (likely(timers_active(peer)) && !timer_pending(&peer->timer_zero_key_material)) - mod_timer(&peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ); + if (!timer_pending(&peer->timer_zero_key_material)) + mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ); } else { ++peer->timer_handshake_attempts; pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, REKEY_TIMEOUT, peer->timer_handshake_attempts + 1); @@ -70,8 +80,7 @@ static void expired_send_keepalive(struct timer_list *timer) packet_send_keepalive(peer); if (peer->timer_need_another_keepalive) { peer->timer_need_another_keepalive = false; - if (likely(timers_active(peer))) - mod_timer(&peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ); + mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ); } peer_put(peer); } @@ -91,8 +100,12 @@ static void expired_zero_key_material(struct timer_list *timer) { peer_get_from_timer(timer_zero_key_material); - if (!queue_work(peer->device->handshake_send_wq, &peer->clear_peer_work)) /* Takes our reference. */ - peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */ + rcu_read_lock_bh(); + if (!peer->is_dead) { + if (!queue_work(peer->device->handshake_send_wq, &peer->clear_peer_work)) /* Should take our reference. */ + peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */ + } + rcu_read_unlock_bh(); } static void queued_expired_zero_key_material(struct work_struct *work) { @@ -116,16 +129,16 @@ static void expired_send_persistent_keepalive(struct timer_list *timer) /* Should be called after an authenticated data packet is sent. */ void timers_data_sent(struct wireguard_peer *peer) { - if (likely(timers_active(peer)) && !timer_pending(&peer->timer_new_handshake)) - mod_timer(&peer->timer_new_handshake, jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ); + if (!timer_pending(&peer->timer_new_handshake)) + mod_peer_timer(peer, &peer->timer_new_handshake, jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ); } /* Should be called after an authenticated data packet is received. */ void timers_data_received(struct wireguard_peer *peer) { - if (likely(timers_active(peer))) { + if (likely(netif_running(peer->device->dev))) { if (!timer_pending(&peer->timer_send_keepalive)) - mod_timer(&peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ); + mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ); else peer->timer_need_another_keepalive = true; } @@ -134,29 +147,25 @@ void timers_data_received(struct wireguard_peer *peer) /* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ void timers_any_authenticated_packet_sent(struct wireguard_peer *peer) { - if (likely(timers_active(peer))) - del_timer(&peer->timer_send_keepalive); + del_peer_timer(peer, &peer->timer_send_keepalive); } /* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ void timers_any_authenticated_packet_received(struct wireguard_peer *peer) { - if (likely(timers_active(peer))) - del_timer(&peer->timer_new_handshake); + del_peer_timer(peer, &peer->timer_new_handshake); } /* Should be called after a handshake initiation message is sent. */ void timers_handshake_initiated(struct wireguard_peer *peer) { - if (likely(timers_active(peer))) - mod_timer(&peer->timer_retransmit_handshake, jiffies + REKEY_TIMEOUT * HZ + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); + mod_peer_timer(peer, &peer->timer_retransmit_handshake, jiffies + REKEY_TIMEOUT * HZ + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); } /* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ void timers_handshake_complete(struct wireguard_peer *peer) { - if (likely(timers_active(peer))) - del_timer(&peer->timer_retransmit_handshake); + del_peer_timer(peer, &peer->timer_retransmit_handshake); peer->timer_handshake_attempts = 0; peer->sent_lastminute_handshake = false; getnstimeofday(&peer->walltime_last_handshake); @@ -165,15 +174,14 @@ void timers_handshake_complete(struct wireguard_peer *peer) /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ void timers_session_derived(struct wireguard_peer *peer) { - if (likely(timers_active(peer))) - mod_timer(&peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ); + mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ); } /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ void timers_any_authenticated_packet_traversal(struct wireguard_peer *peer) { - if (peer->persistent_keepalive_interval && likely(timers_active(peer))) - mod_timer(&peer->timer_persistent_keepalive, jiffies + peer->persistent_keepalive_interval * HZ); + if (peer->persistent_keepalive_interval) + mod_peer_timer(peer, &peer->timer_persistent_keepalive, jiffies + peer->persistent_keepalive_interval * HZ); } void timers_init(struct wireguard_peer *peer) |