diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2017-04-27 11:10:50 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2017-05-17 18:07:42 +0200 |
commit | 41e7aa153984364087a9ef07eca02c72961825c7 (patch) | |
tree | 8d8e81c26bbc77e387194835b9e9467a1d58c498 /src | |
parent | a2223db43496574b9211590f0dea09a718c6ca62 (diff) |
noise: redesign preshared key mode
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'src')
-rw-r--r-- | src/config.c | 32 | ||||
-rw-r--r-- | src/cookie.c | 66 | ||||
-rw-r--r-- | src/cookie.h | 5 | ||||
-rw-r--r-- | src/noise.c | 177 | ||||
-rw-r--r-- | src/noise.h | 9 | ||||
-rw-r--r-- | src/peer.c | 6 | ||||
-rw-r--r-- | src/peer.h | 2 | ||||
-rwxr-xr-x | src/tests/netns.sh | 4 | ||||
-rw-r--r-- | src/uapi.h | 11 |
9 files changed, 151 insertions, 161 deletions
diff --git a/src/config.c b/src/config.c index a5a25c9..46ee2f1 100644 --- a/src/config.c +++ b/src/config.c @@ -60,7 +60,9 @@ static int set_peer(struct wireguard_device *wg, void __user *user_peer, size_t peer = pubkey_hashtable_lookup(&wg->peer_hashtable, in_peer.public_key); if (!peer) { /* Peer doesn't exist yet. Add a new one. */ if (in_peer.flags & WGPEER_REMOVE_ME) - return -ENODEV; /* Tried to remove a non existing peer. */ + return -ENODEV; /* Tried to remove a non-existing peer. */ + if (in_peer.flags & WGPEER_REMOVE_PRESHARED_KEY) + return -EINVAL; /* Tried to remove a psk for a non-existing peer. */ down_read(&wg->static_identity.lock); if (wg->static_identity.has_identity && !memcmp(in_peer.public_key, wg->static_identity.static_public, NOISE_PUBLIC_KEY_LEN)) { @@ -71,7 +73,7 @@ static int set_peer(struct wireguard_device *wg, void __user *user_peer, size_t } up_read(&wg->static_identity.lock); - peer = peer_rcu_get(peer_create(wg, in_peer.public_key)); + peer = peer_rcu_get(peer_create(wg, in_peer.public_key, in_peer.preshared_key)); if (!peer) return -ENOMEM; if (netdev_pub(wg)->flags & IFF_UP) @@ -84,6 +86,16 @@ static int set_peer(struct wireguard_device *wg, void __user *user_peer, size_t goto out; } + if (in_peer.flags & WGPEER_REMOVE_PRESHARED_KEY) { + down_write(&peer->handshake.lock); + memset(&peer->handshake.preshared_key, 0, NOISE_SYMMETRIC_KEY_LEN); + up_write(&peer->handshake.lock); + } else if (memcmp(zeros, in_peer.preshared_key, WG_KEY_LEN)) { + down_write(&peer->handshake.lock); + memcpy(&peer->handshake.preshared_key, in_peer.preshared_key, NOISE_SYMMETRIC_KEY_LEN); + up_write(&peer->handshake.lock); + } + if (in_peer.endpoint.addr.sa_family == AF_INET || in_peer.endpoint.addr.sa_family == AF_INET6) { struct endpoint endpoint = { { { 0 } } }; memcpy(&endpoint, &in_peer.endpoint, sizeof(in_peer.endpoint)); @@ -170,16 +182,8 @@ int config_set_device(struct wireguard_device *wg, void __user *user_device) modified_static_identity = true; } - if (in_device.flags & WGDEVICE_REMOVE_PRESHARED_KEY) { - noise_set_static_identity_preshared_key(&wg->static_identity, NULL); - modified_static_identity = true; - } else if (memcmp(zeros, in_device.preshared_key, WG_KEY_LEN)) { - noise_set_static_identity_preshared_key(&wg->static_identity, in_device.preshared_key); - modified_static_identity = true; - } - if (modified_static_identity) - cookie_checker_precompute_keys(&wg->cookie_checker, NULL); + cookie_checker_precompute_device_keys(&wg->cookie_checker); for (i = 0, offset = 0, user_peer = user_device + sizeof(struct wgdevice); i < in_device.num_peers; ++i, user_peer += offset) { ret = set_peer(wg, user_peer, &offset); @@ -249,7 +253,11 @@ static int populate_peer(struct wireguard_peer *peer, void *ctx) if (ret) return ret; + down_read(&peer->handshake.lock); memcpy(out_peer.public_key, peer->handshake.remote_static, NOISE_PUBLIC_KEY_LEN); + memcpy(out_peer.preshared_key, peer->handshake.preshared_key, NOISE_SYMMETRIC_KEY_LEN); + up_read(&peer->handshake.lock); + read_lock_bh(&peer->endpoint_lock); if (peer->endpoint.addr.sa_family == AF_INET) out_peer.endpoint.addr4 = peer->endpoint.addr4; @@ -315,8 +323,6 @@ int config_get_device(struct wireguard_device *wg, void __user *user_device) memcpy(out_device.private_key, wg->static_identity.static_private, WG_KEY_LEN); memcpy(out_device.public_key, wg->static_identity.static_public, WG_KEY_LEN); } - if (wg->static_identity.has_psk) - memcpy(out_device.preshared_key, wg->static_identity.preshared_key, WG_KEY_LEN); up_read(&wg->static_identity.lock); peer_data.out_len = in_device.peers_size; diff --git a/src/cookie.c b/src/cookie.c index 6ecc02a..2046493 100644 --- a/src/cookie.c +++ b/src/cookie.c @@ -23,31 +23,39 @@ int cookie_checker_init(struct cookie_checker *checker, struct wireguard_device return 0; } -static int precompute_peer_key(struct wireguard_peer *peer, void *psk) +enum { COOKIE_KEY_LABEL_LEN = 8 }; +static const u8 mac1_key_label[COOKIE_KEY_LABEL_LEN] = "mac1----"; +static const u8 cookie_key_label[COOKIE_KEY_LABEL_LEN] = "cookie--"; + +static void precompute_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 pubkey[NOISE_PUBLIC_KEY_LEN], const u8 label[COOKIE_KEY_LABEL_LEN]) { - blake2s(peer->latest_cookie.cookie_decryption_key, peer->handshake.remote_static, psk, NOISE_SYMMETRIC_KEY_LEN, NOISE_PUBLIC_KEY_LEN, psk ? NOISE_SYMMETRIC_KEY_LEN : 0); - return 0; + struct blake2s_state blake; + blake2s_init(&blake, NOISE_SYMMETRIC_KEY_LEN); + blake2s_update(&blake, label, COOKIE_KEY_LABEL_LEN); + blake2s_update(&blake, pubkey, NOISE_PUBLIC_KEY_LEN); + blake2s_final(&blake, key, NOISE_SYMMETRIC_KEY_LEN); } -void cookie_checker_precompute_keys(struct cookie_checker *checker, struct wireguard_peer *peer) +void cookie_checker_precompute_device_keys(struct cookie_checker *checker) { down_read(&checker->device->static_identity.lock); - if (unlikely(!checker->device->static_identity.has_identity)) { - memset(checker->cookie_encryption_key, 0, NOISE_SYMMETRIC_KEY_LEN); - goto out; + if (likely(checker->device->static_identity.has_identity)) { + precompute_key(checker->cookie_encryption_key, checker->device->static_identity.static_public, cookie_key_label); + precompute_key(checker->message_mac1_key, checker->device->static_identity.static_public, mac1_key_label); } - - if (peer) - precompute_peer_key(peer, checker->device->static_identity.has_psk ? checker->device->static_identity.preshared_key : NULL); else { - blake2s(checker->cookie_encryption_key, checker->device->static_identity.static_public, checker->device->static_identity.preshared_key, NOISE_SYMMETRIC_KEY_LEN, NOISE_PUBLIC_KEY_LEN, checker->device->static_identity.has_psk ? NOISE_SYMMETRIC_KEY_LEN : 0); - peer_for_each_unlocked(checker->device, precompute_peer_key, checker->device->static_identity.has_psk ? checker->device->static_identity.preshared_key : NULL); + memset(checker->cookie_encryption_key, 0, NOISE_SYMMETRIC_KEY_LEN); + memset(checker->message_mac1_key, 0, NOISE_SYMMETRIC_KEY_LEN); } - -out: up_read(&checker->device->static_identity.lock); } +void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer) +{ + precompute_key(peer->latest_cookie.cookie_decryption_key, peer->handshake.remote_static, cookie_key_label); + precompute_key(peer->latest_cookie.message_mac1_key, peer->handshake.remote_static, mac1_key_label); +} + void cookie_checker_uninit(struct cookie_checker *checker) { ratelimiter_uninit(&checker->ratelimiter); @@ -59,18 +67,10 @@ void cookie_init(struct cookie *cookie) init_rwsem(&cookie->lock); } -static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, const u8 pubkey[NOISE_PUBLIC_KEY_LEN], const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) +static void compute_mac1(u8 mac1[COOKIE_LEN], const void *message, size_t len, const u8 key[NOISE_SYMMETRIC_KEY_LEN]) { - struct blake2s_state state; len = len - sizeof(struct message_macs) + offsetof(struct message_macs, mac1); - - if (psk) - blake2s_init_key(&state, COOKIE_LEN, psk, NOISE_SYMMETRIC_KEY_LEN); - else - blake2s_init(&state, COOKIE_LEN); - blake2s_update(&state, pubkey, NOISE_PUBLIC_KEY_LEN); - blake2s_update(&state, message, len); - blake2s_final(&state, mac1, COOKIE_LEN); + blake2s(mac1, message, key, COOKIE_LEN, len, NOISE_SYMMETRIC_KEY_LEN); } static void compute_mac2(u8 mac2[COOKIE_LEN], const void *message, size_t len, const u8 cookie[COOKIE_LEN]) @@ -111,13 +111,7 @@ enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, str struct message_macs *macs = (struct message_macs *)(skb->data + skb->len - sizeof(struct message_macs)); ret = INVALID_MAC; - down_read(&checker->device->static_identity.lock); - if (unlikely(!checker->device->static_identity.has_identity)) { - up_read(&checker->device->static_identity.lock); - goto out; - } - compute_mac1(computed_mac, skb->data, skb->len, checker->device->static_identity.static_public, checker->device->static_identity.has_psk ? checker->device->static_identity.preshared_key : NULL); - up_read(&checker->device->static_identity.lock); + compute_mac1(computed_mac, skb->data, skb->len, checker->message_mac1_key); if (crypto_memneq(computed_mac, macs->mac1, COOKIE_LEN)) goto out; @@ -146,16 +140,8 @@ void cookie_add_mac_to_packet(void *message, size_t len, struct wireguard_peer * { struct message_macs *macs = (struct message_macs *)((u8 *)message + len - sizeof(struct message_macs)); - down_read(&peer->device->static_identity.lock); - if (unlikely(!peer->device->static_identity.has_identity)) { - memset(macs, 0, sizeof(struct message_macs)); - up_read(&peer->device->static_identity.lock); - return; - } - compute_mac1(macs->mac1, message, len, peer->handshake.remote_static, peer->device->static_identity.has_psk ? peer->device->static_identity.preshared_key : NULL); - up_read(&peer->device->static_identity.lock); - down_write(&peer->latest_cookie.lock); + compute_mac1(macs->mac1, message, len, peer->latest_cookie.message_mac1_key); memcpy(peer->latest_cookie.last_mac1_sent, macs->mac1, COOKIE_LEN); peer->latest_cookie.have_sent_mac1 = true; up_write(&peer->latest_cookie.lock); diff --git a/src/cookie.h b/src/cookie.h index 87a0e5a..c87d3dd 100644 --- a/src/cookie.h +++ b/src/cookie.h @@ -14,6 +14,7 @@ struct sk_buff; struct cookie_checker { u8 secret[NOISE_HASH_LEN]; u8 cookie_encryption_key[NOISE_SYMMETRIC_KEY_LEN]; + u8 message_mac1_key[NOISE_SYMMETRIC_KEY_LEN]; u64 secret_birthdate; struct rw_semaphore secret_lock; struct ratelimiter ratelimiter; @@ -27,6 +28,7 @@ struct cookie { bool have_sent_mac1; u8 last_mac1_sent[COOKIE_LEN]; u8 cookie_decryption_key[NOISE_SYMMETRIC_KEY_LEN]; + u8 message_mac1_key[NOISE_SYMMETRIC_KEY_LEN]; struct rw_semaphore lock; }; @@ -39,7 +41,8 @@ enum cookie_mac_state { int cookie_checker_init(struct cookie_checker *checker, struct wireguard_device *wg); void cookie_checker_uninit(struct cookie_checker *checker); -void cookie_checker_precompute_keys(struct cookie_checker *checker, struct wireguard_peer *peer); +void cookie_checker_precompute_device_keys(struct cookie_checker *checker); +void cookie_checker_precompute_peer_keys(struct wireguard_peer *peer); void cookie_init(struct cookie *cookie); enum cookie_mac_state cookie_validate_packet(struct cookie_checker *checker, struct sk_buff *skb, bool check_cookie); diff --git a/src/noise.c b/src/noise.c index d6c6398..6e5db8c 100644 --- a/src/noise.c +++ b/src/noise.c @@ -14,34 +14,38 @@ #include <linux/highmem.h> #include <crypto/algapi.h> -/* This implements Noise_IK: +/* This implements Noise_IKpsk2: * * <- s * ****** - * -> e, es, s, ss, t - * <- e, ee, se + * -> e, es, s, ss, {t} + * <- e, ee, se, psk, {} */ -static const u8 handshake_name[33] = "Noise_IK_25519_ChaChaPoly_BLAKE2s"; -static const u8 handshake_psk_name[36] = "NoisePSK_IK_25519_ChaChaPoly_BLAKE2s"; -static u8 handshake_name_hash[NOISE_HASH_LEN] __read_mostly; -static u8 handshake_psk_name_hash[NOISE_HASH_LEN] __read_mostly; -static const u8 identifier_name[34] = "WireGuard v0 zx2c4 Jason@zx2c4.com"; +static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"; +static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com"; +static u8 handshake_init_hash[NOISE_HASH_LEN] __read_mostly; +static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __read_mostly; static atomic64_t keypair_counter = ATOMIC64_INIT(0); void noise_init(void) { - blake2s(handshake_name_hash, handshake_name, NULL, NOISE_HASH_LEN, sizeof(handshake_name), 0); - blake2s(handshake_psk_name_hash, handshake_psk_name, NULL, NOISE_HASH_LEN, sizeof(handshake_psk_name), 0); + struct blake2s_state blake; + blake2s(handshake_init_chaining_key, handshake_name, NULL, NOISE_HASH_LEN, sizeof(handshake_name), 0); + blake2s_init(&blake, NOISE_HASH_LEN); + blake2s_update(&blake, handshake_init_chaining_key, NOISE_HASH_LEN); + blake2s_update(&blake, identifier_name, sizeof(identifier_name)); + blake2s_final(&blake, handshake_init_hash, NOISE_HASH_LEN); } -void noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], struct wireguard_peer *peer) +void noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer) { memset(handshake, 0, sizeof(struct noise_handshake)); init_rwsem(&handshake->lock); handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE; handshake->entry.peer = peer; memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN); + memcpy(handshake->preshared_key, peer_preshared_key, NOISE_SYMMETRIC_KEY_LEN); handshake->static_identity = static_identity; handshake->state = HANDSHAKE_ZEROED; } @@ -55,7 +59,6 @@ void noise_handshake_clear(struct noise_handshake *handshake) memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN); memset(&handshake->hash, 0, NOISE_HASH_LEN); memset(&handshake->chaining_key, 0, NOISE_HASH_LEN); - memset(&handshake->key, 0, NOISE_SYMMETRIC_KEY_LEN); handshake->remote_index = 0; handshake->state = HANDSHAKE_ZEROED; up_write(&handshake->lock); @@ -198,44 +201,44 @@ void noise_set_static_identity_private_key(struct noise_static_identity *static_ up_write(&static_identity->lock); } -void noise_set_static_identity_preshared_key(struct noise_static_identity *static_identity, const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]) -{ - down_write(&static_identity->lock); - if (preshared_key) { - memcpy(static_identity->preshared_key, preshared_key, NOISE_SYMMETRIC_KEY_LEN); - static_identity->has_psk = true; - } else { - memset(static_identity->preshared_key, 0, NOISE_SYMMETRIC_KEY_LEN); - static_identity->has_psk = false; - } - up_write(&static_identity->lock); -} - /* This is Hugo Krawczyk's HKDF: * - https://eprint.iacr.org/2010/264.pdf * - https://tools.ietf.org/html/rfc5869 */ -static void kdf(u8 *first_dst, u8 *second_dst, const u8 *data, - size_t first_len, size_t second_len, size_t data_len, - const u8 chaining_key[NOISE_HASH_LEN]) +static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data, size_t first_len, size_t second_len, size_t third_len, size_t data_len, const u8 chaining_key[NOISE_HASH_LEN]) { u8 secret[BLAKE2S_OUTBYTES]; u8 output[BLAKE2S_OUTBYTES + 1]; - BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES); + BUG_ON(first_len > BLAKE2S_OUTBYTES || second_len > BLAKE2S_OUTBYTES || third_len > BLAKE2S_OUTBYTES || ((second_len || second_dst || third_len || third_dst) && (!first_len || !first_dst)) || ((third_len || third_dst) && (!second_len || !second_dst))); /* Extract entropy from data into secret */ blake2s_hmac(secret, data, chaining_key, BLAKE2S_OUTBYTES, data_len, NOISE_HASH_LEN); + if (!first_dst || !first_len) + goto out; + /* Expand first key: key = secret, data = 0x1 */ output[0] = 1; blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, 1, BLAKE2S_OUTBYTES); memcpy(first_dst, output, first_len); + if (!second_dst || !second_len) + goto out; + /* Expand second key: key = secret, data = first-key || 0x2 */ output[BLAKE2S_OUTBYTES] = 2; blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); memcpy(second_dst, output, second_len); + if (!third_dst || !third_len) + goto out; + + /* Expand third key: key = secret, data = second-key || 0x3 */ + output[BLAKE2S_OUTBYTES] = 3; + blake2s_hmac(output, output, secret, BLAKE2S_OUTBYTES, BLAKE2S_OUTBYTES + 1, BLAKE2S_OUTBYTES); + memcpy(third_dst, output, third_len); + +out: /* Clear sensitive data from stack */ memzero_explicit(secret, BLAKE2S_OUTBYTES); memzero_explicit(output, BLAKE2S_OUTBYTES + 1); @@ -252,23 +255,17 @@ static void symmetric_key_init(struct noise_symmetric_key *key) static void derive_keys(struct noise_symmetric_key *first_dst, struct noise_symmetric_key *second_dst, const u8 chaining_key[NOISE_HASH_LEN]) { - kdf(first_dst->key, second_dst->key, NULL, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, chaining_key); + kdf(first_dst->key, second_dst->key, NULL, NULL, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, 0, chaining_key); symmetric_key_init(first_dst); symmetric_key_init(second_dst); } -static void mix_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], const u8 *src, size_t src_len) -{ - kdf(chaining_key, key, src, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, src_len, chaining_key); -} - -static __must_check bool mix_dh(u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], - const u8 private[NOISE_PUBLIC_KEY_LEN], const u8 public[NOISE_PUBLIC_KEY_LEN]) +static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 private[NOISE_PUBLIC_KEY_LEN], const u8 public[NOISE_PUBLIC_KEY_LEN]) { u8 dh_calculation[NOISE_PUBLIC_KEY_LEN]; if (unlikely(!curve25519(dh_calculation, private, public))) return false; - mix_key(key, chaining_key, dh_calculation, NOISE_PUBLIC_KEY_LEN); + kdf(chaining_key, key, NULL, dh_calculation, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN); return true; } @@ -282,29 +279,28 @@ static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len) blake2s_final(&blake, hash, NOISE_HASH_LEN); } -static void handshake_init(u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], - const u8 remote_static[NOISE_PUBLIC_KEY_LEN], const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) +static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], u8 key[NOISE_SYMMETRIC_KEY_LEN], const u8 psk[NOISE_SYMMETRIC_KEY_LEN]) { - memset(key, 0, NOISE_SYMMETRIC_KEY_LEN); - memcpy(hash, psk ? handshake_psk_name_hash : handshake_name_hash, NOISE_HASH_LEN); - mix_hash(hash, identifier_name, sizeof(identifier_name)); - if (psk) { - u8 temp_hash[NOISE_HASH_LEN]; - kdf(chaining_key, temp_hash, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, handshake_psk_name_hash); - mix_hash(hash, temp_hash, NOISE_HASH_LEN); - memzero_explicit(temp_hash, NOISE_HASH_LEN); - } else - memcpy(chaining_key, handshake_name_hash, NOISE_HASH_LEN); + u8 temp_hash[NOISE_HASH_LEN]; + kdf(chaining_key, temp_hash, key, psk, NOISE_HASH_LEN, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, NOISE_SYMMETRIC_KEY_LEN, chaining_key); + mix_hash(hash, temp_hash, NOISE_HASH_LEN); + memzero_explicit(temp_hash, NOISE_HASH_LEN); +} + +static void handshake_init(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN], const u8 remote_static[NOISE_PUBLIC_KEY_LEN]) +{ + memcpy(hash, handshake_init_hash, NOISE_HASH_LEN); + memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN); mix_hash(hash, remote_static, NOISE_PUBLIC_KEY_LEN); } -static void handshake_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) +static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) { chacha20poly1305_encrypt(dst_ciphertext, src_plaintext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key); mix_hash(hash, dst_ciphertext, noise_encrypted_len(src_len)); } -static bool handshake_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) +static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 hash[NOISE_HASH_LEN]) { if (!chacha20poly1305_decrypt(dst_plaintext, src_ciphertext, src_len, hash, NOISE_HASH_LEN, 0 /* Always zero for Noise_IK */, key)) return false; @@ -312,10 +308,11 @@ static bool handshake_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext, size_ return true; } -static void handshake_nocrypt(u8 *dst, const u8 *src, size_t src_len, u8 hash[NOISE_HASH_LEN]) +static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN], const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN]) { - memcpy(dst, src, src_len); - mix_hash(hash, src, src_len); + memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN); + mix_hash(hash, ephemeral_src, NOISE_PUBLIC_KEY_LEN); + kdf(chaining_key, NULL, NULL, ephemeral_src, NOISE_HASH_LEN, 0, 0, NOISE_PUBLIC_KEY_LEN, chaining_key); } static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN]) @@ -330,6 +327,7 @@ static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN]) bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake) { u8 timestamp[NOISE_TIMESTAMP_LEN]; + u8 key[NOISE_SYMMETRIC_KEY_LEN]; bool ret = false; down_read(&handshake->static_identity->lock); @@ -340,31 +338,28 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION); - handshake_init(handshake->key, handshake->chaining_key, handshake->hash, handshake->remote_static, - handshake->static_identity->has_psk ? handshake->static_identity->preshared_key : NULL); + handshake_init(handshake->chaining_key, handshake->hash, handshake->remote_static); /* e */ curve25519_generate_secret(handshake->ephemeral_private); if (!curve25519_generate_public(handshake->ephemeral_public, handshake->ephemeral_private)) goto out; - handshake_nocrypt(dst->unencrypted_ephemeral, handshake->ephemeral_public, NOISE_PUBLIC_KEY_LEN, handshake->hash); - if (handshake->static_identity->has_psk) - mix_key(handshake->key, handshake->chaining_key, handshake->ephemeral_public, NOISE_PUBLIC_KEY_LEN); + message_ephemeral(dst->unencrypted_ephemeral, handshake->ephemeral_public, handshake->chaining_key, handshake->hash); /* es */ - if (!mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, key, handshake->ephemeral_private, handshake->remote_static)) goto out; /* s */ - handshake_encrypt(dst->encrypted_static, handshake->static_identity->static_public, NOISE_PUBLIC_KEY_LEN, handshake->key, handshake->hash); + message_encrypt(dst->encrypted_static, handshake->static_identity->static_public, NOISE_PUBLIC_KEY_LEN, key, handshake->hash); /* ss */ - if (!mix_dh(handshake->key, handshake->chaining_key, handshake->static_identity->static_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, key, handshake->static_identity->static_private, handshake->remote_static)) goto out; - /* t */ + /* {t} */ tai64n_now(timestamp); - handshake_encrypt(dst->encrypted_timestamp, timestamp, NOISE_TIMESTAMP_LEN, handshake->key, handshake->hash); + message_encrypt(dst->encrypted_timestamp, timestamp, NOISE_TIMESTAMP_LEN, key, handshake->hash); dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); @@ -374,6 +369,7 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, out: up_write(&handshake->lock); up_read(&handshake->static_identity->lock); + memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN); return ret; } @@ -393,28 +389,25 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha if (unlikely(!wg->static_identity.has_identity)) goto out; - handshake_init(key, chaining_key, hash, wg->static_identity.static_public, - wg->static_identity.has_psk ? wg->static_identity.preshared_key : NULL); + handshake_init(chaining_key, hash, wg->static_identity.static_public); /* e */ - handshake_nocrypt(e, src->unencrypted_ephemeral, sizeof(src->unencrypted_ephemeral), hash); - if (wg->static_identity.has_psk) - mix_key(key, chaining_key, e, NOISE_PUBLIC_KEY_LEN); + message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash); /* es */ - if (!mix_dh(key, chaining_key, wg->static_identity.static_private, e)) + if (!mix_dh(chaining_key, key, wg->static_identity.static_private, e)) goto out; /* s */ - if (!handshake_decrypt(s, src->encrypted_static, sizeof(src->encrypted_static), key, hash)) + if (!message_decrypt(s, src->encrypted_static, sizeof(src->encrypted_static), key, hash)) goto out; /* ss */ - if (!mix_dh(key, chaining_key, wg->static_identity.static_private, s)) + if (!mix_dh(chaining_key, key, wg->static_identity.static_private, s)) goto out; - /* t */ - if (!handshake_decrypt(t, src->encrypted_timestamp, sizeof(src->encrypted_timestamp), key, hash)) + /* {t} */ + if (!message_decrypt(t, src->encrypted_timestamp, sizeof(src->encrypted_timestamp), key, hash)) goto out; /* Lookup which peer we're actually talking to */ @@ -436,7 +429,6 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha down_write(&handshake->lock); memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN); memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN); - memcpy(handshake->key, key, NOISE_SYMMETRIC_KEY_LEN); memcpy(handshake->hash, hash, NOISE_HASH_LEN); memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN); handshake->remote_index = src->sender_index; @@ -455,6 +447,7 @@ out: bool noise_handshake_create_response(struct message_handshake_response *dst, struct noise_handshake *handshake) { bool ret = false; + u8 key[NOISE_SYMMETRIC_KEY_LEN]; down_read(&handshake->static_identity->lock); down_write(&handshake->lock); @@ -468,19 +461,21 @@ bool noise_handshake_create_response(struct message_handshake_response *dst, str curve25519_generate_secret(handshake->ephemeral_private); if (!curve25519_generate_public(handshake->ephemeral_public, handshake->ephemeral_private)) goto out; - handshake_nocrypt(dst->unencrypted_ephemeral, handshake->ephemeral_public, NOISE_PUBLIC_KEY_LEN, handshake->hash); - if (handshake->static_identity->has_psk) - mix_key(handshake->key, handshake->chaining_key, handshake->ephemeral_public, NOISE_PUBLIC_KEY_LEN); + message_ephemeral(dst->unencrypted_ephemeral, handshake->ephemeral_public, handshake->chaining_key, handshake->hash); /* ee */ - if (!mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_ephemeral)) + if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_ephemeral)) goto out; /* se */ - if (!mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_static)) + if (!mix_dh(handshake->chaining_key, NULL, handshake->ephemeral_private, handshake->remote_static)) goto out; - handshake_encrypt(dst->encrypted_nothing, NULL, 0, handshake->key, handshake->hash); + /* psk */ + mix_psk(handshake->chaining_key, handshake->hash, key, handshake->preshared_key); + + /* {} */ + message_encrypt(dst->encrypted_nothing, NULL, 0, key, handshake->hash); dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry); @@ -490,6 +485,7 @@ bool noise_handshake_create_response(struct message_handshake_response *dst, str out: up_write(&handshake->lock); up_read(&handshake->static_identity->lock); + memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN); return ret; } @@ -516,7 +512,6 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake down_read(&handshake->lock); state = handshake->state; - memcpy(key, handshake->key, NOISE_SYMMETRIC_KEY_LEN); memcpy(hash, handshake->hash, NOISE_HASH_LEN); memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN); memcpy(ephemeral_private, handshake->ephemeral_private, NOISE_PUBLIC_KEY_LEN); @@ -526,26 +521,26 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake goto fail; /* e */ - handshake_nocrypt(e, src->unencrypted_ephemeral, sizeof(src->unencrypted_ephemeral), hash); - if (wg->static_identity.has_psk) - mix_key(key, chaining_key, e, NOISE_PUBLIC_KEY_LEN); + message_ephemeral(e, src->unencrypted_ephemeral, chaining_key, hash); /* ee */ - if (!mix_dh(key, chaining_key, ephemeral_private, e)) + if (!mix_dh(chaining_key, NULL, ephemeral_private, e)) goto out; /* se */ - if (!mix_dh(key, chaining_key, wg->static_identity.static_private, e)) + if (!mix_dh(chaining_key, NULL, wg->static_identity.static_private, e)) goto out; - /* decrypt nothing */ - if (!handshake_decrypt(NULL, src->encrypted_nothing, sizeof(src->encrypted_nothing), key, hash)) + /* psk */ + mix_psk(chaining_key, hash, key, handshake->preshared_key); + + /* {} */ + if (!message_decrypt(NULL, src->encrypted_nothing, sizeof(src->encrypted_nothing), key, hash)) goto fail; /* Success! Copy everything to peer */ down_write(&handshake->lock); memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN); - memcpy(handshake->key, key, NOISE_SYMMETRIC_KEY_LEN); memcpy(handshake->hash, hash, NOISE_HASH_LEN); memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN); handshake->remote_index = src->sender_index; diff --git a/src/noise.h b/src/noise.h index e60584b..c9b2b56 100644 --- a/src/noise.h +++ b/src/noise.h @@ -53,10 +53,9 @@ struct noise_keypairs { }; struct noise_static_identity { - bool has_identity, has_psk; + bool has_identity; u8 static_public[NOISE_PUBLIC_KEY_LEN]; u8 static_private[NOISE_PUBLIC_KEY_LEN]; - u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]; struct rw_semaphore lock; }; @@ -82,7 +81,8 @@ struct noise_handshake { u8 remote_static[NOISE_PUBLIC_KEY_LEN]; u8 remote_ephemeral[NOISE_PUBLIC_KEY_LEN]; - u8 key[NOISE_SYMMETRIC_KEY_LEN]; + u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]; + u8 hash[NOISE_HASH_LEN]; u8 chaining_key[NOISE_HASH_LEN]; @@ -102,7 +102,7 @@ struct message_data; struct message_handshake_cookie; void noise_init(void); -void noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], struct wireguard_peer *peer); +void noise_handshake_init(struct noise_handshake *handshake, struct noise_static_identity *static_identity, const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN], const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN], struct wireguard_peer *peer); void noise_handshake_clear(struct noise_handshake *handshake); void noise_keypair_put(struct noise_keypair *keypair); struct noise_keypair *noise_keypair_get(struct noise_keypair *keypair); @@ -110,7 +110,6 @@ void noise_keypairs_clear(struct noise_keypairs *keypairs); bool noise_received_with_keypair(struct noise_keypairs *keypairs, struct noise_keypair *received_keypair); void noise_set_static_identity_private_key(struct noise_static_identity *static_identity, const u8 private_key[NOISE_PUBLIC_KEY_LEN]); -void noise_set_static_identity_preshared_key(struct noise_static_identity *static_identity, const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, struct noise_handshake *handshake); struct wireguard_peer *noise_handshake_consume_initiation(struct message_handshake_initiation *src, struct wireguard_device *wg); @@ -14,7 +14,7 @@ static atomic64_t peer_counter = ATOMIC64_INIT(0); -struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN]) +struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]) { struct wireguard_peer *peer; lockdep_assert_held(&wg->device_update_lock); @@ -34,8 +34,8 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ peer->internal_id = atomic64_inc_return(&peer_counter); peer->device = wg; cookie_init(&peer->latest_cookie); - noise_handshake_init(&peer->handshake, &wg->static_identity, public_key, peer); - cookie_checker_precompute_keys(&wg->cookie_checker, peer); + noise_handshake_init(&peer->handshake, &wg->static_identity, public_key, preshared_key, peer); + cookie_checker_precompute_peer_keys(peer); mutex_init(&peer->keypairs.keypair_update_lock); INIT_WORK(&peer->transmit_handshake_work, packet_send_queued_handshakes); rwlock_init(&peer->endpoint_lock); @@ -56,7 +56,7 @@ struct wireguard_peer { #endif }; -struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN]); +struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); struct wireguard_peer *peer_get(struct wireguard_peer *peer); struct wireguard_peer *peer_rcu_get(struct wireguard_peer *peer); diff --git a/src/tests/netns.sh b/src/tests/netns.sh index 6dc917e..9ef83a7 100755 --- a/src/tests/netns.sh +++ b/src/tests/netns.sh @@ -88,15 +88,15 @@ configure_peers() { n1 wg set wg0 \ private-key <(echo "$key1") \ - preshared-key <(echo "$psk") \ listen-port 1 \ peer "$pub2" \ + preshared-key <(echo "$psk") \ allowed-ips 192.168.241.2/32,fd00::2/128 n2 wg set wg0 \ private-key <(echo "$key2") \ - preshared-key <(echo "$psk") \ listen-port 2 \ peer "$pub1" \ + preshared-key <(echo "$psk") \ allowed-ips 192.168.241.1/32,fd00::1/128 ip1 link set up dev wg0 @@ -99,10 +99,13 @@ struct wgipmask { enum { WGPEER_REMOVE_ME = (1 << 0), - WGPEER_REPLACE_IPMASKS = (1 << 1) + WGPEER_REPLACE_IPMASKS = (1 << 1), + WGPEER_REMOVE_PRESHARED_KEY = (1 << 2) }; + struct wgpeer { __u8 public_key[WG_KEY_LEN]; /* Get/Set */ + __u8 preshared_key[WG_KEY_LEN]; /* Get/Set */ __u32 flags; /* Set */ union { @@ -121,12 +124,11 @@ struct wgpeer { enum { WGDEVICE_REPLACE_PEERS = (1 << 0), WGDEVICE_REMOVE_PRIVATE_KEY = (1 << 1), - WGDEVICE_REMOVE_PRESHARED_KEY = (1 << 2), - WGDEVICE_REMOVE_FWMARK = (1 << 3) + WGDEVICE_REMOVE_FWMARK = (1 << 2) }; enum { - WG_API_VERSION_MAGIC = 0xbeef0001 + WG_API_VERSION_MAGIC = 0xbeef0002 }; struct wgdevice { @@ -136,7 +138,6 @@ struct wgdevice { __u8 public_key[WG_KEY_LEN]; /* Get */ __u8 private_key[WG_KEY_LEN]; /* Get/Set */ - __u8 preshared_key[WG_KEY_LEN]; /* Get/Set */ __u32 fwmark; /* Get/Set */ __u16 port; /* Get/Set */ |