diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2017-04-03 21:40:45 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2017-04-04 03:44:35 +0200 |
commit | ad05185de7870e47ae1b5b6a4b32cb31d2d9e155 (patch) | |
tree | 9bccb31dc60559183442d0062ab815d4543951b3 /src/data.c | |
parent | 0656a29de11c3c9ab2cd3d187c8bfc8507e78aaa (diff) |
data: simplify flow
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'src/data.c')
-rw-r--r-- | src/data.c | 117 |
1 files changed, 36 insertions, 81 deletions
@@ -18,7 +18,6 @@ struct encryption_ctx { struct padata_priv padata; struct sk_buff_head queue; - packet_create_data_callback_t callback; struct wireguard_peer *peer; struct noise_keypair *keypair; }; @@ -27,9 +26,7 @@ struct decryption_ctx { struct padata_priv padata; struct endpoint endpoint; struct sk_buff *skb; - packet_consume_data_callback_t callback; struct noise_keypair *keypair; - int ret; }; #ifdef CONFIG_WIREGUARD_PARALLEL @@ -225,7 +222,7 @@ static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_ bool have_simd = chacha20poly1305_init_simd(); skb_queue_walk_safe(queue, skb, tmp) { if (unlikely(!skb_encrypt(skb, keypair, have_simd))) { - skb_unlink(skb, queue); + __skb_unlink(skb, queue); kfree_skb(skb); continue; } @@ -236,32 +233,22 @@ static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_ } #ifdef CONFIG_WIREGUARD_PARALLEL -static void do_encryption(struct padata_priv *padata) +static void begin_parallel_encryption(struct padata_priv *padata) { struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata); - queue_encrypt_reset(&ctx->queue, ctx->keypair); padata_do_serial(padata); } -static void finish_encryption(struct padata_priv *padata) +static void finish_parallel_encryption(struct padata_priv *padata) { struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata); - - ctx->callback(&ctx->queue, ctx->peer); + packet_create_data_done(&ctx->queue, ctx->peer); atomic_dec(&ctx->peer->parallel_encryption_inflight); peer_put(ctx->peer); kmem_cache_free(encryption_ctx_cache, ctx); } -static inline int start_encryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) -{ - memset(priv, 0, sizeof(struct padata_priv)); - priv->parallel = do_encryption; - priv->serial = finish_encryption; - return padata_do_parallel(padata, priv, cb_cpu); -} - static inline unsigned int choose_cpu(__le32 key) { unsigned int cpu_index, cpu, cb_cpu; @@ -276,7 +263,7 @@ static inline unsigned int choose_cpu(__le32 key) } #endif -int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, packet_create_data_callback_t callback) +int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer) { int ret = -ENOKEY; struct noise_keypair *keypair; @@ -303,21 +290,21 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, #ifdef CONFIG_WIREGUARD_PARALLEL if ((skb_queue_len(queue) > 1 || queue->next->len > 256 || atomic_read(&peer->parallel_encryption_inflight) > 0) && cpumask_weight(cpu_online_mask) > 1) { - unsigned int cpu = choose_cpu(keypair->remote_index); struct encryption_ctx *ctx = kmem_cache_alloc(encryption_ctx_cache, GFP_ATOMIC); if (!ctx) goto serial_encrypt; skb_queue_head_init(&ctx->queue); skb_queue_splice_init(queue, &ctx->queue); - ctx->callback = callback; + memset(&ctx->padata, 0, sizeof(ctx->padata)); + ctx->padata.parallel = begin_parallel_encryption; + ctx->padata.serial = finish_parallel_encryption; ctx->keypair = keypair; ctx->peer = peer_rcu_get(peer); ret = -EBUSY; if (unlikely(!ctx->peer)) goto err_parallel; atomic_inc(&peer->parallel_encryption_inflight); - ret = start_encryption(peer->device->parallel_send, &ctx->padata, cpu); - if (unlikely(ret < 0)) { + if (unlikely(padata_do_parallel(peer->device->parallel_send, &ctx->padata, choose_cpu(keypair->remote_index)))) { atomic_dec(&peer->parallel_encryption_inflight); peer_put(ctx->peer); err_parallel: @@ -330,7 +317,7 @@ serial_encrypt: #endif { queue_encrypt_reset(queue, keypair); - callback(queue, peer); + packet_create_data_done(queue, peer); } return 0; @@ -344,83 +331,56 @@ err_rcu: static void begin_decrypt_packet(struct decryption_ctx *ctx) { - ctx->ret = socket_endpoint_from_skb(&ctx->endpoint, ctx->skb); - if (unlikely(ctx->ret < 0)) - goto err; - - ctx->ret = -ENOKEY; - if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving))) - goto err; - - ctx->ret = 0; - return; - -err: - peer_put(ctx->keypair->entry.peer); + if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) { + peer_put(ctx->keypair->entry.peer); + noise_keypair_put(ctx->keypair); + dev_kfree_skb(ctx->skb); + ctx->skb = NULL; + } } static void finish_decrypt_packet(struct decryption_ctx *ctx) { - struct noise_keypairs *keypairs; - bool used_new_key = false; - u64 nonce = PACKET_CB(ctx->skb)->nonce; - int ret = ctx->ret; - if (ret) - goto err; + bool used_new_key; - keypairs = &ctx->keypair->entry.peer->keypairs; - ret = counter_validate(&ctx->keypair->receiving.counter, nonce) ? 0 : -ERANGE; + if (!ctx->skb) + return; - if (likely(!ret)) - used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair); - else { - net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", nonce, ctx->keypair->receiving.counter.receive.counter); + if (unlikely(!counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce))) { + net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter); peer_put(ctx->keypair->entry.peer); - goto err; + noise_keypair_put(ctx->keypair); + dev_kfree_skb(ctx->skb); + return; } - noise_keypair_put(ctx->keypair); - + used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair); skb_reset(ctx->skb); - ctx->callback(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key, 0); - return; - -err: + packet_consume_data_done(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key); noise_keypair_put(ctx->keypair); - ctx->callback(ctx->skb, NULL, NULL, false, ret); } #ifdef CONFIG_WIREGUARD_PARALLEL -static void do_decryption(struct padata_priv *padata) +static void begin_parallel_decryption(struct padata_priv *padata) { struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata); begin_decrypt_packet(ctx); padata_do_serial(padata); } -static void finish_decryption(struct padata_priv *padata) +static void finish_parallel_decryption(struct padata_priv *padata) { struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata); finish_decrypt_packet(ctx); kmem_cache_free(decryption_ctx_cache, ctx); } - -static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) -{ - memset(priv, 0, sizeof(struct padata_priv)); - priv->parallel = do_decryption; - priv->serial = finish_decryption; - return padata_do_parallel(padata, priv, cb_cpu); -} #endif -void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback) +void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg) { - int ret; struct noise_keypair *keypair; __le32 idx = ((struct message_data *)skb->data)->key_idx; - ret = -EINVAL; rcu_read_lock_bh(); keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); rcu_read_unlock_bh(); @@ -429,19 +389,15 @@ void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packe #ifdef CONFIG_WIREGUARD_PARALLEL if (cpumask_weight(cpu_online_mask) > 1) { - unsigned int cpu = choose_cpu(idx); - struct decryption_ctx *ctx; - - ret = -ENOMEM; - ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC); + struct decryption_ctx *ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC); if (unlikely(!ctx)) goto err_peer; - ctx->skb = skb; ctx->keypair = keypair; - ctx->callback = callback; - ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu); - if (unlikely(ret)) { + memset(&ctx->padata, 0, sizeof(ctx->padata)); + ctx->padata.parallel = begin_parallel_decryption; + ctx->padata.serial = finish_parallel_decryption; + if (unlikely(padata_do_parallel(wg->parallel_receive, &ctx->padata, choose_cpu(idx)))) { kmem_cache_free(decryption_ctx_cache, ctx); goto err_peer; } @@ -450,8 +406,7 @@ void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packe { struct decryption_ctx ctx = { .skb = skb, - .keypair = keypair, - .callback = callback + .keypair = keypair }; begin_decrypt_packet(&ctx); finish_decrypt_packet(&ctx); @@ -464,5 +419,5 @@ err_peer: noise_keypair_put(keypair); #endif err: - callback(skb, NULL, NULL, false, ret); + dev_kfree_skb(skb); } |