diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/queueing.h | 1 | ||||
-rw-r--r-- | src/receive.c | 26 |
2 files changed, 21 insertions, 6 deletions
diff --git a/src/queueing.h b/src/queueing.h index 68640b7..9b9b6a6 100644 --- a/src/queueing.h +++ b/src/queueing.h @@ -55,7 +55,6 @@ struct crypt_ctx { atomic_t is_finished; struct wireguard_peer *peer; struct noise_keypair *keypair; - struct endpoint endpoint; }; /* Returns either the correct skb->protocol value, or 0 if invalid. */ diff --git a/src/receive.c b/src/receive.c index 25349f9..9e03bcf 100644 --- a/src/receive.c +++ b/src/receive.c @@ -201,6 +201,7 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * { struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1]; struct sk_buff *trailer; + unsigned int offset; int num_frags; if (unlikely(!key)) @@ -212,8 +213,15 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * } PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter); - skb_pull(skb, sizeof(struct message_data)); + + /* We ensure that the network header is part of the packet before we + * call skb_cow_data, so that there's no chance that data is removed + * from the skb, so that later we can extract the original endpoint. */ + offset = skb->data - skb_network_header(skb); + skb_push(skb, offset); num_frags = skb_cow_data(skb, 0, &trailer); + offset += sizeof(struct message_data); + skb_pull(skb, offset); if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg))) return false; @@ -224,7 +232,14 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key)) return false; - return !pskb_trim(skb, skb->len - noise_encrypted_len(0)); + /* Another ugly situation of pushing and pulling the header so as to + * keep endpoint information intact. */ + skb_push(skb, offset); + if (pskb_trim(skb, skb->len - noise_encrypted_len(0))) + return false; + skb_pull(skb, offset); + + return true; } /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */ @@ -349,6 +364,7 @@ continue_processing: void packet_rx_worker(struct work_struct *work) { + struct endpoint endpoint; struct crypt_ctx *ctx; struct crypt_queue *queue = container_of(work, struct crypt_queue, work); @@ -357,9 +373,9 @@ void packet_rx_worker(struct work_struct *work) while ((ctx = __ptr_ring_peek(&queue->ring)) != NULL && atomic_read(&ctx->is_finished)) { __ptr_ring_discard_one(&queue->ring); if (likely(ctx->skb)) { - if (likely(counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce))) { + if (likely(counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce) && !socket_endpoint_from_skb(&endpoint, ctx->skb))) { skb_reset(ctx->skb); - packet_consume_data_done(ctx->skb, ctx->peer, &ctx->endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair)); + packet_consume_data_done(ctx->skb, ctx->peer, &endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair)); } else { net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter); @@ -381,7 +397,7 @@ void packet_decrypt_worker(struct work_struct *work) struct wireguard_peer *peer; while ((ctx = ptr_ring_consume_bh(&queue->ring)) != NULL) { - if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) { + if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving))) { dev_kfree_skb(ctx->skb); ctx->skb = NULL; } |