summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-07-31 07:03:07 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-07-31 07:19:52 +0200
commit488b7dcba7a5dcf5d65349992c7fc32b3d9c17d1 (patch)
tree847677cfc90dee98bfe86d64a991f3ff0d8f428d
parent0c942d003c4291fe05d0de296ac040c7b0d0503c (diff)
peer: simplify rcu reference counts
Use RCU reference counts only when we must, and otherwise use a more reasonably named function. Reported-by: Jann Horn <jann@thejh.net> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--src/allowedips.c2
-rw-r--r--src/hashtables.c4
-rw-r--r--src/netlink.c8
-rw-r--r--src/peer.c10
-rw-r--r--src/peer.h8
-rw-r--r--src/queueing.h10
-rw-r--r--src/receive.c2
-rw-r--r--src/send.c4
-rw-r--r--src/timers.c5
9 files changed, 23 insertions, 30 deletions
diff --git a/src/allowedips.c b/src/allowedips.c
index 07b2a3c..aecb390 100644
--- a/src/allowedips.c
+++ b/src/allowedips.c
@@ -182,7 +182,7 @@ static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rc
rcu_read_lock_bh();
node = find_node(rcu_dereference_bh(root), bits, ip);
if (node)
- peer = peer_get(node->peer);
+ peer = peer_get_maybe_zero(node->peer);
rcu_read_unlock_bh();
return peer;
}
diff --git a/src/hashtables.c b/src/hashtables.c
index 0e5235d..ab0f622 100644
--- a/src/hashtables.c
+++ b/src/hashtables.c
@@ -48,7 +48,7 @@ struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, c
break;
}
}
- peer = peer_get(peer);
+ peer = peer_get_maybe_zero(peer);
rcu_read_unlock_bh();
return peer;
}
@@ -159,7 +159,7 @@ struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *tab
}
}
if (likely(entry)) {
- entry->peer = peer_get(entry->peer);
+ entry->peer = peer_get_maybe_zero(entry->peer);
if (unlikely(!entry->peer))
entry = NULL;
}
diff --git a/src/netlink.c b/src/netlink.c
index aef743f..90c7aa2 100644
--- a/src/netlink.c
+++ b/src/netlink.c
@@ -221,8 +221,8 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
out:
peer_put(last_peer_cursor);
- if (!ret && !done)
- next_peer_cursor = peer_rcu_get(next_peer_cursor);
+ if (!ret && !done && next_peer_cursor)
+ peer_get(next_peer_cursor);
mutex_unlock(&wg->device_update_lock);
rtnl_unlock();
@@ -326,9 +326,11 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
up_read(&wg->static_identity.lock);
ret = -ENOMEM;
- peer = peer_rcu_get(peer_create(wg, public_key, preshared_key));
+ peer = peer_create(wg, public_key, preshared_key);
if (!peer)
goto out;
+ /* Take additional reference, as though we've just been looked up. */
+ peer_get(peer);
}
ret = 0;
diff --git a/src/peer.c b/src/peer.c
index e8081f5..e115ac1 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -63,7 +63,7 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_
return peer;
}
-struct wireguard_peer *peer_get(struct wireguard_peer *peer)
+struct wireguard_peer *peer_get_maybe_zero(struct wireguard_peer *peer)
{
RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking peer reference without holding the RCU read lock");
if (unlikely(!peer || !kref_get_unless_zero(&peer->refcount)))
@@ -71,14 +71,6 @@ struct wireguard_peer *peer_get(struct wireguard_peer *peer)
return peer;
}
-struct wireguard_peer *peer_rcu_get(struct wireguard_peer *peer)
-{
- rcu_read_lock_bh();
- peer = peer_get(peer);
- rcu_read_unlock_bh();
- return peer;
-}
-
/* We have a separate "remove" function to get rid of the final reference because
* peer_list, clearing handshakes, and flushing all require mutexes which requires
* sleeping, which must only be done from certain contexts.
diff --git a/src/peer.h b/src/peer.h
index 088a6ee..70120ee 100644
--- a/src/peer.h
+++ b/src/peer.h
@@ -62,9 +62,11 @@ struct wireguard_peer {
struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]);
-struct wireguard_peer *peer_get(struct wireguard_peer *peer);
-struct wireguard_peer *peer_rcu_get(struct wireguard_peer *peer);
-
+struct wireguard_peer * __must_check peer_get_maybe_zero(struct wireguard_peer *peer);
+static inline void peer_get(struct wireguard_peer *peer)
+{
+ kref_get(&peer->refcount);
+}
void peer_put(struct wireguard_peer *peer);
void peer_remove(struct wireguard_peer *peer);
void peer_remove_all(struct wireguard_device *wg);
diff --git a/src/queueing.h b/src/queueing.h
index c17b0d8..3fb7b5c 100644
--- a/src/queueing.h
+++ b/src/queueing.h
@@ -132,20 +132,14 @@ static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_q
static inline void queue_enqueue_per_peer(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state)
{
- struct wireguard_peer *peer = peer_rcu_get(PACKET_PEER(skb));
-
atomic_set(&PACKET_CB(skb)->state, state);
- queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &queue->work);
- peer_put(peer);
+ queue_work_on(cpumask_choose_online(&PACKET_PEER(skb)->serial_work_cpu, PACKET_PEER(skb)->internal_id), PACKET_PEER(skb)->device->packet_crypt_wq, &queue->work);
}
static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state)
{
- struct wireguard_peer *peer = peer_rcu_get(PACKET_PEER(skb));
-
atomic_set(&PACKET_CB(skb)->state, state);
- napi_schedule(&peer->napi);
- peer_put(peer);
+ napi_schedule(&PACKET_PEER(skb)->napi);
}
#ifdef DEBUG
diff --git a/src/receive.c b/src/receive.c
index 732ac2b..5e231c9 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -448,7 +448,7 @@ static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb
return;
}
- /* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_rcu_get(). */
+ /* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_get(). */
peer = PACKET_PEER(skb);
ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu);
diff --git a/src/send.c b/src/send.c
index 7a8bea1..90dab85 100644
--- a/src/send.c
+++ b/src/send.c
@@ -61,7 +61,7 @@ void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool i
if (!has_expired(peer->last_sent_handshake, REKEY_TIMEOUT))
return;
- peer = peer_rcu_get(peer);
+ peer_get(peer);
/* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */
if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work))
peer_put(peer); /* If the work was already queued, we want to drop the extra reference */
@@ -320,7 +320,7 @@ void packet_send_staged_packets(struct wireguard_peer *peer)
}
packets.prev->next = NULL;
- peer_rcu_get(keypair->entry.peer);
+ peer_get(keypair->entry.peer);
PACKET_CB(packets.next)->keypair = keypair;
packet_create_data(packets.next);
return;
diff --git a/src/timers.c b/src/timers.c
index f1549a7..e8bb101 100644
--- a/src/timers.c
+++ b/src/timers.c
@@ -20,7 +20,10 @@
*/
#define peer_get_from_timer(timer_name) \
- struct wireguard_peer *peer = peer_rcu_get(from_timer(peer, timer, timer_name)); \
+ struct wireguard_peer *peer; \
+ rcu_read_lock_bh(); \
+ peer = peer_get_maybe_zero(from_timer(peer, timer, timer_name)); \
+ rcu_read_unlock_bh(); \
if (unlikely(!peer)) \
return;