diff options
-rw-r--r-- | src/main.c | 9 | ||||
-rw-r--r-- | src/queueing.c | 16 | ||||
-rw-r--r-- | src/queueing.h | 34 | ||||
-rw-r--r-- | src/receive.c | 108 | ||||
-rw-r--r-- | src/send.c | 103 |
5 files changed, 119 insertions, 151 deletions
@@ -30,13 +30,9 @@ static int __init mod_init(void) #endif noise_init(); - ret = crypt_ctx_cache_init(); - if (ret < 0) - goto err_packet; - ret = device_init(); if (ret < 0) - goto err_device; + goto err_packet; ret = netlink_init(); if (ret < 0) @@ -49,8 +45,6 @@ static int __init mod_init(void) err_netlink: device_uninit(); -err_device: - crypt_ctx_cache_uninit(); err_packet: return ret; } @@ -59,7 +53,6 @@ static void __exit mod_exit(void) { netlink_uninit(); device_uninit(); - crypt_ctx_cache_uninit(); pr_debug("WireGuard unloaded\n"); } diff --git a/src/queueing.c b/src/queueing.c index f1ae4f1..fa50511 100644 --- a/src/queueing.c +++ b/src/queueing.c @@ -1,9 +1,6 @@ /* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ #include "queueing.h" -#include <linux/slab.h> - -struct kmem_cache *crypt_ctx_cache __read_mostly; struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr) { @@ -44,16 +41,3 @@ void packet_queue_free(struct crypt_queue *queue, bool multicore) WARN_ON(!ptr_ring_empty_bh(&queue->ring)); ptr_ring_cleanup(&queue->ring, NULL); } - -int __init crypt_ctx_cache_init(void) -{ - crypt_ctx_cache = KMEM_CACHE(crypt_ctx, 0); - if (!crypt_ctx_cache) - return -ENOMEM; - return 0; -} - -void crypt_ctx_cache_uninit(void) -{ - kmem_cache_destroy(crypt_ctx_cache); -} diff --git a/src/queueing.h b/src/queueing.h index 9b9b6a6..62dee51 100644 --- a/src/queueing.h +++ b/src/queueing.h @@ -16,9 +16,6 @@ struct crypt_queue; struct sk_buff; /* queueing.c APIs: */ -extern struct kmem_cache *crypt_ctx_cache __read_mostly; -int crypt_ctx_cache_init(void); -void crypt_ctx_cache_uninit(void); int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore, unsigned int len); void packet_queue_free(struct crypt_queue *queue, bool multicore); struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr); @@ -41,22 +38,16 @@ void packet_handshake_send_worker(struct work_struct *work); void packet_tx_worker(struct work_struct *work); void packet_encrypt_worker(struct work_struct *work); +enum packet_state { PACKET_STATE_UNCRYPTED, PACKET_STATE_CRYPTED, PACKET_STATE_DEAD }; struct packet_cb { u64 nonce; + struct noise_keypair *keypair; + atomic_t state; u8 ds; }; +#define PACKET_PEER(skb) ((struct packet_cb *)skb->cb)->keypair->entry.peer #define PACKET_CB(skb) ((struct packet_cb *)skb->cb) -struct crypt_ctx { - union { - struct sk_buff_head packets; - struct sk_buff *skb; - }; - atomic_t is_finished; - struct wireguard_peer *peer; - struct noise_keypair *keypair; -}; - /* Returns either the correct skb->protocol value, or 0 if invalid. */ static inline __be16 skb_examine_untrusted_ip_hdr(struct sk_buff *skb) { @@ -130,19 +121,30 @@ static inline int cpumask_next_online(int *next) return cpu; } -static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_queue, struct crypt_queue *peer_queue, struct crypt_ctx *ctx, struct workqueue_struct *wq, int *next_cpu) +static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_queue, struct crypt_queue *peer_queue, struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu) { int cpu; - if (unlikely(ptr_ring_produce_bh(&peer_queue->ring, ctx))) + atomic_set(&PACKET_CB(skb)->state, PACKET_STATE_UNCRYPTED); + if (unlikely(ptr_ring_produce_bh(&peer_queue->ring, skb))) return -ENOSPC; cpu = cpumask_next_online(next_cpu); - if (unlikely(ptr_ring_produce_bh(&device_queue->ring, ctx))) + if (unlikely(ptr_ring_produce_bh(&device_queue->ring, skb))) return -EPIPE; queue_work_on(cpu, wq, &per_cpu_ptr(device_queue->worker, cpu)->work); return 0; } +static inline void queue_enqueue_per_peer(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) +{ + struct wireguard_peer *peer = peer_rcu_get(PACKET_PEER(skb)); + + atomic_set(&PACKET_CB(skb)->state, state); + queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &queue->work); + peer_put(peer); + +} + #ifdef DEBUG bool packet_counter_selftest(void); #endif diff --git a/src/receive.c b/src/receive.c index 9e03bcf..b3ad9c3 100644 --- a/src/receive.c +++ b/src/receive.c @@ -277,15 +277,15 @@ out: } #include "selftest/counter.h" -static void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key) +static void packet_consume_data_done(struct sk_buff *skb, struct endpoint *endpoint) { + struct wireguard_peer *peer = PACKET_PEER(skb), *routed_peer; struct net_device *dev = peer->device->dev; - struct wireguard_peer *routed_peer; unsigned int len; socket_set_peer_endpoint(peer, endpoint); - if (unlikely(used_new_key)) { + if (unlikely(noise_received_with_keypair(&peer->keypairs, PACKET_CB(skb)->keypair))) { timers_handshake_complete(peer); packet_send_staged_packets(peer); } @@ -364,27 +364,42 @@ continue_processing: void packet_rx_worker(struct work_struct *work) { - struct endpoint endpoint; - struct crypt_ctx *ctx; struct crypt_queue *queue = container_of(work, struct crypt_queue, work); + struct wireguard_peer *peer; + struct noise_keypair *keypair; + struct sk_buff *skb; + struct endpoint endpoint; + enum packet_state state; + bool free; local_bh_disable(); spin_lock_bh(&queue->ring.consumer_lock); - while ((ctx = __ptr_ring_peek(&queue->ring)) != NULL && atomic_read(&ctx->is_finished)) { + while ((skb = __ptr_ring_peek(&queue->ring)) != NULL && (state = atomic_read(&PACKET_CB(skb)->state)) != PACKET_STATE_UNCRYPTED) { __ptr_ring_discard_one(&queue->ring); - if (likely(ctx->skb)) { - if (likely(counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce) && !socket_endpoint_from_skb(&endpoint, ctx->skb))) { - skb_reset(ctx->skb); - packet_consume_data_done(ctx->skb, ctx->peer, &endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair)); - } - else { - net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter); - dev_kfree_skb(ctx->skb); - } + peer = PACKET_PEER(skb); + keypair = PACKET_CB(skb)->keypair; + free = true; + + if (unlikely(state != PACKET_STATE_CRYPTED)) + goto next; + + if (unlikely(!counter_validate(&keypair->receiving.counter, PACKET_CB(skb)->nonce))) { + net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, keypair->receiving.counter.receive.counter); + goto next; } - noise_keypair_put(ctx->keypair); - peer_put(ctx->peer); - kmem_cache_free(crypt_ctx_cache, ctx); + + if (unlikely(socket_endpoint_from_skb(&endpoint, skb))) + goto next; + + skb_reset(skb); + packet_consume_data_done(skb, &endpoint); + free = false; + +next: + noise_keypair_put(keypair); + peer_put(peer); + if (unlikely(free)) + dev_kfree_skb(skb); } spin_unlock_bh(&queue->ring.consumer_lock); local_bh_enable(); @@ -392,65 +407,44 @@ void packet_rx_worker(struct work_struct *work) void packet_decrypt_worker(struct work_struct *work) { - struct crypt_ctx *ctx; struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; - struct wireguard_peer *peer; + struct sk_buff *skb; - while ((ctx = ptr_ring_consume_bh(&queue->ring)) != NULL) { - if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving))) { - dev_kfree_skb(ctx->skb); - ctx->skb = NULL; - } - /* Dereferencing ctx is unsafe once ctx->is_finished == true, so - * we take a reference here first. */ - peer = peer_rcu_get(ctx->peer); - atomic_set(&ctx->is_finished, true); - queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &peer->rx_queue.work); - peer_put(peer); + while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) { + if (likely(skb_decrypt(skb, &PACKET_CB(skb)->keypair->receiving))) + queue_enqueue_per_peer(&PACKET_PEER(skb)->rx_queue, skb, PACKET_STATE_CRYPTED); + else + queue_enqueue_per_peer(&PACKET_PEER(skb)->rx_queue, skb, PACKET_STATE_DEAD); } } static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb) { - struct crypt_ctx *ctx; - struct noise_keypair *keypair; + struct wireguard_peer *peer; __le32 idx = ((struct message_data *)skb->data)->key_idx; int ret; rcu_read_lock_bh(); - keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); + PACKET_CB(skb)->keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); rcu_read_unlock_bh(); - if (unlikely(!keypair)) { + if (unlikely(!PACKET_CB(skb)->keypair)) { dev_kfree_skb(skb); return; } - ctx = kmem_cache_alloc(crypt_ctx_cache, GFP_ATOMIC); - if (unlikely(!ctx)) { - dev_kfree_skb(skb); - peer_put(keypair->entry.peer); - noise_keypair_put(keypair); - return; - } - atomic_set(&ctx->is_finished, false); - ctx->keypair = keypair; - ctx->skb = skb; - /* We already have a reference to peer from index_hashtable_lookup. */ - ctx->peer = ctx->keypair->entry.peer; + /* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_rcu_get(). */ + peer = PACKET_PEER(skb); - ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &ctx->peer->rx_queue, ctx, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); + ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); if (likely(!ret)) return; /* Successful. No need to drop references below. */ - dev_kfree_skb(ctx->skb); - if (ret == -EPIPE) { - ctx->skb = NULL; - atomic_set(&ctx->is_finished, true); - queue_work_on(cpumask_choose_online(&ctx->peer->serial_work_cpu, ctx->peer->internal_id), ctx->peer->device->packet_crypt_wq, &ctx->peer->rx_queue.work); - } else { - noise_keypair_put(ctx->keypair); - peer_put(ctx->peer); - kmem_cache_free(crypt_ctx_cache, ctx); + if (ret == -EPIPE) + queue_enqueue_per_peer(&peer->rx_queue, skb, PACKET_STATE_DEAD); + else { + peer_put(peer); + noise_keypair_put(PACKET_CB(skb)->keypair); + dev_kfree_skb(skb); } } @@ -165,20 +165,26 @@ void packet_send_keepalive(struct wireguard_peer *peer) packet_send_staged_packets(peer); } -static void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer) +#define skb_walk_null_queue_safe(first, skb, next) for (skb = first, next = skb->next; skb; skb = next, next = skb ? skb->next : NULL) +static inline void skb_free_null_queue(struct sk_buff *first) { - struct sk_buff *skb, *tmp; - bool is_keepalive, data_sent = false; + struct sk_buff *skb, *next; + skb_walk_null_queue_safe (first, skb, next) + dev_kfree_skb(skb); +} - if (unlikely(skb_queue_empty(queue))) - return; +static void packet_create_data_done(struct sk_buff *first, struct wireguard_peer *peer) +{ + struct sk_buff *skb, *next; + bool is_keepalive, data_sent = false; timers_any_authenticated_packet_traversal(peer); - skb_queue_walk_safe (queue, skb, tmp) { + skb_walk_null_queue_safe (first, skb, next) { is_keepalive = skb->len == message_data_len(0); if (likely(!socket_send_skb_to_peer(peer, skb, PACKET_CB(skb)->ds) && !is_keepalive)) data_sent = true; } + if (likely(data_sent)) timers_data_sent(peer); @@ -188,78 +194,65 @@ static void packet_create_data_done(struct sk_buff_head *queue, struct wireguard void packet_tx_worker(struct work_struct *work) { struct crypt_queue *queue = container_of(work, struct crypt_queue, work); - struct crypt_ctx *ctx; + struct wireguard_peer *peer; + struct noise_keypair *keypair; + struct sk_buff *first; + enum packet_state state; spin_lock_bh(&queue->ring.consumer_lock); - while ((ctx = __ptr_ring_peek(&queue->ring)) != NULL && atomic_read(&ctx->is_finished)) { + while ((first = __ptr_ring_peek(&queue->ring)) != NULL && (state = atomic_read(&PACKET_CB(first)->state)) != PACKET_STATE_UNCRYPTED) { __ptr_ring_discard_one(&queue->ring); - packet_create_data_done(&ctx->packets, ctx->peer); - noise_keypair_put(ctx->keypair); - peer_put(ctx->peer); - kmem_cache_free(crypt_ctx_cache, ctx); + peer = PACKET_PEER(first); + keypair = PACKET_CB(first)->keypair; + + if (likely(state == PACKET_STATE_CRYPTED)) + packet_create_data_done(first, peer); + else + skb_free_null_queue(first); + + noise_keypair_put(keypair); + peer_put(peer); } spin_unlock_bh(&queue->ring.consumer_lock); } void packet_encrypt_worker(struct work_struct *work) { - struct crypt_ctx *ctx; struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; - struct sk_buff *skb, *tmp; - struct wireguard_peer *peer; + struct sk_buff *first, *skb, *next; bool have_simd = chacha20poly1305_init_simd(); - while ((ctx = ptr_ring_consume_bh(&queue->ring)) != NULL) { - skb_queue_walk_safe(&ctx->packets, skb, tmp) { - if (likely(skb_encrypt(skb, ctx->keypair, have_simd))) { + while ((first = ptr_ring_consume_bh(&queue->ring)) != NULL) { + skb_walk_null_queue_safe (first, skb, next) { + if (likely(skb_encrypt(skb, PACKET_CB(first)->keypair, have_simd))) skb_reset(skb); - } else { - __skb_unlink(skb, &ctx->packets); - dev_kfree_skb(skb); + else { + queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, PACKET_STATE_DEAD); + continue; } } - /* Dereferencing ctx is unsafe once ctx->is_finished == true, so - * we grab an additional reference to peer. */ - peer = peer_rcu_get(ctx->peer); - atomic_set(&ctx->is_finished, true); - queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &peer->tx_queue.work); - peer_put(peer); + queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, PACKET_STATE_CRYPTED); } chacha20poly1305_deinit_simd(have_simd); } -static void packet_create_data(struct wireguard_peer *peer, struct sk_buff_head *packets, struct noise_keypair *keypair) +static void packet_create_data(struct sk_buff *first) { - struct crypt_ctx *ctx; + struct wireguard_peer *peer = PACKET_PEER(first); struct wireguard_device *wg = peer->device; int ret; - ctx = kmem_cache_alloc(crypt_ctx_cache, GFP_ATOMIC); - if (unlikely(!ctx)) { - __skb_queue_purge(packets); - goto err_drop_refs; - } - /* This function consumes the passed references to peer and keypair. */ - atomic_set(&ctx->is_finished, false); - ctx->keypair = keypair; - ctx->peer = peer; - __skb_queue_head_init(&ctx->packets); - skb_queue_splice_tail(packets, &ctx->packets); - ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, ctx, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); + ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); if (likely(!ret)) return; /* Successful. No need to fall through to drop references below. */ - __skb_queue_purge(&ctx->packets); - if (ret == -EPIPE) { - atomic_set(&ctx->is_finished, true); - queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &peer->tx_queue.work); - return; - } else - kmem_cache_free(crypt_ctx_cache, ctx); - -err_drop_refs: - noise_keypair_put(keypair); - peer_put(peer); + if (ret == -EPIPE) + queue_enqueue_per_peer(&peer->tx_queue, first, PACKET_STATE_DEAD); + else { + peer_put(peer); + noise_keypair_put(PACKET_CB(first)->keypair); + skb_free_null_queue(first); + } } void packet_send_staged_packets(struct wireguard_peer *peer) @@ -299,8 +292,10 @@ void packet_send_staged_packets(struct wireguard_peer *peer) goto out_invalid; } - /* We pass off our peer and keypair references to the data subsystem and return. */ - packet_create_data(peer_rcu_get(peer), &packets, keypair); + packets.prev->next = NULL; + peer_rcu_get(keypair->entry.peer); + PACKET_CB(packets.next)->keypair = keypair; + packet_create_data(packets.next); return; out_invalid: |