diff options
-rw-r--r-- | src/messages.h | 23 | ||||
-rw-r--r-- | src/receive.c | 52 |
2 files changed, 33 insertions, 42 deletions
diff --git a/src/messages.h b/src/messages.h index 6da04e5..acf248c 100644 --- a/src/messages.h +++ b/src/messages.h @@ -125,27 +125,4 @@ enum { HANDSHAKE_DSCP = 0b10001000 /* AF41, plus 00 ECN */ }; -static const unsigned int message_header_sizes[MESSAGE_TOTAL] = { - [MESSAGE_HANDSHAKE_INITIATION] = sizeof(struct message_handshake_initiation), - [MESSAGE_HANDSHAKE_RESPONSE] = sizeof(struct message_handshake_response), - [MESSAGE_HANDSHAKE_COOKIE] = sizeof(struct message_handshake_cookie), - [MESSAGE_DATA] = sizeof(struct message_data) -}; - -static inline enum message_type message_determine_type(struct sk_buff *skb) -{ - struct message_header *header = (struct message_header *)skb->data; - if (unlikely(skb->len < sizeof(struct message_header))) - return MESSAGE_INVALID; - if (header->type == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH) - return MESSAGE_DATA; - if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation)) - return MESSAGE_HANDSHAKE_INITIATION; - if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response)) - return MESSAGE_HANDSHAKE_RESPONSE; - if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie)) - return MESSAGE_HANDSHAKE_COOKIE; - return MESSAGE_INVALID; -} - #endif diff --git a/src/receive.c b/src/receive.c index 68683c0..7177c59 100644 --- a/src/receive.c +++ b/src/receive.c @@ -31,11 +31,27 @@ static inline void update_latest_addr(struct wireguard_peer *peer, struct sk_buf socket_set_peer_endpoint(peer, &endpoint); } +#define SKB_TYPE_LE32(skb) ((struct message_header *)(skb)->data)->type + +static inline size_t validate_header_len(struct sk_buff *skb) +{ + if (unlikely(skb->len < sizeof(struct message_header))) + return 0; + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH) + return sizeof(struct message_data); + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation)) + return sizeof(struct message_handshake_initiation); + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response)) + return sizeof(struct message_handshake_response); + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie)) + return sizeof(struct message_handshake_cookie); + return 0; +} + static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_device *wg) { struct udphdr *udp; - size_t data_offset, data_len; - enum message_type message_type; + size_t data_offset, data_len, header_len; if (unlikely(skb_examine_untrusted_ip_hdr(skb) != skb->protocol || skb_transport_header(skb) < skb->head || (skb_transport_header(skb) + sizeof(struct udphdr)) > skb_tail_pointer(skb))) return -EINVAL; /* Bogus IP header */ udp = udp_hdr(skb); @@ -52,26 +68,25 @@ static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_devic skb_pull(skb, data_offset); if (unlikely(skb->len != data_len)) return -EINVAL; /* Final len does not agree with calculated len */ - message_type = message_determine_type(skb); + header_len = validate_header_len(skb); + if (unlikely(!header_len)) + return -EINVAL; __skb_push(skb, data_offset); - if (unlikely(!pskb_may_pull(skb, data_offset + message_header_sizes[message_type]))) + if (unlikely(!pskb_may_pull(skb, data_offset + header_len))) return -EINVAL; __skb_pull(skb, data_offset); - return message_type; + return 0; } static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff *skb) { static unsigned long last_under_load = 0; /* Yes this is global, so that our load calculation applies to the whole system. */ struct wireguard_peer *peer = NULL; - enum message_type message_type; bool under_load; enum cookie_mac_state mac_state; bool packet_needs_cookie; - message_type = message_determine_type(skb); - - if (message_type == MESSAGE_HANDSHAKE_COOKIE) { + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE)) { net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n", wg->dev->name, skb); cookie_message_consume((struct message_handshake_cookie *)skb->data, wg); return; @@ -92,8 +107,8 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff return; } - switch (message_type) { - case MESSAGE_HANDSHAKE_INITIATION: { + switch (SKB_TYPE_LE32(skb)) { + case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): { struct message_handshake_initiation *message = (struct message_handshake_initiation *)skb->data; if (packet_needs_cookie) { packet_send_handshake_cookie(wg, skb, message->sender_index); @@ -109,7 +124,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff packet_send_handshake_response(peer); break; } - case MESSAGE_HANDSHAKE_RESPONSE: { + case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): { struct message_handshake_response *message = (struct message_handshake_response *)skb->data; if (packet_needs_cookie) { packet_send_handshake_cookie(wg, skb, message->sender_index); @@ -411,13 +426,12 @@ static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) { - int message_type = skb_prepare_header(skb, wg); - if (unlikely(message_type < 0)) + if (unlikely(skb_prepare_header(skb, wg) < 0)) goto err; - switch (message_type) { - case MESSAGE_HANDSHAKE_INITIATION: - case MESSAGE_HANDSHAKE_RESPONSE: - case MESSAGE_HANDSHAKE_COOKIE: { + switch (SKB_TYPE_LE32(skb)) { + case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): + case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): + case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): { int cpu; if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES) { net_dbg_skb_ratelimited("%s: Too many handshakes queued, dropping packet from %pISpfsc\n", wg->dev->name, skb); @@ -429,7 +443,7 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) queue_work_on(cpu, wg->handshake_receive_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work); break; } - case MESSAGE_DATA: + case cpu_to_le32(MESSAGE_DATA): PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb); packet_consume_data(wg, skb); break; |