summaryrefslogtreecommitdiffhomepage
path: root/src/receive.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/receive.c')
-rw-r--r--src/receive.c26
1 files changed, 21 insertions, 5 deletions
diff --git a/src/receive.c b/src/receive.c
index 25349f9..9e03bcf 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -201,6 +201,7 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *
{
struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1];
struct sk_buff *trailer;
+ unsigned int offset;
int num_frags;
if (unlikely(!key))
@@ -212,8 +213,15 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *
}
PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter);
- skb_pull(skb, sizeof(struct message_data));
+
+ /* We ensure that the network header is part of the packet before we
+ * call skb_cow_data, so that there's no chance that data is removed
+ * from the skb, so that later we can extract the original endpoint. */
+ offset = skb->data - skb_network_header(skb);
+ skb_push(skb, offset);
num_frags = skb_cow_data(skb, 0, &trailer);
+ offset += sizeof(struct message_data);
+ skb_pull(skb, offset);
if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
return false;
@@ -224,7 +232,14 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *
if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key))
return false;
- return !pskb_trim(skb, skb->len - noise_encrypted_len(0));
+ /* Another ugly situation of pushing and pulling the header so as to
+ * keep endpoint information intact. */
+ skb_push(skb, offset);
+ if (pskb_trim(skb, skb->len - noise_encrypted_len(0)))
+ return false;
+ skb_pull(skb, offset);
+
+ return true;
}
/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
@@ -349,6 +364,7 @@ continue_processing:
void packet_rx_worker(struct work_struct *work)
{
+ struct endpoint endpoint;
struct crypt_ctx *ctx;
struct crypt_queue *queue = container_of(work, struct crypt_queue, work);
@@ -357,9 +373,9 @@ void packet_rx_worker(struct work_struct *work)
while ((ctx = __ptr_ring_peek(&queue->ring)) != NULL && atomic_read(&ctx->is_finished)) {
__ptr_ring_discard_one(&queue->ring);
if (likely(ctx->skb)) {
- if (likely(counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce))) {
+ if (likely(counter_validate(&ctx->keypair->receiving.counter, PACKET_CB(ctx->skb)->nonce) && !socket_endpoint_from_skb(&endpoint, ctx->skb))) {
skb_reset(ctx->skb);
- packet_consume_data_done(ctx->skb, ctx->peer, &ctx->endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair));
+ packet_consume_data_done(ctx->skb, ctx->peer, &endpoint, noise_received_with_keypair(&ctx->peer->keypairs, ctx->keypair));
}
else {
net_dbg_ratelimited("%s: Packet has invalid nonce %Lu (max %Lu)\n", ctx->peer->device->dev->name, PACKET_CB(ctx->skb)->nonce, ctx->keypair->receiving.counter.receive.counter);
@@ -381,7 +397,7 @@ void packet_decrypt_worker(struct work_struct *work)
struct wireguard_peer *peer;
while ((ctx = ptr_ring_consume_bh(&queue->ring)) != NULL) {
- if (unlikely(socket_endpoint_from_skb(&ctx->endpoint, ctx->skb) < 0 || !skb_decrypt(ctx->skb, &ctx->keypair->receiving))) {
+ if (unlikely(!skb_decrypt(ctx->skb, &ctx->keypair->receiving))) {
dev_kfree_skb(ctx->skb);
ctx->skb = NULL;
}