diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/data.c | 117 | ||||
-rw-r--r-- | src/packets.h | 9 | ||||
-rw-r--r-- | src/receive.c | 9 | ||||
-rw-r--r-- | src/send.c | 4 |
4 files changed, 45 insertions, 94 deletions
@@ -18,7 +18,6 @@ struct encryption_ctx { struct padata_priv padata; struct sk_buff_head queue; - packet_create_data_callback_t callback; struct wireguard_peer *peer; struct noise_keypair *keypair; }; @@ -27,9 +26,7 @@ struct decryption_ctx { struct padata_priv padata; struct endpoint endpoint; struct sk_buff *skb; - packet_consume_data_callback_t callback; struct noise_keypair *keypair; - int ret; }; #ifdef CONFIG_WIREGUARD_PARALLEL @@ -225,7 +222,7 @@ static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_ bool have_simd = chacha20poly1305_init_simd(); skb_queue_walk_safe(queue, skb, tmp) { if (unlikely(!skb_encrypt(skb, keypair, have_simd))) { - skb_unlink(skb, queue); + __skb_unlink(skb, queue); kfree_skb(skb); continue; } @@ -236,32 +233,22 @@ static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_ } #ifdef CONFIG_WIREGUARD_PARALLEL -static void do_encryption(struct padata_priv *padata) +static void begin_parallel_encryption(struct padata_priv *padata) { struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata); - queue_encrypt_reset(&ctx->queue, ctx->keypair); padata_do_serial(padata); } -static void finish_encryption(struct padata_priv *padata) +static void finish_parallel_encryption(struct padata_priv *padata) { struct encryption_ctx *ctx = container_of(padata, struct encryption_ctx, padata); - - ctx->callback(&ctx->queue, ctx->peer); + packet_create_data_done(&ctx->queue, ctx->peer); atomic_dec(&ctx->peer->parallel_encryption_inflight); peer_put(ctx->peer); kmem_cache_free(encryption_ctx_cache, ctx); } -static inline int start_encryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) -{ - memset(priv, 0, sizeof(struct padata_priv)); - priv->parallel = do_encryption; - priv->serial = finish_encryption; - return padata_do_parallel(padata, priv, cb_cpu); -} - static inline unsigned int choose_cpu(__le32 key) { unsigned int cpu_index, cpu, cb_cpu; @@ -276,7 +263,7 @@ static inline unsigned int choose_cpu(__le32 key) } #endif -int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, packet_create_data_callback_t callback) +int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer) { int ret = -ENOKEY; struct noise_keypair *keypair; @@ -303,21 +290,21 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, #ifdef CONFIG_WIREGUARD_PARALLEL if ((skb_queue_len(queue) > 1 || queue->next->len > 256 || atomic_read(&peer->parallel_encryption_inflight) > 0) && cpumask_weight(cpu_online_mask) > 1) { - unsigned int cpu = choose_cpu(keypair->remote_index); struct encryption_ctx *ctx = kmem_cache_alloc(encryption_ctx_cache, GFP_ATOMIC); if (!ctx) goto serial_encrypt; skb_queue_head_init(&ctx->queue); skb_queue_splice_init(queue, &ctx->queue); - ctx->callback = callback; + memset(&ctx->padata, 0, sizeof(ctx->padata)); + ctx->padata.parallel = begin_parallel_encryption; + ctx->padata.serial = finish_parallel_encryption; ctx->keypair = keypair; ctx->peer = peer_rcu_get(peer); ret = -EBUSY; if (unlikely(!ctx->peer)) goto err_parallel; atomic_inc(&peer->parallel_encryption_inflight); - ret = start_encryption(peer->device->parallel_send, &ctx->padata, cpu); - if (unlikely(ret < 0)) { + if (unlikely(padata_do_parallel(peer->device->parallel_send, &ctx->padata, choose_cpu(keypair->remote_index)))) { atomic_dec(&peer->parallel_encryption_inflight); peer_put(ctx->peer); err_parallel: @@ -330,7 +317,7 @@ serial_encrypt: #endif { queue_encrypt_reset(queue, keypair); - callback(queue, peer); + packet_create_data_done(queue, peer); } return 0; @@ -344,83 +331,56 @@ err_rcu: static void begin_decrypt_packet(struct decryption_ctx *ctx) { - ctx->ret = socket_endpoint_from_skb(&ctx->endpoint, ctx->skb); - if (unlikely(ctx->ret < 0)) - goto err; - - ctx->ret = -ENOKEY; - if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving))) - goto err; - - ctx->ret = 0; - return; - -err: - peer_put(ctx->keypair->entry.peer); + if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) { + peer_put(ctx->keypair->entry.peer); + noise_keypair_put(ctx->keypair); + dev_kfree_skb(ctx->skb); + ctx->skb = NULL; + } } static void finish_decrypt_packet(struct decryption_ctx *ctx) { - struct noise_keypairs *keypairs; - bool used_new_key = false; - u64 nonce = PACKET_CB(ctx->skb)->nonce; - int ret = ctx->ret; - if (ret) - goto err; + bool used_new_key; - keypairs = &ctx->keypair->entry.peer->keypairs; - ret = counter_validate(&ctx->keypair->receiving.counter, nonce) ? 0 : -ERANGE; + if (!ctx->skb) + return; - if (likely(!ret)) - used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair); - else { - net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", nonce, ctx->keypair->receiving.counter.receive.counter); + if (unlikely(!counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce))) { + net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter); peer_put(ctx->keypair->entry.peer); - goto err; + noise_keypair_put(ctx->keypair); + dev_kfree_skb(ctx->skb); + return; } - noise_keypair_put(ctx->keypair); - + used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair); skb_reset(ctx->skb); - ctx->callback(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key, 0); - return; - -err: + packet_consume_data_done(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key); noise_keypair_put(ctx->keypair); - ctx->callback(ctx->skb, NULL, NULL, false, ret); } #ifdef CONFIG_WIREGUARD_PARALLEL -static void do_decryption(struct padata_priv *padata) +static void begin_parallel_decryption(struct padata_priv *padata) { struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata); begin_decrypt_packet(ctx); padata_do_serial(padata); } -static void finish_decryption(struct padata_priv *padata) +static void finish_parallel_decryption(struct padata_priv *padata) { struct decryption_ctx *ctx = container_of(padata, struct decryption_ctx, padata); finish_decrypt_packet(ctx); kmem_cache_free(decryption_ctx_cache, ctx); } - -static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) -{ - memset(priv, 0, sizeof(struct padata_priv)); - priv->parallel = do_decryption; - priv->serial = finish_decryption; - return padata_do_parallel(padata, priv, cb_cpu); -} #endif -void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback) +void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg) { - int ret; struct noise_keypair *keypair; __le32 idx = ((struct message_data *)skb->data)->key_idx; - ret = -EINVAL; rcu_read_lock_bh(); keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); rcu_read_unlock_bh(); @@ -429,19 +389,15 @@ void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packe #ifdef CONFIG_WIREGUARD_PARALLEL if (cpumask_weight(cpu_online_mask) > 1) { - unsigned int cpu = choose_cpu(idx); - struct decryption_ctx *ctx; - - ret = -ENOMEM; - ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC); + struct decryption_ctx *ctx = kmem_cache_alloc(decryption_ctx_cache, GFP_ATOMIC); if (unlikely(!ctx)) goto err_peer; - ctx->skb = skb; ctx->keypair = keypair; - ctx->callback = callback; - ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu); - if (unlikely(ret)) { + memset(&ctx->padata, 0, sizeof(ctx->padata)); + ctx->padata.parallel = begin_parallel_decryption; + ctx->padata.serial = finish_parallel_decryption; + if (unlikely(padata_do_parallel(wg->parallel_receive, &ctx->padata, choose_cpu(idx)))) { kmem_cache_free(decryption_ctx_cache, ctx); goto err_peer; } @@ -450,8 +406,7 @@ void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packe { struct decryption_ctx ctx = { .skb = skb, - .keypair = keypair, - .callback = callback + .keypair = keypair }; begin_decrypt_packet(&ctx); finish_decrypt_packet(&ctx); @@ -464,5 +419,5 @@ err_peer: noise_keypair_put(keypair); #endif err: - callback(skb, NULL, NULL, false, ret); + dev_kfree_skb(skb); } diff --git a/src/packets.h b/src/packets.h index a640847..be9cfd7 100644 --- a/src/packets.h +++ b/src/packets.h @@ -23,6 +23,7 @@ struct packet_cb { /* receive.c */ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb); void packet_process_queued_handshake_packets(struct work_struct *work); +void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key); /* send.c */ void packet_send_queue(struct wireguard_peer *peer); @@ -31,12 +32,12 @@ void packet_queue_handshake_initiation(struct wireguard_peer *peer); void packet_send_queued_handshakes(struct work_struct *work); void packet_send_handshake_response(struct wireguard_peer *peer); void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index); +void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer); + /* data.c */ -typedef void (*packet_create_data_callback_t)(struct sk_buff_head *, struct wireguard_peer *); -typedef void (*packet_consume_data_callback_t)(struct sk_buff *skb, struct wireguard_peer *, struct endpoint *, bool used_new_key, int err); -int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, packet_create_data_callback_t callback); -void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback); +int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer); +void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg); #ifdef CONFIG_WIREGUARD_PARALLEL int packet_init_data_caches(void); diff --git a/src/receive.c b/src/receive.c index 3b375ae..929d723 100644 --- a/src/receive.c +++ b/src/receive.c @@ -205,17 +205,12 @@ static void keep_key_fresh(struct wireguard_peer *peer) } } -static void receive_data_packet(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key, int err) +void packet_consume_data_done(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key) { struct net_device *dev; struct wireguard_peer *routed_peer; struct wireguard_device *wg; - if (unlikely(err < 0 || !peer || !endpoint)) { - dev_kfree_skb(skb); - return; - } - socket_set_peer_endpoint(peer, endpoint); wg = peer->device; @@ -305,7 +300,7 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) break; case MESSAGE_DATA: PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb); - packet_consume_data(skb, wg, receive_data_packet); + packet_consume_data(skb, wg); break; default: net_dbg_skb_ratelimited("Invalid packet from %pISpfsc\n", skb); @@ -118,7 +118,7 @@ void packet_send_keepalive(struct wireguard_peer *peer) packet_send_queue(peer); } -static void message_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer) +void packet_create_data_done(struct sk_buff_head *queue, struct wireguard_peer *peer) { struct sk_buff *skb, *tmp; bool is_keepalive, data_sent = false; @@ -157,7 +157,7 @@ void packet_send_queue(struct wireguard_peer *peer) return; /* We submit it for encryption and sending. */ - switch (packet_create_data(&queue, peer, message_create_data_done)) { + switch (packet_create_data(&queue, peer)) { case 0: break; case -EBUSY: |