summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/allowedips.c148
-rw-r--r--src/allowedips.h23
-rw-r--r--src/cookie.c102
-rw-r--r--src/cookie.h19
-rw-r--r--src/device.c95
-rw-r--r--src/device.h3
-rw-r--r--src/hashtables.c105
-rw-r--r--src/hashtables.h25
-rw-r--r--src/main.c5
-rw-r--r--src/messages.h14
-rw-r--r--src/netlink.c215
-rw-r--r--src/noise.c360
-rw-r--r--src/noise.h33
-rw-r--r--src/peer.c85
-rw-r--r--src/peer.h21
-rw-r--r--src/queueing.c14
-rw-r--r--src/queueing.h58
-rw-r--r--src/ratelimiter.c47
-rw-r--r--src/receive.c257
-rw-r--r--src/selftest/allowedips.h276
-rw-r--r--src/selftest/counter.h25
-rw-r--r--src/selftest/ratelimiter.h48
-rw-r--r--src/send.c203
-rw-r--r--src/socket.c108
-rw-r--r--src/socket.h33
-rw-r--r--src/timers.c122
-rw-r--r--src/timers.h3
27 files changed, 1653 insertions, 794 deletions
diff --git a/src/allowedips.c b/src/allowedips.c
index 1442bf4..4616645 100644
--- a/src/allowedips.c
+++ b/src/allowedips.c
@@ -28,7 +28,8 @@ static __always_inline void swap_endian(u8 *dst, const u8 *src, u8 bits)
}
}
-static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8 cidr, u8 bits)
+static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
+ u8 cidr, u8 bits)
{
node->cidr = cidr;
node->bit_at_a = cidr / 8U;
@@ -39,34 +40,43 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src, u8
memcpy(node->bits, src, bits / 8U);
}
-#define choose_node(parent, key) parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
+#define choose_node(parent, key) \
+ parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
static void node_free_rcu(struct rcu_head *rcu)
{
kfree(container_of(rcu, struct allowedips_node, rcu));
}
-#define push_rcu(stack, p, len) ({ \
- if (rcu_access_pointer(p)) { \
- BUG_ON(len >= 128); \
- stack[len++] = rcu_dereference_raw(p); \
- } \
- true; \
-})
+#define push_rcu(stack, p, len) ({ \
+ if (rcu_access_pointer(p)) { \
+ BUG_ON(len >= 128); \
+ stack[len++] = rcu_dereference_raw(p); \
+ } \
+ true; \
+ })
static void root_free_rcu(struct rcu_head *rcu)
{
- struct allowedips_node *node, *stack[128] = { container_of(rcu, struct allowedips_node, rcu) };
+ struct allowedips_node *node, *stack[128] =
+ { container_of(rcu, struct allowedips_node, rcu) };
unsigned int len = 1;
- while (len > 0 && (node = stack[--len]) && push_rcu(stack, node->bit[0], len) && push_rcu(stack, node->bit[1], len))
+ while (len > 0 && (node = stack[--len]) &&
+ push_rcu(stack, node->bit[0], len) &&
+ push_rcu(stack, node->bit[1], len))
kfree(node);
}
-static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
+static int
+walk_by_peer(struct allowedips_node __rcu *top, u8 bits,
+ struct allowedips_cursor *cursor, struct wireguard_peer *peer,
+ int (*func)(void *ctx, const u8 *ip, u8 cidr, int family),
+ void *ctx, struct mutex *lock)
{
+ const int address_family = bits == 32 ? AF_INET : AF_INET6;
+ u8 ip[16] __aligned(__alignof(u64));
struct allowedips_node *node;
int ret;
- u8 ip[16] __aligned(__alignof(u64));
if (!rcu_access_pointer(top))
return 0;
@@ -74,16 +84,21 @@ static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allow
if (!cursor->len)
push_rcu(cursor->stack, top, cursor->len);
- for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]); --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len), push_rcu(cursor->stack, node->bit[1], cursor->len)) {
- if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) != peer)
+ for (; cursor->len > 0 && (node = cursor->stack[cursor->len - 1]);
+ --cursor->len, push_rcu(cursor->stack, node->bit[0], cursor->len),
+ push_rcu(cursor->stack, node->bit[1], cursor->len)) {
+ const unsigned int cidr_bytes = DIV_ROUND_UP(node->cidr, 8U);
+
+ if (rcu_dereference_protected(node->peer,
+ lockdep_is_held(lock)) != peer)
continue;
swap_endian(ip, node->bits, bits);
- memset(ip + (node->cidr + 7U) / 8U, 0, (bits / 8U) - ((node->cidr + 7U) / 8U));
+ memset(ip + cidr_bytes, 0, bits / 8U - cidr_bytes);
if (node->cidr)
- ip[(node->cidr + 7U) / 8U - 1U] &= ~0U << (-node->cidr % 8U);
+ ip[cidr_bytes - 1U] &= ~0U << (-node->cidr % 8U);
- ret = func(ctx, ip, node->cidr, bits == 32 ? AF_INET : AF_INET6);
+ ret = func(ctx, ip, node->cidr, address_family);
if (ret)
return ret;
}
@@ -93,8 +108,12 @@ static int walk_by_peer(struct allowedips_node __rcu *top, u8 bits, struct allow
#define ref(p) rcu_access_pointer(p)
#define deref(p) rcu_dereference_protected(*p, lockdep_is_held(lock))
-#define push(p) ({ BUG_ON(len >= 128); stack[len++] = p; })
-static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireguard_peer *peer, struct mutex *lock)
+#define push(p) ({ \
+ BUG_ON(len >= 128); \
+ stack[len++] = p; \
+ })
+static void walk_remove_by_peer(struct allowedips_node __rcu **top,
+ struct wireguard_peer *peer, struct mutex *lock)
{
struct allowedips_node __rcu **stack[128], **nptr;
struct allowedips_node *node, *prev;
@@ -110,7 +129,8 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireg
--len;
continue;
}
- if (!prev || ref(prev->bit[0]) == node || ref(prev->bit[1]) == node) {
+ if (!prev || ref(prev->bit[0]) == node ||
+ ref(prev->bit[1]) == node) {
if (ref(node->bit[0]))
push(&node->bit[0]);
else if (ref(node->bit[1]))
@@ -119,10 +139,12 @@ static void walk_remove_by_peer(struct allowedips_node __rcu **top, struct wireg
if (ref(node->bit[1]))
push(&node->bit[1]);
} else {
- if (rcu_dereference_protected(node->peer, lockdep_is_held(lock)) == peer) {
+ if (rcu_dereference_protected(node->peer,
+ lockdep_is_held(lock)) == peer) {
RCU_INIT_POINTER(node->peer, NULL);
if (!node->bit[0] || !node->bit[1]) {
- rcu_assign_pointer(*nptr, deref(&node->bit[!ref(node->bit[0])]));
+ rcu_assign_pointer(*nptr,
+ deref(&node->bit[!ref(node->bit[0])]));
call_rcu_bh(&node->rcu, node_free_rcu);
node = deref(nptr);
}
@@ -140,23 +162,29 @@ static __always_inline unsigned int fls128(u64 a, u64 b)
return a ? fls64(a) + 64U : fls64(b);
}
-static __always_inline u8 common_bits(const struct allowedips_node *node, const u8 *key, u8 bits)
+static __always_inline u8 common_bits(const struct allowedips_node *node,
+ const u8 *key, u8 bits)
{
if (bits == 32)
return 32U - fls(*(const u32 *)node->bits ^ *(const u32 *)key);
else if (bits == 128)
- return 128U - fls128(*(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0], *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]);
+ return 128U - fls128(
+ *(const u64 *)&node->bits[0] ^ *(const u64 *)&key[0],
+ *(const u64 *)&node->bits[8] ^ *(const u64 *)&key[8]);
return 0;
}
-/* This could be much faster if it actually just compared the common bits properly,
- * by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but it turns out that
- * common_bits is already super fast on modern processors, even taking into account
- * the unfortunate bswap. So, we just inline it like this instead.
+/* This could be much faster if it actually just compared the common bits
+ * properly, by precomputing a mask bswap(~0 << (32 - cidr)), and the rest, but
+ * it turns out that common_bits is already super fast on modern processors,
+ * even taking into account the unfortunate bswap. So, we just inline it like
+ * this instead.
*/
-#define prefix_matches(node, key, bits) (common_bits(node, key, bits) >= node->cidr)
+#define prefix_matches(node, key, bits) \
+ (common_bits(node, key, bits) >= node->cidr)
-static __always_inline struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits, const u8 *key)
+static __always_inline struct allowedips_node *
+find_node(struct allowedips_node *trie, u8 bits, const u8 *key)
{
struct allowedips_node *node = trie, *found = NULL;
@@ -171,11 +199,12 @@ static __always_inline struct allowedips_node *find_node(struct allowedips_node
}
/* Returns a strong reference to a peer */
-static __always_inline struct wireguard_peer *lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip)
+static __always_inline struct wireguard_peer *
+lookup(struct allowedips_node __rcu *root, u8 bits, const void *be_ip)
{
+ u8 ip[16] __aligned(__alignof(u64));
struct wireguard_peer *peer = NULL;
struct allowedips_node *node;
- u8 ip[16] __aligned(__alignof(u64));
swap_endian(ip, be_ip, bits);
@@ -191,11 +220,14 @@ retry:
return peer;
}
-__attribute__((nonnull(1)))
-static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr, u8 bits, struct allowedips_node **rnode, struct mutex *lock)
+__attribute__((nonnull(1))) static inline bool
+node_placement(struct allowedips_node __rcu *trie, const u8 *key, u8 cidr,
+ u8 bits, struct allowedips_node **rnode, struct mutex *lock)
{
+ struct allowedips_node *node = rcu_dereference_protected(trie,
+ lockdep_is_held(lock));
+ struct allowedips_node *parent = NULL;
bool exact = false;
- struct allowedips_node *parent = NULL, *node = rcu_dereference_protected(trie, lockdep_is_held(lock));
while (node && node->cidr <= cidr && prefix_matches(node, key, bits)) {
parent = node;
@@ -203,13 +235,15 @@ static inline bool node_placement(struct allowedips_node __rcu *trie, const u8 *
exact = true;
break;
}
- node = rcu_dereference_protected(choose_node(parent, key), lockdep_is_held(lock));
+ node = rcu_dereference_protected(choose_node(parent, key),
+ lockdep_is_held(lock));
}
*rnode = parent;
return exact;
}
-static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
+static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key,
+ u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
{
struct allowedips_node *node, *parent, *down, *newnode;
u8 key[16] __aligned(__alignof(u64));
@@ -242,7 +276,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u
if (!node)
down = rcu_dereference_protected(*trie, lockdep_is_held(lock));
else {
- down = rcu_dereference_protected(choose_node(node, key), lockdep_is_held(lock));
+ down = rcu_dereference_protected(choose_node(node, key),
+ lockdep_is_held(lock));
if (!down) {
rcu_assign_pointer(choose_node(node, key), newnode);
return 0;
@@ -256,7 +291,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u
if (!parent)
rcu_assign_pointer(*trie, newnode);
else
- rcu_assign_pointer(choose_node(parent, newnode->bits), newnode);
+ rcu_assign_pointer(choose_node(parent, newnode->bits),
+ newnode);
} else {
node = kzalloc(sizeof(*node), GFP_KERNEL);
if (!node) {
@@ -270,7 +306,8 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *be_key, u
if (!parent)
rcu_assign_pointer(*trie, node);
else
- rcu_assign_pointer(choose_node(parent, node->bits), node);
+ rcu_assign_pointer(choose_node(parent, node->bits),
+ node);
}
return 0;
}
@@ -288,31 +325,42 @@ void allowedips_free(struct allowedips *table, struct mutex *lock)
RCU_INIT_POINTER(table->root4, NULL);
RCU_INIT_POINTER(table->root6, NULL);
if (rcu_access_pointer(old4))
- call_rcu_bh(&rcu_dereference_protected(old4, lockdep_is_held(lock))->rcu, root_free_rcu);
+ call_rcu_bh(&rcu_dereference_protected(old4,
+ lockdep_is_held(lock))->rcu, root_free_rcu);
if (rcu_access_pointer(old6))
- call_rcu_bh(&rcu_dereference_protected(old6, lockdep_is_held(lock))->rcu, root_free_rcu);
+ call_rcu_bh(&rcu_dereference_protected(old6,
+ lockdep_is_held(lock))->rcu, root_free_rcu);
}
-int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
+int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wireguard_peer *peer,
+ struct mutex *lock)
{
++table->seq;
return add(&table->root4, 32, (const u8 *)ip, cidr, peer, lock);
}
-int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock)
+int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wireguard_peer *peer,
+ struct mutex *lock)
{
++table->seq;
return add(&table->root6, 128, (const u8 *)ip, cidr, peer, lock);
}
-void allowedips_remove_by_peer(struct allowedips *table, struct wireguard_peer *peer, struct mutex *lock)
+void allowedips_remove_by_peer(struct allowedips *table,
+ struct wireguard_peer *peer, struct mutex *lock)
{
++table->seq;
walk_remove_by_peer(&table->root4, peer, lock);
walk_remove_by_peer(&table->root6, peer, lock);
}
-int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock)
+int allowedips_walk_by_peer(struct allowedips *table,
+ struct allowedips_cursor *cursor,
+ struct wireguard_peer *peer,
+ int (*func)(void *ctx, const u8 *ip, u8 cidr, int family),
+ void *ctx, struct mutex *lock)
{
int ret;
@@ -332,7 +380,8 @@ int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *
}
/* Returns a strong reference to a peer */
-struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb)
+struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table,
+ struct sk_buff *skb)
{
if (skb->protocol == htons(ETH_P_IP))
return lookup(table->root4, 32, &ip_hdr(skb)->daddr);
@@ -342,7 +391,8 @@ struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk
}
/* Returns a strong reference to a peer */
-struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, struct sk_buff *skb)
+struct wireguard_peer *allowedips_lookup_src(struct allowedips *table,
+ struct sk_buff *skb)
{
if (skb->protocol == htons(ETH_P_IP))
return lookup(table->root4, 32, &ip_hdr(skb)->saddr);
diff --git a/src/allowedips.h b/src/allowedips.h
index 97ecf69..d5ba1be 100644
--- a/src/allowedips.h
+++ b/src/allowedips.h
@@ -28,14 +28,25 @@ struct allowedips_cursor {
void allowedips_init(struct allowedips *table);
void allowedips_free(struct allowedips *table, struct mutex *mutex);
-int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock);
-int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip, u8 cidr, struct wireguard_peer *peer, struct mutex *lock);
-void allowedips_remove_by_peer(struct allowedips *table, struct wireguard_peer *peer, struct mutex *lock);
-int allowedips_walk_by_peer(struct allowedips *table, struct allowedips_cursor *cursor, struct wireguard_peer *peer, int (*func)(void *ctx, const u8 *ip, u8 cidr, int family), void *ctx, struct mutex *lock);
+int allowedips_insert_v4(struct allowedips *table, const struct in_addr *ip,
+ u8 cidr, struct wireguard_peer *peer,
+ struct mutex *lock);
+int allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
+ u8 cidr, struct wireguard_peer *peer,
+ struct mutex *lock);
+void allowedips_remove_by_peer(struct allowedips *table,
+ struct wireguard_peer *peer, struct mutex *lock);
+int allowedips_walk_by_peer(struct allowedips *table,
+ struct allowedips_cursor *cursor,
+ struct wireguard_peer *peer,
+ int (*func)(void *ctx, const u8 *ip, u8 cidr, int family),
+ void *ctx, struct mutex *lock);
/* These return a strong reference to a peer: */
-struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb);
-struct wireguard_peer *allowedips_lookup_src(struct allowedips *table, struct sk_buff *skb);
+struct wireguard_peer *allowedips_lookup_dst(struct allowedips *table,
+ struct sk_buff *skb);
+struct wireguard_peer *allowedips_lookup_src(struct allowedips *table,
+ struct sk_buff *skb);
#ifdef DEBUG
bool allowedips_selftest(void);
diff --git a/src/cookie.c b/src/cookie.c
index 9268630..7cf0693 100644
--- a/src/cookie.c
+++ b/src/cookie.c
@@ -15,7 +15,8 @@
#include <net/ipv6.h>
#include <crypto/algapi.h>
-void cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg)
+void cookie_checker_init(struct cookie_checker *checker,
+ struct wireguard_device *wg)
{
init_rwsem(&checker->secret_lock);
checker->secret_birthdate = ktime_get_boot_fast_ns();
@@ -27,7 +28,9 @@ enum { COOKIE_KEY_LABEL_LEN = 8 };
static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----";
static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--";
-static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOISE_PUBLIC_KEY_LEN], const u8 label[COOKIE_KEY_LABEL_LEN])
+static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN],
+ const u8 pubkey[NOISE_PUBLIC_KEY_LEN],
+ const u8 label[COOKIE_KEY_LABEL_LEN])
{
struct blake2s_state blake;
@@ -41,18 +44,25 @@ static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOIS
void cookie_checker_precompute_device_keys(struct cookie_checker *checker)
{
if (likely(checker->device->static_identity.has_identity)) {
- precompute_key(checker->cookie_encryption_key, checker->device->static_identity.static_public, cookie_key_label);
- precompute_key(checker->message_mac1_key, checker->device->static_identity.static_public, mac1_key_label);
+ precompute_key(checker->cookie_encryption_key,
+ checker->device->static_identity.static_public,
+ cookie_key_label);
+ precompute_key(checker->message_mac1_key,
+ checker->device->static_identity.static_public,
+ mac1_key_label);
} else {
- memset(checker->cookie_encryption_key, 0, NOISE_SYMMETRIC_KEY_LEN);
+ memset(checker->cookie_encryption_key, 0,
+ NOISE_SYMMETRIC_KEY_LEN);
memset(checker->message_mac1_key, 0, NOISE_SYMMETRIC_KEY_LEN);
}
}
void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer)
{
- precompute_key(peer->latest_cookie.cookie_decryption_key, peer->handshake.remote_static, cookie_key_label);
- precompute_key(peer->latest_cookie.message_mac1_key, peer->handshake.remote_static, mac1_key_label);
+ precompute_key(peer->latest_cookie.cookie_decryption_key,
+ peer->handshake.remote_static, cookie_key_label);
+ precompute_key(peer->latest_cookie.message_mac1_key,
+ peer->handshake.remote_static, mac1_key_label);
}
void cookie_init(struct cookie *cookie)
@@ -61,19 +71,24 @@ void cookie_init(struct cookie *cookie)
init_rwsem(&cookie->lock);
}
-static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, const u8 key[NOISE_SYMMETRIC_KEY_LEN])
+static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len,
+ const u8 key[NOISE_SYMMETRIC_KEY_LEN])
{
- len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac1);
+ len = len - sizeof(struct message_macs) +
+ offsetof(struct message_macs, mac1);
blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN);
}
-static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, const u8 cookie[COOKIE_LEN])
+static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len,
+ const u8 cookie[COOKIE_LEN])
{
- len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac2);
+ len = len - sizeof(struct message_macs) +
+ offsetof(struct message_macs, mac2);
blake2s(mac2, message, cookie, COOKIE_LEN, len, COOKIE_LEN);
}
-static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cookie_checker *checker)
+static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb,
+ struct cookie_checker *checker)
{
struct blake2s_state state;
@@ -88,24 +103,30 @@ static void make_cookie(u8 cookie[COOKIE_LEN], struct sk_buff *skb, struct cooki
blake2s_init_key(&state, COOKIE_LEN, checker->secret, NOISE_HASH_LEN);
if (skb->protocol == htons(ETH_P_IP))
- blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr, sizeof(struct in_addr));
+ blake2s_update(&state, (u8 *)&ip_hdr(skb)->saddr,
+ sizeof(struct in_addr));
else if (skb->protocol == htons(ETH_P_IPV6))
- blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr, sizeof(struct in6_addr));
+ blake2s_update(&state, (u8 *)&ipv6_hdr(skb)->saddr,
+ sizeof(struct in6_addr));
blake2s_update(&state, (u8 *)&udp_hdr(skb)->source, sizeof(__be16));
blake2s_final(&state, cookie, COOKIE_LEN);
up_read(&checker->secret_lock);
}
-enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie)
+enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker,
+ struct sk_buff *skb,
+ bool check_cookie)
{
+ struct message_macs *macs = (struct message_macs *)
+ (skb->data + skb->len - sizeof(struct message_macs));
+ enum cookie_mac_state ret;
u8 computed_mac[COOKIE_LEN];
u8 cookie[COOKIE_LEN];
- enum cookie_mac_state ret;
- struct message_macs *macs = (struct message_macs *)(skb->data + skb->len - sizeof(struct message_macs));
ret = INVALID_MAC;
- compute_mac1(computed_mac, skb->data, skb->len, checker->message_mac1_key);
+ compute_mac1(computed_mac, skb->data, skb->len,
+ checker->message_mac1_key);
if (crypto_memneq(computed_mac, macs->mac1, COOKIE_LEN))
goto out;
@@ -130,27 +151,36 @@ out:
return ret;
}
-void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *peer)
+void cookie_add_mac_to_packet(void *message, size_t len,
+ struct wireguard_peer *peer)
{
- struct message_macs *macs = (struct message_macs *)((u8 *)message + len - sizeof(struct message_macs));
+ struct message_macs *macs = (struct message_macs *)
+ ((u8 *)message + len - sizeof(struct message_macs));
down_write(&peer->latest_cookie.lock);
- compute_mac1(macs->mac1, message, len, peer->latest_cookie.message_mac1_key);
+ compute_mac1(macs->mac1, message, len,
+ peer->latest_cookie.message_mac1_key);
memcpy(peer->latest_cookie.last_mac1_sent, macs->mac1, COOKIE_LEN);
peer->latest_cookie.have_sent_mac1 = true;
up_write(&peer->latest_cookie.lock);
down_read(&peer->latest_cookie.lock);
- if (peer->latest_cookie.is_valid && !has_expired(peer->latest_cookie.birthdate, COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY))
- compute_mac2(macs->mac2, message, len, peer->latest_cookie.cookie);
+ if (peer->latest_cookie.is_valid &&
+ !has_expired(peer->latest_cookie.birthdate,
+ COOKIE_SECRET_MAX_AGE - COOKIE_SECRET_LATENCY))
+ compute_mac2(macs->mac2, message, len,
+ peer->latest_cookie.cookie);
else
memset(macs->mac2, 0, COOKIE_LEN);
up_read(&peer->latest_cookie.lock);
}
-void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff *skb, __le32 index, struct cookie_checker *checker)
+void cookie_message_create(struct message_handshake_cookie *dst,
+ struct sk_buff *skb, __le32 index,
+ struct cookie_checker *checker)
{
- struct message_macs *macs = (struct message_macs *)((u8 *)skb->data + skb->len - sizeof(struct message_macs));
+ struct message_macs *macs = (struct message_macs *)
+ ((u8 *)skb->data + skb->len - sizeof(struct message_macs));
u8 cookie[COOKIE_LEN];
dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE);
@@ -158,16 +188,22 @@ void cookie_message_create(struct message_handshake_cookie *dst, struct sk_buff
get_random_bytes_wait(dst->nonce, COOKIE_NONCE_LEN);
make_cookie(cookie, skb, checker);
- xchacha20poly1305_encrypt(dst->encrypted_cookie, cookie, COOKIE_LEN, macs->mac1, COOKIE_LEN, dst->nonce, checker->cookie_encryption_key);
+ xchacha20poly1305_encrypt(dst->encrypted_cookie, cookie, COOKIE_LEN,
+ macs->mac1, COOKIE_LEN, dst->nonce,
+ checker->cookie_encryption_key);
}
-void cookie_message_consume(struct message_handshake_cookie *src, struct wireguard_device *wg)
+void cookie_message_consume(struct message_handshake_cookie *src,
+ struct wireguard_device *wg)
{
- u8 cookie[COOKIE_LEN];
struct wireguard_peer *peer = NULL;
+ u8 cookie[COOKIE_LEN];
bool ret;
- if (unlikely(!index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index, &peer)))
+ if (unlikely(!index_hashtable_lookup(&wg->index_hashtable,
+ INDEX_HASHTABLE_HANDSHAKE |
+ INDEX_HASHTABLE_KEYPAIR,
+ src->receiver_index, &peer)))
return;
down_read(&peer->latest_cookie.lock);
@@ -175,7 +211,10 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua
up_read(&peer->latest_cookie.lock);
goto out;
}
- ret = xchacha20poly1305_decrypt(cookie, src->encrypted_cookie, sizeof(src->encrypted_cookie), peer->latest_cookie.last_mac1_sent, COOKIE_LEN, src->nonce, peer->latest_cookie.cookie_decryption_key);
+ ret = xchacha20poly1305_decrypt(
+ cookie, src->encrypted_cookie, sizeof(src->encrypted_cookie),
+ peer->latest_cookie.last_mac1_sent, COOKIE_LEN, src->nonce,
+ peer->latest_cookie.cookie_decryption_key);
up_read(&peer->latest_cookie.lock);
if (ret) {
@@ -186,7 +225,8 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua
peer->latest_cookie.have_sent_mac1 = false;
up_write(&peer->latest_cookie.lock);
} else
- net_dbg_ratelimited("%s: Could not decrypt invalid cookie response\n", wg->dev->name);
+ net_dbg_ratelimited("%s: Could not decrypt invalid cookie response\n",
+ wg->dev->name);
out:
peer_put(peer);
diff --git a/src/cookie.h b/src/cookie.h
index 9f519ef..7802f61 100644
--- a/src/cookie.h
+++ b/src/cookie.h
@@ -38,15 +38,22 @@ enum cookie_mac_state {
VALID_MAC_WITH_COOKIE
};
-void cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg);
+void cookie_checker_init(struct cookie_checker *checker,
+ struct wireguard_device *wg);
void cookie_checker_precompute_device_keys(struct cookie_checker *checker);
void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer);
void cookie_init(struct cookie *cookie);
-enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie);
-void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer *peer);
-
-void cookie_message_create(struct message_handshake_cookie *src, struct sk_buff *skb, __le32 index, struct cookie_checker *checker);
-void cookie_message_consume(struct message_handshake_cookie *src, struct wireguard_device *wg);
+enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker,
+ struct sk_buff *skb,
+ bool check_cookie);
+void cookie_add_mac_to_packet(void *message, size_t len,
+ struct wireguard_peer *peer);
+
+void cookie_message_create(struct message_handshake_cookie *src,
+ struct sk_buff *skb, __le32 index,
+ struct cookie_checker *checker);
+void cookie_message_consume(struct message_handshake_cookie *src,
+ struct wireguard_device *wg);
#endif /* _WG_COOKIE_H */
diff --git a/src/device.c b/src/device.c
index 297125a..3dfa794 100644
--- a/src/device.c
+++ b/src/device.c
@@ -28,19 +28,18 @@ static LIST_HEAD(device_list);
static int open(struct net_device *dev)
{
- int ret;
- struct wireguard_peer *peer;
+ struct in_device *dev_v4 = __in_dev_get_rtnl(dev);
struct wireguard_device *wg = netdev_priv(dev);
#ifndef COMPAT_CANNOT_USE_IN6_DEV_GET
struct inet6_dev *dev_v6 = __in6_dev_get(dev);
#endif
- struct in_device *dev_v4 = __in_dev_get_rtnl(dev);
+ struct wireguard_peer *peer;
+ int ret;
if (dev_v4) {
- /* TODO: at some point we might put this check near the ip_rt_send_redirect
- * call of ip_forward in net/ipv4/ip_forward.c, similar to the current secpath
- * check, rather than turning it off like this. This is just a stop gap solution
- * while we're an out of tree module.
+ /* At some point we might put this check near the ip_rt_send_
+ * redirect call of ip_forward in net/ipv4/ip_forward.c, similar
+ * to the current secpath check.
*/
IN_DEV_CONF_SET(dev_v4, SEND_REDIRECTS, false);
IPV4_DEVCONF_ALL(dev_net(dev), SEND_REDIRECTS) = false;
@@ -58,7 +57,7 @@ static int open(struct net_device *dev)
if (ret < 0)
return ret;
mutex_lock(&wg->device_update_lock);
- list_for_each_entry(peer, &wg->peer_list, peer_list) {
+ list_for_each_entry (peer, &wg->peer_list, peer_list) {
packet_send_staged_packets(peer);
if (peer->persistent_keepalive_interval)
packet_send_keepalive(peer);
@@ -68,7 +67,8 @@ static int open(struct net_device *dev)
}
#if defined(CONFIG_PM_SLEEP) && !defined(CONFIG_ANDROID)
-static int pm_notification(struct notifier_block *nb, unsigned long action, void *data)
+static int pm_notification(struct notifier_block *nb, unsigned long action,
+ void *data)
{
struct wireguard_device *wg;
struct wireguard_peer *peer;
@@ -77,9 +77,9 @@ static int pm_notification(struct notifier_block *nb, unsigned long action, void
return 0;
rtnl_lock();
- list_for_each_entry(wg, &device_list, device_list) {
+ list_for_each_entry (wg, &device_list, device_list) {
mutex_lock(&wg->device_update_lock);
- list_for_each_entry(peer, &wg->peer_list, peer_list) {
+ list_for_each_entry (peer, &wg->peer_list, peer_list) {
noise_handshake_clear(&peer->handshake);
noise_keypairs_clear(&peer->keypairs);
if (peer->timers_enabled)
@@ -100,12 +100,14 @@ static int stop(struct net_device *dev)
struct wireguard_peer *peer;
mutex_lock(&wg->device_update_lock);
- list_for_each_entry(peer, &wg->peer_list, peer_list) {
+ list_for_each_entry (peer, &wg->peer_list, peer_list) {
skb_queue_purge(&peer->staged_packet_queue);
timers_stop(peer);
noise_handshake_clear(&peer->handshake);
noise_keypairs_clear(&peer->keypairs);
- atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns() - (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC);
+ atomic64_set(&peer->last_sent_handshake,
+ ktime_get_boot_fast_ns() -
+ (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC);
}
mutex_unlock(&wg->device_update_lock);
skb_queue_purge(&wg->incoming_handshakes);
@@ -133,16 +135,19 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev)
if (unlikely(!peer)) {
ret = -ENOKEY;
if (skb->protocol == htons(ETH_P_IP))
- net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI4\n", dev->name, &ip_hdr(skb)->daddr);
+ net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI4\n",
+ dev->name, &ip_hdr(skb)->daddr);
else if (skb->protocol == htons(ETH_P_IPV6))
- net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI6\n", dev->name, &ipv6_hdr(skb)->daddr);
+ net_dbg_ratelimited("%s: No peer has allowed IPs matching %pI6\n",
+ dev->name, &ipv6_hdr(skb)->daddr);
goto err;
}
family = READ_ONCE(peer->endpoint.addr.sa_family);
if (unlikely(family != AF_INET && family != AF_INET6)) {
ret = -EDESTADDRREQ;
- net_dbg_ratelimited("%s: No valid endpoint has been configured or discovered for peer %llu\n", dev->name, peer->internal_id);
+ net_dbg_ratelimited("%s: No valid endpoint has been configured or discovered for peer %llu\n",
+ dev->name, peer->internal_id);
goto err_peer;
}
@@ -180,8 +185,9 @@ static netdev_tx_t xmit(struct sk_buff *skb, struct net_device *dev)
} while ((skb = next) != NULL);
spin_lock_bh(&peer->staged_packet_queue.lock);
- /* If the queue is getting too big, we start removing the oldest packets until it's small again.
- * We do this before adding the new packet, so we don't remove GSO segments that are in excess.
+ /* If the queue is getting too big, we start removing the oldest packets
+ * until it's small again. We do this before adding the new packet, so
+ * we don't remove GSO segments that are in excess.
*/
while (skb_queue_len(&peer->staged_packet_queue) > MAX_STAGED_PACKETS)
dev_kfree_skb(__skb_dequeue(&peer->staged_packet_queue));
@@ -223,7 +229,8 @@ static void destruct(struct net_device *dev)
wg->incoming_port = 0;
socket_reinit(wg, NULL, NULL);
allowedips_free(&wg->peer_allowedips, &wg->device_update_lock);
- peer_remove_all(wg); /* The final references are cleared in the below calls to destroy_workqueue. */
+ /* The final references are cleared in the below calls to destroy_workqueue. */
+ peer_remove_all(wg);
destroy_workqueue(wg->handshake_receive_wq);
destroy_workqueue(wg->handshake_send_wq);
destroy_workqueue(wg->packet_crypt_wq);
@@ -231,7 +238,8 @@ static void destruct(struct net_device *dev)
packet_queue_free(&wg->encrypt_queue, true);
rcu_barrier_bh(); /* Wait for all the peers to be actually freed. */
ratelimiter_uninit();
- memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity));
+ memzero_explicit(&wg->static_identity,
+ sizeof(struct noise_static_identity));
skb_queue_purge(&wg->incoming_handshakes);
free_percpu(dev->tstats);
free_percpu(wg->incoming_handshakes_worker);
@@ -243,14 +251,14 @@ static void destruct(struct net_device *dev)
free_netdev(dev);
}
-static const struct device_type device_type = {
- .name = KBUILD_MODNAME
-};
+static const struct device_type device_type = { .name = KBUILD_MODNAME };
static void setup(struct net_device *dev)
{
struct wireguard_device *wg = netdev_priv(dev);
- enum { WG_NETDEV_FEATURES = NETIF_F_HW_CSUM | NETIF_F_RXCSUM | NETIF_F_SG | NETIF_F_GSO | NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA };
+ enum { WG_NETDEV_FEATURES = NETIF_F_HW_CSUM | NETIF_F_RXCSUM |
+ NETIF_F_SG | NETIF_F_GSO |
+ NETIF_F_GSO_SOFTWARE | NETIF_F_HIGHDMA };
dev->netdev_ops = &netdev_ops;
dev->hard_header_len = 0;
@@ -268,7 +276,9 @@ static void setup(struct net_device *dev)
dev->features |= WG_NETDEV_FEATURES;
dev->hw_features |= WG_NETDEV_FEATURES;
dev->hw_enc_features |= WG_NETDEV_FEATURES;
- dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH - sizeof(struct udphdr) - max(sizeof(struct ipv6hdr), sizeof(struct iphdr));
+ dev->mtu = ETH_DATA_LEN - MESSAGE_MINIMUM_LENGTH -
+ sizeof(struct udphdr) -
+ max(sizeof(struct ipv6hdr), sizeof(struct iphdr));
SET_NETDEV_DEVTYPE(dev, &device_type);
@@ -279,7 +289,9 @@ static void setup(struct net_device *dev)
wg->dev = dev;
}
-static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *tb[], struct nlattr *data[], struct netlink_ext_ack *extack)
+static int newlink(struct net *src_net, struct net_device *dev,
+ struct nlattr *tb[], struct nlattr *data[],
+ struct netlink_ext_ack *extack)
{
int ret = -ENOMEM;
struct wireguard_device *wg = netdev_priv(dev);
@@ -300,26 +312,32 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t
if (!dev->tstats)
goto error_1;
- wg->incoming_handshakes_worker = packet_alloc_percpu_multicore_worker(packet_handshake_receive_worker, wg);
+ wg->incoming_handshakes_worker = packet_alloc_percpu_multicore_worker(
+ packet_handshake_receive_worker, wg);
if (!wg->incoming_handshakes_worker)
goto error_2;
- wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s", WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name);
+ wg->handshake_receive_wq = alloc_workqueue("wg-kex-%s",
+ WQ_CPU_INTENSIVE | WQ_FREEZABLE, 0, dev->name);
if (!wg->handshake_receive_wq)
goto error_3;
- wg->handshake_send_wq = alloc_workqueue("wg-kex-%s", WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name);
+ wg->handshake_send_wq = alloc_workqueue("wg-kex-%s",
+ WQ_UNBOUND | WQ_FREEZABLE, 0, dev->name);
if (!wg->handshake_send_wq)
goto error_4;
- wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s", WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name);
+ wg->packet_crypt_wq = alloc_workqueue("wg-crypt-%s",
+ WQ_CPU_INTENSIVE | WQ_MEM_RECLAIM, 0, dev->name);
if (!wg->packet_crypt_wq)
goto error_5;
- if (packet_queue_init(&wg->encrypt_queue, packet_encrypt_worker, true, MAX_QUEUED_PACKETS) < 0)
+ if (packet_queue_init(&wg->encrypt_queue, packet_encrypt_worker, true,
+ MAX_QUEUED_PACKETS) < 0)
goto error_6;
- if (packet_queue_init(&wg->decrypt_queue, packet_decrypt_worker, true, MAX_QUEUED_PACKETS) < 0)
+ if (packet_queue_init(&wg->decrypt_queue, packet_decrypt_worker, true,
+ MAX_QUEUED_PACKETS) < 0)
goto error_7;
ret = ratelimiter_init();
@@ -332,8 +350,8 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t
list_add(&wg->device_list, &device_list);
- /* We wait until the end to assign priv_destructor, so that register_netdevice doesn't
- * call it for us if it fails.
+ /* We wait until the end to assign priv_destructor, so that
+ * register_netdevice doesn't call it for us if it fails.
*/
dev->priv_destructor = destruct;
@@ -367,7 +385,8 @@ static struct rtnl_link_ops link_ops __read_mostly = {
.newlink = newlink,
};
-static int netdevice_notification(struct notifier_block *nb, unsigned long action, void *data)
+static int netdevice_notification(struct notifier_block *nb,
+ unsigned long action, void *data)
{
struct net_device *dev = ((struct netdev_notifier_info *)data)->dev;
struct wireguard_device *wg = netdev_priv(dev);
@@ -380,14 +399,16 @@ static int netdevice_notification(struct notifier_block *nb, unsigned long actio
if (dev_net(dev) == wg->creating_net && wg->have_creating_net_ref) {
put_net(wg->creating_net);
wg->have_creating_net_ref = false;
- } else if (dev_net(dev) != wg->creating_net && !wg->have_creating_net_ref) {
+ } else if (dev_net(dev) != wg->creating_net &&
+ !wg->have_creating_net_ref) {
wg->have_creating_net_ref = true;
get_net(wg->creating_net);
}
return 0;
}
-static struct notifier_block netdevice_notifier = { .notifier_call = netdevice_notification };
+static struct notifier_block netdevice_notifier =
+ { .notifier_call = netdevice_notification };
int __init device_init(void)
{
diff --git a/src/device.h b/src/device.h
index 2a0e2c7..2499782 100644
--- a/src/device.h
+++ b/src/device.h
@@ -42,7 +42,8 @@ struct wireguard_device {
struct sock __rcu *sock4, *sock6;
struct net *creating_net;
struct noise_static_identity static_identity;
- struct workqueue_struct *handshake_receive_wq, *handshake_send_wq, *packet_crypt_wq;
+ struct workqueue_struct *handshake_receive_wq, *handshake_send_wq;
+ struct workqueue_struct *packet_crypt_wq;
struct sk_buff_head incoming_handshakes;
int incoming_handshake_cpu;
struct multicore_worker __percpu *incoming_handshakes_worker;
diff --git a/src/hashtables.c b/src/hashtables.c
index ac6df59..4ba2288 100644
--- a/src/hashtables.c
+++ b/src/hashtables.c
@@ -7,12 +7,16 @@
#include "peer.h"
#include "noise.h"
-static inline struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
+static inline struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table,
+ const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
{
- /* siphash gives us a secure 64bit number based on a random key. Since the bits are
- * uniformly distributed, we can then mask off to get the bits we need.
+ /* siphash gives us a secure 64bit number based on a random key. Since
+ * the bits are uniformly distributed, we can then mask off to get the
+ * bits we need.
*/
- return &table->hashtable[siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key) & (HASH_SIZE(table->hashtable) - 1)];
+ return &table->hashtable[
+ siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key) &
+ (HASH_SIZE(table->hashtable) - 1)];
}
void pubkey_hashtable_init(struct pubkey_hashtable *table)
@@ -22,14 +26,17 @@ void pubkey_hashtable_init(struct pubkey_hashtable *table)
mutex_init(&table->lock);
}
-void pubkey_hashtable_add(struct pubkey_hashtable *table, struct wireguard_peer *peer)
+void pubkey_hashtable_add(struct pubkey_hashtable *table,
+ struct wireguard_peer *peer)
{
mutex_lock(&table->lock);
- hlist_add_head_rcu(&peer->pubkey_hash, pubkey_bucket(table, peer->handshake.remote_static));
+ hlist_add_head_rcu(&peer->pubkey_hash,
+ pubkey_bucket(table, peer->handshake.remote_static));
mutex_unlock(&table->lock);
}
-void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_peer *peer)
+void pubkey_hashtable_remove(struct pubkey_hashtable *table,
+ struct wireguard_peer *peer)
{
mutex_lock(&table->lock);
hlist_del_init_rcu(&peer->pubkey_hash);
@@ -37,13 +44,17 @@ void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_pe
}
/* Returns a strong reference to a peer */
-struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
+struct wireguard_peer *
+pubkey_hashtable_lookup(struct pubkey_hashtable *table,
+ const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
{
struct wireguard_peer *iter_peer, *peer = NULL;
rcu_read_lock_bh();
- hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey), pubkey_hash) {
- if (!memcmp(pubkey, iter_peer->handshake.remote_static, NOISE_PUBLIC_KEY_LEN)) {
+ hlist_for_each_entry_rcu_bh (iter_peer, pubkey_bucket(table, pubkey),
+ pubkey_hash) {
+ if (!memcmp(pubkey, iter_peer->handshake.remote_static,
+ NOISE_PUBLIC_KEY_LEN)) {
peer = iter_peer;
break;
}
@@ -53,12 +64,14 @@ struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, c
return peer;
}
-static inline struct hlist_head *index_bucket(struct index_hashtable *table, const __le32 index)
+static inline struct hlist_head *index_bucket(struct index_hashtable *table,
+ const __le32 index)
{
- /* Since the indices are random and thus all bits are uniformly distributed,
- * we can find its bucket simply by masking.
+ /* Since the indices are random and thus all bits are uniformly
+ * distributed, we can find its bucket simply by masking.
*/
- return &table->hashtable[(__force u32)index & (HASH_SIZE(table->hashtable) - 1)];
+ return &table->hashtable[(__force u32)index &
+ (HASH_SIZE(table->hashtable) - 1)];
}
void index_hashtable_init(struct index_hashtable *table)
@@ -67,9 +80,10 @@ void index_hashtable_init(struct index_hashtable *table)
spin_lock_init(&table->lock);
}
-/* At the moment, we limit ourselves to 2^20 total peers, which generally might amount to 2^20*3
- * items in this hashtable. The algorithm below works by picking a random number and testing it.
- * We can see that these limits mean we usually succeed pretty quickly:
+/* At the moment, we limit ourselves to 2^20 total peers, which generally might
+ * amount to 2^20*3 items in this hashtable. The algorithm below works by
+ * picking a random number and testing it. We can see that these limits mean we
+ * usually succeed pretty quickly:
*
* >>> def calculation(tries, size):
* ... return (size / 2**32)**(tries - 1) * (1 - (size / 2**32))
@@ -83,13 +97,15 @@ void index_hashtable_init(struct index_hashtable *table)
* >>> calculation(4, 2**20 * 3)
* 3.9261394135792216e-10
*
- * At the moment, we don't do any masking, so this algorithm isn't exactly constant time in
- * either the random guessing or in the hash list lookup. We could require a minimum of 3
- * tries, which would successfully mask the guessing. TODO: this would not, however, help
- * with the growing hash lengths.
+ * At the moment, we don't do any masking, so this algorithm isn't exactly
+ * constant time in either the random guessing or in the hash list lookup. We
+ * could require a minimum of 3 tries, which would successfully mask the
+ * guessing. this would not, however, help with the growing hash lengths, which
+ * is another thing to consider moving forward.
*/
-__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry)
+__le32 index_hashtable_insert(struct index_hashtable *table,
+ struct index_hashtable_entry *entry)
{
struct index_hashtable_entry *existing_entry;
@@ -102,23 +118,32 @@ __le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashta
search_unused_slot:
/* First we try to find an unused slot, randomly, while unlocked. */
entry->index = (__force __le32)get_random_u32();
- hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) {
+ hlist_for_each_entry_rcu_bh (existing_entry,
+ index_bucket(table, entry->index),
+ index_hash) {
if (existing_entry->index == entry->index)
- goto search_unused_slot; /* If it's already in use, we continue searching. */
+ /* If it's already in use, we continue searching. */
+ goto search_unused_slot;
}
/* Once we've found an unused slot, we lock it, and then double-check
* that nobody else stole it from us.
*/
spin_lock_bh(&table->lock);
- hlist_for_each_entry_rcu_bh(existing_entry, index_bucket(table, entry->index), index_hash) {
+ hlist_for_each_entry_rcu_bh (existing_entry,
+ index_bucket(table, entry->index),
+ index_hash) {
if (existing_entry->index == entry->index) {
spin_unlock_bh(&table->lock);
- goto search_unused_slot; /* If it was stolen, we start over. */
+ /* If it was stolen, we start over. */
+ goto search_unused_slot;
}
}
- /* Otherwise, we know we have it exclusively (since we're locked), so we insert. */
- hlist_add_head_rcu(&entry->index_hash, index_bucket(table, entry->index));
+ /* Otherwise, we know we have it exclusively (since we're locked),
+ * so we insert.
+ */
+ hlist_add_head_rcu(&entry->index_hash,
+ index_bucket(table, entry->index));
spin_unlock_bh(&table->lock);
rcu_read_unlock_bh();
@@ -126,7 +151,9 @@ search_unused_slot:
return entry->index;
}
-bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new)
+bool index_hashtable_replace(struct index_hashtable *table,
+ struct index_hashtable_entry *old,
+ struct index_hashtable_entry *new)
{
if (unlikely(hlist_unhashed(&old->index_hash)))
return false;
@@ -134,17 +161,19 @@ bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtab
new->index = old->index;
hlist_replace_rcu(&old->index_hash, &new->index_hash);
- /* Calling init here NULLs out index_hash, and in fact after this function returns,
- * it's theoretically possible for this to get reinserted elsewhere. That means
- * the RCU lookup below might either terminate early or jump between buckets, in which
- * case the packet simply gets dropped, which isn't terrible.
+ /* Calling init here NULLs out index_hash, and in fact after this
+ * function returns, it's theoretically possible for this to get
+ * reinserted elsewhere. That means the RCU lookup below might either
+ * terminate early or jump between buckets, in which case the packet
+ * simply gets dropped, which isn't terrible.
*/
INIT_HLIST_NODE(&old->index_hash);
spin_unlock_bh(&table->lock);
return true;
}
-void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry)
+void index_hashtable_remove(struct index_hashtable *table,
+ struct index_hashtable_entry *entry)
{
spin_lock_bh(&table->lock);
hlist_del_init_rcu(&entry->index_hash);
@@ -152,12 +181,16 @@ void index_hashtable_remove(struct index_hashtable *table, struct index_hashtabl
}
/* Returns a strong reference to a entry->peer */
-struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer)
+struct index_hashtable_entry *
+index_hashtable_lookup(struct index_hashtable *table,
+ const enum index_hashtable_type type_mask,
+ const __le32 index, struct wireguard_peer **peer)
{
struct index_hashtable_entry *iter_entry, *entry = NULL;
rcu_read_lock_bh();
- hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index), index_hash) {
+ hlist_for_each_entry_rcu_bh (iter_entry, index_bucket(table, index),
+ index_hash) {
if (iter_entry->index == index) {
if (likely(iter_entry->type & type_mask))
entry = iter_entry;
diff --git a/src/hashtables.h b/src/hashtables.h
index f64cd24..62858c5 100644
--- a/src/hashtables.h
+++ b/src/hashtables.h
@@ -22,9 +22,13 @@ struct pubkey_hashtable {
};
void pubkey_hashtable_init(struct pubkey_hashtable *table);
-void pubkey_hashtable_add(struct pubkey_hashtable *table, struct wireguard_peer *peer);
-void pubkey_hashtable_remove(struct pubkey_hashtable *table, struct wireguard_peer *peer);
-struct wireguard_peer *pubkey_hashtable_lookup(struct pubkey_hashtable *table, const u8 pubkey[NOISE_PUBLIC_KEY_LEN]);
+void pubkey_hashtable_add(struct pubkey_hashtable *table,
+ struct wireguard_peer *peer);
+void pubkey_hashtable_remove(struct pubkey_hashtable *table,
+ struct wireguard_peer *peer);
+struct wireguard_peer *
+pubkey_hashtable_lookup(struct pubkey_hashtable *table,
+ const u8 pubkey[NOISE_PUBLIC_KEY_LEN]);
struct index_hashtable {
/* TODO: move to rhashtable */
@@ -44,9 +48,16 @@ struct index_hashtable_entry {
__le32 index;
};
void index_hashtable_init(struct index_hashtable *table);
-__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry);
-bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new);
-void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry);
-struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer);
+__le32 index_hashtable_insert(struct index_hashtable *table,
+ struct index_hashtable_entry *entry);
+bool index_hashtable_replace(struct index_hashtable *table,
+ struct index_hashtable_entry *old,
+ struct index_hashtable_entry *new);
+void index_hashtable_remove(struct index_hashtable *table,
+ struct index_hashtable_entry *entry);
+struct index_hashtable_entry *
+index_hashtable_lookup(struct index_hashtable *table,
+ const enum index_hashtable_type type_mask,
+ const __le32 index, struct wireguard_peer **peer);
#endif /* _WG_HASHTABLES_H */
diff --git a/src/main.c b/src/main.c
index 94f9a7d..bc28516 100644
--- a/src/main.c
+++ b/src/main.c
@@ -31,7 +31,10 @@ static int __init mod_init(void)
blake2s_fpu_init();
curve25519_fpu_init();
#ifdef DEBUG
- if (!allowedips_selftest() || !packet_counter_selftest() || !curve25519_selftest() || !poly1305_selftest() || !chacha20poly1305_selftest() || !blake2s_selftest() || !ratelimiter_selftest())
+ if (!allowedips_selftest() || !packet_counter_selftest() ||
+ !curve25519_selftest() || !poly1305_selftest() ||
+ !chacha20poly1305_selftest() || !blake2s_selftest() ||
+ !ratelimiter_selftest())
return -ENOTRECOVERABLE;
#endif
noise_init();
diff --git a/src/messages.h b/src/messages.h
index 9425a26..bb19957 100644
--- a/src/messages.h
+++ b/src/messages.h
@@ -109,18 +109,20 @@ struct message_data {
u8 encrypted_data[];
};
-#define message_data_len(plain_len) (noise_encrypted_len(plain_len) + sizeof(struct message_data))
+#define message_data_len(plain_len) \
+ (noise_encrypted_len(plain_len) + sizeof(struct message_data))
enum message_alignments {
MESSAGE_PADDING_MULTIPLE = 16,
MESSAGE_MINIMUM_LENGTH = message_data_len(0)
};
-#define SKB_HEADER_LEN (max(sizeof(struct iphdr), sizeof(struct ipv6hdr)) + sizeof(struct udphdr) + NET_SKB_PAD)
-#define DATA_PACKET_HEAD_ROOM ALIGN(sizeof(struct message_data) + SKB_HEADER_LEN, 4)
+#define SKB_HEADER_LEN \
+ (max(sizeof(struct iphdr), sizeof(struct ipv6hdr)) + \
+ sizeof(struct udphdr) + NET_SKB_PAD)
+#define DATA_PACKET_HEAD_ROOM \
+ ALIGN(sizeof(struct message_data) + SKB_HEADER_LEN, 4)
-enum {
- HANDSHAKE_DSCP = 0x88 /* AF41, plus 00 ECN */
-};
+enum { HANDSHAKE_DSCP = 0x88 /* AF41, plus 00 ECN */ };
#endif /* _WG_MESSAGES_H */
diff --git a/src/netlink.c b/src/netlink.c
index 3147587..5390498 100644
--- a/src/netlink.c
+++ b/src/netlink.c
@@ -45,19 +45,23 @@ static const struct nla_policy allowedip_policy[WGALLOWEDIP_A_MAX + 1] = {
[WGALLOWEDIP_A_CIDR_MASK] = { .type = NLA_U8 }
};
-static struct wireguard_device *lookup_interface(struct nlattr **attrs, struct sk_buff *skb)
+static struct wireguard_device *lookup_interface(struct nlattr **attrs,
+ struct sk_buff *skb)
{
struct net_device *dev = NULL;
if (!attrs[WGDEVICE_A_IFINDEX] == !attrs[WGDEVICE_A_IFNAME])
return ERR_PTR(-EBADR);
if (attrs[WGDEVICE_A_IFINDEX])
- dev = dev_get_by_index(sock_net(skb->sk), nla_get_u32(attrs[WGDEVICE_A_IFINDEX]));
+ dev = dev_get_by_index(sock_net(skb->sk),
+ nla_get_u32(attrs[WGDEVICE_A_IFINDEX]));
else if (attrs[WGDEVICE_A_IFNAME])
- dev = dev_get_by_name(sock_net(skb->sk), nla_data(attrs[WGDEVICE_A_IFNAME]));
+ dev = dev_get_by_name(sock_net(skb->sk),
+ nla_data(attrs[WGDEVICE_A_IFNAME]));
if (!dev)
return ERR_PTR(-ENODEV);
- if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind || strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) {
+ if (!dev->rtnl_link_ops || !dev->rtnl_link_ops->kind ||
+ strcmp(dev->rtnl_link_ops->kind, KBUILD_MODNAME)) {
dev_put(dev);
return ERR_PTR(-EOPNOTSUPP);
}
@@ -71,15 +75,17 @@ struct allowedips_ctx {
static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family)
{
- struct nlattr *allowedip_nest;
struct allowedips_ctx *actx = ctx;
+ struct nlattr *allowedip_nest;
allowedip_nest = nla_nest_start(actx->skb, actx->i++);
if (!allowedip_nest)
return -EMSGSIZE;
- if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) || nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) ||
- nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ? sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) {
+ if (nla_put_u8(actx->skb, WGALLOWEDIP_A_CIDR_MASK, cidr) ||
+ nla_put_u16(actx->skb, WGALLOWEDIP_A_FAMILY, family) ||
+ nla_put(actx->skb, WGALLOWEDIP_A_IPADDR, family == AF_INET6 ?
+ sizeof(struct in6_addr) : sizeof(struct in_addr), ip)) {
nla_nest_cancel(actx->skb, allowedip_nest);
return -EMSGSIZE;
}
@@ -88,37 +94,52 @@ static int get_allowedips(void *ctx, const u8 *ip, u8 cidr, int family)
return 0;
}
-static int get_peer(struct wireguard_peer *peer, unsigned int index, struct allowedips_cursor *rt_cursor, struct sk_buff *skb)
+static int get_peer(struct wireguard_peer *peer, unsigned int index,
+ struct allowedips_cursor *rt_cursor, struct sk_buff *skb)
{
- struct allowedips_ctx ctx = { .skb = skb };
struct nlattr *allowedips_nest, *peer_nest = nla_nest_start(skb, index);
+ struct allowedips_ctx ctx = { .skb = skb };
bool fail;
if (!peer_nest)
return -EMSGSIZE;
down_read(&peer->handshake.lock);
- fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN, peer->handshake.remote_static);
+ fail = nla_put(skb, WGPEER_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN,
+ peer->handshake.remote_static);
up_read(&peer->handshake.lock);
if (fail)
goto err;
if (!rt_cursor->seq) {
down_read(&peer->handshake.lock);
- fail = nla_put(skb, WGPEER_A_PRESHARED_KEY, NOISE_SYMMETRIC_KEY_LEN, peer->handshake.preshared_key);
+ fail = nla_put(skb, WGPEER_A_PRESHARED_KEY,
+ NOISE_SYMMETRIC_KEY_LEN,
+ peer->handshake.preshared_key);
up_read(&peer->handshake.lock);
if (fail)
goto err;
- if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME, sizeof(struct timespec), &peer->walltime_last_handshake) || nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL, peer->persistent_keepalive_interval) ||
- nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes, WGPEER_A_UNSPEC) || nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes, WGPEER_A_UNSPEC))
+ if (nla_put(skb, WGPEER_A_LAST_HANDSHAKE_TIME,
+ sizeof(struct timespec),
+ &peer->walltime_last_handshake) ||
+ nla_put_u16(skb, WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL,
+ peer->persistent_keepalive_interval) ||
+ nla_put_u64_64bit(skb, WGPEER_A_TX_BYTES, peer->tx_bytes,
+ WGPEER_A_UNSPEC) ||
+ nla_put_u64_64bit(skb, WGPEER_A_RX_BYTES, peer->rx_bytes,
+ WGPEER_A_UNSPEC))
goto err;
read_lock_bh(&peer->endpoint_lock);
if (peer->endpoint.addr.sa_family == AF_INET)
- fail = nla_put(skb, WGPEER_A_ENDPOINT, sizeof(struct sockaddr_in), &peer->endpoint.addr4);
+ fail = nla_put(skb, WGPEER_A_ENDPOINT,
+ sizeof(struct sockaddr_in),
+ &peer->endpoint.addr4);
else if (peer->endpoint.addr.sa_family == AF_INET6)
- fail = nla_put(skb, WGPEER_A_ENDPOINT, sizeof(struct sockaddr_in6), &peer->endpoint.addr6);
+ fail = nla_put(skb, WGPEER_A_ENDPOINT,
+ sizeof(struct sockaddr_in6),
+ &peer->endpoint.addr6);
read_unlock_bh(&peer->endpoint_lock);
if (fail)
goto err;
@@ -127,7 +148,9 @@ static int get_peer(struct wireguard_peer *peer, unsigned int index, struct allo
allowedips_nest = nla_nest_start(skb, WGPEER_A_ALLOWEDIPS);
if (!allowedips_nest)
goto err;
- if (allowedips_walk_by_peer(&peer->device->peer_allowedips, rt_cursor, peer, get_allowedips, &ctx, &peer->device->device_update_lock)) {
+ if (allowedips_walk_by_peer(&peer->device->peer_allowedips, rt_cursor,
+ peer, get_allowedips, &ctx,
+ &peer->device->device_update_lock)) {
nla_nest_end(skb, allowedips_nest);
nla_nest_end(skb, peer_nest);
return -EMSGSIZE;
@@ -143,13 +166,15 @@ err:
static int get_device_start(struct netlink_callback *cb)
{
- struct wireguard_device *wg;
struct nlattr **attrs = genl_family_attrbuf(&genl_family);
- int ret = nlmsg_parse(cb->nlh, GENL_HDRLEN + genl_family.hdrsize, attrs, genl_family.maxattr, device_policy, NULL);
+ int ret = nlmsg_parse(cb->nlh, GENL_HDRLEN + genl_family.hdrsize, attrs,
+ genl_family.maxattr, device_policy, NULL);
+ struct wireguard_device *wg;
if (ret < 0)
return ret;
- cb->args[2] = (long)kzalloc(sizeof(struct allowedips_cursor), GFP_KERNEL);
+ cb->args[2] =
+ (long)kzalloc(sizeof(struct allowedips_cursor), GFP_KERNEL);
if (!cb->args[2])
return -ENOMEM;
wg = lookup_interface(attrs, cb->skb);
@@ -164,33 +189,46 @@ static int get_device_start(struct netlink_callback *cb)
static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
{
- struct wireguard_device *wg = (struct wireguard_device *)cb->args[0];
struct wireguard_peer *peer, *next_peer_cursor, *last_peer_cursor;
- struct allowedips_cursor *rt_cursor = (struct allowedips_cursor *)cb->args[2];
+ struct allowedips_cursor *rt_cursor;
+ struct wireguard_device *wg;
unsigned int peer_idx = 0;
struct nlattr *peers_nest;
bool done = true;
void *hdr;
int ret = -EMSGSIZE;
- next_peer_cursor = last_peer_cursor = (struct wireguard_peer *)cb->args[1];
+ wg = (struct wireguard_device *)cb->args[0];
+ next_peer_cursor = (struct wireguard_peer *)cb->args[1];
+ last_peer_cursor = (struct wireguard_peer *)cb->args[1];
+ rt_cursor = (struct allowedips_cursor *)cb->args[2];
rtnl_lock();
mutex_lock(&wg->device_update_lock);
cb->seq = wg->device_update_gen;
- hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE);
+ hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
+ &genl_family, NLM_F_MULTI, WG_CMD_GET_DEVICE);
if (!hdr)
goto out;
genl_dump_check_consistent(cb, hdr);
if (!last_peer_cursor) {
- if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT, wg->incoming_port) || nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) || nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) || nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
+ if (nla_put_u16(skb, WGDEVICE_A_LISTEN_PORT,
+ wg->incoming_port) ||
+ nla_put_u32(skb, WGDEVICE_A_FWMARK, wg->fwmark) ||
+ nla_put_u32(skb, WGDEVICE_A_IFINDEX, wg->dev->ifindex) ||
+ nla_put_string(skb, WGDEVICE_A_IFNAME, wg->dev->name))
goto out;
down_read(&wg->static_identity.lock);
if (wg->static_identity.has_identity) {
- if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY, NOISE_PUBLIC_KEY_LEN, wg->static_identity.static_private) || nla_put(skb, WGDEVICE_A_PUBLIC_KEY, NOISE_PUBLIC_KEY_LEN, wg->static_identity.static_public)) {
+ if (nla_put(skb, WGDEVICE_A_PRIVATE_KEY,
+ NOISE_PUBLIC_KEY_LEN,
+ wg->static_identity.static_private) ||
+ nla_put(skb, WGDEVICE_A_PUBLIC_KEY,
+ NOISE_PUBLIC_KEY_LEN,
+ wg->static_identity.static_public)) {
up_read(&wg->static_identity.lock);
goto out;
}
@@ -202,17 +240,19 @@ static int get_device_dump(struct sk_buff *skb, struct netlink_callback *cb)
if (!peers_nest)
goto out;
ret = 0;
- /* If the last cursor was removed via list_del_init in peer_remove, then we just treat
- * this the same as there being no more peers left. The reason is that seq_nr should
- * indicate to userspace that this isn't a coherent dump anyway, so they'll try again.
+ /* If the last cursor was removed via list_del_init in peer_remove, then
+ * we just treat this the same as there being no more peers left. The
+ * reason is that seq_nr should indicate to userspace that this isn't a
+ * coherent dump anyway, so they'll try again.
*/
- if (list_empty(&wg->peer_list) || (last_peer_cursor && list_empty(&last_peer_cursor->peer_list))) {
+ if (list_empty(&wg->peer_list) ||
+ (last_peer_cursor && list_empty(&last_peer_cursor->peer_list))) {
nla_nest_cancel(skb, peers_nest);
goto out;
}
lockdep_assert_held(&wg->device_update_lock);
peer = list_prepare_entry(last_peer_cursor, &wg->peer_list, peer_list);
- list_for_each_entry_continue(peer, &wg->peer_list, peer_list) {
+ list_for_each_entry_continue (peer, &wg->peer_list, peer_list) {
if (get_peer(peer, peer_idx++, rt_cursor, skb)) {
done = false;
break;
@@ -250,7 +290,8 @@ static int get_device_done(struct netlink_callback *cb)
{
struct wireguard_device *wg = (struct wireguard_device *)cb->args[0];
struct wireguard_peer *peer = (struct wireguard_peer *)cb->args[1];
- struct allowedips_cursor *rt_cursor = (struct allowedips_cursor *)cb->args[2];
+ struct allowedips_cursor *rt_cursor =
+ (struct allowedips_cursor *)cb->args[2];
if (wg)
dev_put(wg->dev);
@@ -265,7 +306,7 @@ static int set_port(struct wireguard_device *wg, u16 port)
if (wg->incoming_port == port)
return 0;
- list_for_each_entry(peer, &wg->peer_list, peer_list)
+ list_for_each_entry (peer, &wg->peer_list, peer_list)
socket_clear_peer_endpoint_src(peer);
if (!netif_running(wg->dev)) {
wg->incoming_port = port;
@@ -280,15 +321,25 @@ static int set_allowedip(struct wireguard_peer *peer, struct nlattr **attrs)
u16 family;
u8 cidr;
- if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] || !attrs[WGALLOWEDIP_A_CIDR_MASK])
+ if (!attrs[WGALLOWEDIP_A_FAMILY] || !attrs[WGALLOWEDIP_A_IPADDR] ||
+ !attrs[WGALLOWEDIP_A_CIDR_MASK])
return ret;
family = nla_get_u16(attrs[WGALLOWEDIP_A_FAMILY]);
cidr = nla_get_u8(attrs[WGALLOWEDIP_A_CIDR_MASK]);
- if (family == AF_INET && cidr <= 32 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
- ret = allowedips_insert_v4(&peer->device->peer_allowedips, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock);
- else if (family == AF_INET6 && cidr <= 128 && nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in6_addr))
- ret = allowedips_insert_v6(&peer->device->peer_allowedips, nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer, &peer->device->device_update_lock);
+ if (family == AF_INET && cidr <= 32 &&
+ nla_len(attrs[WGALLOWEDIP_A_IPADDR]) == sizeof(struct in_addr))
+ ret = allowedips_insert_v4(
+ &peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
+ &peer->device->device_update_lock);
+ else if (family == AF_INET6 && cidr <= 128 &&
+ nla_len(attrs[WGALLOWEDIP_A_IPADDR]) ==
+ sizeof(struct in6_addr))
+ ret = allowedips_insert_v6(
+ &peer->device->peer_allowedips,
+ nla_data(attrs[WGALLOWEDIP_A_IPADDR]), cidr, peer,
+ &peer->device->device_update_lock);
return ret;
}
@@ -301,25 +352,33 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
u8 *public_key = NULL, *preshared_key = NULL;
ret = -EINVAL;
- if (attrs[WGPEER_A_PUBLIC_KEY] && nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN)
+ if (attrs[WGPEER_A_PUBLIC_KEY] &&
+ nla_len(attrs[WGPEER_A_PUBLIC_KEY]) == NOISE_PUBLIC_KEY_LEN)
public_key = nla_data(attrs[WGPEER_A_PUBLIC_KEY]);
else
goto out;
- if (attrs[WGPEER_A_PRESHARED_KEY] && nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN)
+ if (attrs[WGPEER_A_PRESHARED_KEY] &&
+ nla_len(attrs[WGPEER_A_PRESHARED_KEY]) == NOISE_SYMMETRIC_KEY_LEN)
preshared_key = nla_data(attrs[WGPEER_A_PRESHARED_KEY]);
if (attrs[WGPEER_A_FLAGS])
flags = nla_get_u32(attrs[WGPEER_A_FLAGS]);
- peer = pubkey_hashtable_lookup(&wg->peer_hashtable, nla_data(attrs[WGPEER_A_PUBLIC_KEY]));
+ peer = pubkey_hashtable_lookup(&wg->peer_hashtable,
+ nla_data(attrs[WGPEER_A_PUBLIC_KEY]));
if (!peer) { /* Peer doesn't exist yet. Add a new one. */
ret = -ENODEV;
if (flags & WGPEER_F_REMOVE_ME)
goto out; /* Tried to remove a non-existing peer. */
down_read(&wg->static_identity.lock);
- if (wg->static_identity.has_identity && !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]), wg->static_identity.static_public, NOISE_PUBLIC_KEY_LEN)) {
- /* We silently ignore peers that have the same public key as the device. The reason we do it silently
- * is that we'd like for people to be able to reuse the same set of API calls across peers.
+ if (wg->static_identity.has_identity &&
+ !memcmp(nla_data(attrs[WGPEER_A_PUBLIC_KEY]),
+ wg->static_identity.static_public,
+ NOISE_PUBLIC_KEY_LEN)) {
+ /* We silently ignore peers that have the same public
+ * key as the device. The reason we do it silently is
+ * that we'd like for people to be able to reuse the
+ * same set of API calls across peers.
*/
up_read(&wg->static_identity.lock);
ret = 0;
@@ -331,7 +390,9 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
peer = peer_create(wg, public_key, preshared_key);
if (!peer)
goto out;
- /* Take additional reference, as though we've just been looked up. */
+ /* Take additional reference, as though we've just been
+ * looked up.
+ */
peer_get(peer);
}
@@ -343,7 +404,8 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
if (preshared_key) {
down_write(&peer->handshake.lock);
- memcpy(&peer->handshake.preshared_key, preshared_key, NOISE_SYMMETRIC_KEY_LEN);
+ memcpy(&peer->handshake.preshared_key, preshared_key,
+ NOISE_SYMMETRIC_KEY_LEN);
up_write(&peer->handshake.lock);
}
@@ -351,7 +413,10 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
struct sockaddr *addr = nla_data(attrs[WGPEER_A_ENDPOINT]);
size_t len = nla_len(attrs[WGPEER_A_ENDPOINT]);
- if ((len == sizeof(struct sockaddr_in) && addr->sa_family == AF_INET) || (len == sizeof(struct sockaddr_in6) && addr->sa_family == AF_INET6)) {
+ if ((len == sizeof(struct sockaddr_in) &&
+ addr->sa_family == AF_INET) ||
+ (len == sizeof(struct sockaddr_in6) &&
+ addr->sa_family == AF_INET6)) {
struct endpoint endpoint = { { { 0 } } };
memcpy(&endpoint.addr, addr, len);
@@ -360,14 +425,16 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
}
if (flags & WGPEER_F_REPLACE_ALLOWEDIPS)
- allowedips_remove_by_peer(&wg->peer_allowedips, peer, &wg->device_update_lock);
+ allowedips_remove_by_peer(&wg->peer_allowedips, peer,
+ &wg->device_update_lock);
if (attrs[WGPEER_A_ALLOWEDIPS]) {
- int rem;
struct nlattr *attr, *allowedip[WGALLOWEDIP_A_MAX + 1];
+ int rem;
- nla_for_each_nested(attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
- ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX, attr, allowedip_policy, NULL);
+ nla_for_each_nested (attr, attrs[WGPEER_A_ALLOWEDIPS], rem) {
+ ret = nla_parse_nested(allowedip, WGALLOWEDIP_A_MAX,
+ attr, allowedip_policy, NULL);
if (ret < 0)
goto out;
ret = set_allowedip(peer, allowedip);
@@ -377,8 +444,12 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
}
if (attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]) {
- const u16 persistent_keepalive_interval = nla_get_u16(attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]);
- const bool send_keepalive = !peer->persistent_keepalive_interval && persistent_keepalive_interval && netif_running(wg->dev);
+ const u16 persistent_keepalive_interval = nla_get_u16(
+ attrs[WGPEER_A_PERSISTENT_KEEPALIVE_INTERVAL]);
+ const bool send_keepalive =
+ !peer->persistent_keepalive_interval &&
+ persistent_keepalive_interval &&
+ netif_running(wg->dev);
peer->persistent_keepalive_interval = persistent_keepalive_interval;
if (send_keepalive)
@@ -391,7 +462,8 @@ static int set_peer(struct wireguard_device *wg, struct nlattr **attrs)
out:
peer_put(peer);
if (attrs[WGPEER_A_PRESHARED_KEY])
- memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]), nla_len(attrs[WGPEER_A_PRESHARED_KEY]));
+ memzero_explicit(nla_data(attrs[WGPEER_A_PRESHARED_KEY]),
+ nla_len(attrs[WGPEER_A_PRESHARED_KEY]));
return ret;
}
@@ -413,26 +485,35 @@ static int set_device(struct sk_buff *skb, struct genl_info *info)
struct wireguard_peer *peer;
wg->fwmark = nla_get_u32(info->attrs[WGDEVICE_A_FWMARK]);
- list_for_each_entry(peer, &wg->peer_list, peer_list)
+ list_for_each_entry (peer, &wg->peer_list, peer_list)
socket_clear_peer_endpoint_src(peer);
}
if (info->attrs[WGDEVICE_A_LISTEN_PORT]) {
- ret = set_port(wg, nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
+ ret = set_port(
+ wg, nla_get_u16(info->attrs[WGDEVICE_A_LISTEN_PORT]));
if (ret)
goto out;
}
- if (info->attrs[WGDEVICE_A_FLAGS] && nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]) & WGDEVICE_F_REPLACE_PEERS)
+ if (info->attrs[WGDEVICE_A_FLAGS] &&
+ nla_get_u32(info->attrs[WGDEVICE_A_FLAGS]) &
+ WGDEVICE_F_REPLACE_PEERS)
peer_remove_all(wg);
- if (info->attrs[WGDEVICE_A_PRIVATE_KEY] && nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) == NOISE_PUBLIC_KEY_LEN) {
+ if (info->attrs[WGDEVICE_A_PRIVATE_KEY] &&
+ nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]) ==
+ NOISE_PUBLIC_KEY_LEN) {
+ u8 *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]);
+ u8 public_key[NOISE_PUBLIC_KEY_LEN];
struct wireguard_peer *peer, *temp;
- u8 public_key[NOISE_PUBLIC_KEY_LEN], *private_key = nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]);
- /* We remove before setting, to prevent race, which means doing two 25519-genpub ops. */
+ /* We remove before setting, to prevent race, which means doing
+ * two 25519-genpub ops.
+ */
if (curve25519_generate_public(public_key, private_key)) {
- peer = pubkey_hashtable_lookup(&wg->peer_hashtable, public_key);
+ peer = pubkey_hashtable_lookup(&wg->peer_hashtable,
+ public_key);
if (peer) {
peer_put(peer);
peer_remove(peer);
@@ -440,8 +521,10 @@ static int set_device(struct sk_buff *skb, struct genl_info *info)
}
down_write(&wg->static_identity.lock);
- noise_set_static_identity_private_key(&wg->static_identity, private_key);
- list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list) {
+ noise_set_static_identity_private_key(&wg->static_identity,
+ private_key);
+ list_for_each_entry_safe (peer, temp, &wg->peer_list,
+ peer_list) {
if (!noise_precompute_static_static(peer))
peer_remove(peer);
}
@@ -453,8 +536,9 @@ static int set_device(struct sk_buff *skb, struct genl_info *info)
int rem;
struct nlattr *attr, *peer[WGPEER_A_MAX + 1];
- nla_for_each_nested(attr, info->attrs[WGDEVICE_A_PEERS], rem) {
- ret = nla_parse_nested(peer, WGPEER_A_MAX, attr, peer_policy, NULL);
+ nla_for_each_nested (attr, info->attrs[WGDEVICE_A_PEERS], rem) {
+ ret = nla_parse_nested(peer, WGPEER_A_MAX, attr,
+ peer_policy, NULL);
if (ret < 0)
goto out;
ret = set_peer(wg, peer);
@@ -470,7 +554,8 @@ out:
dev_put(wg->dev);
out_nodev:
if (info->attrs[WGDEVICE_A_PRIVATE_KEY])
- memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]), nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]));
+ memzero_explicit(nla_data(info->attrs[WGDEVICE_A_PRIVATE_KEY]),
+ nla_len(info->attrs[WGDEVICE_A_PRIVATE_KEY]));
return ret;
}
diff --git a/src/noise.c b/src/noise.c
index 0f6e51b..70b53a6 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -35,7 +35,8 @@ void __init noise_init(void)
{
struct blake2s_state blake;
- blake2s(handshake_init_chaining_key, handshake_name, NULL, NOISE_HASH_LEN, sizeof(handshake_name), 0);
+ blake2s(handshake_init_chaining_key, handshake_name, NULL,
+ NOISE_HASH_LEN, sizeof(handshake_name), 0);
blake2s_init(&blake, NOISE_HASH_LEN);
blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN);
blake2s_update(&blake, identifier_name, sizeof(identifier_name));
@@ -46,16 +47,25 @@ void __init noise_init(void)
bool noise_precompute_static_static(struct wireguard_peer *peer)
{
bool ret = true;
+
down_write(&peer->handshake.lock);
if (peer->handshake.static_identity->has_identity)
- ret = curve25519(peer->handshake.precomputed_static_static, peer->handshake.static_identity->static_private, peer->handshake.remote_static);
+ ret = curve25519(
+ peer->handshake.precomputed_static_static,
+ peer->handshake.static_identity->static_private,
+ peer->handshake.remote_static);
else
- memset(peer->handshake.precomputed_static_static, 0, NOISE_PUBLIC_KEY_LEN);
+ memset(peer->handshake.precomputed_static_static, 0,
+ NOISE_PUBLIC_KEY_LEN);
up_write(&peer->handshake.lock);
return ret;
}
-bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer)
+bool noise_handshake_init(struct noise_handshake *handshake,
+ struct noise_static_identity *static_identity,
+ const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
+ const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
+ struct wireguard_peer *peer)
{
memset(handshake, 0, sizeof(struct noise_handshake));
init_rwsem(&handshake->lock);
@@ -63,7 +73,8 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static
handshake->entry.peer = peer;
memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
if (peer_preshared_key)
- memcpy(handshake->preshared_key, peer_preshared_key, NOISE_SYMMETRIC_KEY_LEN);
+ memcpy(handshake->preshared_key, peer_preshared_key,
+ NOISE_SYMMETRIC_KEY_LEN);
handshake->static_identity = static_identity;
handshake->state = HANDSHAKE_ZEROED;
return noise_precompute_static_static(peer);
@@ -81,16 +92,19 @@ static void handshake_zero(struct noise_handshake *handshake)
void noise_handshake_clear(struct noise_handshake *handshake)
{
- index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
+ index_hashtable_remove(&handshake->entry.peer->device->index_hashtable,
+ &handshake->entry);
down_write(&handshake->lock);
handshake_zero(handshake);
up_write(&handshake->lock);
- index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
+ index_hashtable_remove(&handshake->entry.peer->device->index_hashtable,
+ &handshake->entry);
}
static struct noise_keypair *keypair_create(struct wireguard_peer *peer)
{
- struct noise_keypair *keypair = kzalloc(sizeof(struct noise_keypair), GFP_KERNEL);
+ struct noise_keypair *keypair =
+ kzalloc(sizeof(struct noise_keypair), GFP_KERNEL);
if (unlikely(!keypair))
return NULL;
@@ -108,9 +122,14 @@ static void keypair_free_rcu(struct rcu_head *rcu)
static void keypair_free_kref(struct kref *kref)
{
- struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount);
- net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id);
- index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry);
+ struct noise_keypair *keypair =
+ container_of(kref, struct noise_keypair, refcount);
+ net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
+ keypair->entry.peer->device->dev->name,
+ keypair->internal_id,
+ keypair->entry.peer->internal_id);
+ index_hashtable_remove(&keypair->entry.peer->device->index_hashtable,
+ &keypair->entry);
call_rcu_bh(&keypair->rcu, keypair_free_rcu);
}
@@ -119,13 +138,16 @@ void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
if (unlikely(!keypair))
return;
if (unlikely(unreference_now))
- index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry);
+ index_hashtable_remove(
+ &keypair->entry.peer->device->index_hashtable,
+ &keypair->entry);
kref_put(&keypair->refcount, keypair_free_kref);
}
struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair)
{
- RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking noise keypair reference without holding the RCU BH read lock");
+ RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
+ "Taking noise keypair reference without holding the RCU BH read lock");
if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
return NULL;
return keypair;
@@ -136,55 +158,67 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs)
struct noise_keypair *old;
spin_lock_bh(&keypairs->keypair_update_lock);
- old = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
+ old = rcu_dereference_protected(keypairs->previous_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
noise_keypair_put(old, true);
- old = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
+ old = rcu_dereference_protected(keypairs->next_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
noise_keypair_put(old, true);
- old = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
+ old = rcu_dereference_protected(keypairs->current_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock));
RCU_INIT_POINTER(keypairs->current_keypair, NULL);
noise_keypair_put(old, true);
spin_unlock_bh(&keypairs->keypair_update_lock);
}
-static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypair *new_keypair)
+static void add_new_keypair(struct noise_keypairs *keypairs,
+ struct noise_keypair *new_keypair)
{
struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
spin_lock_bh(&keypairs->keypair_update_lock);
- previous_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
- next_keypair = rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
- current_keypair = rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
+ previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock));
+ next_keypair = rcu_dereference_protected(keypairs->next_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock));
+ current_keypair = rcu_dereference_protected(keypairs->current_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock));
if (new_keypair->i_am_the_initiator) {
- /* If we're the initiator, it means we've sent a handshake, and received
- * a confirmation response, which means this new keypair can now be used.
+ /* If we're the initiator, it means we've sent a handshake, and
+ * received a confirmation response, which means this new
+ * keypair can now be used.
*/
if (next_keypair) {
- /* If there already was a next keypair pending, we demote it to be
- * the previous keypair, and free the existing current.
- * TODO: note that this means KCI can result in this transition. It
- * would perhaps be more sound to always just get rid of the unused
- * next keypair instead of putting it in the previous slot, but this
- * might be a bit less robust. Something to think about and decide on.
+ /* If there already was a next keypair pending, we
+ * demote it to be the previous keypair, and free the
+ * existing current. Note that this means KCI can result
+ * in this transition. It would perhaps be more sound to
+ * always just get rid of the unused next keypair
+ * instead of putting it in the previous slot, but this
+ * might be a bit less robust. Something to think about
+ * for the future.
*/
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
- rcu_assign_pointer(keypairs->previous_keypair, next_keypair);
+ rcu_assign_pointer(keypairs->previous_keypair,
+ next_keypair);
noise_keypair_put(current_keypair, true);
- } else /* If there wasn't an existing next keypair, we replace the
- * previous with the current one.
+ } else /* If there wasn't an existing next keypair, we replace
+ * the previous with the current one.
*/
- rcu_assign_pointer(keypairs->previous_keypair, current_keypair);
- /* At this point we can get rid of the old previous keypair, and set up
- * the new keypair.
+ rcu_assign_pointer(keypairs->previous_keypair,
+ current_keypair);
+ /* At this point we can get rid of the old previous keypair, and
+ * set up the new keypair.
*/
noise_keypair_put(previous_keypair, true);
rcu_assign_pointer(keypairs->current_keypair, new_keypair);
} else {
- /* If we're the responder, it means we can't use the new keypair until
- * we receive confirmation via the first data packet, so we get rid of
- * the existing previous one, the possibly existing next one, and slide
- * in the new next one.
+ /* If we're the responder, it means we can't use the new keypair
+ * until we receive confirmation via the first data packet, so
+ * we get rid of the existing previous one, the possibly
+ * existing next one, and slide in the new next one.
*/
rcu_assign_pointer(keypairs->next_keypair, new_keypair);
noise_keypair_put(next_keypair, true);
@@ -194,19 +228,25 @@ static void add_new_keypair(struct noise_keypairs *keypairs, struct noise_keypai
spin_unlock_bh(&keypairs->keypair_update_lock);
}
-bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair)
+bool noise_received_with_keypair(struct noise_keypairs *keypairs,
+ struct noise_keypair *received_keypair)
{
- bool key_is_new;
struct noise_keypair *old_keypair;
+ bool key_is_new;
/* We first check without taking the spinlock. */
- key_is_new = received_keypair == rcu_access_pointer(keypairs->next_keypair);
+ key_is_new = received_keypair ==
+ rcu_access_pointer(keypairs->next_keypair);
if (likely(!key_is_new))
return false;
spin_lock_bh(&keypairs->keypair_update_lock);
- /* After locking, we double check that things didn't change from beneath us. */
- if (unlikely(received_keypair != rcu_dereference_protected(keypairs->next_keypair, lockdep_is_held(&keypairs->keypair_update_lock)))) {
+ /* After locking, we double check that things didn't change from
+ * beneath us.
+ */
+ if (unlikely(received_keypair !=
+ rcu_dereference_protected(keypairs->next_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock)))) {
spin_unlock_bh(&keypairs->keypair_update_lock);
return false;
}
@@ -215,8 +255,11 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k
* into the current, the current into the previous, and get rid of
* the old previous.
*/
- old_keypair = rcu_dereference_protected(keypairs->previous_keypair, lockdep_is_held(&keypairs->keypair_update_lock));
- rcu_assign_pointer(keypairs->previous_keypair, rcu_dereference_protected(keypairs->current_keypair, lockdep_is_held(&keypairs->keypair_update_lock)));
+ old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock));
+ rcu_assign_pointer(keypairs->previous_keypair,
+ rcu_dereference_protected(keypairs->current_keypair,
+ lockdep_is_held(&keypairs->keypair_update_lock)));
noise_keypair_put(old_keypair, true);
rcu_assign_pointer(keypairs->current_keypair, received_keypair);
RCU_INIT_POINTER(keypairs->next_keypair, NULL);
@@ -226,34 +269,46 @@ bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_k
}
/* Must hold static_identity->lock */
-void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN])
+void noise_set_static_identity_private_key(
+ struct noise_static_identity *static_identity,
+ const u8 private_key[NOISE_PUBLIC_KEY_LEN])
{
- memcpy(static_identity->static_private, private_key, NOISE_PUBLIC_KEY_LEN);
- static_identity->has_identity = curve25519_generate_public(static_identity->static_public, private_key);
+ memcpy(static_identity->static_private, private_key,
+ NOISE_PUBLIC_KEY_LEN);
+ static_identity->has_identity = curve25519_generate_public(
+ static_identity->static_public, private_key);
}
/* This is Hugo Krawczyk's HKDF:
* - https://eprint.iacr.org/2010/264.pdf
* - https://tools.ietf.org/html/rfc5869
*/
-static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, size_t first_len, size_t second_len, size_t third_len, size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
+static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
+ size_t first_len, size_t second_len, size_t third_len,
+ size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
{
- u8 secret[BLAKE2S_OUTBYTES];
u8 output[BLAKE2S_OUTBYTES + 1];
+ u8 secret[BLAKE2S_OUTBYTES];
#ifdef DEBUG
- BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES || third_len > BLAKE2S_OUTBYTES || ((second_len || second_dst || third_len || third_dst) && (!first_len || !first_dst)) || ((third_len || third_dst) && (!second_len || !second_dst)));
+ BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES ||
+ third_len > BLAKE2S_OUTBYTES ||
+ ((second_len || second_dst || third_len || third_dst) &&
+ (!first_len || !first_dst)) ||
+ ((third_len || third_dst) && (!second_len || !second_dst)));
#endif
/* Extract entropy from data into secret */
- blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len, NOISE_HASH_LEN);
+ blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len,
+ NOISE_HASH_LEN);
if (!first_dst || !first_len)
goto out;
/* Expand first key: key = secret, data = 0x1 */
output[0] = 1;
- blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1, BLAKE2S_OUTBYTES);
+ blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1,
+ BLAKE2S_OUTBYTES);
memcpy(first_dst, output, first_len);
if (!second_dst || !second_len)
@@ -261,7 +316,8 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, si
/* Expand second key: key = secret, data = first-key || 0x2 */
output[BLAKE2S_OUTBYTES] = 2;
- blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES);
+ blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES,
+ BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES);
memcpy(second_dst, output, second_len);
if (!third_dst || !third_len)
@@ -269,7 +325,8 @@ static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, si
/* Expand third key: key = secret, data = second-key || 0x3 */
output[BLAKE2S_OUTBYTES] = 3;
- blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES);
+ blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES,
+ BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES);
memcpy(third_dst, output, third_len);
out:
@@ -282,25 +339,34 @@ static void symmetric_key_init(struct noise_symmetric_key *key)
{
spin_lock_init(&key->counter.receive.lock);
atomic64_set(&key->counter.counter, 0);
- memset(key->counter.receive.backtrack, 0, sizeof(key->counter.receive.backtrack));
+ memset(key->counter.receive.backtrack, 0,
+ sizeof(key->counter.receive.backtrack));
key->birthdate = ktime_get_boot_fast_ns();
key->is_valid = true;
}
-static void derive_keys(struct noise_symmetric_key *first_dst, struct noise_symmetric_key *second_dst, const u8 chaining_key[NOISE_HASH_LEN])
+static void derive_keys(struct noise_symmetric_key *first_dst,
+ struct noise_symmetric_key *second_dst,
+ const u8 chaining_key[NOISE_HASH_LEN])
{
- kdf(first_dst->key, second_dst->key, NULL, NULL, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, chaining_key);
+ kdf(first_dst->key, second_dst->key, NULL, NULL,
+ NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0,
+ chaining_key);
symmetric_key_init(first_dst);
symmetric_key_init(second_dst);
}
-static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 private[NOISE_PUBLIC_KEY_LEN], const u8 public[NOISE_PUBLIC_KEY_LEN])
+static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
+ u8 key[NOISE_SYMMETRIC_KEY_LEN],
+ const u8 private[NOISE_PUBLIC_KEY_LEN],
+ const u8 public[NOISE_PUBLIC_KEY_LEN])
{
u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
if (unlikely(!curve25519(dh_calculation, private, public)))
return false;
- kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
+ kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN,
+ NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN);
return true;
}
@@ -315,42 +381,59 @@ static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
blake2s_final(&blake, hash, NOISE_HASH_LEN);
}
-static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
+static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
+ u8 key[NOISE_SYMMETRIC_KEY_LEN],
+ const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
{
u8 temp_hash[NOISE_HASH_LEN];
- kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
+ kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN,
+ NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key);
mix_hash(hash, temp_hash, NOISE_HASH_LEN);
memzero_explicit(temp_hash, NOISE_HASH_LEN);
}
-static void handshake_init(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
+static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
+ u8 hash[NOISE_HASH_LEN],
+ const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
{
memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN);
}
-static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN])
+static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
+ size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
+ u8 hash[NOISE_HASH_LEN])
{
- chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key);
+ chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash,
+ NOISE_HASH_LEN,
+ 0 /* Always zero for Noise_IK */, key);
mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len));
}
-static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN])
+static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
+ size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
+ u8 hash[NOISE_HASH_LEN])
{
- if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key))
+ if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len,
+ hash, NOISE_HASH_LEN,
+ 0 /* Always zero for Noise_IK */, key))
return false;
mix_hash(hash, src_ciphertext, src_len);
return true;
}
-static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN], const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN])
+static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
+ const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
+ u8 chaining_key[NOISE_HASH_LEN],
+ u8 hash[NOISE_HASH_LEN])
{
if (ephemeral_dst != ephemeral_src)
memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
- kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
+ kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0,
+ NOISE_PUBLIC_KEY_LEN, chaining_key);
}
static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
@@ -363,14 +446,15 @@ static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
*(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
}
-bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake)
+bool noise_handshake_create_initiation(struct message_handshake_initiation *dst,
+ struct noise_handshake *handshake)
{
u8 timestamp[NOISE_TIMESTAMP_LEN];
u8 key[NOISE_SYMMETRIC_KEY_LEN];
bool ret = false;
- /* We need to wait for crng _before_ taking any locks, since curve25519_generate_secret
- * uses get_random_bytes_wait.
+ /* We need to wait for crng _before_ taking any locks, since
+ * curve25519_generate_secret uses get_random_bytes_wait.
*/
wait_for_random_bytes();
@@ -382,29 +466,42 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst,
dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
- handshake_init(handshake->chaining_key, handshake->hash, handshake->remote_static);
+ handshake_init(handshake->chaining_key, handshake->hash,
+ handshake->remote_static);
/* e */
curve25519_generate_secret(handshake->ephemeral_private);
- if (!curve25519_generate_public(dst->unencrypted_ephemeral, handshake->ephemeral_private))
+ if (!curve25519_generate_public(dst->unencrypted_ephemeral,
+ handshake->ephemeral_private))
goto out;
- message_ephemeral(dst->unencrypted_ephemeral, dst->unencrypted_ephemeral, handshake->chaining_key, handshake->hash);
+ message_ephemeral(dst->unencrypted_ephemeral,
+ dst->unencrypted_ephemeral, handshake->chaining_key,
+ handshake->hash);
/* es */
- if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private, handshake->remote_static))
+ if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private,
+ handshake->remote_static))
goto out;
/* s */
- message_encrypt(dst->encrypted_static, handshake->static_identity->static_public, NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
+ message_encrypt(dst->encrypted_static,
+ handshake->static_identity->static_public,
+ NOISE_PUBLIC_KEY_LEN, key, handshake->hash);
/* ss */
- kdf(handshake->chaining_key, key, NULL, handshake->precomputed_static_static, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, handshake->chaining_key);
+ kdf(handshake->chaining_key, key, NULL,
+ handshake->precomputed_static_static, NOISE_HASH_LEN,
+ NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
+ handshake->chaining_key);
/* {t} */
tai64n_now(timestamp);
- message_encrypt(dst->encrypted_timestamp, timestamp, NOISE_TIMESTAMP_LEN, key, handshake->hash);
+ message_encrypt(dst->encrypted_timestamp, timestamp,
+ NOISE_TIMESTAMP_LEN, key, handshake->hash);
- dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
+ dst->sender_index = index_hashtable_insert(
+ &handshake->entry.peer->device->index_hashtable,
+ &handshake->entry);
handshake->state = HANDSHAKE_CREATED_INITIATION;
ret = true;
@@ -416,17 +513,19 @@ out:
return ret;
}
-struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg)
+struct wireguard_peer *
+noise_handshake_consume_initiation(struct message_handshake_initiation *src,
+ struct wireguard_device *wg)
{
+ struct wireguard_peer *peer = NULL, *ret_peer = NULL;
+ struct noise_handshake *handshake;
bool replay_attack, flood_attack;
+ u8 key[NOISE_SYMMETRIC_KEY_LEN];
+ u8 chaining_key[NOISE_HASH_LEN];
+ u8 hash[NOISE_HASH_LEN];
u8 s[NOISE_PUBLIC_KEY_LEN];
u8 e[NOISE_PUBLIC_KEY_LEN];
u8 t[NOISE_TIMESTAMP_LEN];
- struct noise_handshake *handshake;
- struct wireguard_peer *peer = NULL, *ret_peer = NULL;
- u8 key[NOISE_SYMMETRIC_KEY_LEN];
- u8 hash[NOISE_HASH_LEN];
- u8 chaining_key[NOISE_HASH_LEN];
down_read(&wg->static_identity.lock);
if (unlikely(!wg->static_identity.has_identity))
@@ -442,7 +541,8 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha
goto out;
/* s */
- if (!message_decrypt(s, src->encrypted_static, sizeof(src->encrypted_static), key, hash))
+ if (!message_decrypt(s, src->encrypted_static,
+ sizeof(src->encrypted_static), key, hash))
goto out;
/* Lookup which peer we're actually talking to */
@@ -452,15 +552,21 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha
handshake = &peer->handshake;
/* ss */
- kdf(chaining_key, key, NULL, handshake->precomputed_static_static, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key);
+ kdf(chaining_key, key, NULL, handshake->precomputed_static_static,
+ NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN,
+ chaining_key);
/* {t} */
- if (!message_decrypt(t, src->encrypted_timestamp, sizeof(src->encrypted_timestamp), key, hash))
+ if (!message_decrypt(t, src->encrypted_timestamp,
+ sizeof(src->encrypted_timestamp), key, hash))
goto out;
down_read(&handshake->lock);
- replay_attack = memcmp(t, handshake->latest_timestamp, NOISE_TIMESTAMP_LEN) <= 0;
- flood_attack = handshake->last_initiation_consumption + NSEC_PER_SEC / INITIATIONS_PER_SECOND > ktime_get_boot_fast_ns();
+ replay_attack = memcmp(t, handshake->latest_timestamp,
+ NOISE_TIMESTAMP_LEN) <= 0;
+ flood_attack = handshake->last_initiation_consumption +
+ NSEC_PER_SEC / INITIATIONS_PER_SECOND >
+ ktime_get_boot_fast_ns();
up_read(&handshake->lock);
if (replay_attack || flood_attack)
goto out;
@@ -487,13 +593,14 @@ out:
return ret_peer;
}
-bool noise_handshake_create_response(struct message_handshake_response *dst, struct noise_handshake *handshake)
+bool noise_handshake_create_response(struct message_handshake_response *dst,
+ struct noise_handshake *handshake)
{
bool ret = false;
u8 key[NOISE_SYMMETRIC_KEY_LEN];
- /* We need to wait for crng _before_ taking any locks, since curve25519_generate_secret
- * uses get_random_bytes_wait.
+ /* We need to wait for crng _before_ taking any locks, since
+ * curve25519_generate_secret uses get_random_bytes_wait.
*/
wait_for_random_bytes();
@@ -508,25 +615,33 @@ bool noise_handshake_create_response(struct message_handshake_response *dst, str
/* e */
curve25519_generate_secret(handshake->ephemeral_private);
- if (!curve25519_generate_public(dst->unencrypted_ephemeral, handshake->ephemeral_private))
+ if (!curve25519_generate_public(dst->unencrypted_ephemeral,
+ handshake->ephemeral_private))
goto out;
- message_ephemeral(dst->unencrypted_ephemeral, dst->unencrypted_ephemeral, handshake->chaining_key, handshake->hash);
+ message_ephemeral(dst->unencrypted_ephemeral,
+ dst->unencrypted_ephemeral, handshake->chaining_key,
+ handshake->hash);
/* ee */
- if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_ephemeral))
+ if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
+ handshake->remote_ephemeral))
goto out;
/* se */
- if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_static))
+ if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private,
+ handshake->remote_static))
goto out;
/* psk */
- mix_psk(handshake->chaining_key, handshake->hash, key, handshake->preshared_key);
+ mix_psk(handshake->chaining_key, handshake->hash, key,
+ handshake->preshared_key);
/* {} */
message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash);
- dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
+ dst->sender_index = index_hashtable_insert(
+ &handshake->entry.peer->device->index_hashtable,
+ &handshake->entry);
handshake->state = HANDSHAKE_CREATED_RESPONSE;
ret = true;
@@ -538,7 +653,9 @@ out:
return ret;
}
-struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg)
+struct wireguard_peer *
+noise_handshake_consume_response(struct message_handshake_response *src,
+ struct wireguard_device *wg)
{
struct noise_handshake *handshake;
struct wireguard_peer *peer = NULL, *ret_peer = NULL;
@@ -555,7 +672,9 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
if (unlikely(!wg->static_identity.has_identity))
goto out;
- handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer);
+ handshake = (struct noise_handshake *)index_hashtable_lookup(
+ &wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE,
+ src->receiver_index, &peer);
if (unlikely(!handshake))
goto out;
@@ -563,7 +682,8 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
state = handshake->state;
memcpy(hash, handshake->hash, NOISE_HASH_LEN);
memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
- memcpy(ephemeral_private, handshake->ephemeral_private, NOISE_PUBLIC_KEY_LEN);
+ memcpy(ephemeral_private, handshake->ephemeral_private,
+ NOISE_PUBLIC_KEY_LEN);
up_read(&handshake->lock);
if (state != HANDSHAKE_CREATED_INITIATION)
@@ -584,12 +704,15 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
mix_psk(chaining_key, hash, key, handshake->preshared_key);
/* {} */
- if (!message_decrypt(NULL, src->encrypted_nothing, sizeof(src->encrypted_nothing), key, hash))
+ if (!message_decrypt(NULL, src->encrypted_nothing,
+ sizeof(src->encrypted_nothing), key, hash))
goto fail;
/* Success! Copy everything to peer */
down_write(&handshake->lock);
- /* It's important to check that the state is still the same, while we have an exclusive lock */
+ /* It's important to check that the state is still the same, while we
+ * have an exclusive lock.
+ */
if (handshake->state != state) {
up_write(&handshake->lock);
goto fail;
@@ -615,32 +738,43 @@ out:
return ret_peer;
}
-bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs)
+bool noise_handshake_begin_session(struct noise_handshake *handshake,
+ struct noise_keypairs *keypairs)
{
struct noise_keypair *new_keypair;
bool ret = false;
down_write(&handshake->lock);
- if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
+ if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
+ handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
goto out;
new_keypair = keypair_create(handshake->entry.peer);
if (!new_keypair)
goto out;
- new_keypair->i_am_the_initiator = handshake->state == HANDSHAKE_CONSUMED_RESPONSE;
+ new_keypair->i_am_the_initiator = handshake->state ==
+ HANDSHAKE_CONSUMED_RESPONSE;
new_keypair->remote_index = handshake->remote_index;
if (new_keypair->i_am_the_initiator)
- derive_keys(&new_keypair->sending, &new_keypair->receiving, handshake->chaining_key);
+ derive_keys(&new_keypair->sending, &new_keypair->receiving,
+ handshake->chaining_key);
else
- derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key);
+ derive_keys(&new_keypair->receiving, &new_keypair->sending,
+ handshake->chaining_key);
handshake_zero(handshake);
rcu_read_lock_bh();
- if (likely(!container_of(handshake, struct wireguard_peer, handshake)->is_dead)) {
+ if (likely(!container_of(handshake, struct wireguard_peer,
+ handshake)->is_dead)) {
add_new_keypair(keypairs, new_keypair);
- net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n", handshake->entry.peer->device->dev->name, new_keypair->internal_id, handshake->entry.peer->internal_id);
- ret = index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry);
+ net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
+ handshake->entry.peer->device->dev->name,
+ new_keypair->internal_id,
+ handshake->entry.peer->internal_id);
+ ret = index_hashtable_replace(
+ &handshake->entry.peer->device->index_hashtable,
+ &handshake->entry, &new_keypair->entry);
} else
kzfree(new_keypair);
rcu_read_unlock_bh();
diff --git a/src/noise.h b/src/noise.h
index be59587..6a563ce 100644
--- a/src/noise.h
+++ b/src/noise.h
@@ -86,29 +86,44 @@ struct noise_handshake {
u8 latest_timestamp[NOISE_TIMESTAMP_LEN];
__le32 remote_index;
- /* Protects all members except the immutable (after noise_handshake_init): remote_static, precomputed_static_static, static_identity */
+ /* Protects all members except the immutable (after noise_handshake_
+ * init): remote_static, precomputed_static_static, static_identity. */
struct rw_semaphore lock;
};
struct wireguard_device;
void noise_init(void);
-bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer);
+bool noise_handshake_init(struct noise_handshake *handshake,
+ struct noise_static_identity *static_identity,
+ const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
+ const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
+ struct wireguard_peer *peer);
void noise_handshake_clear(struct noise_handshake *handshake);
void noise_keypair_put(struct noise_keypair *keypair, bool unreference_now);
struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair);
void noise_keypairs_clear(struct noise_keypairs *keypairs);
-bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair);
+bool noise_received_with_keypair(struct noise_keypairs *keypairs,
+ struct noise_keypair *received_keypair);
-void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]);
+void noise_set_static_identity_private_key(
+ struct noise_static_identity *static_identity,
+ const u8 private_key[NOISE_PUBLIC_KEY_LEN]);
bool noise_precompute_static_static(struct wireguard_peer *peer);
-bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake);
-struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg);
+bool noise_handshake_create_initiation(struct message_handshake_initiation *dst,
+ struct noise_handshake *handshake);
+struct wireguard_peer *
+noise_handshake_consume_initiation(struct message_handshake_initiation *src,
+ struct wireguard_device *wg);
-bool noise_handshake_create_response(struct message_handshake_response *dst, struct noise_handshake *handshake);
-struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg);
+bool noise_handshake_create_response(struct message_handshake_response *dst,
+ struct noise_handshake *handshake);
+struct wireguard_peer *
+noise_handshake_consume_response(struct message_handshake_response *src,
+ struct wireguard_device *wg);
-bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noise_keypairs *keypairs);
+bool noise_handshake_begin_session(struct noise_handshake *handshake,
+ struct noise_keypairs *keypairs);
#endif /* _WG_NOISE_H */
diff --git a/src/peer.c b/src/peer.c
index b9703d5..c079b71 100644
--- a/src/peer.c
+++ b/src/peer.c
@@ -17,7 +17,10 @@
static atomic64_t peer_counter = ATOMIC64_INIT(0);
-struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN])
+struct wireguard_peer *
+peer_create(struct wireguard_device *wg,
+ const u8 public_key[NOISE_PUBLIC_KEY_LEN],
+ const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN])
{
struct wireguard_peer *peer;
@@ -31,11 +34,13 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_
return NULL;
peer->device = wg;
- if (!noise_handshake_init(&peer->handshake, &wg->static_identity, public_key, preshared_key, peer))
+ if (!noise_handshake_init(&peer->handshake, &wg->static_identity,
+ public_key, preshared_key, peer))
goto err_1;
if (dst_cache_init(&peer->endpoint_cache, GFP_KERNEL))
goto err_1;
- if (packet_queue_init(&peer->tx_queue, packet_tx_worker, false, MAX_QUEUED_PACKETS))
+ if (packet_queue_init(&peer->tx_queue, packet_tx_worker, false,
+ MAX_QUEUED_PACKETS))
goto err_2;
if (packet_queue_init(&peer->rx_queue, NULL, false, MAX_QUEUED_PACKETS))
goto err_3;
@@ -50,7 +55,9 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_
rwlock_init(&peer->endpoint_lock);
kref_init(&peer->refcount);
skb_queue_head_init(&peer->staged_packet_queue);
- atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns() - (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC);
+ atomic64_set(&peer->last_sent_handshake,
+ ktime_get_boot_fast_ns() -
+ (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC);
set_bit(NAPI_STATE_NO_BUSY_POLL, &peer->napi.state);
netif_napi_add(wg->dev, &peer->napi, packet_rx_poll, NAPI_POLL_WEIGHT);
napi_enable(&peer->napi);
@@ -71,15 +78,16 @@ err_1:
struct wireguard_peer *peer_get_maybe_zero(struct wireguard_peer *peer)
{
- RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(), "Taking peer reference without holding the RCU read lock");
+ RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
+ "Taking peer reference without holding the RCU read lock");
if (unlikely(!peer || !kref_get_unless_zero(&peer->refcount)))
return NULL;
return peer;
}
-/* We have a separate "remove" function to get rid of the final reference because
- * peer_list, clearing handshakes, and flushing all require mutexes which requires
- * sleeping, which must only be done from certain contexts.
+/* We have a separate "remove" function to get rid of the final reference
+ * because peer_list, clearing handshakes, and flushing all require mutexes
+ * which requires sleeping, which must only be done from certain contexts.
*/
void peer_remove(struct wireguard_peer *peer)
{
@@ -87,37 +95,49 @@ void peer_remove(struct wireguard_peer *peer)
return;
lockdep_assert_held(&peer->device->device_update_lock);
- /* Remove from configuration-time lookup structures so new packets can't enter. */
+ /* Remove from configuration-time lookup structures so new packets
+ * can't enter.
+ */
list_del_init(&peer->peer_list);
- allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, &peer->device->device_update_lock);
+ allowedips_remove_by_peer(&peer->device->peer_allowedips, peer,
+ &peer->device->device_update_lock);
pubkey_hashtable_remove(&peer->device->peer_hashtable, peer);
/* Mark as dead, so that we don't allow jumping contexts after. */
WRITE_ONCE(peer->is_dead, true);
synchronize_rcu_bh();
- /* Now that no more keypairs can be created for this peer, we destroy existing ones. */
+ /* Now that no more keypairs can be created for this peer, we destroy
+ * existing ones.
+ */
noise_keypairs_clear(&peer->keypairs);
- /* Destroy all ongoing timers that were in-flight at the beginning of this function. */
+ /* Destroy all ongoing timers that were in-flight at the beginning of
+ * this function.
+ */
timers_stop(peer);
- /* The transition between packet encryption/decryption queues isn't guarded
- * by is_dead, but each reference's life is strictly bounded by two
- * generations: once for parallel crypto and once for serial ingestion,
- * so we can simply flush twice, and be sure that we no longer have references
- * inside these queues.
- *
- * a) For encrypt/decrypt. */
+ /* The transition between packet encryption/decryption queues isn't
+ * guarded by is_dead, but each reference's life is strictly bounded by
+ * two generations: once for parallel crypto and once for serial
+ * ingestion, so we can simply flush twice, and be sure that we no
+ * longer have references inside these queues.
+ */
+
+ /* a) For encrypt/decrypt. */
flush_workqueue(peer->device->packet_crypt_wq);
/* b.1) For send (but not receive, since that's napi). */
flush_workqueue(peer->device->packet_crypt_wq);
/* b.2.1) For receive (but not send, since that's wq). */
napi_disable(&peer->napi);
- /* b.2.1) It's now safe to remove the napi struct, which must be done here from process context. */
+ /* b.2.1) It's now safe to remove the napi struct, which must be done
+ * here from process context.
+ */
netif_napi_del(&peer->napi);
- /* Ensure any workstructs we own (like transmit_handshake_work or clear_peer_work) no longer are in use. */
+ /* Ensure any workstructs we own (like transmit_handshake_work or
+ * clear_peer_work) no longer are in use.
+ */
flush_workqueue(peer->device->handshake_send_wq);
--peer->device->num_peers;
@@ -126,7 +146,8 @@ void peer_remove(struct wireguard_peer *peer)
static void rcu_release(struct rcu_head *rcu)
{
- struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu);
+ struct wireguard_peer *peer =
+ container_of(rcu, struct wireguard_peer, rcu);
dst_cache_destroy(&peer->endpoint_cache);
packet_queue_free(&peer->rx_queue, false);
packet_queue_free(&peer->tx_queue, false);
@@ -135,11 +156,19 @@ static void rcu_release(struct rcu_head *rcu)
static void kref_release(struct kref *refcount)
{
- struct wireguard_peer *peer = container_of(refcount, struct wireguard_peer, refcount);
- pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
- /* Remove ourself from dynamic runtime lookup structures, now that the last reference is gone. */
- index_hashtable_remove(&peer->device->index_hashtable, &peer->handshake.entry);
- /* Remove any lingering packets that didn't have a chance to be transmitted. */
+ struct wireguard_peer *peer =
+ container_of(refcount, struct wireguard_peer, refcount);
+ pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n",
+ peer->device->dev->name, peer->internal_id,
+ &peer->endpoint.addr);
+ /* Remove ourself from dynamic runtime lookup structures, now that the
+ * last reference is gone.
+ */
+ index_hashtable_remove(&peer->device->index_hashtable,
+ &peer->handshake.entry);
+ /* Remove any lingering packets that didn't have a chance to be
+ * transmitted.
+ */
skb_queue_purge(&peer->staged_packet_queue);
/* Free the memory used. */
call_rcu_bh(&peer->rcu, rcu_release);
@@ -157,6 +186,6 @@ void peer_remove_all(struct wireguard_device *wg)
struct wireguard_peer *peer, *temp;
lockdep_assert_held(&wg->device_update_lock);
- list_for_each_entry_safe(peer, temp, &wg->peer_list, peer_list)
+ list_for_each_entry_safe (peer, temp, &wg->peer_list, peer_list)
peer_remove(peer);
}
diff --git a/src/peer.h b/src/peer.h
index 29d2e00..5613ccc 100644
--- a/src/peer.h
+++ b/src/peer.h
@@ -27,7 +27,8 @@ struct endpoint {
union {
struct {
struct in_addr src4;
- int src_if4; /* Essentially the same as addr6->scope_id */
+ /* Essentially the same as addr6->scope_id */
+ int src_if4;
};
struct in6_addr src6;
};
@@ -48,10 +49,13 @@ struct wireguard_peer {
struct cookie latest_cookie;
struct hlist_node pubkey_hash;
u64 rx_bytes, tx_bytes;
- struct timer_list timer_retransmit_handshake, timer_send_keepalive, timer_new_handshake, timer_zero_key_material, timer_persistent_keepalive;
+ struct timer_list timer_retransmit_handshake, timer_send_keepalive;
+ struct timer_list timer_new_handshake, timer_zero_key_material;
+ struct timer_list timer_persistent_keepalive;
unsigned int timer_handshake_attempts;
u16 persistent_keepalive_interval;
- bool timers_enabled, timer_need_another_keepalive, sent_lastminute_handshake;
+ bool timers_enabled, timer_need_another_keepalive;
+ bool sent_lastminute_handshake;
struct timespec walltime_last_handshake;
struct kref refcount;
struct rcu_head rcu;
@@ -61,9 +65,13 @@ struct wireguard_peer {
bool is_dead;
};
-struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]);
+struct wireguard_peer *
+peer_create(struct wireguard_device *wg,
+ const u8 public_key[NOISE_PUBLIC_KEY_LEN],
+ const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]);
-struct wireguard_peer * __must_check peer_get_maybe_zero(struct wireguard_peer *peer);
+struct wireguard_peer *__must_check
+peer_get_maybe_zero(struct wireguard_peer *peer);
static inline struct wireguard_peer *peer_get(struct wireguard_peer *peer)
{
kref_get(&peer->refcount);
@@ -73,6 +81,7 @@ void peer_put(struct wireguard_peer *peer);
void peer_remove(struct wireguard_peer *peer);
void peer_remove_all(struct wireguard_device *wg);
-struct wireguard_peer *peer_lookup_by_index(struct wireguard_device *wg, u32 index);
+struct wireguard_peer *peer_lookup_by_index(struct wireguard_device *wg,
+ u32 index);
#endif /* _WG_PEER_H */
diff --git a/src/queueing.c b/src/queueing.c
index c8394fc..9ec6588 100644
--- a/src/queueing.c
+++ b/src/queueing.c
@@ -5,22 +5,25 @@
#include "queueing.h"
-struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr)
+struct multicore_worker __percpu *
+packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr)
{
int cpu;
- struct multicore_worker __percpu *worker = alloc_percpu(struct multicore_worker);
+ struct multicore_worker __percpu *worker =
+ alloc_percpu(struct multicore_worker);
if (!worker)
return NULL;
- for_each_possible_cpu(cpu) {
+ for_each_possible_cpu (cpu) {
per_cpu_ptr(worker, cpu)->ptr = ptr;
INIT_WORK(&per_cpu_ptr(worker, cpu)->work, function);
}
return worker;
}
-int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore, unsigned int len)
+int packet_queue_init(struct crypt_queue *queue, work_func_t function,
+ bool multicore, unsigned int len)
{
int ret;
@@ -30,7 +33,8 @@ int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool mult
return ret;
if (function) {
if (multicore) {
- queue->worker = packet_alloc_percpu_multicore_worker(function, queue);
+ queue->worker = packet_alloc_percpu_multicore_worker(
+ function, queue);
if (!queue->worker)
return -ENOMEM;
} else
diff --git a/src/queueing.h b/src/queueing.h
index 6a1de33..66b7134 100644
--- a/src/queueing.h
+++ b/src/queueing.h
@@ -19,9 +19,11 @@ struct crypt_queue;
struct sk_buff;
/* queueing.c APIs: */
-int packet_queue_init(struct crypt_queue *queue, work_func_t function, bool multicore, unsigned int len);
+int packet_queue_init(struct crypt_queue *queue, work_func_t function,
+ bool multicore, unsigned int len);
void packet_queue_free(struct crypt_queue *queue, bool multicore);
-struct multicore_worker __percpu *packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr);
+struct multicore_worker __percpu *
+packet_alloc_percpu_multicore_worker(work_func_t function, void *ptr);
/* receive.c APIs: */
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb);
@@ -32,9 +34,12 @@ int packet_rx_poll(struct napi_struct *napi, int budget);
void packet_decrypt_worker(struct work_struct *work);
/* send.c APIs: */
-void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool is_retry);
+void packet_send_queued_handshake_initiation(struct wireguard_peer *peer,
+ bool is_retry);
void packet_send_handshake_response(struct wireguard_peer *peer);
-void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index);
+void packet_send_handshake_cookie(struct wireguard_device *wg,
+ struct sk_buff *initiating_skb,
+ __le32 sender_index);
void packet_send_keepalive(struct wireguard_peer *peer);
void packet_send_staged_packets(struct wireguard_peer *peer);
/* Workqueue workers: */
@@ -42,7 +47,12 @@ void packet_handshake_send_worker(struct work_struct *work);
void packet_tx_worker(struct work_struct *work);
void packet_encrypt_worker(struct work_struct *work);
-enum packet_state { PACKET_STATE_UNCRYPTED, PACKET_STATE_CRYPTED, PACKET_STATE_DEAD };
+enum packet_state {
+ PACKET_STATE_UNCRYPTED,
+ PACKET_STATE_CRYPTED,
+ PACKET_STATE_DEAD
+};
+
struct packet_cb {
u64 nonce;
struct noise_keypair *keypair;
@@ -50,15 +60,22 @@ struct packet_cb {
u32 mtu;
u8 ds;
};
+
#define PACKET_PEER(skb) (((struct packet_cb *)skb->cb)->keypair->entry.peer)
#define PACKET_CB(skb) ((struct packet_cb *)skb->cb)
/* Returns either the correct skb->protocol value, or 0 if invalid. */
static inline __be16 skb_examine_untrusted_ip_hdr(struct sk_buff *skb)
{
- if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct iphdr)) <= skb_tail_pointer(skb) && ip_hdr(skb)->version == 4)
+ if (skb_network_header(skb) >= skb->head &&
+ (skb_network_header(skb) + sizeof(struct iphdr)) <=
+ skb_tail_pointer(skb) &&
+ ip_hdr(skb)->version == 4)
return htons(ETH_P_IP);
- if (skb_network_header(skb) >= skb->head && (skb_network_header(skb) + sizeof(struct ipv6hdr)) <= skb_tail_pointer(skb) && ipv6_hdr(skb)->version == 6)
+ if (skb_network_header(skb) >= skb->head &&
+ (skb_network_header(skb) + sizeof(struct ipv6hdr)) <=
+ skb_tail_pointer(skb) &&
+ ipv6_hdr(skb)->version == 6)
return htons(ETH_P_IPV6);
return 0;
}
@@ -67,7 +84,9 @@ static inline void skb_reset(struct sk_buff *skb)
{
const int pfmemalloc = skb->pfmemalloc;
skb_scrub_packet(skb, true);
- memset(&skb->headers_start, 0, offsetof(struct sk_buff, headers_end) - offsetof(struct sk_buff, headers_start));
+ memset(&skb->headers_start, 0,
+ offsetof(struct sk_buff, headers_end) -
+ offsetof(struct sk_buff, headers_start));
skb->pfmemalloc = pfmemalloc;
skb->queue_mapping = 0;
skb->nohdr = 0;
@@ -89,7 +108,8 @@ static inline int cpumask_choose_online(int *stored_cpu, unsigned int id)
{
unsigned int cpu = *stored_cpu, cpu_index, i;
- if (unlikely(cpu == nr_cpumask_bits || !cpumask_test_cpu(cpu, cpu_online_mask))) {
+ if (unlikely(cpu == nr_cpumask_bits ||
+ !cpumask_test_cpu(cpu, cpu_online_mask))) {
cpu_index = id % cpumask_weight(cpu_online_mask);
cpu = cpumask_first(cpu_online_mask);
for (i = 0; i < cpu_index; ++i)
@@ -103,8 +123,8 @@ static inline int cpumask_choose_online(int *stored_cpu, unsigned int id)
* the same CPU twice. A race-free version of this would be to instead store an
* atomic sequence number, do an increment-and-return, and then iterate through
* every possible CPU until we get to that index -- choose_cpu. However that's
- * a bit slower, and it doesn't seem like this potential race actually introduces
- * any performance loss, so we live with it.
+ * a bit slower, and it doesn't seem like this potential race actually
+ * introduces any performance loss, so we live with it.
*/
static inline int cpumask_next_online(int *next)
{
@@ -116,7 +136,9 @@ static inline int cpumask_next_online(int *next)
return cpu;
}
-static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_queue, struct crypt_queue *peer_queue, struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu)
+static inline int queue_enqueue_per_device_and_peer(
+ struct crypt_queue *device_queue, struct crypt_queue *peer_queue,
+ struct sk_buff *skb, struct workqueue_struct *wq, int *next_cpu)
{
int cpu;
@@ -136,18 +158,24 @@ static inline int queue_enqueue_per_device_and_peer(struct crypt_queue *device_q
return 0;
}
-static inline void queue_enqueue_per_peer(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state)
+static inline void queue_enqueue_per_peer(struct crypt_queue *queue,
+ struct sk_buff *skb,
+ enum packet_state state)
{
/* We take a reference, because as soon as we call atomic_set, the
* peer can be freed from below us.
*/
struct wireguard_peer *peer = peer_get(PACKET_PEER(skb));
atomic_set_release(&PACKET_CB(skb)->state, state);
- queue_work_on(cpumask_choose_online(&peer->serial_work_cpu, peer->internal_id), peer->device->packet_crypt_wq, &queue->work);
+ queue_work_on(cpumask_choose_online(&peer->serial_work_cpu,
+ peer->internal_id),
+ peer->device->packet_crypt_wq, &queue->work);
peer_put(peer);
}
-static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state)
+static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue,
+ struct sk_buff *skb,
+ enum packet_state state)
{
/* We take a reference, because as soon as we call atomic_set, the
* peer can be freed from below us.
diff --git a/src/ratelimiter.c b/src/ratelimiter.c
index 638da2b..3966ce8 100644
--- a/src/ratelimiter.c
+++ b/src/ratelimiter.c
@@ -41,7 +41,8 @@ enum {
static void entry_free(struct rcu_head *rcu)
{
- kmem_cache_free(entry_cache, container_of(rcu, struct ratelimiter_entry, rcu));
+ kmem_cache_free(entry_cache,
+ container_of(rcu, struct ratelimiter_entry, rcu));
atomic_dec(&total_entries);
}
@@ -54,20 +55,22 @@ static void entry_uninit(struct ratelimiter_entry *entry)
/* Calling this function with a NULL work uninits all entries. */
static void gc_entries(struct work_struct *work)
{
- unsigned int i;
+ const u64 now = ktime_get_boot_fast_ns();
struct ratelimiter_entry *entry;
struct hlist_node *temp;
- const u64 now = ktime_get_boot_fast_ns();
+ unsigned int i;
for (i = 0; i < table_size; ++i) {
spin_lock(&table_lock);
- hlist_for_each_entry_safe(entry, temp, &table_v4[i], hash) {
- if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC)
+ hlist_for_each_entry_safe (entry, temp, &table_v4[i], hash) {
+ if (unlikely(!work) ||
+ now - entry->last_time_ns > NSEC_PER_SEC)
entry_uninit(entry);
}
#if IS_ENABLED(CONFIG_IPV6)
- hlist_for_each_entry_safe(entry, temp, &table_v6[i], hash) {
- if (unlikely(!work) || now - entry->last_time_ns > NSEC_PER_SEC)
+ hlist_for_each_entry_safe (entry, temp, &table_v6[i], hash) {
+ if (unlikely(!work) ||
+ now - entry->last_time_ns > NSEC_PER_SEC)
entry_uninit(entry);
}
#endif
@@ -81,34 +84,41 @@ static void gc_entries(struct work_struct *work)
bool ratelimiter_allow(struct sk_buff *skb, struct net *net)
{
+ struct { __be64 ip; u32 net; } data =
+ { .net = (unsigned long)net & 0xffffffff };
struct ratelimiter_entry *entry;
struct hlist_head *bucket;
- struct { __be64 ip; u32 net; } data = { .net = (unsigned long)net & 0xffffffff };
if (skb->protocol == htons(ETH_P_IP)) {
data.ip = (__force __be64)ip_hdr(skb)->saddr;
- bucket = &table_v4[hsiphash(&data, sizeof(u32) * 3, &key) & (table_size - 1)];
+ bucket = &table_v4[hsiphash(&data, sizeof(u32) * 3, &key) &
+ (table_size - 1)];
}
#if IS_ENABLED(CONFIG_IPV6)
else if (skb->protocol == htons(ETH_P_IPV6)) {
- memcpy(&data.ip, &ipv6_hdr(skb)->saddr, sizeof(__be64)); /* Only 64 bits */
- bucket = &table_v6[hsiphash(&data, sizeof(u32) * 3, &key) & (table_size - 1)];
+ memcpy(&data.ip, &ipv6_hdr(skb)->saddr,
+ sizeof(__be64)); /* Only 64 bits */
+ bucket = &table_v6[hsiphash(&data, sizeof(u32) * 3, &key) &
+ (table_size - 1)];
}
#endif
else
return false;
rcu_read_lock();
- hlist_for_each_entry_rcu(entry, bucket, hash) {
+ hlist_for_each_entry_rcu (entry, bucket, hash) {
if (entry->net == net && entry->ip == data.ip) {
u64 now, tokens;
bool ret;
- /* Inspired by nft_limit.c, but this is actually a slightly different
- * algorithm. Namely, we incorporate the burst as part of the maximum
- * tokens, rather than as part of the rate.
+ /* Quasi-inspired by nft_limit.c, but this is actually a
+ * slightly different algorithm. Namely, we incorporate
+ * the burst as part of the maximum tokens, rather than
+ * as part of the rate.
*/
spin_lock(&entry->lock);
now = ktime_get_boot_fast_ns();
- tokens = min_t(u64, TOKEN_MAX, entry->tokens + now - entry->last_time_ns);
+ tokens = min_t(u64, TOKEN_MAX,
+ entry->tokens + now -
+ entry->last_time_ns);
entry->last_time_ns = now;
ret = tokens >= PACKET_COST;
entry->tokens = ret ? tokens - PACKET_COST : tokens;
@@ -157,7 +167,10 @@ int ratelimiter_init(void)
* we borrow their wisdom about good table sizes on different systems
* dependent on RAM. This calculation here comes from there.
*/
- table_size = (totalram_pages > (1U << 30) / PAGE_SIZE) ? 8192 : max_t(unsigned long, 16, roundup_pow_of_two((totalram_pages << PAGE_SHIFT) / (1U << 14) / sizeof(struct hlist_head)));
+ table_size = (totalram_pages > (1U << 30) / PAGE_SIZE) ? 8192 :
+ max_t(unsigned long, 16, roundup_pow_of_two(
+ (totalram_pages << PAGE_SHIFT) /
+ (1U << 14) / sizeof(struct hlist_head)));
max_entries = table_size * 8;
table_v4 = kvzalloc(table_size * sizeof(struct hlist_head), GFP_KERNEL);
diff --git a/src/receive.c b/src/receive.c
index 4e73da1..a2d1d9d 100644
--- a/src/receive.c
+++ b/src/receive.c
@@ -20,7 +20,8 @@
/* Must be called with bh disabled. */
static inline void rx_stats(struct wireguard_peer *peer, size_t len)
{
- struct pcpu_sw_netstats *tstats = get_cpu_ptr(peer->device->dev->tstats);
+ struct pcpu_sw_netstats *tstats =
+ get_cpu_ptr(peer->device->dev->tstats);
u64_stats_update_begin(&tstats->syncp);
++tstats->rx_packets;
@@ -36,38 +37,57 @@ static inline size_t validate_header_len(struct sk_buff *skb)
{
if (unlikely(skb->len < sizeof(struct message_header)))
return 0;
- if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) && skb->len >= MESSAGE_MINIMUM_LENGTH)
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_DATA) &&
+ skb->len >= MESSAGE_MINIMUM_LENGTH)
return sizeof(struct message_data);
- if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) && skb->len == sizeof(struct message_handshake_initiation))
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION) &&
+ skb->len == sizeof(struct message_handshake_initiation))
return sizeof(struct message_handshake_initiation);
- if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) && skb->len == sizeof(struct message_handshake_response))
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE) &&
+ skb->len == sizeof(struct message_handshake_response))
return sizeof(struct message_handshake_response);
- if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) && skb->len == sizeof(struct message_handshake_cookie))
+ if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE) &&
+ skb->len == sizeof(struct message_handshake_cookie))
return sizeof(struct message_handshake_cookie);
return 0;
}
-static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_device *wg)
+static inline int skb_prepare_header(struct sk_buff *skb,
+ struct wireguard_device *wg)
{
- struct udphdr *udp;
size_t data_offset, data_len, header_len;
+ struct udphdr *udp;
- if (unlikely(skb_examine_untrusted_ip_hdr(skb) != skb->protocol || skb_transport_header(skb) < skb->head || (skb_transport_header(skb) + sizeof(struct udphdr)) > skb_tail_pointer(skb)))
+ if (unlikely(skb_examine_untrusted_ip_hdr(skb) != skb->protocol ||
+ skb_transport_header(skb) < skb->head ||
+ (skb_transport_header(skb) + sizeof(struct udphdr)) >
+ skb_tail_pointer(skb)))
return -EINVAL; /* Bogus IP header */
udp = udp_hdr(skb);
data_offset = (u8 *)udp - skb->data;
- if (unlikely(data_offset > U16_MAX || data_offset + sizeof(struct udphdr) > skb->len))
- return -EINVAL; /* Packet has offset at impossible location or isn't big enough to have UDP fields */
+ if (unlikely(data_offset > U16_MAX ||
+ data_offset + sizeof(struct udphdr) > skb->len))
+ /* Packet has offset at impossible location or isn't big enough
+ * to have UDP fields.
+ */
+ return -EINVAL;
data_len = ntohs(udp->len);
- if (unlikely(data_len < sizeof(struct udphdr) || data_len > skb->len - data_offset))
- return -EINVAL; /* UDP packet is reporting too small of a size or lying about its size */
+ if (unlikely(data_len < sizeof(struct udphdr) ||
+ data_len > skb->len - data_offset))
+ /* UDP packet is reporting too small of a size or lying about
+ * its size.
+ */
+ return -EINVAL;
data_len -= sizeof(struct udphdr);
data_offset = (u8 *)udp + sizeof(struct udphdr) - skb->data;
- if (unlikely(!pskb_may_pull(skb, data_offset + sizeof(struct message_header)) || pskb_trim(skb, data_len + data_offset) < 0))
+ if (unlikely(!pskb_may_pull(skb,
+ data_offset + sizeof(struct message_header)) ||
+ pskb_trim(skb, data_len + data_offset) < 0))
return -EINVAL;
skb_pull(skb, data_offset);
if (unlikely(skb->len != data_len))
- return -EINVAL; /* Final len does not agree with calculated len */
+ /* Final len does not agree with calculated len */
+ return -EINVAL;
header_len = validate_header_len(skb);
if (unlikely(!header_len))
return -EINVAL;
@@ -78,74 +98,96 @@ static inline int skb_prepare_header(struct sk_buff *skb, struct wireguard_devic
return 0;
}
-static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff *skb)
+static void receive_handshake_packet(struct wireguard_device *wg,
+ struct sk_buff *skb)
{
- static u64 last_under_load; /* Yes this is global, so that our load calculation applies to the whole system. */
struct wireguard_peer *peer = NULL;
- bool under_load;
enum cookie_mac_state mac_state;
+ /* This is global, so that our load calculation applies to
+ * the whole system.
+ */
+ static u64 last_under_load;
bool packet_needs_cookie;
+ bool under_load;
if (SKB_TYPE_LE32(skb) == cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE)) {
- net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n", wg->dev->name, skb);
- cookie_message_consume((struct message_handshake_cookie *)skb->data, wg);
+ net_dbg_skb_ratelimited("%s: Receiving cookie response from %pISpfsc\n",
+ wg->dev->name, skb);
+ cookie_message_consume(
+ (struct message_handshake_cookie *)skb->data, wg);
return;
}
- under_load = skb_queue_len(&wg->incoming_handshakes) >= MAX_QUEUED_INCOMING_HANDSHAKES / 8;
+ under_load = skb_queue_len(&wg->incoming_handshakes) >=
+ MAX_QUEUED_INCOMING_HANDSHAKES / 8;
if (under_load)
last_under_load = ktime_get_boot_fast_ns();
else if (last_under_load)
under_load = !has_expired(last_under_load, 1);
- mac_state = cookie_validate_packet(&wg->cookie_checker, skb, under_load);
- if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) || (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
+ mac_state = cookie_validate_packet(&wg->cookie_checker, skb,
+ under_load);
+ if ((under_load && mac_state == VALID_MAC_WITH_COOKIE) ||
+ (!under_load && mac_state == VALID_MAC_BUT_NO_COOKIE))
packet_needs_cookie = false;
else if (under_load && mac_state == VALID_MAC_BUT_NO_COOKIE)
packet_needs_cookie = true;
else {
- net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n", wg->dev->name, skb);
+ net_dbg_skb_ratelimited("%s: Invalid MAC of handshake, dropping packet from %pISpfsc\n",
+ wg->dev->name, skb);
return;
}
switch (SKB_TYPE_LE32(skb)) {
case cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION): {
- struct message_handshake_initiation *message = (struct message_handshake_initiation *)skb->data;
+ struct message_handshake_initiation *message =
+ (struct message_handshake_initiation *)skb->data;
if (packet_needs_cookie) {
- packet_send_handshake_cookie(wg, skb, message->sender_index);
+ packet_send_handshake_cookie(wg, skb,
+ message->sender_index);
return;
}
peer = noise_handshake_consume_initiation(message, wg);
if (unlikely(!peer)) {
- net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n", wg->dev->name, skb);
+ net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n",
+ wg->dev->name, skb);
return;
}
socket_set_peer_endpoint_from_skb(peer, skb);
- net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n", wg->dev->name, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n",
+ wg->dev->name, peer->internal_id,
+ &peer->endpoint.addr);
packet_send_handshake_response(peer);
break;
}
case cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE): {
- struct message_handshake_response *message = (struct message_handshake_response *)skb->data;
+ struct message_handshake_response *message =
+ (struct message_handshake_response *)skb->data;
if (packet_needs_cookie) {
- packet_send_handshake_cookie(wg, skb, message->sender_index);
+ packet_send_handshake_cookie(wg, skb,
+ message->sender_index);
return;
}
peer = noise_handshake_consume_response(message, wg);
if (unlikely(!peer)) {
- net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n", wg->dev->name, skb);
+ net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n",
+ wg->dev->name, skb);
return;
}
socket_set_peer_endpoint_from_skb(peer, skb);
- net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n", wg->dev->name, peer->internal_id, &peer->endpoint.addr);
- if (noise_handshake_begin_session(&peer->handshake, &peer->keypairs)) {
+ net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n",
+ wg->dev->name, peer->internal_id,
+ &peer->endpoint.addr);
+ if (noise_handshake_begin_session(&peer->handshake,
+ &peer->keypairs)) {
timers_session_derived(peer);
timers_handshake_complete(peer);
- /* Calling this function will either send any existing packets in the queue
- * and not send a keepalive, which is the best case, Or, if there's nothing
- * in the queue, it will send a keepalive, in order to give immediate
- * confirmation of the session.
+ /* Calling this function will either send any existing
+ * packets in the queue and not send a keepalive, which
+ * is the best case, Or, if there's nothing in the
+ * queue, it will send a keepalive, in order to give
+ * immediate confirmation of the session.
*/
packet_send_keepalive(peer);
}
@@ -169,7 +211,8 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff
void packet_handshake_receive_worker(struct work_struct *work)
{
- struct wireguard_device *wg = container_of(work, struct multicore_worker, work)->ptr;
+ struct wireguard_device *wg =
+ container_of(work, struct multicore_worker, work)->ptr;
struct sk_buff *skb;
while ((skb = skb_dequeue(&wg->incoming_handshakes)) != NULL) {
@@ -189,8 +232,10 @@ static inline void keep_key_fresh(struct wireguard_peer *peer)
rcu_read_lock_bh();
keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
- if (likely(keypair && keypair->sending.is_valid) && keypair->i_am_the_initiator &&
- unlikely(has_expired(keypair->sending.birthdate, REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT)))
+ if (likely(keypair && keypair->sending.is_valid) &&
+ keypair->i_am_the_initiator &&
+ unlikely(has_expired(keypair->sending.birthdate,
+ REJECT_AFTER_TIME - KEEPALIVE_TIMEOUT - REKEY_TIMEOUT)))
send = true;
rcu_read_unlock_bh();
@@ -200,7 +245,9 @@ static inline void keep_key_fresh(struct wireguard_peer *peer)
}
}
-static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *key, simd_context_t simd_context)
+static inline bool skb_decrypt(struct sk_buff *skb,
+ struct noise_symmetric_key *key,
+ simd_context_t simd_context)
{
struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1];
struct sk_buff *trailer;
@@ -210,12 +257,15 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *
if (unlikely(!key))
return false;
- if (unlikely(!key->is_valid || has_expired(key->birthdate, REJECT_AFTER_TIME) || key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
+ if (unlikely(!key->is_valid ||
+ has_expired(key->birthdate, REJECT_AFTER_TIME) ||
+ key->counter.receive.counter >= REJECT_AFTER_MESSAGES)) {
key->is_valid = false;
return false;
}
- PACKET_CB(skb)->nonce = le64_to_cpu(((struct message_data *)skb->data)->counter);
+ PACKET_CB(skb)->nonce =
+ le64_to_cpu(((struct message_data *)skb->data)->counter);
/* We ensure that the network header is part of the packet before we
* call skb_cow_data, so that there's no chance that data is removed
@@ -233,7 +283,9 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *
if (skb_to_sgvec(skb, sg, 0, skb->len) <= 0)
return false;
- if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0, PACKET_CB(skb)->nonce, key->key, simd_context))
+ if (!chacha20poly1305_decrypt_sg(sg, sg, skb->len, NULL, 0,
+ PACKET_CB(skb)->nonce, key->key,
+ simd_context))
return false;
/* Another ugly situation of pushing and pulling the header so as to
@@ -248,33 +300,39 @@ static inline bool skb_decrypt(struct sk_buff *skb, struct noise_symmetric_key *
}
/* This is RFC6479, a replay detection bitmap algorithm that avoids bitshifts */
-static inline bool counter_validate(union noise_counter *counter, u64 their_counter)
+static inline bool counter_validate(union noise_counter *counter,
+ u64 their_counter)
{
- bool ret = false;
unsigned long index, index_current, top, i;
+ bool ret = false;
spin_lock_bh(&counter->receive.lock);
- if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 || their_counter >= REJECT_AFTER_MESSAGES))
+ if (unlikely(counter->receive.counter >= REJECT_AFTER_MESSAGES + 1 ||
+ their_counter >= REJECT_AFTER_MESSAGES))
goto out;
++their_counter;
- if (unlikely((COUNTER_WINDOW_SIZE + their_counter) < counter->receive.counter))
+ if (unlikely((COUNTER_WINDOW_SIZE + their_counter) <
+ counter->receive.counter))
goto out;
index = their_counter >> ilog2(BITS_PER_LONG);
if (likely(their_counter > counter->receive.counter)) {
index_current = counter->receive.counter >> ilog2(BITS_PER_LONG);
- top = min_t(unsigned long, index - index_current, COUNTER_BITS_TOTAL / BITS_PER_LONG);
+ top = min_t(unsigned long, index - index_current,
+ COUNTER_BITS_TOTAL / BITS_PER_LONG);
for (i = 1; i <= top; ++i)
- counter->receive.backtrack[(i + index_current) & ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
+ counter->receive.backtrack[(i + index_current) &
+ ((COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1)] = 0;
counter->receive.counter = their_counter;
}
index &= (COUNTER_BITS_TOTAL / BITS_PER_LONG) - 1;
- ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1), &counter->receive.backtrack[index]);
+ ret = !test_and_set_bit(their_counter & (BITS_PER_LONG - 1),
+ &counter->receive.backtrack[index]);
out:
spin_unlock_bh(&counter->receive.lock);
@@ -282,15 +340,18 @@ out:
}
#include "selftest/counter.h"
-static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff *skb, struct endpoint *endpoint)
+static void packet_consume_data_done(struct wireguard_peer *peer,
+ struct sk_buff *skb,
+ struct endpoint *endpoint)
{
- struct wireguard_peer *routed_peer;
struct net_device *dev = peer->device->dev;
+ struct wireguard_peer *routed_peer;
unsigned int len, len_before_trim;
socket_set_peer_endpoint(peer, endpoint);
- if (unlikely(noise_received_with_keypair(&peer->keypairs, PACKET_CB(skb)->keypair))) {
+ if (unlikely(noise_received_with_keypair(&peer->keypairs,
+ PACKET_CB(skb)->keypair))) {
timers_handshake_complete(peer);
packet_send_staged_packets(peer);
}
@@ -303,7 +364,9 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff
/* A packet with length 0 is a keepalive packet */
if (unlikely(!skb->len)) {
rx_stats(peer, message_data_len(0));
- net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_ratelimited("%s: Receiving keepalive packet from peer %llu (%pISpfsc)\n",
+ dev->name, peer->internal_id,
+ &peer->endpoint.addr);
goto packet_processed;
}
@@ -311,7 +374,10 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff
if (unlikely(skb_network_header(skb) < skb->head))
goto dishonest_packet_size;
- if (unlikely(!(pskb_network_may_pull(skb, sizeof(struct iphdr)) && (ip_hdr(skb)->version == 4 || (ip_hdr(skb)->version == 6 && pskb_network_may_pull(skb, sizeof(struct ipv6hdr)))))))
+ if (unlikely(!(pskb_network_may_pull(skb, sizeof(struct iphdr)) &&
+ (ip_hdr(skb)->version == 4 ||
+ (ip_hdr(skb)->version == 6 &&
+ pskb_network_may_pull(skb, sizeof(struct ipv6hdr)))))))
goto dishonest_packet_type;
skb->dev = dev;
@@ -324,7 +390,8 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff
if (INET_ECN_is_ce(PACKET_CB(skb)->ds))
IP_ECN_set_ce(ip_hdr(skb));
} else if (skb->protocol == htons(ETH_P_IPV6)) {
- len = ntohs(ipv6_hdr(skb)->payload_len) + sizeof(struct ipv6hdr);
+ len = ntohs(ipv6_hdr(skb)->payload_len) +
+ sizeof(struct ipv6hdr);
if (INET_ECN_is_ce(PACKET_CB(skb)->ds))
IP6_ECN_set_ce(skb, ipv6_hdr(skb));
} else
@@ -344,23 +411,29 @@ static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff
if (unlikely(napi_gro_receive(&peer->napi, skb) == GRO_DROP)) {
++dev->stats.rx_dropped;
- net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n",
+ dev->name, peer->internal_id,
+ &peer->endpoint.addr);
} else
rx_stats(peer, message_data_len(len_before_trim));
return;
dishonest_packet_peer:
- net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n", dev->name, skb, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n",
+ dev->name, skb, peer->internal_id,
+ &peer->endpoint.addr);
++dev->stats.rx_errors;
++dev->stats.rx_frame_errors;
goto packet_processed;
dishonest_packet_type:
- net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n",
+ dev->name, peer->internal_id, &peer->endpoint.addr);
++dev->stats.rx_errors;
++dev->stats.rx_frame_errors;
goto packet_processed;
dishonest_packet_size:
- net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_ratelimited("%s: Packet has incorrect size from peer %llu (%pISpfsc)\n",
+ dev->name, peer->internal_id, &peer->endpoint.addr);
++dev->stats.rx_errors;
++dev->stats.rx_length_errors;
goto packet_processed;
@@ -370,19 +443,22 @@ packet_processed:
int packet_rx_poll(struct napi_struct *napi, int budget)
{
- struct wireguard_peer *peer = container_of(napi, struct wireguard_peer, napi);
+ struct wireguard_peer *peer =
+ container_of(napi, struct wireguard_peer, napi);
struct crypt_queue *queue = &peer->rx_queue;
struct noise_keypair *keypair;
- struct sk_buff *skb;
struct endpoint endpoint;
enum packet_state state;
+ struct sk_buff *skb;
int work_done = 0;
bool free;
if (unlikely(budget <= 0))
return 0;
- while ((skb = __ptr_ring_peek(&queue->ring)) != NULL && (state = atomic_read_acquire(&PACKET_CB(skb)->state)) != PACKET_STATE_UNCRYPTED) {
+ while ((skb = __ptr_ring_peek(&queue->ring)) != NULL &&
+ (state = atomic_read_acquire(&PACKET_CB(skb)->state)) !=
+ PACKET_STATE_UNCRYPTED) {
__ptr_ring_discard_one(&queue->ring);
peer = PACKET_PEER(skb);
keypair = PACKET_CB(skb)->keypair;
@@ -391,8 +467,12 @@ int packet_rx_poll(struct napi_struct *napi, int budget)
if (unlikely(state != PACKET_STATE_CRYPTED))
goto next;
- if (unlikely(!counter_validate(&keypair->receiving.counter, PACKET_CB(skb)->nonce))) {
- net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, keypair->receiving.counter.receive.counter);
+ if (unlikely(!counter_validate(&keypair->receiving.counter,
+ PACKET_CB(skb)->nonce))) {
+ net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n",
+ peer->device->dev->name,
+ PACKET_CB(skb)->nonce,
+ keypair->receiving.counter.receive.counter);
goto next;
}
@@ -403,7 +483,7 @@ int packet_rx_poll(struct napi_struct *napi, int budget)
packet_consume_data_done(peer, skb, &endpoint);
free = false;
-next:
+ next:
noise_keypair_put(keypair, false);
peer_put(peer);
if (unlikely(free))
@@ -421,34 +501,46 @@ next:
void packet_decrypt_worker(struct work_struct *work)
{
- struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr;
- struct sk_buff *skb;
+ struct crypt_queue *queue =
+ container_of(work, struct multicore_worker, work)->ptr;
simd_context_t simd_context = simd_get();
+ struct sk_buff *skb;
while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) {
- enum packet_state state = likely(skb_decrypt(skb, &PACKET_CB(skb)->keypair->receiving, simd_context)) ? PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
- queue_enqueue_per_peer_napi(&PACKET_PEER(skb)->rx_queue, skb, state);
+ enum packet_state state = likely(skb_decrypt(skb,
+ &PACKET_CB(skb)->keypair->receiving,
+ simd_context)) ?
+ PACKET_STATE_CRYPTED : PACKET_STATE_DEAD;
+ queue_enqueue_per_peer_napi(&PACKET_PEER(skb)->rx_queue, skb,
+ state);
simd_context = simd_relax(simd_context);
}
simd_put(simd_context);
}
-static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb)
+static void packet_consume_data(struct wireguard_device *wg,
+ struct sk_buff *skb)
{
- struct wireguard_peer *peer = NULL;
__le32 idx = ((struct message_data *)skb->data)->key_idx;
+ struct wireguard_peer *peer = NULL;
int ret;
rcu_read_lock_bh();
- PACKET_CB(skb)->keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx, &peer);
+ PACKET_CB(skb)->keypair =
+ (struct noise_keypair *)index_hashtable_lookup(
+ &wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx,
+ &peer);
if (unlikely(!noise_keypair_get(PACKET_CB(skb)->keypair)))
goto err_keypair;
if (unlikely(peer->is_dead))
goto err;
- ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu);
+ ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue,
+ &peer->rx_queue, skb,
+ wg->packet_crypt_wq,
+ &wg->decrypt_queue.last_cpu);
if (unlikely(ret == -EPIPE))
queue_enqueue_per_peer(&peer->rx_queue, skb, PACKET_STATE_DEAD);
if (likely(!ret || ret == -EPIPE)) {
@@ -473,14 +565,20 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
case cpu_to_le32(MESSAGE_HANDSHAKE_COOKIE): {
int cpu;
- if (skb_queue_len(&wg->incoming_handshakes) > MAX_QUEUED_INCOMING_HANDSHAKES || unlikely(!rng_is_initialized())) {
- net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n", wg->dev->name, skb);
+ if (skb_queue_len(&wg->incoming_handshakes) >
+ MAX_QUEUED_INCOMING_HANDSHAKES ||
+ unlikely(!rng_is_initialized())) {
+ net_dbg_skb_ratelimited("%s: Dropping handshake packet from %pISpfsc\n",
+ wg->dev->name, skb);
goto err;
}
skb_queue_tail(&wg->incoming_handshakes, skb);
- /* Queues up a call to packet_process_queued_handshake_packets(skb): */
+ /* Queues up a call to packet_process_queued_handshake_
+ * packets(skb):
+ */
cpu = cpumask_next_online(&wg->incoming_handshake_cpu);
- queue_work_on(cpu, wg->handshake_receive_wq, &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work);
+ queue_work_on(cpu, wg->handshake_receive_wq,
+ &per_cpu_ptr(wg->incoming_handshakes_worker, cpu)->work);
break;
}
case cpu_to_le32(MESSAGE_DATA):
@@ -488,7 +586,8 @@ void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)
packet_consume_data(wg, skb);
break;
default:
- net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n", wg->dev->name, skb);
+ net_dbg_skb_ratelimited("%s: Invalid packet from %pISpfsc\n",
+ wg->dev->name, skb);
goto err;
}
return;
diff --git a/src/selftest/allowedips.h b/src/selftest/allowedips.h
index 1cf8c8f..28461d7 100644
--- a/src/selftest/allowedips.h
+++ b/src/selftest/allowedips.h
@@ -8,7 +8,8 @@
#ifdef DEBUG_PRINT_TRIE_GRAPHVIZ
#include <linux/siphash.h>
-static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, u8 cidr)
+static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits,
+ u8 cidr)
{
swap_endian(dst, src, bits);
memset(dst + (cidr + 7) / 8, 0, bits / 8 - (cidr + 7) / 8);
@@ -18,34 +19,44 @@ static __init void swap_endian_and_apply_cidr(u8 *dst, const u8 *src, u8 bits, u
static __init void print_node(struct allowedips_node *node, u8 bits)
{
- u32 color = 0;
- char *style = "dotted";
char *fmt_connection = KERN_DEBUG "\t\"%p/%d\" -> \"%p/%d\";\n";
- char *fmt_declaration = KERN_DEBUG "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
+ char *fmt_declaration = KERN_DEBUG
+ "\t\"%p/%d\"[style=%s, color=\"#%06x\"];\n";
+ char *style = "dotted";
u8 ip1[16], ip2[16];
+ u32 color = 0;
+
if (bits == 32) {
fmt_connection = KERN_DEBUG "\t\"%pI4/%d\" -> \"%pI4/%d\";\n";
- fmt_declaration = KERN_DEBUG "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
+ fmt_declaration = KERN_DEBUG
+ "\t\"%pI4/%d\"[style=%s, color=\"#%06x\"];\n";
} else if (bits == 128) {
fmt_connection = KERN_DEBUG "\t\"%pI6/%d\" -> \"%pI6/%d\";\n";
- fmt_declaration = KERN_DEBUG "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
+ fmt_declaration = KERN_DEBUG
+ "\t\"%pI6/%d\"[style=%s, color=\"#%06x\"];\n";
}
if (node->peer) {
hsiphash_key_t key = { 0 };
memcpy(&key, &node->peer, sizeof(node->peer));
- color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 | hsiphash_1u32(0xbabecafe, &key) % 200 << 8 | hsiphash_1u32(0xabad1dea, &key) % 200;
+ color = hsiphash_1u32(0xdeadbeef, &key) % 200 << 16 |
+ hsiphash_1u32(0xbabecafe, &key) % 200 << 8 |
+ hsiphash_1u32(0xabad1dea, &key) % 200;
style = "bold";
}
swap_endian_and_apply_cidr(ip1, node->bits, bits, node->cidr);
printk(fmt_declaration, ip1, node->cidr, style, color);
if (node->bit[0]) {
- swap_endian_and_apply_cidr(ip2, node->bit[0]->bits, bits, node->cidr);
- printk(fmt_connection, ip1, node->cidr, ip2, node->bit[0]->cidr);
+ swap_endian_and_apply_cidr(ip2, node->bit[0]->bits, bits,
+ node->cidr);
+ printk(fmt_connection, ip1, node->cidr, ip2,
+ node->bit[0]->cidr);
print_node(node->bit[0], bits);
}
if (node->bit[1]) {
- swap_endian_and_apply_cidr(ip2, node->bit[1]->bits, bits, node->cidr);
- printk(fmt_connection, ip1, node->cidr, ip2, node->bit[1]->cidr);
+ swap_endian_and_apply_cidr(ip2, node->bit[1]->bits, bits,
+ node->cidr);
+ printk(fmt_connection, ip1, node->cidr, ip2,
+ node->bit[1]->cidr);
print_node(node->bit[1], bits);
}
}
@@ -79,9 +90,10 @@ static __init void horrible_allowedips_init(struct horrible_allowedips *table)
}
static __init void horrible_allowedips_free(struct horrible_allowedips *table)
{
- struct hlist_node *h;
struct horrible_allowedips_node *node;
- hlist_for_each_entry_safe(node, h, &table->head, table) {
+ struct hlist_node *h;
+
+ hlist_for_each_entry_safe (node, h, &table->head, table) {
hlist_del(&node->table);
kfree(node);
}
@@ -89,20 +101,21 @@ static __init void horrible_allowedips_free(struct horrible_allowedips *table)
static __init inline union nf_inet_addr horrible_cidr_to_mask(uint8_t cidr)
{
union nf_inet_addr mask;
+
memset(&mask, 0x00, 128 / 8);
memset(&mask, 0xff, cidr / 8);
if (cidr % 32)
- mask.all[cidr / 32] = htonl((0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
+ mask.all[cidr / 32] = htonl(
+ (0xFFFFFFFFUL << (32 - (cidr % 32))) & 0xFFFFFFFFUL);
return mask;
}
static __init inline uint8_t horrible_mask_to_cidr(union nf_inet_addr subnet)
{
- return hweight32(subnet.all[0])
- + hweight32(subnet.all[1])
- + hweight32(subnet.all[2])
- + hweight32(subnet.all[3]);
+ return hweight32(subnet.all[0]) + hweight32(subnet.all[1]) +
+ hweight32(subnet.all[2]) + hweight32(subnet.all[3]);
}
-static __init inline void horrible_mask_self(struct horrible_allowedips_node *node)
+static __init inline void
+horrible_mask_self(struct horrible_allowedips_node *node)
{
if (node->ip_version == 4)
node->ip.ip &= node->mask.ip;
@@ -113,24 +126,36 @@ static __init inline void horrible_mask_self(struct horrible_allowedips_node *no
node->ip.ip6[3] &= node->mask.ip6[3];
}
}
-static __init inline bool horrible_match_v4(const struct horrible_allowedips_node *node, struct in_addr *ip)
+static __init inline bool
+horrible_match_v4(const struct horrible_allowedips_node *node,
+ struct in_addr *ip)
{
return (ip->s_addr & node->mask.ip) == node->ip.ip;
}
-static __init inline bool horrible_match_v6(const struct horrible_allowedips_node *node, struct in6_addr *ip)
+static __init inline bool
+horrible_match_v6(const struct horrible_allowedips_node *node,
+ struct in6_addr *ip)
{
- return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) == node->ip.ip6[0] &&
- (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) == node->ip.ip6[1] &&
- (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) == node->ip.ip6[2] &&
- (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
+ return (ip->in6_u.u6_addr32[0] & node->mask.ip6[0]) ==
+ node->ip.ip6[0] &&
+ (ip->in6_u.u6_addr32[1] & node->mask.ip6[1]) ==
+ node->ip.ip6[1] &&
+ (ip->in6_u.u6_addr32[2] & node->mask.ip6[2]) ==
+ node->ip.ip6[2] &&
+ (ip->in6_u.u6_addr32[3] & node->mask.ip6[3]) == node->ip.ip6[3];
}
-static __init void horrible_insert_ordered(struct horrible_allowedips *table, struct horrible_allowedips_node *node)
+static __init void
+horrible_insert_ordered(struct horrible_allowedips *table,
+ struct horrible_allowedips_node *node)
{
struct horrible_allowedips_node *other = NULL, *where = NULL;
uint8_t my_cidr = horrible_mask_to_cidr(node->mask);
- hlist_for_each_entry(other, &table->head, table) {
- if (!memcmp(&other->mask, &node->mask, sizeof(union nf_inet_addr)) &&
- !memcmp(&other->ip, &node->ip, sizeof(union nf_inet_addr)) &&
+
+ hlist_for_each_entry (other, &table->head, table) {
+ if (!memcmp(&other->mask, &node->mask,
+ sizeof(union nf_inet_addr)) &&
+ !memcmp(&other->ip, &node->ip,
+ sizeof(union nf_inet_addr)) &&
other->ip_version == node->ip_version) {
other->value = node->value;
kfree(node);
@@ -147,9 +172,13 @@ static __init void horrible_insert_ordered(struct horrible_allowedips *table, st
else
hlist_add_before(&node->table, &where->table);
}
-static __init int horrible_allowedips_insert_v4(struct horrible_allowedips *table, struct in_addr *ip, uint8_t cidr, void *value)
+static __init int
+horrible_allowedips_insert_v4(struct horrible_allowedips *table,
+ struct in_addr *ip, uint8_t cidr, void *value)
{
- struct horrible_allowedips_node *node = kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL);
+ struct horrible_allowedips_node *node =
+ kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL);
+
if (!node)
return -ENOMEM;
node->ip.in = *ip;
@@ -160,9 +189,13 @@ static __init int horrible_allowedips_insert_v4(struct horrible_allowedips *tabl
horrible_insert_ordered(table, node);
return 0;
}
-static __init int horrible_allowedips_insert_v6(struct horrible_allowedips *table, struct in6_addr *ip, uint8_t cidr, void *value)
+static __init int
+horrible_allowedips_insert_v6(struct horrible_allowedips *table,
+ struct in6_addr *ip, uint8_t cidr, void *value)
{
- struct horrible_allowedips_node *node = kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL);
+ struct horrible_allowedips_node *node =
+ kzalloc(sizeof(struct horrible_allowedips_node), GFP_KERNEL);
+
if (!node)
return -ENOMEM;
node->ip.in6 = *ip;
@@ -173,11 +206,14 @@ static __init int horrible_allowedips_insert_v6(struct horrible_allowedips *tabl
horrible_insert_ordered(table, node);
return 0;
}
-static __init void *horrible_allowedips_lookup_v4(struct horrible_allowedips *table, struct in_addr *ip)
+static __init void *
+horrible_allowedips_lookup_v4(struct horrible_allowedips *table,
+ struct in_addr *ip)
{
struct horrible_allowedips_node *node;
void *ret = NULL;
- hlist_for_each_entry(node, &table->head, table) {
+
+ hlist_for_each_entry (node, &table->head, table) {
if (node->ip_version != 4)
continue;
if (horrible_match_v4(node, ip)) {
@@ -187,11 +223,14 @@ static __init void *horrible_allowedips_lookup_v4(struct horrible_allowedips *ta
}
return ret;
}
-static __init void *horrible_allowedips_lookup_v6(struct horrible_allowedips *table, struct in6_addr *ip)
+static __init void *
+horrible_allowedips_lookup_v6(struct horrible_allowedips *table,
+ struct in6_addr *ip)
{
struct horrible_allowedips_node *node;
void *ret = NULL;
- hlist_for_each_entry(node, &table->head, table) {
+
+ hlist_for_each_entry (node, &table->head, table) {
if (node->ip_version != 6)
continue;
if (horrible_match_v6(node, ip)) {
@@ -204,13 +243,13 @@ static __init void *horrible_allowedips_lookup_v6(struct horrible_allowedips *ta
static __init bool randomized_test(void)
{
- DEFINE_MUTEX(mutex);
- bool ret = false;
unsigned int i, j, k, mutate_amount, cidr;
+ u8 ip[16], mutate_mask[16], mutated[16];
struct wireguard_peer **peers, *peer;
- struct allowedips t;
struct horrible_allowedips h;
- u8 ip[16], mutate_mask[16], mutated[16];
+ DEFINE_MUTEX(mutex);
+ struct allowedips t;
+ bool ret = false;
mutex_init(&mutex);
@@ -237,11 +276,13 @@ static __init bool randomized_test(void)
prandom_bytes(ip, 4);
cidr = prandom_u32_max(32) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
- if (allowedips_insert_v4(&t, (struct in_addr *)ip, cidr, peer, &mutex) < 0) {
+ if (allowedips_insert_v4(&t, (struct in_addr *)ip, cidr, peer,
+ &mutex) < 0) {
pr_info("allowedips random self-test: out of memory\n");
goto free;
}
- if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip, cidr, peer) < 0) {
+ if (horrible_allowedips_insert_v4(&h, (struct in_addr *)ip,
+ cidr, peer) < 0) {
pr_info("allowedips random self-test: out of memory\n");
goto free;
}
@@ -251,18 +292,23 @@ static __init bool randomized_test(void)
mutate_amount = prandom_u32_max(32);
for (k = 0; k < mutate_amount / 8; ++k)
mutate_mask[k] = 0xff;
- mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8);
+ mutate_mask[k] = 0xff
+ << ((8 - (mutate_amount % 8)) % 8);
for (; k < 4; ++k)
mutate_mask[k] = 0;
for (k = 0; k < 4; ++k)
- mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256));
+ mutated[k] = (mutated[k] & mutate_mask[k]) |
+ (~mutate_mask[k] &
+ prandom_u32_max(256));
cidr = prandom_u32_max(32) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
- if (allowedips_insert_v4(&t, (struct in_addr *)mutated, cidr, peer, &mutex) < 0) {
+ if (allowedips_insert_v4(&t, (struct in_addr *)mutated,
+ cidr, peer, &mutex) < 0) {
pr_info("allowedips random self-test: out of memory\n");
goto free;
}
- if (horrible_allowedips_insert_v4(&h, (struct in_addr *)mutated, cidr, peer)) {
+ if (horrible_allowedips_insert_v4(&h,
+ (struct in_addr *)mutated, cidr, peer)) {
pr_info("allowedips random self-test: out of memory\n");
goto free;
}
@@ -273,11 +319,13 @@ static __init bool randomized_test(void)
prandom_bytes(ip, 16);
cidr = prandom_u32_max(128) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
- if (allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr, peer, &mutex) < 0) {
+ if (allowedips_insert_v6(&t, (struct in6_addr *)ip, cidr, peer,
+ &mutex) < 0) {
pr_info("allowedips random self-test: out of memory\n");
goto free;
}
- if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip, cidr, peer) < 0) {
+ if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)ip,
+ cidr, peer) < 0) {
pr_info("allowedips random self-test: out of memory\n");
goto free;
}
@@ -287,18 +335,24 @@ static __init bool randomized_test(void)
mutate_amount = prandom_u32_max(128);
for (k = 0; k < mutate_amount / 8; ++k)
mutate_mask[k] = 0xff;
- mutate_mask[k] = 0xff << ((8 - (mutate_amount % 8)) % 8);
+ mutate_mask[k] = 0xff
+ << ((8 - (mutate_amount % 8)) % 8);
for (; k < 4; ++k)
mutate_mask[k] = 0;
for (k = 0; k < 4; ++k)
- mutated[k] = (mutated[k] & mutate_mask[k]) | (~mutate_mask[k] & prandom_u32_max(256));
+ mutated[k] = (mutated[k] & mutate_mask[k]) |
+ (~mutate_mask[k] &
+ prandom_u32_max(256));
cidr = prandom_u32_max(128) + 1;
peer = peers[prandom_u32_max(NUM_PEERS)];
- if (allowedips_insert_v6(&t, (struct in6_addr *)mutated, cidr, peer, &mutex) < 0) {
+ if (allowedips_insert_v6(&t, (struct in6_addr *)mutated,
+ cidr, peer, &mutex) < 0) {
pr_info("allowedips random self-test: out of memory\n");
goto free;
}
- if (horrible_allowedips_insert_v6(&h, (struct in6_addr *)mutated, cidr, peer)) {
+ if (horrible_allowedips_insert_v6(
+ &h, (struct in6_addr *)mutated, cidr,
+ peer)) {
pr_info("allowedips random self-test: out of memory\n");
goto free;
}
@@ -314,7 +368,8 @@ static __init bool randomized_test(void)
for (i = 0; i < NUM_QUERIES; ++i) {
prandom_bytes(ip, 4);
- if (lookup(t.root4, 32, ip) != horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
+ if (lookup(t.root4, 32, ip) !=
+ horrible_allowedips_lookup_v4(&h, (struct in_addr *)ip)) {
pr_info("allowedips random self-test: FAIL\n");
goto free;
}
@@ -322,7 +377,8 @@ static __init bool randomized_test(void)
for (i = 0; i < NUM_QUERIES; ++i) {
prandom_bytes(ip, 16);
- if (lookup(t.root6, 128, ip) != horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
+ if (lookup(t.root6, 128, ip) !=
+ horrible_allowedips_lookup_v6(&h, (struct in6_addr *)ip)) {
pr_info("allowedips random self-test: FAIL\n");
goto free;
}
@@ -376,15 +432,22 @@ static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family)
wctx->count++;
- if (cidr == 27 && !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
+ if (cidr == 27 &&
+ !memcmp(ip, ip4(192, 95, 5, 64), sizeof(struct in_addr)))
wctx->found_a = true;
- else if (cidr == 128 && !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543), sizeof(struct in6_addr)))
+ else if (cidr == 128 &&
+ !memcmp(ip, ip6(0x26075300, 0x60006b00, 0, 0xc05f0543),
+ sizeof(struct in6_addr)))
wctx->found_b = true;
- else if (cidr == 29 && !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
+ else if (cidr == 29 &&
+ !memcmp(ip, ip4(10, 1, 0, 16), sizeof(struct in_addr)))
wctx->found_c = true;
- else if (cidr == 83 && !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0), sizeof(struct in6_addr)))
+ else if (cidr == 83 &&
+ !memcmp(ip, ip6(0x26075300, 0x6d8a6bf8, 0xdab1e000, 0),
+ sizeof(struct in6_addr)))
wctx->found_d = true;
- else if (cidr == 21 && !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr)))
+ else if (cidr == 21 &&
+ !memcmp(ip, ip6(0x26075000, 0, 0, 0), sizeof(struct in6_addr)))
wctx->found_e = true;
else
wctx->found_other = true;
@@ -392,50 +455,55 @@ static __init int walk_callback(void *ctx, const u8 *ip, u8 cidr, int family)
return 0;
}
-#define init_peer(name) do { \
- name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); \
- if (!name) { \
- pr_info("allowedips self-test: out of memory\n"); \
- goto free; \
- } \
- kref_init(&name->refcount); \
-} while (0)
-
-#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \
- allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), cidr, mem, &mutex)
-
-#define maybe_fail \
- ++i; \
- if (!_s) { \
- pr_info("allowedips self-test %zu: FAIL\n", i); \
- success = false; \
- }
-
-#define test(version, mem, ipa, ipb, ipc, ipd) do { \
- bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) == mem; \
- maybe_fail \
-} while (0)
-
-#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
- bool _s = lookup(t.root##version, version == 4 ? 32 : 128, ip##version(ipa, ipb, ipc, ipd)) != mem; \
- maybe_fail \
-} while (0)
-
-#define test_boolean(cond) do { \
- bool _s = (cond); \
- maybe_fail \
-} while (0)
+#define init_peer(name) do { \
+ name = kzalloc(sizeof(struct wireguard_peer), GFP_KERNEL); \
+ if (!name) { \
+ pr_info("allowedips self-test: out of memory\n"); \
+ goto free; \
+ } \
+ kref_init(&name->refcount); \
+ } while (0)
+
+#define insert(version, mem, ipa, ipb, ipc, ipd, cidr) \
+ allowedips_insert_v##version(&t, ip##version(ipa, ipb, ipc, ipd), \
+ cidr, mem, &mutex)
+
+#define maybe_fail() do { \
+ ++i; \
+ if (!_s) { \
+ pr_info("allowedips self-test %zu: FAIL\n", i); \
+ success = false; \
+ } \
+ } while (0)
+
+#define test(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t.root##version, version == 4 ? 32 : 128, \
+ ip##version(ipa, ipb, ipc, ipd)) == mem; \
+ maybe_fail(); \
+ } while (0)
+
+#define test_negative(version, mem, ipa, ipb, ipc, ipd) do { \
+ bool _s = lookup(t.root##version, version == 4 ? 32 : 128, \
+ ip##version(ipa, ipb, ipc, ipd)) != mem; \
+ maybe_fail(); \
+ } while (0)
+
+#define test_boolean(cond) do { \
+ bool _s = (cond); \
+ maybe_fail(); \
+ } while (0)
bool __init allowedips_selftest(void)
{
- DEFINE_MUTEX(mutex);
- struct allowedips t;
- struct walk_ctx wctx = { 0 };
+ struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL,
+ *e = NULL, *f = NULL, *g = NULL, *h = NULL;
struct allowedips_cursor cursor = { 0 };
- struct wireguard_peer *a = NULL, *b = NULL, *c = NULL, *d = NULL, *e = NULL, *f = NULL, *g = NULL, *h = NULL;
- size_t i = 0;
+ struct walk_ctx wctx = { 0 };
bool success = false;
+ struct allowedips t;
+ DEFINE_MUTEX(mutex);
struct in6_addr ip;
+ size_t i = 0;
__be64 part;
mutex_init(&mutex);
@@ -455,19 +523,23 @@ bool __init allowedips_selftest(void)
insert(4, b, 192, 168, 4, 4, 32);
insert(4, c, 192, 168, 0, 0, 16);
insert(4, d, 192, 95, 5, 64, 27);
- insert(4, c, 192, 95, 5, 65, 27); /* replaces previous entry, and maskself is required */
+ /* replaces previous entry, and maskself is required */
+ insert(4, c, 192, 95, 5, 65, 27);
insert(6, d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128);
insert(6, c, 0x26075300, 0x60006b00, 0, 0, 64);
insert(4, e, 0, 0, 0, 0, 0);
insert(6, e, 0, 0, 0, 0, 0);
- insert(6, f, 0, 0, 0, 0, 0); /* replaces previous entry */
+ /* replaces previous entry */
+ insert(6, f, 0, 0, 0, 0, 0);
insert(6, g, 0x24046800, 0, 0, 0, 32);
- insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64); /* maskself is required */
+ /* maskself is required */
+ insert(6, h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64);
insert(6, a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128);
insert(6, c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128);
insert(6, b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98);
insert(4, g, 64, 15, 112, 0, 20);
- insert(4, h, 64, 15, 123, 211, 25); /* maskself is required */
+ /* maskself is required */
+ insert(4, h, 64, 15, 123, 211, 25);
insert(4, a, 10, 0, 0, 0, 25);
insert(4, b, 10, 0, 0, 128, 25);
insert(4, a, 10, 1, 0, 0, 30);
diff --git a/src/selftest/counter.h b/src/selftest/counter.h
index 5344075..1c2a3b4 100644
--- a/src/selftest/counter.h
+++ b/src/selftest/counter.h
@@ -6,13 +6,24 @@
#ifdef DEBUG
bool __init packet_counter_selftest(void)
{
- bool success = true;
unsigned int test_num = 0, i;
union noise_counter counter;
+ bool success = true;
-#define T_INIT do { memset(&counter, 0, sizeof(union noise_counter)); spin_lock_init(&counter.receive.lock); } while (0)
+#define T_INIT do { \
+ memset(&counter, 0, sizeof(union noise_counter)); \
+ spin_lock_init(&counter.receive.lock); \
+ } while (0)
#define T_LIM (COUNTER_WINDOW_SIZE + 1)
-#define T(n, v) do { ++test_num; if (counter_validate(&counter, n) != v) { pr_info("nonce counter self-test %u: FAIL\n", test_num); success = false; } } while (0)
+#define T(n, v) do { \
+ ++test_num; \
+ if (counter_validate(&counter, n) != v) { \
+ pr_info("nonce counter self-test %u: FAIL\n", \
+ test_num); \
+ success = false; \
+ } \
+ } while (0)
+
T_INIT;
/* 1 */ T(0, true);
/* 2 */ T(1, true);
@@ -62,22 +73,22 @@ bool __init packet_counter_selftest(void)
T(0, false);
T_INIT;
- for (i = COUNTER_WINDOW_SIZE + 1; i-- > 0 ;)
+ for (i = COUNTER_WINDOW_SIZE + 1; i-- > 0;)
T(i, true);
T_INIT;
- for (i = COUNTER_WINDOW_SIZE + 2; i-- > 1 ;)
+ for (i = COUNTER_WINDOW_SIZE + 2; i-- > 1;)
T(i, true);
T(0, false);
T_INIT;
- for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1 ;)
+ for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1;)
T(i, true);
T(COUNTER_WINDOW_SIZE + 1, true);
T(0, false);
T_INIT;
- for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1 ;)
+ for (i = COUNTER_WINDOW_SIZE + 1; i-- > 1;)
T(i, true);
T(0, true);
T(COUNTER_WINDOW_SIZE + 1, true);
diff --git a/src/selftest/ratelimiter.h b/src/selftest/ratelimiter.h
index c05eac7..a71ddb1 100644
--- a/src/selftest/ratelimiter.h
+++ b/src/selftest/ratelimiter.h
@@ -7,7 +7,10 @@
#include <linux/jiffies.h>
-static const struct { bool result; unsigned int msec_to_sleep_before; } expected_results[] __initconst = {
+static const struct {
+ bool result;
+ unsigned int msec_to_sleep_before;
+} expected_results[] __initconst = {
[0 ... PACKETS_BURSTABLE - 1] = { true, 0 },
[PACKETS_BURSTABLE] = { false, 0 },
[PACKETS_BURSTABLE + 1] = { true, MSEC_PER_SEC / PACKETS_PER_SECOND },
@@ -29,14 +32,14 @@ static __init unsigned int maximum_jiffies_at_index(int index)
bool __init ratelimiter_selftest(void)
{
- struct sk_buff *skb4;
- struct iphdr *hdr4;
+ int i, test = 0, tries = 0, ret = false;
+ unsigned long loop_start_time;
#if IS_ENABLED(CONFIG_IPV6)
struct sk_buff *skb6;
struct ipv6hdr *hdr6;
#endif
- int i, test = 0, tries = 0, ret = false;
- unsigned long loop_start_time;
+ struct sk_buff *skb4;
+ struct iphdr *hdr4;
BUILD_BUG_ON(MSEC_PER_SEC % PACKETS_PER_SECOND != 0);
@@ -81,21 +84,24 @@ bool __init ratelimiter_selftest(void)
restart:
loop_start_time = jiffies;
for (i = 0; i < ARRAY_SIZE(expected_results); ++i) {
-#define ensure_time do {\
- if (time_is_before_jiffies(loop_start_time + maximum_jiffies_at_index(i))) { \
- if (++tries >= 5000) \
- goto err; \
- gc_entries(NULL); \
- rcu_barrier(); \
- msleep(500); \
- goto restart; \
- }} while (0)
+#define ensure_time do { \
+ if (time_is_before_jiffies(loop_start_time + \
+ maximum_jiffies_at_index(i))) { \
+ if (++tries >= 5000) \
+ goto err; \
+ gc_entries(NULL); \
+ rcu_barrier(); \
+ msleep(500); \
+ goto restart; \
+ } \
+ } while (0)
if (expected_results[i].msec_to_sleep_before)
msleep(expected_results[i].msec_to_sleep_before);
ensure_time;
- if (ratelimiter_allow(skb4, &init_net) != expected_results[i].result)
+ if (ratelimiter_allow(skb4, &init_net) !=
+ expected_results[i].result)
goto err;
++test;
hdr4->saddr = htonl(ntohl(hdr4->saddr) + i + 1);
@@ -106,17 +112,21 @@ restart:
hdr4->saddr = htonl(ntohl(hdr4->saddr) - i - 1);
#if IS_ENABLED(CONFIG_IPV6)
- hdr6->saddr.in6_u.u6_addr32[2] = hdr6->saddr.in6_u.u6_addr32[3] = htonl(i);
+ hdr6->saddr.in6_u.u6_addr32[2] =
+ hdr6->saddr.in6_u.u6_addr32[3] = htonl(i);
ensure_time;
- if (ratelimiter_allow(skb6, &init_net) != expected_results[i].result)
+ if (ratelimiter_allow(skb6, &init_net) !=
+ expected_results[i].result)
goto err;
++test;
- hdr6->saddr.in6_u.u6_addr32[0] = htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) + i + 1);
+ hdr6->saddr.in6_u.u6_addr32[0] =
+ htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) + i + 1);
ensure_time;
if (!ratelimiter_allow(skb6, &init_net))
goto err;
++test;
- hdr6->saddr.in6_u.u6_addr32[0] = htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) - i - 1);
+ hdr6->saddr.in6_u.u6_addr32[0] =
+ htonl(ntohl(hdr6->saddr.in6_u.u6_addr32[0]) - i - 1);
ensure_time;
#endif
}
diff --git a/src/send.c b/src/send.c
index 3af7ef3..ec0074e 100644
--- a/src/send.c
+++ b/src/send.c
@@ -23,46 +23,63 @@ static void packet_send_handshake_initiation(struct wireguard_peer *peer)
{
struct message_handshake_initiation packet;
- if (!has_expired(atomic64_read(&peer->last_sent_handshake), REKEY_TIMEOUT))
+ if (!has_expired(atomic64_read(&peer->last_sent_handshake),
+ REKEY_TIMEOUT))
return; /* This function is rate limited. */
atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns());
- net_dbg_ratelimited("%s: Sending handshake initiation to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_ratelimited("%s: Sending handshake initiation to peer %llu (%pISpfsc)\n",
+ peer->device->dev->name, peer->internal_id,
+ &peer->endpoint.addr);
if (noise_handshake_create_initiation(&packet, &peer->handshake)) {
cookie_add_mac_to_packet(&packet, sizeof(packet), peer);
timers_any_authenticated_packet_traversal(peer);
timers_any_authenticated_packet_sent(peer);
- atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns());
- socket_send_buffer_to_peer(peer, &packet, sizeof(struct message_handshake_initiation), HANDSHAKE_DSCP);
+ atomic64_set(&peer->last_sent_handshake,
+ ktime_get_boot_fast_ns());
+ socket_send_buffer_to_peer(
+ peer, &packet,
+ sizeof(struct message_handshake_initiation),
+ HANDSHAKE_DSCP);
timers_handshake_initiated(peer);
}
}
void packet_handshake_send_worker(struct work_struct *work)
{
- struct wireguard_peer *peer = container_of(work, struct wireguard_peer, transmit_handshake_work);
+ struct wireguard_peer *peer = container_of(work, struct wireguard_peer,
+ transmit_handshake_work);
packet_send_handshake_initiation(peer);
peer_put(peer);
}
-void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool is_retry)
+void packet_send_queued_handshake_initiation(struct wireguard_peer *peer,
+ bool is_retry)
{
if (!is_retry)
peer->timer_handshake_attempts = 0;
rcu_read_lock_bh();
- /* We check last_sent_handshake here in addition to the actual function we're queueing
- * up, so that we don't queue things if not strictly necessary.
+ /* We check last_sent_handshake here in addition to the actual function
+ * we're queueing up, so that we don't queue things if not strictly
+ * necessary:
*/
- if (!has_expired(atomic64_read(&peer->last_sent_handshake), REKEY_TIMEOUT) || unlikely(peer->is_dead))
+ if (!has_expired(atomic64_read(&peer->last_sent_handshake),
+ REKEY_TIMEOUT) || unlikely(peer->is_dead))
goto out;
peer_get(peer);
- /* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */
- if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work))
- peer_put(peer); /* If the work was already queued, we want to drop the extra reference */
+ /* Queues up calling packet_send_queued_handshakes(peer), where we do a
+ * peer_put(peer) after:
+ */
+ if (!queue_work(peer->device->handshake_send_wq,
+ &peer->transmit_handshake_work))
+ /* If the work was already queued, we want to drop the
+ * extra reference:
+ */
+ peer_put(peer);
out:
rcu_read_unlock_bh();
}
@@ -72,27 +89,39 @@ void packet_send_handshake_response(struct wireguard_peer *peer)
struct message_handshake_response packet;
atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns());
- net_dbg_ratelimited("%s: Sending handshake response to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_ratelimited("%s: Sending handshake response to peer %llu (%pISpfsc)\n",
+ peer->device->dev->name, peer->internal_id,
+ &peer->endpoint.addr);
if (noise_handshake_create_response(&packet, &peer->handshake)) {
cookie_add_mac_to_packet(&packet, sizeof(packet), peer);
- if (noise_handshake_begin_session(&peer->handshake, &peer->keypairs)) {
+ if (noise_handshake_begin_session(&peer->handshake,
+ &peer->keypairs)) {
timers_session_derived(peer);
timers_any_authenticated_packet_traversal(peer);
timers_any_authenticated_packet_sent(peer);
- atomic64_set(&peer->last_sent_handshake, ktime_get_boot_fast_ns());
- socket_send_buffer_to_peer(peer, &packet, sizeof(struct message_handshake_response), HANDSHAKE_DSCP);
+ atomic64_set(&peer->last_sent_handshake,
+ ktime_get_boot_fast_ns());
+ socket_send_buffer_to_peer(
+ peer, &packet,
+ sizeof(struct message_handshake_response),
+ HANDSHAKE_DSCP);
}
}
}
-void packet_send_handshake_cookie(struct wireguard_device *wg, struct sk_buff *initiating_skb, __le32 sender_index)
+void packet_send_handshake_cookie(struct wireguard_device *wg,
+ struct sk_buff *initiating_skb,
+ __le32 sender_index)
{
struct message_handshake_cookie packet;
- net_dbg_skb_ratelimited("%s: Sending cookie response for denied handshake message for %pISpfsc\n", wg->dev->name, initiating_skb);
- cookie_message_create(&packet, initiating_skb, sender_index, &wg->cookie_checker);
- socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet, sizeof(packet));
+ net_dbg_skb_ratelimited("%s: Sending cookie response for denied handshake message for %pISpfsc\n",
+ wg->dev->name, initiating_skb);
+ cookie_message_create(&packet, initiating_skb, sender_index,
+ &wg->cookie_checker);
+ socket_send_buffer_as_reply_to_skb(wg, initiating_skb, &packet,
+ sizeof(packet));
}
static inline void keep_key_fresh(struct wireguard_peer *peer)
@@ -103,8 +132,11 @@ static inline void keep_key_fresh(struct wireguard_peer *peer)
rcu_read_lock_bh();
keypair = rcu_dereference_bh(peer->keypairs.current_keypair);
if (likely(keypair && keypair->sending.is_valid) &&
- (unlikely(atomic64_read(&keypair->sending.counter.counter) > REKEY_AFTER_MESSAGES) ||
- (keypair->i_am_the_initiator && unlikely(has_expired(keypair->sending.birthdate, REKEY_AFTER_TIME)))))
+ (unlikely(atomic64_read(&keypair->sending.counter.counter) >
+ REKEY_AFTER_MESSAGES) ||
+ (keypair->i_am_the_initiator &&
+ unlikely(has_expired(keypair->sending.birthdate,
+ REKEY_AFTER_TIME)))))
send = true;
rcu_read_unlock_bh();
@@ -114,9 +146,10 @@ static inline void keep_key_fresh(struct wireguard_peer *peer)
static inline unsigned int skb_padding(struct sk_buff *skb)
{
- /* We do this modulo business with the MTU, just in case the networking layer
- * gives us a packet that's bigger than the MTU. In that case, we wouldn't want
- * the final subtraction to overflow in the case of the padded_size being clamped.
+ /* We do this modulo business with the MTU, just in case the networking
+ * layer gives us a packet that's bigger than the MTU. In that case, we
+ * wouldn't want the final subtraction to overflow in the case of the
+ * padded_size being clamped.
*/
unsigned int last_unit = skb->len % PACKET_CB(skb)->mtu;
unsigned int padded_size = ALIGN(last_unit, MESSAGE_PADDING_MULTIPLE);
@@ -126,38 +159,49 @@ static inline unsigned int skb_padding(struct sk_buff *skb)
return padded_size - last_unit;
}
-static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypair, simd_context_t simd_context)
+static inline bool skb_encrypt(struct sk_buff *skb,
+ struct noise_keypair *keypair,
+ simd_context_t simd_context)
{
+ unsigned int padding_len, plaintext_len, trailer_len;
struct scatterlist sg[MAX_SKB_FRAGS * 2 + 1];
struct message_data *header;
- unsigned int padding_len, plaintext_len, trailer_len;
- int num_frags;
struct sk_buff *trailer;
+ int num_frags;
- /* Calculate lengths */
+ /* Calculate lengths. */
padding_len = skb_padding(skb);
trailer_len = padding_len + noise_encrypted_len(0);
plaintext_len = skb->len + padding_len;
- /* Expand data section to have room for padding and auth tag */
+ /* Expand data section to have room for padding and auth tag. */
num_frags = skb_cow_data(skb, trailer_len, &trailer);
if (unlikely(num_frags < 0 || num_frags > ARRAY_SIZE(sg)))
return false;
- /* Set the padding to zeros, and make sure it and the auth tag are part of the skb */
+ /* Set the padding to zeros, and make sure it and the auth tag are part
+ * of the skb.
+ */
memset(skb_tail_pointer(trailer), 0, padding_len);
- /* Expand head section to have room for our header and the network stack's headers. */
+ /* Expand head section to have room for our header and the network
+ * stack's headers.
+ */
if (unlikely(skb_cow_head(skb, DATA_PACKET_HEAD_ROOM) < 0))
return false;
- /* We have to remember to add the checksum to the innerpacket, in case the receiver forwards it. */
+ /* We have to remember to add the checksum to the innerpacket, in case
+ * the receiver forwards it.
+ */
if (likely(!skb_checksum_setup(skb, true)))
skb_checksum_help(skb);
- /* Only after checksumming can we safely add on the padding at the end and the header. */
+ /* Only after checksumming can we safely add on the padding at the end
+ * and the header.
+ */
skb_set_inner_network_header(skb, 0);
- header = (struct message_data *)skb_push(skb, sizeof(struct message_data));
+ header = (struct message_data *)skb_push(skb,
+ sizeof(struct message_data));
header->header.type = cpu_to_le32(MESSAGE_DATA);
header->key_idx = keypair->remote_index;
header->counter = cpu_to_le64(PACKET_CB(skb)->nonce);
@@ -165,9 +209,12 @@ static inline bool skb_encrypt(struct sk_buff *skb, struct noise_keypair *keypai
/* Now we can encrypt the scattergather segments */
sg_init_table(sg, num_frags);
- if (skb_to_sgvec(skb, sg, sizeof(struct message_data), noise_encrypted_len(plaintext_len)) <= 0)
+ if (skb_to_sgvec(skb, sg, sizeof(struct message_data),
+ noise_encrypted_len(plaintext_len)) <= 0)
return false;
- return chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0, PACKET_CB(skb)->nonce, keypair->sending.key, simd_context);
+ return chacha20poly1305_encrypt_sg(sg, sg, plaintext_len, NULL, 0,
+ PACKET_CB(skb)->nonce,
+ keypair->sending.key, simd_context);
}
void packet_send_keepalive(struct wireguard_peer *peer)
@@ -175,38 +222,45 @@ void packet_send_keepalive(struct wireguard_peer *peer)
struct sk_buff *skb;
if (skb_queue_empty(&peer->staged_packet_queue)) {
- skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH, GFP_ATOMIC);
+ skb = alloc_skb(DATA_PACKET_HEAD_ROOM + MESSAGE_MINIMUM_LENGTH,
+ GFP_ATOMIC);
if (unlikely(!skb))
return;
skb_reserve(skb, DATA_PACKET_HEAD_ROOM);
skb->dev = peer->device->dev;
PACKET_CB(skb)->mtu = skb->dev->mtu;
skb_queue_tail(&peer->staged_packet_queue, skb);
- net_dbg_ratelimited("%s: Sending keepalive packet to peer %llu (%pISpfsc)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
+ net_dbg_ratelimited("%s: Sending keepalive packet to peer %llu (%pISpfsc)\n",
+ peer->device->dev->name, peer->internal_id,
+ &peer->endpoint.addr);
}
packet_send_staged_packets(peer);
}
-#define skb_walk_null_queue_safe(first, skb, next) for (skb = first, next = skb->next; skb; skb = next, next = skb ? skb->next : NULL)
+#define skb_walk_null_queue_safe(first, skb, next) \
+ for (skb = first, next = skb->next; skb; \
+ skb = next, next = skb ? skb->next : NULL)
static inline void skb_free_null_queue(struct sk_buff *first)
{
struct sk_buff *skb, *next;
- skb_walk_null_queue_safe(first, skb, next)
+ skb_walk_null_queue_safe (first, skb, next)
dev_kfree_skb(skb);
}
-static void packet_create_data_done(struct sk_buff *first, struct wireguard_peer *peer)
+static void packet_create_data_done(struct sk_buff *first,
+ struct wireguard_peer *peer)
{
struct sk_buff *skb, *next;
bool is_keepalive, data_sent = false;
timers_any_authenticated_packet_traversal(peer);
timers_any_authenticated_packet_sent(peer);
- skb_walk_null_queue_safe(first, skb, next) {
+ skb_walk_null_queue_safe (first, skb, next) {
is_keepalive = skb->len == message_data_len(0);
- if (likely(!socket_send_skb_to_peer(peer, skb, PACKET_CB(skb)->ds) && !is_keepalive))
+ if (likely(!socket_send_skb_to_peer(peer, skb,
+ PACKET_CB(skb)->ds) && !is_keepalive))
data_sent = true;
}
@@ -218,13 +272,16 @@ static void packet_create_data_done(struct sk_buff *first, struct wireguard_peer
void packet_tx_worker(struct work_struct *work)
{
- struct crypt_queue *queue = container_of(work, struct crypt_queue, work);
+ struct crypt_queue *queue =
+ container_of(work, struct crypt_queue, work);
struct wireguard_peer *peer;
struct noise_keypair *keypair;
struct sk_buff *first;
enum packet_state state;
- while ((first = __ptr_ring_peek(&queue->ring)) != NULL && (state = atomic_read_acquire(&PACKET_CB(first)->state)) != PACKET_STATE_UNCRYPTED) {
+ while ((first = __ptr_ring_peek(&queue->ring)) != NULL &&
+ (state = atomic_read_acquire(&PACKET_CB(first)->state)) !=
+ PACKET_STATE_UNCRYPTED) {
__ptr_ring_discard_one(&queue->ring);
peer = PACKET_PEER(first);
keypair = PACKET_CB(first)->keypair;
@@ -241,22 +298,25 @@ void packet_tx_worker(struct work_struct *work)
void packet_encrypt_worker(struct work_struct *work)
{
- struct crypt_queue *queue = container_of(work, struct multicore_worker, work)->ptr;
+ struct crypt_queue *queue =
+ container_of(work, struct multicore_worker, work)->ptr;
struct sk_buff *first, *skb, *next;
simd_context_t simd_context = simd_get();
while ((first = ptr_ring_consume_bh(&queue->ring)) != NULL) {
enum packet_state state = PACKET_STATE_CRYPTED;
- skb_walk_null_queue_safe(first, skb, next) {
- if (likely(skb_encrypt(skb, PACKET_CB(first)->keypair, simd_context)))
+ skb_walk_null_queue_safe (first, skb, next) {
+ if (likely(skb_encrypt(skb, PACKET_CB(first)->keypair,
+ simd_context)))
skb_reset(skb);
else {
state = PACKET_STATE_DEAD;
break;
}
}
- queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, state);
+ queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first,
+ state);
simd_context = simd_relax(simd_context);
}
@@ -273,9 +333,13 @@ static void packet_create_data(struct sk_buff *first)
if (unlikely(peer->is_dead))
goto err;
- ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu);
+ ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue,
+ &peer->tx_queue, first,
+ wg->packet_crypt_wq,
+ &wg->encrypt_queue.last_cpu);
if (unlikely(ret == -EPIPE))
- queue_enqueue_per_peer(&peer->tx_queue, first, PACKET_STATE_DEAD);
+ queue_enqueue_per_peer(&peer->tx_queue, first,
+ PACKET_STATE_DEAD);
err:
rcu_read_unlock_bh();
if (likely(!ret || ret == -EPIPE))
@@ -287,8 +351,8 @@ err:
void packet_send_staged_packets(struct wireguard_peer *peer)
{
- struct noise_keypair *keypair;
struct noise_symmetric_key *key;
+ struct noise_keypair *keypair;
struct sk_buff_head packets;
struct sk_buff *skb;
@@ -302,7 +366,8 @@ void packet_send_staged_packets(struct wireguard_peer *peer)
/* First we make sure we have a valid reference to a valid key. */
rcu_read_lock_bh();
- keypair = noise_keypair_get(rcu_dereference_bh(peer->keypairs.current_keypair));
+ keypair = noise_keypair_get(
+ rcu_dereference_bh(peer->keypairs.current_keypair));
rcu_read_unlock_bh();
if (unlikely(!keypair))
goto out_nokey;
@@ -312,13 +377,17 @@ void packet_send_staged_packets(struct wireguard_peer *peer)
if (unlikely(has_expired(key->birthdate, REJECT_AFTER_TIME)))
goto out_invalid;
- /* After we know we have a somewhat valid key, we now try to assign nonces to
- * all of the packets in the queue. If we can't assign nonces for all of them,
- * we just consider it a failure and wait for the next handshake.
+ /* After we know we have a somewhat valid key, we now try to assign
+ * nonces to all of the packets in the queue. If we can't assign nonces
+ * for all of them, we just consider it a failure and wait for the next
+ * handshake.
*/
- skb_queue_walk(&packets, skb) {
- PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0 /* No outer TOS: no leak. TODO: should we use flowi->tos as outer? */, ip_hdr(skb), skb);
- PACKET_CB(skb)->nonce = atomic64_inc_return(&key->counter.counter) - 1;
+ skb_queue_walk (&packets, skb) {
+ /* 0 for no outer TOS: no leak. TODO: should we use flowi->tos
+ * as outer? */
+ PACKET_CB(skb)->ds = ip_tunnel_ecn_encap(0, ip_hdr(skb), skb);
+ PACKET_CB(skb)->nonce =
+ atomic64_inc_return(&key->counter.counter) - 1;
if (unlikely(PACKET_CB(skb)->nonce >= REJECT_AFTER_MESSAGES))
goto out_invalid;
}
@@ -337,19 +406,19 @@ out_nokey:
/* We orphan the packets if we're waiting on a handshake, so that they
* don't block a socket's pool.
*/
- skb_queue_walk(&packets, skb)
+ skb_queue_walk (&packets, skb)
skb_orphan(skb);
- /* Then we put them back on the top of the queue. We're not too concerned about
- * accidentally getting things a little out of order if packets are being added
- * really fast, because this queue is for before packets can even be sent and
- * it's small anyway.
+ /* Then we put them back on the top of the queue. We're not too
+ * concerned about accidentally getting things a little out of order if
+ * packets are being added really fast, because this queue is for before
+ * packets can even be sent and it's small anyway.
*/
spin_lock_bh(&peer->staged_packet_queue.lock);
skb_queue_splice(&packets, &peer->staged_packet_queue);
spin_unlock_bh(&peer->staged_packet_queue.lock);
- /* If we're exiting because there's something wrong with the key, it means
- * we should initiate a new handshake.
+ /* If we're exiting because there's something wrong with the key, it
+ * means we should initiate a new handshake.
*/
packet_send_queued_handshake_initiation(peer, false);
}
diff --git a/src/socket.c b/src/socket.c
index 663a462..f544dd0 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -17,7 +17,9 @@
#include <net/udp_tunnel.h>
#include <net/ipv6.h>
-static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, u8 ds, struct dst_cache *cache)
+static inline int send4(struct wireguard_device *wg, struct sk_buff *skb,
+ struct endpoint *endpoint, u8 ds,
+ struct dst_cache *cache)
{
struct flowi4 fl = {
.saddr = endpoint->src4.s_addr,
@@ -49,14 +51,21 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
if (!rt) {
security_sk_classify_flow(sock, flowi4_to_flowi(&fl));
- if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0, fl.saddr, RT_SCOPE_HOST))) {
- endpoint->src4.s_addr = *(__force __be32 *)&endpoint->src_if4 = fl.saddr = 0;
+ if (unlikely(!inet_confirm_addr(sock_net(sock), NULL, 0,
+ fl.saddr, RT_SCOPE_HOST))) {
+ endpoint->src4.s_addr = 0;
+ *(__force __be32 *)&endpoint->src_if4 = 0;
+ fl.saddr = 0;
if (cache)
dst_cache_reset(cache);
}
rt = ip_route_output_flow(sock_net(sock), &fl, sock);
- if (unlikely(endpoint->src_if4 && ((IS_ERR(rt) && PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) && rt->dst.dev->ifindex != endpoint->src_if4)))) {
- endpoint->src4.s_addr = *(__force __be32 *)&endpoint->src_if4 = fl.saddr = 0;
+ if (unlikely(endpoint->src_if4 && ((IS_ERR(rt) &&
+ PTR_ERR(rt) == -EINVAL) || (!IS_ERR(rt) &&
+ rt->dst.dev->ifindex != endpoint->src_if4)))) {
+ endpoint->src4.s_addr = 0;
+ *(__force __be32 *)&endpoint->src_if4 = 0;
+ fl.saddr = 0;
if (cache)
dst_cache_reset(cache);
if (!IS_ERR(rt))
@@ -65,18 +74,22 @@ static inline int send4(struct wireguard_device *wg, struct sk_buff *skb, struct
}
if (unlikely(IS_ERR(rt))) {
ret = PTR_ERR(rt);
- net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", wg->dev->name, &endpoint->addr, ret);
+ net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
+ wg->dev->name, &endpoint->addr, ret);
goto err;
} else if (unlikely(rt->dst.dev == skb->dev)) {
ip_rt_put(rt);
ret = -ELOOP;
- net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", wg->dev->name, &endpoint->addr);
+ net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n",
+ wg->dev->name, &endpoint->addr);
goto err;
}
if (cache)
dst_cache_set_ip4(cache, &rt->dst, fl.saddr);
}
- udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds, ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport, fl.fl4_dport, false, false);
+ udp_tunnel_xmit_skb(rt, sock, skb, fl.saddr, fl.daddr, ds,
+ ip4_dst_hoplimit(&rt->dst), 0, fl.fl4_sport,
+ fl.fl4_dport, false, false);
goto out;
err:
@@ -86,7 +99,9 @@ out:
return ret;
}
-static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct endpoint *endpoint, u8 ds, struct dst_cache *cache)
+static inline int send6(struct wireguard_device *wg, struct sk_buff *skb,
+ struct endpoint *endpoint, u8 ds,
+ struct dst_cache *cache)
{
#if IS_ENABLED(CONFIG_IPV6)
struct flowi6 fl = {
@@ -121,26 +136,32 @@ static inline int send6(struct wireguard_device *wg, struct sk_buff *skb, struct
if (!dst) {
security_sk_classify_flow(sock, flowi6_to_flowi(&fl));
- if (unlikely(!ipv6_addr_any(&fl.saddr) && !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) {
+ if (unlikely(!ipv6_addr_any(&fl.saddr) &&
+ !ipv6_chk_addr(sock_net(sock), &fl.saddr, NULL, 0))) {
endpoint->src6 = fl.saddr = in6addr_any;
if (cache)
dst_cache_reset(cache);
}
- ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst, &fl);
+ ret = ipv6_stub->ipv6_dst_lookup(sock_net(sock), sock, &dst,
+ &fl);
if (unlikely(ret)) {
- net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n", wg->dev->name, &endpoint->addr, ret);
+ net_dbg_ratelimited("%s: No route to %pISpfsc, error %d\n",
+ wg->dev->name, &endpoint->addr, ret);
goto err;
} else if (unlikely(dst->dev == skb->dev)) {
dst_release(dst);
ret = -ELOOP;
- net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n", wg->dev->name, &endpoint->addr);
+ net_dbg_ratelimited("%s: Avoiding routing loop to %pISpfsc\n",
+ wg->dev->name, &endpoint->addr);
goto err;
}
if (cache)
dst_cache_set_ip6(cache, dst, &fl.saddr);
}
- udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds, ip6_dst_hoplimit(dst), 0, fl.fl6_sport, fl.fl6_dport, false);
+ udp_tunnel6_xmit_skb(dst, sock, skb, skb->dev, &fl.saddr, &fl.daddr, ds,
+ ip6_dst_hoplimit(dst), 0, fl.fl6_sport,
+ fl.fl6_dport, false);
goto out;
err:
@@ -153,16 +174,19 @@ out:
#endif
}
-int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds)
+int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb,
+ u8 ds)
{
size_t skb_len = skb->len;
int ret = -EAFNOSUPPORT;
read_lock_bh(&peer->endpoint_lock);
if (peer->endpoint.addr.sa_family == AF_INET)
- ret = send4(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache);
+ ret = send4(peer->device, skb, &peer->endpoint, ds,
+ &peer->endpoint_cache);
else if (peer->endpoint.addr.sa_family == AF_INET6)
- ret = send6(peer->device, skb, &peer->endpoint, ds, &peer->endpoint_cache);
+ ret = send6(peer->device, skb, &peer->endpoint, ds,
+ &peer->endpoint_cache);
else
dev_kfree_skb(skb);
if (likely(!ret))
@@ -172,7 +196,8 @@ int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8
return ret;
}
-int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t len, u8 ds)
+int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer,
+ size_t len, u8 ds)
{
struct sk_buff *skb = alloc_skb(len + SKB_HEADER_LEN, GFP_ATOMIC);
@@ -185,7 +210,9 @@ int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *buffer, size_t
return socket_send_skb_to_peer(peer, skb, ds);
}
-int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *buffer, size_t len)
+int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg,
+ struct sk_buff *in_skb, void *buffer,
+ size_t len)
{
int ret = 0;
struct sk_buff *skb;
@@ -208,12 +235,15 @@ int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_bu
ret = send4(wg, skb, &endpoint, 0, NULL);
else if (endpoint.addr.sa_family == AF_INET6)
ret = send6(wg, skb, &endpoint, 0, NULL);
- /* No other possibilities if the endpoint is valid, which it is, as we checked above. */
+ /* No other possibilities if the endpoint is valid, which it is,
+ * as we checked above.
+ */
return ret;
}
-int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb)
+int socket_endpoint_from_skb(struct endpoint *endpoint,
+ const struct sk_buff *skb)
{
memset(endpoint, 0, sizeof(struct endpoint));
if (skb->protocol == htons(ETH_P_IP)) {
@@ -226,25 +256,32 @@ int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *sk
endpoint->addr6.sin6_family = AF_INET6;
endpoint->addr6.sin6_port = udp_hdr(skb)->source;
endpoint->addr6.sin6_addr = ipv6_hdr(skb)->saddr;
- endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id(&ipv6_hdr(skb)->saddr, skb->skb_iif);
+ endpoint->addr6.sin6_scope_id = ipv6_iface_scope_id(
+ &ipv6_hdr(skb)->saddr, skb->skb_iif);
endpoint->src6 = ipv6_hdr(skb)->daddr;
} else
return -EINVAL;
return 0;
}
-static inline bool endpoint_eq(const struct endpoint *a, const struct endpoint *b)
+static inline bool endpoint_eq(const struct endpoint *a,
+ const struct endpoint *b)
{
return (a->addr.sa_family == AF_INET && b->addr.sa_family == AF_INET &&
- a->addr4.sin_port == b->addr4.sin_port && a->addr4.sin_addr.s_addr == b->addr4.sin_addr.s_addr &&
+ a->addr4.sin_port == b->addr4.sin_port &&
+ a->addr4.sin_addr.s_addr == b->addr4.sin_addr.s_addr &&
a->src4.s_addr == b->src4.s_addr && a->src_if4 == b->src_if4) ||
- (a->addr.sa_family == AF_INET6 && b->addr.sa_family == AF_INET6 &&
- a->addr6.sin6_port == b->addr6.sin6_port && ipv6_addr_equal(&a->addr6.sin6_addr, &b->addr6.sin6_addr) &&
- a->addr6.sin6_scope_id == b->addr6.sin6_scope_id && ipv6_addr_equal(&a->src6, &b->src6)) ||
+ (a->addr.sa_family == AF_INET6 &&
+ b->addr.sa_family == AF_INET6 &&
+ a->addr6.sin6_port == b->addr6.sin6_port &&
+ ipv6_addr_equal(&a->addr6.sin6_addr, &b->addr6.sin6_addr) &&
+ a->addr6.sin6_scope_id == b->addr6.sin6_scope_id &&
+ ipv6_addr_equal(&a->src6, &b->src6)) ||
unlikely(!a->addr.sa_family && !b->addr.sa_family);
}
-void socket_set_peer_endpoint(struct wireguard_peer *peer, const struct endpoint *endpoint)
+void socket_set_peer_endpoint(struct wireguard_peer *peer,
+ const struct endpoint *endpoint)
{
/* First we check unlocked, in order to optimize, since it's pretty rare
* that an endpoint will change. If we happen to be mid-write, and two
@@ -268,7 +305,8 @@ out:
write_unlock_bh(&peer->endpoint_lock);
}
-void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, const struct sk_buff *skb)
+void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer,
+ const struct sk_buff *skb)
{
struct endpoint endpoint;
@@ -362,7 +400,8 @@ retry:
udp_tunnel_sock_release(new4);
if (ret == -EADDRINUSE && !port && retries++ < 100)
goto retry;
- pr_err("%s: Could not create IPv6 socket\n", wg->dev->name);
+ pr_err("%s: Could not create IPv6 socket\n",
+ wg->dev->name);
return ret;
}
set_sock_opts(new6);
@@ -374,13 +413,16 @@ retry:
return 0;
}
-void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6)
+void socket_reinit(struct wireguard_device *wg, struct sock *new4,
+ struct sock *new6)
{
struct sock *old4, *old6;
mutex_lock(&wg->socket_update_lock);
- old4 = rcu_dereference_protected(wg->sock4, lockdep_is_held(&wg->socket_update_lock));
- old6 = rcu_dereference_protected(wg->sock6, lockdep_is_held(&wg->socket_update_lock));
+ old4 = rcu_dereference_protected(wg->sock4,
+ lockdep_is_held(&wg->socket_update_lock));
+ old6 = rcu_dereference_protected(wg->sock6,
+ lockdep_is_held(&wg->socket_update_lock));
rcu_assign_pointer(wg->sock4, new4);
rcu_assign_pointer(wg->sock6, new6);
if (new4)
diff --git a/src/socket.h b/src/socket.h
index 9faba77..d873ffa 100644
--- a/src/socket.h
+++ b/src/socket.h
@@ -12,22 +12,31 @@
#include <linux/if_ether.h>
int socket_init(struct wireguard_device *wg, u16 port);
-void socket_reinit(struct wireguard_device *wg, struct sock *new4, struct sock *new6);
-int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data, size_t len, u8 ds);
-int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb, u8 ds);
-int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg, struct sk_buff *in_skb, void *out_buffer, size_t len);
+void socket_reinit(struct wireguard_device *wg, struct sock *new4,
+ struct sock *new6);
+int socket_send_buffer_to_peer(struct wireguard_peer *peer, void *data,
+ size_t len, u8 ds);
+int socket_send_skb_to_peer(struct wireguard_peer *peer, struct sk_buff *skb,
+ u8 ds);
+int socket_send_buffer_as_reply_to_skb(struct wireguard_device *wg,
+ struct sk_buff *in_skb, void *out_buffer,
+ size_t len);
-int socket_endpoint_from_skb(struct endpoint *endpoint, const struct sk_buff *skb);
-void socket_set_peer_endpoint(struct wireguard_peer *peer, const struct endpoint *endpoint);
-void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer, const struct sk_buff *skb);
+int socket_endpoint_from_skb(struct endpoint *endpoint,
+ const struct sk_buff *skb);
+void socket_set_peer_endpoint(struct wireguard_peer *peer,
+ const struct endpoint *endpoint);
+void socket_set_peer_endpoint_from_skb(struct wireguard_peer *peer,
+ const struct sk_buff *skb);
void socket_clear_peer_endpoint_src(struct wireguard_peer *peer);
#if defined(CONFIG_DYNAMIC_DEBUG) || defined(DEBUG)
-#define net_dbg_skb_ratelimited(fmt, dev, skb, ...) do { \
- struct endpoint __endpoint; \
- socket_endpoint_from_skb(&__endpoint, skb); \
- net_dbg_ratelimited(fmt, dev, &__endpoint.addr, ##__VA_ARGS__); \
-} while (0)
+#define net_dbg_skb_ratelimited(fmt, dev, skb, ...) do { \
+ struct endpoint __endpoint; \
+ socket_endpoint_from_skb(&__endpoint, skb); \
+ net_dbg_ratelimited(fmt, dev, &__endpoint.addr, \
+ ##__VA_ARGS__); \
+ } while (0)
#else
#define net_dbg_skb_ratelimited(fmt, skb, ...)
#endif
diff --git a/src/timers.c b/src/timers.c
index 7db892a..ad76466 100644
--- a/src/timers.c
+++ b/src/timers.c
@@ -10,22 +10,33 @@
#include "socket.h"
/*
- * Timer for retransmitting the handshake if we don't hear back after `REKEY_TIMEOUT + jitter` ms
- * Timer for sending empty packet if we have received a packet but after have not sent one for `KEEPALIVE_TIMEOUT` ms
- * Timer for initiating new handshake if we have sent a packet but after have not received one (even empty) for `(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT)` ms
- * Timer for zeroing out all ephemeral keys after `(REJECT_AFTER_TIME * 3)` ms if no new keys have been received
- * Timer for, if enabled, sending an empty authenticated packet every user-specified seconds
+ * - Timer for retransmitting the handshake if we don't hear back after
+ * `REKEY_TIMEOUT + jitter` ms.
+ *
+ * - Timer for sending empty packet if we have received a packet but after have
+ * not sent one for `KEEPALIVE_TIMEOUT` ms.
+ *
+ * - Timer for initiating new handshake if we have sent a packet but after have
+ * not received one (even empty) for `(KEEPALIVE_TIMEOUT + REKEY_TIMEOUT)` ms.
+ *
+ * - Timer for zeroing out all ephemeral keys after `(REJECT_AFTER_TIME * 3)` ms
+ * if no new keys have been received.
+ *
+ * - Timer for, if enabled, sending an empty authenticated packet every user-
+ * specified seconds.
*/
-#define peer_get_from_timer(timer_name) \
- struct wireguard_peer *peer; \
- rcu_read_lock_bh(); \
- peer = peer_get_maybe_zero(from_timer(peer, timer, timer_name)); \
- rcu_read_unlock_bh(); \
- if (unlikely(!peer)) \
+#define peer_get_from_timer(timer_name) \
+ struct wireguard_peer *peer; \
+ rcu_read_lock_bh(); \
+ peer = peer_get_maybe_zero(from_timer(peer, timer, timer_name)); \
+ rcu_read_unlock_bh(); \
+ if (unlikely(!peer)) \
return;
-static inline void mod_peer_timer(struct wireguard_peer *peer, struct timer_list *timer, unsigned long expires)
+static inline void mod_peer_timer(struct wireguard_peer *peer,
+ struct timer_list *timer,
+ unsigned long expires)
{
rcu_read_lock_bh();
if (likely(netif_running(peer->device->dev) && !peer->is_dead))
@@ -33,7 +44,8 @@ static inline void mod_peer_timer(struct wireguard_peer *peer, struct timer_list
rcu_read_unlock_bh();
}
-static inline void del_peer_timer(struct wireguard_peer *peer, struct timer_list *timer)
+static inline void del_peer_timer(struct wireguard_peer *peer,
+ struct timer_list *timer)
{
rcu_read_lock_bh();
if (likely(netif_running(peer->device->dev) && !peer->is_dead))
@@ -46,7 +58,9 @@ static void expired_retransmit_handshake(struct timer_list *timer)
peer_get_from_timer(timer_retransmit_handshake);
if (peer->timer_handshake_attempts > MAX_TIMER_HANDSHAKES) {
- pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2);
+ pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d attempts, giving up\n",
+ peer->device->dev->name, peer->internal_id,
+ &peer->endpoint.addr, MAX_TIMER_HANDSHAKES + 2);
del_peer_timer(peer, &peer->timer_send_keepalive);
/* We drop all packets without a keypair and don't try again,
@@ -58,12 +72,18 @@ static void expired_retransmit_handshake(struct timer_list *timer)
* of a partial exchange.
*/
if (!timer_pending(&peer->timer_zero_key_material))
- mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ);
+ mod_peer_timer(peer, &peer->timer_zero_key_material,
+ jiffies + REJECT_AFTER_TIME * 3 * HZ);
} else {
++peer->timer_handshake_attempts;
- pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, REKEY_TIMEOUT, peer->timer_handshake_attempts + 1);
+ pr_debug("%s: Handshake for peer %llu (%pISpfsc) did not complete after %d seconds, retrying (try %d)\n",
+ peer->device->dev->name, peer->internal_id,
+ &peer->endpoint.addr, REKEY_TIMEOUT,
+ peer->timer_handshake_attempts + 1);
- /* We clear the endpoint address src address, in case this is the cause of trouble. */
+ /* We clear the endpoint address src address, in case this is
+ * the cause of trouble.
+ */
socket_clear_peer_endpoint_src(peer);
packet_send_queued_handshake_initiation(peer, true);
@@ -78,7 +98,8 @@ static void expired_send_keepalive(struct timer_list *timer)
packet_send_keepalive(peer);
if (peer->timer_need_another_keepalive) {
peer->timer_need_another_keepalive = false;
- mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ);
+ mod_peer_timer(peer, &peer->timer_send_keepalive,
+ jiffies + KEEPALIVE_TIMEOUT * HZ);
}
peer_put(peer);
}
@@ -87,8 +108,12 @@ static void expired_new_handshake(struct timer_list *timer)
{
peer_get_from_timer(timer_new_handshake);
- pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
- /* We clear the endpoint address src address, in case this is the cause of trouble. */
+ pr_debug("%s: Retrying handshake with peer %llu (%pISpfsc) because we stopped hearing back after %d seconds\n",
+ peer->device->dev->name, peer->internal_id,
+ &peer->endpoint.addr, KEEPALIVE_TIMEOUT + REKEY_TIMEOUT);
+ /* We clear the endpoint address src address, in case this is the cause
+ * of trouble.
+ */
socket_clear_peer_endpoint_src(peer);
packet_send_queued_handshake_initiation(peer, false);
peer_put(peer);
@@ -100,16 +125,22 @@ static void expired_zero_key_material(struct timer_list *timer)
rcu_read_lock_bh();
if (!peer->is_dead) {
- if (!queue_work(peer->device->handshake_send_wq, &peer->clear_peer_work)) /* Should take our reference. */
- peer_put(peer); /* If the work was already on the queue, we want to drop the extra reference */
+ /* Should take our reference. */
+ if (!queue_work(peer->device->handshake_send_wq,
+ &peer->clear_peer_work))
+ /* If the work was already on the queue, we want to drop the extra reference */
+ peer_put(peer);
}
rcu_read_unlock_bh();
}
static void queued_expired_zero_key_material(struct work_struct *work)
{
- struct wireguard_peer *peer = container_of(work, struct wireguard_peer, clear_peer_work);
+ struct wireguard_peer *peer =
+ container_of(work, struct wireguard_peer, clear_peer_work);
- pr_debug("%s: Zeroing out all keys for peer %llu (%pISpfsc), since we haven't received a new one in %d seconds\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr, REJECT_AFTER_TIME * 3);
+ pr_debug("%s: Zeroing out all keys for peer %llu (%pISpfsc), since we haven't received a new one in %d seconds\n",
+ peer->device->dev->name, peer->internal_id,
+ &peer->endpoint.addr, REJECT_AFTER_TIME * 3);
noise_handshake_clear(&peer->handshake);
noise_keypairs_clear(&peer->keypairs);
peer_put(peer);
@@ -128,7 +159,8 @@ static void expired_send_persistent_keepalive(struct timer_list *timer)
void timers_data_sent(struct wireguard_peer *peer)
{
if (!timer_pending(&peer->timer_new_handshake))
- mod_peer_timer(peer, &peer->timer_new_handshake, jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ);
+ mod_peer_timer(peer, &peer->timer_new_handshake,
+ jiffies + (KEEPALIVE_TIMEOUT + REKEY_TIMEOUT) * HZ);
}
/* Should be called after an authenticated data packet is received. */
@@ -136,19 +168,24 @@ void timers_data_received(struct wireguard_peer *peer)
{
if (likely(netif_running(peer->device->dev))) {
if (!timer_pending(&peer->timer_send_keepalive))
- mod_peer_timer(peer, &peer->timer_send_keepalive, jiffies + KEEPALIVE_TIMEOUT * HZ);
+ mod_peer_timer(peer, &peer->timer_send_keepalive,
+ jiffies + KEEPALIVE_TIMEOUT * HZ);
else
peer->timer_need_another_keepalive = true;
}
}
-/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */
+/* Should be called after any type of authenticated packet is sent, whether
+ * keepalive, data, or handshake.
+*/
void timers_any_authenticated_packet_sent(struct wireguard_peer *peer)
{
del_peer_timer(peer, &peer->timer_send_keepalive);
}
-/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */
+/* Should be called after any type of authenticated packet is received, whether
+ * keepalive, data, or handshake.
+ */
void timers_any_authenticated_packet_received(struct wireguard_peer *peer)
{
del_peer_timer(peer, &peer->timer_new_handshake);
@@ -157,10 +194,15 @@ void timers_any_authenticated_packet_received(struct wireguard_peer *peer)
/* Should be called after a handshake initiation message is sent. */
void timers_handshake_initiated(struct wireguard_peer *peer)
{
- mod_peer_timer(peer, &peer->timer_retransmit_handshake, jiffies + REKEY_TIMEOUT * HZ + prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES));
+ mod_peer_timer(
+ peer, &peer->timer_retransmit_handshake,
+ jiffies + REKEY_TIMEOUT * HZ +
+ prandom_u32_max(REKEY_TIMEOUT_JITTER_MAX_JIFFIES));
}
-/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
+/* Should be called after a handshake response message is received and processed
+ * or when getting key confirmation via the first data message.
+ */
void timers_handshake_complete(struct wireguard_peer *peer)
{
del_peer_timer(peer, &peer->timer_retransmit_handshake);
@@ -169,26 +211,34 @@ void timers_handshake_complete(struct wireguard_peer *peer)
getnstimeofday(&peer->walltime_last_handshake);
}
-/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
+/* Should be called after an ephemeral key is created, which is before sending a
+ * handshake response or after receiving a handshake response.
+ */
void timers_session_derived(struct wireguard_peer *peer)
{
- mod_peer_timer(peer, &peer->timer_zero_key_material, jiffies + REJECT_AFTER_TIME * 3 * HZ);
+ mod_peer_timer(peer, &peer->timer_zero_key_material,
+ jiffies + REJECT_AFTER_TIME * 3 * HZ);
}
-/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
+/* Should be called before a packet with authentication, whether
+ * keepalive, data, or handshakem is sent, or after one is received.
+ */
void timers_any_authenticated_packet_traversal(struct wireguard_peer *peer)
{
if (peer->persistent_keepalive_interval)
- mod_peer_timer(peer, &peer->timer_persistent_keepalive, jiffies + peer->persistent_keepalive_interval * HZ);
+ mod_peer_timer(peer, &peer->timer_persistent_keepalive,
+ jiffies + peer->persistent_keepalive_interval * HZ);
}
void timers_init(struct wireguard_peer *peer)
{
- timer_setup(&peer->timer_retransmit_handshake, expired_retransmit_handshake, 0);
+ timer_setup(&peer->timer_retransmit_handshake,
+ expired_retransmit_handshake, 0);
timer_setup(&peer->timer_send_keepalive, expired_send_keepalive, 0);
timer_setup(&peer->timer_new_handshake, expired_new_handshake, 0);
timer_setup(&peer->timer_zero_key_material, expired_zero_key_material, 0);
- timer_setup(&peer->timer_persistent_keepalive, expired_send_persistent_keepalive, 0);
+ timer_setup(&peer->timer_persistent_keepalive,
+ expired_send_persistent_keepalive, 0);
INIT_WORK(&peer->clear_peer_work, queued_expired_zero_key_material);
peer->timer_handshake_attempts = 0;
peer->sent_lastminute_handshake = false;
diff --git a/src/timers.h b/src/timers.h
index c95bf31..483529c 100644
--- a/src/timers.h
+++ b/src/timers.h
@@ -23,7 +23,8 @@ void timers_any_authenticated_packet_traversal(struct wireguard_peer *peer);
static inline bool has_expired(u64 birthday_nanoseconds, u64 expiration_seconds)
{
- return (s64)(birthday_nanoseconds + expiration_seconds * NSEC_PER_SEC) <= (s64)ktime_get_boot_fast_ns();
+ return (s64)(birthday_nanoseconds + expiration_seconds * NSEC_PER_SEC)
+ <= (s64)ktime_get_boot_fast_ns();
}
#endif /* _WG_TIMERS_H */