diff options
-rw-r--r-- | src/cookie.c | 12 | ||||
-rw-r--r-- | src/cookie.h | 4 | ||||
-rw-r--r-- | src/data.c | 167 | ||||
-rw-r--r-- | src/messages.h | 22 | ||||
-rw-r--r-- | src/packets.h | 10 | ||||
-rw-r--r-- | src/ratelimiter.c | 2 | ||||
-rw-r--r-- | src/receive.c | 91 | ||||
-rw-r--r-- | src/send.c | 9 |
8 files changed, 158 insertions, 159 deletions
diff --git a/src/cookie.c b/src/cookie.c index 1c188c6..66f5d45 100644 --- a/src/cookie.c +++ b/src/cookie.c @@ -103,12 +103,12 @@ static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cooki up_read(&checker->secret_lock); } -enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, void *data_start, size_t data_len, bool check_cookie) +enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie) { u8 computed_mac[COOKIE_LEN]; u8 cookie[COOKIE_LEN]; enum cookie_mac_state ret; - struct message_macs *macs = (struct message_macs *)((u8 *)data_start + data_len - sizeof(struct message_macs)); + struct message_macs *macs = (struct message_macs *)(skb->data + skb->len - sizeof(struct message_macs)); ret = INVALID_MAC; down_read(&checker->device->static_identity.lock); @@ -116,7 +116,7 @@ enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, str up_read(&checker->device->static_identity.lock); goto out; } - compute_mac1(computed_mac, data_start, data_len, checker->device->static_identity.static_public, checker->device->static_identity.has_psk ? checker->device->static_identity.preshared_key : NULL); + compute_mac1(computed_mac, skb->data, skb->len, checker->device->static_identity.static_public, checker->device->static_identity.has_psk ? checker->device->static_identity.preshared_key : NULL); up_read(&checker->device->static_identity.lock); if (crypto_memneq(computed_mac, macs->mac1, COOKIE_LEN)) goto out; @@ -128,7 +128,7 @@ enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, str make_cookie(cookie, skb, checker); - compute_mac2(computed_mac, data_start, data_len, cookie); + compute_mac2(computed_mac, skb->data, skb->len, cookie); if (crypto_memneq(computed_mac, macs->mac2, COOKIE_LEN)) goto out; @@ -168,9 +168,9 @@ void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer * up_read(&peer->latest_cookie.lock); } -void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff *skb, void *data_start, size_t data_len, __le32 index, struct cookie_checker *checker) +void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff *skb, __le32 index, struct cookie_checker *checker) { - struct message_macs *macs = (struct message_macs *)((u8 *)data_start + data_len - sizeof(struct message_macs)); + struct message_macs *macs = (struct message_macs *)((u8 *)skb->data + skb->len - sizeof(struct message_macs)); u8 cookie[COOKIE_LEN]; dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE); diff --git a/src/cookie.h b/src/cookie.h index e1c8d8e..87a0e5a 100644 --- a/src/cookie.h +++ b/src/cookie.h @@ -42,10 +42,10 @@ void cookie_checker_uninit(struct cookie_checker *checker); void cookie_checker_precompute_keys(struct cookie_checker *checker, struct wireguard_peer *peer); void cookie_init(struct cookie *cookie); -enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, void *data_start, size_t data_len, bool check_cookie); +enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie); void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *peer); -void cookie_message_create(struct message_handshake_cookie *src, struct sk_buff *skb, void *data_start, size_t data_len, __le32 index, struct cookie_checker *checker); +void cookie_message_create(struct message_handshake_cookie *src, struct sk_buff *skb, __le32 index, struct cookie_checker *checker); void cookie_message_consume(struct message_handshake_cookie *src, struct wireguard_device *wg); #endif @@ -15,14 +15,6 @@ #include <net/xfrm.h> #include <crypto/algapi.h> -struct encryption_skb_cb { - u8 ds; - u8 num_frags; - unsigned int plaintext_len, trailer_len; - struct sk_buff *trailer; - u64 nonce; -}; - struct encryption_ctx { struct padata_priv padata; struct sk_buff_head queue; @@ -33,13 +25,11 @@ struct encryption_ctx { struct decryption_ctx { struct padata_priv padata; + struct endpoint endpoint; struct sk_buff *skb; packet_consume_data_callback_t callback; struct noise_keypair *keypair; - struct endpoint endpoint; - u64 nonce; int ret; - u8 num_frags; }; #ifdef CONFIG_WIREGUARD_PARALLEL @@ -48,7 +38,6 @@ static struct kmem_cache *decryption_ctx_cache __read_mostly; int packet_init_data_caches(void) { - BUILD_BUG_ON(sizeof(struct encryption_skb_cb) > sizeof(((struct sk_buff *)0)->cb)); encryption_ctx_cache = kmem_cache_create("wireguard_encryption_ctx", sizeof(struct encryption_ctx), 0, 0, NULL); if (!encryption_ctx_cache) return -ENOMEM; @@ -130,13 +119,36 @@ static inline void skb_reset(struct sk_buff *skb) skb_reset_mac_header(skb); skb_reset_network_header(skb); skb_probe_transport_header(skb, 0); + skb_reset_inner_headers(skb); } -static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd) +static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, bool have_simd) { - struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb; - struct scatterlist sg[cb->num_frags]; /* This should be bound to at most 128 by the caller. */ + struct scatterlist *sg; struct message_data *header; + unsigned int padding_len, plaintext_len, trailer_len; + int num_frags; + struct sk_buff *trailer; + + /* Store the ds bit in the cb */ + PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); + + /* Calculate lengths */ + padding_len = skb_padding(skb); + trailer_len = padding_len + noise_encrypted_len(0); + plaintext_len = skb->len + padding_len; + + /* Expand data section to have room for padding and auth tag */ + num_frags = skb_cow_data(skb, trailer_len, &trailer); + if (unlikely(num_frags < 0 || num_frags > 128)) + return false; + + /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ + memset(skb_tail_pointer(trailer), 0, padding_len); + + /* Expand head section to have room for our header and the network stack's headers. */ + if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0)) + return false; /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */ if (likely(!skb_checksum_setup(skb, true))) @@ -146,18 +158,23 @@ static inline void skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypai header = (struct message_data *)skb_push(skb, sizeof(struct message_data)); header->header.type = cpu_to_le32(MESSAGE_DATA); header->key_idx = keypair->remote_index; - header->counter = cpu_to_le64(cb->nonce); - pskb_put(skb, cb->trailer, cb->trailer_len); + header->counter = cpu_to_le64(PACKET_CB(skb)->nonce); + pskb_put(skb, trailer, trailer_len); /* Now we can encrypt the scattergather segments */ - sg_init_table(sg, cb->num_frags); - skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(cb->plaintext_len)); - chacha20poly1305_encrypt_sg(sg, sg, cb->plaintext_len, NULL, 0, cb->nonce, keypair->sending.key, have_simd); + sg = __builtin_alloca(num_frags * sizeof(struct scatterlist)); /* bounded to 128 */ + sg_init_table(sg, num_frags); + skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(plaintext_len)); + chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, keypair->sending.key, have_simd); + + return true; } -static inline bool skb_decrypt(struct sk_buff *skb, u8 num_frags, u64 nonce, struct noise_symmetric_key *key) +static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key) { - struct scatterlist sg[num_frags]; /* This should be bound to at most 128 by the caller. */ + struct scatterlist *sg; + struct sk_buff *trailer; + int num_frags; if (unlikely(!key)) return false; @@ -167,10 +184,17 @@ static inline bool skb_decrypt(struct sk_buff *skb, u8 num_frags, u64 nonce, str return false; } + PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter); + skb_pull(skb, sizeof(struct message_data)); + num_frags = skb_cow_data(skb, 0, &trailer); + if (unlikely(num_frags < 0 || num_frags > 128)) + return false; + sg = __builtin_alloca(num_frags * sizeof(struct scatterlist)); /* bounded to 128 */ + sg_init_table(sg, num_frags); skb_to_sgvec(skb, sg, 0, skb->len); - if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, nonce, key->key)) + if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key)) return false; return pskb_trim(skb, skb->len - noise_encrypted_len(0)) == 0; @@ -197,10 +221,14 @@ static inline bool get_encryption_nonce(u64 *nonce, struct noise_symmetric_key * static inline void queue_encrypt_reset(struct sk_buff_head *queue, struct noise_keypair *keypair) { - struct sk_buff *skb; + struct sk_buff *skb, *tmp; bool have_simd = chacha20poly1305_init_simd(); - skb_queue_walk(queue, skb) { - skb_encrypt(skb, keypair, have_simd); + skb_queue_walk_safe(queue, skb, tmp) { + if (unlikely(!skb_encrypt(skb, keypair, have_simd))) { + skb_unlink(skb, queue); + kfree_skb(skb); + continue; + } skb_reset(skb); } chacha20poly1305_deinit_simd(have_simd); @@ -261,35 +289,7 @@ int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, rcu_read_unlock(); skb_queue_walk(queue, skb) { - struct encryption_skb_cb *cb = (struct encryption_skb_cb *)skb->cb; - unsigned int padding_len, num_frags; - - if (unlikely(!get_encryption_nonce(&cb->nonce, &keypair->sending))) - goto err; - - padding_len = skb_padding(skb); - cb->trailer_len = padding_len + noise_encrypted_len(0); - cb->plaintext_len = skb->len + padding_len; - - /* Store the ds bit in the cb */ - cb->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); - - /* Expand data section to have room for padding and auth tag */ - ret = skb_cow_data(skb, cb->trailer_len, &cb->trailer); - if (unlikely(ret < 0)) - goto err; - num_frags = ret; - ret = -ENOMEM; - if (unlikely(num_frags > 128)) - goto err; - cb->num_frags = num_frags; - - /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ - memset(skb_tail_pointer(cb->trailer), 0, padding_len); - - /* Expand head section to have room for our header and the network stack's headers. */ - ret = skb_cow_head(skb, DATA_PACKET_HEAD_ROOM); - if (unlikely(ret < 0)) + if (unlikely(!get_encryption_nonce(&PACKET_CB(skb)->nonce, &keypair->sending))) goto err; /* After the first time through the loop, if we've suceeded with a legitimate nonce, @@ -344,15 +344,18 @@ err_rcu: static void begin_decrypt_packet(struct decryption_ctx *ctx) { - if (unlikely(!skb_decrypt(ctx->skb, ctx->num_frags, ctx->nonce, &ctx->keypair->receiving))) + ctx->ret = socket_endpoint_from_skb(&ctx->endpoint, ctx->skb); + if (unlikely(ctx->ret < 0)) + goto err; + + ctx->ret = -ENOKEY; + if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving))) goto err; - skb_reset(ctx->skb); ctx->ret = 0; return; err: - ctx->ret = -ENOKEY; peer_put(ctx->keypair->entry.peer); } @@ -360,22 +363,25 @@ static void finish_decrypt_packet(struct decryption_ctx *ctx) { struct noise_keypairs *keypairs; bool used_new_key = false; + u64 nonce = PACKET_CB(ctx->skb)->nonce; int ret = ctx->ret; if (ret) goto err; keypairs = &ctx->keypair->entry.peer->keypairs; - ret = counter_validate(&ctx->keypair->receiving.counter, ctx->nonce) ? 0 : -ERANGE; + ret = counter_validate(&ctx->keypair->receiving.counter, nonce) ? 0 : -ERANGE; if (likely(!ret)) used_new_key = noise_received_with_keypair(&ctx->keypair->entry.peer->keypairs, ctx->keypair); else { - net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", ctx->nonce, ctx->keypair->receiving.counter.receive.counter); + net_dbg_ratelimited("Packet has invalid nonce %Lu (max %Lu)\n", nonce, ctx->keypair->receiving.counter.receive.counter); peer_put(ctx->keypair->entry.peer); goto err; } noise_keypair_put(ctx->keypair); + + skb_reset(ctx->skb); ctx->callback(ctx->skb, ctx->keypair->entry.peer, &ctx->endpoint, used_new_key, 0); return; @@ -401,51 +407,26 @@ static void finish_decryption(struct padata_priv *padata) static inline int start_decryption(struct padata_instance *padata, struct padata_priv *priv, int cb_cpu) { + memset(priv, 0, sizeof(struct padata_priv)); priv->parallel = do_decryption; priv->serial = finish_decryption; return padata_do_parallel(padata, priv, cb_cpu); } #endif -void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_device *wg, packet_consume_data_callback_t callback) +void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback) { int ret; - struct endpoint endpoint; - unsigned int num_frags; - struct sk_buff *trailer; - struct message_data *header; struct noise_keypair *keypair; - u64 nonce; - __le32 idx; + __le32 idx = ((struct message_data *)skb->data)->key_idx; - ret = socket_endpoint_from_skb(&endpoint, skb); - if (unlikely(ret < 0)) - goto err; - - ret = -ENOMEM; - if (unlikely(!pskb_may_pull(skb, offset + sizeof(struct message_data)))) - goto err; - - header = (struct message_data *)(skb->data + offset); - offset += sizeof(struct message_data); - skb_pull(skb, offset); - - idx = header->key_idx; - nonce = le64_to_cpu(header->counter); - - ret = skb_cow_data(skb, 0, &trailer); - if (unlikely(ret < 0)) - goto err; - num_frags = ret; - ret = -ENOMEM; - if (unlikely(num_frags > 128)) - goto err; ret = -EINVAL; rcu_read_lock(); keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); rcu_read_unlock(); if (unlikely(!keypair)) goto err; + #ifdef CONFIG_WIREGUARD_PARALLEL if (cpumask_weight(cpu_online_mask) > 1) { unsigned int cpu = choose_cpu(idx); @@ -459,9 +440,6 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de ctx->skb = skb; ctx->keypair = keypair; ctx->callback = callback; - ctx->nonce = nonce; - ctx->num_frags = num_frags; - ctx->endpoint = endpoint; ret = start_decryption(wg->parallel_receive, &ctx->padata, cpu); if (unlikely(ret)) { kmem_cache_free(decryption_ctx_cache, ctx); @@ -473,10 +451,7 @@ void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_de struct decryption_ctx ctx = { .skb = skb, .keypair = keypair, - .callback = callback, - .nonce = nonce, - .num_frags = num_frags, - .endpoint = endpoint + .callback = callback }; begin_decrypt_packet(&ctx); finish_decrypt_packet(&ctx); diff --git a/src/messages.h b/src/messages.h index 7dc09aa..defc831 100644 --- a/src/messages.h +++ b/src/messages.h @@ -13,6 +13,7 @@ #include <linux/kernel.h> #include <linux/param.h> +#include <linux/skbuff.h> enum noise_lengths { NOISE_PUBLIC_KEY_LEN = CURVE25519_POINT_SIZE, @@ -124,18 +125,25 @@ enum { HANDSHAKE_DSCP = 0b10001000 /* AF41, plus 00 ECN */ }; -static inline enum message_type message_determine_type(void *src, size_t src_len) +static const unsigned int message_header_sizes[MESSAGE_TOTAL] = { + [MESSAGE_HANDSHAKE_INITIATION] = sizeof(struct message_handshake_initiation), + [MESSAGE_HANDSHAKE_RESPONSE] = sizeof(struct message_handshake_response), + [MESSAGE_HANDSHAKE_COOKIE] = sizeof(struct message_handshake_cookie), + [MESSAGE_DATA] = sizeof(struct message_data) +}; + +static inline enum message_type message_determine_type(struct sk_buff *skb) { - struct message_header *header = src; - if (unlikely(src_len < sizeof(struct message_header))) + struct message_header *header = (struct message_header *)skb->data; + if (unlikely(skb->len < sizeof(struct message_header))) return MESSAGE_INVALID; - if (header->type == cpu_to_le32(MESSAGE_DATA) && src_len >= MESSAGE_MINIMUM_LENGTH) + if (header->type == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH) return MESSAGE_DATA; - if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && src_len == sizeof(struct message_handshake_initiation)) + if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation)) return MESSAGE_HANDSHAKE_INITIATION; - if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && src_len == sizeof(struct message_handshake_response)) + if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response)) return MESSAGE_HANDSHAKE_RESPONSE; - if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && src_len == sizeof(struct message_handshake_cookie)) + if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie)) return MESSAGE_HANDSHAKE_COOKIE; return MESSAGE_INVALID; } diff --git a/src/packets.h b/src/packets.h index 6530048..a640847 100644 --- a/src/packets.h +++ b/src/packets.h @@ -14,6 +14,12 @@ struct wireguard_device; struct wireguard_peer; struct sk_buff; +struct packet_cb { + u64 nonce; + u8 ds; +}; +#define PACKET_CB(skb) ((struct packet_cb *)skb->cb) + /* receive.c */ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb); void packet_process_queued_handshake_packets(struct work_struct *work); @@ -24,13 +30,13 @@ void packet_send_keepalive(struct wireguard_peer *peer); void packet_queue_handshake_initiation(struct wireguard_peer *peer); void packet_send_queued_handshakes(struct work_struct *work); void packet_send_handshake_response(struct wireguard_peer *peer); -void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, void *data, size_t data_len, __le32 sender_index); +void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index); /* data.c */ typedef void (*packet_create_data_callback_t)(struct sk_buff_head *, struct wireguard_peer *); typedef void (*packet_consume_data_callback_t)(struct sk_buff *skb, struct wireguard_peer *, struct endpoint *, bool used_new_key, int err); int packet_create_data(struct sk_buff_head *queue, struct wireguard_peer *peer, packet_create_data_callback_t callback); -void packet_consume_data(struct sk_buff *skb, size_t offset, struct wireguard_device *wg, packet_consume_data_callback_t callback); +void packet_consume_data(struct sk_buff *skb, struct wireguard_device *wg, packet_consume_data_callback_t callback); #ifdef CONFIG_WIREGUARD_PARALLEL int packet_init_data_caches(void); diff --git a/src/ratelimiter.c b/src/ratelimiter.c index 12282fd..ab8f93d 100644 --- a/src/ratelimiter.c +++ b/src/ratelimiter.c @@ -25,7 +25,7 @@ static inline void cfg_init(struct hashlimit_cfg1 *cfg, int family) cfg->srcmask = 32; else if (family == NFPROTO_IPV6) cfg->srcmask = 96; - cfg->mode = XT_HASHLIMIT_HASH_SIP; /* source IP only -- we could also do source port by ORing this with XT_HASHLIMIT_HASH_SPT */ + cfg->mode = XT_HASHLIMIT_HASH_SIP; /* source IP only -- we could also do source port by ORing this with XT_HASHLIMIT_HASH_SPT, but we don't really want to do that. It would also cause problems since we skb_pull early on, and hashlimit's nexthdr stuff isn't so nice. */ cfg->avg = XT_HASHLIMIT_SCALE / RATELIMITER_PACKETS_PER_SECOND; /* 30 per second per IP */ cfg->burst = RATELIMITER_PACKETS_BURSTABLE; /* Allow bursts of 5 at a time */ cfg->gc_interval = 1000; /* same as expiration date */ diff --git a/src/receive.c b/src/receive.c index 5707ab2..f791a2e 100644 --- a/src/receive.c +++ b/src/receive.c @@ -30,9 +30,11 @@ static inline void update_latest_addr(struct wireguard_peer *peer, struct sk_buf socket_set_peer_endpoint(peer, &endpoint); } -static inline int skb_data_offset(struct sk_buff *skb, size_t *data_offset, size_t *data_len) +static inline int skb_prepare_header(struct sk_buff *skb) { struct udphdr *udp; + size_t data_offset, data_len; + enum message_type message_type; if (unlikely(skb->len < sizeof(struct iphdr))) return -EINVAL; @@ -42,35 +44,50 @@ static inline int skb_data_offset(struct sk_buff *skb, size_t *data_offset, size return -EINVAL; udp = udp_hdr(skb); - *data_offset = (u8 *)udp - skb->data; - if (unlikely(*data_offset > U16_MAX)) { + data_offset = (u8 *)udp - skb->data; + if (unlikely(data_offset > U16_MAX)) { net_dbg_skb_ratelimited("Packet has offset at impossible location from %pISpfsc\n", skb); return -EINVAL; } - if (unlikely(*data_offset + sizeof(struct udphdr) > skb->len)) { + if (unlikely(data_offset + sizeof(struct udphdr) > skb->len)) { net_dbg_skb_ratelimited("Packet isn't big enough to have UDP fields from %pISpfsc\n", skb); return -EINVAL; } - *data_len = ntohs(udp->len); - if (unlikely(*data_len < sizeof(struct udphdr))) { + data_len = ntohs(udp->len); + if (unlikely(data_len < sizeof(struct udphdr))) { net_dbg_skb_ratelimited("UDP packet is reporting too small of a size from %pISpfsc\n", skb); return -EINVAL; } - if (unlikely(*data_len > skb->len - *data_offset)) { + if (unlikely(data_len > skb->len - data_offset)) { net_dbg_skb_ratelimited("UDP packet is lying about its size from %pISpfsc\n", skb); return -EINVAL; } - *data_len -= sizeof(struct udphdr); - *data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data; - if (!pskb_may_pull(skb, *data_offset + sizeof(struct message_header))) { + data_len -= sizeof(struct udphdr); + data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data; + if (unlikely(!pskb_may_pull(skb, data_offset + sizeof(struct message_header)))) { net_dbg_skb_ratelimited("Could not pull header into data section from %pISpfsc\n", skb); return -EINVAL; } - - return 0; + if (pskb_trim(skb, data_len + data_offset) < 0) { + net_dbg_skb_ratelimited("Could not trim packet from %pISpfsc\n", skb); + return -EINVAL; + } + skb_pull(skb, data_offset); + if (unlikely(skb->len != data_len)) { + net_dbg_skb_ratelimited("Final len does not agree with calculated len from %pISpfsc\n", skb); + return -EINVAL; + } + message_type = message_determine_type(skb); + __skb_push(skb, data_offset); + if (unlikely(!pskb_may_pull(skb, data_offset + message_header_sizes[message_type]))) { + net_dbg_skb_ratelimited("Could not pull full header into data section from %pISpfsc\n", skb); + return -EINVAL; + } + __skb_pull(skb, data_offset); + return message_type; } -static void receive_handshake_packet(struct wireguard_device *wg, void *data, size_t len, struct sk_buff *skb) +static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff *skb) { struct wireguard_peer *peer = NULL; enum message_type message_type; @@ -78,16 +95,16 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si enum cookie_mac_state mac_state; bool packet_needs_cookie; - message_type = message_determine_type(data, len); + message_type = message_determine_type(skb); if (message_type == MESSAGE_HANDSHAKE_COOKIE) { net_dbg_skb_ratelimited("Receiving cookie response from %pISpfsc\n", skb); - cookie_message_consume(data, wg); + cookie_message_consume((struct message_handshake_cookie *)skb->data, wg); return; } under_load = skb_queue_len(&wg->incoming_handshakes) >= MAX_QUEUED_INCOMING_HANDSHAKES / 2; - mac_state = cookie_validate_packet(&wg->cookie_checker, skb, data, len, under_load); + mac_state = cookie_validate_packet(&wg->cookie_checker, skb, under_load); if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) packet_needs_cookie = false; else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE) @@ -98,13 +115,13 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si } switch (message_type) { - case MESSAGE_HANDSHAKE_INITIATION: + case MESSAGE_HANDSHAKE_INITIATION: { + struct message_handshake_initiation *message = (struct message_handshake_initiation *)skb->data; if (packet_needs_cookie) { - struct message_handshake_initiation *message = data; - packet_send_handshake_cookie(wg, skb, message, sizeof(*message), message->sender_index); + packet_send_handshake_cookie(wg, skb, message->sender_index); return; } - peer = noise_handshake_consume_initiation(data, wg); + peer = noise_handshake_consume_initiation(message, wg); if (unlikely(!peer)) { net_dbg_skb_ratelimited("Invalid handshake initiation from %pISpfsc\n", skb); return; @@ -113,13 +130,14 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si net_dbg_ratelimited("Receiving handshake initiation from peer %Lu (%pISpfsc)\n", peer->internal_id, &peer->endpoint.addr); packet_send_handshake_response(peer); break; - case MESSAGE_HANDSHAKE_RESPONSE: + } + case MESSAGE_HANDSHAKE_RESPONSE: { + struct message_handshake_response *message = (struct message_handshake_response *)skb->data; if (packet_needs_cookie) { - struct message_handshake_response *message = data; - packet_send_handshake_cookie(wg, skb, message, sizeof(*message), message->sender_index); + packet_send_handshake_cookie(wg, skb, message->sender_index); return; } - peer = noise_handshake_consume_response(data, wg); + peer = noise_handshake_consume_response(message, wg); if (unlikely(!peer)) { net_dbg_skb_ratelimited("Invalid handshake response from %pISpfsc\n", skb); return; @@ -137,6 +155,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si packet_send_keepalive(peer); } break; + } default: WARN(1, "Somehow a wrong type of packet wound up in the handshake queue!\n"); return; @@ -144,7 +163,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, void *data, si BUG_ON(!peer); - rx_stats(peer, len); + rx_stats(peer, skb->len); timers_any_authenticated_packet_received(peer); timers_any_authenticated_packet_traversal(peer); peer_put(peer); @@ -154,12 +173,10 @@ void packet_process_queued_handshake_packets(struct work_struct *work) { struct wireguard_device *wg = container_of(work, struct wireguard_device, incoming_handshakes_work); struct sk_buff *skb; - size_t len, offset; size_t num_processed = 0; while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) { - if (!skb_data_offset(skb, &offset, &len)) - receive_handshake_packet(wg, skb->data + offset, len, skb); + receive_handshake_packet(wg, skb); dev_kfree_skb(skb); if (++num_processed == MAX_BURST_INCOMING_HANDSHAKES) { queue_work(wg->workqueue, &wg->incoming_handshakes_work); @@ -188,11 +205,6 @@ static void keep_key_fresh(struct wireguard_peer *peer) } } -struct packet_cb { - u8 ds; -}; -#define PACKET_CB(skb) ((struct packet_cb *)skb->cb) - static void receive_data_packet(struct sk_buff *skb, struct wireguard_peer *peer, struct endpoint *endpoint, bool used_new_key, int err) { struct net_device *dev; @@ -276,11 +288,10 @@ continue_processing: void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) { - size_t len, offset; - - if (unlikely(skb_data_offset(skb, &offset, &len) < 0)) + int message_type = skb_prepare_header(skb); + if (unlikely(message_type < 0)) goto err; - switch (message_determine_type(skb->data + offset, len)) { + switch (message_type) { case MESSAGE_HANDSHAKE_INITIATION: case MESSAGE_HANDSHAKE_RESPONSE: case MESSAGE_HANDSHAKE_COOKIE: @@ -288,17 +299,13 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) net_dbg_skb_ratelimited("Too many handshakes queued, dropping packet from %pISpfsc\n", skb); goto err; } - if (skb_linearize(skb) < 0) { - net_dbg_skb_ratelimited("Unable to linearize handshake skb from %pISpfsc\n", skb); - goto err; - } skb_queue_tail(&wg->incoming_handshakes, skb); /* Queues up a call to packet_process_queued_handshake_packets(skb): */ queue_work(wg->workqueue, &wg->incoming_handshakes_work); break; case MESSAGE_DATA: PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb); - packet_consume_data(skb, offset, wg, receive_data_packet); + packet_consume_data(skb, wg, receive_data_packet); break; default: net_dbg_skb_ratelimited("Invalid packet from %pISpfsc\n", skb); @@ -77,12 +77,12 @@ void packet_send_handshake_response(struct wireguard_peer *peer) } } -void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, void *data, size_t data_len, __le32 sender_index) +void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index) { struct message_handshake_cookie packet; net_dbg_skb_ratelimited("Sending cookie response for denied handshake message for %pISpfsc\n", initiating_skb); - cookie_message_create(&packet, initiating_skb, data, data_len, sender_index, &wg->cookie_checker); + cookie_message_create(&packet, initiating_skb, sender_index, &wg->cookie_checker); socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, sizeof(packet)); } @@ -123,10 +123,13 @@ static void message_create_data_done(struct sk_buff_head *queue, struct wireguar struct sk_buff *skb, *tmp; bool is_keepalive, data_sent = false; + if (unlikely(!skb_queue_len(queue))) + return; + timers_any_authenticated_packet_traversal(peer); skb_queue_walk_safe(queue, skb, tmp) { is_keepalive = skb->len == message_data_len(0); - if (likely(!socket_send_skb_to_peer(peer, skb, *(u8 *)skb->cb) && !is_keepalive)) + if (likely(!socket_send_skb_to_peer(peer, skb, PACKET_CB(skb)->ds) && !is_keepalive)) data_sent = true; } if (likely(data_sent)) |