summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/messages.h23
-rw-r--r--src/receive.c52
2 files changed, 33 insertions, 42 deletions
diff --git a/src/messages.h b/src/messages.h
index 6da04e5..acf248c 100644
--- a/src/messages.h
+++ b/src/messages.h
@@ -125,27 +125,4 @@ enum {
HANDSHAKE_DSCP = 0b10001000 /* AF41, plus 00 ECN */
};
-static const unsigned int message_header_sizes[MESSAGE_TOTAL] = {
- [MESSAGE_HANDSHAKE_INITIATION] = sizeof(struct message_handshake_initiation),
- [MESSAGE_HANDSHAKE_RESPONSE] = sizeof(struct message_handshake_response),
- [MESSAGE_HANDSHAKE_COOKIE] = sizeof(struct message_handshake_cookie),
- [MESSAGE_DATA] = sizeof(struct message_data)
-};
-
-static inline enum message_type message_determine_type(struct sk_buff *skb)
-{
- struct message_header *header = (struct message_header *)skb->data;
- if (unlikely(skb->len < sizeof(struct message_header)))
- return MESSAGE_INVALID;
- if (header->type == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH)
- return MESSAGE_DATA;
- if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation))
- return MESSAGE_HANDSHAKE_INITIATION;
- if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response))
- return MESSAGE_HANDSHAKE_RESPONSE;
- if (header->type == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie))
- return MESSAGE_HANDSHAKE_COOKIE;
- return MESSAGE_INVALID;
-}
-
#endif
diff --git a/src/receive.c b/src/receive.c
index 68683c0..7177c59 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -31,11 +31,27 @@ static inline void update_latest_addr(struct wireguard_peer *peer, struct sk_buf
socket_set_peer_endpoint(peer, &endpoint);
}
+#define SKB_TYPE_LE32(skb) ((struct message_header *)(skb)->data)->type
+
+static inline size_t validate_header_len(struct sk_buff *skb)
+{
+ if (unlikely(skb->len < sizeof(struct message_header)))
+ return 0;
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH)
+ return sizeof(struct message_data);
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation))
+ return sizeof(struct message_handshake_initiation);
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response))
+ return sizeof(struct message_handshake_response);
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie))
+ return sizeof(struct message_handshake_cookie);
+ return 0;
+}
+
static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_device *wg)
{
struct udphdr *udp;
- size_t data_offset, data_len;
- enum message_type message_type;
+ size_t data_offset, data_len, header_len;
if (unlikely(skb_examine_untrusted_ip_hdr(skb) != skb->protocol || skb_transport_header(skb) < skb->head || (skb_transport_header(skb) + sizeof(struct udphdr)) > skb_tail_pointer(skb)))
return -EINVAL; /* Bogus IP header */
udp = udp_hdr(skb);
@@ -52,26 +68,25 @@ static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_devic
skb_pull(skb, data_offset);
if (unlikely(skb->len != data_len))
return -EINVAL; /* Final len does not agree with calculated len */
- message_type = message_determine_type(skb);
+ header_len = validate_header_len(skb);
+ if (unlikely(!header_len))
+ return -EINVAL;
__skb_push(skb, data_offset);
- if (unlikely(!pskb_may_pull(skb, data_offset + message_header_sizes[message_type])))
+ if (unlikely(!pskb_may_pull(skb, data_offset + header_len)))
return -EINVAL;
__skb_pull(skb, data_offset);
- return message_type;
+ return 0;
}
static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff *skb)
{
static unsigned long last_under_load = 0; /* Yes this is global, so that our load calculation applies to the whole system. */
struct wireguard_peer *peer = NULL;
- enum message_type message_type;
bool under_load;
enum cookie_mac_state mac_state;
bool packet_needs_cookie;
- message_type = message_determine_type(skb);
-
- if (message_type == MESSAGE_HANDSHAKE_COOKIE) {
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE)) {
net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n", wg->dev->name, skb);
cookie_message_consume((struct message_handshake_cookie *)skb->data, wg);
return;
@@ -92,8 +107,8 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff
return;
}
- switch (message_type) {
- case MESSAGE_HANDSHAKE_INITIATION: {
+ switch (SKB_TYPE_LE32(skb)) {
+ case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): {
struct message_handshake_initiation *message = (struct message_handshake_initiation *)skb->data;
if (packet_needs_cookie) {
packet_send_handshake_cookie(wg, skb, message->sender_index);
@@ -109,7 +124,7 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff
packet_send_handshake_response(peer);
break;
}
- case MESSAGE_HANDSHAKE_RESPONSE: {
+ case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): {
struct message_handshake_response *message = (struct message_handshake_response *)skb->data;
if (packet_needs_cookie) {
packet_send_handshake_cookie(wg, skb, message->sender_index);
@@ -411,13 +426,12 @@ static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
{
- int message_type = skb_prepare_header(skb, wg);
- if (unlikely(message_type < 0))
+ if (unlikely(skb_prepare_header(skb, wg) < 0))
goto err;
- switch (message_type) {
- case MESSAGE_HANDSHAKE_INITIATION:
- case MESSAGE_HANDSHAKE_RESPONSE:
- case MESSAGE_HANDSHAKE_COOKIE: {
+ switch (SKB_TYPE_LE32(skb)) {
+ case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION):
+ case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE):
+ case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): {
int cpu;
if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES) {
net_dbg_skb_ratelimited("%s: Too many handshakes queued, dropping packet from %pISpfsc\n", wg->dev->name, skb);
@@ -429,7 +443,7 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
queue_work_on(cpu, wg->handshake_receive_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work);
break;
}
- case MESSAGE_DATA:
+ case cpu_to_le32(MESSAGE_DATA):
PACKET_CB(skb)->ds = ip_tunnel_get_dsfield(ip_hdr(skb), skb);
packet_consume_data(wg, skb);
break;