diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2017-03-15 19:20:58 +0100 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2017-03-20 01:02:06 +0100 |
commit | 2e6e03366543069811c9ea189340a73cd000a29b (patch) | |
tree | d3fcd8e802587ee94dafb01ce3c5b56b97710528 /src/data.c | |
parent | 05acbf5bbbf5f6a377dc001ac945ea8e214c87b8 (diff) |
data: big refactoring
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'src/data.c')
-rw-r--r-- | src/data.c | 167 |
1 files changed, 71 insertions, 96 deletions
@@ -15,14 +15,6 @@ #include <net/xfrm.h> #include <crypto/algapi.h> -struct encryption_skb_cb { - u8 ds; - u8 num_frags; - unsigned int plaintext_len, trailer_len; - struct sk_buff *trailer; - u64 nonce; -}; - struct encryption_ctx { struct padata_priv padata; struct sk_buff_head queue; @@ -33,13 +25,11 @@ struct encryption_ctx { struct decryption_ctx { struct padata_priv padata; + struct endpoint endpoint; struct sk_buff *skb; packet_consume_data_callback_t callback; struct noise_keypair *keypair; - struct endpoint endpoint; - u64 nonce; int ret; - u8 num_frags; }; #ifdef CONFIG_WIREGUARD_PARALLEL @@ -48,7 +38,6 @@ static struct kmem_cache *decryption_ctx_cache __read_mostly; int packet_init_data_caches(void) { - BUILD_BUG_ON(sizeof(struct encryption_skb_cb) > sizeof(((struct sk_buff *)0)->cb)); encryption_ctx_cache = kmem_cache_create("wireguard_encryption_ctx", sizeof(struct encryption_ctx), 0, 0, NULL); if (!encryption_ctx_cache) return -ENOMEM; @@ -130,13 +119,36 @@ static inline void skb_reset(struct sk_buff *skb) skb_reset_mac_header(skb); skb_reset_network_header(skb); skb_probe_transport_header(skb, 0); + skb_reset_inner_headers(skb); } -static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd) +static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd) { - struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb; - struct scatterlist sg[cb->num_frags]; /* This should be bound to at most 128 by the caller. */ + struct scatterlist *sg; struct message_data *header; + unsigned int padding_len, plaintext_len, trailer_len; + int num_frags; + struct sk_buff *trailer; + + /* Store the ds bit in the cb */ + PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); + + /* Calculate lengths */ + padding_len = skb_padding(skb); + trailer_len = padding_len + noise_encrypted_len(0); + plaintext_len = skb->len + padding_len; + + /* Expand data section to have room for padding and auth tag */ + num_frags = skb_cow_data(skb, trailer_len, &trailer); + if (unlikely(num_frags < 0 || num_frags > 128)) + return false; + + /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ + memset(skb_tail_pointer(trailer), 0, padding_len); + + /* Expand head section to have room for our header and the network stack's headers. */ + if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0)) + return false; /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */ if (likely(!skb_checksum_setup(skb, true))) @@ -146,18 +158,23 @@ static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypai header = (struct message_data *)skb_push(skb, sizeof(struct message_data)); header->header.type = cpu_to_le32(MESSAGE_DATA); header->key_idx = keypair->remote_index; - header->counter = cpu_to_le64(cb->nonce); - pskb_put(skb, cb->trailer, cb->trailer_len); + header->counter = cpu_to_le64(PACKET_CB(skb)->nonce); + pskb_put(skb, trailer, trailer_len); /* Now we can encrypt the scattergather segments */ - sg_init_table(sg, cb->num_frags); - skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(cb->plaintext_len)); - chacha20poly1305_encrypt_sg(sg, sg, cb->plaintext_len, NULL, 0, cb->nonce, keypair->sending.key, have_simd); + sg = __builtin_alloca(num_frags * sizeof(struct scatterlist)); /* bounded to 128 */ + sg_init_table(sg, num_frags); + skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(plaintext_len)); + chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, keypair->sending.key, have_simd); + + return true; } -static inline bool skb_decrypt(struct sk_buff *skb, u8 num_frags, u64 nonce, struct noise_symmetric_key *key) +static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key) { - struct scatterlist sg[num_frags]; /* This should be bound to at most 128 by the caller. */ + struct scatterlist *sg; + struct sk_buff *trailer; + int num_frags; if (unlikely(!key)) return false; @@ -167,10 +184,17 @@ static inline bool skb_decrypt(struct sk_buff *skb, u8 num_frags, u64 nonce, str return false; } + PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter); + skb_pull(skb, sizeof(struct message_data)); + num_frags = skb_cow_data(skb, 0, &trailer); + if (unlikely(num_frags < 0 || num_frags > 128)) + return false; + sg = __builtin_alloca(num_frags * sizeof(struct scatterlist)); /* bounded to 128 */ + sg_init_table(sg, num_frags); skb_to_sgvec(skb, sg, 0, skb->len); - if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, nonce, key->key)) + if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key)) return false; return pskb_trim(skb, skb->len - noise_encrypted_len(0)) == 0; @@ -197,10 +221,14 @@ static inline bool get_encryption_nonce(u64 *nonce, struct noise_symmetric_key * static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_keypair *keypair) { - struct sk_buff *skb; + struct sk_buff *skb, *tmp; bool have_simd = chacha20poly1305_init_simd(); - skb_queue_walk(queue, skb) { - skb_encrypt(skb, keypair, have_simd); + skb_queue_walk_safe(queue, skb, tmp) { + if (unlikely(!skb_encrypt(skb, keypair, have_simd))) { + skb_unlink(skb, queue); + kfree_skb(skb); + continue; + } skb_reset(skb); } chacha20poly1305_deinit_simd(have_simd); @@ -261,35 +289,7 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, rcu_read_unlock(); skb_queue_walk(queue, skb) { - struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb; - unsigned int padding_len, num_frags; - - if (unlikely(!get_encryption_nonce(&cb->nonce, &keypair->sending))) - goto err; - - padding_len = skb_padding(skb); - cb->trailer_len = padding_len + noise_encrypted_len(0); - cb->plaintext_len = skb->len + padding_len; - - /* Store the ds bit in the cb */ - cb->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); - - /* Expand data section to have room for padding and auth tag */ - ret = skb_cow_data(skb, cb->trailer_len, &cb->trailer); - if (unlikely(ret < 0)) - goto err; - num_frags = ret; - ret = -ENOMEM; - if (unlikely(num_frags > 128)) - goto err; - cb->num_frags = num_frags; - - /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ - memset(skb_tail_pointer(cb->trailer), 0, padding_len); - - /* Expand head section to have room for our header and the network stack's headers. */ - ret = skb_cow_head(skb, DATA_PACKET_HEAD_ROOM); - if (unlikely(ret < 0)) + if (unlikely(!get_encryption_nonce(&PACKET_CB(skb)->nonce, &keypair->sending))) goto err; /* After the first time through the loop, if we've suceeded with a legitimate nonce, @@ -344,15 +344,18 @@ err_rcu: static void begin_decrypt_packet(struct decryption_ctx *ctx) { - if (unlikely(!skb_decrypt(ctx->skb, ctx->num_frags, ctx->nonce, &ctx->keypair->receiving))) + ctx->ret = socket_endpoint_from_skb(&ctx->endpoint, ctx->skb); + if (unlikely(ctx->ret < 0)) + goto err; + + ctx->ret = -ENOKEY; + if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving))) goto err; - skb_reset(ctx->skb); ctx->ret = 0; return; err: - ctx->ret = -ENOKEY; peer_put(ctx->keypair->entry.peer); } @@ -360,22 +363,25 @@ static void finish_decrypt_packet(struct decryption_ctx *ctx) { struct noise_keypairs *keypairs; bool used_new_key = false; + u64 nonce = PACKET_CB(ctx->skb)->nonce; int ret = ctx->ret; if (ret) goto err; keypairs = &ctx->keypair->entry.peer->keypairs; - ret = counter_validate(&ctx->keypair->receiving.counter, ctx->nonce) ? 0 : -ERANGE; + ret = counter_validate(&ctx->keypair->receiving.counter, nonce) ? 0 : -ERANGE; if (likely(!ret)) used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair); else { - net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", ctx->nonce, ctx->keypair->receiving.counter.receive.counter); + net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", nonce, ctx->keypair->receiving.counter.receive.counter); peer_put(ctx->keypair->entry.peer); goto err; } noise_keypair_put(ctx->keypair); + + skb_reset(ctx->skb); ctx->callback(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key, 0); return; @@ -401,51 +407,26 @@ static void finish_decryption(struct padata_priv *padata) static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) { + memset(priv, 0, sizeof(struct padata_priv)); priv->parallel = do_decryption; priv->serial = finish_decryption; return padata_do_parallel(padata, priv, cb_cpu); } #endif -void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_device *wg, packet_consume_data_callback_t callback) +void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback) { int ret; - struct endpoint endpoint; - unsigned int num_frags; - struct sk_buff *trailer; - struct message_data *header; struct noise_keypair *keypair; - u64 nonce; - __le32 idx; + __le32 idx = ((struct message_data *)skb->data)->key_idx; - ret = socket_endpoint_from_skb(&endpoint, skb); - if (unlikely(ret < 0)) - goto err; - - ret = -ENOMEM; - if (unlikely(!pskb_may_pull(skb, offset + sizeof(struct message_data)))) - goto err; - - header = (struct message_data *)(skb->data + offset); - offset += sizeof(struct message_data); - skb_pull(skb, offset); - - idx = header->key_idx; - nonce = le64_to_cpu(header->counter); - - ret = skb_cow_data(skb, 0, &trailer); - if (unlikely(ret < 0)) - goto err; - num_frags = ret; - ret = -ENOMEM; - if (unlikely(num_frags > 128)) - goto err; ret = -EINVAL; rcu_read_lock(); keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); rcu_read_unlock(); if (unlikely(!keypair)) goto err; + #ifdef CONFIG_WIREGUARD_PARALLEL if (cpumask_weight(cpu_online_mask) > 1) { unsigned int cpu = choose_cpu(idx); @@ -459,9 +440,6 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de ctx->skb = skb; ctx->keypair = keypair; ctx->callback = callback; - ctx->nonce = nonce; - ctx->num_frags = num_frags; - ctx->endpoint = endpoint; ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu); if (unlikely(ret)) { kmem_cache_free(decryption_ctx_cache, ctx); @@ -473,10 +451,7 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de struct decryption_ctx ctx = { .skb = skb, .keypair = keypair, - .callback = callback, - .nonce = nonce, - .num_frags = num_frags, - .endpoint = endpoint + .callback = callback }; begin_decrypt_packet(&ctx); finish_decrypt_packet(&ctx); |