diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-07-31 07:03:07 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-07-31 07:19:52 +0200 |
commit | 488b7dcba7a5dcf5d65349992c7fc32b3d9c17d1 (patch) | |
tree | 847677cfc90dee98bfe86d64a991f3ff0d8f428d | |
parent | 0c942d003c4291fe05d0de296ac040c7b0d0503c (diff) |
peer: simplify rcu reference counts
Use RCU reference counts only when we must, and otherwise use a more
reasonably named function.
Reported-by: Jann Horn <jann@thejh.net>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r-- | src/allowedips.c | 2 | ||||
-rw-r--r-- | src/hashtables.c | 4 | ||||
-rw-r--r-- | src/netlink.c | 8 | ||||
-rw-r--r-- | src/peer.c | 10 | ||||
-rw-r--r-- | src/peer.h | 8 | ||||
-rw-r--r-- | src/queueing.h | 10 | ||||
-rw-r--r-- | src/receive.c | 2 | ||||
-rw-r--r-- | src/send.c | 4 | ||||
-rw-r--r-- | src/timers.c | 5 |
9 files changed, 23 insertions, 30 deletions
diff --git a/src/allowedips.c b/src/allowedips.c index 07b2a3c..aecb390 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -182,7 +182,7 @@ static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rc rcu_read_lock_bh(); node = find_node(rcu_dereference_bh(root), bits, ip); if (node) - peer = peer_get(node->peer); + peer = peer_get_maybe_zero(node->peer); rcu_read_unlock_bh(); return peer; } diff --git a/src/hashtables.c b/src/hashtables.c index 0e5235d..ab0f622 100644 --- a/src/hashtables.c +++ b/src/hashtables.c @@ -48,7 +48,7 @@ struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, c break; } } - peer = peer_get(peer); + peer = peer_get_maybe_zero(peer); rcu_read_unlock_bh(); return peer; } @@ -159,7 +159,7 @@ struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *tab } } if (likely(entry)) { - entry->peer = peer_get(entry->peer); + entry->peer = peer_get_maybe_zero(entry->peer); if (unlikely(!entry->peer)) entry = NULL; } diff --git a/src/netlink.c b/src/netlink.c index aef743f..90c7aa2 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -221,8 +221,8 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) out: peer_put(last_peer_cursor); - if (!ret && !done) - next_peer_cursor = peer_rcu_get(next_peer_cursor); + if (!ret && !done && next_peer_cursor) + peer_get(next_peer_cursor); mutex_unlock(&wg->device_update_lock); rtnl_unlock(); @@ -326,9 +326,11 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) up_read(&wg->static_identity.lock); ret = -ENOMEM; - peer = peer_rcu_get(peer_create(wg, public_key, preshared_key)); + peer = peer_create(wg, public_key, preshared_key); if (!peer) goto out; + /* Take additional reference, as though we've just been looked up. */ + peer_get(peer); } ret = 0; @@ -63,7 +63,7 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ return peer; } -struct wireguard_peer *peer_get(struct wireguard_peer *peer) +struct wireguard_peer *peer_get_maybe_zero(struct wireguard_peer *peer) { RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking peer reference without holding the RCU read lock"); if (unlikely(!peer || !kref_get_unless_zero(&peer->refcount))) @@ -71,14 +71,6 @@ struct wireguard_peer *peer_get(struct wireguard_peer *peer) return peer; } -struct wireguard_peer *peer_rcu_get(struct wireguard_peer *peer) -{ - rcu_read_lock_bh(); - peer = peer_get(peer); - rcu_read_unlock_bh(); - return peer; -} - /* We have a separate "remove" function to get rid of the final reference because * peer_list, clearing handshakes, and flushing all require mutexes which requires * sleeping, which must only be done from certain contexts. @@ -62,9 +62,11 @@ struct wireguard_peer { struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); -struct wireguard_peer *peer_get(struct wireguard_peer *peer); -struct wireguard_peer *peer_rcu_get(struct wireguard_peer *peer); - +struct wireguard_peer * __must_check peer_get_maybe_zero(struct wireguard_peer *peer); +static inline void peer_get(struct wireguard_peer *peer) +{ + kref_get(&peer->refcount); +} void peer_put(struct wireguard_peer *peer); void peer_remove(struct wireguard_peer *peer); void peer_remove_all(struct wireguard_device *wg); diff --git a/src/queueing.h b/src/queueing.h index c17b0d8..3fb7b5c 100644 --- a/src/queueing.h +++ b/src/queueing.h @@ -132,20 +132,14 @@ static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_q static inline void queue_enqueue_per_peer(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) { - struct wireguard_peer *peer = peer_rcu_get(PACKET_PEER(skb)); - atomic_set(&PACKET_CB(skb)->state, state); - queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &queue->work); - peer_put(peer); + queue_work_on(cpumask_choose_online(&PACKET_PEER(skb)->serial_work_cpu, PACKET_PEER(skb)->internal_id), PACKET_PEER(skb)->device->packet_crypt_wq, &queue->work); } static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) { - struct wireguard_peer *peer = peer_rcu_get(PACKET_PEER(skb)); - atomic_set(&PACKET_CB(skb)->state, state); - napi_schedule(&peer->napi); - peer_put(peer); + napi_schedule(&PACKET_PEER(skb)->napi); } #ifdef DEBUG diff --git a/src/receive.c b/src/receive.c index 732ac2b..5e231c9 100644 --- a/src/receive.c +++ b/src/receive.c @@ -448,7 +448,7 @@ static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb return; } - /* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_rcu_get(). */ + /* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_get(). */ peer = PACKET_PEER(skb); ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); @@ -61,7 +61,7 @@ void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool i if (!has_expired(peer->last_sent_handshake, REKEY_TIMEOUT)) return; - peer = peer_rcu_get(peer); + peer_get(peer); /* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */ if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work)) peer_put(peer); /* If the work was already queued, we want to drop the extra reference */ @@ -320,7 +320,7 @@ void packet_send_staged_packets(struct wireguard_peer *peer) } packets.prev->next = NULL; - peer_rcu_get(keypair->entry.peer); + peer_get(keypair->entry.peer); PACKET_CB(packets.next)->keypair = keypair; packet_create_data(packets.next); return; diff --git a/src/timers.c b/src/timers.c index f1549a7..e8bb101 100644 --- a/src/timers.c +++ b/src/timers.c @@ -20,7 +20,10 @@ */ #define peer_get_from_timer(timer_name) \ - struct wireguard_peer *peer = peer_rcu_get(from_timer(peer, timer, timer_name)); \ + struct wireguard_peer *peer; \ + rcu_read_lock_bh(); \ + peer = peer_get_maybe_zero(from_timer(peer, timer, timer_name)); \ + rcu_read_unlock_bh(); \ if (unlikely(!peer)) \ return; |