diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/allowedips.c | 148 | ||||
-rw-r--r-- | src/allowedips.h | 23 | ||||
-rw-r--r-- | src/cookie.c | 102 | ||||
-rw-r--r-- | src/cookie.h | 19 | ||||
-rw-r--r-- | src/device.c | 95 | ||||
-rw-r--r-- | src/device.h | 3 | ||||
-rw-r--r-- | src/hashtables.c | 105 | ||||
-rw-r--r-- | src/hashtables.h | 25 | ||||
-rw-r--r-- | src/main.c | 5 | ||||
-rw-r--r-- | src/messages.h | 14 | ||||
-rw-r--r-- | src/netlink.c | 215 | ||||
-rw-r--r-- | src/noise.c | 360 | ||||
-rw-r--r-- | src/noise.h | 33 | ||||
-rw-r--r-- | src/peer.c | 85 | ||||
-rw-r--r-- | src/peer.h | 21 | ||||
-rw-r--r-- | src/queueing.c | 14 | ||||
-rw-r--r-- | src/queueing.h | 58 | ||||
-rw-r--r-- | src/ratelimiter.c | 47 | ||||
-rw-r--r-- | src/receive.c | 257 | ||||
-rw-r--r-- | src/selftest/allowedips.h | 276 | ||||
-rw-r--r-- | src/selftest/counter.h | 25 | ||||
-rw-r--r-- | src/selftest/ratelimiter.h | 48 | ||||
-rw-r--r-- | src/send.c | 203 | ||||
-rw-r--r-- | src/socket.c | 108 | ||||
-rw-r--r-- | src/socket.h | 33 | ||||
-rw-r--r-- | src/timers.c | 122 | ||||
-rw-r--r-- | src/timers.h | 3 |
27 files changed, 1653 insertions, 794 deletions
diff --git a/src/allowedips.c b/src/allowedips.c index 1442bf4..4616645 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -28,7 +28,8 @@ static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits) } } -static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 cidr, u8 bits) +static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, + u8 cidr, u8 bits) { node->cidr = cidr; node->bit_at_a = cidr / 8U; @@ -39,34 +40,43 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 memcpy(node->bits, src, bits / 8U); } -#define choose_node(parent, key) parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] +#define choose_node(parent, key) \ + parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] static void node_free_rcu(struct rcu_head *rcu) { kfree(container_of(rcu, struct allowedips_node, rcu)); } -#define push_rcu(stack, p, len) ({ \ - if (rcu_access_pointer(p)) { \ - BUG_ON(len >= 128); \ - stack[len++] = rcu_dereference_raw(p); \ - } \ - true; \ -}) +#define push_rcu(stack, p, len) ({ \ + if (rcu_access_pointer(p)) { \ + BUG_ON(len >= 128); \ + stack[len++] = rcu_dereference_raw(p); \ + } \ + true; \ + }) static void root_free_rcu(struct rcu_head *rcu) { - struct allowedips_node *node, *stack[128] = { container_of(rcu, struct allowedips_node, rcu) }; + struct allowedips_node *node, *stack[128] = + { container_of(rcu, struct allowedips_node, rcu) }; unsigned int len = 1; - while (len > 0 && (node = stack[--len]) && push_rcu(stack, node->bit[0], len) && push_rcu(stack, node->bit[1], len)) + while (len > 0 && (node = stack[--len]) && + push_rcu(stack, node->bit[0], len) && + push_rcu(stack, node->bit[1], len)) kfree(node); } -static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock) +static int +walk_by_peer(struct allowedips_node __rcu *top, u8 bits, + struct allowedips_cursor *cursor, struct wireguard_peer *peer, + int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), + void *ctx, struct mutex *lock) { + const int address_family = bits == 32 ? AF_INET : AF_INET6; + u8 ip[16] __aligned(__alignof(u64)); struct allowedips_node *node; int ret; - u8 ip[16] __aligned(__alignof(u64)); if (!rcu_access_pointer(top)) return 0; @@ -74,16 +84,21 @@ static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allow if (!cursor->len) push_rcu(cursor->stack, top, cursor->len); - for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len), push_rcu(cursor->stack, node->bit[1], cursor->len)) { - if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) != peer) + for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); + --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len), + push_rcu(cursor->stack, node->bit[1], cursor->len)) { + const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); + + if (rcu_dereference_protected(node->peer, + lockdep_is_held(lock)) != peer) continue; swap_endian(ip, node->bits, bits); - memset(ip + (node->cidr + 7U) / 8U, 0, (bits / 8U) - ((node->cidr + 7U) / 8U)); + memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes); if (node->cidr) - ip[(node->cidr + 7U) / 8U - 1U] &= ~0U << (-node->cidr % 8U); + ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); - ret = func(ctx, ip, node->cidr, bits == 32 ? AF_INET : AF_INET6); + ret = func(ctx, ip, node->cidr, address_family); if (ret) return ret; } @@ -93,8 +108,12 @@ static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allow #define ref(p) rcu_access_pointer(p) #define deref(p) rcu_dereference_protected(*p, lockdep_is_held(lock)) -#define push(p) ({ BUG_ON(len >= 128); stack[len++] = p; }) -static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireguard_peer *peer, struct mutex *lock) +#define push(p) ({ \ + BUG_ON(len >= 128); \ + stack[len++] = p; \ + }) +static void walk_remove_by_peer(struct allowedips_node __rcu **top, + struct wireguard_peer *peer, struct mutex *lock) { struct allowedips_node __rcu **stack[128], **nptr; struct allowedips_node *node, *prev; @@ -110,7 +129,8 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireg --len; continue; } - if (!prev || ref(prev->bit[0]) == node || ref(prev->bit[1]) == node) { + if (!prev || ref(prev->bit[0]) == node || + ref(prev->bit[1]) == node) { if (ref(node->bit[0])) push(&node->bit[0]); else if (ref(node->bit[1])) @@ -119,10 +139,12 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireg if (ref(node->bit[1])) push(&node->bit[1]); } else { - if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) { + if (rcu_dereference_protected(node->peer, + lockdep_is_held(lock)) == peer) { RCU_INIT_POINTER(node->peer, NULL); if (!node->bit[0] || !node->bit[1]) { - rcu_assign_pointer(*nptr, deref(&node->bit[!ref(node->bit[0])])); + rcu_assign_pointer(*nptr, + deref(&node->bit[!ref(node->bit[0])])); call_rcu_bh(&node->rcu, node_free_rcu); node = deref(nptr); } @@ -140,23 +162,29 @@ static __always_inline unsigned int fls128(u64 a, u64 b) return a ? fls64(a) + 64U : fls64(b); } -static __always_inline u8 common_bits(const struct allowedips_node *node, const u8 *key, u8 bits) +static __always_inline u8 common_bits(const struct allowedips_node *node, + const u8 *key, u8 bits) { if (bits == 32) return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key); else if (bits == 128) - return 128U - fls128(*(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); + return 128U - fls128( + *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], + *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); return 0; } -/* This could be much faster if it actually just compared the common bits properly, - * by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but it turns out that - * common_bits is already super fast on modern processors, even taking into account - * the unfortunate bswap. So, we just inline it like this instead. +/* This could be much faster if it actually just compared the common bits + * properly, by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but + * it turns out that common_bits is already super fast on modern processors, + * even taking into account the unfortunate bswap. So, we just inline it like + * this instead. */ -#define prefix_matches(node, key, bits) (common_bits(node, key, bits) >= node->cidr) +#define prefix_matches(node, key, bits) \ + (common_bits(node, key, bits) >= node->cidr) -static __always_inline struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, const u8 *key) +static __always_inline struct allowedips_node * +find_node(struct allowedips_node *trie, u8 bits, const u8 *key) { struct allowedips_node *node = trie, *found = NULL; @@ -171,11 +199,12 @@ static __always_inline struct allowedips_node *find_node(struct allowedips_node } /* Returns a strong reference to a peer */ -static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip) +static __always_inline struct wireguard_peer * +lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip) { + u8 ip[16] __aligned(__alignof(u64)); struct wireguard_peer *peer = NULL; struct allowedips_node *node; - u8 ip[16] __aligned(__alignof(u64)); swap_endian(ip, be_ip, bits); @@ -191,11 +220,14 @@ retry: return peer; } -__attribute__((nonnull(1))) -static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, u8 bits, struct allowedips_node **rnode, struct mutex *lock) +__attribute__((nonnull(1))) static inline bool +node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, + u8 bits, struct allowedips_node **rnode, struct mutex *lock) { + struct allowedips_node *node = rcu_dereference_protected(trie, + lockdep_is_held(lock)); + struct allowedips_node *parent = NULL; bool exact = false; - struct allowedips_node *parent = NULL, *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { parent = node; @@ -203,13 +235,15 @@ static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 * exact = true; break; } - node = rcu_dereference_protected(choose_node(parent, key), lockdep_is_held(lock)); + node = rcu_dereference_protected(choose_node(parent, key), + lockdep_is_held(lock)); } *rnode = parent; return exact; } -static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) +static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, + u8 cidr, struct wireguard_peer *peer, struct mutex *lock) { struct allowedips_node *node, *parent, *down, *newnode; u8 key[16] __aligned(__alignof(u64)); @@ -242,7 +276,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u if (!node) down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); else { - down = rcu_dereference_protected(choose_node(node, key), lockdep_is_held(lock)); + down = rcu_dereference_protected(choose_node(node, key), + lockdep_is_held(lock)); if (!down) { rcu_assign_pointer(choose_node(node, key), newnode); return 0; @@ -256,7 +291,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u if (!parent) rcu_assign_pointer(*trie, newnode); else - rcu_assign_pointer(choose_node(parent, newnode->bits), newnode); + rcu_assign_pointer(choose_node(parent, newnode->bits), + newnode); } else { node = kzalloc(sizeof(*node), GFP_KERNEL); if (!node) { @@ -270,7 +306,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u if (!parent) rcu_assign_pointer(*trie, node); else - rcu_assign_pointer(choose_node(parent, node->bits), node); + rcu_assign_pointer(choose_node(parent, node->bits), + node); } return 0; } @@ -288,31 +325,42 @@ void allowedips_free(struct allowedips *table, struct mutex *lock) RCU_INIT_POINTER(table->root4, NULL); RCU_INIT_POINTER(table->root6, NULL); if (rcu_access_pointer(old4)) - call_rcu_bh(&rcu_dereference_protected(old4, lockdep_is_held(lock))->rcu, root_free_rcu); + call_rcu_bh(&rcu_dereference_protected(old4, + lockdep_is_held(lock))->rcu, root_free_rcu); if (rcu_access_pointer(old6)) - call_rcu_bh(&rcu_dereference_protected(old6, lockdep_is_held(lock))->rcu, root_free_rcu); + call_rcu_bh(&rcu_dereference_protected(old6, + lockdep_is_held(lock))->rcu, root_free_rcu); } -int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) +int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wireguard_peer *peer, + struct mutex *lock) { ++table->seq; return add(&table->root4, 32, (const u8 *)ip, cidr, peer, lock); } -int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) +int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wireguard_peer *peer, + struct mutex *lock) { ++table->seq; return add(&table->root6, 128, (const u8 *)ip, cidr, peer, lock); } -void allowedips_remove_by_peer(struct allowedips *table, struct wireguard_peer *peer, struct mutex *lock) +void allowedips_remove_by_peer(struct allowedips *table, + struct wireguard_peer *peer, struct mutex *lock) { ++table->seq; walk_remove_by_peer(&table->root4, peer, lock); walk_remove_by_peer(&table->root6, peer, lock); } -int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock) +int allowedips_walk_by_peer(struct allowedips *table, + struct allowedips_cursor *cursor, + struct wireguard_peer *peer, + int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), + void *ctx, struct mutex *lock) { int ret; @@ -332,7 +380,8 @@ int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor * } /* Returns a strong reference to a peer */ -struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb) +struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, + struct sk_buff *skb) { if (skb->protocol == htons(ETH_P_IP)) return lookup(table->root4, 32, &ip_hdr(skb)->daddr); @@ -342,7 +391,8 @@ struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk } /* Returns a strong reference to a peer */ -struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, struct sk_buff *skb) +struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, + struct sk_buff *skb) { if (skb->protocol == htons(ETH_P_IP)) return lookup(table->root4, 32, &ip_hdr(skb)->saddr); diff --git a/src/allowedips.h b/src/allowedips.h index 97ecf69..d5ba1be 100644 --- a/src/allowedips.h +++ b/src/allowedips.h @@ -28,14 +28,25 @@ struct allowedips_cursor { void allowedips_init(struct allowedips *table); void allowedips_free(struct allowedips *table, struct mutex *mutex); -int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock); -int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock); -void allowedips_remove_by_peer(struct allowedips *table, struct wireguard_peer *peer, struct mutex *lock); -int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock); +int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wireguard_peer *peer, + struct mutex *lock); +int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wireguard_peer *peer, + struct mutex *lock); +void allowedips_remove_by_peer(struct allowedips *table, + struct wireguard_peer *peer, struct mutex *lock); +int allowedips_walk_by_peer(struct allowedips *table, + struct allowedips_cursor *cursor, + struct wireguard_peer *peer, + int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), + void *ctx, struct mutex *lock); /* These return a strong reference to a peer: */ -struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb); -struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, struct sk_buff *skb); +struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, + struct sk_buff *skb); +struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, + struct sk_buff *skb); #ifdef DEBUG bool allowedips_selftest(void); diff --git a/src/cookie.c b/src/cookie.c index 9268630..7cf0693 100644 --- a/src/cookie.c +++ b/src/cookie.c @@ -15,7 +15,8 @@ #include <net/ipv6.h> #include <crypto/algapi.h> -void cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg) +void cookie_checker_init(struct cookie_checker *checker, + struct wireguard_device *wg) { init_rwsem(&checker->secret_lock); checker->secret_birthdate = ktime_get_boot_fast_ns(); @@ -27,7 +28,9 @@ enum { COOKIE_KEY_LABEL_LEN = 8 }; static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----"; static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--"; -static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOISE_PUBLIC_KEY_LEN], const u8 label[COOKIE_KEY_LABEL_LEN]) +static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 pubkey[NOISE_PUBLIC_KEY_LEN], + const u8 label[COOKIE_KEY_LABEL_LEN]) { struct blake2s_state blake; @@ -41,18 +44,25 @@ static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOIS void cookie_checker_precompute_device_keys(struct cookie_checker *checker) { if (likely(checker->device->static_identity.has_identity)) { - precompute_key(checker->cookie_encryption_key, checker->device->static_identity.static_public, cookie_key_label); - precompute_key(checker->message_mac1_key, checker->device->static_identity.static_public, mac1_key_label); + precompute_key(checker->cookie_encryption_key, + checker->device->static_identity.static_public, + cookie_key_label); + precompute_key(checker->message_mac1_key, + checker->device->static_identity.static_public, + mac1_key_label); } else { - memset(checker->cookie_encryption_key, 0, NOISE_SYMMETRIC_KEY_LEN); + memset(checker->cookie_encryption_key, 0, + NOISE_SYMMETRIC_KEY_LEN); memset(checker->message_mac1_key, 0, NOISE_SYMMETRIC_KEY_LEN); } } void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer) { - precompute_key(peer->latest_cookie.cookie_decryption_key, peer->handshake.remote_static, cookie_key_label); - precompute_key(peer->latest_cookie.message_mac1_key, peer->handshake.remote_static, mac1_key_label); + precompute_key(peer->latest_cookie.cookie_decryption_key, + peer->handshake.remote_static, cookie_key_label); + precompute_key(peer->latest_cookie.message_mac1_key, + peer->handshake.remote_static, mac1_key_label); } void cookie_init(struct cookie *cookie) @@ -61,19 +71,24 @@ void cookie_init(struct cookie *cookie) init_rwsem(&cookie->lock); } -static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, const u8 key[NOISE_SYMMETRIC_KEY_LEN]) +static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, + const u8 key[NOISE_SYMMETRIC_KEY_LEN]) { - len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac1); + len = len - sizeof(struct message_macs) + + offsetof(struct message_macs, mac1); blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN); } -static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, const u8 cookie[COOKIE_LEN]) +static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, + const u8 cookie[COOKIE_LEN]) { - len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac2); + len = len - sizeof(struct message_macs) + + offsetof(struct message_macs, mac2); blake2s(mac2, message, cookie, COOKIE_LEN, len, COOKIE_LEN); } -static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cookie_checker *checker) +static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, + struct cookie_checker *checker) { struct blake2s_state state; @@ -88,24 +103,30 @@ static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cooki blake2s_init_key(&state, COOKIE_LEN, checker->secret, NOISE_HASH_LEN); if (skb->protocol == htons(ETH_P_IP)) - blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr, sizeof(struct in_addr)); + blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr, + sizeof(struct in_addr)); else if (skb->protocol == htons(ETH_P_IPV6)) - blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr, sizeof(struct in6_addr)); + blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr, + sizeof(struct in6_addr)); blake2s_update(&state, (u8 *)&udp_hdr(skb)->source, sizeof(__be16)); blake2s_final(&state, cookie, COOKIE_LEN); up_read(&checker->secret_lock); } -enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie) +enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, + struct sk_buff *skb, + bool check_cookie) { + struct message_macs *macs = (struct message_macs *) + (skb->data + skb->len - sizeof(struct message_macs)); + enum cookie_mac_state ret; u8 computed_mac[COOKIE_LEN]; u8 cookie[COOKIE_LEN]; - enum cookie_mac_state ret; - struct message_macs *macs = (struct message_macs *)(skb->data + skb->len - sizeof(struct message_macs)); ret = INVALID_MAC; - compute_mac1(computed_mac, skb->data, skb->len, checker->message_mac1_key); + compute_mac1(computed_mac, skb->data, skb->len, + checker->message_mac1_key); if (crypto_memneq(computed_mac, macs->mac1, COOKIE_LEN)) goto out; @@ -130,27 +151,36 @@ out: return ret; } -void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *peer) +void cookie_add_mac_to_packet(void *message, size_t len, + struct wireguard_peer *peer) { - struct message_macs *macs = (struct message_macs *)((u8 *)message + len - sizeof(struct message_macs)); + struct message_macs *macs = (struct message_macs *) + ((u8 *)message + len - sizeof(struct message_macs)); down_write(&peer->latest_cookie.lock); - compute_mac1(macs->mac1, message, len, peer->latest_cookie.message_mac1_key); + compute_mac1(macs->mac1, message, len, + peer->latest_cookie.message_mac1_key); memcpy(peer->latest_cookie.last_mac1_sent, macs->mac1, COOKIE_LEN); peer->latest_cookie.have_sent_mac1 = true; up_write(&peer->latest_cookie.lock); down_read(&peer->latest_cookie.lock); - if (peer->latest_cookie.is_valid && !has_expired(peer->latest_cookie.birthdate, COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY)) - compute_mac2(macs->mac2, message, len, peer->latest_cookie.cookie); + if (peer->latest_cookie.is_valid && + !has_expired(peer->latest_cookie.birthdate, + COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY)) + compute_mac2(macs->mac2, message, len, + peer->latest_cookie.cookie); else memset(macs->mac2, 0, COOKIE_LEN); up_read(&peer->latest_cookie.lock); } -void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff *skb, __le32 index, struct cookie_checker *checker) +void cookie_message_create(struct message_handshake_cookie *dst, + struct sk_buff *skb, __le32 index, + struct cookie_checker *checker) { - struct message_macs *macs = (struct message_macs *)((u8 *)skb->data + skb->len - sizeof(struct message_macs)); + struct message_macs *macs = (struct message_macs *) + ((u8 *)skb->data + skb->len - sizeof(struct message_macs)); u8 cookie[COOKIE_LEN]; dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE); @@ -158,16 +188,22 @@ void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff get_random_bytes_wait(dst->nonce, COOKIE_NONCE_LEN); make_cookie(cookie, skb, checker); - xchacha20poly1305_encrypt(dst->encrypted_cookie, cookie, COOKIE_LEN, macs->mac1, COOKIE_LEN, dst->nonce, checker->cookie_encryption_key); + xchacha20poly1305_encrypt(dst->encrypted_cookie, cookie, COOKIE_LEN, + macs->mac1, COOKIE_LEN, dst->nonce, + checker->cookie_encryption_key); } -void cookie_message_consume(struct message_handshake_cookie *src, struct wireguard_device *wg) +void cookie_message_consume(struct message_handshake_cookie *src, + struct wireguard_device *wg) { - u8 cookie[COOKIE_LEN]; struct wireguard_peer *peer = NULL; + u8 cookie[COOKIE_LEN]; bool ret; - if (unlikely(!index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index, &peer))) + if (unlikely(!index_hashtable_lookup(&wg->index_hashtable, + INDEX_HASHTABLE_HANDSHAKE | + INDEX_HASHTABLE_KEYPAIR, + src->receiver_index, &peer))) return; down_read(&peer->latest_cookie.lock); @@ -175,7 +211,10 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua up_read(&peer->latest_cookie.lock); goto out; } - ret = xchacha20poly1305_decrypt(cookie, src->encrypted_cookie, sizeof(src->encrypted_cookie), peer->latest_cookie.last_mac1_sent, COOKIE_LEN, src->nonce, peer->latest_cookie.cookie_decryption_key); + ret = xchacha20poly1305_decrypt( + cookie, src->encrypted_cookie, sizeof(src->encrypted_cookie), + peer->latest_cookie.last_mac1_sent, COOKIE_LEN, src->nonce, + peer->latest_cookie.cookie_decryption_key); up_read(&peer->latest_cookie.lock); if (ret) { @@ -186,7 +225,8 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua peer->latest_cookie.have_sent_mac1 = false; up_write(&peer->latest_cookie.lock); } else - net_dbg_ratelimited("%s: Could not decrypt invalid cookie response\n", wg->dev->name); + net_dbg_ratelimited("%s: Could not decrypt invalid cookie response\n", + wg->dev->name); out: peer_put(peer); diff --git a/src/cookie.h b/src/cookie.h index 9f519ef..7802f61 100644 --- a/src/cookie.h +++ b/src/cookie.h @@ -38,15 +38,22 @@ enum cookie_mac_state { VALID_MAC_WITH_COOKIE }; -void cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg); +void cookie_checker_init(struct cookie_checker *checker, + struct wireguard_device *wg); void cookie_checker_precompute_device_keys(struct cookie_checker *checker); void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer); void cookie_init(struct cookie *cookie); -enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie); -void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *peer); - -void cookie_message_create(struct message_handshake_cookie *src, struct sk_buff *skb, __le32 index, struct cookie_checker *checker); -void cookie_message_consume(struct message_handshake_cookie *src, struct wireguard_device *wg); +enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, + struct sk_buff *skb, + bool check_cookie); +void cookie_add_mac_to_packet(void *message, size_t len, + struct wireguard_peer *peer); + +void cookie_message_create(struct message_handshake_cookie *src, + struct sk_buff *skb, __le32 index, + struct cookie_checker *checker); +void cookie_message_consume(struct message_handshake_cookie *src, + struct wireguard_device *wg); #endif /* _WG_COOKIE_H */ diff --git a/src/device.c b/src/device.c index 297125a..3dfa794 100644 --- a/src/device.c +++ b/src/device.c @@ -28,19 +28,18 @@ static LIST_HEAD(device_list); static int open(struct net_device *dev) { - int ret; - struct wireguard_peer *peer; + struct in_device *dev_v4 = __in_dev_get_rtnl(dev); struct wireguard_device *wg = netdev_priv(dev); #ifndef COMPAT_CANNOT_USE_IN6_DEV_GET struct inet6_dev *dev_v6 = __in6_dev_get(dev); #endif - struct in_device *dev_v4 = __in_dev_get_rtnl(dev); + struct wireguard_peer *peer; + int ret; if (dev_v4) { - /* TODO: at some point we might put this check near the ip_rt_send_redirect - * call of ip_forward in net/ipv4/ip_forward.c, similar to the current secpath - * check, rather than turning it off like this. This is just a stop gap solution - * while we're an out of tree module. + /* At some point we might put this check near the ip_rt_send_ + * redirect call of ip_forward in net/ipv4/ip_forward.c, similar + * to the current secpath check. */ IN_DEV_CONF_SET(dev_v4, SEND_REDIRECTS, false); IPV4_DEVCONF_ALL(dev_net(dev), SEND_REDIRECTS) = false; @@ -58,7 +57,7 @@ static int open(struct net_device *dev) if (ret < 0) return ret; mutex_lock(&wg->device_update_lock); - list_for_each_entry(peer, &wg->peer_list, peer_list) { + list_for_each_entry (peer, &wg->peer_list, peer_list) { packet_send_staged_packets(peer); if (peer->persistent_keepalive_interval) packet_send_keepalive(peer); @@ -68,7 +67,8 @@ static int open(struct net_device *dev) } #if defined(CONFIG_PM_SLEEP) && !defined(CONFIG_ANDROID) -static int pm_notification(struct notifier_block *nb, unsigned long action, void *data) +static int pm_notification(struct notifier_block *nb, unsigned long action, + void *data) { struct wireguard_device *wg; struct wireguard_peer *peer; @@ -77,9 +77,9 @@ static int pm_notification(struct notifier_block *nb, unsigned long action, void return 0; rtnl_lock(); - list_for_each_entry(wg, &device_list, device_list) { + list_for_each_entry (wg, &device_list, device_list) { mutex_lock(&wg->device_update_lock); - list_for_each_entry(peer, &wg->peer_list, peer_list) { + list_for_each_entry (peer, &wg->peer_list, peer_list) { noise_handshake_clear(&peer->handshake); noise_keypairs_clear(&peer->keypairs); if (peer->timers_enabled) @@ -100,12 +100,14 @@ static int stop(struct net_device *dev) struct wireguard_peer *peer; mutex_lock(&wg->device_update_lock); - list_for_each_entry(peer, &wg->peer_list, peer_list) { + list_for_each_entry (peer, &wg->peer_list, peer_list) { skb_queue_purge(&peer->staged_packet_queue); timers_stop(peer); noise_handshake_clear(&peer->handshake); noise_keypairs_clear(&peer->keypairs); - atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns() - (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC); + atomic64_set(&peer->last_sent_handshake, + ktime_get_boot_fast_ns() - + (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC); } mutex_unlock(&wg->device_update_lock); skb_queue_purge(&wg->incoming_handshakes); @@ -133,16 +135,19 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev) if (unlikely(!peer)) { ret = -ENOKEY; if (skb->protocol == htons(ETH_P_IP)) - net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI4\n", dev->name, &ip_hdr(skb)->daddr); + net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI4\n", + dev->name, &ip_hdr(skb)->daddr); else if (skb->protocol == htons(ETH_P_IPV6)) - net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI6\n", dev->name, &ipv6_hdr(skb)->daddr); + net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI6\n", + dev->name, &ipv6_hdr(skb)->daddr); goto err; } family = READ_ONCE(peer->endpoint.addr.sa_family); if (unlikely(family != AF_INET && family != AF_INET6)) { ret = -EDESTADDRREQ; - net_dbg_ratelimited("%s: No valid endpoint has been configured or discovered for peer %llu\n", dev->name, peer->internal_id); + net_dbg_ratelimited("%s: No valid endpoint has been configured or discovered for peer %llu\n", + dev->name, peer->internal_id); goto err_peer; } @@ -180,8 +185,9 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev) } while ((skb = next) != NULL); spin_lock_bh(&peer->staged_packet_queue.lock); - /* If the queue is getting too big, we start removing the oldest packets until it's small again. - * We do this before adding the new packet, so we don't remove GSO segments that are in excess. + /* If the queue is getting too big, we start removing the oldest packets + * until it's small again. We do this before adding the new packet, so + * we don't remove GSO segments that are in excess. */ while (skb_queue_len(&peer->staged_packet_queue) > MAX_STAGED_PACKETS) dev_kfree_skb(__skb_dequeue(&peer->staged_packet_queue)); @@ -223,7 +229,8 @@ static void destruct(struct net_device *dev) wg->incoming_port = 0; socket_reinit(wg, NULL, NULL); allowedips_free(&wg->peer_allowedips, &wg->device_update_lock); - peer_remove_all(wg); /* The final references are cleared in the below calls to destroy_workqueue. */ + /* The final references are cleared in the below calls to destroy_workqueue. */ + peer_remove_all(wg); destroy_workqueue(wg->handshake_receive_wq); destroy_workqueue(wg->handshake_send_wq); destroy_workqueue(wg->packet_crypt_wq); @@ -231,7 +238,8 @@ static void destruct(struct net_device *dev) packet_queue_free(&wg->encrypt_queue, true); rcu_barrier_bh(); /* Wait for all the peers to be actually freed. */ ratelimiter_uninit(); - memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity)); + memzero_explicit(&wg->static_identity, + sizeof(struct noise_static_identity)); skb_queue_purge(&wg->incoming_handshakes); free_percpu(dev->tstats); free_percpu(wg->incoming_handshakes_worker); @@ -243,14 +251,14 @@ static void destruct(struct net_device *dev) free_netdev(dev); } -static const struct device_type device_type = { - .name = KBUILD_MODNAME -}; +static const struct device_type device_type = { .name = KBUILD_MODNAME }; static void setup(struct net_device *dev) { struct wireguard_device *wg = netdev_priv(dev); - enum { WG_NETDEV_FEATURES = NETIF_F_HW_CSUM | NETIF_F_RXCSUM | NETIF_F_SG | NETIF_F_GSO | NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA }; + enum { WG_NETDEV_FEATURES = NETIF_F_HW_CSUM | NETIF_F_RXCSUM | + NETIF_F_SG | NETIF_F_GSO | + NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA }; dev->netdev_ops = &netdev_ops; dev->hard_header_len = 0; @@ -268,7 +276,9 @@ static void setup(struct net_device *dev) dev->features |= WG_NETDEV_FEATURES; dev->hw_features |= WG_NETDEV_FEATURES; dev->hw_enc_features |= WG_NETDEV_FEATURES; - dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH - sizeof(struct udphdr) - max(sizeof(struct ipv6hdr), sizeof(struct iphdr)); + dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH - + sizeof(struct udphdr) - + max(sizeof(struct ipv6hdr), sizeof(struct iphdr)); SET_NETDEV_DEVTYPE(dev, &device_type); @@ -279,7 +289,9 @@ static void setup(struct net_device *dev) wg->dev = dev; } -static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[], struct netlink_ext_ack *extack) +static int newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) { int ret = -ENOMEM; struct wireguard_device *wg = netdev_priv(dev); @@ -300,26 +312,32 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t if (!dev->tstats) goto error_1; - wg->incoming_handshakes_worker = packet_alloc_percpu_multicore_worker(packet_handshake_receive_worker, wg); + wg->incoming_handshakes_worker = packet_alloc_percpu_multicore_worker( + packet_handshake_receive_worker, wg); if (!wg->incoming_handshakes_worker) goto error_2; - wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name); + wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", + WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name); if (!wg->handshake_receive_wq) goto error_3; - wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); + wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", + WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); if (!wg->handshake_send_wq) goto error_4; - wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name); + wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s", + WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name); if (!wg->packet_crypt_wq) goto error_5; - if (packet_queue_init(&wg->encrypt_queue, packet_encrypt_worker, true, MAX_QUEUED_PACKETS) < 0) + if (packet_queue_init(&wg->encrypt_queue, packet_encrypt_worker, true, + MAX_QUEUED_PACKETS) < 0) goto error_6; - if (packet_queue_init(&wg->decrypt_queue, packet_decrypt_worker, true, MAX_QUEUED_PACKETS) < 0) + if (packet_queue_init(&wg->decrypt_queue, packet_decrypt_worker, true, + MAX_QUEUED_PACKETS) < 0) goto error_7; ret = ratelimiter_init(); @@ -332,8 +350,8 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t list_add(&wg->device_list, &device_list); - /* We wait until the end to assign priv_destructor, so that register_netdevice doesn't - * call it for us if it fails. + /* We wait until the end to assign priv_destructor, so that + * register_netdevice doesn't call it for us if it fails. */ dev->priv_destructor = destruct; @@ -367,7 +385,8 @@ static struct rtnl_link_ops link_ops __read_mostly = { .newlink = newlink, }; -static int netdevice_notification(struct notifier_block *nb, unsigned long action, void *data) +static int netdevice_notification(struct notifier_block *nb, + unsigned long action, void *data) { struct net_device *dev = ((struct netdev_notifier_info *)data)->dev; struct wireguard_device *wg = netdev_priv(dev); @@ -380,14 +399,16 @@ static int netdevice_notification(struct notifier_block *nb, unsigned long actio if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) { put_net(wg->creating_net); wg->have_creating_net_ref = false; - } else if (dev_net(dev) != wg->creating_net && !wg->have_creating_net_ref) { + } else if (dev_net(dev) != wg->creating_net && + !wg->have_creating_net_ref) { wg->have_creating_net_ref = true; get_net(wg->creating_net); } return 0; } -static struct notifier_block netdevice_notifier = { .notifier_call = netdevice_notification }; +static struct notifier_block netdevice_notifier = + { .notifier_call = netdevice_notification }; int __init device_init(void) { diff --git a/src/device.h b/src/device.h index 2a0e2c7..2499782 100644 --- a/src/device.h +++ b/src/device.h @@ -42,7 +42,8 @@ struct wireguard_device { struct sock __rcu *sock4, *sock6; struct net *creating_net; struct noise_static_identity static_identity; - struct workqueue_struct *handshake_receive_wq, *handshake_send_wq, *packet_crypt_wq; + struct workqueue_struct *handshake_receive_wq, *handshake_send_wq; + struct workqueue_struct *packet_crypt_wq; struct sk_buff_head incoming_handshakes; int incoming_handshake_cpu; struct multicore_worker __percpu *incoming_handshakes_worker; diff --git a/src/hashtables.c b/src/hashtables.c index ac6df59..4ba2288 100644 --- a/src/hashtables.c +++ b/src/hashtables.c @@ -7,12 +7,16 @@ #include "peer.h" #include "noise.h" -static inline struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) +static inline struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table, + const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) { - /* siphash gives us a secure 64bit number based on a random key. Since the bits are - * uniformly distributed, we can then mask off to get the bits we need. + /* siphash gives us a secure 64bit number based on a random key. Since + * the bits are uniformly distributed, we can then mask off to get the + * bits we need. */ - return &table->hashtable[siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key) & (HASH_SIZE(table->hashtable) - 1)]; + return &table->hashtable[ + siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key) & + (HASH_SIZE(table->hashtable) - 1)]; } void pubkey_hashtable_init(struct pubkey_hashtable *table) @@ -22,14 +26,17 @@ void pubkey_hashtable_init(struct pubkey_hashtable *table) mutex_init(&table->lock); } -void pubkey_hashtable_add(struct pubkey_hashtable *table, struct wireguard_peer *peer) +void pubkey_hashtable_add(struct pubkey_hashtable *table, + struct wireguard_peer *peer) { mutex_lock(&table->lock); - hlist_add_head_rcu(&peer->pubkey_hash, pubkey_bucket(table, peer->handshake.remote_static)); + hlist_add_head_rcu(&peer->pubkey_hash, + pubkey_bucket(table, peer->handshake.remote_static)); mutex_unlock(&table->lock); } -void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_peer *peer) +void pubkey_hashtable_remove(struct pubkey_hashtable *table, + struct wireguard_peer *peer) { mutex_lock(&table->lock); hlist_del_init_rcu(&peer->pubkey_hash); @@ -37,13 +44,17 @@ void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_pe } /* Returns a strong reference to a peer */ -struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) +struct wireguard_peer * +pubkey_hashtable_lookup(struct pubkey_hashtable *table, + const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) { struct wireguard_peer *iter_peer, *peer = NULL; rcu_read_lock_bh(); - hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey), pubkey_hash) { - if (!memcmp(pubkey, iter_peer->handshake.remote_static, NOISE_PUBLIC_KEY_LEN)) { + hlist_for_each_entry_rcu_bh (iter_peer, pubkey_bucket(table, pubkey), + pubkey_hash) { + if (!memcmp(pubkey, iter_peer->handshake.remote_static, + NOISE_PUBLIC_KEY_LEN)) { peer = iter_peer; break; } @@ -53,12 +64,14 @@ struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, c return peer; } -static inline struct hlist_head *index_bucket(struct index_hashtable *table, const __le32 index) +static inline struct hlist_head *index_bucket(struct index_hashtable *table, + const __le32 index) { - /* Since the indices are random and thus all bits are uniformly distributed, - * we can find its bucket simply by masking. + /* Since the indices are random and thus all bits are uniformly + * distributed, we can find its bucket simply by masking. */ - return &table->hashtable[(__force u32)index & (HASH_SIZE(table->hashtable) - 1)]; + return &table->hashtable[(__force u32)index & + (HASH_SIZE(table->hashtable) - 1)]; } void index_hashtable_init(struct index_hashtable *table) @@ -67,9 +80,10 @@ void index_hashtable_init(struct index_hashtable *table) spin_lock_init(&table->lock); } -/* At the moment, we limit ourselves to 2^20 total peers, which generally might amount to 2^20*3 - * items in this hashtable. The algorithm below works by picking a random number and testing it. - * We can see that these limits mean we usually succeed pretty quickly: +/* At the moment, we limit ourselves to 2^20 total peers, which generally might + * amount to 2^20*3 items in this hashtable. The algorithm below works by + * picking a random number and testing it. We can see that these limits mean we + * usually succeed pretty quickly: * * >>> def calculation(tries, size): * ... return (size / 2**32)**(tries - 1) * (1 - (size / 2**32)) @@ -83,13 +97,15 @@ void index_hashtable_init(struct index_hashtable *table) * >>> calculation(4, 2**20 * 3) * 3.9261394135792216e-10 * - * At the moment, we don't do any masking, so this algorithm isn't exactly constant time in - * either the random guessing or in the hash list lookup. We could require a minimum of 3 - * tries, which would successfully mask the guessing. TODO: this would not, however, help - * with the growing hash lengths. + * At the moment, we don't do any masking, so this algorithm isn't exactly + * constant time in either the random guessing or in the hash list lookup. We + * could require a minimum of 3 tries, which would successfully mask the + * guessing. this would not, however, help with the growing hash lengths, which + * is another thing to consider moving forward. */ -__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry) +__le32 index_hashtable_insert(struct index_hashtable *table, + struct index_hashtable_entry *entry) { struct index_hashtable_entry *existing_entry; @@ -102,23 +118,32 @@ __le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashta search_unused_slot: /* First we try to find an unused slot, randomly, while unlocked. */ entry->index = (__force __le32)get_random_u32(); - hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) { + hlist_for_each_entry_rcu_bh (existing_entry, + index_bucket(table, entry->index), + index_hash) { if (existing_entry->index == entry->index) - goto search_unused_slot; /* If it's already in use, we continue searching. */ + /* If it's already in use, we continue searching. */ + goto search_unused_slot; } /* Once we've found an unused slot, we lock it, and then double-check * that nobody else stole it from us. */ spin_lock_bh(&table->lock); - hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) { + hlist_for_each_entry_rcu_bh (existing_entry, + index_bucket(table, entry->index), + index_hash) { if (existing_entry->index == entry->index) { spin_unlock_bh(&table->lock); - goto search_unused_slot; /* If it was stolen, we start over. */ + /* If it was stolen, we start over. */ + goto search_unused_slot; } } - /* Otherwise, we know we have it exclusively (since we're locked), so we insert. */ - hlist_add_head_rcu(&entry->index_hash, index_bucket(table, entry->index)); + /* Otherwise, we know we have it exclusively (since we're locked), + * so we insert. + */ + hlist_add_head_rcu(&entry->index_hash, + index_bucket(table, entry->index)); spin_unlock_bh(&table->lock); rcu_read_unlock_bh(); @@ -126,7 +151,9 @@ search_unused_slot: return entry->index; } -bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new) +bool index_hashtable_replace(struct index_hashtable *table, + struct index_hashtable_entry *old, + struct index_hashtable_entry *new) { if (unlikely(hlist_unhashed(&old->index_hash))) return false; @@ -134,17 +161,19 @@ bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtab new->index = old->index; hlist_replace_rcu(&old->index_hash, &new->index_hash); - /* Calling init here NULLs out index_hash, and in fact after this function returns, - * it's theoretically possible for this to get reinserted elsewhere. That means - * the RCU lookup below might either terminate early or jump between buckets, in which - * case the packet simply gets dropped, which isn't terrible. + /* Calling init here NULLs out index_hash, and in fact after this + * function returns, it's theoretically possible for this to get + * reinserted elsewhere. That means the RCU lookup below might either + * terminate early or jump between buckets, in which case the packet + * simply gets dropped, which isn't terrible. */ INIT_HLIST_NODE(&old->index_hash); spin_unlock_bh(&table->lock); return true; } -void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry) +void index_hashtable_remove(struct index_hashtable *table, + struct index_hashtable_entry *entry) { spin_lock_bh(&table->lock); hlist_del_init_rcu(&entry->index_hash); @@ -152,12 +181,16 @@ void index_hashtable_remove(struct index_hashtable *table, struct index_hashtabl } /* Returns a strong reference to a entry->peer */ -struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer) +struct index_hashtable_entry * +index_hashtable_lookup(struct index_hashtable *table, + const enum index_hashtable_type type_mask, + const __le32 index, struct wireguard_peer **peer) { struct index_hashtable_entry *iter_entry, *entry = NULL; rcu_read_lock_bh(); - hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index), index_hash) { + hlist_for_each_entry_rcu_bh (iter_entry, index_bucket(table, index), + index_hash) { if (iter_entry->index == index) { if (likely(iter_entry->type & type_mask)) entry = iter_entry; diff --git a/src/hashtables.h b/src/hashtables.h index f64cd24..62858c5 100644 --- a/src/hashtables.h +++ b/src/hashtables.h @@ -22,9 +22,13 @@ struct pubkey_hashtable { }; void pubkey_hashtable_init(struct pubkey_hashtable *table); -void pubkey_hashtable_add(struct pubkey_hashtable *table, struct wireguard_peer *peer); -void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_peer *peer); -struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN]); +void pubkey_hashtable_add(struct pubkey_hashtable *table, + struct wireguard_peer *peer); +void pubkey_hashtable_remove(struct pubkey_hashtable *table, + struct wireguard_peer *peer); +struct wireguard_peer * +pubkey_hashtable_lookup(struct pubkey_hashtable *table, + const u8 pubkey[NOISE_PUBLIC_KEY_LEN]); struct index_hashtable { /* TODO: move to rhashtable */ @@ -44,9 +48,16 @@ struct index_hashtable_entry { __le32 index; }; void index_hashtable_init(struct index_hashtable *table); -__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry); -bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new); -void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry); -struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer); +__le32 index_hashtable_insert(struct index_hashtable *table, + struct index_hashtable_entry *entry); +bool index_hashtable_replace(struct index_hashtable *table, + struct index_hashtable_entry *old, + struct index_hashtable_entry *new); +void index_hashtable_remove(struct index_hashtable *table, + struct index_hashtable_entry *entry); +struct index_hashtable_entry * +index_hashtable_lookup(struct index_hashtable *table, + const enum index_hashtable_type type_mask, + const __le32 index, struct wireguard_peer **peer); #endif /* _WG_HASHTABLES_H */ @@ -31,7 +31,10 @@ static int __init mod_init(void) blake2s_fpu_init(); curve25519_fpu_init(); #ifdef DEBUG - if (!allowedips_selftest() || !packet_counter_selftest() || !curve25519_selftest() || !poly1305_selftest() || !chacha20poly1305_selftest() || !blake2s_selftest() || !ratelimiter_selftest()) + if (!allowedips_selftest() || !packet_counter_selftest() || + !curve25519_selftest() || !poly1305_selftest() || + !chacha20poly1305_selftest() || !blake2s_selftest() || + !ratelimiter_selftest()) return -ENOTRECOVERABLE; #endif noise_init(); diff --git a/src/messages.h b/src/messages.h index 9425a26..bb19957 100644 --- a/src/messages.h +++ b/src/messages.h @@ -109,18 +109,20 @@ struct message_data { u8 encrypted_data[]; }; -#define message_data_len(plain_len) (noise_encrypted_len(plain_len) + sizeof(struct message_data)) +#define message_data_len(plain_len) \ + (noise_encrypted_len(plain_len) + sizeof(struct message_data)) enum message_alignments { MESSAGE_PADDING_MULTIPLE = 16, MESSAGE_MINIMUM_LENGTH = message_data_len(0) }; -#define SKB_HEADER_LEN (max(sizeof(struct iphdr), sizeof(struct ipv6hdr)) + sizeof(struct udphdr) + NET_SKB_PAD) -#define DATA_PACKET_HEAD_ROOM ALIGN(sizeof(struct message_data) + SKB_HEADER_LEN, 4) +#define SKB_HEADER_LEN \ + (max(sizeof(struct iphdr), sizeof(struct ipv6hdr)) + \ + sizeof(struct udphdr) + NET_SKB_PAD) +#define DATA_PACKET_HEAD_ROOM \ + ALIGN(sizeof(struct message_data) + SKB_HEADER_LEN, 4) -enum { - HANDSHAKE_DSCP = 0x88 /* AF41, plus 00 ECN */ -}; +enum { HANDSHAKE_DSCP = 0x88 /* AF41, plus 00 ECN */ }; #endif /* _WG_MESSAGES_H */ diff --git a/src/netlink.c b/src/netlink.c index 3147587..5390498 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -45,19 +45,23 @@ static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = { [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 } }; -static struct wireguard_device *lookup_interface(struct nlattr **attrs, struct sk_buff *skb) +static struct wireguard_device *lookup_interface(struct nlattr **attrs, + struct sk_buff *skb) { struct net_device *dev = NULL; if (!attrs[WGDEVICE_A_IFINDEX] == !attrs[WGDEVICE_A_IFNAME]) return ERR_PTR(-EBADR); if (attrs[WGDEVICE_A_IFINDEX]) - dev = dev_get_by_index(sock_net(skb->sk), nla_get_u32(attrs[WGDEVICE_A_IFINDEX])); + dev = dev_get_by_index(sock_net(skb->sk), + nla_get_u32(attrs[WGDEVICE_A_IFINDEX])); else if (attrs[WGDEVICE_A_IFNAME]) - dev = dev_get_by_name(sock_net(skb->sk), nla_data(attrs[WGDEVICE_A_IFNAME])); + dev = dev_get_by_name(sock_net(skb->sk), + nla_data(attrs[WGDEVICE_A_IFNAME])); if (!dev) return ERR_PTR(-ENODEV); - if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind || strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) { + if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind || + strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) { dev_put(dev); return ERR_PTR(-EOPNOTSUPP); } @@ -71,15 +75,17 @@ struct allowedips_ctx { static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) { - struct nlattr *allowedip_nest; struct allowedips_ctx *actx = ctx; + struct nlattr *allowedip_nest; allowedip_nest = nla_nest_start(actx->skb, actx->i++); if (!allowedip_nest) return -EMSGSIZE; - if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) || nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) || - nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) { + if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) || + nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) || + nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? + sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) { nla_nest_cancel(actx->skb, allowedip_nest); return -EMSGSIZE; } @@ -88,37 +94,52 @@ static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) return 0; } -static int get_peer(struct wireguard_peer *peer, unsigned int index, struct allowedips_cursor *rt_cursor, struct sk_buff *skb) +static int get_peer(struct wireguard_peer *peer, unsigned int index, + struct allowedips_cursor *rt_cursor, struct sk_buff *skb) { - struct allowedips_ctx ctx = { .skb = skb }; struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, index); + struct allowedips_ctx ctx = { .skb = skb }; bool fail; if (!peer_nest) return -EMSGSIZE; down_read(&peer->handshake.lock); - fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN, peer->handshake.remote_static); + fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN, + peer->handshake.remote_static); up_read(&peer->handshake.lock); if (fail) goto err; if (!rt_cursor->seq) { down_read(&peer->handshake.lock); - fail = nla_put(skb, WGPEER_A_PRESHARED_KEY, NOISE_SYMMETRIC_KEY_LEN, peer->handshake.preshared_key); + fail = nla_put(skb, WGPEER_A_PRESHARED_KEY, + NOISE_SYMMETRIC_KEY_LEN, + peer->handshake.preshared_key); up_read(&peer->handshake.lock); if (fail) goto err; - if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME, sizeof(struct timespec), &peer->walltime_last_handshake) || nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval) || - nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes, WGPEER_A_UNSPEC) || nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes, WGPEER_A_UNSPEC)) + if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME, + sizeof(struct timespec), + &peer->walltime_last_handshake) || + nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, + peer->persistent_keepalive_interval) || + nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes, + WGPEER_A_UNSPEC) || + nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes, + WGPEER_A_UNSPEC)) goto err; read_lock_bh(&peer->endpoint_lock); if (peer->endpoint.addr.sa_family == AF_INET) - fail = nla_put(skb, WGPEER_A_ENDPOINT, sizeof(struct sockaddr_in), &peer->endpoint.addr4); + fail = nla_put(skb, WGPEER_A_ENDPOINT, + sizeof(struct sockaddr_in), + &peer->endpoint.addr4); else if (peer->endpoint.addr.sa_family == AF_INET6) - fail = nla_put(skb, WGPEER_A_ENDPOINT, sizeof(struct sockaddr_in6), &peer->endpoint.addr6); + fail = nla_put(skb, WGPEER_A_ENDPOINT, + sizeof(struct sockaddr_in6), + &peer->endpoint.addr6); read_unlock_bh(&peer->endpoint_lock); if (fail) goto err; @@ -127,7 +148,9 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, struct allo allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); if (!allowedips_nest) goto err; - if (allowedips_walk_by_peer(&peer->device->peer_allowedips, rt_cursor, peer, get_allowedips, &ctx, &peer->device->device_update_lock)) { + if (allowedips_walk_by_peer(&peer->device->peer_allowedips, rt_cursor, + peer, get_allowedips, &ctx, + &peer->device->device_update_lock)) { nla_nest_end(skb, allowedips_nest); nla_nest_end(skb, peer_nest); return -EMSGSIZE; @@ -143,13 +166,15 @@ err: static int get_device_start(struct netlink_callback *cb) { - struct wireguard_device *wg; struct nlattr **attrs = genl_family_attrbuf(&genl_family); - int ret = nlmsg_parse(cb->nlh, GENL_HDRLEN + genl_family.hdrsize, attrs, genl_family.maxattr, device_policy, NULL); + int ret = nlmsg_parse(cb->nlh, GENL_HDRLEN + genl_family.hdrsize, attrs, + genl_family.maxattr, device_policy, NULL); + struct wireguard_device *wg; if (ret < 0) return ret; - cb->args[2] = (long)kzalloc(sizeof(struct allowedips_cursor), GFP_KERNEL); + cb->args[2] = + (long)kzalloc(sizeof(struct allowedips_cursor), GFP_KERNEL); if (!cb->args[2]) return -ENOMEM; wg = lookup_interface(attrs, cb->skb); @@ -164,33 +189,46 @@ static int get_device_start(struct netlink_callback *cb) static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) { - struct wireguard_device *wg = (struct wireguard_device *)cb->args[0]; struct wireguard_peer *peer, *next_peer_cursor, *last_peer_cursor; - struct allowedips_cursor *rt_cursor = (struct allowedips_cursor *)cb->args[2]; + struct allowedips_cursor *rt_cursor; + struct wireguard_device *wg; unsigned int peer_idx = 0; struct nlattr *peers_nest; bool done = true; void *hdr; int ret = -EMSGSIZE; - next_peer_cursor = last_peer_cursor = (struct wireguard_peer *)cb->args[1]; + wg = (struct wireguard_device *)cb->args[0]; + next_peer_cursor = (struct wireguard_peer *)cb->args[1]; + last_peer_cursor = (struct wireguard_peer *)cb->args[1]; + rt_cursor = (struct allowedips_cursor *)cb->args[2]; rtnl_lock(); mutex_lock(&wg->device_update_lock); cb->seq = wg->device_update_gen; - hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE); + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE); if (!hdr) goto out; genl_dump_check_consistent(cb, hdr); if (!last_peer_cursor) { - if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT, wg->incoming_port) || nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) || nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) || nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name)) + if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT, + wg->incoming_port) || + nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) || + nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) || + nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name)) goto out; down_read(&wg->static_identity.lock); if (wg->static_identity.has_identity) { - if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY, NOISE_PUBLIC_KEY_LEN, wg->static_identity.static_private) || nla_put(skb, WGDEVICE_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN, wg->static_identity.static_public)) { + if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY, + NOISE_PUBLIC_KEY_LEN, + wg->static_identity.static_private) || + nla_put(skb, WGDEVICE_A_PUBLIC_KEY, + NOISE_PUBLIC_KEY_LEN, + wg->static_identity.static_public)) { up_read(&wg->static_identity.lock); goto out; } @@ -202,17 +240,19 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) if (!peers_nest) goto out; ret = 0; - /* If the last cursor was removed via list_del_init in peer_remove, then we just treat - * this the same as there being no more peers left. The reason is that seq_nr should - * indicate to userspace that this isn't a coherent dump anyway, so they'll try again. + /* If the last cursor was removed via list_del_init in peer_remove, then + * we just treat this the same as there being no more peers left. The + * reason is that seq_nr should indicate to userspace that this isn't a + * coherent dump anyway, so they'll try again. */ - if (list_empty(&wg->peer_list) || (last_peer_cursor && list_empty(&last_peer_cursor->peer_list))) { + if (list_empty(&wg->peer_list) || + (last_peer_cursor && list_empty(&last_peer_cursor->peer_list))) { nla_nest_cancel(skb, peers_nest); goto out; } lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list); - list_for_each_entry_continue(peer, &wg->peer_list, peer_list) { + list_for_each_entry_continue (peer, &wg->peer_list, peer_list) { if (get_peer(peer, peer_idx++, rt_cursor, skb)) { done = false; break; @@ -250,7 +290,8 @@ static int get_device_done(struct netlink_callback *cb) { struct wireguard_device *wg = (struct wireguard_device *)cb->args[0]; struct wireguard_peer *peer = (struct wireguard_peer *)cb->args[1]; - struct allowedips_cursor *rt_cursor = (struct allowedips_cursor *)cb->args[2]; + struct allowedips_cursor *rt_cursor = + (struct allowedips_cursor *)cb->args[2]; if (wg) dev_put(wg->dev); @@ -265,7 +306,7 @@ static int set_port(struct wireguard_device *wg, u16 port) if (wg->incoming_port == port) return 0; - list_for_each_entry(peer, &wg->peer_list, peer_list) + list_for_each_entry (peer, &wg->peer_list, peer_list) socket_clear_peer_endpoint_src(peer); if (!netif_running(wg->dev)) { wg->incoming_port = port; @@ -280,15 +321,25 @@ static int set_allowedip(struct wireguard_peer *peer, struct nlattr **attrs) u16 family; u8 cidr; - if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] || !attrs[WGALLOWEDIP_A_CIDR_MASK]) + if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] || + !attrs[WGALLOWEDIP_A_CIDR_MASK]) return ret; family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]); cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]); - if (family == AF_INET && cidr <= 32 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) - ret = allowedips_insert_v4(&peer->device->peer_allowedips, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock); - else if (family == AF_INET6 && cidr <= 128 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) - ret = allowedips_insert_v6(&peer->device->peer_allowedips, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock); + if (family == AF_INET && cidr <= 32 && + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) + ret = allowedips_insert_v4( + &peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, + &peer->device->device_update_lock); + else if (family == AF_INET6 && cidr <= 128 && + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == + sizeof(struct in6_addr)) + ret = allowedips_insert_v6( + &peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, + &peer->device->device_update_lock); return ret; } @@ -301,25 +352,33 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) u8 *public_key = NULL, *preshared_key = NULL; ret = -EINVAL; - if (attrs[WGPEER_A_PUBLIC_KEY] && nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN) + if (attrs[WGPEER_A_PUBLIC_KEY] && + nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN) public_key = nla_data(attrs[WGPEER_A_PUBLIC_KEY]); else goto out; - if (attrs[WGPEER_A_PRESHARED_KEY] && nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN) + if (attrs[WGPEER_A_PRESHARED_KEY] && + nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN) preshared_key = nla_data(attrs[WGPEER_A_PRESHARED_KEY]); if (attrs[WGPEER_A_FLAGS]) flags = nla_get_u32(attrs[WGPEER_A_FLAGS]); - peer = pubkey_hashtable_lookup(&wg->peer_hashtable, nla_data(attrs[WGPEER_A_PUBLIC_KEY])); + peer = pubkey_hashtable_lookup(&wg->peer_hashtable, + nla_data(attrs[WGPEER_A_PUBLIC_KEY])); if (!peer) { /* Peer doesn't exist yet. Add a new one. */ ret = -ENODEV; if (flags & WGPEER_F_REMOVE_ME) goto out; /* Tried to remove a non-existing peer. */ down_read(&wg->static_identity.lock); - if (wg->static_identity.has_identity && !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]), wg->static_identity.static_public, NOISE_PUBLIC_KEY_LEN)) { - /* We silently ignore peers that have the same public key as the device. The reason we do it silently - * is that we'd like for people to be able to reuse the same set of API calls across peers. + if (wg->static_identity.has_identity && + !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]), + wg->static_identity.static_public, + NOISE_PUBLIC_KEY_LEN)) { + /* We silently ignore peers that have the same public + * key as the device. The reason we do it silently is + * that we'd like for people to be able to reuse the + * same set of API calls across peers. */ up_read(&wg->static_identity.lock); ret = 0; @@ -331,7 +390,9 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) peer = peer_create(wg, public_key, preshared_key); if (!peer) goto out; - /* Take additional reference, as though we've just been looked up. */ + /* Take additional reference, as though we've just been + * looked up. + */ peer_get(peer); } @@ -343,7 +404,8 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) if (preshared_key) { down_write(&peer->handshake.lock); - memcpy(&peer->handshake.preshared_key, preshared_key, NOISE_SYMMETRIC_KEY_LEN); + memcpy(&peer->handshake.preshared_key, preshared_key, + NOISE_SYMMETRIC_KEY_LEN); up_write(&peer->handshake.lock); } @@ -351,7 +413,10 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]); size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]); - if ((len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) || (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6)) { + if ((len == sizeof(struct sockaddr_in) && + addr->sa_family == AF_INET) || + (len == sizeof(struct sockaddr_in6) && + addr->sa_family == AF_INET6)) { struct endpoint endpoint = { { { 0 } } }; memcpy(&endpoint.addr, addr, len); @@ -360,14 +425,16 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) } if (flags & WGPEER_F_REPLACE_ALLOWEDIPS) - allowedips_remove_by_peer(&wg->peer_allowedips, peer, &wg->device_update_lock); + allowedips_remove_by_peer(&wg->peer_allowedips, peer, + &wg->device_update_lock); if (attrs[WGPEER_A_ALLOWEDIPS]) { - int rem; struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1]; + int rem; - nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) { - ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX, attr, allowedip_policy, NULL); + nla_for_each_nested (attr, attrs[WGPEER_A_ALLOWEDIPS], rem) { + ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX, + attr, allowedip_policy, NULL); if (ret < 0) goto out; ret = set_allowedip(peer, allowedip); @@ -377,8 +444,12 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) } if (attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]) { - const u16 persistent_keepalive_interval = nla_get_u16(attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]); - const bool send_keepalive = !peer->persistent_keepalive_interval && persistent_keepalive_interval && netif_running(wg->dev); + const u16 persistent_keepalive_interval = nla_get_u16( + attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]); + const bool send_keepalive = + !peer->persistent_keepalive_interval && + persistent_keepalive_interval && + netif_running(wg->dev); peer->persistent_keepalive_interval = persistent_keepalive_interval; if (send_keepalive) @@ -391,7 +462,8 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) out: peer_put(peer); if (attrs[WGPEER_A_PRESHARED_KEY]) - memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]), nla_len(attrs[WGPEER_A_PRESHARED_KEY])); + memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]), + nla_len(attrs[WGPEER_A_PRESHARED_KEY])); return ret; } @@ -413,26 +485,35 @@ static int set_device(struct sk_buff *skb, struct genl_info *info) struct wireguard_peer *peer; wg->fwmark = nla_get_u32(info->attrs[WGDEVICE_A_FWMARK]); - list_for_each_entry(peer, &wg->peer_list, peer_list) + list_for_each_entry (peer, &wg->peer_list, peer_list) socket_clear_peer_endpoint_src(peer); } if (info->attrs[WGDEVICE_A_LISTEN_PORT]) { - ret = set_port(wg, nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT])); + ret = set_port( + wg, nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT])); if (ret) goto out; } - if (info->attrs[WGDEVICE_A_FLAGS] && nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]) & WGDEVICE_F_REPLACE_PEERS) + if (info->attrs[WGDEVICE_A_FLAGS] && + nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]) & + WGDEVICE_F_REPLACE_PEERS) peer_remove_all(wg); - if (info->attrs[WGDEVICE_A_PRIVATE_KEY] && nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) == NOISE_PUBLIC_KEY_LEN) { + if (info->attrs[WGDEVICE_A_PRIVATE_KEY] && + nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) == + NOISE_PUBLIC_KEY_LEN) { + u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]); + u8 public_key[NOISE_PUBLIC_KEY_LEN]; struct wireguard_peer *peer, *temp; - u8 public_key[NOISE_PUBLIC_KEY_LEN], *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]); - /* We remove before setting, to prevent race, which means doing two 25519-genpub ops. */ + /* We remove before setting, to prevent race, which means doing + * two 25519-genpub ops. + */ if (curve25519_generate_public(public_key, private_key)) { - peer = pubkey_hashtable_lookup(&wg->peer_hashtable, public_key); + peer = pubkey_hashtable_lookup(&wg->peer_hashtable, + public_key); if (peer) { peer_put(peer); peer_remove(peer); @@ -440,8 +521,10 @@ static int set_device(struct sk_buff *skb, struct genl_info *info) } down_write(&wg->static_identity.lock); - noise_set_static_identity_private_key(&wg->static_identity, private_key); - list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) { + noise_set_static_identity_private_key(&wg->static_identity, + private_key); + list_for_each_entry_safe (peer, temp, &wg->peer_list, + peer_list) { if (!noise_precompute_static_static(peer)) peer_remove(peer); } @@ -453,8 +536,9 @@ static int set_device(struct sk_buff *skb, struct genl_info *info) int rem; struct nlattr *attr, *peer[WGPEER_A_MAX + 1]; - nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) { - ret = nla_parse_nested(peer, WGPEER_A_MAX, attr, peer_policy, NULL); + nla_for_each_nested (attr, info->attrs[WGDEVICE_A_PEERS], rem) { + ret = nla_parse_nested(peer, WGPEER_A_MAX, attr, + peer_policy, NULL); if (ret < 0) goto out; ret = set_peer(wg, peer); @@ -470,7 +554,8 @@ out: dev_put(wg->dev); out_nodev: if (info->attrs[WGDEVICE_A_PRIVATE_KEY]) - memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]), nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY])); + memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]), + nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY])); return ret; } diff --git a/src/noise.c b/src/noise.c index 0f6e51b..70b53a6 100644 --- a/src/noise.c +++ b/src/noise.c @@ -35,7 +35,8 @@ void __init noise_init(void) { struct blake2s_state blake; - blake2s(handshake_init_chaining_key, handshake_name, NULL, NOISE_HASH_LEN, sizeof(handshake_name), 0); + blake2s(handshake_init_chaining_key, handshake_name, NULL, + NOISE_HASH_LEN, sizeof(handshake_name), 0); blake2s_init(&blake, NOISE_HASH_LEN); blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN); blake2s_update(&blake, identifier_name, sizeof(identifier_name)); @@ -46,16 +47,25 @@ void __init noise_init(void) bool noise_precompute_static_static(struct wireguard_peer *peer) { bool ret = true; + down_write(&peer->handshake.lock); if (peer->handshake.static_identity->has_identity) - ret = curve25519(peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static); + ret = curve25519( + peer->handshake.precomputed_static_static, + peer->handshake.static_identity->static_private, + peer->handshake.remote_static); else - memset(peer->handshake.precomputed_static_static, 0, NOISE_PUBLIC_KEY_LEN); + memset(peer->handshake.precomputed_static_static, 0, + NOISE_PUBLIC_KEY_LEN); up_write(&peer->handshake.lock); return ret; } -bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer) +bool noise_handshake_init(struct noise_handshake *handshake, + struct noise_static_identity *static_identity, + const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], + const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], + struct wireguard_peer *peer) { memset(handshake, 0, sizeof(struct noise_handshake)); init_rwsem(&handshake->lock); @@ -63,7 +73,8 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static handshake->entry.peer = peer; memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN); if (peer_preshared_key) - memcpy(handshake->preshared_key, peer_preshared_key, NOISE_SYMMETRIC_KEY_LEN); + memcpy(handshake->preshared_key, peer_preshared_key, + NOISE_SYMMETRIC_KEY_LEN); handshake->static_identity = static_identity; handshake->state = HANDSHAKE_ZEROED; return noise_precompute_static_static(peer); @@ -81,16 +92,19 @@ static void handshake_zero(struct noise_handshake *handshake) void noise_handshake_clear(struct noise_handshake *handshake) { - index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, + &handshake->entry); down_write(&handshake->lock); handshake_zero(handshake); up_write(&handshake->lock); - index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, + &handshake->entry); } static struct noise_keypair *keypair_create(struct wireguard_peer *peer) { - struct noise_keypair *keypair = kzalloc(sizeof(struct noise_keypair), GFP_KERNEL); + struct noise_keypair *keypair = + kzalloc(sizeof(struct noise_keypair), GFP_KERNEL); if (unlikely(!keypair)) return NULL; @@ -108,9 +122,14 @@ static void keypair_free_rcu(struct rcu_head *rcu) static void keypair_free_kref(struct kref *kref) { - struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount); - net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id); - index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); + struct noise_keypair *keypair = + container_of(kref, struct noise_keypair, refcount); + net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", + keypair->entry.peer->device->dev->name, + keypair->internal_id, + keypair->entry.peer->internal_id); + index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, + &keypair->entry); call_rcu_bh(&keypair->rcu, keypair_free_rcu); } @@ -119,13 +138,16 @@ void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now) if (unlikely(!keypair)) return; if (unlikely(unreference_now)) - index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); + index_hashtable_remove( + &keypair->entry.peer->device->index_hashtable, + &keypair->entry); kref_put(&keypair->refcount, keypair_free_kref); } struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair) { - RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking noise keypair reference without holding the RCU BH read lock"); + RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), + "Taking noise keypair reference without holding the RCU BH read lock"); if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount))) return NULL; return keypair; @@ -136,55 +158,67 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs) struct noise_keypair *old; spin_lock_bh(&keypairs->keypair_update_lock); - old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->previous_keypair, NULL); noise_keypair_put(old, true); - old = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->next_keypair, NULL); noise_keypair_put(old, true); - old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->current_keypair, NULL); noise_keypair_put(old, true); spin_unlock_bh(&keypairs->keypair_update_lock); } -static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypair *new_keypair) +static void add_new_keypair(struct noise_keypairs *keypairs, + struct noise_keypair *new_keypair) { struct noise_keypair *previous_keypair, *next_keypair, *current_keypair; spin_lock_bh(&keypairs->keypair_update_lock); - previous_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - next_keypair = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - current_keypair = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + previous_keypair = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + next_keypair = rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + current_keypair = rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); if (new_keypair->i_am_the_initiator) { - /* If we're the initiator, it means we've sent a handshake, and received - * a confirmation response, which means this new keypair can now be used. + /* If we're the initiator, it means we've sent a handshake, and + * received a confirmation response, which means this new + * keypair can now be used. */ if (next_keypair) { - /* If there already was a next keypair pending, we demote it to be - * the previous keypair, and free the existing current. - * TODO: note that this means KCI can result in this transition. It - * would perhaps be more sound to always just get rid of the unused - * next keypair instead of putting it in the previous slot, but this - * might be a bit less robust. Something to think about and decide on. + /* If there already was a next keypair pending, we + * demote it to be the previous keypair, and free the + * existing current. Note that this means KCI can result + * in this transition. It would perhaps be more sound to + * always just get rid of the unused next keypair + * instead of putting it in the previous slot, but this + * might be a bit less robust. Something to think about + * for the future. */ RCU_INIT_POINTER(keypairs->next_keypair, NULL); - rcu_assign_pointer(keypairs->previous_keypair, next_keypair); + rcu_assign_pointer(keypairs->previous_keypair, + next_keypair); noise_keypair_put(current_keypair, true); - } else /* If there wasn't an existing next keypair, we replace the - * previous with the current one. + } else /* If there wasn't an existing next keypair, we replace + * the previous with the current one. */ - rcu_assign_pointer(keypairs->previous_keypair, current_keypair); - /* At this point we can get rid of the old previous keypair, and set up - * the new keypair. + rcu_assign_pointer(keypairs->previous_keypair, + current_keypair); + /* At this point we can get rid of the old previous keypair, and + * set up the new keypair. */ noise_keypair_put(previous_keypair, true); rcu_assign_pointer(keypairs->current_keypair, new_keypair); } else { - /* If we're the responder, it means we can't use the new keypair until - * we receive confirmation via the first data packet, so we get rid of - * the existing previous one, the possibly existing next one, and slide - * in the new next one. + /* If we're the responder, it means we can't use the new keypair + * until we receive confirmation via the first data packet, so + * we get rid of the existing previous one, the possibly + * existing next one, and slide in the new next one. */ rcu_assign_pointer(keypairs->next_keypair, new_keypair); noise_keypair_put(next_keypair, true); @@ -194,19 +228,25 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai spin_unlock_bh(&keypairs->keypair_update_lock); } -bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair) +bool noise_received_with_keypair(struct noise_keypairs *keypairs, + struct noise_keypair *received_keypair) { - bool key_is_new; struct noise_keypair *old_keypair; + bool key_is_new; /* We first check without taking the spinlock. */ - key_is_new = received_keypair == rcu_access_pointer(keypairs->next_keypair); + key_is_new = received_keypair == + rcu_access_pointer(keypairs->next_keypair); if (likely(!key_is_new)) return false; spin_lock_bh(&keypairs->keypair_update_lock); - /* After locking, we double check that things didn't change from beneath us. */ - if (unlikely(received_keypair != rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)))) { + /* After locking, we double check that things didn't change from + * beneath us. + */ + if (unlikely(received_keypair != + rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)))) { spin_unlock_bh(&keypairs->keypair_update_lock); return false; } @@ -215,8 +255,11 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k * into the current, the current into the previous, and get rid of * the old previous. */ - old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock))); + old_keypair = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + rcu_assign_pointer(keypairs->previous_keypair, + rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock))); noise_keypair_put(old_keypair, true); rcu_assign_pointer(keypairs->current_keypair, received_keypair); RCU_INIT_POINTER(keypairs->next_keypair, NULL); @@ -226,34 +269,46 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k } /* Must hold static_identity->lock */ -void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]) +void noise_set_static_identity_private_key( + struct noise_static_identity *static_identity, + const u8 private_key[NOISE_PUBLIC_KEY_LEN]) { - memcpy(static_identity->static_private, private_key, NOISE_PUBLIC_KEY_LEN); - static_identity->has_identity = curve25519_generate_public(static_identity->static_public, private_key); + memcpy(static_identity->static_private, private_key, + NOISE_PUBLIC_KEY_LEN); + static_identity->has_identity = curve25519_generate_public( + static_identity->static_public, private_key); } /* This is Hugo Krawczyk's HKDF: * - https://eprint.iacr.org/2010/264.pdf * - https://tools.ietf.org/html/rfc5869 */ -static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, size_t first_len, size_t second_len, size_t third_len, size_t data_len, const u8 chaining_key[NOISE_HASH_LEN]) +static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, + size_t first_len, size_t second_len, size_t third_len, + size_t data_len, const u8 chaining_key[NOISE_HASH_LEN]) { - u8 secret[BLAKE2S_OUTBYTES]; u8 output[BLAKE2S_OUTBYTES + 1]; + u8 secret[BLAKE2S_OUTBYTES]; #ifdef DEBUG - BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES || third_len > BLAKE2S_OUTBYTES || ((second_len || second_dst || third_len || third_dst) && (!first_len || !first_dst)) || ((third_len || third_dst) && (!second_len || !second_dst))); + BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES || + third_len > BLAKE2S_OUTBYTES || + ((second_len || second_dst || third_len || third_dst) && + (!first_len || !first_dst)) || + ((third_len || third_dst) && (!second_len || !second_dst))); #endif /* Extract entropy from data into secret */ - blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len, NOISE_HASH_LEN); + blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len, + NOISE_HASH_LEN); if (!first_dst || !first_len) goto out; /* Expand first key: key = secret, data = 0x1 */ output[0] = 1; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1, + BLAKE2S_OUTBYTES); memcpy(first_dst, output, first_len); if (!second_dst || !second_len) @@ -261,7 +316,8 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, si /* Expand second key: key = secret, data = first-key || 0x2 */ output[BLAKE2S_OUTBYTES] = 2; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, + BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); memcpy(second_dst, output, second_len); if (!third_dst || !third_len) @@ -269,7 +325,8 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, si /* Expand third key: key = secret, data = second-key || 0x3 */ output[BLAKE2S_OUTBYTES] = 3; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, + BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); memcpy(third_dst, output, third_len); out: @@ -282,25 +339,34 @@ static void symmetric_key_init(struct noise_symmetric_key *key) { spin_lock_init(&key->counter.receive.lock); atomic64_set(&key->counter.counter, 0); - memset(key->counter.receive.backtrack, 0, sizeof(key->counter.receive.backtrack)); + memset(key->counter.receive.backtrack, 0, + sizeof(key->counter.receive.backtrack)); key->birthdate = ktime_get_boot_fast_ns(); key->is_valid = true; } -static void derive_keys(struct noise_symmetric_key *first_dst, struct noise_symmetric_key *second_dst, const u8 chaining_key[NOISE_HASH_LEN]) +static void derive_keys(struct noise_symmetric_key *first_dst, + struct noise_symmetric_key *second_dst, + const u8 chaining_key[NOISE_HASH_LEN]) { - kdf(first_dst->key, second_dst->key, NULL, NULL, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, chaining_key); + kdf(first_dst->key, second_dst->key, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + chaining_key); symmetric_key_init(first_dst); symmetric_key_init(second_dst); } -static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 private[NOISE_PUBLIC_KEY_LEN], const u8 public[NOISE_PUBLIC_KEY_LEN]) +static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], + u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 private[NOISE_PUBLIC_KEY_LEN], + const u8 public[NOISE_PUBLIC_KEY_LEN]) { u8 dh_calculation[NOISE_PUBLIC_KEY_LEN]; if (unlikely(!curve25519(dh_calculation, private, public))) return false; - kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN); return true; } @@ -315,42 +381,59 @@ static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len) blake2s_final(&blake, hash, NOISE_HASH_LEN); } -static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) +static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], + u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) { u8 temp_hash[NOISE_HASH_LEN]; - kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key); + kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key); mix_hash(hash, temp_hash, NOISE_HASH_LEN); memzero_explicit(temp_hash, NOISE_HASH_LEN); } -static void handshake_init(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], const u8 remote_static[NOISE_PUBLIC_KEY_LEN]) +static void handshake_init(u8 chaining_key[NOISE_HASH_LEN], + u8 hash[NOISE_HASH_LEN], + const u8 remote_static[NOISE_PUBLIC_KEY_LEN]) { memcpy(hash, handshake_init_hash, NOISE_HASH_LEN); memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN); mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN); } -static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) +static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, + size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], + u8 hash[NOISE_HASH_LEN]) { - chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key); + chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash, + NOISE_HASH_LEN, + 0 /* Always zero for Noise_IK */, key); mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len)); } -static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) +static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, + size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], + u8 hash[NOISE_HASH_LEN]) { - if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key)) + if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len, + hash, NOISE_HASH_LEN, + 0 /* Always zero for Noise_IK */, key)) return false; mix_hash(hash, src_ciphertext, src_len); return true; } -static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN], const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN]) +static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN], + const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN], + u8 chaining_key[NOISE_HASH_LEN], + u8 hash[NOISE_HASH_LEN]) { if (ephemeral_dst != ephemeral_src) memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN); mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN); - kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0, + NOISE_PUBLIC_KEY_LEN, chaining_key); } static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN]) @@ -363,14 +446,15 @@ static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN]) *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec); } -bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake) +bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, + struct noise_handshake *handshake) { u8 timestamp[NOISE_TIMESTAMP_LEN]; u8 key[NOISE_SYMMETRIC_KEY_LEN]; bool ret = false; - /* We need to wait for crng _before_ taking any locks, since curve25519_generate_secret - * uses get_random_bytes_wait. + /* We need to wait for crng _before_ taking any locks, since + * curve25519_generate_secret uses get_random_bytes_wait. */ wait_for_random_bytes(); @@ -382,29 +466,42 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION); - handshake_init(handshake->chaining_key, handshake->hash, handshake->remote_static); + handshake_init(handshake->chaining_key, handshake->hash, + handshake->remote_static); /* e */ curve25519_generate_secret(handshake->ephemeral_private); - if (!curve25519_generate_public(dst->unencrypted_ephemeral, handshake->ephemeral_private)) + if (!curve25519_generate_public(dst->unencrypted_ephemeral, + handshake->ephemeral_private)) goto out; - message_ephemeral(dst->unencrypted_ephemeral, dst->unencrypted_ephemeral, handshake->chaining_key, handshake->hash); + message_ephemeral(dst->unencrypted_ephemeral, + dst->unencrypted_ephemeral, handshake->chaining_key, + handshake->hash); /* es */ - if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private, + handshake->remote_static)) goto out; /* s */ - message_encrypt(dst->encrypted_static, handshake->static_identity->static_public, NOISE_PUBLIC_KEY_LEN, key, handshake->hash); + message_encrypt(dst->encrypted_static, + handshake->static_identity->static_public, + NOISE_PUBLIC_KEY_LEN, key, handshake->hash); /* ss */ - kdf(handshake->chaining_key, key, NULL, handshake->precomputed_static_static, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, handshake->chaining_key); + kdf(handshake->chaining_key, key, NULL, + handshake->precomputed_static_static, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, + handshake->chaining_key); /* {t} */ tai64n_now(timestamp); - message_encrypt(dst->encrypted_timestamp, timestamp, NOISE_TIMESTAMP_LEN, key, handshake->hash); + message_encrypt(dst->encrypted_timestamp, timestamp, + NOISE_TIMESTAMP_LEN, key, handshake->hash); - dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + dst->sender_index = index_hashtable_insert( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry); handshake->state = HANDSHAKE_CREATED_INITIATION; ret = true; @@ -416,17 +513,19 @@ out: return ret; } -struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg) +struct wireguard_peer * +noise_handshake_consume_initiation(struct message_handshake_initiation *src, + struct wireguard_device *wg) { + struct wireguard_peer *peer = NULL, *ret_peer = NULL; + struct noise_handshake *handshake; bool replay_attack, flood_attack; + u8 key[NOISE_SYMMETRIC_KEY_LEN]; + u8 chaining_key[NOISE_HASH_LEN]; + u8 hash[NOISE_HASH_LEN]; u8 s[NOISE_PUBLIC_KEY_LEN]; u8 e[NOISE_PUBLIC_KEY_LEN]; u8 t[NOISE_TIMESTAMP_LEN]; - struct noise_handshake *handshake; - struct wireguard_peer *peer = NULL, *ret_peer = NULL; - u8 key[NOISE_SYMMETRIC_KEY_LEN]; - u8 hash[NOISE_HASH_LEN]; - u8 chaining_key[NOISE_HASH_LEN]; down_read(&wg->static_identity.lock); if (unlikely(!wg->static_identity.has_identity)) @@ -442,7 +541,8 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha goto out; /* s */ - if (!message_decrypt(s, src->encrypted_static, sizeof(src->encrypted_static), key, hash)) + if (!message_decrypt(s, src->encrypted_static, + sizeof(src->encrypted_static), key, hash)) goto out; /* Lookup which peer we're actually talking to */ @@ -452,15 +552,21 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha handshake = &peer->handshake; /* ss */ - kdf(chaining_key, key, NULL, handshake->precomputed_static_static, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, key, NULL, handshake->precomputed_static_static, + NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, + chaining_key); /* {t} */ - if (!message_decrypt(t, src->encrypted_timestamp, sizeof(src->encrypted_timestamp), key, hash)) + if (!message_decrypt(t, src->encrypted_timestamp, + sizeof(src->encrypted_timestamp), key, hash)) goto out; down_read(&handshake->lock); - replay_attack = memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) <= 0; - flood_attack = handshake->last_initiation_consumption + NSEC_PER_SEC / INITIATIONS_PER_SECOND > ktime_get_boot_fast_ns(); + replay_attack = memcmp(t, handshake->latest_timestamp, + NOISE_TIMESTAMP_LEN) <= 0; + flood_attack = handshake->last_initiation_consumption + + NSEC_PER_SEC / INITIATIONS_PER_SECOND > + ktime_get_boot_fast_ns(); up_read(&handshake->lock); if (replay_attack || flood_attack) goto out; @@ -487,13 +593,14 @@ out: return ret_peer; } -bool noise_handshake_create_response(struct message_handshake_response *dst, struct noise_handshake *handshake) +bool noise_handshake_create_response(struct message_handshake_response *dst, + struct noise_handshake *handshake) { bool ret = false; u8 key[NOISE_SYMMETRIC_KEY_LEN]; - /* We need to wait for crng _before_ taking any locks, since curve25519_generate_secret - * uses get_random_bytes_wait. + /* We need to wait for crng _before_ taking any locks, since + * curve25519_generate_secret uses get_random_bytes_wait. */ wait_for_random_bytes(); @@ -508,25 +615,33 @@ bool noise_handshake_create_response(struct message_handshake_response *dst, str /* e */ curve25519_generate_secret(handshake->ephemeral_private); - if (!curve25519_generate_public(dst->unencrypted_ephemeral, handshake->ephemeral_private)) + if (!curve25519_generate_public(dst->unencrypted_ephemeral, + handshake->ephemeral_private)) goto out; - message_ephemeral(dst->unencrypted_ephemeral, dst->unencrypted_ephemeral, handshake->chaining_key, handshake->hash); + message_ephemeral(dst->unencrypted_ephemeral, + dst->unencrypted_ephemeral, handshake->chaining_key, + handshake->hash); /* ee */ - if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_ephemeral)) + if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, + handshake->remote_ephemeral)) goto out; /* se */ - if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, + handshake->remote_static)) goto out; /* psk */ - mix_psk(handshake->chaining_key, handshake->hash, key, handshake->preshared_key); + mix_psk(handshake->chaining_key, handshake->hash, key, + handshake->preshared_key); /* {} */ message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash); - dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + dst->sender_index = index_hashtable_insert( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry); handshake->state = HANDSHAKE_CREATED_RESPONSE; ret = true; @@ -538,7 +653,9 @@ out: return ret; } -struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg) +struct wireguard_peer * +noise_handshake_consume_response(struct message_handshake_response *src, + struct wireguard_device *wg) { struct noise_handshake *handshake; struct wireguard_peer *peer = NULL, *ret_peer = NULL; @@ -555,7 +672,9 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake if (unlikely(!wg->static_identity.has_identity)) goto out; - handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer); + handshake = (struct noise_handshake *)index_hashtable_lookup( + &wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, + src->receiver_index, &peer); if (unlikely(!handshake)) goto out; @@ -563,7 +682,8 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake state = handshake->state; memcpy(hash, handshake->hash, NOISE_HASH_LEN); memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN); - memcpy(ephemeral_private, handshake->ephemeral_private, NOISE_PUBLIC_KEY_LEN); + memcpy(ephemeral_private, handshake->ephemeral_private, + NOISE_PUBLIC_KEY_LEN); up_read(&handshake->lock); if (state != HANDSHAKE_CREATED_INITIATION) @@ -584,12 +704,15 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake mix_psk(chaining_key, hash, key, handshake->preshared_key); /* {} */ - if (!message_decrypt(NULL, src->encrypted_nothing, sizeof(src->encrypted_nothing), key, hash)) + if (!message_decrypt(NULL, src->encrypted_nothing, + sizeof(src->encrypted_nothing), key, hash)) goto fail; /* Success! Copy everything to peer */ down_write(&handshake->lock); - /* It's important to check that the state is still the same, while we have an exclusive lock */ + /* It's important to check that the state is still the same, while we + * have an exclusive lock. + */ if (handshake->state != state) { up_write(&handshake->lock); goto fail; @@ -615,32 +738,43 @@ out: return ret_peer; } -bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs) +bool noise_handshake_begin_session(struct noise_handshake *handshake, + struct noise_keypairs *keypairs) { struct noise_keypair *new_keypair; bool ret = false; down_write(&handshake->lock); - if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE) + if (handshake->state != HANDSHAKE_CREATED_RESPONSE && + handshake->state != HANDSHAKE_CONSUMED_RESPONSE) goto out; new_keypair = keypair_create(handshake->entry.peer); if (!new_keypair) goto out; - new_keypair->i_am_the_initiator = handshake->state == HANDSHAKE_CONSUMED_RESPONSE; + new_keypair->i_am_the_initiator = handshake->state == + HANDSHAKE_CONSUMED_RESPONSE; new_keypair->remote_index = handshake->remote_index; if (new_keypair->i_am_the_initiator) - derive_keys(&new_keypair->sending, &new_keypair->receiving, handshake->chaining_key); + derive_keys(&new_keypair->sending, &new_keypair->receiving, + handshake->chaining_key); else - derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key); + derive_keys(&new_keypair->receiving, &new_keypair->sending, + handshake->chaining_key); handshake_zero(handshake); rcu_read_lock_bh(); - if (likely(!container_of(handshake, struct wireguard_peer, handshake)->is_dead)) { + if (likely(!container_of(handshake, struct wireguard_peer, + handshake)->is_dead)) { add_new_keypair(keypairs, new_keypair); - net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id); - ret = index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry); + net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", + handshake->entry.peer->device->dev->name, + new_keypair->internal_id, + handshake->entry.peer->internal_id); + ret = index_hashtable_replace( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry, &new_keypair->entry); } else kzfree(new_keypair); rcu_read_unlock_bh(); diff --git a/src/noise.h b/src/noise.h index be59587..6a563ce 100644 --- a/src/noise.h +++ b/src/noise.h @@ -86,29 +86,44 @@ struct noise_handshake { u8 latest_timestamp[NOISE_TIMESTAMP_LEN]; __le32 remote_index; - /* Protects all members except the immutable (after noise_handshake_init): remote_static, precomputed_static_static, static_identity */ + /* Protects all members except the immutable (after noise_handshake_ + * init): remote_static, precomputed_static_static, static_identity. */ struct rw_semaphore lock; }; struct wireguard_device; void noise_init(void); -bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer); +bool noise_handshake_init(struct noise_handshake *handshake, + struct noise_static_identity *static_identity, + const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], + const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], + struct wireguard_peer *peer); void noise_handshake_clear(struct noise_handshake *handshake); void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now); struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair); void noise_keypairs_clear(struct noise_keypairs *keypairs); -bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair); +bool noise_received_with_keypair(struct noise_keypairs *keypairs, + struct noise_keypair *received_keypair); -void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]); +void noise_set_static_identity_private_key( + struct noise_static_identity *static_identity, + const u8 private_key[NOISE_PUBLIC_KEY_LEN]); bool noise_precompute_static_static(struct wireguard_peer *peer); -bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake); -struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg); +bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, + struct noise_handshake *handshake); +struct wireguard_peer * +noise_handshake_consume_initiation(struct message_handshake_initiation *src, + struct wireguard_device *wg); -bool noise_handshake_create_response(struct message_handshake_response *dst, struct noise_handshake *handshake); -struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg); +bool noise_handshake_create_response(struct message_handshake_response *dst, + struct noise_handshake *handshake); +struct wireguard_peer * +noise_handshake_consume_response(struct message_handshake_response *src, + struct wireguard_device *wg); -bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs); +bool noise_handshake_begin_session(struct noise_handshake *handshake, + struct noise_keypairs *keypairs); #endif /* _WG_NOISE_H */ @@ -17,7 +17,10 @@ static atomic64_t peer_counter = ATOMIC64_INIT(0); -struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]) +struct wireguard_peer * +peer_create(struct wireguard_device *wg, + const u8 public_key[NOISE_PUBLIC_KEY_LEN], + const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]) { struct wireguard_peer *peer; @@ -31,11 +34,13 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ return NULL; peer->device = wg; - if (!noise_handshake_init(&peer->handshake, &wg->static_identity, public_key, preshared_key, peer)) + if (!noise_handshake_init(&peer->handshake, &wg->static_identity, + public_key, preshared_key, peer)) goto err_1; if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)) goto err_1; - if (packet_queue_init(&peer->tx_queue, packet_tx_worker, false, MAX_QUEUED_PACKETS)) + if (packet_queue_init(&peer->tx_queue, packet_tx_worker, false, + MAX_QUEUED_PACKETS)) goto err_2; if (packet_queue_init(&peer->rx_queue, NULL, false, MAX_QUEUED_PACKETS)) goto err_3; @@ -50,7 +55,9 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ rwlock_init(&peer->endpoint_lock); kref_init(&peer->refcount); skb_queue_head_init(&peer->staged_packet_queue); - atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns() - (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC); + atomic64_set(&peer->last_sent_handshake, + ktime_get_boot_fast_ns() - + (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC); set_bit(NAPI_STATE_NO_BUSY_POLL, &peer->napi.state); netif_napi_add(wg->dev, &peer->napi, packet_rx_poll, NAPI_POLL_WEIGHT); napi_enable(&peer->napi); @@ -71,15 +78,16 @@ err_1: struct wireguard_peer *peer_get_maybe_zero(struct wireguard_peer *peer) { - RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking peer reference without holding the RCU read lock"); + RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), + "Taking peer reference without holding the RCU read lock"); if (unlikely(!peer || !kref_get_unless_zero(&peer->refcount))) return NULL; return peer; } -/* We have a separate "remove" function to get rid of the final reference because - * peer_list, clearing handshakes, and flushing all require mutexes which requires - * sleeping, which must only be done from certain contexts. +/* We have a separate "remove" function to get rid of the final reference + * because peer_list, clearing handshakes, and flushing all require mutexes + * which requires sleeping, which must only be done from certain contexts. */ void peer_remove(struct wireguard_peer *peer) { @@ -87,37 +95,49 @@ void peer_remove(struct wireguard_peer *peer) return; lockdep_assert_held(&peer->device->device_update_lock); - /* Remove from configuration-time lookup structures so new packets can't enter. */ + /* Remove from configuration-time lookup structures so new packets + * can't enter. + */ list_del_init(&peer->peer_list); - allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, &peer->device->device_update_lock); + allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, + &peer->device->device_update_lock); pubkey_hashtable_remove(&peer->device->peer_hashtable, peer); /* Mark as dead, so that we don't allow jumping contexts after. */ WRITE_ONCE(peer->is_dead, true); synchronize_rcu_bh(); - /* Now that no more keypairs can be created for this peer, we destroy existing ones. */ + /* Now that no more keypairs can be created for this peer, we destroy + * existing ones. + */ noise_keypairs_clear(&peer->keypairs); - /* Destroy all ongoing timers that were in-flight at the beginning of this function. */ + /* Destroy all ongoing timers that were in-flight at the beginning of + * this function. + */ timers_stop(peer); - /* The transition between packet encryption/decryption queues isn't guarded - * by is_dead, but each reference's life is strictly bounded by two - * generations: once for parallel crypto and once for serial ingestion, - * so we can simply flush twice, and be sure that we no longer have references - * inside these queues. - * - * a) For encrypt/decrypt. */ + /* The transition between packet encryption/decryption queues isn't + * guarded by is_dead, but each reference's life is strictly bounded by + * two generations: once for parallel crypto and once for serial + * ingestion, so we can simply flush twice, and be sure that we no + * longer have references inside these queues. + */ + + /* a) For encrypt/decrypt. */ flush_workqueue(peer->device->packet_crypt_wq); /* b.1) For send (but not receive, since that's napi). */ flush_workqueue(peer->device->packet_crypt_wq); /* b.2.1) For receive (but not send, since that's wq). */ napi_disable(&peer->napi); - /* b.2.1) It's now safe to remove the napi struct, which must be done here from process context. */ + /* b.2.1) It's now safe to remove the napi struct, which must be done + * here from process context. + */ netif_napi_del(&peer->napi); - /* Ensure any workstructs we own (like transmit_handshake_work or clear_peer_work) no longer are in use. */ + /* Ensure any workstructs we own (like transmit_handshake_work or + * clear_peer_work) no longer are in use. + */ flush_workqueue(peer->device->handshake_send_wq); --peer->device->num_peers; @@ -126,7 +146,8 @@ void peer_remove(struct wireguard_peer *peer) static void rcu_release(struct rcu_head *rcu) { - struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu); + struct wireguard_peer *peer = + container_of(rcu, struct wireguard_peer, rcu); dst_cache_destroy(&peer->endpoint_cache); packet_queue_free(&peer->rx_queue, false); packet_queue_free(&peer->tx_queue, false); @@ -135,11 +156,19 @@ static void rcu_release(struct rcu_head *rcu) static void kref_release(struct kref *refcount) { - struct wireguard_peer *peer = container_of(refcount, struct wireguard_peer, refcount); - pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); - /* Remove ourself from dynamic runtime lookup structures, now that the last reference is gone. */ - index_hashtable_remove(&peer->device->index_hashtable, &peer->handshake.entry); - /* Remove any lingering packets that didn't have a chance to be transmitted. */ + struct wireguard_peer *peer = + container_of(refcount, struct wireguard_peer, refcount); + pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr); + /* Remove ourself from dynamic runtime lookup structures, now that the + * last reference is gone. + */ + index_hashtable_remove(&peer->device->index_hashtable, + &peer->handshake.entry); + /* Remove any lingering packets that didn't have a chance to be + * transmitted. + */ skb_queue_purge(&peer->staged_packet_queue); /* Free the memory used. */ call_rcu_bh(&peer->rcu, rcu_release); @@ -157,6 +186,6 @@ void peer_remove_all(struct wireguard_device *wg) struct wireguard_peer *peer, *temp; lockdep_assert_held(&wg->device_update_lock); - list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) + list_for_each_entry_safe (peer, temp, &wg->peer_list, peer_list) peer_remove(peer); } @@ -27,7 +27,8 @@ struct endpoint { union { struct { struct in_addr src4; - int src_if4; /* Essentially the same as addr6->scope_id */ + /* Essentially the same as addr6->scope_id */ + int src_if4; }; struct in6_addr src6; }; @@ -48,10 +49,13 @@ struct wireguard_peer { struct cookie latest_cookie; struct hlist_node pubkey_hash; u64 rx_bytes, tx_bytes; - struct timer_list timer_retransmit_handshake, timer_send_keepalive, timer_new_handshake, timer_zero_key_material, timer_persistent_keepalive; + struct timer_list timer_retransmit_handshake, timer_send_keepalive; + struct timer_list timer_new_handshake, timer_zero_key_material; + struct timer_list timer_persistent_keepalive; unsigned int timer_handshake_attempts; u16 persistent_keepalive_interval; - bool timers_enabled, timer_need_another_keepalive, sent_lastminute_handshake; + bool timers_enabled, timer_need_another_keepalive; + bool sent_lastminute_handshake; struct timespec walltime_last_handshake; struct kref refcount; struct rcu_head rcu; @@ -61,9 +65,13 @@ struct wireguard_peer { bool is_dead; }; -struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); +struct wireguard_peer * +peer_create(struct wireguard_device *wg, + const u8 public_key[NOISE_PUBLIC_KEY_LEN], + const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); -struct wireguard_peer * __must_check peer_get_maybe_zero(struct wireguard_peer *peer); +struct wireguard_peer *__must_check +peer_get_maybe_zero(struct wireguard_peer *peer); static inline struct wireguard_peer *peer_get(struct wireguard_peer *peer) { kref_get(&peer->refcount); @@ -73,6 +81,7 @@ void peer_put(struct wireguard_peer *peer); void peer_remove(struct wireguard_peer *peer); void peer_remove_all(struct wireguard_device *wg); -struct wireguard_peer *peer_lookup_by_index(struct wireguard_device *wg, u32 index); +struct wireguard_peer *peer_lookup_by_index(struct wireguard_device *wg, + u32 index); #endif /* _WG_PEER_H */ diff --git a/src/queueing.c b/src/queueing.c index c8394fc..9ec6588 100644 --- a/src/queueing.c +++ b/src/queueing.c @@ -5,22 +5,25 @@ #include "queueing.h" -struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr) +struct multicore_worker __percpu * +packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr) { int cpu; - struct multicore_worker __percpu *worker = alloc_percpu(struct multicore_worker); + struct multicore_worker __percpu *worker = + alloc_percpu(struct multicore_worker); if (!worker) return NULL; - for_each_possible_cpu(cpu) { + for_each_possible_cpu (cpu) { per_cpu_ptr(worker, cpu)->ptr = ptr; INIT_WORK(&per_cpu_ptr(worker, cpu)->work, function); } return worker; } -int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore, unsigned int len) +int packet_queue_init(struct crypt_queue *queue, work_func_t function, + bool multicore, unsigned int len) { int ret; @@ -30,7 +33,8 @@ int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool mult return ret; if (function) { if (multicore) { - queue->worker = packet_alloc_percpu_multicore_worker(function, queue); + queue->worker = packet_alloc_percpu_multicore_worker( + function, queue); if (!queue->worker) return -ENOMEM; } else diff --git a/src/queueing.h b/src/queueing.h index 6a1de33..66b7134 100644 --- a/src/queueing.h +++ b/src/queueing.h @@ -19,9 +19,11 @@ struct crypt_queue; struct sk_buff; /* queueing.c APIs: */ -int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore, unsigned int len); +int packet_queue_init(struct crypt_queue *queue, work_func_t function, + bool multicore, unsigned int len); void packet_queue_free(struct crypt_queue *queue, bool multicore); -struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr); +struct multicore_worker __percpu * +packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr); /* receive.c APIs: */ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb); @@ -32,9 +34,12 @@ int packet_rx_poll(struct napi_struct *napi, int budget); void packet_decrypt_worker(struct work_struct *work); /* send.c APIs: */ -void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool is_retry); +void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, + bool is_retry); void packet_send_handshake_response(struct wireguard_peer *peer); -void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index); +void packet_send_handshake_cookie(struct wireguard_device *wg, + struct sk_buff *initiating_skb, + __le32 sender_index); void packet_send_keepalive(struct wireguard_peer *peer); void packet_send_staged_packets(struct wireguard_peer *peer); /* Workqueue workers: */ @@ -42,7 +47,12 @@ void packet_handshake_send_worker(struct work_struct *work); void packet_tx_worker(struct work_struct *work); void packet_encrypt_worker(struct work_struct *work); -enum packet_state { PACKET_STATE_UNCRYPTED, PACKET_STATE_CRYPTED, PACKET_STATE_DEAD }; +enum packet_state { + PACKET_STATE_UNCRYPTED, + PACKET_STATE_CRYPTED, + PACKET_STATE_DEAD +}; + struct packet_cb { u64 nonce; struct noise_keypair *keypair; @@ -50,15 +60,22 @@ struct packet_cb { u32 mtu; u8 ds; }; + #define PACKET_PEER(skb) (((struct packet_cb *)skb->cb)->keypair->entry.peer) #define PACKET_CB(skb) ((struct packet_cb *)skb->cb) /* Returns either the correct skb->protocol value, or 0 if invalid. */ static inline __be16 skb_examine_untrusted_ip_hdr(struct sk_buff *skb) { - if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct iphdr)) <= skb_tail_pointer(skb) && ip_hdr(skb)->version == 4) + if (skb_network_header(skb) >= skb->head && + (skb_network_header(skb) + sizeof(struct iphdr)) <= + skb_tail_pointer(skb) && + ip_hdr(skb)->version == 4) return htons(ETH_P_IP); - if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct ipv6hdr)) <= skb_tail_pointer(skb) && ipv6_hdr(skb)->version == 6) + if (skb_network_header(skb) >= skb->head && + (skb_network_header(skb) + sizeof(struct ipv6hdr)) <= + skb_tail_pointer(skb) && + ipv6_hdr(skb)->version == 6) return htons(ETH_P_IPV6); return 0; } @@ -67,7 +84,9 @@ static inline void skb_reset(struct sk_buff *skb) { const int pfmemalloc = skb->pfmemalloc; skb_scrub_packet(skb, true); - memset(&skb->headers_start, 0, offsetof(struct sk_buff, headers_end) - offsetof(struct sk_buff, headers_start)); + memset(&skb->headers_start, 0, + offsetof(struct sk_buff, headers_end) - + offsetof(struct sk_buff, headers_start)); skb->pfmemalloc = pfmemalloc; skb->queue_mapping = 0; skb->nohdr = 0; @@ -89,7 +108,8 @@ static inline int cpumask_choose_online(int *stored_cpu, unsigned int id) { unsigned int cpu = *stored_cpu, cpu_index, i; - if (unlikely(cpu == nr_cpumask_bits || !cpumask_test_cpu(cpu, cpu_online_mask))) { + if (unlikely(cpu == nr_cpumask_bits || + !cpumask_test_cpu(cpu, cpu_online_mask))) { cpu_index = id % cpumask_weight(cpu_online_mask); cpu = cpumask_first(cpu_online_mask); for (i = 0; i < cpu_index; ++i) @@ -103,8 +123,8 @@ static inline int cpumask_choose_online(int *stored_cpu, unsigned int id) * the same CPU twice. A race-free version of this would be to instead store an * atomic sequence number, do an increment-and-return, and then iterate through * every possible CPU until we get to that index -- choose_cpu. However that's - * a bit slower, and it doesn't seem like this potential race actually introduces - * any performance loss, so we live with it. + * a bit slower, and it doesn't seem like this potential race actually + * introduces any performance loss, so we live with it. */ static inline int cpumask_next_online(int *next) { @@ -116,7 +136,9 @@ static inline int cpumask_next_online(int *next) return cpu; } -static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_queue, struct crypt_queue *peer_queue, struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu) +static inline int queue_enqueue_per_device_and_peer( + struct crypt_queue *device_queue, struct crypt_queue *peer_queue, + struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu) { int cpu; @@ -136,18 +158,24 @@ static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_q return 0; } -static inline void queue_enqueue_per_peer(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) +static inline void queue_enqueue_per_peer(struct crypt_queue *queue, + struct sk_buff *skb, + enum packet_state state) { /* We take a reference, because as soon as we call atomic_set, the * peer can be freed from below us. */ struct wireguard_peer *peer = peer_get(PACKET_PEER(skb)); atomic_set_release(&PACKET_CB(skb)->state, state); - queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &queue->work); + queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, + peer->internal_id), + peer->device->packet_crypt_wq, &queue->work); peer_put(peer); } -static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) +static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue, + struct sk_buff *skb, + enum packet_state state) { /* We take a reference, because as soon as we call atomic_set, the * peer can be freed from below us. diff --git a/src/ratelimiter.c b/src/ratelimiter.c index 638da2b..3966ce8 100644 --- a/src/ratelimiter.c +++ b/src/ratelimiter.c @@ -41,7 +41,8 @@ enum { static void entry_free(struct rcu_head *rcu) { - kmem_cache_free(entry_cache, container_of(rcu, struct ratelimiter_entry, rcu)); + kmem_cache_free(entry_cache, + container_of(rcu, struct ratelimiter_entry, rcu)); atomic_dec(&total_entries); } @@ -54,20 +55,22 @@ static void entry_uninit(struct ratelimiter_entry *entry) /* Calling this function with a NULL work uninits all entries. */ static void gc_entries(struct work_struct *work) { - unsigned int i; + const u64 now = ktime_get_boot_fast_ns(); struct ratelimiter_entry *entry; struct hlist_node *temp; - const u64 now = ktime_get_boot_fast_ns(); + unsigned int i; for (i = 0; i < table_size; ++i) { spin_lock(&table_lock); - hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) { - if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC) + hlist_for_each_entry_safe (entry, temp, &table_v4[i], hash) { + if (unlikely(!work) || + now - entry->last_time_ns > NSEC_PER_SEC) entry_uninit(entry); } #if IS_ENABLED(CONFIG_IPV6) - hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) { - if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC) + hlist_for_each_entry_safe (entry, temp, &table_v6[i], hash) { + if (unlikely(!work) || + now - entry->last_time_ns > NSEC_PER_SEC) entry_uninit(entry); } #endif @@ -81,34 +84,41 @@ static void gc_entries(struct work_struct *work) bool ratelimiter_allow(struct sk_buff *skb, struct net *net) { + struct { __be64 ip; u32 net; } data = + { .net = (unsigned long)net & 0xffffffff }; struct ratelimiter_entry *entry; struct hlist_head *bucket; - struct { __be64 ip; u32 net; } data = { .net = (unsigned long)net & 0xffffffff }; if (skb->protocol == htons(ETH_P_IP)) { data.ip = (__force __be64)ip_hdr(skb)->saddr; - bucket = &table_v4[hsiphash(&data, sizeof(u32) * 3, &key) & (table_size - 1)]; + bucket = &table_v4[hsiphash(&data, sizeof(u32) * 3, &key) & + (table_size - 1)]; } #if IS_ENABLED(CONFIG_IPV6) else if (skb->protocol == htons(ETH_P_IPV6)) { - memcpy(&data.ip, &ipv6_hdr(skb)->saddr, sizeof(__be64)); /* Only 64 bits */ - bucket = &table_v6[hsiphash(&data, sizeof(u32) * 3, &key) & (table_size - 1)]; + memcpy(&data.ip, &ipv6_hdr(skb)->saddr, + sizeof(__be64)); /* Only 64 bits */ + bucket = &table_v6[hsiphash(&data, sizeof(u32) * 3, &key) & + (table_size - 1)]; } #endif else return false; rcu_read_lock(); - hlist_for_each_entry_rcu(entry, bucket, hash) { + hlist_for_each_entry_rcu (entry, bucket, hash) { if (entry->net == net && entry->ip == data.ip) { u64 now, tokens; bool ret; - /* Inspired by nft_limit.c, but this is actually a slightly different - * algorithm. Namely, we incorporate the burst as part of the maximum - * tokens, rather than as part of the rate. + /* Quasi-inspired by nft_limit.c, but this is actually a + * slightly different algorithm. Namely, we incorporate + * the burst as part of the maximum tokens, rather than + * as part of the rate. */ spin_lock(&entry->lock); now = ktime_get_boot_fast_ns(); - tokens = min_t(u64, TOKEN_MAX, entry->tokens + now - entry->last_time_ns); + tokens = min_t(u64, TOKEN_MAX, + entry->tokens + now - + entry->last_time_ns); entry->last_time_ns = now; ret = tokens >= PACKET_COST; entry->tokens = ret ? tokens - PACKET_COST : tokens; @@ -157,7 +167,10 @@ int ratelimiter_init(void) * we borrow their wisdom about good table sizes on different systems * dependent on RAM. This calculation here comes from there. */ - table_size = (totalram_pages > (1U << 30) / PAGE_SIZE) ? 8192 : max_t(unsigned long, 16, roundup_pow_of_two((totalram_pages << PAGE_SHIFT) / (1U << 14) / sizeof(struct hlist_head))); + table_size = (totalram_pages > (1U << 30) / PAGE_SIZE) ? 8192 : + max_t(unsigned long, 16, roundup_pow_of_two( + (totalram_pages << PAGE_SHIFT) / + (1U << 14) / sizeof(struct hlist_head))); max_entries = table_size * 8; table_v4 = kvzalloc(table_size * sizeof(struct hlist_head), GFP_KERNEL); diff --git a/src/receive.c b/src/receive.c index 4e73da1..a2d1d9d 100644 --- a/src/receive.c +++ b/src/receive.c @@ -20,7 +20,8 @@ /* Must be called with bh disabled. */ static inline void rx_stats(struct wireguard_peer *peer, size_t len) { - struct pcpu_sw_netstats *tstats = get_cpu_ptr(peer->device->dev->tstats); + struct pcpu_sw_netstats *tstats = + get_cpu_ptr(peer->device->dev->tstats); u64_stats_update_begin(&tstats->syncp); ++tstats->rx_packets; @@ -36,38 +37,57 @@ static inline size_t validate_header_len(struct sk_buff *skb) { if (unlikely(skb->len < sizeof(struct message_header))) return 0; - if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH) + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) && + skb->len >= MESSAGE_MINIMUM_LENGTH) return sizeof(struct message_data); - if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation)) + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && + skb->len == sizeof(struct message_handshake_initiation)) return sizeof(struct message_handshake_initiation); - if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response)) + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && + skb->len == sizeof(struct message_handshake_response)) return sizeof(struct message_handshake_response); - if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie)) + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && + skb->len == sizeof(struct message_handshake_cookie)) return sizeof(struct message_handshake_cookie); return 0; } -static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_device *wg) +static inline int skb_prepare_header(struct sk_buff *skb, + struct wireguard_device *wg) { - struct udphdr *udp; size_t data_offset, data_len, header_len; + struct udphdr *udp; - if (unlikely(skb_examine_untrusted_ip_hdr(skb) != skb->protocol || skb_transport_header(skb) < skb->head || (skb_transport_header(skb) + sizeof(struct udphdr)) > skb_tail_pointer(skb))) + if (unlikely(skb_examine_untrusted_ip_hdr(skb) != skb->protocol || + skb_transport_header(skb) < skb->head || + (skb_transport_header(skb) + sizeof(struct udphdr)) > + skb_tail_pointer(skb))) return -EINVAL; /* Bogus IP header */ udp = udp_hdr(skb); data_offset = (u8 *)udp - skb->data; - if (unlikely(data_offset > U16_MAX || data_offset + sizeof(struct udphdr) > skb->len)) - return -EINVAL; /* Packet has offset at impossible location or isn't big enough to have UDP fields */ + if (unlikely(data_offset > U16_MAX || + data_offset + sizeof(struct udphdr) > skb->len)) + /* Packet has offset at impossible location or isn't big enough + * to have UDP fields. + */ + return -EINVAL; data_len = ntohs(udp->len); - if (unlikely(data_len < sizeof(struct udphdr) || data_len > skb->len - data_offset)) - return -EINVAL; /* UDP packet is reporting too small of a size or lying about its size */ + if (unlikely(data_len < sizeof(struct udphdr) || + data_len > skb->len - data_offset)) + /* UDP packet is reporting too small of a size or lying about + * its size. + */ + return -EINVAL; data_len -= sizeof(struct udphdr); data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data; - if (unlikely(!pskb_may_pull(skb, data_offset + sizeof(struct message_header)) || pskb_trim(skb, data_len + data_offset) < 0)) + if (unlikely(!pskb_may_pull(skb, + data_offset + sizeof(struct message_header)) || + pskb_trim(skb, data_len + data_offset) < 0)) return -EINVAL; skb_pull(skb, data_offset); if (unlikely(skb->len != data_len)) - return -EINVAL; /* Final len does not agree with calculated len */ + /* Final len does not agree with calculated len */ + return -EINVAL; header_len = validate_header_len(skb); if (unlikely(!header_len)) return -EINVAL; @@ -78,74 +98,96 @@ static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_devic return 0; } -static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff *skb) +static void receive_handshake_packet(struct wireguard_device *wg, + struct sk_buff *skb) { - static u64 last_under_load; /* Yes this is global, so that our load calculation applies to the whole system. */ struct wireguard_peer *peer = NULL; - bool under_load; enum cookie_mac_state mac_state; + /* This is global, so that our load calculation applies to + * the whole system. + */ + static u64 last_under_load; bool packet_needs_cookie; + bool under_load; if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE)) { - net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n", wg->dev->name, skb); - cookie_message_consume((struct message_handshake_cookie *)skb->data, wg); + net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n", + wg->dev->name, skb); + cookie_message_consume( + (struct message_handshake_cookie *)skb->data, wg); return; } - under_load = skb_queue_len(&wg->incoming_handshakes) >= MAX_QUEUED_INCOMING_HANDSHAKES / 8; + under_load = skb_queue_len(&wg->incoming_handshakes) >= + MAX_QUEUED_INCOMING_HANDSHAKES / 8; if (under_load) last_under_load = ktime_get_boot_fast_ns(); else if (last_under_load) under_load = !has_expired(last_under_load, 1); - mac_state = cookie_validate_packet(&wg->cookie_checker, skb, under_load); - if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) + mac_state = cookie_validate_packet(&wg->cookie_checker, skb, + under_load); + if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || + (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) packet_needs_cookie = false; else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE) packet_needs_cookie = true; else { - net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n", wg->dev->name, skb); + net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n", + wg->dev->name, skb); return; } switch (SKB_TYPE_LE32(skb)) { case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): { - struct message_handshake_initiation *message = (struct message_handshake_initiation *)skb->data; + struct message_handshake_initiation *message = + (struct message_handshake_initiation *)skb->data; if (packet_needs_cookie) { - packet_send_handshake_cookie(wg, skb, message->sender_index); + packet_send_handshake_cookie(wg, skb, + message->sender_index); return; } peer = noise_handshake_consume_initiation(message, wg); if (unlikely(!peer)) { - net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n", wg->dev->name, skb); + net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n", + wg->dev->name, skb); return; } socket_set_peer_endpoint_from_skb(peer, skb); - net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n", wg->dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n", + wg->dev->name, peer->internal_id, + &peer->endpoint.addr); packet_send_handshake_response(peer); break; } case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): { - struct message_handshake_response *message = (struct message_handshake_response *)skb->data; + struct message_handshake_response *message = + (struct message_handshake_response *)skb->data; if (packet_needs_cookie) { - packet_send_handshake_cookie(wg, skb, message->sender_index); + packet_send_handshake_cookie(wg, skb, + message->sender_index); return; } peer = noise_handshake_consume_response(message, wg); if (unlikely(!peer)) { - net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n", wg->dev->name, skb); + net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n", + wg->dev->name, skb); return; } socket_set_peer_endpoint_from_skb(peer, skb); - net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n", wg->dev->name, peer->internal_id, &peer->endpoint.addr); - if (noise_handshake_begin_session(&peer->handshake, &peer->keypairs)) { + net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n", + wg->dev->name, peer->internal_id, + &peer->endpoint.addr); + if (noise_handshake_begin_session(&peer->handshake, + &peer->keypairs)) { timers_session_derived(peer); timers_handshake_complete(peer); - /* Calling this function will either send any existing packets in the queue - * and not send a keepalive, which is the best case, Or, if there's nothing - * in the queue, it will send a keepalive, in order to give immediate - * confirmation of the session. + /* Calling this function will either send any existing + * packets in the queue and not send a keepalive, which + * is the best case, Or, if there's nothing in the + * queue, it will send a keepalive, in order to give + * immediate confirmation of the session. */ packet_send_keepalive(peer); } @@ -169,7 +211,8 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff void packet_handshake_receive_worker(struct work_struct *work) { - struct wireguard_device *wg = container_of(work, struct multicore_worker, work)->ptr; + struct wireguard_device *wg = + container_of(work, struct multicore_worker, work)->ptr; struct sk_buff *skb; while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) { @@ -189,8 +232,10 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) rcu_read_lock_bh(); keypair = rcu_dereference_bh(peer->keypairs.current_keypair); - if (likely(keypair && keypair->sending.is_valid) && keypair->i_am_the_initiator && - unlikely(has_expired(keypair->sending.birthdate, REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT))) + if (likely(keypair && keypair->sending.is_valid) && + keypair->i_am_the_initiator && + unlikely(has_expired(keypair->sending.birthdate, + REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT))) send = true; rcu_read_unlock_bh(); @@ -200,7 +245,9 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) } } -static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key, simd_context_t simd_context) +static inline bool skb_decrypt(struct sk_buff *skb, + struct noise_symmetric_key *key, + simd_context_t simd_context) { struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1]; struct sk_buff *trailer; @@ -210,12 +257,15 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * if (unlikely(!key)) return false; - if (unlikely(!key->is_valid || has_expired(key->birthdate, REJECT_AFTER_TIME) || key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) { + if (unlikely(!key->is_valid || + has_expired(key->birthdate, REJECT_AFTER_TIME) || + key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) { key->is_valid = false; return false; } - PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter); + PACKET_CB(skb)->nonce = + le64_to_cpu(((struct message_data *)skb->data)->counter); /* We ensure that the network header is part of the packet before we * call skb_cow_data, so that there's no chance that data is removed @@ -233,7 +283,9 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0) return false; - if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key, simd_context)) + if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, + PACKET_CB(skb)->nonce, key->key, + simd_context)) return false; /* Another ugly situation of pushing and pulling the header so as to @@ -248,33 +300,39 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * } /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */ -static inline bool counter_validate(union noise_counter *counter, u64 their_counter) +static inline bool counter_validate(union noise_counter *counter, + u64 their_counter) { - bool ret = false; unsigned long index, index_current, top, i; + bool ret = false; spin_lock_bh(&counter->receive.lock); - if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES)) + if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || + their_counter >= REJECT_AFTER_MESSAGES)) goto out; ++their_counter; - if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < counter->receive.counter)) + if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < + counter->receive.counter)) goto out; index = their_counter >> ilog2(BITS_PER_LONG); if (likely(their_counter > counter->receive.counter)) { index_current = counter->receive.counter >> ilog2(BITS_PER_LONG); - top = min_t(unsigned long, index - index_current, COUNTER_BITS_TOTAL / BITS_PER_LONG); + top = min_t(unsigned long, index - index_current, + COUNTER_BITS_TOTAL / BITS_PER_LONG); for (i = 1; i <= top; ++i) - counter->receive.backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; + counter->receive.backtrack[(i + index_current) & + ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; counter->receive.counter = their_counter; } index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1; - ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), &counter->receive.backtrack[index]); + ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), + &counter->receive.backtrack[index]); out: spin_unlock_bh(&counter->receive.lock); @@ -282,15 +340,18 @@ out: } #include "selftest/counter.h" -static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff *skb, struct endpoint *endpoint) +static void packet_consume_data_done(struct wireguard_peer *peer, + struct sk_buff *skb, + struct endpoint *endpoint) { - struct wireguard_peer *routed_peer; struct net_device *dev = peer->device->dev; + struct wireguard_peer *routed_peer; unsigned int len, len_before_trim; socket_set_peer_endpoint(peer, endpoint); - if (unlikely(noise_received_with_keypair(&peer->keypairs, PACKET_CB(skb)->keypair))) { + if (unlikely(noise_received_with_keypair(&peer->keypairs, + PACKET_CB(skb)->keypair))) { timers_handshake_complete(peer); packet_send_staged_packets(peer); } @@ -303,7 +364,9 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff /* A packet with length 0 is a keepalive packet */ if (unlikely(!skb->len)) { rx_stats(peer, message_data_len(0)); - net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n", + dev->name, peer->internal_id, + &peer->endpoint.addr); goto packet_processed; } @@ -311,7 +374,10 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff if (unlikely(skb_network_header(skb) < skb->head)) goto dishonest_packet_size; - if (unlikely(!(pskb_network_may_pull(skb, sizeof(struct iphdr)) && (ip_hdr(skb)->version == 4 || (ip_hdr(skb)->version == 6 && pskb_network_may_pull(skb, sizeof(struct ipv6hdr))))))) + if (unlikely(!(pskb_network_may_pull(skb, sizeof(struct iphdr)) && + (ip_hdr(skb)->version == 4 || + (ip_hdr(skb)->version == 6 && + pskb_network_may_pull(skb, sizeof(struct ipv6hdr))))))) goto dishonest_packet_type; skb->dev = dev; @@ -324,7 +390,8 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff if (INET_ECN_is_ce(PACKET_CB(skb)->ds)) IP_ECN_set_ce(ip_hdr(skb)); } else if (skb->protocol == htons(ETH_P_IPV6)) { - len = ntohs(ipv6_hdr(skb)->payload_len) + sizeof(struct ipv6hdr); + len = ntohs(ipv6_hdr(skb)->payload_len) + + sizeof(struct ipv6hdr); if (INET_ECN_is_ce(PACKET_CB(skb)->ds)) IP6_ECN_set_ce(skb, ipv6_hdr(skb)); } else @@ -344,23 +411,29 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff if (unlikely(napi_gro_receive(&peer->napi, skb) == GRO_DROP)) { ++dev->stats.rx_dropped; - net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n", + dev->name, peer->internal_id, + &peer->endpoint.addr); } else rx_stats(peer, message_data_len(len_before_trim)); return; dishonest_packet_peer: - net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n", dev->name, skb, peer->internal_id, &peer->endpoint.addr); + net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n", + dev->name, skb, peer->internal_id, + &peer->endpoint.addr); ++dev->stats.rx_errors; ++dev->stats.rx_frame_errors; goto packet_processed; dishonest_packet_type: - net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n", + dev->name, peer->internal_id, &peer->endpoint.addr); ++dev->stats.rx_errors; ++dev->stats.rx_frame_errors; goto packet_processed; dishonest_packet_size: - net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n", + dev->name, peer->internal_id, &peer->endpoint.addr); ++dev->stats.rx_errors; ++dev->stats.rx_length_errors; goto packet_processed; @@ -370,19 +443,22 @@ packet_processed: int packet_rx_poll(struct napi_struct *napi, int budget) { - struct wireguard_peer *peer = container_of(napi, struct wireguard_peer, napi); + struct wireguard_peer *peer = + container_of(napi, struct wireguard_peer, napi); struct crypt_queue *queue = &peer->rx_queue; struct noise_keypair *keypair; - struct sk_buff *skb; struct endpoint endpoint; enum packet_state state; + struct sk_buff *skb; int work_done = 0; bool free; if (unlikely(budget <= 0)) return 0; - while ((skb = __ptr_ring_peek(&queue->ring)) != NULL && (state = atomic_read_acquire(&PACKET_CB(skb)->state)) != PACKET_STATE_UNCRYPTED) { + while ((skb = __ptr_ring_peek(&queue->ring)) != NULL && + (state = atomic_read_acquire(&PACKET_CB(skb)->state)) != + PACKET_STATE_UNCRYPTED) { __ptr_ring_discard_one(&queue->ring); peer = PACKET_PEER(skb); keypair = PACKET_CB(skb)->keypair; @@ -391,8 +467,12 @@ int packet_rx_poll(struct napi_struct *napi, int budget) if (unlikely(state != PACKET_STATE_CRYPTED)) goto next; - if (unlikely(!counter_validate(&keypair->receiving.counter, PACKET_CB(skb)->nonce))) { - net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, keypair->receiving.counter.receive.counter); + if (unlikely(!counter_validate(&keypair->receiving.counter, + PACKET_CB(skb)->nonce))) { + net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", + peer->device->dev->name, + PACKET_CB(skb)->nonce, + keypair->receiving.counter.receive.counter); goto next; } @@ -403,7 +483,7 @@ int packet_rx_poll(struct napi_struct *napi, int budget) packet_consume_data_done(peer, skb, &endpoint); free = false; -next: + next: noise_keypair_put(keypair, false); peer_put(peer); if (unlikely(free)) @@ -421,34 +501,46 @@ next: void packet_decrypt_worker(struct work_struct *work) { - struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; - struct sk_buff *skb; + struct crypt_queue *queue = + container_of(work, struct multicore_worker, work)->ptr; simd_context_t simd_context = simd_get(); + struct sk_buff *skb; while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) { - enum packet_state state = likely(skb_decrypt(skb, &PACKET_CB(skb)->keypair->receiving, simd_context)) ? PACKET_STATE_CRYPTED : PACKET_STATE_DEAD; - queue_enqueue_per_peer_napi(&PACKET_PEER(skb)->rx_queue, skb, state); + enum packet_state state = likely(skb_decrypt(skb, + &PACKET_CB(skb)->keypair->receiving, + simd_context)) ? + PACKET_STATE_CRYPTED : PACKET_STATE_DEAD; + queue_enqueue_per_peer_napi(&PACKET_PEER(skb)->rx_queue, skb, + state); simd_context = simd_relax(simd_context); } simd_put(simd_context); } -static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb) +static void packet_consume_data(struct wireguard_device *wg, + struct sk_buff *skb) { - struct wireguard_peer *peer = NULL; __le32 idx = ((struct message_data *)skb->data)->key_idx; + struct wireguard_peer *peer = NULL; int ret; rcu_read_lock_bh(); - PACKET_CB(skb)->keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx, &peer); + PACKET_CB(skb)->keypair = + (struct noise_keypair *)index_hashtable_lookup( + &wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx, + &peer); if (unlikely(!noise_keypair_get(PACKET_CB(skb)->keypair))) goto err_keypair; if (unlikely(peer->is_dead)) goto err; - ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); + ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, + &peer->rx_queue, skb, + wg->packet_crypt_wq, + &wg->decrypt_queue.last_cpu); if (unlikely(ret == -EPIPE)) queue_enqueue_per_peer(&peer->rx_queue, skb, PACKET_STATE_DEAD); if (likely(!ret || ret == -EPIPE)) { @@ -473,14 +565,20 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): { int cpu; - if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES || unlikely(!rng_is_initialized())) { - net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n", wg->dev->name, skb); + if (skb_queue_len(&wg->incoming_handshakes) > + MAX_QUEUED_INCOMING_HANDSHAKES || + unlikely(!rng_is_initialized())) { + net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n", + wg->dev->name, skb); goto err; } skb_queue_tail(&wg->incoming_handshakes, skb); - /* Queues up a call to packet_process_queued_handshake_packets(skb): */ + /* Queues up a call to packet_process_queued_handshake_ + * packets(skb): + */ cpu = cpumask_next_online(&wg->incoming_handshake_cpu); - queue_work_on(cpu, wg->handshake_receive_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work); + queue_work_on(cpu, wg->handshake_receive_wq, + &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work); break; } case cpu_to_le32(MESSAGE_DATA): @@ -488,7 +586,8 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) packet_consume_data(wg, skb); break; default: - net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n", wg->dev->name, skb); + net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n", + wg->dev->name, skb); goto err; } return; diff --git a/src/selftest/allowedips.h b/src/selftest/allowedips.h index 1cf8c8f..28461d7 100644 --- a/src/selftest/allowedips.h +++ b/src/selftest/allowedips.h @@ -8,7 +8,8 @@ #ifdef DEBUG_PRINT_TRIE_GRAPHVIZ #include <linux/siphash.h> -static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, u8 cidr) +static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, + u8 cidr) { swap_endian(dst, src, bits); memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8); @@ -18,34 +19,44 @@ static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, u static __init void print_node(struct allowedips_node *node, u8 bits) { - u32 color = 0; - char *style = "dotted"; char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n"; - char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; + char *fmt_declaration = KERN_DEBUG + "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; + char *style = "dotted"; u8 ip1[16], ip2[16]; + u32 color = 0; + if (bits == 32) { fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n"; - fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; + fmt_declaration = KERN_DEBUG + "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; } else if (bits == 128) { fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n"; - fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; + fmt_declaration = KERN_DEBUG + "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; } if (node->peer) { hsiphash_key_t key = { 0 }; memcpy(&key, &node->peer, sizeof(node->peer)); - color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 | hsiphash_1u32(0xbabecafe, &key) % 200 << 8 | hsiphash_1u32(0xabad1dea, &key) % 200; + color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 | + hsiphash_1u32(0xbabecafe, &key) % 200 << 8 | + hsiphash_1u32(0xabad1dea, &key) % 200; style = "bold"; } swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr); printk(fmt_declaration, ip1, node->cidr, style, color); if (node->bit[0]) { - swap_endian_and_apply_cidr(ip2, node->bit[0]->bits, bits, node->cidr); - printk(fmt_connection, ip1, node->cidr, ip2, node->bit[0]->cidr); + swap_endian_and_apply_cidr(ip2, node->bit[0]->bits, bits, + node->cidr); + printk(fmt_connection, ip1, node->cidr, ip2, + node->bit[0]->cidr); print_node(node->bit[0], bits); } if (node->bit[1]) { - swap_endian_and_apply_cidr(ip2, node->bit[1]->bits, bits, node->cidr); - printk(fmt_connection, ip1, node->cidr, ip2, node->bit[1]->cidr); + swap_endian_and_apply_cidr(ip2, node->bit[1]->bits, bits, + node->cidr); + printk(fmt_connection, ip1, node->cidr, ip2, + node->bit[1]->cidr); print_node(node->bit[1], bits); } } @@ -79,9 +90,10 @@ static __init void horrible_allowedips_init(struct horrible_allowedips *table) } static __init void horrible_allowedips_free(struct horrible_allowedips *table) { - struct hlist_node *h; struct horrible_allowedips_node *node; - hlist_for_each_entry_safe(node, h, &table->head, table) { + struct hlist_node *h; + + hlist_for_each_entry_safe (node, h, &table->head, table) { hlist_del(&node->table); kfree(node); } @@ -89,20 +101,21 @@ static __init void horrible_allowedips_free(struct horrible_allowedips *table) static __init inline union nf_inet_addr horrible_cidr_to_mask(uint8_t cidr) { union nf_inet_addr mask; + memset(&mask, 0x00, 128 / 8); memset(&mask, 0xff, cidr / 8); if (cidr % 32) - mask.all[cidr / 32] = htonl((0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); + mask.all[cidr / 32] = htonl( + (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); return mask; } static __init inline uint8_t horrible_mask_to_cidr(union nf_inet_addr subnet) { - return hweight32(subnet.all[0]) - + hweight32(subnet.all[1]) - + hweight32(subnet.all[2]) - + hweight32(subnet.all[3]); + return hweight32(subnet.all[0]) + hweight32(subnet.all[1]) + + hweight32(subnet.all[2]) + hweight32(subnet.all[3]); } -static __init inline void horrible_mask_self(struct horrible_allowedips_node *node) +static __init inline void +horrible_mask_self(struct horrible_allowedips_node *node) { if (node->ip_version == 4) node->ip.ip &= node->mask.ip; @@ -113,24 +126,36 @@ static __init inline void horrible_mask_self(struct horrible_allowedips_node *no node->ip.ip6[3] &= node->mask.ip6[3]; } } -static __init inline bool horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip) +static __init inline bool +horrible_match_v4(const struct horrible_allowedips_node *node, + struct in_addr *ip) { return (ip->s_addr & node->mask.ip) == node->ip.ip; } -static __init inline bool horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip) +static __init inline bool +horrible_match_v6(const struct horrible_allowedips_node *node, + struct in6_addr *ip) { - return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] && - (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] && - (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] && - (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; + return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == + node->ip.ip6[0] && + (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == + node->ip.ip6[1] && + (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == + node->ip.ip6[2] && + (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; } -static __init void horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node) +static __init void +horrible_insert_ordered(struct horrible_allowedips *table, + struct horrible_allowedips_node *node) { struct horrible_allowedips_node *other = NULL, *where = NULL; uint8_t my_cidr = horrible_mask_to_cidr(node->mask); - hlist_for_each_entry(other, &table->head, table) { - if (!memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) && - !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr)) && + + hlist_for_each_entry (other, &table->head, table) { + if (!memcmp(&other->mask, &node->mask, + sizeof(union nf_inet_addr)) && + !memcmp(&other->ip, &node->ip, + sizeof(union nf_inet_addr)) && other->ip_version == node->ip_version) { other->value = node->value; kfree(node); @@ -147,9 +172,13 @@ static __init void horrible_insert_ordered(struct horrible_allowedips *table, st else hlist_add_before(&node->table, &where->table); } -static __init int horrible_allowedips_insert_v4(struct horrible_allowedips *table, struct in_addr *ip, uint8_t cidr, void *value) +static __init int +horrible_allowedips_insert_v4(struct horrible_allowedips *table, + struct in_addr *ip, uint8_t cidr, void *value) { - struct horrible_allowedips_node *node = kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL); + struct horrible_allowedips_node *node = + kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL); + if (!node) return -ENOMEM; node->ip.in = *ip; @@ -160,9 +189,13 @@ static __init int horrible_allowedips_insert_v4(struct horrible_allowedips *tabl horrible_insert_ordered(table, node); return 0; } -static __init int horrible_allowedips_insert_v6(struct horrible_allowedips *table, struct in6_addr *ip, uint8_t cidr, void *value) +static __init int +horrible_allowedips_insert_v6(struct horrible_allowedips *table, + struct in6_addr *ip, uint8_t cidr, void *value) { - struct horrible_allowedips_node *node = kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL); + struct horrible_allowedips_node *node = + kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL); + if (!node) return -ENOMEM; node->ip.in6 = *ip; @@ -173,11 +206,14 @@ static __init int horrible_allowedips_insert_v6(struct horrible_allowedips *tabl horrible_insert_ordered(table, node); return 0; } -static __init void *horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip) +static __init void * +horrible_allowedips_lookup_v4(struct horrible_allowedips *table, + struct in_addr *ip) { struct horrible_allowedips_node *node; void *ret = NULL; - hlist_for_each_entry(node, &table->head, table) { + + hlist_for_each_entry (node, &table->head, table) { if (node->ip_version != 4) continue; if (horrible_match_v4(node, ip)) { @@ -187,11 +223,14 @@ static __init void *horrible_allowedips_lookup_v4(struct horrible_allowedips *ta } return ret; } -static __init void *horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip) +static __init void * +horrible_allowedips_lookup_v6(struct horrible_allowedips *table, + struct in6_addr *ip) { struct horrible_allowedips_node *node; void *ret = NULL; - hlist_for_each_entry(node, &table->head, table) { + + hlist_for_each_entry (node, &table->head, table) { if (node->ip_version != 6) continue; if (horrible_match_v6(node, ip)) { @@ -204,13 +243,13 @@ static __init void *horrible_allowedips_lookup_v6(struct horrible_allowedips *ta static __init bool randomized_test(void) { - DEFINE_MUTEX(mutex); - bool ret = false; unsigned int i, j, k, mutate_amount, cidr; + u8 ip[16], mutate_mask[16], mutated[16]; struct wireguard_peer **peers, *peer; - struct allowedips t; struct horrible_allowedips h; - u8 ip[16], mutate_mask[16], mutated[16]; + DEFINE_MUTEX(mutex); + struct allowedips t; + bool ret = false; mutex_init(&mutex); @@ -237,11 +276,13 @@ static __init bool randomized_test(void) prandom_bytes(ip, 4); cidr = prandom_u32_max(32) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (allowedips_insert_v4(&t, (struct in_addr *)ip, cidr, peer, &mutex) < 0) { + if (allowedips_insert_v4(&t, (struct in_addr *)ip, cidr, peer, + &mutex) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } - if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, cidr, peer) < 0) { + if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, + cidr, peer) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } @@ -251,18 +292,23 @@ static __init bool randomized_test(void) mutate_amount = prandom_u32_max(32); for (k = 0; k < mutate_amount / 8; ++k) mutate_mask[k] = 0xff; - mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8); + mutate_mask[k] = 0xff + << ((8 - (mutate_amount % 8)) % 8); for (; k < 4; ++k) mutate_mask[k] = 0; for (k = 0; k < 4; ++k) - mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256)); + mutated[k] = (mutated[k] & mutate_mask[k]) | + (~mutate_mask[k] & + prandom_u32_max(256)); cidr = prandom_u32_max(32) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (allowedips_insert_v4(&t, (struct in_addr *)mutated, cidr, peer, &mutex) < 0) { + if (allowedips_insert_v4(&t, (struct in_addr *)mutated, + cidr, peer, &mutex) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } - if (horrible_allowedips_insert_v4(&h, (struct in_addr *)mutated, cidr, peer)) { + if (horrible_allowedips_insert_v4(&h, + (struct in_addr *)mutated, cidr, peer)) { pr_info("allowedips random self-test: out of memory\n"); goto free; } @@ -273,11 +319,13 @@ static __init bool randomized_test(void) prandom_bytes(ip, 16); cidr = prandom_u32_max(128) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr, peer, &mutex) < 0) { + if (allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr, peer, + &mutex) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } - if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, cidr, peer) < 0) { + if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, + cidr, peer) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } @@ -287,18 +335,24 @@ static __init bool randomized_test(void) mutate_amount = prandom_u32_max(128); for (k = 0; k < mutate_amount / 8; ++k) mutate_mask[k] = 0xff; - mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8); + mutate_mask[k] = 0xff + << ((8 - (mutate_amount % 8)) % 8); for (; k < 4; ++k) mutate_mask[k] = 0; for (k = 0; k < 4; ++k) - mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256)); + mutated[k] = (mutated[k] & mutate_mask[k]) | + (~mutate_mask[k] & + prandom_u32_max(256)); cidr = prandom_u32_max(128) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (allowedips_insert_v6(&t, (struct in6_addr *)mutated, cidr, peer, &mutex) < 0) { + if (allowedips_insert_v6(&t, (struct in6_addr *)mutated, + cidr, peer, &mutex) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } - if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)mutated, cidr, peer)) { + if (horrible_allowedips_insert_v6( + &h, (struct in6_addr *)mutated, cidr, + peer)) { pr_info("allowedips random self-test: out of memory\n"); goto free; } @@ -314,7 +368,8 @@ static __init bool randomized_test(void) for (i = 0; i < NUM_QUERIES; ++i) { prandom_bytes(ip, 4); - if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { + if (lookup(t.root4, 32, ip) != + horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { pr_info("allowedips random self-test: FAIL\n"); goto free; } @@ -322,7 +377,8 @@ static __init bool randomized_test(void) for (i = 0; i < NUM_QUERIES; ++i) { prandom_bytes(ip, 16); - if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { + if (lookup(t.root6, 128, ip) != + horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { pr_info("allowedips random self-test: FAIL\n"); goto free; } @@ -376,15 +432,22 @@ static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family) wctx->count++; - if (cidr == 27 && !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) + if (cidr == 27 && + !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) wctx->found_a = true; - else if (cidr == 128 && !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), sizeof(struct in6_addr))) + else if (cidr == 128 && + !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), + sizeof(struct in6_addr))) wctx->found_b = true; - else if (cidr == 29 && !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) + else if (cidr == 29 && + !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) wctx->found_c = true; - else if (cidr == 83 && !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), sizeof(struct in6_addr))) + else if (cidr == 83 && + !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), + sizeof(struct in6_addr))) wctx->found_d = true; - else if (cidr == 21 && !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr))) + else if (cidr == 21 && + !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr))) wctx->found_e = true; else wctx->found_other = true; @@ -392,50 +455,55 @@ static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family) return 0; } -#define init_peer(name) do { \ - name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); \ - if (!name) { \ - pr_info("allowedips self-test: out of memory\n"); \ - goto free; \ - } \ - kref_init(&name->refcount); \ -} while (0) - -#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \ - allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem, &mutex) - -#define maybe_fail \ - ++i; \ - if (!_s) { \ - pr_info("allowedips self-test %zu: FAIL\n", i); \ - success = false; \ - } - -#define test(version, mem, ipa, ipb, ipc, ipd) do { \ - bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) == mem; \ - maybe_fail \ -} while (0) - -#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \ - bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) != mem; \ - maybe_fail \ -} while (0) - -#define test_boolean(cond) do { \ - bool _s = (cond); \ - maybe_fail \ -} while (0) +#define init_peer(name) do { \ + name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); \ + if (!name) { \ + pr_info("allowedips self-test: out of memory\n"); \ + goto free; \ + } \ + kref_init(&name->refcount); \ + } while (0) + +#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \ + allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ + cidr, mem, &mutex) + +#define maybe_fail() do { \ + ++i; \ + if (!_s) { \ + pr_info("allowedips self-test %zu: FAIL\n", i); \ + success = false; \ + } \ + } while (0) + +#define test(version, mem, ipa, ipb, ipc, ipd) do { \ + bool _s = lookup(t.root##version, version == 4 ? 32 : 128, \ + ip##version(ipa, ipb, ipc, ipd)) == mem; \ + maybe_fail(); \ + } while (0) + +#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \ + bool _s = lookup(t.root##version, version == 4 ? 32 : 128, \ + ip##version(ipa, ipb, ipc, ipd)) != mem; \ + maybe_fail(); \ + } while (0) + +#define test_boolean(cond) do { \ + bool _s = (cond); \ + maybe_fail(); \ + } while (0) bool __init allowedips_selftest(void) { - DEFINE_MUTEX(mutex); - struct allowedips t; - struct walk_ctx wctx = { 0 }; + struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, + *e = NULL, *f = NULL, *g = NULL, *h = NULL; struct allowedips_cursor cursor = { 0 }; - struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL; - size_t i = 0; + struct walk_ctx wctx = { 0 }; bool success = false; + struct allowedips t; + DEFINE_MUTEX(mutex); struct in6_addr ip; + size_t i = 0; __be64 part; mutex_init(&mutex); @@ -455,19 +523,23 @@ bool __init allowedips_selftest(void) insert(4, b, 192, 168, 4, 4, 32); insert(4, c, 192, 168, 0, 0, 16); insert(4, d, 192, 95, 5, 64, 27); - insert(4, c, 192, 95, 5, 65, 27); /* replaces previous entry, and maskself is required */ + /* replaces previous entry, and maskself is required */ + insert(4, c, 192, 95, 5, 65, 27); insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128); insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64); insert(4, e, 0, 0, 0, 0, 0); insert(6, e, 0, 0, 0, 0, 0); - insert(6, f, 0, 0, 0, 0, 0); /* replaces previous entry */ + /* replaces previous entry */ + insert(6, f, 0, 0, 0, 0, 0); insert(6, g, 0x24046800, 0, 0, 0, 32); - insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); /* maskself is required */ + /* maskself is required */ + insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128); insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); insert(4, g, 64, 15, 112, 0, 20); - insert(4, h, 64, 15, 123, 211, 25); /* maskself is required */ + /* maskself is required */ + insert(4, h, 64, 15, 123, 211, 25); insert(4, a, 10, 0, 0, 0, 25); insert(4, b, 10, 0, 0, 128, 25); insert(4, a, 10, 1, 0, 0, 30); diff --git a/src/selftest/counter.h b/src/selftest/counter.h index 5344075..1c2a3b4 100644 --- a/src/selftest/counter.h +++ b/src/selftest/counter.h @@ -6,13 +6,24 @@ #ifdef DEBUG bool __init packet_counter_selftest(void) { - bool success = true; unsigned int test_num = 0, i; union noise_counter counter; + bool success = true; -#define T_INIT do { memset(&counter, 0, sizeof(union noise_counter)); spin_lock_init(&counter.receive.lock); } while (0) +#define T_INIT do { \ + memset(&counter, 0, sizeof(union noise_counter)); \ + spin_lock_init(&counter.receive.lock); \ + } while (0) #define T_LIM (COUNTER_WINDOW_SIZE + 1) -#define T(n, v) do { ++test_num; if (counter_validate(&counter, n) != v) { pr_info("nonce counter self-test %u: FAIL\n", test_num); success = false; } } while (0) +#define T(n, v) do { \ + ++test_num; \ + if (counter_validate(&counter, n) != v) { \ + pr_info("nonce counter self-test %u: FAIL\n", \ + test_num); \ + success = false; \ + } \ + } while (0) + T_INIT; /* 1 */ T(0, true); /* 2 */ T(1, true); @@ -62,22 +73,22 @@ bool __init packet_counter_selftest(void) T(0, false); T_INIT; - for (i = COUNTER_WINDOW_SIZE + 1; i-- > 0 ;) + for (i = COUNTER_WINDOW_SIZE + 1; i-- > 0;) T(i, true); T_INIT; - for (i = COUNTER_WINDOW_SIZE + 2; i-- > 1 ;) + for (i = COUNTER_WINDOW_SIZE + 2; i-- > 1;) T(i, true); T(0, false); T_INIT; - for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1 ;) + for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1;) T(i, true); T(COUNTER_WINDOW_SIZE + 1, true); T(0, false); T_INIT; - for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1 ;) + for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1;) T(i, true); T(0, true); T(COUNTER_WINDOW_SIZE + 1, true); diff --git a/src/selftest/ratelimiter.h b/src/selftest/ratelimiter.h index c05eac7..a71ddb1 100644 --- a/src/selftest/ratelimiter.h +++ b/src/selftest/ratelimiter.h @@ -7,7 +7,10 @@ #include <linux/jiffies.h> -static const struct { bool result; unsigned int msec_to_sleep_before; } expected_results[] __initconst = { +static const struct { + bool result; + unsigned int msec_to_sleep_before; +} expected_results[] __initconst = { [0 ... PACKETS_BURSTABLE - 1] = { true, 0 }, [PACKETS_BURSTABLE] = { false, 0 }, [PACKETS_BURSTABLE + 1] = { true, MSEC_PER_SEC / PACKETS_PER_SECOND }, @@ -29,14 +32,14 @@ static __init unsigned int maximum_jiffies_at_index(int index) bool __init ratelimiter_selftest(void) { - struct sk_buff *skb4; - struct iphdr *hdr4; + int i, test = 0, tries = 0, ret = false; + unsigned long loop_start_time; #if IS_ENABLED(CONFIG_IPV6) struct sk_buff *skb6; struct ipv6hdr *hdr6; #endif - int i, test = 0, tries = 0, ret = false; - unsigned long loop_start_time; + struct sk_buff *skb4; + struct iphdr *hdr4; BUILD_BUG_ON(MSEC_PER_SEC % PACKETS_PER_SECOND != 0); @@ -81,21 +84,24 @@ bool __init ratelimiter_selftest(void) restart: loop_start_time = jiffies; for (i = 0; i < ARRAY_SIZE(expected_results); ++i) { -#define ensure_time do {\ - if (time_is_before_jiffies(loop_start_time + maximum_jiffies_at_index(i))) { \ - if (++tries >= 5000) \ - goto err; \ - gc_entries(NULL); \ - rcu_barrier(); \ - msleep(500); \ - goto restart; \ - }} while (0) +#define ensure_time do { \ + if (time_is_before_jiffies(loop_start_time + \ + maximum_jiffies_at_index(i))) { \ + if (++tries >= 5000) \ + goto err; \ + gc_entries(NULL); \ + rcu_barrier(); \ + msleep(500); \ + goto restart; \ + } \ + } while (0) if (expected_results[i].msec_to_sleep_before) msleep(expected_results[i].msec_to_sleep_before); ensure_time; - if (ratelimiter_allow(skb4, &init_net) != expected_results[i].result) + if (ratelimiter_allow(skb4, &init_net) != + expected_results[i].result) goto err; ++test; hdr4->saddr = htonl(ntohl(hdr4->saddr) + i + 1); @@ -106,17 +112,21 @@ restart: hdr4->saddr = htonl(ntohl(hdr4->saddr) - i - 1); #if IS_ENABLED(CONFIG_IPV6) - hdr6->saddr.in6_u.u6_addr32[2] = hdr6->saddr.in6_u.u6_addr32[3] = htonl(i); + hdr6->saddr.in6_u.u6_addr32[2] = + hdr6->saddr.in6_u.u6_addr32[3] = htonl(i); ensure_time; - if (ratelimiter_allow(skb6, &init_net) != expected_results[i].result) + if (ratelimiter_allow(skb6, &init_net) != + expected_results[i].result) goto err; ++test; - hdr6->saddr.in6_u.u6_addr32[0] = htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) + i + 1); + hdr6->saddr.in6_u.u6_addr32[0] = + htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) + i + 1); ensure_time; if (!ratelimiter_allow(skb6, &init_net)) goto err; ++test; - hdr6->saddr.in6_u.u6_addr32[0] = htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) - i - 1); + hdr6->saddr.in6_u.u6_addr32[0] = + htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) - i - 1); ensure_time; #endif } @@ -23,46 +23,63 @@ static void packet_send_handshake_initiation(struct wireguard_peer *peer) { struct message_handshake_initiation packet; - if (!has_expired(atomic64_read(&peer->last_sent_handshake), REKEY_TIMEOUT)) + if (!has_expired(atomic64_read(&peer->last_sent_handshake), + REKEY_TIMEOUT)) return; /* This function is rate limited. */ atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns()); - net_dbg_ratelimited("%s: Sending handshake initiation to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Sending handshake initiation to peer %llu (%pISpfsc)\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr); if (noise_handshake_create_initiation(&packet, &peer->handshake)) { cookie_add_mac_to_packet(&packet, sizeof(packet), peer); timers_any_authenticated_packet_traversal(peer); timers_any_authenticated_packet_sent(peer); - atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns()); - socket_send_buffer_to_peer(peer, &packet, sizeof(struct message_handshake_initiation), HANDSHAKE_DSCP); + atomic64_set(&peer->last_sent_handshake, + ktime_get_boot_fast_ns()); + socket_send_buffer_to_peer( + peer, &packet, + sizeof(struct message_handshake_initiation), + HANDSHAKE_DSCP); timers_handshake_initiated(peer); } } void packet_handshake_send_worker(struct work_struct *work) { - struct wireguard_peer *peer = container_of(work, struct wireguard_peer, transmit_handshake_work); + struct wireguard_peer *peer = container_of(work, struct wireguard_peer, + transmit_handshake_work); packet_send_handshake_initiation(peer); peer_put(peer); } -void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool is_retry) +void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, + bool is_retry) { if (!is_retry) peer->timer_handshake_attempts = 0; rcu_read_lock_bh(); - /* We check last_sent_handshake here in addition to the actual function we're queueing - * up, so that we don't queue things if not strictly necessary. + /* We check last_sent_handshake here in addition to the actual function + * we're queueing up, so that we don't queue things if not strictly + * necessary: */ - if (!has_expired(atomic64_read(&peer->last_sent_handshake), REKEY_TIMEOUT) || unlikely(peer->is_dead)) + if (!has_expired(atomic64_read(&peer->last_sent_handshake), + REKEY_TIMEOUT) || unlikely(peer->is_dead)) goto out; peer_get(peer); - /* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */ - if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work)) - peer_put(peer); /* If the work was already queued, we want to drop the extra reference */ + /* Queues up calling packet_send_queued_handshakes(peer), where we do a + * peer_put(peer) after: + */ + if (!queue_work(peer->device->handshake_send_wq, + &peer->transmit_handshake_work)) + /* If the work was already queued, we want to drop the + * extra reference: + */ + peer_put(peer); out: rcu_read_unlock_bh(); } @@ -72,27 +89,39 @@ void packet_send_handshake_response(struct wireguard_peer *peer) struct message_handshake_response packet; atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns()); - net_dbg_ratelimited("%s: Sending handshake response to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Sending handshake response to peer %llu (%pISpfsc)\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr); if (noise_handshake_create_response(&packet, &peer->handshake)) { cookie_add_mac_to_packet(&packet, sizeof(packet), peer); - if (noise_handshake_begin_session(&peer->handshake, &peer->keypairs)) { + if (noise_handshake_begin_session(&peer->handshake, + &peer->keypairs)) { timers_session_derived(peer); timers_any_authenticated_packet_traversal(peer); timers_any_authenticated_packet_sent(peer); - atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns()); - socket_send_buffer_to_peer(peer, &packet, sizeof(struct message_handshake_response), HANDSHAKE_DSCP); + atomic64_set(&peer->last_sent_handshake, + ktime_get_boot_fast_ns()); + socket_send_buffer_to_peer( + peer, &packet, + sizeof(struct message_handshake_response), + HANDSHAKE_DSCP); } } } -void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index) +void packet_send_handshake_cookie(struct wireguard_device *wg, + struct sk_buff *initiating_skb, + __le32 sender_index) { struct message_handshake_cookie packet; - net_dbg_skb_ratelimited("%s: Sending cookie response for denied handshake message for %pISpfsc\n", wg->dev->name, initiating_skb); - cookie_message_create(&packet, initiating_skb, sender_index, &wg->cookie_checker); - socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, sizeof(packet)); + net_dbg_skb_ratelimited("%s: Sending cookie response for denied handshake message for %pISpfsc\n", + wg->dev->name, initiating_skb); + cookie_message_create(&packet, initiating_skb, sender_index, + &wg->cookie_checker); + socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, + sizeof(packet)); } static inline void keep_key_fresh(struct wireguard_peer *peer) @@ -103,8 +132,11 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) rcu_read_lock_bh(); keypair = rcu_dereference_bh(peer->keypairs.current_keypair); if (likely(keypair && keypair->sending.is_valid) && - (unlikely(atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES) || - (keypair->i_am_the_initiator && unlikely(has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME))))) + (unlikely(atomic64_read(&keypair->sending.counter.counter) > + REKEY_AFTER_MESSAGES) || + (keypair->i_am_the_initiator && + unlikely(has_expired(keypair->sending.birthdate, + REKEY_AFTER_TIME))))) send = true; rcu_read_unlock_bh(); @@ -114,9 +146,10 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) static inline unsigned int skb_padding(struct sk_buff *skb) { - /* We do this modulo business with the MTU, just in case the networking layer - * gives us a packet that's bigger than the MTU. In that case, we wouldn't want - * the final subtraction to overflow in the case of the padded_size being clamped. + /* We do this modulo business with the MTU, just in case the networking + * layer gives us a packet that's bigger than the MTU. In that case, we + * wouldn't want the final subtraction to overflow in the case of the + * padded_size being clamped. */ unsigned int last_unit = skb->len % PACKET_CB(skb)->mtu; unsigned int padded_size = ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE); @@ -126,38 +159,49 @@ static inline unsigned int skb_padding(struct sk_buff *skb) return padded_size - last_unit; } -static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, simd_context_t simd_context) +static inline bool skb_encrypt(struct sk_buff *skb, + struct noise_keypair *keypair, + simd_context_t simd_context) { + unsigned int padding_len, plaintext_len, trailer_len; struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1]; struct message_data *header; - unsigned int padding_len, plaintext_len, trailer_len; - int num_frags; struct sk_buff *trailer; + int num_frags; - /* Calculate lengths */ + /* Calculate lengths. */ padding_len = skb_padding(skb); trailer_len = padding_len + noise_encrypted_len(0); plaintext_len = skb->len + padding_len; - /* Expand data section to have room for padding and auth tag */ + /* Expand data section to have room for padding and auth tag. */ num_frags = skb_cow_data(skb, trailer_len, &trailer); if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg))) return false; - /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ + /* Set the padding to zeros, and make sure it and the auth tag are part + * of the skb. + */ memset(skb_tail_pointer(trailer), 0, padding_len); - /* Expand head section to have room for our header and the network stack's headers. */ + /* Expand head section to have room for our header and the network + * stack's headers. + */ if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0)) return false; - /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */ + /* We have to remember to add the checksum to the innerpacket, in case + * the receiver forwards it. + */ if (likely(!skb_checksum_setup(skb, true))) skb_checksum_help(skb); - /* Only after checksumming can we safely add on the padding at the end and the header. */ + /* Only after checksumming can we safely add on the padding at the end + * and the header. + */ skb_set_inner_network_header(skb, 0); - header = (struct message_data *)skb_push(skb, sizeof(struct message_data)); + header = (struct message_data *)skb_push(skb, + sizeof(struct message_data)); header->header.type = cpu_to_le32(MESSAGE_DATA); header->key_idx = keypair->remote_index; header->counter = cpu_to_le64(PACKET_CB(skb)->nonce); @@ -165,9 +209,12 @@ static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypai /* Now we can encrypt the scattergather segments */ sg_init_table(sg, num_frags); - if (skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(plaintext_len)) <= 0) + if (skb_to_sgvec(skb, sg, sizeof(struct message_data), + noise_encrypted_len(plaintext_len)) <= 0) return false; - return chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, keypair->sending.key, simd_context); + return chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, + PACKET_CB(skb)->nonce, + keypair->sending.key, simd_context); } void packet_send_keepalive(struct wireguard_peer *peer) @@ -175,38 +222,45 @@ void packet_send_keepalive(struct wireguard_peer *peer) struct sk_buff *skb; if (skb_queue_empty(&peer->staged_packet_queue)) { - skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, GFP_ATOMIC); + skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, + GFP_ATOMIC); if (unlikely(!skb)) return; skb_reserve(skb, DATA_PACKET_HEAD_ROOM); skb->dev = peer->device->dev; PACKET_CB(skb)->mtu = skb->dev->mtu; skb_queue_tail(&peer->staged_packet_queue, skb); - net_dbg_ratelimited("%s: Sending keepalive packet to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Sending keepalive packet to peer %llu (%pISpfsc)\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr); } packet_send_staged_packets(peer); } -#define skb_walk_null_queue_safe(first, skb, next) for (skb = first, next = skb->next; skb; skb = next, next = skb ? skb->next : NULL) +#define skb_walk_null_queue_safe(first, skb, next) \ + for (skb = first, next = skb->next; skb; \ + skb = next, next = skb ? skb->next : NULL) static inline void skb_free_null_queue(struct sk_buff *first) { struct sk_buff *skb, *next; - skb_walk_null_queue_safe(first, skb, next) + skb_walk_null_queue_safe (first, skb, next) dev_kfree_skb(skb); } -static void packet_create_data_done(struct sk_buff *first, struct wireguard_peer *peer) +static void packet_create_data_done(struct sk_buff *first, + struct wireguard_peer *peer) { struct sk_buff *skb, *next; bool is_keepalive, data_sent = false; timers_any_authenticated_packet_traversal(peer); timers_any_authenticated_packet_sent(peer); - skb_walk_null_queue_safe(first, skb, next) { + skb_walk_null_queue_safe (first, skb, next) { is_keepalive = skb->len == message_data_len(0); - if (likely(!socket_send_skb_to_peer(peer, skb, PACKET_CB(skb)->ds) && !is_keepalive)) + if (likely(!socket_send_skb_to_peer(peer, skb, + PACKET_CB(skb)->ds) && !is_keepalive)) data_sent = true; } @@ -218,13 +272,16 @@ static void packet_create_data_done(struct sk_buff *first, struct wireguard_peer void packet_tx_worker(struct work_struct *work) { - struct crypt_queue *queue = container_of(work, struct crypt_queue, work); + struct crypt_queue *queue = + container_of(work, struct crypt_queue, work); struct wireguard_peer *peer; struct noise_keypair *keypair; struct sk_buff *first; enum packet_state state; - while ((first = __ptr_ring_peek(&queue->ring)) != NULL && (state = atomic_read_acquire(&PACKET_CB(first)->state)) != PACKET_STATE_UNCRYPTED) { + while ((first = __ptr_ring_peek(&queue->ring)) != NULL && + (state = atomic_read_acquire(&PACKET_CB(first)->state)) != + PACKET_STATE_UNCRYPTED) { __ptr_ring_discard_one(&queue->ring); peer = PACKET_PEER(first); keypair = PACKET_CB(first)->keypair; @@ -241,22 +298,25 @@ void packet_tx_worker(struct work_struct *work) void packet_encrypt_worker(struct work_struct *work) { - struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; + struct crypt_queue *queue = + container_of(work, struct multicore_worker, work)->ptr; struct sk_buff *first, *skb, *next; simd_context_t simd_context = simd_get(); while ((first = ptr_ring_consume_bh(&queue->ring)) != NULL) { enum packet_state state = PACKET_STATE_CRYPTED; - skb_walk_null_queue_safe(first, skb, next) { - if (likely(skb_encrypt(skb, PACKET_CB(first)->keypair, simd_context))) + skb_walk_null_queue_safe (first, skb, next) { + if (likely(skb_encrypt(skb, PACKET_CB(first)->keypair, + simd_context))) skb_reset(skb); else { state = PACKET_STATE_DEAD; break; } } - queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, state); + queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, + state); simd_context = simd_relax(simd_context); } @@ -273,9 +333,13 @@ static void packet_create_data(struct sk_buff *first) if (unlikely(peer->is_dead)) goto err; - ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); + ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, + &peer->tx_queue, first, + wg->packet_crypt_wq, + &wg->encrypt_queue.last_cpu); if (unlikely(ret == -EPIPE)) - queue_enqueue_per_peer(&peer->tx_queue, first, PACKET_STATE_DEAD); + queue_enqueue_per_peer(&peer->tx_queue, first, + PACKET_STATE_DEAD); err: rcu_read_unlock_bh(); if (likely(!ret || ret == -EPIPE)) @@ -287,8 +351,8 @@ err: void packet_send_staged_packets(struct wireguard_peer *peer) { - struct noise_keypair *keypair; struct noise_symmetric_key *key; + struct noise_keypair *keypair; struct sk_buff_head packets; struct sk_buff *skb; @@ -302,7 +366,8 @@ void packet_send_staged_packets(struct wireguard_peer *peer) /* First we make sure we have a valid reference to a valid key. */ rcu_read_lock_bh(); - keypair = noise_keypair_get(rcu_dereference_bh(peer->keypairs.current_keypair)); + keypair = noise_keypair_get( + rcu_dereference_bh(peer->keypairs.current_keypair)); rcu_read_unlock_bh(); if (unlikely(!keypair)) goto out_nokey; @@ -312,13 +377,17 @@ void packet_send_staged_packets(struct wireguard_peer *peer) if (unlikely(has_expired(key->birthdate, REJECT_AFTER_TIME))) goto out_invalid; - /* After we know we have a somewhat valid key, we now try to assign nonces to - * all of the packets in the queue. If we can't assign nonces for all of them, - * we just consider it a failure and wait for the next handshake. + /* After we know we have a somewhat valid key, we now try to assign + * nonces to all of the packets in the queue. If we can't assign nonces + * for all of them, we just consider it a failure and wait for the next + * handshake. */ - skb_queue_walk(&packets, skb) { - PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); - PACKET_CB(skb)->nonce = atomic64_inc_return(&key->counter.counter) - 1; + skb_queue_walk (&packets, skb) { + /* 0 for no outer TOS: no leak. TODO: should we use flowi->tos + * as outer? */ + PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb); + PACKET_CB(skb)->nonce = + atomic64_inc_return(&key->counter.counter) - 1; if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES)) goto out_invalid; } @@ -337,19 +406,19 @@ out_nokey: /* We orphan the packets if we're waiting on a handshake, so that they * don't block a socket's pool. */ - skb_queue_walk(&packets, skb) + skb_queue_walk (&packets, skb) skb_orphan(skb); - /* Then we put them back on the top of the queue. We're not too concerned about - * accidentally getting things a little out of order if packets are being added - * really fast, because this queue is for before packets can even be sent and - * it's small anyway. + /* Then we put them back on the top of the queue. We're not too + * concerned about accidentally getting things a little out of order if + * packets are being added really fast, because this queue is for before + * packets can even be sent and it's small anyway. */ spin_lock_bh(&peer->staged_packet_queue.lock); skb_queue_splice(&packets, &peer->staged_packet_queue); spin_unlock_bh(&peer->staged_packet_queue.lock); - /* If we're exiting because there's something wrong with the key, it means - * we should initiate a new handshake. + /* If we're exiting because there's something wrong with the key, it + * means we should initiate a new handshake. */ packet_send_queued_handshake_initiation(peer, false); } diff --git a/src/socket.c b/src/socket.c index 663a462..f544dd0 100644 --- a/src/socket.c +++ b/src/socket.c @@ -17,7 +17,9 @@ #include <net/udp_tunnel.h> #include <net/ipv6.h> -static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, u8 ds, struct dst_cache *cache) +static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, + struct endpoint *endpoint, u8 ds, + struct dst_cache *cache) { struct flowi4 fl = { .saddr = endpoint->src4.s_addr, @@ -49,14 +51,21 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct if (!rt) { security_sk_classify_flow(sock, flowi4_to_flowi(&fl)); - if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, fl.saddr, RT_SCOPE_HOST))) { - endpoint->src4.s_addr = *(__force __be32 *)&endpoint->src_if4 = fl.saddr = 0; + if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, + fl.saddr, RT_SCOPE_HOST))) { + endpoint->src4.s_addr = 0; + *(__force __be32 *)&endpoint->src_if4 = 0; + fl.saddr = 0; if (cache) dst_cache_reset(cache); } rt = ip_route_output_flow(sock_net(sock), &fl, sock); - if (unlikely(endpoint->src_if4 && ((IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && rt->dst.dev->ifindex != endpoint->src_if4)))) { - endpoint->src4.s_addr = *(__force __be32 *)&endpoint->src_if4 = fl.saddr = 0; + if (unlikely(endpoint->src_if4 && ((IS_ERR(rt) && + PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && + rt->dst.dev->ifindex != endpoint->src_if4)))) { + endpoint->src4.s_addr = 0; + *(__force __be32 *)&endpoint->src_if4 = 0; + fl.saddr = 0; if (cache) dst_cache_reset(cache); if (!IS_ERR(rt)) @@ -65,18 +74,22 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct } if (unlikely(IS_ERR(rt))) { ret = PTR_ERR(rt); - net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", wg->dev->name, &endpoint->addr, ret); + net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", + wg->dev->name, &endpoint->addr, ret); goto err; } else if (unlikely(rt->dst.dev == skb->dev)) { ip_rt_put(rt); ret = -ELOOP; - net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", wg->dev->name, &endpoint->addr); + net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", + wg->dev->name, &endpoint->addr); goto err; } if (cache) dst_cache_set_ip4(cache, &rt->dst, fl.saddr); } - udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport, fl.fl4_dport, false, false); + udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, + ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport, + fl.fl4_dport, false, false); goto out; err: @@ -86,7 +99,9 @@ out: return ret; } -static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, u8 ds, struct dst_cache *cache) +static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, + struct endpoint *endpoint, u8 ds, + struct dst_cache *cache) { #if IS_ENABLED(CONFIG_IPV6) struct flowi6 fl = { @@ -121,26 +136,32 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct if (!dst) { security_sk_classify_flow(sock, flowi6_to_flowi(&fl)); - if (unlikely(!ipv6_addr_any(&fl.saddr) && !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) { + if (unlikely(!ipv6_addr_any(&fl.saddr) && + !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) { endpoint->src6 = fl.saddr = in6addr_any; if (cache) dst_cache_reset(cache); } - ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst, &fl); + ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst, + &fl); if (unlikely(ret)) { - net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", wg->dev->name, &endpoint->addr, ret); + net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", + wg->dev->name, &endpoint->addr, ret); goto err; } else if (unlikely(dst->dev == skb->dev)) { dst_release(dst); ret = -ELOOP; - net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", wg->dev->name, &endpoint->addr); + net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", + wg->dev->name, &endpoint->addr); goto err; } if (cache) dst_cache_set_ip6(cache, dst, &fl.saddr); } - udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, ip6_dst_hoplimit(dst), 0, fl.fl6_sport, fl.fl6_dport, false); + udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, + ip6_dst_hoplimit(dst), 0, fl.fl6_sport, + fl.fl6_dport, false); goto out; err: @@ -153,16 +174,19 @@ out: #endif } -int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds) +int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, + u8 ds) { size_t skb_len = skb->len; int ret = -EAFNOSUPPORT; read_lock_bh(&peer->endpoint_lock); if (peer->endpoint.addr.sa_family == AF_INET) - ret = send4(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache); + ret = send4(peer->device, skb, &peer->endpoint, ds, + &peer->endpoint_cache); else if (peer->endpoint.addr.sa_family == AF_INET6) - ret = send6(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache); + ret = send6(peer->device, skb, &peer->endpoint, ds, + &peer->endpoint_cache); else dev_kfree_skb(skb); if (likely(!ret)) @@ -172,7 +196,8 @@ int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 return ret; } -int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t len, u8 ds) +int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, + size_t len, u8 ds) { struct sk_buff *skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC); @@ -185,7 +210,9 @@ int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t return socket_send_skb_to_peer(peer, skb, ds); } -int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *buffer, size_t len) +int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, + struct sk_buff *in_skb, void *buffer, + size_t len) { int ret = 0; struct sk_buff *skb; @@ -208,12 +235,15 @@ int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_bu ret = send4(wg, skb, &endpoint, 0, NULL); else if (endpoint.addr.sa_family == AF_INET6) ret = send6(wg, skb, &endpoint, 0, NULL); - /* No other possibilities if the endpoint is valid, which it is, as we checked above. */ + /* No other possibilities if the endpoint is valid, which it is, + * as we checked above. + */ return ret; } -int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb) +int socket_endpoint_from_skb(struct endpoint *endpoint, + const struct sk_buff *skb) { memset(endpoint, 0, sizeof(struct endpoint)); if (skb->protocol == htons(ETH_P_IP)) { @@ -226,25 +256,32 @@ int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *sk endpoint->addr6.sin6_family = AF_INET6; endpoint->addr6.sin6_port = udp_hdr(skb)->source; endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr; - endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr, skb->skb_iif); + endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id( + &ipv6_hdr(skb)->saddr, skb->skb_iif); endpoint->src6 = ipv6_hdr(skb)->daddr; } else return -EINVAL; return 0; } -static inline bool endpoint_eq(const struct endpoint *a, const struct endpoint *b) +static inline bool endpoint_eq(const struct endpoint *a, + const struct endpoint *b) { return (a->addr.sa_family == AF_INET && b->addr.sa_family == AF_INET && - a->addr4.sin_port == b->addr4.sin_port && a->addr4.sin_addr.s_addr == b->addr4.sin_addr.s_addr && + a->addr4.sin_port == b->addr4.sin_port && + a->addr4.sin_addr.s_addr == b->addr4.sin_addr.s_addr && a->src4.s_addr == b->src4.s_addr && a->src_if4 == b->src_if4) || - (a->addr.sa_family == AF_INET6 && b->addr.sa_family == AF_INET6 && - a->addr6.sin6_port == b->addr6.sin6_port && ipv6_addr_equal(&a->addr6.sin6_addr, &b->addr6.sin6_addr) && - a->addr6.sin6_scope_id == b->addr6.sin6_scope_id && ipv6_addr_equal(&a->src6, &b->src6)) || + (a->addr.sa_family == AF_INET6 && + b->addr.sa_family == AF_INET6 && + a->addr6.sin6_port == b->addr6.sin6_port && + ipv6_addr_equal(&a->addr6.sin6_addr, &b->addr6.sin6_addr) && + a->addr6.sin6_scope_id == b->addr6.sin6_scope_id && + ipv6_addr_equal(&a->src6, &b->src6)) || unlikely(!a->addr.sa_family && !b->addr.sa_family); } -void socket_set_peer_endpoint(struct wireguard_peer *peer, const struct endpoint *endpoint) +void socket_set_peer_endpoint(struct wireguard_peer *peer, + const struct endpoint *endpoint) { /* First we check unlocked, in order to optimize, since it's pretty rare * that an endpoint will change. If we happen to be mid-write, and two @@ -268,7 +305,8 @@ out: write_unlock_bh(&peer->endpoint_lock); } -void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, const struct sk_buff *skb) +void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, + const struct sk_buff *skb) { struct endpoint endpoint; @@ -362,7 +400,8 @@ retry: udp_tunnel_sock_release(new4); if (ret == -EADDRINUSE && !port && retries++ < 100) goto retry; - pr_err("%s: Could not create IPv6 socket\n", wg->dev->name); + pr_err("%s: Could not create IPv6 socket\n", + wg->dev->name); return ret; } set_sock_opts(new6); @@ -374,13 +413,16 @@ retry: return 0; } -void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6) +void socket_reinit(struct wireguard_device *wg, struct sock *new4, + struct sock *new6) { struct sock *old4, *old6; mutex_lock(&wg->socket_update_lock); - old4 = rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock)); - old6 = rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock)); + old4 = rcu_dereference_protected(wg->sock4, + lockdep_is_held(&wg->socket_update_lock)); + old6 = rcu_dereference_protected(wg->sock6, + lockdep_is_held(&wg->socket_update_lock)); rcu_assign_pointer(wg->sock4, new4); rcu_assign_pointer(wg->sock6, new6); if (new4) diff --git a/src/socket.h b/src/socket.h index 9faba77..d873ffa 100644 --- a/src/socket.h +++ b/src/socket.h @@ -12,22 +12,31 @@ #include <linux/if_ether.h> int socket_init(struct wireguard_device *wg, u16 port); -void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6); -int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, size_t len, u8 ds); -int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds); -int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *out_buffer, size_t len); +void socket_reinit(struct wireguard_device *wg, struct sock *new4, + struct sock *new6); +int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, + size_t len, u8 ds); +int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, + u8 ds); +int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, + struct sk_buff *in_skb, void *out_buffer, + size_t len); -int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb); -void socket_set_peer_endpoint(struct wireguard_peer *peer, const struct endpoint *endpoint); -void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, const struct sk_buff *skb); +int socket_endpoint_from_skb(struct endpoint *endpoint, + const struct sk_buff *skb); +void socket_set_peer_endpoint(struct wireguard_peer *peer, + const struct endpoint *endpoint); +void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, + const struct sk_buff *skb); void socket_clear_peer_endpoint_src(struct wireguard_peer *peer); #if defined(CONFIG_DYNAMIC_DEBUG) || defined(DEBUG) -#define net_dbg_skb_ratelimited(fmt, dev, skb, ...) do { \ - struct endpoint __endpoint; \ - socket_endpoint_from_skb(&__endpoint, skb); \ - net_dbg_ratelimited(fmt, dev, &__endpoint.addr, ##__VA_ARGS__); \ -} while (0) +#define net_dbg_skb_ratelimited(fmt, dev, skb, ...) do { \ + struct endpoint __endpoint; \ + socket_endpoint_from_skb(&__endpoint, skb); \ + net_dbg_ratelimited(fmt, dev, &__endpoint.addr, \ + ##__VA_ARGS__); \ + } while (0) #else #define net_dbg_skb_ratelimited(fmt, skb, ...) #endif diff --git a/src/timers.c b/src/timers.c index 7db892a..ad76466 100644 --- a/src/timers.c +++ b/src/timers.c @@ -10,22 +10,33 @@ #include "socket.h" /* - * Timer for retransmitting the handshake if we don't hear back after `REKEY_TIMEOUT + jitter` ms - * Timer for sending empty packet if we have received a packet but after have not sent one for `KEEPALIVE_TIMEOUT` ms - * Timer for initiating new handshake if we have sent a packet but after have not received one (even empty) for `(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT)` ms - * Timer for zeroing out all ephemeral keys after `(REJECT_AFTER_TIME * 3)` ms if no new keys have been received - * Timer for, if enabled, sending an empty authenticated packet every user-specified seconds + * - Timer for retransmitting the handshake if we don't hear back after + * `REKEY_TIMEOUT + jitter` ms. + * + * - Timer for sending empty packet if we have received a packet but after have + * not sent one for `KEEPALIVE_TIMEOUT` ms. + * + * - Timer for initiating new handshake if we have sent a packet but after have + * not received one (even empty) for `(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT)` ms. + * + * - Timer for zeroing out all ephemeral keys after `(REJECT_AFTER_TIME * 3)` ms + * if no new keys have been received. + * + * - Timer for, if enabled, sending an empty authenticated packet every user- + * specified seconds. */ -#define peer_get_from_timer(timer_name) \ - struct wireguard_peer *peer; \ - rcu_read_lock_bh(); \ - peer = peer_get_maybe_zero(from_timer(peer, timer, timer_name)); \ - rcu_read_unlock_bh(); \ - if (unlikely(!peer)) \ +#define peer_get_from_timer(timer_name) \ + struct wireguard_peer *peer; \ + rcu_read_lock_bh(); \ + peer = peer_get_maybe_zero(from_timer(peer, timer, timer_name)); \ + rcu_read_unlock_bh(); \ + if (unlikely(!peer)) \ return; -static inline void mod_peer_timer(struct wireguard_peer *peer, struct timer_list *timer, unsigned long expires) +static inline void mod_peer_timer(struct wireguard_peer *peer, + struct timer_list *timer, + unsigned long expires) { rcu_read_lock_bh(); if (likely(netif_running(peer->device->dev) && !peer->is_dead)) @@ -33,7 +44,8 @@ static inline void mod_peer_timer(struct wireguard_peer *peer, struct timer_list rcu_read_unlock_bh(); } -static inline void del_peer_timer(struct wireguard_peer *peer, struct timer_list *timer) +static inline void del_peer_timer(struct wireguard_peer *peer, + struct timer_list *timer) { rcu_read_lock_bh(); if (likely(netif_running(peer->device->dev) && !peer->is_dead)) @@ -46,7 +58,9 @@ static void expired_retransmit_handshake(struct timer_list *timer) peer_get_from_timer(timer_retransmit_handshake); if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) { - pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2); + pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2); del_peer_timer(peer, &peer->timer_send_keepalive); /* We drop all packets without a keypair and don't try again, @@ -58,12 +72,18 @@ static void expired_retransmit_handshake(struct timer_list *timer) * of a partial exchange. */ if (!timer_pending(&peer->timer_zero_key_material)) - mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ); + mod_peer_timer(peer, &peer->timer_zero_key_material, + jiffies + REJECT_AFTER_TIME * 3 * HZ); } else { ++peer->timer_handshake_attempts; - pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, REKEY_TIMEOUT, peer->timer_handshake_attempts + 1); + pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr, REKEY_TIMEOUT, + peer->timer_handshake_attempts + 1); - /* We clear the endpoint address src address, in case this is the cause of trouble. */ + /* We clear the endpoint address src address, in case this is + * the cause of trouble. + */ socket_clear_peer_endpoint_src(peer); packet_send_queued_handshake_initiation(peer, true); @@ -78,7 +98,8 @@ static void expired_send_keepalive(struct timer_list *timer) packet_send_keepalive(peer); if (peer->timer_need_another_keepalive) { peer->timer_need_another_keepalive = false; - mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ); + mod_peer_timer(peer, &peer->timer_send_keepalive, + jiffies + KEEPALIVE_TIMEOUT * HZ); } peer_put(peer); } @@ -87,8 +108,12 @@ static void expired_new_handshake(struct timer_list *timer) { peer_get_from_timer(timer_new_handshake); - pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); - /* We clear the endpoint address src address, in case this is the cause of trouble. */ + pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr, KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + /* We clear the endpoint address src address, in case this is the cause + * of trouble. + */ socket_clear_peer_endpoint_src(peer); packet_send_queued_handshake_initiation(peer, false); peer_put(peer); @@ -100,16 +125,22 @@ static void expired_zero_key_material(struct timer_list *timer) rcu_read_lock_bh(); if (!peer->is_dead) { - if (!queue_work(peer->device->handshake_send_wq, &peer->clear_peer_work)) /* Should take our reference. */ - peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */ + /* Should take our reference. */ + if (!queue_work(peer->device->handshake_send_wq, + &peer->clear_peer_work)) + /* If the work was already on the queue, we want to drop the extra reference */ + peer_put(peer); } rcu_read_unlock_bh(); } static void queued_expired_zero_key_material(struct work_struct *work) { - struct wireguard_peer *peer = container_of(work, struct wireguard_peer, clear_peer_work); + struct wireguard_peer *peer = + container_of(work, struct wireguard_peer, clear_peer_work); - pr_debug("%s: Zeroing out all keys for peer %llu (%pISpfsc), since we haven't received a new one in %d seconds\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, REJECT_AFTER_TIME * 3); + pr_debug("%s: Zeroing out all keys for peer %llu (%pISpfsc), since we haven't received a new one in %d seconds\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr, REJECT_AFTER_TIME * 3); noise_handshake_clear(&peer->handshake); noise_keypairs_clear(&peer->keypairs); peer_put(peer); @@ -128,7 +159,8 @@ static void expired_send_persistent_keepalive(struct timer_list *timer) void timers_data_sent(struct wireguard_peer *peer) { if (!timer_pending(&peer->timer_new_handshake)) - mod_peer_timer(peer, &peer->timer_new_handshake, jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ); + mod_peer_timer(peer, &peer->timer_new_handshake, + jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ); } /* Should be called after an authenticated data packet is received. */ @@ -136,19 +168,24 @@ void timers_data_received(struct wireguard_peer *peer) { if (likely(netif_running(peer->device->dev))) { if (!timer_pending(&peer->timer_send_keepalive)) - mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ); + mod_peer_timer(peer, &peer->timer_send_keepalive, + jiffies + KEEPALIVE_TIMEOUT * HZ); else peer->timer_need_another_keepalive = true; } } -/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ +/* Should be called after any type of authenticated packet is sent, whether + * keepalive, data, or handshake. +*/ void timers_any_authenticated_packet_sent(struct wireguard_peer *peer) { del_peer_timer(peer, &peer->timer_send_keepalive); } -/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ +/* Should be called after any type of authenticated packet is received, whether + * keepalive, data, or handshake. + */ void timers_any_authenticated_packet_received(struct wireguard_peer *peer) { del_peer_timer(peer, &peer->timer_new_handshake); @@ -157,10 +194,15 @@ void timers_any_authenticated_packet_received(struct wireguard_peer *peer) /* Should be called after a handshake initiation message is sent. */ void timers_handshake_initiated(struct wireguard_peer *peer) { - mod_peer_timer(peer, &peer->timer_retransmit_handshake, jiffies + REKEY_TIMEOUT * HZ + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); + mod_peer_timer( + peer, &peer->timer_retransmit_handshake, + jiffies + REKEY_TIMEOUT * HZ + + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); } -/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ +/* Should be called after a handshake response message is received and processed + * or when getting key confirmation via the first data message. + */ void timers_handshake_complete(struct wireguard_peer *peer) { del_peer_timer(peer, &peer->timer_retransmit_handshake); @@ -169,26 +211,34 @@ void timers_handshake_complete(struct wireguard_peer *peer) getnstimeofday(&peer->walltime_last_handshake); } -/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ +/* Should be called after an ephemeral key is created, which is before sending a + * handshake response or after receiving a handshake response. + */ void timers_session_derived(struct wireguard_peer *peer) { - mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ); + mod_peer_timer(peer, &peer->timer_zero_key_material, + jiffies + REJECT_AFTER_TIME * 3 * HZ); } -/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ +/* Should be called before a packet with authentication, whether + * keepalive, data, or handshakem is sent, or after one is received. + */ void timers_any_authenticated_packet_traversal(struct wireguard_peer *peer) { if (peer->persistent_keepalive_interval) - mod_peer_timer(peer, &peer->timer_persistent_keepalive, jiffies + peer->persistent_keepalive_interval * HZ); + mod_peer_timer(peer, &peer->timer_persistent_keepalive, + jiffies + peer->persistent_keepalive_interval * HZ); } void timers_init(struct wireguard_peer *peer) { - timer_setup(&peer->timer_retransmit_handshake, expired_retransmit_handshake, 0); + timer_setup(&peer->timer_retransmit_handshake, + expired_retransmit_handshake, 0); timer_setup(&peer->timer_send_keepalive, expired_send_keepalive, 0); timer_setup(&peer->timer_new_handshake, expired_new_handshake, 0); timer_setup(&peer->timer_zero_key_material, expired_zero_key_material, 0); - timer_setup(&peer->timer_persistent_keepalive, expired_send_persistent_keepalive, 0); + timer_setup(&peer->timer_persistent_keepalive, + expired_send_persistent_keepalive, 0); INIT_WORK(&peer->clear_peer_work, queued_expired_zero_key_material); peer->timer_handshake_attempts = 0; peer->sent_lastminute_handshake = false; diff --git a/src/timers.h b/src/timers.h index c95bf31..483529c 100644 --- a/src/timers.h +++ b/src/timers.h @@ -23,7 +23,8 @@ void timers_any_authenticated_packet_traversal(struct wireguard_peer *peer); static inline bool has_expired(u64 birthday_nanoseconds, u64 expiration_seconds) { - return (s64)(birthday_nanoseconds + expiration_seconds * NSEC_PER_SEC) <= (s64)ktime_get_boot_fast_ns(); + return (s64)(birthday_nanoseconds + expiration_seconds * NSEC_PER_SEC) + <= (s64)ktime_get_boot_fast_ns(); } #endif /* _WG_TIMERS_H */ |