From a18329341368fc9b9b19055736aa9731bf81ec97 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 23 Aug 2018 11:35:55 -0700 Subject: global: run through clang-format This is the worst commit in the whole repo, making the code much less readable, but so it goes with upstream maintainers. We are now woefully wrapped at 80 columns. Signed-off-by: Jason A. Donenfeld --- src/allowedips.c | 148 +++++++++++++------ src/allowedips.h | 23 ++- src/cookie.c | 102 +++++++++---- src/cookie.h | 19 ++- src/device.c | 95 +++++++----- src/device.h | 3 +- src/hashtables.c | 105 ++++++++----- src/hashtables.h | 25 +++- src/main.c | 5 +- src/messages.h | 14 +- src/netlink.c | 215 +++++++++++++++++++-------- src/noise.c | 360 +++++++++++++++++++++++++++++++-------------- src/noise.h | 33 +++-- src/peer.c | 85 +++++++---- src/peer.h | 21 ++- src/queueing.c | 14 +- src/queueing.h | 58 ++++++-- src/ratelimiter.c | 47 +++--- src/receive.c | 257 ++++++++++++++++++++++---------- src/selftest/allowedips.h | 276 +++++++++++++++++++++------------- src/selftest/counter.h | 25 +++- src/selftest/ratelimiter.h | 48 +++--- src/send.c | 203 ++++++++++++++++--------- src/socket.c | 108 +++++++++----- src/socket.h | 33 +++-- src/timers.c | 122 ++++++++++----- src/timers.h | 3 +- 27 files changed, 1653 insertions(+), 794 deletions(-) (limited to 'src') diff --git a/src/allowedips.c b/src/allowedips.c index 1442bf4..4616645 100644 --- a/src/allowedips.c +++ b/src/allowedips.c @@ -28,7 +28,8 @@ static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits) } } -static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 cidr, u8 bits) +static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, + u8 cidr, u8 bits) { node->cidr = cidr; node->bit_at_a = cidr / 8U; @@ -39,34 +40,43 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 memcpy(node->bits, src, bits / 8U); } -#define choose_node(parent, key) parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] +#define choose_node(parent, key) \ + parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1] static void node_free_rcu(struct rcu_head *rcu) { kfree(container_of(rcu, struct allowedips_node, rcu)); } -#define push_rcu(stack, p, len) ({ \ - if (rcu_access_pointer(p)) { \ - BUG_ON(len >= 128); \ - stack[len++] = rcu_dereference_raw(p); \ - } \ - true; \ -}) +#define push_rcu(stack, p, len) ({ \ + if (rcu_access_pointer(p)) { \ + BUG_ON(len >= 128); \ + stack[len++] = rcu_dereference_raw(p); \ + } \ + true; \ + }) static void root_free_rcu(struct rcu_head *rcu) { - struct allowedips_node *node, *stack[128] = { container_of(rcu, struct allowedips_node, rcu) }; + struct allowedips_node *node, *stack[128] = + { container_of(rcu, struct allowedips_node, rcu) }; unsigned int len = 1; - while (len > 0 && (node = stack[--len]) && push_rcu(stack, node->bit[0], len) && push_rcu(stack, node->bit[1], len)) + while (len > 0 && (node = stack[--len]) && + push_rcu(stack, node->bit[0], len) && + push_rcu(stack, node->bit[1], len)) kfree(node); } -static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock) +static int +walk_by_peer(struct allowedips_node __rcu *top, u8 bits, + struct allowedips_cursor *cursor, struct wireguard_peer *peer, + int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), + void *ctx, struct mutex *lock) { + const int address_family = bits == 32 ? AF_INET : AF_INET6; + u8 ip[16] __aligned(__alignof(u64)); struct allowedips_node *node; int ret; - u8 ip[16] __aligned(__alignof(u64)); if (!rcu_access_pointer(top)) return 0; @@ -74,16 +84,21 @@ static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allow if (!cursor->len) push_rcu(cursor->stack, top, cursor->len); - for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len), push_rcu(cursor->stack, node->bit[1], cursor->len)) { - if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) != peer) + for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); + --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len), + push_rcu(cursor->stack, node->bit[1], cursor->len)) { + const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U); + + if (rcu_dereference_protected(node->peer, + lockdep_is_held(lock)) != peer) continue; swap_endian(ip, node->bits, bits); - memset(ip + (node->cidr + 7U) / 8U, 0, (bits / 8U) - ((node->cidr + 7U) / 8U)); + memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes); if (node->cidr) - ip[(node->cidr + 7U) / 8U - 1U] &= ~0U << (-node->cidr % 8U); + ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U); - ret = func(ctx, ip, node->cidr, bits == 32 ? AF_INET : AF_INET6); + ret = func(ctx, ip, node->cidr, address_family); if (ret) return ret; } @@ -93,8 +108,12 @@ static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allow #define ref(p) rcu_access_pointer(p) #define deref(p) rcu_dereference_protected(*p, lockdep_is_held(lock)) -#define push(p) ({ BUG_ON(len >= 128); stack[len++] = p; }) -static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireguard_peer *peer, struct mutex *lock) +#define push(p) ({ \ + BUG_ON(len >= 128); \ + stack[len++] = p; \ + }) +static void walk_remove_by_peer(struct allowedips_node __rcu **top, + struct wireguard_peer *peer, struct mutex *lock) { struct allowedips_node __rcu **stack[128], **nptr; struct allowedips_node *node, *prev; @@ -110,7 +129,8 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireg --len; continue; } - if (!prev || ref(prev->bit[0]) == node || ref(prev->bit[1]) == node) { + if (!prev || ref(prev->bit[0]) == node || + ref(prev->bit[1]) == node) { if (ref(node->bit[0])) push(&node->bit[0]); else if (ref(node->bit[1])) @@ -119,10 +139,12 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireg if (ref(node->bit[1])) push(&node->bit[1]); } else { - if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) { + if (rcu_dereference_protected(node->peer, + lockdep_is_held(lock)) == peer) { RCU_INIT_POINTER(node->peer, NULL); if (!node->bit[0] || !node->bit[1]) { - rcu_assign_pointer(*nptr, deref(&node->bit[!ref(node->bit[0])])); + rcu_assign_pointer(*nptr, + deref(&node->bit[!ref(node->bit[0])])); call_rcu_bh(&node->rcu, node_free_rcu); node = deref(nptr); } @@ -140,23 +162,29 @@ static __always_inline unsigned int fls128(u64 a, u64 b) return a ? fls64(a) + 64U : fls64(b); } -static __always_inline u8 common_bits(const struct allowedips_node *node, const u8 *key, u8 bits) +static __always_inline u8 common_bits(const struct allowedips_node *node, + const u8 *key, u8 bits) { if (bits == 32) return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key); else if (bits == 128) - return 128U - fls128(*(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); + return 128U - fls128( + *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], + *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]); return 0; } -/* This could be much faster if it actually just compared the common bits properly, - * by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but it turns out that - * common_bits is already super fast on modern processors, even taking into account - * the unfortunate bswap. So, we just inline it like this instead. +/* This could be much faster if it actually just compared the common bits + * properly, by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but + * it turns out that common_bits is already super fast on modern processors, + * even taking into account the unfortunate bswap. So, we just inline it like + * this instead. */ -#define prefix_matches(node, key, bits) (common_bits(node, key, bits) >= node->cidr) +#define prefix_matches(node, key, bits) \ + (common_bits(node, key, bits) >= node->cidr) -static __always_inline struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, const u8 *key) +static __always_inline struct allowedips_node * +find_node(struct allowedips_node *trie, u8 bits, const u8 *key) { struct allowedips_node *node = trie, *found = NULL; @@ -171,11 +199,12 @@ static __always_inline struct allowedips_node *find_node(struct allowedips_node } /* Returns a strong reference to a peer */ -static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip) +static __always_inline struct wireguard_peer * +lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip) { + u8 ip[16] __aligned(__alignof(u64)); struct wireguard_peer *peer = NULL; struct allowedips_node *node; - u8 ip[16] __aligned(__alignof(u64)); swap_endian(ip, be_ip, bits); @@ -191,11 +220,14 @@ retry: return peer; } -__attribute__((nonnull(1))) -static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, u8 bits, struct allowedips_node **rnode, struct mutex *lock) +__attribute__((nonnull(1))) static inline bool +node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, + u8 bits, struct allowedips_node **rnode, struct mutex *lock) { + struct allowedips_node *node = rcu_dereference_protected(trie, + lockdep_is_held(lock)); + struct allowedips_node *parent = NULL; bool exact = false; - struct allowedips_node *parent = NULL, *node = rcu_dereference_protected(trie, lockdep_is_held(lock)); while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) { parent = node; @@ -203,13 +235,15 @@ static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 * exact = true; break; } - node = rcu_dereference_protected(choose_node(parent, key), lockdep_is_held(lock)); + node = rcu_dereference_protected(choose_node(parent, key), + lockdep_is_held(lock)); } *rnode = parent; return exact; } -static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) +static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, + u8 cidr, struct wireguard_peer *peer, struct mutex *lock) { struct allowedips_node *node, *parent, *down, *newnode; u8 key[16] __aligned(__alignof(u64)); @@ -242,7 +276,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u if (!node) down = rcu_dereference_protected(*trie, lockdep_is_held(lock)); else { - down = rcu_dereference_protected(choose_node(node, key), lockdep_is_held(lock)); + down = rcu_dereference_protected(choose_node(node, key), + lockdep_is_held(lock)); if (!down) { rcu_assign_pointer(choose_node(node, key), newnode); return 0; @@ -256,7 +291,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u if (!parent) rcu_assign_pointer(*trie, newnode); else - rcu_assign_pointer(choose_node(parent, newnode->bits), newnode); + rcu_assign_pointer(choose_node(parent, newnode->bits), + newnode); } else { node = kzalloc(sizeof(*node), GFP_KERNEL); if (!node) { @@ -270,7 +306,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u if (!parent) rcu_assign_pointer(*trie, node); else - rcu_assign_pointer(choose_node(parent, node->bits), node); + rcu_assign_pointer(choose_node(parent, node->bits), + node); } return 0; } @@ -288,31 +325,42 @@ void allowedips_free(struct allowedips *table, struct mutex *lock) RCU_INIT_POINTER(table->root4, NULL); RCU_INIT_POINTER(table->root6, NULL); if (rcu_access_pointer(old4)) - call_rcu_bh(&rcu_dereference_protected(old4, lockdep_is_held(lock))->rcu, root_free_rcu); + call_rcu_bh(&rcu_dereference_protected(old4, + lockdep_is_held(lock))->rcu, root_free_rcu); if (rcu_access_pointer(old6)) - call_rcu_bh(&rcu_dereference_protected(old6, lockdep_is_held(lock))->rcu, root_free_rcu); + call_rcu_bh(&rcu_dereference_protected(old6, + lockdep_is_held(lock))->rcu, root_free_rcu); } -int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) +int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wireguard_peer *peer, + struct mutex *lock) { ++table->seq; return add(&table->root4, 32, (const u8 *)ip, cidr, peer, lock); } -int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock) +int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wireguard_peer *peer, + struct mutex *lock) { ++table->seq; return add(&table->root6, 128, (const u8 *)ip, cidr, peer, lock); } -void allowedips_remove_by_peer(struct allowedips *table, struct wireguard_peer *peer, struct mutex *lock) +void allowedips_remove_by_peer(struct allowedips *table, + struct wireguard_peer *peer, struct mutex *lock) { ++table->seq; walk_remove_by_peer(&table->root4, peer, lock); walk_remove_by_peer(&table->root6, peer, lock); } -int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock) +int allowedips_walk_by_peer(struct allowedips *table, + struct allowedips_cursor *cursor, + struct wireguard_peer *peer, + int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), + void *ctx, struct mutex *lock) { int ret; @@ -332,7 +380,8 @@ int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor * } /* Returns a strong reference to a peer */ -struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb) +struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, + struct sk_buff *skb) { if (skb->protocol == htons(ETH_P_IP)) return lookup(table->root4, 32, &ip_hdr(skb)->daddr); @@ -342,7 +391,8 @@ struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk } /* Returns a strong reference to a peer */ -struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, struct sk_buff *skb) +struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, + struct sk_buff *skb) { if (skb->protocol == htons(ETH_P_IP)) return lookup(table->root4, 32, &ip_hdr(skb)->saddr); diff --git a/src/allowedips.h b/src/allowedips.h index 97ecf69..d5ba1be 100644 --- a/src/allowedips.h +++ b/src/allowedips.h @@ -28,14 +28,25 @@ struct allowedips_cursor { void allowedips_init(struct allowedips *table); void allowedips_free(struct allowedips *table, struct mutex *mutex); -int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock); -int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock); -void allowedips_remove_by_peer(struct allowedips *table, struct wireguard_peer *peer, struct mutex *lock); -int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock); +int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, + u8 cidr, struct wireguard_peer *peer, + struct mutex *lock); +int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, + u8 cidr, struct wireguard_peer *peer, + struct mutex *lock); +void allowedips_remove_by_peer(struct allowedips *table, + struct wireguard_peer *peer, struct mutex *lock); +int allowedips_walk_by_peer(struct allowedips *table, + struct allowedips_cursor *cursor, + struct wireguard_peer *peer, + int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), + void *ctx, struct mutex *lock); /* These return a strong reference to a peer: */ -struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb); -struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, struct sk_buff *skb); +struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, + struct sk_buff *skb); +struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, + struct sk_buff *skb); #ifdef DEBUG bool allowedips_selftest(void); diff --git a/src/cookie.c b/src/cookie.c index 9268630..7cf0693 100644 --- a/src/cookie.c +++ b/src/cookie.c @@ -15,7 +15,8 @@ #include #include -void cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg) +void cookie_checker_init(struct cookie_checker *checker, + struct wireguard_device *wg) { init_rwsem(&checker->secret_lock); checker->secret_birthdate = ktime_get_boot_fast_ns(); @@ -27,7 +28,9 @@ enum { COOKIE_KEY_LABEL_LEN = 8 }; static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----"; static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--"; -static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOISE_PUBLIC_KEY_LEN], const u8 label[COOKIE_KEY_LABEL_LEN]) +static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 pubkey[NOISE_PUBLIC_KEY_LEN], + const u8 label[COOKIE_KEY_LABEL_LEN]) { struct blake2s_state blake; @@ -41,18 +44,25 @@ static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOIS void cookie_checker_precompute_device_keys(struct cookie_checker *checker) { if (likely(checker->device->static_identity.has_identity)) { - precompute_key(checker->cookie_encryption_key, checker->device->static_identity.static_public, cookie_key_label); - precompute_key(checker->message_mac1_key, checker->device->static_identity.static_public, mac1_key_label); + precompute_key(checker->cookie_encryption_key, + checker->device->static_identity.static_public, + cookie_key_label); + precompute_key(checker->message_mac1_key, + checker->device->static_identity.static_public, + mac1_key_label); } else { - memset(checker->cookie_encryption_key, 0, NOISE_SYMMETRIC_KEY_LEN); + memset(checker->cookie_encryption_key, 0, + NOISE_SYMMETRIC_KEY_LEN); memset(checker->message_mac1_key, 0, NOISE_SYMMETRIC_KEY_LEN); } } void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer) { - precompute_key(peer->latest_cookie.cookie_decryption_key, peer->handshake.remote_static, cookie_key_label); - precompute_key(peer->latest_cookie.message_mac1_key, peer->handshake.remote_static, mac1_key_label); + precompute_key(peer->latest_cookie.cookie_decryption_key, + peer->handshake.remote_static, cookie_key_label); + precompute_key(peer->latest_cookie.message_mac1_key, + peer->handshake.remote_static, mac1_key_label); } void cookie_init(struct cookie *cookie) @@ -61,19 +71,24 @@ void cookie_init(struct cookie *cookie) init_rwsem(&cookie->lock); } -static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, const u8 key[NOISE_SYMMETRIC_KEY_LEN]) +static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, + const u8 key[NOISE_SYMMETRIC_KEY_LEN]) { - len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac1); + len = len - sizeof(struct message_macs) + + offsetof(struct message_macs, mac1); blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN); } -static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, const u8 cookie[COOKIE_LEN]) +static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, + const u8 cookie[COOKIE_LEN]) { - len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac2); + len = len - sizeof(struct message_macs) + + offsetof(struct message_macs, mac2); blake2s(mac2, message, cookie, COOKIE_LEN, len, COOKIE_LEN); } -static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cookie_checker *checker) +static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, + struct cookie_checker *checker) { struct blake2s_state state; @@ -88,24 +103,30 @@ static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cooki blake2s_init_key(&state, COOKIE_LEN, checker->secret, NOISE_HASH_LEN); if (skb->protocol == htons(ETH_P_IP)) - blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr, sizeof(struct in_addr)); + blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr, + sizeof(struct in_addr)); else if (skb->protocol == htons(ETH_P_IPV6)) - blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr, sizeof(struct in6_addr)); + blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr, + sizeof(struct in6_addr)); blake2s_update(&state, (u8 *)&udp_hdr(skb)->source, sizeof(__be16)); blake2s_final(&state, cookie, COOKIE_LEN); up_read(&checker->secret_lock); } -enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie) +enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, + struct sk_buff *skb, + bool check_cookie) { + struct message_macs *macs = (struct message_macs *) + (skb->data + skb->len - sizeof(struct message_macs)); + enum cookie_mac_state ret; u8 computed_mac[COOKIE_LEN]; u8 cookie[COOKIE_LEN]; - enum cookie_mac_state ret; - struct message_macs *macs = (struct message_macs *)(skb->data + skb->len - sizeof(struct message_macs)); ret = INVALID_MAC; - compute_mac1(computed_mac, skb->data, skb->len, checker->message_mac1_key); + compute_mac1(computed_mac, skb->data, skb->len, + checker->message_mac1_key); if (crypto_memneq(computed_mac, macs->mac1, COOKIE_LEN)) goto out; @@ -130,27 +151,36 @@ out: return ret; } -void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *peer) +void cookie_add_mac_to_packet(void *message, size_t len, + struct wireguard_peer *peer) { - struct message_macs *macs = (struct message_macs *)((u8 *)message + len - sizeof(struct message_macs)); + struct message_macs *macs = (struct message_macs *) + ((u8 *)message + len - sizeof(struct message_macs)); down_write(&peer->latest_cookie.lock); - compute_mac1(macs->mac1, message, len, peer->latest_cookie.message_mac1_key); + compute_mac1(macs->mac1, message, len, + peer->latest_cookie.message_mac1_key); memcpy(peer->latest_cookie.last_mac1_sent, macs->mac1, COOKIE_LEN); peer->latest_cookie.have_sent_mac1 = true; up_write(&peer->latest_cookie.lock); down_read(&peer->latest_cookie.lock); - if (peer->latest_cookie.is_valid && !has_expired(peer->latest_cookie.birthdate, COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY)) - compute_mac2(macs->mac2, message, len, peer->latest_cookie.cookie); + if (peer->latest_cookie.is_valid && + !has_expired(peer->latest_cookie.birthdate, + COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY)) + compute_mac2(macs->mac2, message, len, + peer->latest_cookie.cookie); else memset(macs->mac2, 0, COOKIE_LEN); up_read(&peer->latest_cookie.lock); } -void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff *skb, __le32 index, struct cookie_checker *checker) +void cookie_message_create(struct message_handshake_cookie *dst, + struct sk_buff *skb, __le32 index, + struct cookie_checker *checker) { - struct message_macs *macs = (struct message_macs *)((u8 *)skb->data + skb->len - sizeof(struct message_macs)); + struct message_macs *macs = (struct message_macs *) + ((u8 *)skb->data + skb->len - sizeof(struct message_macs)); u8 cookie[COOKIE_LEN]; dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE); @@ -158,16 +188,22 @@ void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff get_random_bytes_wait(dst->nonce, COOKIE_NONCE_LEN); make_cookie(cookie, skb, checker); - xchacha20poly1305_encrypt(dst->encrypted_cookie, cookie, COOKIE_LEN, macs->mac1, COOKIE_LEN, dst->nonce, checker->cookie_encryption_key); + xchacha20poly1305_encrypt(dst->encrypted_cookie, cookie, COOKIE_LEN, + macs->mac1, COOKIE_LEN, dst->nonce, + checker->cookie_encryption_key); } -void cookie_message_consume(struct message_handshake_cookie *src, struct wireguard_device *wg) +void cookie_message_consume(struct message_handshake_cookie *src, + struct wireguard_device *wg) { - u8 cookie[COOKIE_LEN]; struct wireguard_peer *peer = NULL; + u8 cookie[COOKIE_LEN]; bool ret; - if (unlikely(!index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index, &peer))) + if (unlikely(!index_hashtable_lookup(&wg->index_hashtable, + INDEX_HASHTABLE_HANDSHAKE | + INDEX_HASHTABLE_KEYPAIR, + src->receiver_index, &peer))) return; down_read(&peer->latest_cookie.lock); @@ -175,7 +211,10 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua up_read(&peer->latest_cookie.lock); goto out; } - ret = xchacha20poly1305_decrypt(cookie, src->encrypted_cookie, sizeof(src->encrypted_cookie), peer->latest_cookie.last_mac1_sent, COOKIE_LEN, src->nonce, peer->latest_cookie.cookie_decryption_key); + ret = xchacha20poly1305_decrypt( + cookie, src->encrypted_cookie, sizeof(src->encrypted_cookie), + peer->latest_cookie.last_mac1_sent, COOKIE_LEN, src->nonce, + peer->latest_cookie.cookie_decryption_key); up_read(&peer->latest_cookie.lock); if (ret) { @@ -186,7 +225,8 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua peer->latest_cookie.have_sent_mac1 = false; up_write(&peer->latest_cookie.lock); } else - net_dbg_ratelimited("%s: Could not decrypt invalid cookie response\n", wg->dev->name); + net_dbg_ratelimited("%s: Could not decrypt invalid cookie response\n", + wg->dev->name); out: peer_put(peer); diff --git a/src/cookie.h b/src/cookie.h index 9f519ef..7802f61 100644 --- a/src/cookie.h +++ b/src/cookie.h @@ -38,15 +38,22 @@ enum cookie_mac_state { VALID_MAC_WITH_COOKIE }; -void cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg); +void cookie_checker_init(struct cookie_checker *checker, + struct wireguard_device *wg); void cookie_checker_precompute_device_keys(struct cookie_checker *checker); void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer); void cookie_init(struct cookie *cookie); -enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie); -void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *peer); - -void cookie_message_create(struct message_handshake_cookie *src, struct sk_buff *skb, __le32 index, struct cookie_checker *checker); -void cookie_message_consume(struct message_handshake_cookie *src, struct wireguard_device *wg); +enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, + struct sk_buff *skb, + bool check_cookie); +void cookie_add_mac_to_packet(void *message, size_t len, + struct wireguard_peer *peer); + +void cookie_message_create(struct message_handshake_cookie *src, + struct sk_buff *skb, __le32 index, + struct cookie_checker *checker); +void cookie_message_consume(struct message_handshake_cookie *src, + struct wireguard_device *wg); #endif /* _WG_COOKIE_H */ diff --git a/src/device.c b/src/device.c index 297125a..3dfa794 100644 --- a/src/device.c +++ b/src/device.c @@ -28,19 +28,18 @@ static LIST_HEAD(device_list); static int open(struct net_device *dev) { - int ret; - struct wireguard_peer *peer; + struct in_device *dev_v4 = __in_dev_get_rtnl(dev); struct wireguard_device *wg = netdev_priv(dev); #ifndef COMPAT_CANNOT_USE_IN6_DEV_GET struct inet6_dev *dev_v6 = __in6_dev_get(dev); #endif - struct in_device *dev_v4 = __in_dev_get_rtnl(dev); + struct wireguard_peer *peer; + int ret; if (dev_v4) { - /* TODO: at some point we might put this check near the ip_rt_send_redirect - * call of ip_forward in net/ipv4/ip_forward.c, similar to the current secpath - * check, rather than turning it off like this. This is just a stop gap solution - * while we're an out of tree module. + /* At some point we might put this check near the ip_rt_send_ + * redirect call of ip_forward in net/ipv4/ip_forward.c, similar + * to the current secpath check. */ IN_DEV_CONF_SET(dev_v4, SEND_REDIRECTS, false); IPV4_DEVCONF_ALL(dev_net(dev), SEND_REDIRECTS) = false; @@ -58,7 +57,7 @@ static int open(struct net_device *dev) if (ret < 0) return ret; mutex_lock(&wg->device_update_lock); - list_for_each_entry(peer, &wg->peer_list, peer_list) { + list_for_each_entry (peer, &wg->peer_list, peer_list) { packet_send_staged_packets(peer); if (peer->persistent_keepalive_interval) packet_send_keepalive(peer); @@ -68,7 +67,8 @@ static int open(struct net_device *dev) } #if defined(CONFIG_PM_SLEEP) && !defined(CONFIG_ANDROID) -static int pm_notification(struct notifier_block *nb, unsigned long action, void *data) +static int pm_notification(struct notifier_block *nb, unsigned long action, + void *data) { struct wireguard_device *wg; struct wireguard_peer *peer; @@ -77,9 +77,9 @@ static int pm_notification(struct notifier_block *nb, unsigned long action, void return 0; rtnl_lock(); - list_for_each_entry(wg, &device_list, device_list) { + list_for_each_entry (wg, &device_list, device_list) { mutex_lock(&wg->device_update_lock); - list_for_each_entry(peer, &wg->peer_list, peer_list) { + list_for_each_entry (peer, &wg->peer_list, peer_list) { noise_handshake_clear(&peer->handshake); noise_keypairs_clear(&peer->keypairs); if (peer->timers_enabled) @@ -100,12 +100,14 @@ static int stop(struct net_device *dev) struct wireguard_peer *peer; mutex_lock(&wg->device_update_lock); - list_for_each_entry(peer, &wg->peer_list, peer_list) { + list_for_each_entry (peer, &wg->peer_list, peer_list) { skb_queue_purge(&peer->staged_packet_queue); timers_stop(peer); noise_handshake_clear(&peer->handshake); noise_keypairs_clear(&peer->keypairs); - atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns() - (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC); + atomic64_set(&peer->last_sent_handshake, + ktime_get_boot_fast_ns() - + (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC); } mutex_unlock(&wg->device_update_lock); skb_queue_purge(&wg->incoming_handshakes); @@ -133,16 +135,19 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev) if (unlikely(!peer)) { ret = -ENOKEY; if (skb->protocol == htons(ETH_P_IP)) - net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI4\n", dev->name, &ip_hdr(skb)->daddr); + net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI4\n", + dev->name, &ip_hdr(skb)->daddr); else if (skb->protocol == htons(ETH_P_IPV6)) - net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI6\n", dev->name, &ipv6_hdr(skb)->daddr); + net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI6\n", + dev->name, &ipv6_hdr(skb)->daddr); goto err; } family = READ_ONCE(peer->endpoint.addr.sa_family); if (unlikely(family != AF_INET && family != AF_INET6)) { ret = -EDESTADDRREQ; - net_dbg_ratelimited("%s: No valid endpoint has been configured or discovered for peer %llu\n", dev->name, peer->internal_id); + net_dbg_ratelimited("%s: No valid endpoint has been configured or discovered for peer %llu\n", + dev->name, peer->internal_id); goto err_peer; } @@ -180,8 +185,9 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev) } while ((skb = next) != NULL); spin_lock_bh(&peer->staged_packet_queue.lock); - /* If the queue is getting too big, we start removing the oldest packets until it's small again. - * We do this before adding the new packet, so we don't remove GSO segments that are in excess. + /* If the queue is getting too big, we start removing the oldest packets + * until it's small again. We do this before adding the new packet, so + * we don't remove GSO segments that are in excess. */ while (skb_queue_len(&peer->staged_packet_queue) > MAX_STAGED_PACKETS) dev_kfree_skb(__skb_dequeue(&peer->staged_packet_queue)); @@ -223,7 +229,8 @@ static void destruct(struct net_device *dev) wg->incoming_port = 0; socket_reinit(wg, NULL, NULL); allowedips_free(&wg->peer_allowedips, &wg->device_update_lock); - peer_remove_all(wg); /* The final references are cleared in the below calls to destroy_workqueue. */ + /* The final references are cleared in the below calls to destroy_workqueue. */ + peer_remove_all(wg); destroy_workqueue(wg->handshake_receive_wq); destroy_workqueue(wg->handshake_send_wq); destroy_workqueue(wg->packet_crypt_wq); @@ -231,7 +238,8 @@ static void destruct(struct net_device *dev) packet_queue_free(&wg->encrypt_queue, true); rcu_barrier_bh(); /* Wait for all the peers to be actually freed. */ ratelimiter_uninit(); - memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity)); + memzero_explicit(&wg->static_identity, + sizeof(struct noise_static_identity)); skb_queue_purge(&wg->incoming_handshakes); free_percpu(dev->tstats); free_percpu(wg->incoming_handshakes_worker); @@ -243,14 +251,14 @@ static void destruct(struct net_device *dev) free_netdev(dev); } -static const struct device_type device_type = { - .name = KBUILD_MODNAME -}; +static const struct device_type device_type = { .name = KBUILD_MODNAME }; static void setup(struct net_device *dev) { struct wireguard_device *wg = netdev_priv(dev); - enum { WG_NETDEV_FEATURES = NETIF_F_HW_CSUM | NETIF_F_RXCSUM | NETIF_F_SG | NETIF_F_GSO | NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA }; + enum { WG_NETDEV_FEATURES = NETIF_F_HW_CSUM | NETIF_F_RXCSUM | + NETIF_F_SG | NETIF_F_GSO | + NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA }; dev->netdev_ops = &netdev_ops; dev->hard_header_len = 0; @@ -268,7 +276,9 @@ static void setup(struct net_device *dev) dev->features |= WG_NETDEV_FEATURES; dev->hw_features |= WG_NETDEV_FEATURES; dev->hw_enc_features |= WG_NETDEV_FEATURES; - dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH - sizeof(struct udphdr) - max(sizeof(struct ipv6hdr), sizeof(struct iphdr)); + dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH - + sizeof(struct udphdr) - + max(sizeof(struct ipv6hdr), sizeof(struct iphdr)); SET_NETDEV_DEVTYPE(dev, &device_type); @@ -279,7 +289,9 @@ static void setup(struct net_device *dev) wg->dev = dev; } -static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[], struct netlink_ext_ack *extack) +static int newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) { int ret = -ENOMEM; struct wireguard_device *wg = netdev_priv(dev); @@ -300,26 +312,32 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t if (!dev->tstats) goto error_1; - wg->incoming_handshakes_worker = packet_alloc_percpu_multicore_worker(packet_handshake_receive_worker, wg); + wg->incoming_handshakes_worker = packet_alloc_percpu_multicore_worker( + packet_handshake_receive_worker, wg); if (!wg->incoming_handshakes_worker) goto error_2; - wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name); + wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", + WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name); if (!wg->handshake_receive_wq) goto error_3; - wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); + wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", + WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name); if (!wg->handshake_send_wq) goto error_4; - wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name); + wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s", + WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name); if (!wg->packet_crypt_wq) goto error_5; - if (packet_queue_init(&wg->encrypt_queue, packet_encrypt_worker, true, MAX_QUEUED_PACKETS) < 0) + if (packet_queue_init(&wg->encrypt_queue, packet_encrypt_worker, true, + MAX_QUEUED_PACKETS) < 0) goto error_6; - if (packet_queue_init(&wg->decrypt_queue, packet_decrypt_worker, true, MAX_QUEUED_PACKETS) < 0) + if (packet_queue_init(&wg->decrypt_queue, packet_decrypt_worker, true, + MAX_QUEUED_PACKETS) < 0) goto error_7; ret = ratelimiter_init(); @@ -332,8 +350,8 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t list_add(&wg->device_list, &device_list); - /* We wait until the end to assign priv_destructor, so that register_netdevice doesn't - * call it for us if it fails. + /* We wait until the end to assign priv_destructor, so that + * register_netdevice doesn't call it for us if it fails. */ dev->priv_destructor = destruct; @@ -367,7 +385,8 @@ static struct rtnl_link_ops link_ops __read_mostly = { .newlink = newlink, }; -static int netdevice_notification(struct notifier_block *nb, unsigned long action, void *data) +static int netdevice_notification(struct notifier_block *nb, + unsigned long action, void *data) { struct net_device *dev = ((struct netdev_notifier_info *)data)->dev; struct wireguard_device *wg = netdev_priv(dev); @@ -380,14 +399,16 @@ static int netdevice_notification(struct notifier_block *nb, unsigned long actio if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) { put_net(wg->creating_net); wg->have_creating_net_ref = false; - } else if (dev_net(dev) != wg->creating_net && !wg->have_creating_net_ref) { + } else if (dev_net(dev) != wg->creating_net && + !wg->have_creating_net_ref) { wg->have_creating_net_ref = true; get_net(wg->creating_net); } return 0; } -static struct notifier_block netdevice_notifier = { .notifier_call = netdevice_notification }; +static struct notifier_block netdevice_notifier = + { .notifier_call = netdevice_notification }; int __init device_init(void) { diff --git a/src/device.h b/src/device.h index 2a0e2c7..2499782 100644 --- a/src/device.h +++ b/src/device.h @@ -42,7 +42,8 @@ struct wireguard_device { struct sock __rcu *sock4, *sock6; struct net *creating_net; struct noise_static_identity static_identity; - struct workqueue_struct *handshake_receive_wq, *handshake_send_wq, *packet_crypt_wq; + struct workqueue_struct *handshake_receive_wq, *handshake_send_wq; + struct workqueue_struct *packet_crypt_wq; struct sk_buff_head incoming_handshakes; int incoming_handshake_cpu; struct multicore_worker __percpu *incoming_handshakes_worker; diff --git a/src/hashtables.c b/src/hashtables.c index ac6df59..4ba2288 100644 --- a/src/hashtables.c +++ b/src/hashtables.c @@ -7,12 +7,16 @@ #include "peer.h" #include "noise.h" -static inline struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) +static inline struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table, + const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) { - /* siphash gives us a secure 64bit number based on a random key. Since the bits are - * uniformly distributed, we can then mask off to get the bits we need. + /* siphash gives us a secure 64bit number based on a random key. Since + * the bits are uniformly distributed, we can then mask off to get the + * bits we need. */ - return &table->hashtable[siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key) & (HASH_SIZE(table->hashtable) - 1)]; + return &table->hashtable[ + siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key) & + (HASH_SIZE(table->hashtable) - 1)]; } void pubkey_hashtable_init(struct pubkey_hashtable *table) @@ -22,14 +26,17 @@ void pubkey_hashtable_init(struct pubkey_hashtable *table) mutex_init(&table->lock); } -void pubkey_hashtable_add(struct pubkey_hashtable *table, struct wireguard_peer *peer) +void pubkey_hashtable_add(struct pubkey_hashtable *table, + struct wireguard_peer *peer) { mutex_lock(&table->lock); - hlist_add_head_rcu(&peer->pubkey_hash, pubkey_bucket(table, peer->handshake.remote_static)); + hlist_add_head_rcu(&peer->pubkey_hash, + pubkey_bucket(table, peer->handshake.remote_static)); mutex_unlock(&table->lock); } -void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_peer *peer) +void pubkey_hashtable_remove(struct pubkey_hashtable *table, + struct wireguard_peer *peer) { mutex_lock(&table->lock); hlist_del_init_rcu(&peer->pubkey_hash); @@ -37,13 +44,17 @@ void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_pe } /* Returns a strong reference to a peer */ -struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) +struct wireguard_peer * +pubkey_hashtable_lookup(struct pubkey_hashtable *table, + const u8 pubkey[NOISE_PUBLIC_KEY_LEN]) { struct wireguard_peer *iter_peer, *peer = NULL; rcu_read_lock_bh(); - hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey), pubkey_hash) { - if (!memcmp(pubkey, iter_peer->handshake.remote_static, NOISE_PUBLIC_KEY_LEN)) { + hlist_for_each_entry_rcu_bh (iter_peer, pubkey_bucket(table, pubkey), + pubkey_hash) { + if (!memcmp(pubkey, iter_peer->handshake.remote_static, + NOISE_PUBLIC_KEY_LEN)) { peer = iter_peer; break; } @@ -53,12 +64,14 @@ struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, c return peer; } -static inline struct hlist_head *index_bucket(struct index_hashtable *table, const __le32 index) +static inline struct hlist_head *index_bucket(struct index_hashtable *table, + const __le32 index) { - /* Since the indices are random and thus all bits are uniformly distributed, - * we can find its bucket simply by masking. + /* Since the indices are random and thus all bits are uniformly + * distributed, we can find its bucket simply by masking. */ - return &table->hashtable[(__force u32)index & (HASH_SIZE(table->hashtable) - 1)]; + return &table->hashtable[(__force u32)index & + (HASH_SIZE(table->hashtable) - 1)]; } void index_hashtable_init(struct index_hashtable *table) @@ -67,9 +80,10 @@ void index_hashtable_init(struct index_hashtable *table) spin_lock_init(&table->lock); } -/* At the moment, we limit ourselves to 2^20 total peers, which generally might amount to 2^20*3 - * items in this hashtable. The algorithm below works by picking a random number and testing it. - * We can see that these limits mean we usually succeed pretty quickly: +/* At the moment, we limit ourselves to 2^20 total peers, which generally might + * amount to 2^20*3 items in this hashtable. The algorithm below works by + * picking a random number and testing it. We can see that these limits mean we + * usually succeed pretty quickly: * * >>> def calculation(tries, size): * ... return (size / 2**32)**(tries - 1) * (1 - (size / 2**32)) @@ -83,13 +97,15 @@ void index_hashtable_init(struct index_hashtable *table) * >>> calculation(4, 2**20 * 3) * 3.9261394135792216e-10 * - * At the moment, we don't do any masking, so this algorithm isn't exactly constant time in - * either the random guessing or in the hash list lookup. We could require a minimum of 3 - * tries, which would successfully mask the guessing. TODO: this would not, however, help - * with the growing hash lengths. + * At the moment, we don't do any masking, so this algorithm isn't exactly + * constant time in either the random guessing or in the hash list lookup. We + * could require a minimum of 3 tries, which would successfully mask the + * guessing. this would not, however, help with the growing hash lengths, which + * is another thing to consider moving forward. */ -__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry) +__le32 index_hashtable_insert(struct index_hashtable *table, + struct index_hashtable_entry *entry) { struct index_hashtable_entry *existing_entry; @@ -102,23 +118,32 @@ __le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashta search_unused_slot: /* First we try to find an unused slot, randomly, while unlocked. */ entry->index = (__force __le32)get_random_u32(); - hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) { + hlist_for_each_entry_rcu_bh (existing_entry, + index_bucket(table, entry->index), + index_hash) { if (existing_entry->index == entry->index) - goto search_unused_slot; /* If it's already in use, we continue searching. */ + /* If it's already in use, we continue searching. */ + goto search_unused_slot; } /* Once we've found an unused slot, we lock it, and then double-check * that nobody else stole it from us. */ spin_lock_bh(&table->lock); - hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) { + hlist_for_each_entry_rcu_bh (existing_entry, + index_bucket(table, entry->index), + index_hash) { if (existing_entry->index == entry->index) { spin_unlock_bh(&table->lock); - goto search_unused_slot; /* If it was stolen, we start over. */ + /* If it was stolen, we start over. */ + goto search_unused_slot; } } - /* Otherwise, we know we have it exclusively (since we're locked), so we insert. */ - hlist_add_head_rcu(&entry->index_hash, index_bucket(table, entry->index)); + /* Otherwise, we know we have it exclusively (since we're locked), + * so we insert. + */ + hlist_add_head_rcu(&entry->index_hash, + index_bucket(table, entry->index)); spin_unlock_bh(&table->lock); rcu_read_unlock_bh(); @@ -126,7 +151,9 @@ search_unused_slot: return entry->index; } -bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new) +bool index_hashtable_replace(struct index_hashtable *table, + struct index_hashtable_entry *old, + struct index_hashtable_entry *new) { if (unlikely(hlist_unhashed(&old->index_hash))) return false; @@ -134,17 +161,19 @@ bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtab new->index = old->index; hlist_replace_rcu(&old->index_hash, &new->index_hash); - /* Calling init here NULLs out index_hash, and in fact after this function returns, - * it's theoretically possible for this to get reinserted elsewhere. That means - * the RCU lookup below might either terminate early or jump between buckets, in which - * case the packet simply gets dropped, which isn't terrible. + /* Calling init here NULLs out index_hash, and in fact after this + * function returns, it's theoretically possible for this to get + * reinserted elsewhere. That means the RCU lookup below might either + * terminate early or jump between buckets, in which case the packet + * simply gets dropped, which isn't terrible. */ INIT_HLIST_NODE(&old->index_hash); spin_unlock_bh(&table->lock); return true; } -void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry) +void index_hashtable_remove(struct index_hashtable *table, + struct index_hashtable_entry *entry) { spin_lock_bh(&table->lock); hlist_del_init_rcu(&entry->index_hash); @@ -152,12 +181,16 @@ void index_hashtable_remove(struct index_hashtable *table, struct index_hashtabl } /* Returns a strong reference to a entry->peer */ -struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer) +struct index_hashtable_entry * +index_hashtable_lookup(struct index_hashtable *table, + const enum index_hashtable_type type_mask, + const __le32 index, struct wireguard_peer **peer) { struct index_hashtable_entry *iter_entry, *entry = NULL; rcu_read_lock_bh(); - hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index), index_hash) { + hlist_for_each_entry_rcu_bh (iter_entry, index_bucket(table, index), + index_hash) { if (iter_entry->index == index) { if (likely(iter_entry->type & type_mask)) entry = iter_entry; diff --git a/src/hashtables.h b/src/hashtables.h index f64cd24..62858c5 100644 --- a/src/hashtables.h +++ b/src/hashtables.h @@ -22,9 +22,13 @@ struct pubkey_hashtable { }; void pubkey_hashtable_init(struct pubkey_hashtable *table); -void pubkey_hashtable_add(struct pubkey_hashtable *table, struct wireguard_peer *peer); -void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_peer *peer); -struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN]); +void pubkey_hashtable_add(struct pubkey_hashtable *table, + struct wireguard_peer *peer); +void pubkey_hashtable_remove(struct pubkey_hashtable *table, + struct wireguard_peer *peer); +struct wireguard_peer * +pubkey_hashtable_lookup(struct pubkey_hashtable *table, + const u8 pubkey[NOISE_PUBLIC_KEY_LEN]); struct index_hashtable { /* TODO: move to rhashtable */ @@ -44,9 +48,16 @@ struct index_hashtable_entry { __le32 index; }; void index_hashtable_init(struct index_hashtable *table); -__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry); -bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new); -void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry); -struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer); +__le32 index_hashtable_insert(struct index_hashtable *table, + struct index_hashtable_entry *entry); +bool index_hashtable_replace(struct index_hashtable *table, + struct index_hashtable_entry *old, + struct index_hashtable_entry *new); +void index_hashtable_remove(struct index_hashtable *table, + struct index_hashtable_entry *entry); +struct index_hashtable_entry * +index_hashtable_lookup(struct index_hashtable *table, + const enum index_hashtable_type type_mask, + const __le32 index, struct wireguard_peer **peer); #endif /* _WG_HASHTABLES_H */ diff --git a/src/main.c b/src/main.c index 94f9a7d..bc28516 100644 --- a/src/main.c +++ b/src/main.c @@ -31,7 +31,10 @@ static int __init mod_init(void) blake2s_fpu_init(); curve25519_fpu_init(); #ifdef DEBUG - if (!allowedips_selftest() || !packet_counter_selftest() || !curve25519_selftest() || !poly1305_selftest() || !chacha20poly1305_selftest() || !blake2s_selftest() || !ratelimiter_selftest()) + if (!allowedips_selftest() || !packet_counter_selftest() || + !curve25519_selftest() || !poly1305_selftest() || + !chacha20poly1305_selftest() || !blake2s_selftest() || + !ratelimiter_selftest()) return -ENOTRECOVERABLE; #endif noise_init(); diff --git a/src/messages.h b/src/messages.h index 9425a26..bb19957 100644 --- a/src/messages.h +++ b/src/messages.h @@ -109,18 +109,20 @@ struct message_data { u8 encrypted_data[]; }; -#define message_data_len(plain_len) (noise_encrypted_len(plain_len) + sizeof(struct message_data)) +#define message_data_len(plain_len) \ + (noise_encrypted_len(plain_len) + sizeof(struct message_data)) enum message_alignments { MESSAGE_PADDING_MULTIPLE = 16, MESSAGE_MINIMUM_LENGTH = message_data_len(0) }; -#define SKB_HEADER_LEN (max(sizeof(struct iphdr), sizeof(struct ipv6hdr)) + sizeof(struct udphdr) + NET_SKB_PAD) -#define DATA_PACKET_HEAD_ROOM ALIGN(sizeof(struct message_data) + SKB_HEADER_LEN, 4) +#define SKB_HEADER_LEN \ + (max(sizeof(struct iphdr), sizeof(struct ipv6hdr)) + \ + sizeof(struct udphdr) + NET_SKB_PAD) +#define DATA_PACKET_HEAD_ROOM \ + ALIGN(sizeof(struct message_data) + SKB_HEADER_LEN, 4) -enum { - HANDSHAKE_DSCP = 0x88 /* AF41, plus 00 ECN */ -}; +enum { HANDSHAKE_DSCP = 0x88 /* AF41, plus 00 ECN */ }; #endif /* _WG_MESSAGES_H */ diff --git a/src/netlink.c b/src/netlink.c index 3147587..5390498 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -45,19 +45,23 @@ static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = { [WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 } }; -static struct wireguard_device *lookup_interface(struct nlattr **attrs, struct sk_buff *skb) +static struct wireguard_device *lookup_interface(struct nlattr **attrs, + struct sk_buff *skb) { struct net_device *dev = NULL; if (!attrs[WGDEVICE_A_IFINDEX] == !attrs[WGDEVICE_A_IFNAME]) return ERR_PTR(-EBADR); if (attrs[WGDEVICE_A_IFINDEX]) - dev = dev_get_by_index(sock_net(skb->sk), nla_get_u32(attrs[WGDEVICE_A_IFINDEX])); + dev = dev_get_by_index(sock_net(skb->sk), + nla_get_u32(attrs[WGDEVICE_A_IFINDEX])); else if (attrs[WGDEVICE_A_IFNAME]) - dev = dev_get_by_name(sock_net(skb->sk), nla_data(attrs[WGDEVICE_A_IFNAME])); + dev = dev_get_by_name(sock_net(skb->sk), + nla_data(attrs[WGDEVICE_A_IFNAME])); if (!dev) return ERR_PTR(-ENODEV); - if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind || strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) { + if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind || + strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) { dev_put(dev); return ERR_PTR(-EOPNOTSUPP); } @@ -71,15 +75,17 @@ struct allowedips_ctx { static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) { - struct nlattr *allowedip_nest; struct allowedips_ctx *actx = ctx; + struct nlattr *allowedip_nest; allowedip_nest = nla_nest_start(actx->skb, actx->i++); if (!allowedip_nest) return -EMSGSIZE; - if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) || nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) || - nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) { + if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) || + nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) || + nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? + sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) { nla_nest_cancel(actx->skb, allowedip_nest); return -EMSGSIZE; } @@ -88,37 +94,52 @@ static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family) return 0; } -static int get_peer(struct wireguard_peer *peer, unsigned int index, struct allowedips_cursor *rt_cursor, struct sk_buff *skb) +static int get_peer(struct wireguard_peer *peer, unsigned int index, + struct allowedips_cursor *rt_cursor, struct sk_buff *skb) { - struct allowedips_ctx ctx = { .skb = skb }; struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, index); + struct allowedips_ctx ctx = { .skb = skb }; bool fail; if (!peer_nest) return -EMSGSIZE; down_read(&peer->handshake.lock); - fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN, peer->handshake.remote_static); + fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN, + peer->handshake.remote_static); up_read(&peer->handshake.lock); if (fail) goto err; if (!rt_cursor->seq) { down_read(&peer->handshake.lock); - fail = nla_put(skb, WGPEER_A_PRESHARED_KEY, NOISE_SYMMETRIC_KEY_LEN, peer->handshake.preshared_key); + fail = nla_put(skb, WGPEER_A_PRESHARED_KEY, + NOISE_SYMMETRIC_KEY_LEN, + peer->handshake.preshared_key); up_read(&peer->handshake.lock); if (fail) goto err; - if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME, sizeof(struct timespec), &peer->walltime_last_handshake) || nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval) || - nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes, WGPEER_A_UNSPEC) || nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes, WGPEER_A_UNSPEC)) + if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME, + sizeof(struct timespec), + &peer->walltime_last_handshake) || + nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, + peer->persistent_keepalive_interval) || + nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes, + WGPEER_A_UNSPEC) || + nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes, + WGPEER_A_UNSPEC)) goto err; read_lock_bh(&peer->endpoint_lock); if (peer->endpoint.addr.sa_family == AF_INET) - fail = nla_put(skb, WGPEER_A_ENDPOINT, sizeof(struct sockaddr_in), &peer->endpoint.addr4); + fail = nla_put(skb, WGPEER_A_ENDPOINT, + sizeof(struct sockaddr_in), + &peer->endpoint.addr4); else if (peer->endpoint.addr.sa_family == AF_INET6) - fail = nla_put(skb, WGPEER_A_ENDPOINT, sizeof(struct sockaddr_in6), &peer->endpoint.addr6); + fail = nla_put(skb, WGPEER_A_ENDPOINT, + sizeof(struct sockaddr_in6), + &peer->endpoint.addr6); read_unlock_bh(&peer->endpoint_lock); if (fail) goto err; @@ -127,7 +148,9 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, struct allo allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS); if (!allowedips_nest) goto err; - if (allowedips_walk_by_peer(&peer->device->peer_allowedips, rt_cursor, peer, get_allowedips, &ctx, &peer->device->device_update_lock)) { + if (allowedips_walk_by_peer(&peer->device->peer_allowedips, rt_cursor, + peer, get_allowedips, &ctx, + &peer->device->device_update_lock)) { nla_nest_end(skb, allowedips_nest); nla_nest_end(skb, peer_nest); return -EMSGSIZE; @@ -143,13 +166,15 @@ err: static int get_device_start(struct netlink_callback *cb) { - struct wireguard_device *wg; struct nlattr **attrs = genl_family_attrbuf(&genl_family); - int ret = nlmsg_parse(cb->nlh, GENL_HDRLEN + genl_family.hdrsize, attrs, genl_family.maxattr, device_policy, NULL); + int ret = nlmsg_parse(cb->nlh, GENL_HDRLEN + genl_family.hdrsize, attrs, + genl_family.maxattr, device_policy, NULL); + struct wireguard_device *wg; if (ret < 0) return ret; - cb->args[2] = (long)kzalloc(sizeof(struct allowedips_cursor), GFP_KERNEL); + cb->args[2] = + (long)kzalloc(sizeof(struct allowedips_cursor), GFP_KERNEL); if (!cb->args[2]) return -ENOMEM; wg = lookup_interface(attrs, cb->skb); @@ -164,33 +189,46 @@ static int get_device_start(struct netlink_callback *cb) static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) { - struct wireguard_device *wg = (struct wireguard_device *)cb->args[0]; struct wireguard_peer *peer, *next_peer_cursor, *last_peer_cursor; - struct allowedips_cursor *rt_cursor = (struct allowedips_cursor *)cb->args[2]; + struct allowedips_cursor *rt_cursor; + struct wireguard_device *wg; unsigned int peer_idx = 0; struct nlattr *peers_nest; bool done = true; void *hdr; int ret = -EMSGSIZE; - next_peer_cursor = last_peer_cursor = (struct wireguard_peer *)cb->args[1]; + wg = (struct wireguard_device *)cb->args[0]; + next_peer_cursor = (struct wireguard_peer *)cb->args[1]; + last_peer_cursor = (struct wireguard_peer *)cb->args[1]; + rt_cursor = (struct allowedips_cursor *)cb->args[2]; rtnl_lock(); mutex_lock(&wg->device_update_lock); cb->seq = wg->device_update_gen; - hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE); + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE); if (!hdr) goto out; genl_dump_check_consistent(cb, hdr); if (!last_peer_cursor) { - if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT, wg->incoming_port) || nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) || nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) || nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name)) + if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT, + wg->incoming_port) || + nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) || + nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) || + nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name)) goto out; down_read(&wg->static_identity.lock); if (wg->static_identity.has_identity) { - if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY, NOISE_PUBLIC_KEY_LEN, wg->static_identity.static_private) || nla_put(skb, WGDEVICE_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN, wg->static_identity.static_public)) { + if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY, + NOISE_PUBLIC_KEY_LEN, + wg->static_identity.static_private) || + nla_put(skb, WGDEVICE_A_PUBLIC_KEY, + NOISE_PUBLIC_KEY_LEN, + wg->static_identity.static_public)) { up_read(&wg->static_identity.lock); goto out; } @@ -202,17 +240,19 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb) if (!peers_nest) goto out; ret = 0; - /* If the last cursor was removed via list_del_init in peer_remove, then we just treat - * this the same as there being no more peers left. The reason is that seq_nr should - * indicate to userspace that this isn't a coherent dump anyway, so they'll try again. + /* If the last cursor was removed via list_del_init in peer_remove, then + * we just treat this the same as there being no more peers left. The + * reason is that seq_nr should indicate to userspace that this isn't a + * coherent dump anyway, so they'll try again. */ - if (list_empty(&wg->peer_list) || (last_peer_cursor && list_empty(&last_peer_cursor->peer_list))) { + if (list_empty(&wg->peer_list) || + (last_peer_cursor && list_empty(&last_peer_cursor->peer_list))) { nla_nest_cancel(skb, peers_nest); goto out; } lockdep_assert_held(&wg->device_update_lock); peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list); - list_for_each_entry_continue(peer, &wg->peer_list, peer_list) { + list_for_each_entry_continue (peer, &wg->peer_list, peer_list) { if (get_peer(peer, peer_idx++, rt_cursor, skb)) { done = false; break; @@ -250,7 +290,8 @@ static int get_device_done(struct netlink_callback *cb) { struct wireguard_device *wg = (struct wireguard_device *)cb->args[0]; struct wireguard_peer *peer = (struct wireguard_peer *)cb->args[1]; - struct allowedips_cursor *rt_cursor = (struct allowedips_cursor *)cb->args[2]; + struct allowedips_cursor *rt_cursor = + (struct allowedips_cursor *)cb->args[2]; if (wg) dev_put(wg->dev); @@ -265,7 +306,7 @@ static int set_port(struct wireguard_device *wg, u16 port) if (wg->incoming_port == port) return 0; - list_for_each_entry(peer, &wg->peer_list, peer_list) + list_for_each_entry (peer, &wg->peer_list, peer_list) socket_clear_peer_endpoint_src(peer); if (!netif_running(wg->dev)) { wg->incoming_port = port; @@ -280,15 +321,25 @@ static int set_allowedip(struct wireguard_peer *peer, struct nlattr **attrs) u16 family; u8 cidr; - if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] || !attrs[WGALLOWEDIP_A_CIDR_MASK]) + if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] || + !attrs[WGALLOWEDIP_A_CIDR_MASK]) return ret; family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]); cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]); - if (family == AF_INET && cidr <= 32 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) - ret = allowedips_insert_v4(&peer->device->peer_allowedips, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock); - else if (family == AF_INET6 && cidr <= 128 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr)) - ret = allowedips_insert_v6(&peer->device->peer_allowedips, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock); + if (family == AF_INET && cidr <= 32 && + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr)) + ret = allowedips_insert_v4( + &peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, + &peer->device->device_update_lock); + else if (family == AF_INET6 && cidr <= 128 && + nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == + sizeof(struct in6_addr)) + ret = allowedips_insert_v6( + &peer->device->peer_allowedips, + nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, + &peer->device->device_update_lock); return ret; } @@ -301,25 +352,33 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) u8 *public_key = NULL, *preshared_key = NULL; ret = -EINVAL; - if (attrs[WGPEER_A_PUBLIC_KEY] && nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN) + if (attrs[WGPEER_A_PUBLIC_KEY] && + nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN) public_key = nla_data(attrs[WGPEER_A_PUBLIC_KEY]); else goto out; - if (attrs[WGPEER_A_PRESHARED_KEY] && nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN) + if (attrs[WGPEER_A_PRESHARED_KEY] && + nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN) preshared_key = nla_data(attrs[WGPEER_A_PRESHARED_KEY]); if (attrs[WGPEER_A_FLAGS]) flags = nla_get_u32(attrs[WGPEER_A_FLAGS]); - peer = pubkey_hashtable_lookup(&wg->peer_hashtable, nla_data(attrs[WGPEER_A_PUBLIC_KEY])); + peer = pubkey_hashtable_lookup(&wg->peer_hashtable, + nla_data(attrs[WGPEER_A_PUBLIC_KEY])); if (!peer) { /* Peer doesn't exist yet. Add a new one. */ ret = -ENODEV; if (flags & WGPEER_F_REMOVE_ME) goto out; /* Tried to remove a non-existing peer. */ down_read(&wg->static_identity.lock); - if (wg->static_identity.has_identity && !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]), wg->static_identity.static_public, NOISE_PUBLIC_KEY_LEN)) { - /* We silently ignore peers that have the same public key as the device. The reason we do it silently - * is that we'd like for people to be able to reuse the same set of API calls across peers. + if (wg->static_identity.has_identity && + !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]), + wg->static_identity.static_public, + NOISE_PUBLIC_KEY_LEN)) { + /* We silently ignore peers that have the same public + * key as the device. The reason we do it silently is + * that we'd like for people to be able to reuse the + * same set of API calls across peers. */ up_read(&wg->static_identity.lock); ret = 0; @@ -331,7 +390,9 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) peer = peer_create(wg, public_key, preshared_key); if (!peer) goto out; - /* Take additional reference, as though we've just been looked up. */ + /* Take additional reference, as though we've just been + * looked up. + */ peer_get(peer); } @@ -343,7 +404,8 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) if (preshared_key) { down_write(&peer->handshake.lock); - memcpy(&peer->handshake.preshared_key, preshared_key, NOISE_SYMMETRIC_KEY_LEN); + memcpy(&peer->handshake.preshared_key, preshared_key, + NOISE_SYMMETRIC_KEY_LEN); up_write(&peer->handshake.lock); } @@ -351,7 +413,10 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]); size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]); - if ((len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) || (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6)) { + if ((len == sizeof(struct sockaddr_in) && + addr->sa_family == AF_INET) || + (len == sizeof(struct sockaddr_in6) && + addr->sa_family == AF_INET6)) { struct endpoint endpoint = { { { 0 } } }; memcpy(&endpoint.addr, addr, len); @@ -360,14 +425,16 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) } if (flags & WGPEER_F_REPLACE_ALLOWEDIPS) - allowedips_remove_by_peer(&wg->peer_allowedips, peer, &wg->device_update_lock); + allowedips_remove_by_peer(&wg->peer_allowedips, peer, + &wg->device_update_lock); if (attrs[WGPEER_A_ALLOWEDIPS]) { - int rem; struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1]; + int rem; - nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) { - ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX, attr, allowedip_policy, NULL); + nla_for_each_nested (attr, attrs[WGPEER_A_ALLOWEDIPS], rem) { + ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX, + attr, allowedip_policy, NULL); if (ret < 0) goto out; ret = set_allowedip(peer, allowedip); @@ -377,8 +444,12 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) } if (attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]) { - const u16 persistent_keepalive_interval = nla_get_u16(attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]); - const bool send_keepalive = !peer->persistent_keepalive_interval && persistent_keepalive_interval && netif_running(wg->dev); + const u16 persistent_keepalive_interval = nla_get_u16( + attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]); + const bool send_keepalive = + !peer->persistent_keepalive_interval && + persistent_keepalive_interval && + netif_running(wg->dev); peer->persistent_keepalive_interval = persistent_keepalive_interval; if (send_keepalive) @@ -391,7 +462,8 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs) out: peer_put(peer); if (attrs[WGPEER_A_PRESHARED_KEY]) - memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]), nla_len(attrs[WGPEER_A_PRESHARED_KEY])); + memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]), + nla_len(attrs[WGPEER_A_PRESHARED_KEY])); return ret; } @@ -413,26 +485,35 @@ static int set_device(struct sk_buff *skb, struct genl_info *info) struct wireguard_peer *peer; wg->fwmark = nla_get_u32(info->attrs[WGDEVICE_A_FWMARK]); - list_for_each_entry(peer, &wg->peer_list, peer_list) + list_for_each_entry (peer, &wg->peer_list, peer_list) socket_clear_peer_endpoint_src(peer); } if (info->attrs[WGDEVICE_A_LISTEN_PORT]) { - ret = set_port(wg, nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT])); + ret = set_port( + wg, nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT])); if (ret) goto out; } - if (info->attrs[WGDEVICE_A_FLAGS] && nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]) & WGDEVICE_F_REPLACE_PEERS) + if (info->attrs[WGDEVICE_A_FLAGS] && + nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]) & + WGDEVICE_F_REPLACE_PEERS) peer_remove_all(wg); - if (info->attrs[WGDEVICE_A_PRIVATE_KEY] && nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) == NOISE_PUBLIC_KEY_LEN) { + if (info->attrs[WGDEVICE_A_PRIVATE_KEY] && + nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) == + NOISE_PUBLIC_KEY_LEN) { + u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]); + u8 public_key[NOISE_PUBLIC_KEY_LEN]; struct wireguard_peer *peer, *temp; - u8 public_key[NOISE_PUBLIC_KEY_LEN], *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]); - /* We remove before setting, to prevent race, which means doing two 25519-genpub ops. */ + /* We remove before setting, to prevent race, which means doing + * two 25519-genpub ops. + */ if (curve25519_generate_public(public_key, private_key)) { - peer = pubkey_hashtable_lookup(&wg->peer_hashtable, public_key); + peer = pubkey_hashtable_lookup(&wg->peer_hashtable, + public_key); if (peer) { peer_put(peer); peer_remove(peer); @@ -440,8 +521,10 @@ static int set_device(struct sk_buff *skb, struct genl_info *info) } down_write(&wg->static_identity.lock); - noise_set_static_identity_private_key(&wg->static_identity, private_key); - list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) { + noise_set_static_identity_private_key(&wg->static_identity, + private_key); + list_for_each_entry_safe (peer, temp, &wg->peer_list, + peer_list) { if (!noise_precompute_static_static(peer)) peer_remove(peer); } @@ -453,8 +536,9 @@ static int set_device(struct sk_buff *skb, struct genl_info *info) int rem; struct nlattr *attr, *peer[WGPEER_A_MAX + 1]; - nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) { - ret = nla_parse_nested(peer, WGPEER_A_MAX, attr, peer_policy, NULL); + nla_for_each_nested (attr, info->attrs[WGDEVICE_A_PEERS], rem) { + ret = nla_parse_nested(peer, WGPEER_A_MAX, attr, + peer_policy, NULL); if (ret < 0) goto out; ret = set_peer(wg, peer); @@ -470,7 +554,8 @@ out: dev_put(wg->dev); out_nodev: if (info->attrs[WGDEVICE_A_PRIVATE_KEY]) - memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]), nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY])); + memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]), + nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY])); return ret; } diff --git a/src/noise.c b/src/noise.c index 0f6e51b..70b53a6 100644 --- a/src/noise.c +++ b/src/noise.c @@ -35,7 +35,8 @@ void __init noise_init(void) { struct blake2s_state blake; - blake2s(handshake_init_chaining_key, handshake_name, NULL, NOISE_HASH_LEN, sizeof(handshake_name), 0); + blake2s(handshake_init_chaining_key, handshake_name, NULL, + NOISE_HASH_LEN, sizeof(handshake_name), 0); blake2s_init(&blake, NOISE_HASH_LEN); blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN); blake2s_update(&blake, identifier_name, sizeof(identifier_name)); @@ -46,16 +47,25 @@ void __init noise_init(void) bool noise_precompute_static_static(struct wireguard_peer *peer) { bool ret = true; + down_write(&peer->handshake.lock); if (peer->handshake.static_identity->has_identity) - ret = curve25519(peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static); + ret = curve25519( + peer->handshake.precomputed_static_static, + peer->handshake.static_identity->static_private, + peer->handshake.remote_static); else - memset(peer->handshake.precomputed_static_static, 0, NOISE_PUBLIC_KEY_LEN); + memset(peer->handshake.precomputed_static_static, 0, + NOISE_PUBLIC_KEY_LEN); up_write(&peer->handshake.lock); return ret; } -bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer) +bool noise_handshake_init(struct noise_handshake *handshake, + struct noise_static_identity *static_identity, + const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], + const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], + struct wireguard_peer *peer) { memset(handshake, 0, sizeof(struct noise_handshake)); init_rwsem(&handshake->lock); @@ -63,7 +73,8 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static handshake->entry.peer = peer; memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN); if (peer_preshared_key) - memcpy(handshake->preshared_key, peer_preshared_key, NOISE_SYMMETRIC_KEY_LEN); + memcpy(handshake->preshared_key, peer_preshared_key, + NOISE_SYMMETRIC_KEY_LEN); handshake->static_identity = static_identity; handshake->state = HANDSHAKE_ZEROED; return noise_precompute_static_static(peer); @@ -81,16 +92,19 @@ static void handshake_zero(struct noise_handshake *handshake) void noise_handshake_clear(struct noise_handshake *handshake) { - index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, + &handshake->entry); down_write(&handshake->lock); handshake_zero(handshake); up_write(&handshake->lock); - index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, + &handshake->entry); } static struct noise_keypair *keypair_create(struct wireguard_peer *peer) { - struct noise_keypair *keypair = kzalloc(sizeof(struct noise_keypair), GFP_KERNEL); + struct noise_keypair *keypair = + kzalloc(sizeof(struct noise_keypair), GFP_KERNEL); if (unlikely(!keypair)) return NULL; @@ -108,9 +122,14 @@ static void keypair_free_rcu(struct rcu_head *rcu) static void keypair_free_kref(struct kref *kref) { - struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount); - net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id); - index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); + struct noise_keypair *keypair = + container_of(kref, struct noise_keypair, refcount); + net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", + keypair->entry.peer->device->dev->name, + keypair->internal_id, + keypair->entry.peer->internal_id); + index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, + &keypair->entry); call_rcu_bh(&keypair->rcu, keypair_free_rcu); } @@ -119,13 +138,16 @@ void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now) if (unlikely(!keypair)) return; if (unlikely(unreference_now)) - index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry); + index_hashtable_remove( + &keypair->entry.peer->device->index_hashtable, + &keypair->entry); kref_put(&keypair->refcount, keypair_free_kref); } struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair) { - RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking noise keypair reference without holding the RCU BH read lock"); + RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), + "Taking noise keypair reference without holding the RCU BH read lock"); if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount))) return NULL; return keypair; @@ -136,55 +158,67 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs) struct noise_keypair *old; spin_lock_bh(&keypairs->keypair_update_lock); - old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->previous_keypair, NULL); noise_keypair_put(old, true); - old = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->next_keypair, NULL); noise_keypair_put(old, true); - old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + old = rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); RCU_INIT_POINTER(keypairs->current_keypair, NULL); noise_keypair_put(old, true); spin_unlock_bh(&keypairs->keypair_update_lock); } -static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypair *new_keypair) +static void add_new_keypair(struct noise_keypairs *keypairs, + struct noise_keypair *new_keypair) { struct noise_keypair *previous_keypair, *next_keypair, *current_keypair; spin_lock_bh(&keypairs->keypair_update_lock); - previous_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - next_keypair = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - current_keypair = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); + previous_keypair = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + next_keypair = rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + current_keypair = rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); if (new_keypair->i_am_the_initiator) { - /* If we're the initiator, it means we've sent a handshake, and received - * a confirmation response, which means this new keypair can now be used. + /* If we're the initiator, it means we've sent a handshake, and + * received a confirmation response, which means this new + * keypair can now be used. */ if (next_keypair) { - /* If there already was a next keypair pending, we demote it to be - * the previous keypair, and free the existing current. - * TODO: note that this means KCI can result in this transition. It - * would perhaps be more sound to always just get rid of the unused - * next keypair instead of putting it in the previous slot, but this - * might be a bit less robust. Something to think about and decide on. + /* If there already was a next keypair pending, we + * demote it to be the previous keypair, and free the + * existing current. Note that this means KCI can result + * in this transition. It would perhaps be more sound to + * always just get rid of the unused next keypair + * instead of putting it in the previous slot, but this + * might be a bit less robust. Something to think about + * for the future. */ RCU_INIT_POINTER(keypairs->next_keypair, NULL); - rcu_assign_pointer(keypairs->previous_keypair, next_keypair); + rcu_assign_pointer(keypairs->previous_keypair, + next_keypair); noise_keypair_put(current_keypair, true); - } else /* If there wasn't an existing next keypair, we replace the - * previous with the current one. + } else /* If there wasn't an existing next keypair, we replace + * the previous with the current one. */ - rcu_assign_pointer(keypairs->previous_keypair, current_keypair); - /* At this point we can get rid of the old previous keypair, and set up - * the new keypair. + rcu_assign_pointer(keypairs->previous_keypair, + current_keypair); + /* At this point we can get rid of the old previous keypair, and + * set up the new keypair. */ noise_keypair_put(previous_keypair, true); rcu_assign_pointer(keypairs->current_keypair, new_keypair); } else { - /* If we're the responder, it means we can't use the new keypair until - * we receive confirmation via the first data packet, so we get rid of - * the existing previous one, the possibly existing next one, and slide - * in the new next one. + /* If we're the responder, it means we can't use the new keypair + * until we receive confirmation via the first data packet, so + * we get rid of the existing previous one, the possibly + * existing next one, and slide in the new next one. */ rcu_assign_pointer(keypairs->next_keypair, new_keypair); noise_keypair_put(next_keypair, true); @@ -194,19 +228,25 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai spin_unlock_bh(&keypairs->keypair_update_lock); } -bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair) +bool noise_received_with_keypair(struct noise_keypairs *keypairs, + struct noise_keypair *received_keypair) { - bool key_is_new; struct noise_keypair *old_keypair; + bool key_is_new; /* We first check without taking the spinlock. */ - key_is_new = received_keypair == rcu_access_pointer(keypairs->next_keypair); + key_is_new = received_keypair == + rcu_access_pointer(keypairs->next_keypair); if (likely(!key_is_new)) return false; spin_lock_bh(&keypairs->keypair_update_lock); - /* After locking, we double check that things didn't change from beneath us. */ - if (unlikely(received_keypair != rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)))) { + /* After locking, we double check that things didn't change from + * beneath us. + */ + if (unlikely(received_keypair != + rcu_dereference_protected(keypairs->next_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)))) { spin_unlock_bh(&keypairs->keypair_update_lock); return false; } @@ -215,8 +255,11 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k * into the current, the current into the previous, and get rid of * the old previous. */ - old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock)); - rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock))); + old_keypair = rcu_dereference_protected(keypairs->previous_keypair, + lockdep_is_held(&keypairs->keypair_update_lock)); + rcu_assign_pointer(keypairs->previous_keypair, + rcu_dereference_protected(keypairs->current_keypair, + lockdep_is_held(&keypairs->keypair_update_lock))); noise_keypair_put(old_keypair, true); rcu_assign_pointer(keypairs->current_keypair, received_keypair); RCU_INIT_POINTER(keypairs->next_keypair, NULL); @@ -226,34 +269,46 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k } /* Must hold static_identity->lock */ -void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]) +void noise_set_static_identity_private_key( + struct noise_static_identity *static_identity, + const u8 private_key[NOISE_PUBLIC_KEY_LEN]) { - memcpy(static_identity->static_private, private_key, NOISE_PUBLIC_KEY_LEN); - static_identity->has_identity = curve25519_generate_public(static_identity->static_public, private_key); + memcpy(static_identity->static_private, private_key, + NOISE_PUBLIC_KEY_LEN); + static_identity->has_identity = curve25519_generate_public( + static_identity->static_public, private_key); } /* This is Hugo Krawczyk's HKDF: * - https://eprint.iacr.org/2010/264.pdf * - https://tools.ietf.org/html/rfc5869 */ -static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, size_t first_len, size_t second_len, size_t third_len, size_t data_len, const u8 chaining_key[NOISE_HASH_LEN]) +static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, + size_t first_len, size_t second_len, size_t third_len, + size_t data_len, const u8 chaining_key[NOISE_HASH_LEN]) { - u8 secret[BLAKE2S_OUTBYTES]; u8 output[BLAKE2S_OUTBYTES + 1]; + u8 secret[BLAKE2S_OUTBYTES]; #ifdef DEBUG - BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES || third_len > BLAKE2S_OUTBYTES || ((second_len || second_dst || third_len || third_dst) && (!first_len || !first_dst)) || ((third_len || third_dst) && (!second_len || !second_dst))); + BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES || + third_len > BLAKE2S_OUTBYTES || + ((second_len || second_dst || third_len || third_dst) && + (!first_len || !first_dst)) || + ((third_len || third_dst) && (!second_len || !second_dst))); #endif /* Extract entropy from data into secret */ - blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len, NOISE_HASH_LEN); + blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len, + NOISE_HASH_LEN); if (!first_dst || !first_len) goto out; /* Expand first key: key = secret, data = 0x1 */ output[0] = 1; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1, + BLAKE2S_OUTBYTES); memcpy(first_dst, output, first_len); if (!second_dst || !second_len) @@ -261,7 +316,8 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, si /* Expand second key: key = secret, data = first-key || 0x2 */ output[BLAKE2S_OUTBYTES] = 2; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, + BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); memcpy(second_dst, output, second_len); if (!third_dst || !third_len) @@ -269,7 +325,8 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, si /* Expand third key: key = secret, data = second-key || 0x3 */ output[BLAKE2S_OUTBYTES] = 3; - blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, + BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); memcpy(third_dst, output, third_len); out: @@ -282,25 +339,34 @@ static void symmetric_key_init(struct noise_symmetric_key *key) { spin_lock_init(&key->counter.receive.lock); atomic64_set(&key->counter.counter, 0); - memset(key->counter.receive.backtrack, 0, sizeof(key->counter.receive.backtrack)); + memset(key->counter.receive.backtrack, 0, + sizeof(key->counter.receive.backtrack)); key->birthdate = ktime_get_boot_fast_ns(); key->is_valid = true; } -static void derive_keys(struct noise_symmetric_key *first_dst, struct noise_symmetric_key *second_dst, const u8 chaining_key[NOISE_HASH_LEN]) +static void derive_keys(struct noise_symmetric_key *first_dst, + struct noise_symmetric_key *second_dst, + const u8 chaining_key[NOISE_HASH_LEN]) { - kdf(first_dst->key, second_dst->key, NULL, NULL, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, chaining_key); + kdf(first_dst->key, second_dst->key, NULL, NULL, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, + chaining_key); symmetric_key_init(first_dst); symmetric_key_init(second_dst); } -static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 private[NOISE_PUBLIC_KEY_LEN], const u8 public[NOISE_PUBLIC_KEY_LEN]) +static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], + u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 private[NOISE_PUBLIC_KEY_LEN], + const u8 public[NOISE_PUBLIC_KEY_LEN]) { u8 dh_calculation[NOISE_PUBLIC_KEY_LEN]; if (unlikely(!curve25519(dh_calculation, private, public))) return false; - kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN); return true; } @@ -315,42 +381,59 @@ static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len) blake2s_final(&blake, hash, NOISE_HASH_LEN); } -static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) +static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], + u8 key[NOISE_SYMMETRIC_KEY_LEN], + const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) { u8 temp_hash[NOISE_HASH_LEN]; - kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key); + kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key); mix_hash(hash, temp_hash, NOISE_HASH_LEN); memzero_explicit(temp_hash, NOISE_HASH_LEN); } -static void handshake_init(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], const u8 remote_static[NOISE_PUBLIC_KEY_LEN]) +static void handshake_init(u8 chaining_key[NOISE_HASH_LEN], + u8 hash[NOISE_HASH_LEN], + const u8 remote_static[NOISE_PUBLIC_KEY_LEN]) { memcpy(hash, handshake_init_hash, NOISE_HASH_LEN); memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN); mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN); } -static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) +static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, + size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], + u8 hash[NOISE_HASH_LEN]) { - chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key); + chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash, + NOISE_HASH_LEN, + 0 /* Always zero for Noise_IK */, key); mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len)); } -static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) +static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, + size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], + u8 hash[NOISE_HASH_LEN]) { - if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key)) + if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len, + hash, NOISE_HASH_LEN, + 0 /* Always zero for Noise_IK */, key)) return false; mix_hash(hash, src_ciphertext, src_len); return true; } -static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN], const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN]) +static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN], + const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN], + u8 chaining_key[NOISE_HASH_LEN], + u8 hash[NOISE_HASH_LEN]) { if (ephemeral_dst != ephemeral_src) memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN); mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN); - kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0, + NOISE_PUBLIC_KEY_LEN, chaining_key); } static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN]) @@ -363,14 +446,15 @@ static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN]) *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec); } -bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake) +bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, + struct noise_handshake *handshake) { u8 timestamp[NOISE_TIMESTAMP_LEN]; u8 key[NOISE_SYMMETRIC_KEY_LEN]; bool ret = false; - /* We need to wait for crng _before_ taking any locks, since curve25519_generate_secret - * uses get_random_bytes_wait. + /* We need to wait for crng _before_ taking any locks, since + * curve25519_generate_secret uses get_random_bytes_wait. */ wait_for_random_bytes(); @@ -382,29 +466,42 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION); - handshake_init(handshake->chaining_key, handshake->hash, handshake->remote_static); + handshake_init(handshake->chaining_key, handshake->hash, + handshake->remote_static); /* e */ curve25519_generate_secret(handshake->ephemeral_private); - if (!curve25519_generate_public(dst->unencrypted_ephemeral, handshake->ephemeral_private)) + if (!curve25519_generate_public(dst->unencrypted_ephemeral, + handshake->ephemeral_private)) goto out; - message_ephemeral(dst->unencrypted_ephemeral, dst->unencrypted_ephemeral, handshake->chaining_key, handshake->hash); + message_ephemeral(dst->unencrypted_ephemeral, + dst->unencrypted_ephemeral, handshake->chaining_key, + handshake->hash); /* es */ - if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private, + handshake->remote_static)) goto out; /* s */ - message_encrypt(dst->encrypted_static, handshake->static_identity->static_public, NOISE_PUBLIC_KEY_LEN, key, handshake->hash); + message_encrypt(dst->encrypted_static, + handshake->static_identity->static_public, + NOISE_PUBLIC_KEY_LEN, key, handshake->hash); /* ss */ - kdf(handshake->chaining_key, key, NULL, handshake->precomputed_static_static, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, handshake->chaining_key); + kdf(handshake->chaining_key, key, NULL, + handshake->precomputed_static_static, NOISE_HASH_LEN, + NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, + handshake->chaining_key); /* {t} */ tai64n_now(timestamp); - message_encrypt(dst->encrypted_timestamp, timestamp, NOISE_TIMESTAMP_LEN, key, handshake->hash); + message_encrypt(dst->encrypted_timestamp, timestamp, + NOISE_TIMESTAMP_LEN, key, handshake->hash); - dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + dst->sender_index = index_hashtable_insert( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry); handshake->state = HANDSHAKE_CREATED_INITIATION; ret = true; @@ -416,17 +513,19 @@ out: return ret; } -struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg) +struct wireguard_peer * +noise_handshake_consume_initiation(struct message_handshake_initiation *src, + struct wireguard_device *wg) { + struct wireguard_peer *peer = NULL, *ret_peer = NULL; + struct noise_handshake *handshake; bool replay_attack, flood_attack; + u8 key[NOISE_SYMMETRIC_KEY_LEN]; + u8 chaining_key[NOISE_HASH_LEN]; + u8 hash[NOISE_HASH_LEN]; u8 s[NOISE_PUBLIC_KEY_LEN]; u8 e[NOISE_PUBLIC_KEY_LEN]; u8 t[NOISE_TIMESTAMP_LEN]; - struct noise_handshake *handshake; - struct wireguard_peer *peer = NULL, *ret_peer = NULL; - u8 key[NOISE_SYMMETRIC_KEY_LEN]; - u8 hash[NOISE_HASH_LEN]; - u8 chaining_key[NOISE_HASH_LEN]; down_read(&wg->static_identity.lock); if (unlikely(!wg->static_identity.has_identity)) @@ -442,7 +541,8 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha goto out; /* s */ - if (!message_decrypt(s, src->encrypted_static, sizeof(src->encrypted_static), key, hash)) + if (!message_decrypt(s, src->encrypted_static, + sizeof(src->encrypted_static), key, hash)) goto out; /* Lookup which peer we're actually talking to */ @@ -452,15 +552,21 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha handshake = &peer->handshake; /* ss */ - kdf(chaining_key, key, NULL, handshake->precomputed_static_static, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); + kdf(chaining_key, key, NULL, handshake->precomputed_static_static, + NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, + chaining_key); /* {t} */ - if (!message_decrypt(t, src->encrypted_timestamp, sizeof(src->encrypted_timestamp), key, hash)) + if (!message_decrypt(t, src->encrypted_timestamp, + sizeof(src->encrypted_timestamp), key, hash)) goto out; down_read(&handshake->lock); - replay_attack = memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) <= 0; - flood_attack = handshake->last_initiation_consumption + NSEC_PER_SEC / INITIATIONS_PER_SECOND > ktime_get_boot_fast_ns(); + replay_attack = memcmp(t, handshake->latest_timestamp, + NOISE_TIMESTAMP_LEN) <= 0; + flood_attack = handshake->last_initiation_consumption + + NSEC_PER_SEC / INITIATIONS_PER_SECOND > + ktime_get_boot_fast_ns(); up_read(&handshake->lock); if (replay_attack || flood_attack) goto out; @@ -487,13 +593,14 @@ out: return ret_peer; } -bool noise_handshake_create_response(struct message_handshake_response *dst, struct noise_handshake *handshake) +bool noise_handshake_create_response(struct message_handshake_response *dst, + struct noise_handshake *handshake) { bool ret = false; u8 key[NOISE_SYMMETRIC_KEY_LEN]; - /* We need to wait for crng _before_ taking any locks, since curve25519_generate_secret - * uses get_random_bytes_wait. + /* We need to wait for crng _before_ taking any locks, since + * curve25519_generate_secret uses get_random_bytes_wait. */ wait_for_random_bytes(); @@ -508,25 +615,33 @@ bool noise_handshake_create_response(struct message_handshake_response *dst, str /* e */ curve25519_generate_secret(handshake->ephemeral_private); - if (!curve25519_generate_public(dst->unencrypted_ephemeral, handshake->ephemeral_private)) + if (!curve25519_generate_public(dst->unencrypted_ephemeral, + handshake->ephemeral_private)) goto out; - message_ephemeral(dst->unencrypted_ephemeral, dst->unencrypted_ephemeral, handshake->chaining_key, handshake->hash); + message_ephemeral(dst->unencrypted_ephemeral, + dst->unencrypted_ephemeral, handshake->chaining_key, + handshake->hash); /* ee */ - if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_ephemeral)) + if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, + handshake->remote_ephemeral)) goto out; /* se */ - if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, + handshake->remote_static)) goto out; /* psk */ - mix_psk(handshake->chaining_key, handshake->hash, key, handshake->preshared_key); + mix_psk(handshake->chaining_key, handshake->hash, key, + handshake->preshared_key); /* {} */ message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash); - dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); + dst->sender_index = index_hashtable_insert( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry); handshake->state = HANDSHAKE_CREATED_RESPONSE; ret = true; @@ -538,7 +653,9 @@ out: return ret; } -struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg) +struct wireguard_peer * +noise_handshake_consume_response(struct message_handshake_response *src, + struct wireguard_device *wg) { struct noise_handshake *handshake; struct wireguard_peer *peer = NULL, *ret_peer = NULL; @@ -555,7 +672,9 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake if (unlikely(!wg->static_identity.has_identity)) goto out; - handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer); + handshake = (struct noise_handshake *)index_hashtable_lookup( + &wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, + src->receiver_index, &peer); if (unlikely(!handshake)) goto out; @@ -563,7 +682,8 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake state = handshake->state; memcpy(hash, handshake->hash, NOISE_HASH_LEN); memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN); - memcpy(ephemeral_private, handshake->ephemeral_private, NOISE_PUBLIC_KEY_LEN); + memcpy(ephemeral_private, handshake->ephemeral_private, + NOISE_PUBLIC_KEY_LEN); up_read(&handshake->lock); if (state != HANDSHAKE_CREATED_INITIATION) @@ -584,12 +704,15 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake mix_psk(chaining_key, hash, key, handshake->preshared_key); /* {} */ - if (!message_decrypt(NULL, src->encrypted_nothing, sizeof(src->encrypted_nothing), key, hash)) + if (!message_decrypt(NULL, src->encrypted_nothing, + sizeof(src->encrypted_nothing), key, hash)) goto fail; /* Success! Copy everything to peer */ down_write(&handshake->lock); - /* It's important to check that the state is still the same, while we have an exclusive lock */ + /* It's important to check that the state is still the same, while we + * have an exclusive lock. + */ if (handshake->state != state) { up_write(&handshake->lock); goto fail; @@ -615,32 +738,43 @@ out: return ret_peer; } -bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs) +bool noise_handshake_begin_session(struct noise_handshake *handshake, + struct noise_keypairs *keypairs) { struct noise_keypair *new_keypair; bool ret = false; down_write(&handshake->lock); - if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE) + if (handshake->state != HANDSHAKE_CREATED_RESPONSE && + handshake->state != HANDSHAKE_CONSUMED_RESPONSE) goto out; new_keypair = keypair_create(handshake->entry.peer); if (!new_keypair) goto out; - new_keypair->i_am_the_initiator = handshake->state == HANDSHAKE_CONSUMED_RESPONSE; + new_keypair->i_am_the_initiator = handshake->state == + HANDSHAKE_CONSUMED_RESPONSE; new_keypair->remote_index = handshake->remote_index; if (new_keypair->i_am_the_initiator) - derive_keys(&new_keypair->sending, &new_keypair->receiving, handshake->chaining_key); + derive_keys(&new_keypair->sending, &new_keypair->receiving, + handshake->chaining_key); else - derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key); + derive_keys(&new_keypair->receiving, &new_keypair->sending, + handshake->chaining_key); handshake_zero(handshake); rcu_read_lock_bh(); - if (likely(!container_of(handshake, struct wireguard_peer, handshake)->is_dead)) { + if (likely(!container_of(handshake, struct wireguard_peer, + handshake)->is_dead)) { add_new_keypair(keypairs, new_keypair); - net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id); - ret = index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry); + net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", + handshake->entry.peer->device->dev->name, + new_keypair->internal_id, + handshake->entry.peer->internal_id); + ret = index_hashtable_replace( + &handshake->entry.peer->device->index_hashtable, + &handshake->entry, &new_keypair->entry); } else kzfree(new_keypair); rcu_read_unlock_bh(); diff --git a/src/noise.h b/src/noise.h index be59587..6a563ce 100644 --- a/src/noise.h +++ b/src/noise.h @@ -86,29 +86,44 @@ struct noise_handshake { u8 latest_timestamp[NOISE_TIMESTAMP_LEN]; __le32 remote_index; - /* Protects all members except the immutable (after noise_handshake_init): remote_static, precomputed_static_static, static_identity */ + /* Protects all members except the immutable (after noise_handshake_ + * init): remote_static, precomputed_static_static, static_identity. */ struct rw_semaphore lock; }; struct wireguard_device; void noise_init(void); -bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer); +bool noise_handshake_init(struct noise_handshake *handshake, + struct noise_static_identity *static_identity, + const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], + const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], + struct wireguard_peer *peer); void noise_handshake_clear(struct noise_handshake *handshake); void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now); struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair); void noise_keypairs_clear(struct noise_keypairs *keypairs); -bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair); +bool noise_received_with_keypair(struct noise_keypairs *keypairs, + struct noise_keypair *received_keypair); -void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]); +void noise_set_static_identity_private_key( + struct noise_static_identity *static_identity, + const u8 private_key[NOISE_PUBLIC_KEY_LEN]); bool noise_precompute_static_static(struct wireguard_peer *peer); -bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake); -struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg); +bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, + struct noise_handshake *handshake); +struct wireguard_peer * +noise_handshake_consume_initiation(struct message_handshake_initiation *src, + struct wireguard_device *wg); -bool noise_handshake_create_response(struct message_handshake_response *dst, struct noise_handshake *handshake); -struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg); +bool noise_handshake_create_response(struct message_handshake_response *dst, + struct noise_handshake *handshake); +struct wireguard_peer * +noise_handshake_consume_response(struct message_handshake_response *src, + struct wireguard_device *wg); -bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs); +bool noise_handshake_begin_session(struct noise_handshake *handshake, + struct noise_keypairs *keypairs); #endif /* _WG_NOISE_H */ diff --git a/src/peer.c b/src/peer.c index b9703d5..c079b71 100644 --- a/src/peer.c +++ b/src/peer.c @@ -17,7 +17,10 @@ static atomic64_t peer_counter = ATOMIC64_INIT(0); -struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]) +struct wireguard_peer * +peer_create(struct wireguard_device *wg, + const u8 public_key[NOISE_PUBLIC_KEY_LEN], + const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]) { struct wireguard_peer *peer; @@ -31,11 +34,13 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ return NULL; peer->device = wg; - if (!noise_handshake_init(&peer->handshake, &wg->static_identity, public_key, preshared_key, peer)) + if (!noise_handshake_init(&peer->handshake, &wg->static_identity, + public_key, preshared_key, peer)) goto err_1; if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL)) goto err_1; - if (packet_queue_init(&peer->tx_queue, packet_tx_worker, false, MAX_QUEUED_PACKETS)) + if (packet_queue_init(&peer->tx_queue, packet_tx_worker, false, + MAX_QUEUED_PACKETS)) goto err_2; if (packet_queue_init(&peer->rx_queue, NULL, false, MAX_QUEUED_PACKETS)) goto err_3; @@ -50,7 +55,9 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ rwlock_init(&peer->endpoint_lock); kref_init(&peer->refcount); skb_queue_head_init(&peer->staged_packet_queue); - atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns() - (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC); + atomic64_set(&peer->last_sent_handshake, + ktime_get_boot_fast_ns() - + (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC); set_bit(NAPI_STATE_NO_BUSY_POLL, &peer->napi.state); netif_napi_add(wg->dev, &peer->napi, packet_rx_poll, NAPI_POLL_WEIGHT); napi_enable(&peer->napi); @@ -71,15 +78,16 @@ err_1: struct wireguard_peer *peer_get_maybe_zero(struct wireguard_peer *peer) { - RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking peer reference without holding the RCU read lock"); + RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), + "Taking peer reference without holding the RCU read lock"); if (unlikely(!peer || !kref_get_unless_zero(&peer->refcount))) return NULL; return peer; } -/* We have a separate "remove" function to get rid of the final reference because - * peer_list, clearing handshakes, and flushing all require mutexes which requires - * sleeping, which must only be done from certain contexts. +/* We have a separate "remove" function to get rid of the final reference + * because peer_list, clearing handshakes, and flushing all require mutexes + * which requires sleeping, which must only be done from certain contexts. */ void peer_remove(struct wireguard_peer *peer) { @@ -87,37 +95,49 @@ void peer_remove(struct wireguard_peer *peer) return; lockdep_assert_held(&peer->device->device_update_lock); - /* Remove from configuration-time lookup structures so new packets can't enter. */ + /* Remove from configuration-time lookup structures so new packets + * can't enter. + */ list_del_init(&peer->peer_list); - allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, &peer->device->device_update_lock); + allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, + &peer->device->device_update_lock); pubkey_hashtable_remove(&peer->device->peer_hashtable, peer); /* Mark as dead, so that we don't allow jumping contexts after. */ WRITE_ONCE(peer->is_dead, true); synchronize_rcu_bh(); - /* Now that no more keypairs can be created for this peer, we destroy existing ones. */ + /* Now that no more keypairs can be created for this peer, we destroy + * existing ones. + */ noise_keypairs_clear(&peer->keypairs); - /* Destroy all ongoing timers that were in-flight at the beginning of this function. */ + /* Destroy all ongoing timers that were in-flight at the beginning of + * this function. + */ timers_stop(peer); - /* The transition between packet encryption/decryption queues isn't guarded - * by is_dead, but each reference's life is strictly bounded by two - * generations: once for parallel crypto and once for serial ingestion, - * so we can simply flush twice, and be sure that we no longer have references - * inside these queues. - * - * a) For encrypt/decrypt. */ + /* The transition between packet encryption/decryption queues isn't + * guarded by is_dead, but each reference's life is strictly bounded by + * two generations: once for parallel crypto and once for serial + * ingestion, so we can simply flush twice, and be sure that we no + * longer have references inside these queues. + */ + + /* a) For encrypt/decrypt. */ flush_workqueue(peer->device->packet_crypt_wq); /* b.1) For send (but not receive, since that's napi). */ flush_workqueue(peer->device->packet_crypt_wq); /* b.2.1) For receive (but not send, since that's wq). */ napi_disable(&peer->napi); - /* b.2.1) It's now safe to remove the napi struct, which must be done here from process context. */ + /* b.2.1) It's now safe to remove the napi struct, which must be done + * here from process context. + */ netif_napi_del(&peer->napi); - /* Ensure any workstructs we own (like transmit_handshake_work or clear_peer_work) no longer are in use. */ + /* Ensure any workstructs we own (like transmit_handshake_work or + * clear_peer_work) no longer are in use. + */ flush_workqueue(peer->device->handshake_send_wq); --peer->device->num_peers; @@ -126,7 +146,8 @@ void peer_remove(struct wireguard_peer *peer) static void rcu_release(struct rcu_head *rcu) { - struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu); + struct wireguard_peer *peer = + container_of(rcu, struct wireguard_peer, rcu); dst_cache_destroy(&peer->endpoint_cache); packet_queue_free(&peer->rx_queue, false); packet_queue_free(&peer->tx_queue, false); @@ -135,11 +156,19 @@ static void rcu_release(struct rcu_head *rcu) static void kref_release(struct kref *refcount) { - struct wireguard_peer *peer = container_of(refcount, struct wireguard_peer, refcount); - pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); - /* Remove ourself from dynamic runtime lookup structures, now that the last reference is gone. */ - index_hashtable_remove(&peer->device->index_hashtable, &peer->handshake.entry); - /* Remove any lingering packets that didn't have a chance to be transmitted. */ + struct wireguard_peer *peer = + container_of(refcount, struct wireguard_peer, refcount); + pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr); + /* Remove ourself from dynamic runtime lookup structures, now that the + * last reference is gone. + */ + index_hashtable_remove(&peer->device->index_hashtable, + &peer->handshake.entry); + /* Remove any lingering packets that didn't have a chance to be + * transmitted. + */ skb_queue_purge(&peer->staged_packet_queue); /* Free the memory used. */ call_rcu_bh(&peer->rcu, rcu_release); @@ -157,6 +186,6 @@ void peer_remove_all(struct wireguard_device *wg) struct wireguard_peer *peer, *temp; lockdep_assert_held(&wg->device_update_lock); - list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) + list_for_each_entry_safe (peer, temp, &wg->peer_list, peer_list) peer_remove(peer); } diff --git a/src/peer.h b/src/peer.h index 29d2e00..5613ccc 100644 --- a/src/peer.h +++ b/src/peer.h @@ -27,7 +27,8 @@ struct endpoint { union { struct { struct in_addr src4; - int src_if4; /* Essentially the same as addr6->scope_id */ + /* Essentially the same as addr6->scope_id */ + int src_if4; }; struct in6_addr src6; }; @@ -48,10 +49,13 @@ struct wireguard_peer { struct cookie latest_cookie; struct hlist_node pubkey_hash; u64 rx_bytes, tx_bytes; - struct timer_list timer_retransmit_handshake, timer_send_keepalive, timer_new_handshake, timer_zero_key_material, timer_persistent_keepalive; + struct timer_list timer_retransmit_handshake, timer_send_keepalive; + struct timer_list timer_new_handshake, timer_zero_key_material; + struct timer_list timer_persistent_keepalive; unsigned int timer_handshake_attempts; u16 persistent_keepalive_interval; - bool timers_enabled, timer_need_another_keepalive, sent_lastminute_handshake; + bool timers_enabled, timer_need_another_keepalive; + bool sent_lastminute_handshake; struct timespec walltime_last_handshake; struct kref refcount; struct rcu_head rcu; @@ -61,9 +65,13 @@ struct wireguard_peer { bool is_dead; }; -struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); +struct wireguard_peer * +peer_create(struct wireguard_device *wg, + const u8 public_key[NOISE_PUBLIC_KEY_LEN], + const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); -struct wireguard_peer * __must_check peer_get_maybe_zero(struct wireguard_peer *peer); +struct wireguard_peer *__must_check +peer_get_maybe_zero(struct wireguard_peer *peer); static inline struct wireguard_peer *peer_get(struct wireguard_peer *peer) { kref_get(&peer->refcount); @@ -73,6 +81,7 @@ void peer_put(struct wireguard_peer *peer); void peer_remove(struct wireguard_peer *peer); void peer_remove_all(struct wireguard_device *wg); -struct wireguard_peer *peer_lookup_by_index(struct wireguard_device *wg, u32 index); +struct wireguard_peer *peer_lookup_by_index(struct wireguard_device *wg, + u32 index); #endif /* _WG_PEER_H */ diff --git a/src/queueing.c b/src/queueing.c index c8394fc..9ec6588 100644 --- a/src/queueing.c +++ b/src/queueing.c @@ -5,22 +5,25 @@ #include "queueing.h" -struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr) +struct multicore_worker __percpu * +packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr) { int cpu; - struct multicore_worker __percpu *worker = alloc_percpu(struct multicore_worker); + struct multicore_worker __percpu *worker = + alloc_percpu(struct multicore_worker); if (!worker) return NULL; - for_each_possible_cpu(cpu) { + for_each_possible_cpu (cpu) { per_cpu_ptr(worker, cpu)->ptr = ptr; INIT_WORK(&per_cpu_ptr(worker, cpu)->work, function); } return worker; } -int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore, unsigned int len) +int packet_queue_init(struct crypt_queue *queue, work_func_t function, + bool multicore, unsigned int len) { int ret; @@ -30,7 +33,8 @@ int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool mult return ret; if (function) { if (multicore) { - queue->worker = packet_alloc_percpu_multicore_worker(function, queue); + queue->worker = packet_alloc_percpu_multicore_worker( + function, queue); if (!queue->worker) return -ENOMEM; } else diff --git a/src/queueing.h b/src/queueing.h index 6a1de33..66b7134 100644 --- a/src/queueing.h +++ b/src/queueing.h @@ -19,9 +19,11 @@ struct crypt_queue; struct sk_buff; /* queueing.c APIs: */ -int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore, unsigned int len); +int packet_queue_init(struct crypt_queue *queue, work_func_t function, + bool multicore, unsigned int len); void packet_queue_free(struct crypt_queue *queue, bool multicore); -struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr); +struct multicore_worker __percpu * +packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr); /* receive.c APIs: */ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb); @@ -32,9 +34,12 @@ int packet_rx_poll(struct napi_struct *napi, int budget); void packet_decrypt_worker(struct work_struct *work); /* send.c APIs: */ -void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool is_retry); +void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, + bool is_retry); void packet_send_handshake_response(struct wireguard_peer *peer); -void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index); +void packet_send_handshake_cookie(struct wireguard_device *wg, + struct sk_buff *initiating_skb, + __le32 sender_index); void packet_send_keepalive(struct wireguard_peer *peer); void packet_send_staged_packets(struct wireguard_peer *peer); /* Workqueue workers: */ @@ -42,7 +47,12 @@ void packet_handshake_send_worker(struct work_struct *work); void packet_tx_worker(struct work_struct *work); void packet_encrypt_worker(struct work_struct *work); -enum packet_state { PACKET_STATE_UNCRYPTED, PACKET_STATE_CRYPTED, PACKET_STATE_DEAD }; +enum packet_state { + PACKET_STATE_UNCRYPTED, + PACKET_STATE_CRYPTED, + PACKET_STATE_DEAD +}; + struct packet_cb { u64 nonce; struct noise_keypair *keypair; @@ -50,15 +60,22 @@ struct packet_cb { u32 mtu; u8 ds; }; + #define PACKET_PEER(skb) (((struct packet_cb *)skb->cb)->keypair->entry.peer) #define PACKET_CB(skb) ((struct packet_cb *)skb->cb) /* Returns either the correct skb->protocol value, or 0 if invalid. */ static inline __be16 skb_examine_untrusted_ip_hdr(struct sk_buff *skb) { - if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct iphdr)) <= skb_tail_pointer(skb) && ip_hdr(skb)->version == 4) + if (skb_network_header(skb) >= skb->head && + (skb_network_header(skb) + sizeof(struct iphdr)) <= + skb_tail_pointer(skb) && + ip_hdr(skb)->version == 4) return htons(ETH_P_IP); - if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct ipv6hdr)) <= skb_tail_pointer(skb) && ipv6_hdr(skb)->version == 6) + if (skb_network_header(skb) >= skb->head && + (skb_network_header(skb) + sizeof(struct ipv6hdr)) <= + skb_tail_pointer(skb) && + ipv6_hdr(skb)->version == 6) return htons(ETH_P_IPV6); return 0; } @@ -67,7 +84,9 @@ static inline void skb_reset(struct sk_buff *skb) { const int pfmemalloc = skb->pfmemalloc; skb_scrub_packet(skb, true); - memset(&skb->headers_start, 0, offsetof(struct sk_buff, headers_end) - offsetof(struct sk_buff, headers_start)); + memset(&skb->headers_start, 0, + offsetof(struct sk_buff, headers_end) - + offsetof(struct sk_buff, headers_start)); skb->pfmemalloc = pfmemalloc; skb->queue_mapping = 0; skb->nohdr = 0; @@ -89,7 +108,8 @@ static inline int cpumask_choose_online(int *stored_cpu, unsigned int id) { unsigned int cpu = *stored_cpu, cpu_index, i; - if (unlikely(cpu == nr_cpumask_bits || !cpumask_test_cpu(cpu, cpu_online_mask))) { + if (unlikely(cpu == nr_cpumask_bits || + !cpumask_test_cpu(cpu, cpu_online_mask))) { cpu_index = id % cpumask_weight(cpu_online_mask); cpu = cpumask_first(cpu_online_mask); for (i = 0; i < cpu_index; ++i) @@ -103,8 +123,8 @@ static inline int cpumask_choose_online(int *stored_cpu, unsigned int id) * the same CPU twice. A race-free version of this would be to instead store an * atomic sequence number, do an increment-and-return, and then iterate through * every possible CPU until we get to that index -- choose_cpu. However that's - * a bit slower, and it doesn't seem like this potential race actually introduces - * any performance loss, so we live with it. + * a bit slower, and it doesn't seem like this potential race actually + * introduces any performance loss, so we live with it. */ static inline int cpumask_next_online(int *next) { @@ -116,7 +136,9 @@ static inline int cpumask_next_online(int *next) return cpu; } -static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_queue, struct crypt_queue *peer_queue, struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu) +static inline int queue_enqueue_per_device_and_peer( + struct crypt_queue *device_queue, struct crypt_queue *peer_queue, + struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu) { int cpu; @@ -136,18 +158,24 @@ static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_q return 0; } -static inline void queue_enqueue_per_peer(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) +static inline void queue_enqueue_per_peer(struct crypt_queue *queue, + struct sk_buff *skb, + enum packet_state state) { /* We take a reference, because as soon as we call atomic_set, the * peer can be freed from below us. */ struct wireguard_peer *peer = peer_get(PACKET_PEER(skb)); atomic_set_release(&PACKET_CB(skb)->state, state); - queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &queue->work); + queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, + peer->internal_id), + peer->device->packet_crypt_wq, &queue->work); peer_put(peer); } -static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) +static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue, + struct sk_buff *skb, + enum packet_state state) { /* We take a reference, because as soon as we call atomic_set, the * peer can be freed from below us. diff --git a/src/ratelimiter.c b/src/ratelimiter.c index 638da2b..3966ce8 100644 --- a/src/ratelimiter.c +++ b/src/ratelimiter.c @@ -41,7 +41,8 @@ enum { static void entry_free(struct rcu_head *rcu) { - kmem_cache_free(entry_cache, container_of(rcu, struct ratelimiter_entry, rcu)); + kmem_cache_free(entry_cache, + container_of(rcu, struct ratelimiter_entry, rcu)); atomic_dec(&total_entries); } @@ -54,20 +55,22 @@ static void entry_uninit(struct ratelimiter_entry *entry) /* Calling this function with a NULL work uninits all entries. */ static void gc_entries(struct work_struct *work) { - unsigned int i; + const u64 now = ktime_get_boot_fast_ns(); struct ratelimiter_entry *entry; struct hlist_node *temp; - const u64 now = ktime_get_boot_fast_ns(); + unsigned int i; for (i = 0; i < table_size; ++i) { spin_lock(&table_lock); - hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) { - if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC) + hlist_for_each_entry_safe (entry, temp, &table_v4[i], hash) { + if (unlikely(!work) || + now - entry->last_time_ns > NSEC_PER_SEC) entry_uninit(entry); } #if IS_ENABLED(CONFIG_IPV6) - hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) { - if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC) + hlist_for_each_entry_safe (entry, temp, &table_v6[i], hash) { + if (unlikely(!work) || + now - entry->last_time_ns > NSEC_PER_SEC) entry_uninit(entry); } #endif @@ -81,34 +84,41 @@ static void gc_entries(struct work_struct *work) bool ratelimiter_allow(struct sk_buff *skb, struct net *net) { + struct { __be64 ip; u32 net; } data = + { .net = (unsigned long)net & 0xffffffff }; struct ratelimiter_entry *entry; struct hlist_head *bucket; - struct { __be64 ip; u32 net; } data = { .net = (unsigned long)net & 0xffffffff }; if (skb->protocol == htons(ETH_P_IP)) { data.ip = (__force __be64)ip_hdr(skb)->saddr; - bucket = &table_v4[hsiphash(&data, sizeof(u32) * 3, &key) & (table_size - 1)]; + bucket = &table_v4[hsiphash(&data, sizeof(u32) * 3, &key) & + (table_size - 1)]; } #if IS_ENABLED(CONFIG_IPV6) else if (skb->protocol == htons(ETH_P_IPV6)) { - memcpy(&data.ip, &ipv6_hdr(skb)->saddr, sizeof(__be64)); /* Only 64 bits */ - bucket = &table_v6[hsiphash(&data, sizeof(u32) * 3, &key) & (table_size - 1)]; + memcpy(&data.ip, &ipv6_hdr(skb)->saddr, + sizeof(__be64)); /* Only 64 bits */ + bucket = &table_v6[hsiphash(&data, sizeof(u32) * 3, &key) & + (table_size - 1)]; } #endif else return false; rcu_read_lock(); - hlist_for_each_entry_rcu(entry, bucket, hash) { + hlist_for_each_entry_rcu (entry, bucket, hash) { if (entry->net == net && entry->ip == data.ip) { u64 now, tokens; bool ret; - /* Inspired by nft_limit.c, but this is actually a slightly different - * algorithm. Namely, we incorporate the burst as part of the maximum - * tokens, rather than as part of the rate. + /* Quasi-inspired by nft_limit.c, but this is actually a + * slightly different algorithm. Namely, we incorporate + * the burst as part of the maximum tokens, rather than + * as part of the rate. */ spin_lock(&entry->lock); now = ktime_get_boot_fast_ns(); - tokens = min_t(u64, TOKEN_MAX, entry->tokens + now - entry->last_time_ns); + tokens = min_t(u64, TOKEN_MAX, + entry->tokens + now - + entry->last_time_ns); entry->last_time_ns = now; ret = tokens >= PACKET_COST; entry->tokens = ret ? tokens - PACKET_COST : tokens; @@ -157,7 +167,10 @@ int ratelimiter_init(void) * we borrow their wisdom about good table sizes on different systems * dependent on RAM. This calculation here comes from there. */ - table_size = (totalram_pages > (1U << 30) / PAGE_SIZE) ? 8192 : max_t(unsigned long, 16, roundup_pow_of_two((totalram_pages << PAGE_SHIFT) / (1U << 14) / sizeof(struct hlist_head))); + table_size = (totalram_pages > (1U << 30) / PAGE_SIZE) ? 8192 : + max_t(unsigned long, 16, roundup_pow_of_two( + (totalram_pages << PAGE_SHIFT) / + (1U << 14) / sizeof(struct hlist_head))); max_entries = table_size * 8; table_v4 = kvzalloc(table_size * sizeof(struct hlist_head), GFP_KERNEL); diff --git a/src/receive.c b/src/receive.c index 4e73da1..a2d1d9d 100644 --- a/src/receive.c +++ b/src/receive.c @@ -20,7 +20,8 @@ /* Must be called with bh disabled. */ static inline void rx_stats(struct wireguard_peer *peer, size_t len) { - struct pcpu_sw_netstats *tstats = get_cpu_ptr(peer->device->dev->tstats); + struct pcpu_sw_netstats *tstats = + get_cpu_ptr(peer->device->dev->tstats); u64_stats_update_begin(&tstats->syncp); ++tstats->rx_packets; @@ -36,38 +37,57 @@ static inline size_t validate_header_len(struct sk_buff *skb) { if (unlikely(skb->len < sizeof(struct message_header))) return 0; - if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH) + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) && + skb->len >= MESSAGE_MINIMUM_LENGTH) return sizeof(struct message_data); - if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation)) + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && + skb->len == sizeof(struct message_handshake_initiation)) return sizeof(struct message_handshake_initiation); - if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response)) + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && + skb->len == sizeof(struct message_handshake_response)) return sizeof(struct message_handshake_response); - if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie)) + if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && + skb->len == sizeof(struct message_handshake_cookie)) return sizeof(struct message_handshake_cookie); return 0; } -static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_device *wg) +static inline int skb_prepare_header(struct sk_buff *skb, + struct wireguard_device *wg) { - struct udphdr *udp; size_t data_offset, data_len, header_len; + struct udphdr *udp; - if (unlikely(skb_examine_untrusted_ip_hdr(skb) != skb->protocol || skb_transport_header(skb) < skb->head || (skb_transport_header(skb) + sizeof(struct udphdr)) > skb_tail_pointer(skb))) + if (unlikely(skb_examine_untrusted_ip_hdr(skb) != skb->protocol || + skb_transport_header(skb) < skb->head || + (skb_transport_header(skb) + sizeof(struct udphdr)) > + skb_tail_pointer(skb))) return -EINVAL; /* Bogus IP header */ udp = udp_hdr(skb); data_offset = (u8 *)udp - skb->data; - if (unlikely(data_offset > U16_MAX || data_offset + sizeof(struct udphdr) > skb->len)) - return -EINVAL; /* Packet has offset at impossible location or isn't big enough to have UDP fields */ + if (unlikely(data_offset > U16_MAX || + data_offset + sizeof(struct udphdr) > skb->len)) + /* Packet has offset at impossible location or isn't big enough + * to have UDP fields. + */ + return -EINVAL; data_len = ntohs(udp->len); - if (unlikely(data_len < sizeof(struct udphdr) || data_len > skb->len - data_offset)) - return -EINVAL; /* UDP packet is reporting too small of a size or lying about its size */ + if (unlikely(data_len < sizeof(struct udphdr) || + data_len > skb->len - data_offset)) + /* UDP packet is reporting too small of a size or lying about + * its size. + */ + return -EINVAL; data_len -= sizeof(struct udphdr); data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data; - if (unlikely(!pskb_may_pull(skb, data_offset + sizeof(struct message_header)) || pskb_trim(skb, data_len + data_offset) < 0)) + if (unlikely(!pskb_may_pull(skb, + data_offset + sizeof(struct message_header)) || + pskb_trim(skb, data_len + data_offset) < 0)) return -EINVAL; skb_pull(skb, data_offset); if (unlikely(skb->len != data_len)) - return -EINVAL; /* Final len does not agree with calculated len */ + /* Final len does not agree with calculated len */ + return -EINVAL; header_len = validate_header_len(skb); if (unlikely(!header_len)) return -EINVAL; @@ -78,74 +98,96 @@ static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_devic return 0; } -static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff *skb) +static void receive_handshake_packet(struct wireguard_device *wg, + struct sk_buff *skb) { - static u64 last_under_load; /* Yes this is global, so that our load calculation applies to the whole system. */ struct wireguard_peer *peer = NULL; - bool under_load; enum cookie_mac_state mac_state; + /* This is global, so that our load calculation applies to + * the whole system. + */ + static u64 last_under_load; bool packet_needs_cookie; + bool under_load; if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE)) { - net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n", wg->dev->name, skb); - cookie_message_consume((struct message_handshake_cookie *)skb->data, wg); + net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n", + wg->dev->name, skb); + cookie_message_consume( + (struct message_handshake_cookie *)skb->data, wg); return; } - under_load = skb_queue_len(&wg->incoming_handshakes) >= MAX_QUEUED_INCOMING_HANDSHAKES / 8; + under_load = skb_queue_len(&wg->incoming_handshakes) >= + MAX_QUEUED_INCOMING_HANDSHAKES / 8; if (under_load) last_under_load = ktime_get_boot_fast_ns(); else if (last_under_load) under_load = !has_expired(last_under_load, 1); - mac_state = cookie_validate_packet(&wg->cookie_checker, skb, under_load); - if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) + mac_state = cookie_validate_packet(&wg->cookie_checker, skb, + under_load); + if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || + (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)) packet_needs_cookie = false; else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE) packet_needs_cookie = true; else { - net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n", wg->dev->name, skb); + net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n", + wg->dev->name, skb); return; } switch (SKB_TYPE_LE32(skb)) { case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): { - struct message_handshake_initiation *message = (struct message_handshake_initiation *)skb->data; + struct message_handshake_initiation *message = + (struct message_handshake_initiation *)skb->data; if (packet_needs_cookie) { - packet_send_handshake_cookie(wg, skb, message->sender_index); + packet_send_handshake_cookie(wg, skb, + message->sender_index); return; } peer = noise_handshake_consume_initiation(message, wg); if (unlikely(!peer)) { - net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n", wg->dev->name, skb); + net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n", + wg->dev->name, skb); return; } socket_set_peer_endpoint_from_skb(peer, skb); - net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n", wg->dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n", + wg->dev->name, peer->internal_id, + &peer->endpoint.addr); packet_send_handshake_response(peer); break; } case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): { - struct message_handshake_response *message = (struct message_handshake_response *)skb->data; + struct message_handshake_response *message = + (struct message_handshake_response *)skb->data; if (packet_needs_cookie) { - packet_send_handshake_cookie(wg, skb, message->sender_index); + packet_send_handshake_cookie(wg, skb, + message->sender_index); return; } peer = noise_handshake_consume_response(message, wg); if (unlikely(!peer)) { - net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n", wg->dev->name, skb); + net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n", + wg->dev->name, skb); return; } socket_set_peer_endpoint_from_skb(peer, skb); - net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n", wg->dev->name, peer->internal_id, &peer->endpoint.addr); - if (noise_handshake_begin_session(&peer->handshake, &peer->keypairs)) { + net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n", + wg->dev->name, peer->internal_id, + &peer->endpoint.addr); + if (noise_handshake_begin_session(&peer->handshake, + &peer->keypairs)) { timers_session_derived(peer); timers_handshake_complete(peer); - /* Calling this function will either send any existing packets in the queue - * and not send a keepalive, which is the best case, Or, if there's nothing - * in the queue, it will send a keepalive, in order to give immediate - * confirmation of the session. + /* Calling this function will either send any existing + * packets in the queue and not send a keepalive, which + * is the best case, Or, if there's nothing in the + * queue, it will send a keepalive, in order to give + * immediate confirmation of the session. */ packet_send_keepalive(peer); } @@ -169,7 +211,8 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff void packet_handshake_receive_worker(struct work_struct *work) { - struct wireguard_device *wg = container_of(work, struct multicore_worker, work)->ptr; + struct wireguard_device *wg = + container_of(work, struct multicore_worker, work)->ptr; struct sk_buff *skb; while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) { @@ -189,8 +232,10 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) rcu_read_lock_bh(); keypair = rcu_dereference_bh(peer->keypairs.current_keypair); - if (likely(keypair && keypair->sending.is_valid) && keypair->i_am_the_initiator && - unlikely(has_expired(keypair->sending.birthdate, REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT))) + if (likely(keypair && keypair->sending.is_valid) && + keypair->i_am_the_initiator && + unlikely(has_expired(keypair->sending.birthdate, + REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT))) send = true; rcu_read_unlock_bh(); @@ -200,7 +245,9 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) } } -static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key, simd_context_t simd_context) +static inline bool skb_decrypt(struct sk_buff *skb, + struct noise_symmetric_key *key, + simd_context_t simd_context) { struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1]; struct sk_buff *trailer; @@ -210,12 +257,15 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * if (unlikely(!key)) return false; - if (unlikely(!key->is_valid || has_expired(key->birthdate, REJECT_AFTER_TIME) || key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) { + if (unlikely(!key->is_valid || + has_expired(key->birthdate, REJECT_AFTER_TIME) || + key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) { key->is_valid = false; return false; } - PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter); + PACKET_CB(skb)->nonce = + le64_to_cpu(((struct message_data *)skb->data)->counter); /* We ensure that the network header is part of the packet before we * call skb_cow_data, so that there's no chance that data is removed @@ -233,7 +283,9 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0) return false; - if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key, simd_context)) + if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, + PACKET_CB(skb)->nonce, key->key, + simd_context)) return false; /* Another ugly situation of pushing and pulling the header so as to @@ -248,33 +300,39 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key * } /* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */ -static inline bool counter_validate(union noise_counter *counter, u64 their_counter) +static inline bool counter_validate(union noise_counter *counter, + u64 their_counter) { - bool ret = false; unsigned long index, index_current, top, i; + bool ret = false; spin_lock_bh(&counter->receive.lock); - if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES)) + if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || + their_counter >= REJECT_AFTER_MESSAGES)) goto out; ++their_counter; - if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < counter->receive.counter)) + if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < + counter->receive.counter)) goto out; index = their_counter >> ilog2(BITS_PER_LONG); if (likely(their_counter > counter->receive.counter)) { index_current = counter->receive.counter >> ilog2(BITS_PER_LONG); - top = min_t(unsigned long, index - index_current, COUNTER_BITS_TOTAL / BITS_PER_LONG); + top = min_t(unsigned long, index - index_current, + COUNTER_BITS_TOTAL / BITS_PER_LONG); for (i = 1; i <= top; ++i) - counter->receive.backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; + counter->receive.backtrack[(i + index_current) & + ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0; counter->receive.counter = their_counter; } index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1; - ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), &counter->receive.backtrack[index]); + ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), + &counter->receive.backtrack[index]); out: spin_unlock_bh(&counter->receive.lock); @@ -282,15 +340,18 @@ out: } #include "selftest/counter.h" -static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff *skb, struct endpoint *endpoint) +static void packet_consume_data_done(struct wireguard_peer *peer, + struct sk_buff *skb, + struct endpoint *endpoint) { - struct wireguard_peer *routed_peer; struct net_device *dev = peer->device->dev; + struct wireguard_peer *routed_peer; unsigned int len, len_before_trim; socket_set_peer_endpoint(peer, endpoint); - if (unlikely(noise_received_with_keypair(&peer->keypairs, PACKET_CB(skb)->keypair))) { + if (unlikely(noise_received_with_keypair(&peer->keypairs, + PACKET_CB(skb)->keypair))) { timers_handshake_complete(peer); packet_send_staged_packets(peer); } @@ -303,7 +364,9 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff /* A packet with length 0 is a keepalive packet */ if (unlikely(!skb->len)) { rx_stats(peer, message_data_len(0)); - net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n", + dev->name, peer->internal_id, + &peer->endpoint.addr); goto packet_processed; } @@ -311,7 +374,10 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff if (unlikely(skb_network_header(skb) < skb->head)) goto dishonest_packet_size; - if (unlikely(!(pskb_network_may_pull(skb, sizeof(struct iphdr)) && (ip_hdr(skb)->version == 4 || (ip_hdr(skb)->version == 6 && pskb_network_may_pull(skb, sizeof(struct ipv6hdr))))))) + if (unlikely(!(pskb_network_may_pull(skb, sizeof(struct iphdr)) && + (ip_hdr(skb)->version == 4 || + (ip_hdr(skb)->version == 6 && + pskb_network_may_pull(skb, sizeof(struct ipv6hdr))))))) goto dishonest_packet_type; skb->dev = dev; @@ -324,7 +390,8 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff if (INET_ECN_is_ce(PACKET_CB(skb)->ds)) IP_ECN_set_ce(ip_hdr(skb)); } else if (skb->protocol == htons(ETH_P_IPV6)) { - len = ntohs(ipv6_hdr(skb)->payload_len) + sizeof(struct ipv6hdr); + len = ntohs(ipv6_hdr(skb)->payload_len) + + sizeof(struct ipv6hdr); if (INET_ECN_is_ce(PACKET_CB(skb)->ds)) IP6_ECN_set_ce(skb, ipv6_hdr(skb)); } else @@ -344,23 +411,29 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff if (unlikely(napi_gro_receive(&peer->napi, skb) == GRO_DROP)) { ++dev->stats.rx_dropped; - net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n", + dev->name, peer->internal_id, + &peer->endpoint.addr); } else rx_stats(peer, message_data_len(len_before_trim)); return; dishonest_packet_peer: - net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n", dev->name, skb, peer->internal_id, &peer->endpoint.addr); + net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n", + dev->name, skb, peer->internal_id, + &peer->endpoint.addr); ++dev->stats.rx_errors; ++dev->stats.rx_frame_errors; goto packet_processed; dishonest_packet_type: - net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n", + dev->name, peer->internal_id, &peer->endpoint.addr); ++dev->stats.rx_errors; ++dev->stats.rx_frame_errors; goto packet_processed; dishonest_packet_size: - net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n", + dev->name, peer->internal_id, &peer->endpoint.addr); ++dev->stats.rx_errors; ++dev->stats.rx_length_errors; goto packet_processed; @@ -370,19 +443,22 @@ packet_processed: int packet_rx_poll(struct napi_struct *napi, int budget) { - struct wireguard_peer *peer = container_of(napi, struct wireguard_peer, napi); + struct wireguard_peer *peer = + container_of(napi, struct wireguard_peer, napi); struct crypt_queue *queue = &peer->rx_queue; struct noise_keypair *keypair; - struct sk_buff *skb; struct endpoint endpoint; enum packet_state state; + struct sk_buff *skb; int work_done = 0; bool free; if (unlikely(budget <= 0)) return 0; - while ((skb = __ptr_ring_peek(&queue->ring)) != NULL && (state = atomic_read_acquire(&PACKET_CB(skb)->state)) != PACKET_STATE_UNCRYPTED) { + while ((skb = __ptr_ring_peek(&queue->ring)) != NULL && + (state = atomic_read_acquire(&PACKET_CB(skb)->state)) != + PACKET_STATE_UNCRYPTED) { __ptr_ring_discard_one(&queue->ring); peer = PACKET_PEER(skb); keypair = PACKET_CB(skb)->keypair; @@ -391,8 +467,12 @@ int packet_rx_poll(struct napi_struct *napi, int budget) if (unlikely(state != PACKET_STATE_CRYPTED)) goto next; - if (unlikely(!counter_validate(&keypair->receiving.counter, PACKET_CB(skb)->nonce))) { - net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, keypair->receiving.counter.receive.counter); + if (unlikely(!counter_validate(&keypair->receiving.counter, + PACKET_CB(skb)->nonce))) { + net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", + peer->device->dev->name, + PACKET_CB(skb)->nonce, + keypair->receiving.counter.receive.counter); goto next; } @@ -403,7 +483,7 @@ int packet_rx_poll(struct napi_struct *napi, int budget) packet_consume_data_done(peer, skb, &endpoint); free = false; -next: + next: noise_keypair_put(keypair, false); peer_put(peer); if (unlikely(free)) @@ -421,34 +501,46 @@ next: void packet_decrypt_worker(struct work_struct *work) { - struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; - struct sk_buff *skb; + struct crypt_queue *queue = + container_of(work, struct multicore_worker, work)->ptr; simd_context_t simd_context = simd_get(); + struct sk_buff *skb; while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) { - enum packet_state state = likely(skb_decrypt(skb, &PACKET_CB(skb)->keypair->receiving, simd_context)) ? PACKET_STATE_CRYPTED : PACKET_STATE_DEAD; - queue_enqueue_per_peer_napi(&PACKET_PEER(skb)->rx_queue, skb, state); + enum packet_state state = likely(skb_decrypt(skb, + &PACKET_CB(skb)->keypair->receiving, + simd_context)) ? + PACKET_STATE_CRYPTED : PACKET_STATE_DEAD; + queue_enqueue_per_peer_napi(&PACKET_PEER(skb)->rx_queue, skb, + state); simd_context = simd_relax(simd_context); } simd_put(simd_context); } -static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb) +static void packet_consume_data(struct wireguard_device *wg, + struct sk_buff *skb) { - struct wireguard_peer *peer = NULL; __le32 idx = ((struct message_data *)skb->data)->key_idx; + struct wireguard_peer *peer = NULL; int ret; rcu_read_lock_bh(); - PACKET_CB(skb)->keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx, &peer); + PACKET_CB(skb)->keypair = + (struct noise_keypair *)index_hashtable_lookup( + &wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx, + &peer); if (unlikely(!noise_keypair_get(PACKET_CB(skb)->keypair))) goto err_keypair; if (unlikely(peer->is_dead)) goto err; - ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); + ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, + &peer->rx_queue, skb, + wg->packet_crypt_wq, + &wg->decrypt_queue.last_cpu); if (unlikely(ret == -EPIPE)) queue_enqueue_per_peer(&peer->rx_queue, skb, PACKET_STATE_DEAD); if (likely(!ret || ret == -EPIPE)) { @@ -473,14 +565,20 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): { int cpu; - if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES || unlikely(!rng_is_initialized())) { - net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n", wg->dev->name, skb); + if (skb_queue_len(&wg->incoming_handshakes) > + MAX_QUEUED_INCOMING_HANDSHAKES || + unlikely(!rng_is_initialized())) { + net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n", + wg->dev->name, skb); goto err; } skb_queue_tail(&wg->incoming_handshakes, skb); - /* Queues up a call to packet_process_queued_handshake_packets(skb): */ + /* Queues up a call to packet_process_queued_handshake_ + * packets(skb): + */ cpu = cpumask_next_online(&wg->incoming_handshake_cpu); - queue_work_on(cpu, wg->handshake_receive_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work); + queue_work_on(cpu, wg->handshake_receive_wq, + &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work); break; } case cpu_to_le32(MESSAGE_DATA): @@ -488,7 +586,8 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) packet_consume_data(wg, skb); break; default: - net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n", wg->dev->name, skb); + net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n", + wg->dev->name, skb); goto err; } return; diff --git a/src/selftest/allowedips.h b/src/selftest/allowedips.h index 1cf8c8f..28461d7 100644 --- a/src/selftest/allowedips.h +++ b/src/selftest/allowedips.h @@ -8,7 +8,8 @@ #ifdef DEBUG_PRINT_TRIE_GRAPHVIZ #include -static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, u8 cidr) +static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, + u8 cidr) { swap_endian(dst, src, bits); memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8); @@ -18,34 +19,44 @@ static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, u static __init void print_node(struct allowedips_node *node, u8 bits) { - u32 color = 0; - char *style = "dotted"; char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n"; - char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; + char *fmt_declaration = KERN_DEBUG + "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n"; + char *style = "dotted"; u8 ip1[16], ip2[16]; + u32 color = 0; + if (bits == 32) { fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n"; - fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; + fmt_declaration = KERN_DEBUG + "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n"; } else if (bits == 128) { fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n"; - fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; + fmt_declaration = KERN_DEBUG + "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n"; } if (node->peer) { hsiphash_key_t key = { 0 }; memcpy(&key, &node->peer, sizeof(node->peer)); - color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 | hsiphash_1u32(0xbabecafe, &key) % 200 << 8 | hsiphash_1u32(0xabad1dea, &key) % 200; + color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 | + hsiphash_1u32(0xbabecafe, &key) % 200 << 8 | + hsiphash_1u32(0xabad1dea, &key) % 200; style = "bold"; } swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr); printk(fmt_declaration, ip1, node->cidr, style, color); if (node->bit[0]) { - swap_endian_and_apply_cidr(ip2, node->bit[0]->bits, bits, node->cidr); - printk(fmt_connection, ip1, node->cidr, ip2, node->bit[0]->cidr); + swap_endian_and_apply_cidr(ip2, node->bit[0]->bits, bits, + node->cidr); + printk(fmt_connection, ip1, node->cidr, ip2, + node->bit[0]->cidr); print_node(node->bit[0], bits); } if (node->bit[1]) { - swap_endian_and_apply_cidr(ip2, node->bit[1]->bits, bits, node->cidr); - printk(fmt_connection, ip1, node->cidr, ip2, node->bit[1]->cidr); + swap_endian_and_apply_cidr(ip2, node->bit[1]->bits, bits, + node->cidr); + printk(fmt_connection, ip1, node->cidr, ip2, + node->bit[1]->cidr); print_node(node->bit[1], bits); } } @@ -79,9 +90,10 @@ static __init void horrible_allowedips_init(struct horrible_allowedips *table) } static __init void horrible_allowedips_free(struct horrible_allowedips *table) { - struct hlist_node *h; struct horrible_allowedips_node *node; - hlist_for_each_entry_safe(node, h, &table->head, table) { + struct hlist_node *h; + + hlist_for_each_entry_safe (node, h, &table->head, table) { hlist_del(&node->table); kfree(node); } @@ -89,20 +101,21 @@ static __init void horrible_allowedips_free(struct horrible_allowedips *table) static __init inline union nf_inet_addr horrible_cidr_to_mask(uint8_t cidr) { union nf_inet_addr mask; + memset(&mask, 0x00, 128 / 8); memset(&mask, 0xff, cidr / 8); if (cidr % 32) - mask.all[cidr / 32] = htonl((0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); + mask.all[cidr / 32] = htonl( + (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL); return mask; } static __init inline uint8_t horrible_mask_to_cidr(union nf_inet_addr subnet) { - return hweight32(subnet.all[0]) - + hweight32(subnet.all[1]) - + hweight32(subnet.all[2]) - + hweight32(subnet.all[3]); + return hweight32(subnet.all[0]) + hweight32(subnet.all[1]) + + hweight32(subnet.all[2]) + hweight32(subnet.all[3]); } -static __init inline void horrible_mask_self(struct horrible_allowedips_node *node) +static __init inline void +horrible_mask_self(struct horrible_allowedips_node *node) { if (node->ip_version == 4) node->ip.ip &= node->mask.ip; @@ -113,24 +126,36 @@ static __init inline void horrible_mask_self(struct horrible_allowedips_node *no node->ip.ip6[3] &= node->mask.ip6[3]; } } -static __init inline bool horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip) +static __init inline bool +horrible_match_v4(const struct horrible_allowedips_node *node, + struct in_addr *ip) { return (ip->s_addr & node->mask.ip) == node->ip.ip; } -static __init inline bool horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip) +static __init inline bool +horrible_match_v6(const struct horrible_allowedips_node *node, + struct in6_addr *ip) { - return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] && - (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] && - (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] && - (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; + return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == + node->ip.ip6[0] && + (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == + node->ip.ip6[1] && + (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == + node->ip.ip6[2] && + (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3]; } -static __init void horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node) +static __init void +horrible_insert_ordered(struct horrible_allowedips *table, + struct horrible_allowedips_node *node) { struct horrible_allowedips_node *other = NULL, *where = NULL; uint8_t my_cidr = horrible_mask_to_cidr(node->mask); - hlist_for_each_entry(other, &table->head, table) { - if (!memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) && - !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr)) && + + hlist_for_each_entry (other, &table->head, table) { + if (!memcmp(&other->mask, &node->mask, + sizeof(union nf_inet_addr)) && + !memcmp(&other->ip, &node->ip, + sizeof(union nf_inet_addr)) && other->ip_version == node->ip_version) { other->value = node->value; kfree(node); @@ -147,9 +172,13 @@ static __init void horrible_insert_ordered(struct horrible_allowedips *table, st else hlist_add_before(&node->table, &where->table); } -static __init int horrible_allowedips_insert_v4(struct horrible_allowedips *table, struct in_addr *ip, uint8_t cidr, void *value) +static __init int +horrible_allowedips_insert_v4(struct horrible_allowedips *table, + struct in_addr *ip, uint8_t cidr, void *value) { - struct horrible_allowedips_node *node = kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL); + struct horrible_allowedips_node *node = + kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL); + if (!node) return -ENOMEM; node->ip.in = *ip; @@ -160,9 +189,13 @@ static __init int horrible_allowedips_insert_v4(struct horrible_allowedips *tabl horrible_insert_ordered(table, node); return 0; } -static __init int horrible_allowedips_insert_v6(struct horrible_allowedips *table, struct in6_addr *ip, uint8_t cidr, void *value) +static __init int +horrible_allowedips_insert_v6(struct horrible_allowedips *table, + struct in6_addr *ip, uint8_t cidr, void *value) { - struct horrible_allowedips_node *node = kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL); + struct horrible_allowedips_node *node = + kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL); + if (!node) return -ENOMEM; node->ip.in6 = *ip; @@ -173,11 +206,14 @@ static __init int horrible_allowedips_insert_v6(struct horrible_allowedips *tabl horrible_insert_ordered(table, node); return 0; } -static __init void *horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip) +static __init void * +horrible_allowedips_lookup_v4(struct horrible_allowedips *table, + struct in_addr *ip) { struct horrible_allowedips_node *node; void *ret = NULL; - hlist_for_each_entry(node, &table->head, table) { + + hlist_for_each_entry (node, &table->head, table) { if (node->ip_version != 4) continue; if (horrible_match_v4(node, ip)) { @@ -187,11 +223,14 @@ static __init void *horrible_allowedips_lookup_v4(struct horrible_allowedips *ta } return ret; } -static __init void *horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip) +static __init void * +horrible_allowedips_lookup_v6(struct horrible_allowedips *table, + struct in6_addr *ip) { struct horrible_allowedips_node *node; void *ret = NULL; - hlist_for_each_entry(node, &table->head, table) { + + hlist_for_each_entry (node, &table->head, table) { if (node->ip_version != 6) continue; if (horrible_match_v6(node, ip)) { @@ -204,13 +243,13 @@ static __init void *horrible_allowedips_lookup_v6(struct horrible_allowedips *ta static __init bool randomized_test(void) { - DEFINE_MUTEX(mutex); - bool ret = false; unsigned int i, j, k, mutate_amount, cidr; + u8 ip[16], mutate_mask[16], mutated[16]; struct wireguard_peer **peers, *peer; - struct allowedips t; struct horrible_allowedips h; - u8 ip[16], mutate_mask[16], mutated[16]; + DEFINE_MUTEX(mutex); + struct allowedips t; + bool ret = false; mutex_init(&mutex); @@ -237,11 +276,13 @@ static __init bool randomized_test(void) prandom_bytes(ip, 4); cidr = prandom_u32_max(32) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (allowedips_insert_v4(&t, (struct in_addr *)ip, cidr, peer, &mutex) < 0) { + if (allowedips_insert_v4(&t, (struct in_addr *)ip, cidr, peer, + &mutex) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } - if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, cidr, peer) < 0) { + if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, + cidr, peer) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } @@ -251,18 +292,23 @@ static __init bool randomized_test(void) mutate_amount = prandom_u32_max(32); for (k = 0; k < mutate_amount / 8; ++k) mutate_mask[k] = 0xff; - mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8); + mutate_mask[k] = 0xff + << ((8 - (mutate_amount % 8)) % 8); for (; k < 4; ++k) mutate_mask[k] = 0; for (k = 0; k < 4; ++k) - mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256)); + mutated[k] = (mutated[k] & mutate_mask[k]) | + (~mutate_mask[k] & + prandom_u32_max(256)); cidr = prandom_u32_max(32) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (allowedips_insert_v4(&t, (struct in_addr *)mutated, cidr, peer, &mutex) < 0) { + if (allowedips_insert_v4(&t, (struct in_addr *)mutated, + cidr, peer, &mutex) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } - if (horrible_allowedips_insert_v4(&h, (struct in_addr *)mutated, cidr, peer)) { + if (horrible_allowedips_insert_v4(&h, + (struct in_addr *)mutated, cidr, peer)) { pr_info("allowedips random self-test: out of memory\n"); goto free; } @@ -273,11 +319,13 @@ static __init bool randomized_test(void) prandom_bytes(ip, 16); cidr = prandom_u32_max(128) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr, peer, &mutex) < 0) { + if (allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr, peer, + &mutex) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } - if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, cidr, peer) < 0) { + if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, + cidr, peer) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } @@ -287,18 +335,24 @@ static __init bool randomized_test(void) mutate_amount = prandom_u32_max(128); for (k = 0; k < mutate_amount / 8; ++k) mutate_mask[k] = 0xff; - mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8); + mutate_mask[k] = 0xff + << ((8 - (mutate_amount % 8)) % 8); for (; k < 4; ++k) mutate_mask[k] = 0; for (k = 0; k < 4; ++k) - mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256)); + mutated[k] = (mutated[k] & mutate_mask[k]) | + (~mutate_mask[k] & + prandom_u32_max(256)); cidr = prandom_u32_max(128) + 1; peer = peers[prandom_u32_max(NUM_PEERS)]; - if (allowedips_insert_v6(&t, (struct in6_addr *)mutated, cidr, peer, &mutex) < 0) { + if (allowedips_insert_v6(&t, (struct in6_addr *)mutated, + cidr, peer, &mutex) < 0) { pr_info("allowedips random self-test: out of memory\n"); goto free; } - if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)mutated, cidr, peer)) { + if (horrible_allowedips_insert_v6( + &h, (struct in6_addr *)mutated, cidr, + peer)) { pr_info("allowedips random self-test: out of memory\n"); goto free; } @@ -314,7 +368,8 @@ static __init bool randomized_test(void) for (i = 0; i < NUM_QUERIES; ++i) { prandom_bytes(ip, 4); - if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { + if (lookup(t.root4, 32, ip) != + horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) { pr_info("allowedips random self-test: FAIL\n"); goto free; } @@ -322,7 +377,8 @@ static __init bool randomized_test(void) for (i = 0; i < NUM_QUERIES; ++i) { prandom_bytes(ip, 16); - if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { + if (lookup(t.root6, 128, ip) != + horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) { pr_info("allowedips random self-test: FAIL\n"); goto free; } @@ -376,15 +432,22 @@ static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family) wctx->count++; - if (cidr == 27 && !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) + if (cidr == 27 && + !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr))) wctx->found_a = true; - else if (cidr == 128 && !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), sizeof(struct in6_addr))) + else if (cidr == 128 && + !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), + sizeof(struct in6_addr))) wctx->found_b = true; - else if (cidr == 29 && !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) + else if (cidr == 29 && + !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr))) wctx->found_c = true; - else if (cidr == 83 && !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), sizeof(struct in6_addr))) + else if (cidr == 83 && + !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), + sizeof(struct in6_addr))) wctx->found_d = true; - else if (cidr == 21 && !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr))) + else if (cidr == 21 && + !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr))) wctx->found_e = true; else wctx->found_other = true; @@ -392,50 +455,55 @@ static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family) return 0; } -#define init_peer(name) do { \ - name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); \ - if (!name) { \ - pr_info("allowedips self-test: out of memory\n"); \ - goto free; \ - } \ - kref_init(&name->refcount); \ -} while (0) - -#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \ - allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem, &mutex) - -#define maybe_fail \ - ++i; \ - if (!_s) { \ - pr_info("allowedips self-test %zu: FAIL\n", i); \ - success = false; \ - } - -#define test(version, mem, ipa, ipb, ipc, ipd) do { \ - bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) == mem; \ - maybe_fail \ -} while (0) - -#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \ - bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) != mem; \ - maybe_fail \ -} while (0) - -#define test_boolean(cond) do { \ - bool _s = (cond); \ - maybe_fail \ -} while (0) +#define init_peer(name) do { \ + name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); \ + if (!name) { \ + pr_info("allowedips self-test: out of memory\n"); \ + goto free; \ + } \ + kref_init(&name->refcount); \ + } while (0) + +#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \ + allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \ + cidr, mem, &mutex) + +#define maybe_fail() do { \ + ++i; \ + if (!_s) { \ + pr_info("allowedips self-test %zu: FAIL\n", i); \ + success = false; \ + } \ + } while (0) + +#define test(version, mem, ipa, ipb, ipc, ipd) do { \ + bool _s = lookup(t.root##version, version == 4 ? 32 : 128, \ + ip##version(ipa, ipb, ipc, ipd)) == mem; \ + maybe_fail(); \ + } while (0) + +#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \ + bool _s = lookup(t.root##version, version == 4 ? 32 : 128, \ + ip##version(ipa, ipb, ipc, ipd)) != mem; \ + maybe_fail(); \ + } while (0) + +#define test_boolean(cond) do { \ + bool _s = (cond); \ + maybe_fail(); \ + } while (0) bool __init allowedips_selftest(void) { - DEFINE_MUTEX(mutex); - struct allowedips t; - struct walk_ctx wctx = { 0 }; + struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, + *e = NULL, *f = NULL, *g = NULL, *h = NULL; struct allowedips_cursor cursor = { 0 }; - struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL; - size_t i = 0; + struct walk_ctx wctx = { 0 }; bool success = false; + struct allowedips t; + DEFINE_MUTEX(mutex); struct in6_addr ip; + size_t i = 0; __be64 part; mutex_init(&mutex); @@ -455,19 +523,23 @@ bool __init allowedips_selftest(void) insert(4, b, 192, 168, 4, 4, 32); insert(4, c, 192, 168, 0, 0, 16); insert(4, d, 192, 95, 5, 64, 27); - insert(4, c, 192, 95, 5, 65, 27); /* replaces previous entry, and maskself is required */ + /* replaces previous entry, and maskself is required */ + insert(4, c, 192, 95, 5, 65, 27); insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128); insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64); insert(4, e, 0, 0, 0, 0, 0); insert(6, e, 0, 0, 0, 0, 0); - insert(6, f, 0, 0, 0, 0, 0); /* replaces previous entry */ + /* replaces previous entry */ + insert(6, f, 0, 0, 0, 0, 0); insert(6, g, 0x24046800, 0, 0, 0, 32); - insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); /* maskself is required */ + /* maskself is required */ + insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128); insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128); insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98); insert(4, g, 64, 15, 112, 0, 20); - insert(4, h, 64, 15, 123, 211, 25); /* maskself is required */ + /* maskself is required */ + insert(4, h, 64, 15, 123, 211, 25); insert(4, a, 10, 0, 0, 0, 25); insert(4, b, 10, 0, 0, 128, 25); insert(4, a, 10, 1, 0, 0, 30); diff --git a/src/selftest/counter.h b/src/selftest/counter.h index 5344075..1c2a3b4 100644 --- a/src/selftest/counter.h +++ b/src/selftest/counter.h @@ -6,13 +6,24 @@ #ifdef DEBUG bool __init packet_counter_selftest(void) { - bool success = true; unsigned int test_num = 0, i; union noise_counter counter; + bool success = true; -#define T_INIT do { memset(&counter, 0, sizeof(union noise_counter)); spin_lock_init(&counter.receive.lock); } while (0) +#define T_INIT do { \ + memset(&counter, 0, sizeof(union noise_counter)); \ + spin_lock_init(&counter.receive.lock); \ + } while (0) #define T_LIM (COUNTER_WINDOW_SIZE + 1) -#define T(n, v) do { ++test_num; if (counter_validate(&counter, n) != v) { pr_info("nonce counter self-test %u: FAIL\n", test_num); success = false; } } while (0) +#define T(n, v) do { \ + ++test_num; \ + if (counter_validate(&counter, n) != v) { \ + pr_info("nonce counter self-test %u: FAIL\n", \ + test_num); \ + success = false; \ + } \ + } while (0) + T_INIT; /* 1 */ T(0, true); /* 2 */ T(1, true); @@ -62,22 +73,22 @@ bool __init packet_counter_selftest(void) T(0, false); T_INIT; - for (i = COUNTER_WINDOW_SIZE + 1; i-- > 0 ;) + for (i = COUNTER_WINDOW_SIZE + 1; i-- > 0;) T(i, true); T_INIT; - for (i = COUNTER_WINDOW_SIZE + 2; i-- > 1 ;) + for (i = COUNTER_WINDOW_SIZE + 2; i-- > 1;) T(i, true); T(0, false); T_INIT; - for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1 ;) + for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1;) T(i, true); T(COUNTER_WINDOW_SIZE + 1, true); T(0, false); T_INIT; - for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1 ;) + for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1;) T(i, true); T(0, true); T(COUNTER_WINDOW_SIZE + 1, true); diff --git a/src/selftest/ratelimiter.h b/src/selftest/ratelimiter.h index c05eac7..a71ddb1 100644 --- a/src/selftest/ratelimiter.h +++ b/src/selftest/ratelimiter.h @@ -7,7 +7,10 @@ #include -static const struct { bool result; unsigned int msec_to_sleep_before; } expected_results[] __initconst = { +static const struct { + bool result; + unsigned int msec_to_sleep_before; +} expected_results[] __initconst = { [0 ... PACKETS_BURSTABLE - 1] = { true, 0 }, [PACKETS_BURSTABLE] = { false, 0 }, [PACKETS_BURSTABLE + 1] = { true, MSEC_PER_SEC / PACKETS_PER_SECOND }, @@ -29,14 +32,14 @@ static __init unsigned int maximum_jiffies_at_index(int index) bool __init ratelimiter_selftest(void) { - struct sk_buff *skb4; - struct iphdr *hdr4; + int i, test = 0, tries = 0, ret = false; + unsigned long loop_start_time; #if IS_ENABLED(CONFIG_IPV6) struct sk_buff *skb6; struct ipv6hdr *hdr6; #endif - int i, test = 0, tries = 0, ret = false; - unsigned long loop_start_time; + struct sk_buff *skb4; + struct iphdr *hdr4; BUILD_BUG_ON(MSEC_PER_SEC % PACKETS_PER_SECOND != 0); @@ -81,21 +84,24 @@ bool __init ratelimiter_selftest(void) restart: loop_start_time = jiffies; for (i = 0; i < ARRAY_SIZE(expected_results); ++i) { -#define ensure_time do {\ - if (time_is_before_jiffies(loop_start_time + maximum_jiffies_at_index(i))) { \ - if (++tries >= 5000) \ - goto err; \ - gc_entries(NULL); \ - rcu_barrier(); \ - msleep(500); \ - goto restart; \ - }} while (0) +#define ensure_time do { \ + if (time_is_before_jiffies(loop_start_time + \ + maximum_jiffies_at_index(i))) { \ + if (++tries >= 5000) \ + goto err; \ + gc_entries(NULL); \ + rcu_barrier(); \ + msleep(500); \ + goto restart; \ + } \ + } while (0) if (expected_results[i].msec_to_sleep_before) msleep(expected_results[i].msec_to_sleep_before); ensure_time; - if (ratelimiter_allow(skb4, &init_net) != expected_results[i].result) + if (ratelimiter_allow(skb4, &init_net) != + expected_results[i].result) goto err; ++test; hdr4->saddr = htonl(ntohl(hdr4->saddr) + i + 1); @@ -106,17 +112,21 @@ restart: hdr4->saddr = htonl(ntohl(hdr4->saddr) - i - 1); #if IS_ENABLED(CONFIG_IPV6) - hdr6->saddr.in6_u.u6_addr32[2] = hdr6->saddr.in6_u.u6_addr32[3] = htonl(i); + hdr6->saddr.in6_u.u6_addr32[2] = + hdr6->saddr.in6_u.u6_addr32[3] = htonl(i); ensure_time; - if (ratelimiter_allow(skb6, &init_net) != expected_results[i].result) + if (ratelimiter_allow(skb6, &init_net) != + expected_results[i].result) goto err; ++test; - hdr6->saddr.in6_u.u6_addr32[0] = htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) + i + 1); + hdr6->saddr.in6_u.u6_addr32[0] = + htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) + i + 1); ensure_time; if (!ratelimiter_allow(skb6, &init_net)) goto err; ++test; - hdr6->saddr.in6_u.u6_addr32[0] = htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) - i - 1); + hdr6->saddr.in6_u.u6_addr32[0] = + htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) - i - 1); ensure_time; #endif } diff --git a/src/send.c b/src/send.c index 3af7ef3..ec0074e 100644 --- a/src/send.c +++ b/src/send.c @@ -23,46 +23,63 @@ static void packet_send_handshake_initiation(struct wireguard_peer *peer) { struct message_handshake_initiation packet; - if (!has_expired(atomic64_read(&peer->last_sent_handshake), REKEY_TIMEOUT)) + if (!has_expired(atomic64_read(&peer->last_sent_handshake), + REKEY_TIMEOUT)) return; /* This function is rate limited. */ atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns()); - net_dbg_ratelimited("%s: Sending handshake initiation to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Sending handshake initiation to peer %llu (%pISpfsc)\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr); if (noise_handshake_create_initiation(&packet, &peer->handshake)) { cookie_add_mac_to_packet(&packet, sizeof(packet), peer); timers_any_authenticated_packet_traversal(peer); timers_any_authenticated_packet_sent(peer); - atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns()); - socket_send_buffer_to_peer(peer, &packet, sizeof(struct message_handshake_initiation), HANDSHAKE_DSCP); + atomic64_set(&peer->last_sent_handshake, + ktime_get_boot_fast_ns()); + socket_send_buffer_to_peer( + peer, &packet, + sizeof(struct message_handshake_initiation), + HANDSHAKE_DSCP); timers_handshake_initiated(peer); } } void packet_handshake_send_worker(struct work_struct *work) { - struct wireguard_peer *peer = container_of(work, struct wireguard_peer, transmit_handshake_work); + struct wireguard_peer *peer = container_of(work, struct wireguard_peer, + transmit_handshake_work); packet_send_handshake_initiation(peer); peer_put(peer); } -void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool is_retry) +void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, + bool is_retry) { if (!is_retry) peer->timer_handshake_attempts = 0; rcu_read_lock_bh(); - /* We check last_sent_handshake here in addition to the actual function we're queueing - * up, so that we don't queue things if not strictly necessary. + /* We check last_sent_handshake here in addition to the actual function + * we're queueing up, so that we don't queue things if not strictly + * necessary: */ - if (!has_expired(atomic64_read(&peer->last_sent_handshake), REKEY_TIMEOUT) || unlikely(peer->is_dead)) + if (!has_expired(atomic64_read(&peer->last_sent_handshake), + REKEY_TIMEOUT) || unlikely(peer->is_dead)) goto out; peer_get(peer); - /* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */ - if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work)) - peer_put(peer); /* If the work was already queued, we want to drop the extra reference */ + /* Queues up calling packet_send_queued_handshakes(peer), where we do a + * peer_put(peer) after: + */ + if (!queue_work(peer->device->handshake_send_wq, + &peer->transmit_handshake_work)) + /* If the work was already queued, we want to drop the + * extra reference: + */ + peer_put(peer); out: rcu_read_unlock_bh(); } @@ -72,27 +89,39 @@ void packet_send_handshake_response(struct wireguard_peer *peer) struct message_handshake_response packet; atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns()); - net_dbg_ratelimited("%s: Sending handshake response to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Sending handshake response to peer %llu (%pISpfsc)\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr); if (noise_handshake_create_response(&packet, &peer->handshake)) { cookie_add_mac_to_packet(&packet, sizeof(packet), peer); - if (noise_handshake_begin_session(&peer->handshake, &peer->keypairs)) { + if (noise_handshake_begin_session(&peer->handshake, + &peer->keypairs)) { timers_session_derived(peer); timers_any_authenticated_packet_traversal(peer); timers_any_authenticated_packet_sent(peer); - atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns()); - socket_send_buffer_to_peer(peer, &packet, sizeof(struct message_handshake_response), HANDSHAKE_DSCP); + atomic64_set(&peer->last_sent_handshake, + ktime_get_boot_fast_ns()); + socket_send_buffer_to_peer( + peer, &packet, + sizeof(struct message_handshake_response), + HANDSHAKE_DSCP); } } } -void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index) +void packet_send_handshake_cookie(struct wireguard_device *wg, + struct sk_buff *initiating_skb, + __le32 sender_index) { struct message_handshake_cookie packet; - net_dbg_skb_ratelimited("%s: Sending cookie response for denied handshake message for %pISpfsc\n", wg->dev->name, initiating_skb); - cookie_message_create(&packet, initiating_skb, sender_index, &wg->cookie_checker); - socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, sizeof(packet)); + net_dbg_skb_ratelimited("%s: Sending cookie response for denied handshake message for %pISpfsc\n", + wg->dev->name, initiating_skb); + cookie_message_create(&packet, initiating_skb, sender_index, + &wg->cookie_checker); + socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, + sizeof(packet)); } static inline void keep_key_fresh(struct wireguard_peer *peer) @@ -103,8 +132,11 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) rcu_read_lock_bh(); keypair = rcu_dereference_bh(peer->keypairs.current_keypair); if (likely(keypair && keypair->sending.is_valid) && - (unlikely(atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES) || - (keypair->i_am_the_initiator && unlikely(has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME))))) + (unlikely(atomic64_read(&keypair->sending.counter.counter) > + REKEY_AFTER_MESSAGES) || + (keypair->i_am_the_initiator && + unlikely(has_expired(keypair->sending.birthdate, + REKEY_AFTER_TIME))))) send = true; rcu_read_unlock_bh(); @@ -114,9 +146,10 @@ static inline void keep_key_fresh(struct wireguard_peer *peer) static inline unsigned int skb_padding(struct sk_buff *skb) { - /* We do this modulo business with the MTU, just in case the networking layer - * gives us a packet that's bigger than the MTU. In that case, we wouldn't want - * the final subtraction to overflow in the case of the padded_size being clamped. + /* We do this modulo business with the MTU, just in case the networking + * layer gives us a packet that's bigger than the MTU. In that case, we + * wouldn't want the final subtraction to overflow in the case of the + * padded_size being clamped. */ unsigned int last_unit = skb->len % PACKET_CB(skb)->mtu; unsigned int padded_size = ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE); @@ -126,38 +159,49 @@ static inline unsigned int skb_padding(struct sk_buff *skb) return padded_size - last_unit; } -static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, simd_context_t simd_context) +static inline bool skb_encrypt(struct sk_buff *skb, + struct noise_keypair *keypair, + simd_context_t simd_context) { + unsigned int padding_len, plaintext_len, trailer_len; struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1]; struct message_data *header; - unsigned int padding_len, plaintext_len, trailer_len; - int num_frags; struct sk_buff *trailer; + int num_frags; - /* Calculate lengths */ + /* Calculate lengths. */ padding_len = skb_padding(skb); trailer_len = padding_len + noise_encrypted_len(0); plaintext_len = skb->len + padding_len; - /* Expand data section to have room for padding and auth tag */ + /* Expand data section to have room for padding and auth tag. */ num_frags = skb_cow_data(skb, trailer_len, &trailer); if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg))) return false; - /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */ + /* Set the padding to zeros, and make sure it and the auth tag are part + * of the skb. + */ memset(skb_tail_pointer(trailer), 0, padding_len); - /* Expand head section to have room for our header and the network stack's headers. */ + /* Expand head section to have room for our header and the network + * stack's headers. + */ if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0)) return false; - /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */ + /* We have to remember to add the checksum to the innerpacket, in case + * the receiver forwards it. + */ if (likely(!skb_checksum_setup(skb, true))) skb_checksum_help(skb); - /* Only after checksumming can we safely add on the padding at the end and the header. */ + /* Only after checksumming can we safely add on the padding at the end + * and the header. + */ skb_set_inner_network_header(skb, 0); - header = (struct message_data *)skb_push(skb, sizeof(struct message_data)); + header = (struct message_data *)skb_push(skb, + sizeof(struct message_data)); header->header.type = cpu_to_le32(MESSAGE_DATA); header->key_idx = keypair->remote_index; header->counter = cpu_to_le64(PACKET_CB(skb)->nonce); @@ -165,9 +209,12 @@ static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypai /* Now we can encrypt the scattergather segments */ sg_init_table(sg, num_frags); - if (skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(plaintext_len)) <= 0) + if (skb_to_sgvec(skb, sg, sizeof(struct message_data), + noise_encrypted_len(plaintext_len)) <= 0) return false; - return chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, keypair->sending.key, simd_context); + return chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, + PACKET_CB(skb)->nonce, + keypair->sending.key, simd_context); } void packet_send_keepalive(struct wireguard_peer *peer) @@ -175,38 +222,45 @@ void packet_send_keepalive(struct wireguard_peer *peer) struct sk_buff *skb; if (skb_queue_empty(&peer->staged_packet_queue)) { - skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, GFP_ATOMIC); + skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, + GFP_ATOMIC); if (unlikely(!skb)) return; skb_reserve(skb, DATA_PACKET_HEAD_ROOM); skb->dev = peer->device->dev; PACKET_CB(skb)->mtu = skb->dev->mtu; skb_queue_tail(&peer->staged_packet_queue, skb); - net_dbg_ratelimited("%s: Sending keepalive packet to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); + net_dbg_ratelimited("%s: Sending keepalive packet to peer %llu (%pISpfsc)\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr); } packet_send_staged_packets(peer); } -#define skb_walk_null_queue_safe(first, skb, next) for (skb = first, next = skb->next; skb; skb = next, next = skb ? skb->next : NULL) +#define skb_walk_null_queue_safe(first, skb, next) \ + for (skb = first, next = skb->next; skb; \ + skb = next, next = skb ? skb->next : NULL) static inline void skb_free_null_queue(struct sk_buff *first) { struct sk_buff *skb, *next; - skb_walk_null_queue_safe(first, skb, next) + skb_walk_null_queue_safe (first, skb, next) dev_kfree_skb(skb); } -static void packet_create_data_done(struct sk_buff *first, struct wireguard_peer *peer) +static void packet_create_data_done(struct sk_buff *first, + struct wireguard_peer *peer) { struct sk_buff *skb, *next; bool is_keepalive, data_sent = false; timers_any_authenticated_packet_traversal(peer); timers_any_authenticated_packet_sent(peer); - skb_walk_null_queue_safe(first, skb, next) { + skb_walk_null_queue_safe (first, skb, next) { is_keepalive = skb->len == message_data_len(0); - if (likely(!socket_send_skb_to_peer(peer, skb, PACKET_CB(skb)->ds) && !is_keepalive)) + if (likely(!socket_send_skb_to_peer(peer, skb, + PACKET_CB(skb)->ds) && !is_keepalive)) data_sent = true; } @@ -218,13 +272,16 @@ static void packet_create_data_done(struct sk_buff *first, struct wireguard_peer void packet_tx_worker(struct work_struct *work) { - struct crypt_queue *queue = container_of(work, struct crypt_queue, work); + struct crypt_queue *queue = + container_of(work, struct crypt_queue, work); struct wireguard_peer *peer; struct noise_keypair *keypair; struct sk_buff *first; enum packet_state state; - while ((first = __ptr_ring_peek(&queue->ring)) != NULL && (state = atomic_read_acquire(&PACKET_CB(first)->state)) != PACKET_STATE_UNCRYPTED) { + while ((first = __ptr_ring_peek(&queue->ring)) != NULL && + (state = atomic_read_acquire(&PACKET_CB(first)->state)) != + PACKET_STATE_UNCRYPTED) { __ptr_ring_discard_one(&queue->ring); peer = PACKET_PEER(first); keypair = PACKET_CB(first)->keypair; @@ -241,22 +298,25 @@ void packet_tx_worker(struct work_struct *work) void packet_encrypt_worker(struct work_struct *work) { - struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr; + struct crypt_queue *queue = + container_of(work, struct multicore_worker, work)->ptr; struct sk_buff *first, *skb, *next; simd_context_t simd_context = simd_get(); while ((first = ptr_ring_consume_bh(&queue->ring)) != NULL) { enum packet_state state = PACKET_STATE_CRYPTED; - skb_walk_null_queue_safe(first, skb, next) { - if (likely(skb_encrypt(skb, PACKET_CB(first)->keypair, simd_context))) + skb_walk_null_queue_safe (first, skb, next) { + if (likely(skb_encrypt(skb, PACKET_CB(first)->keypair, + simd_context))) skb_reset(skb); else { state = PACKET_STATE_DEAD; break; } } - queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, state); + queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, + state); simd_context = simd_relax(simd_context); } @@ -273,9 +333,13 @@ static void packet_create_data(struct sk_buff *first) if (unlikely(peer->is_dead)) goto err; - ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); + ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, + &peer->tx_queue, first, + wg->packet_crypt_wq, + &wg->encrypt_queue.last_cpu); if (unlikely(ret == -EPIPE)) - queue_enqueue_per_peer(&peer->tx_queue, first, PACKET_STATE_DEAD); + queue_enqueue_per_peer(&peer->tx_queue, first, + PACKET_STATE_DEAD); err: rcu_read_unlock_bh(); if (likely(!ret || ret == -EPIPE)) @@ -287,8 +351,8 @@ err: void packet_send_staged_packets(struct wireguard_peer *peer) { - struct noise_keypair *keypair; struct noise_symmetric_key *key; + struct noise_keypair *keypair; struct sk_buff_head packets; struct sk_buff *skb; @@ -302,7 +366,8 @@ void packet_send_staged_packets(struct wireguard_peer *peer) /* First we make sure we have a valid reference to a valid key. */ rcu_read_lock_bh(); - keypair = noise_keypair_get(rcu_dereference_bh(peer->keypairs.current_keypair)); + keypair = noise_keypair_get( + rcu_dereference_bh(peer->keypairs.current_keypair)); rcu_read_unlock_bh(); if (unlikely(!keypair)) goto out_nokey; @@ -312,13 +377,17 @@ void packet_send_staged_packets(struct wireguard_peer *peer) if (unlikely(has_expired(key->birthdate, REJECT_AFTER_TIME))) goto out_invalid; - /* After we know we have a somewhat valid key, we now try to assign nonces to - * all of the packets in the queue. If we can't assign nonces for all of them, - * we just consider it a failure and wait for the next handshake. + /* After we know we have a somewhat valid key, we now try to assign + * nonces to all of the packets in the queue. If we can't assign nonces + * for all of them, we just consider it a failure and wait for the next + * handshake. */ - skb_queue_walk(&packets, skb) { - PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb); - PACKET_CB(skb)->nonce = atomic64_inc_return(&key->counter.counter) - 1; + skb_queue_walk (&packets, skb) { + /* 0 for no outer TOS: no leak. TODO: should we use flowi->tos + * as outer? */ + PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb); + PACKET_CB(skb)->nonce = + atomic64_inc_return(&key->counter.counter) - 1; if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES)) goto out_invalid; } @@ -337,19 +406,19 @@ out_nokey: /* We orphan the packets if we're waiting on a handshake, so that they * don't block a socket's pool. */ - skb_queue_walk(&packets, skb) + skb_queue_walk (&packets, skb) skb_orphan(skb); - /* Then we put them back on the top of the queue. We're not too concerned about - * accidentally getting things a little out of order if packets are being added - * really fast, because this queue is for before packets can even be sent and - * it's small anyway. + /* Then we put them back on the top of the queue. We're not too + * concerned about accidentally getting things a little out of order if + * packets are being added really fast, because this queue is for before + * packets can even be sent and it's small anyway. */ spin_lock_bh(&peer->staged_packet_queue.lock); skb_queue_splice(&packets, &peer->staged_packet_queue); spin_unlock_bh(&peer->staged_packet_queue.lock); - /* If we're exiting because there's something wrong with the key, it means - * we should initiate a new handshake. + /* If we're exiting because there's something wrong with the key, it + * means we should initiate a new handshake. */ packet_send_queued_handshake_initiation(peer, false); } diff --git a/src/socket.c b/src/socket.c index 663a462..f544dd0 100644 --- a/src/socket.c +++ b/src/socket.c @@ -17,7 +17,9 @@ #include #include -static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, u8 ds, struct dst_cache *cache) +static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, + struct endpoint *endpoint, u8 ds, + struct dst_cache *cache) { struct flowi4 fl = { .saddr = endpoint->src4.s_addr, @@ -49,14 +51,21 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct if (!rt) { security_sk_classify_flow(sock, flowi4_to_flowi(&fl)); - if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, fl.saddr, RT_SCOPE_HOST))) { - endpoint->src4.s_addr = *(__force __be32 *)&endpoint->src_if4 = fl.saddr = 0; + if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, + fl.saddr, RT_SCOPE_HOST))) { + endpoint->src4.s_addr = 0; + *(__force __be32 *)&endpoint->src_if4 = 0; + fl.saddr = 0; if (cache) dst_cache_reset(cache); } rt = ip_route_output_flow(sock_net(sock), &fl, sock); - if (unlikely(endpoint->src_if4 && ((IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && rt->dst.dev->ifindex != endpoint->src_if4)))) { - endpoint->src4.s_addr = *(__force __be32 *)&endpoint->src_if4 = fl.saddr = 0; + if (unlikely(endpoint->src_if4 && ((IS_ERR(rt) && + PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && + rt->dst.dev->ifindex != endpoint->src_if4)))) { + endpoint->src4.s_addr = 0; + *(__force __be32 *)&endpoint->src_if4 = 0; + fl.saddr = 0; if (cache) dst_cache_reset(cache); if (!IS_ERR(rt)) @@ -65,18 +74,22 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct } if (unlikely(IS_ERR(rt))) { ret = PTR_ERR(rt); - net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", wg->dev->name, &endpoint->addr, ret); + net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", + wg->dev->name, &endpoint->addr, ret); goto err; } else if (unlikely(rt->dst.dev == skb->dev)) { ip_rt_put(rt); ret = -ELOOP; - net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", wg->dev->name, &endpoint->addr); + net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", + wg->dev->name, &endpoint->addr); goto err; } if (cache) dst_cache_set_ip4(cache, &rt->dst, fl.saddr); } - udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport, fl.fl4_dport, false, false); + udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, + ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport, + fl.fl4_dport, false, false); goto out; err: @@ -86,7 +99,9 @@ out: return ret; } -static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, u8 ds, struct dst_cache *cache) +static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, + struct endpoint *endpoint, u8 ds, + struct dst_cache *cache) { #if IS_ENABLED(CONFIG_IPV6) struct flowi6 fl = { @@ -121,26 +136,32 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct if (!dst) { security_sk_classify_flow(sock, flowi6_to_flowi(&fl)); - if (unlikely(!ipv6_addr_any(&fl.saddr) && !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) { + if (unlikely(!ipv6_addr_any(&fl.saddr) && + !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) { endpoint->src6 = fl.saddr = in6addr_any; if (cache) dst_cache_reset(cache); } - ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst, &fl); + ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst, + &fl); if (unlikely(ret)) { - net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", wg->dev->name, &endpoint->addr, ret); + net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", + wg->dev->name, &endpoint->addr, ret); goto err; } else if (unlikely(dst->dev == skb->dev)) { dst_release(dst); ret = -ELOOP; - net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", wg->dev->name, &endpoint->addr); + net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", + wg->dev->name, &endpoint->addr); goto err; } if (cache) dst_cache_set_ip6(cache, dst, &fl.saddr); } - udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, ip6_dst_hoplimit(dst), 0, fl.fl6_sport, fl.fl6_dport, false); + udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, + ip6_dst_hoplimit(dst), 0, fl.fl6_sport, + fl.fl6_dport, false); goto out; err: @@ -153,16 +174,19 @@ out: #endif } -int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds) +int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, + u8 ds) { size_t skb_len = skb->len; int ret = -EAFNOSUPPORT; read_lock_bh(&peer->endpoint_lock); if (peer->endpoint.addr.sa_family == AF_INET) - ret = send4(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache); + ret = send4(peer->device, skb, &peer->endpoint, ds, + &peer->endpoint_cache); else if (peer->endpoint.addr.sa_family == AF_INET6) - ret = send6(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache); + ret = send6(peer->device, skb, &peer->endpoint, ds, + &peer->endpoint_cache); else dev_kfree_skb(skb); if (likely(!ret)) @@ -172,7 +196,8 @@ int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 return ret; } -int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t len, u8 ds) +int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, + size_t len, u8 ds) { struct sk_buff *skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC); @@ -185,7 +210,9 @@ int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t return socket_send_skb_to_peer(peer, skb, ds); } -int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *buffer, size_t len) +int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, + struct sk_buff *in_skb, void *buffer, + size_t len) { int ret = 0; struct sk_buff *skb; @@ -208,12 +235,15 @@ int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_bu ret = send4(wg, skb, &endpoint, 0, NULL); else if (endpoint.addr.sa_family == AF_INET6) ret = send6(wg, skb, &endpoint, 0, NULL); - /* No other possibilities if the endpoint is valid, which it is, as we checked above. */ + /* No other possibilities if the endpoint is valid, which it is, + * as we checked above. + */ return ret; } -int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb) +int socket_endpoint_from_skb(struct endpoint *endpoint, + const struct sk_buff *skb) { memset(endpoint, 0, sizeof(struct endpoint)); if (skb->protocol == htons(ETH_P_IP)) { @@ -226,25 +256,32 @@ int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *sk endpoint->addr6.sin6_family = AF_INET6; endpoint->addr6.sin6_port = udp_hdr(skb)->source; endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr; - endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr, skb->skb_iif); + endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id( + &ipv6_hdr(skb)->saddr, skb->skb_iif); endpoint->src6 = ipv6_hdr(skb)->daddr; } else return -EINVAL; return 0; } -static inline bool endpoint_eq(const struct endpoint *a, const struct endpoint *b) +static inline bool endpoint_eq(const struct endpoint *a, + const struct endpoint *b) { return (a->addr.sa_family == AF_INET && b->addr.sa_family == AF_INET && - a->addr4.sin_port == b->addr4.sin_port && a->addr4.sin_addr.s_addr == b->addr4.sin_addr.s_addr && + a->addr4.sin_port == b->addr4.sin_port && + a->addr4.sin_addr.s_addr == b->addr4.sin_addr.s_addr && a->src4.s_addr == b->src4.s_addr && a->src_if4 == b->src_if4) || - (a->addr.sa_family == AF_INET6 && b->addr.sa_family == AF_INET6 && - a->addr6.sin6_port == b->addr6.sin6_port && ipv6_addr_equal(&a->addr6.sin6_addr, &b->addr6.sin6_addr) && - a->addr6.sin6_scope_id == b->addr6.sin6_scope_id && ipv6_addr_equal(&a->src6, &b->src6)) || + (a->addr.sa_family == AF_INET6 && + b->addr.sa_family == AF_INET6 && + a->addr6.sin6_port == b->addr6.sin6_port && + ipv6_addr_equal(&a->addr6.sin6_addr, &b->addr6.sin6_addr) && + a->addr6.sin6_scope_id == b->addr6.sin6_scope_id && + ipv6_addr_equal(&a->src6, &b->src6)) || unlikely(!a->addr.sa_family && !b->addr.sa_family); } -void socket_set_peer_endpoint(struct wireguard_peer *peer, const struct endpoint *endpoint) +void socket_set_peer_endpoint(struct wireguard_peer *peer, + const struct endpoint *endpoint) { /* First we check unlocked, in order to optimize, since it's pretty rare * that an endpoint will change. If we happen to be mid-write, and two @@ -268,7 +305,8 @@ out: write_unlock_bh(&peer->endpoint_lock); } -void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, const struct sk_buff *skb) +void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, + const struct sk_buff *skb) { struct endpoint endpoint; @@ -362,7 +400,8 @@ retry: udp_tunnel_sock_release(new4); if (ret == -EADDRINUSE && !port && retries++ < 100) goto retry; - pr_err("%s: Could not create IPv6 socket\n", wg->dev->name); + pr_err("%s: Could not create IPv6 socket\n", + wg->dev->name); return ret; } set_sock_opts(new6); @@ -374,13 +413,16 @@ retry: return 0; } -void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6) +void socket_reinit(struct wireguard_device *wg, struct sock *new4, + struct sock *new6) { struct sock *old4, *old6; mutex_lock(&wg->socket_update_lock); - old4 = rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock)); - old6 = rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock)); + old4 = rcu_dereference_protected(wg->sock4, + lockdep_is_held(&wg->socket_update_lock)); + old6 = rcu_dereference_protected(wg->sock6, + lockdep_is_held(&wg->socket_update_lock)); rcu_assign_pointer(wg->sock4, new4); rcu_assign_pointer(wg->sock6, new6); if (new4) diff --git a/src/socket.h b/src/socket.h index 9faba77..d873ffa 100644 --- a/src/socket.h +++ b/src/socket.h @@ -12,22 +12,31 @@ #include int socket_init(struct wireguard_device *wg, u16 port); -void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6); -int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, size_t len, u8 ds); -int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds); -int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *out_buffer, size_t len); +void socket_reinit(struct wireguard_device *wg, struct sock *new4, + struct sock *new6); +int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, + size_t len, u8 ds); +int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, + u8 ds); +int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, + struct sk_buff *in_skb, void *out_buffer, + size_t len); -int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb); -void socket_set_peer_endpoint(struct wireguard_peer *peer, const struct endpoint *endpoint); -void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, const struct sk_buff *skb); +int socket_endpoint_from_skb(struct endpoint *endpoint, + const struct sk_buff *skb); +void socket_set_peer_endpoint(struct wireguard_peer *peer, + const struct endpoint *endpoint); +void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, + const struct sk_buff *skb); void socket_clear_peer_endpoint_src(struct wireguard_peer *peer); #if defined(CONFIG_DYNAMIC_DEBUG) || defined(DEBUG) -#define net_dbg_skb_ratelimited(fmt, dev, skb, ...) do { \ - struct endpoint __endpoint; \ - socket_endpoint_from_skb(&__endpoint, skb); \ - net_dbg_ratelimited(fmt, dev, &__endpoint.addr, ##__VA_ARGS__); \ -} while (0) +#define net_dbg_skb_ratelimited(fmt, dev, skb, ...) do { \ + struct endpoint __endpoint; \ + socket_endpoint_from_skb(&__endpoint, skb); \ + net_dbg_ratelimited(fmt, dev, &__endpoint.addr, \ + ##__VA_ARGS__); \ + } while (0) #else #define net_dbg_skb_ratelimited(fmt, skb, ...) #endif diff --git a/src/timers.c b/src/timers.c index 7db892a..ad76466 100644 --- a/src/timers.c +++ b/src/timers.c @@ -10,22 +10,33 @@ #include "socket.h" /* - * Timer for retransmitting the handshake if we don't hear back after `REKEY_TIMEOUT + jitter` ms - * Timer for sending empty packet if we have received a packet but after have not sent one for `KEEPALIVE_TIMEOUT` ms - * Timer for initiating new handshake if we have sent a packet but after have not received one (even empty) for `(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT)` ms - * Timer for zeroing out all ephemeral keys after `(REJECT_AFTER_TIME * 3)` ms if no new keys have been received - * Timer for, if enabled, sending an empty authenticated packet every user-specified seconds + * - Timer for retransmitting the handshake if we don't hear back after + * `REKEY_TIMEOUT + jitter` ms. + * + * - Timer for sending empty packet if we have received a packet but after have + * not sent one for `KEEPALIVE_TIMEOUT` ms. + * + * - Timer for initiating new handshake if we have sent a packet but after have + * not received one (even empty) for `(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT)` ms. + * + * - Timer for zeroing out all ephemeral keys after `(REJECT_AFTER_TIME * 3)` ms + * if no new keys have been received. + * + * - Timer for, if enabled, sending an empty authenticated packet every user- + * specified seconds. */ -#define peer_get_from_timer(timer_name) \ - struct wireguard_peer *peer; \ - rcu_read_lock_bh(); \ - peer = peer_get_maybe_zero(from_timer(peer, timer, timer_name)); \ - rcu_read_unlock_bh(); \ - if (unlikely(!peer)) \ +#define peer_get_from_timer(timer_name) \ + struct wireguard_peer *peer; \ + rcu_read_lock_bh(); \ + peer = peer_get_maybe_zero(from_timer(peer, timer, timer_name)); \ + rcu_read_unlock_bh(); \ + if (unlikely(!peer)) \ return; -static inline void mod_peer_timer(struct wireguard_peer *peer, struct timer_list *timer, unsigned long expires) +static inline void mod_peer_timer(struct wireguard_peer *peer, + struct timer_list *timer, + unsigned long expires) { rcu_read_lock_bh(); if (likely(netif_running(peer->device->dev) && !peer->is_dead)) @@ -33,7 +44,8 @@ static inline void mod_peer_timer(struct wireguard_peer *peer, struct timer_list rcu_read_unlock_bh(); } -static inline void del_peer_timer(struct wireguard_peer *peer, struct timer_list *timer) +static inline void del_peer_timer(struct wireguard_peer *peer, + struct timer_list *timer) { rcu_read_lock_bh(); if (likely(netif_running(peer->device->dev) && !peer->is_dead)) @@ -46,7 +58,9 @@ static void expired_retransmit_handshake(struct timer_list *timer) peer_get_from_timer(timer_retransmit_handshake); if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) { - pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2); + pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2); del_peer_timer(peer, &peer->timer_send_keepalive); /* We drop all packets without a keypair and don't try again, @@ -58,12 +72,18 @@ static void expired_retransmit_handshake(struct timer_list *timer) * of a partial exchange. */ if (!timer_pending(&peer->timer_zero_key_material)) - mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ); + mod_peer_timer(peer, &peer->timer_zero_key_material, + jiffies + REJECT_AFTER_TIME * 3 * HZ); } else { ++peer->timer_handshake_attempts; - pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, REKEY_TIMEOUT, peer->timer_handshake_attempts + 1); + pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr, REKEY_TIMEOUT, + peer->timer_handshake_attempts + 1); - /* We clear the endpoint address src address, in case this is the cause of trouble. */ + /* We clear the endpoint address src address, in case this is + * the cause of trouble. + */ socket_clear_peer_endpoint_src(peer); packet_send_queued_handshake_initiation(peer, true); @@ -78,7 +98,8 @@ static void expired_send_keepalive(struct timer_list *timer) packet_send_keepalive(peer); if (peer->timer_need_another_keepalive) { peer->timer_need_another_keepalive = false; - mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ); + mod_peer_timer(peer, &peer->timer_send_keepalive, + jiffies + KEEPALIVE_TIMEOUT * HZ); } peer_put(peer); } @@ -87,8 +108,12 @@ static void expired_new_handshake(struct timer_list *timer) { peer_get_from_timer(timer_new_handshake); - pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); - /* We clear the endpoint address src address, in case this is the cause of trouble. */ + pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr, KEEPALIVE_TIMEOUT + REKEY_TIMEOUT); + /* We clear the endpoint address src address, in case this is the cause + * of trouble. + */ socket_clear_peer_endpoint_src(peer); packet_send_queued_handshake_initiation(peer, false); peer_put(peer); @@ -100,16 +125,22 @@ static void expired_zero_key_material(struct timer_list *timer) rcu_read_lock_bh(); if (!peer->is_dead) { - if (!queue_work(peer->device->handshake_send_wq, &peer->clear_peer_work)) /* Should take our reference. */ - peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */ + /* Should take our reference. */ + if (!queue_work(peer->device->handshake_send_wq, + &peer->clear_peer_work)) + /* If the work was already on the queue, we want to drop the extra reference */ + peer_put(peer); } rcu_read_unlock_bh(); } static void queued_expired_zero_key_material(struct work_struct *work) { - struct wireguard_peer *peer = container_of(work, struct wireguard_peer, clear_peer_work); + struct wireguard_peer *peer = + container_of(work, struct wireguard_peer, clear_peer_work); - pr_debug("%s: Zeroing out all keys for peer %llu (%pISpfsc), since we haven't received a new one in %d seconds\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, REJECT_AFTER_TIME * 3); + pr_debug("%s: Zeroing out all keys for peer %llu (%pISpfsc), since we haven't received a new one in %d seconds\n", + peer->device->dev->name, peer->internal_id, + &peer->endpoint.addr, REJECT_AFTER_TIME * 3); noise_handshake_clear(&peer->handshake); noise_keypairs_clear(&peer->keypairs); peer_put(peer); @@ -128,7 +159,8 @@ static void expired_send_persistent_keepalive(struct timer_list *timer) void timers_data_sent(struct wireguard_peer *peer) { if (!timer_pending(&peer->timer_new_handshake)) - mod_peer_timer(peer, &peer->timer_new_handshake, jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ); + mod_peer_timer(peer, &peer->timer_new_handshake, + jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ); } /* Should be called after an authenticated data packet is received. */ @@ -136,19 +168,24 @@ void timers_data_received(struct wireguard_peer *peer) { if (likely(netif_running(peer->device->dev))) { if (!timer_pending(&peer->timer_send_keepalive)) - mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ); + mod_peer_timer(peer, &peer->timer_send_keepalive, + jiffies + KEEPALIVE_TIMEOUT * HZ); else peer->timer_need_another_keepalive = true; } } -/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ +/* Should be called after any type of authenticated packet is sent, whether + * keepalive, data, or handshake. +*/ void timers_any_authenticated_packet_sent(struct wireguard_peer *peer) { del_peer_timer(peer, &peer->timer_send_keepalive); } -/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ +/* Should be called after any type of authenticated packet is received, whether + * keepalive, data, or handshake. + */ void timers_any_authenticated_packet_received(struct wireguard_peer *peer) { del_peer_timer(peer, &peer->timer_new_handshake); @@ -157,10 +194,15 @@ void timers_any_authenticated_packet_received(struct wireguard_peer *peer) /* Should be called after a handshake initiation message is sent. */ void timers_handshake_initiated(struct wireguard_peer *peer) { - mod_peer_timer(peer, &peer->timer_retransmit_handshake, jiffies + REKEY_TIMEOUT * HZ + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); + mod_peer_timer( + peer, &peer->timer_retransmit_handshake, + jiffies + REKEY_TIMEOUT * HZ + + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES)); } -/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ +/* Should be called after a handshake response message is received and processed + * or when getting key confirmation via the first data message. + */ void timers_handshake_complete(struct wireguard_peer *peer) { del_peer_timer(peer, &peer->timer_retransmit_handshake); @@ -169,26 +211,34 @@ void timers_handshake_complete(struct wireguard_peer *peer) getnstimeofday(&peer->walltime_last_handshake); } -/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ +/* Should be called after an ephemeral key is created, which is before sending a + * handshake response or after receiving a handshake response. + */ void timers_session_derived(struct wireguard_peer *peer) { - mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ); + mod_peer_timer(peer, &peer->timer_zero_key_material, + jiffies + REJECT_AFTER_TIME * 3 * HZ); } -/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ +/* Should be called before a packet with authentication, whether + * keepalive, data, or handshakem is sent, or after one is received. + */ void timers_any_authenticated_packet_traversal(struct wireguard_peer *peer) { if (peer->persistent_keepalive_interval) - mod_peer_timer(peer, &peer->timer_persistent_keepalive, jiffies + peer->persistent_keepalive_interval * HZ); + mod_peer_timer(peer, &peer->timer_persistent_keepalive, + jiffies + peer->persistent_keepalive_interval * HZ); } void timers_init(struct wireguard_peer *peer) { - timer_setup(&peer->timer_retransmit_handshake, expired_retransmit_handshake, 0); + timer_setup(&peer->timer_retransmit_handshake, + expired_retransmit_handshake, 0); timer_setup(&peer->timer_send_keepalive, expired_send_keepalive, 0); timer_setup(&peer->timer_new_handshake, expired_new_handshake, 0); timer_setup(&peer->timer_zero_key_material, expired_zero_key_material, 0); - timer_setup(&peer->timer_persistent_keepalive, expired_send_persistent_keepalive, 0); + timer_setup(&peer->timer_persistent_keepalive, + expired_send_persistent_keepalive, 0); INIT_WORK(&peer->clear_peer_work, queued_expired_zero_key_material); peer->timer_handshake_attempts = 0; peer->sent_lastminute_handshake = false; diff --git a/src/timers.h b/src/timers.h index c95bf31..483529c 100644 --- a/src/timers.h +++ b/src/timers.h @@ -23,7 +23,8 @@ void timers_any_authenticated_packet_traversal(struct wireguard_peer *peer); static inline bool has_expired(u64 birthday_nanoseconds, u64 expiration_seconds) { - return (s64)(birthday_nanoseconds + expiration_seconds * NSEC_PER_SEC) <= (s64)ktime_get_boot_fast_ns(); + return (s64)(birthday_nanoseconds + expiration_seconds * NSEC_PER_SEC) + <= (s64)ktime_get_boot_fast_ns(); } #endif /* _WG_TIMERS_H */ -- cgit v1.2.3