diff options
-rw-r--r-- | src/config.c | 2 | ||||
-rw-r--r-- | src/crypto/curve25519.c | 44 | ||||
-rw-r--r-- | src/crypto/curve25519.h | 4 | ||||
-rw-r--r-- | src/noise.c | 39 | ||||
-rw-r--r-- | src/selftest/curve25519.h | 28 |
5 files changed, 74 insertions, 43 deletions
diff --git a/src/config.c b/src/config.c index c420f7c..f307ff6 100644 --- a/src/config.c +++ b/src/config.c @@ -159,7 +159,7 @@ int config_set_device(struct wireguard_device *wg, void __user *user_device) u8 public_key[NOISE_PUBLIC_KEY_LEN] = { 0 }; struct wireguard_peer *peer; /* We remove before setting, to prevent race, which means doing two 25519-genpub ops. */ - curve25519_generate_public(public_key, in_device.private_key); + bool unused __attribute((unused)) = curve25519_generate_public(public_key, in_device.private_key); peer = pubkey_hashtable_lookup(&wg->peer_hashtable, public_key); if (peer) { peer_put(peer); diff --git a/src/crypto/curve25519.c b/src/crypto/curve25519.c index 5412b64..f95be39 100644 --- a/src/crypto/curve25519.c +++ b/src/crypto/curve25519.c @@ -16,6 +16,7 @@ static __always_inline void normalize_secret(u8 secret[CURVE25519_POINT_SIZE]) secret[31] &= 127; secret[31] |= 64; } +static const u8 null_point[CURVE25519_POINT_SIZE] = { 0 }; #ifdef CONFIG_X86_64 #include <asm/cpufeature.h> @@ -62,9 +63,14 @@ static void curve25519_sandy2x(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secr #undef x1 #undef x2 #undef z2 - curve25519_sandy2x_fe51_invert(&z_51, &z_51); - curve25519_sandy2x_fe51_mul(&x_51, &x_51, &z_51); - curve25519_sandy2x_fe51_pack(mypublic, &x_51); + curve25519_sandy2x_fe51_invert(&z_51, (const fe51 *)&z_51); + curve25519_sandy2x_fe51_mul(&x_51, (const fe51 *)&x_51, (const fe51 *)&z_51); + curve25519_sandy2x_fe51_pack(mypublic, (const fe51 *)&x_51); + + memzero_explicit(e, sizeof(e)); + memzero_explicit(var, sizeof(var)); + memzero_explicit(x_51, sizeof(x_51)); + memzero_explicit(z_51, sizeof(z_51)); } static void curve25519_sandy2x_base(u8 pub[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE]) @@ -89,9 +95,14 @@ static void curve25519_sandy2x_base(u8 pub[CURVE25519_POINT_SIZE], const u8 secr x_51[4] = (x2[9] << 26) + x2[8]; #undef x2 #undef z2 - curve25519_sandy2x_fe51_invert(&z_51, &z_51); - curve25519_sandy2x_fe51_mul(&x_51, &x_51, &z_51); - curve25519_sandy2x_fe51_pack(pub, &x_51); + curve25519_sandy2x_fe51_invert(&z_51, (const fe51 *)&z_51); + curve25519_sandy2x_fe51_mul(&x_51, (const fe51 *)&x_51, (const fe51 *)&z_51); + curve25519_sandy2x_fe51_pack(pub, (const fe51 *)&x_51); + + memzero_explicit(e, sizeof(e)); + memzero_explicit(var, sizeof(var)); + memzero_explicit(x_51, sizeof(x_51)); + memzero_explicit(z_51, sizeof(z_51)); } #else void curve25519_fpu_init(void) { } @@ -473,7 +484,7 @@ static void crecip(felem out, const felem z) /* 2^255 - 21 */ fmul(out, t0, a); } -void curve25519(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE], const u8 basepoint[CURVE25519_POINT_SIZE]) +bool curve25519(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE], const u8 basepoint[CURVE25519_POINT_SIZE]) { #ifdef CONFIG_X86_64 if (curve25519_use_avx && irq_fpu_usable()) { @@ -501,21 +512,21 @@ void curve25519(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_P memzero_explicit(z, sizeof(z)); memzero_explicit(zmone, sizeof(zmone)); } + return crypto_memneq(mypublic, null_point, CURVE25519_POINT_SIZE); } -void curve25519_generate_public(u8 pub[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE]) +bool curve25519_generate_public(u8 pub[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE]) { + static const u8 basepoint[CURVE25519_POINT_SIZE] = { 9 }; #ifdef CONFIG_X86_64 if (curve25519_use_avx && irq_fpu_usable()) { kernel_fpu_begin(); curve25519_sandy2x_base(pub, secret); kernel_fpu_end(); - } else -#endif - { - static const u8 basepoint[CURVE25519_POINT_SIZE] = { 9 }; - curve25519(pub, secret, basepoint); + return crypto_memneq(pub, null_point, CURVE25519_POINT_SIZE); } +#endif + return curve25519(pub, secret, basepoint); } #else typedef s64 limb; @@ -1306,7 +1317,7 @@ static void cmult(limb *resultx, limb *resultz, const u8 *n, const limb *q) memcpy(resultz, nqz, sizeof(limb) * 10); } -void curve25519(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE], const u8 basepoint[CURVE25519_POINT_SIZE]) +bool curve25519(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE], const u8 basepoint[CURVE25519_POINT_SIZE]) { limb bp[10], x[10], z[11], zmone[10]; u8 e[32]; @@ -1325,12 +1336,13 @@ void curve25519(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_P memzero_explicit(x, sizeof(x)); memzero_explicit(z, sizeof(z)); memzero_explicit(zmone, sizeof(zmone)); + return crypto_memneq(mypublic, null_point, CURVE25519_POINT_SIZE); } -void curve25519_generate_public(u8 pub[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE]) +bool curve25519_generate_public(u8 pub[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE]) { static const u8 basepoint[CURVE25519_POINT_SIZE] = { 9 }; - curve25519(pub, secret, basepoint); + return curve25519(pub, secret, basepoint); } #endif diff --git a/src/crypto/curve25519.h b/src/crypto/curve25519.h index 8e440a1..16be496 100644 --- a/src/crypto/curve25519.h +++ b/src/crypto/curve25519.h @@ -9,9 +9,9 @@ enum curve25519_lengths { CURVE25519_POINT_SIZE = 32 }; -void curve25519(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE], const u8 basepoint[CURVE25519_POINT_SIZE]); +bool __must_check curve25519(u8 mypublic[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE], const u8 basepoint[CURVE25519_POINT_SIZE]); void curve25519_generate_secret(u8 secret[CURVE25519_POINT_SIZE]); -void curve25519_generate_public(u8 pub[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE]); +bool __must_check curve25519_generate_public(u8 pub[CURVE25519_POINT_SIZE], const u8 secret[CURVE25519_POINT_SIZE]); void curve25519_fpu_init(void); diff --git a/src/noise.c b/src/noise.c index 0ffffd7..608a175 100644 --- a/src/noise.c +++ b/src/noise.c @@ -189,8 +189,7 @@ void noise_set_static_identity_private_key(struct noise_static_identity *static_ down_write(&static_identity->lock); if (private_key) { memcpy(static_identity->static_private, private_key, NOISE_PUBLIC_KEY_LEN); - curve25519_generate_public(static_identity->static_public, private_key); - static_identity->has_identity = true; + static_identity->has_identity = curve25519_generate_public(static_identity->static_public, private_key); } else { memset(static_identity->static_private, 0, NOISE_PUBLIC_KEY_LEN); memset(static_identity->static_public, 0, NOISE_PUBLIC_KEY_LEN); @@ -263,13 +262,15 @@ static void mix_key(u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 chaining_key[NOISE_HASH_ kdf(chaining_key, key, src, NOISE_HASH_LEN, NOISE_SYMMETRIC_KEY_LEN, src_len, chaining_key); } -static void mix_dh(u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], +static __must_check bool mix_dh(u8 key[NOISE_SYMMETRIC_KEY_LEN], u8 chaining_key[NOISE_HASH_LEN], const u8 private[NOISE_PUBLIC_KEY_LEN], const u8 public[NOISE_PUBLIC_KEY_LEN]) { u8 dh_calculation[NOISE_PUBLIC_KEY_LEN]; - curve25519(dh_calculation, private, public); + if (unlikely(!curve25519(dh_calculation, private, public))) + return false; mix_key(key, chaining_key, dh_calculation, NOISE_PUBLIC_KEY_LEN); memzero_explicit(dh_calculation, NOISE_PUBLIC_KEY_LEN); + return true; } static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len) @@ -346,20 +347,23 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst, /* e */ curve25519_generate_secret(handshake->ephemeral_private); - curve25519_generate_public(handshake->ephemeral_public, handshake->ephemeral_private); + if (!curve25519_generate_public(handshake->ephemeral_public, handshake->ephemeral_private)) + goto out; handshake_nocrypt(dst->unencrypted_ephemeral, handshake->ephemeral_public, NOISE_PUBLIC_KEY_LEN, handshake->hash); if (handshake->static_identity->has_psk) mix_key(handshake->key, handshake->chaining_key, handshake->ephemeral_public, NOISE_PUBLIC_KEY_LEN); /* es */ - mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_static); + if (!mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_static)) + goto out; /* s */ if (!handshake_encrypt(dst->encrypted_static, handshake->static_identity->static_public, NOISE_PUBLIC_KEY_LEN, handshake->key, handshake->hash)) goto out; /* ss */ - mix_dh(handshake->key, handshake->chaining_key, handshake->static_identity->static_private, handshake->remote_static); + if (!mix_dh(handshake->key, handshake->chaining_key, handshake->static_identity->static_private, handshake->remote_static)) + goto out; /* t */ tai64n_now(timestamp); @@ -402,14 +406,16 @@ struct wireguard_peer *noise_handshake_consume_initiation(struct message_handsha mix_key(key, chaining_key, e, NOISE_PUBLIC_KEY_LEN); /* es */ - mix_dh(key, chaining_key, wg->static_identity.static_private, e); + if (!mix_dh(key, chaining_key, wg->static_identity.static_private, e)) + goto out; /* s */ if (!handshake_decrypt(s, src->encrypted_static, sizeof(src->encrypted_static), key, hash)) goto out; /* ss */ - mix_dh(key, chaining_key, wg->static_identity.static_private, s); + if (!mix_dh(key, chaining_key, wg->static_identity.static_private, s)) + goto out; /* t */ if (!handshake_decrypt(t, src->encrypted_timestamp, sizeof(src->encrypted_timestamp), key, hash)) @@ -464,16 +470,19 @@ bool noise_handshake_create_response(struct message_handshake_response *dst, str /* e */ curve25519_generate_secret(handshake->ephemeral_private); - curve25519_generate_public(handshake->ephemeral_public, handshake->ephemeral_private); + if (!curve25519_generate_public(handshake->ephemeral_public, handshake->ephemeral_private)) + goto out; handshake_nocrypt(dst->unencrypted_ephemeral, handshake->ephemeral_public, NOISE_PUBLIC_KEY_LEN, handshake->hash); if (handshake->static_identity->has_psk) mix_key(handshake->key, handshake->chaining_key, handshake->ephemeral_public, NOISE_PUBLIC_KEY_LEN); /* ee */ - mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_ephemeral); + if (!mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_ephemeral)) + goto out; /* se */ - mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_static); + if (!mix_dh(handshake->key, handshake->chaining_key, handshake->ephemeral_private, handshake->remote_static)) + goto out; if (!handshake_encrypt(dst->encrypted_nothing, NULL, 0, handshake->key, handshake->hash)) goto out; @@ -527,10 +536,12 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake mix_key(key, chaining_key, e, NOISE_PUBLIC_KEY_LEN); /* ee */ - mix_dh(key, chaining_key, ephemeral_private, e); + if (!mix_dh(key, chaining_key, ephemeral_private, e)) + goto out; /* se */ - mix_dh(key, chaining_key, wg->static_identity.static_private, e); + if (!mix_dh(key, chaining_key, wg->static_identity.static_private, e)) + goto out; /* decrypt nothing */ if (!handshake_decrypt(NULL, src->encrypted_nothing, sizeof(src->encrypted_nothing), key, hash)) diff --git a/src/selftest/curve25519.h b/src/selftest/curve25519.h index 04b25a3..0ed3dae 100644 --- a/src/selftest/curve25519.h +++ b/src/selftest/curve25519.h @@ -5,54 +5,62 @@ struct curve25519_test_vector { u8 private[CURVE25519_POINT_SIZE]; u8 public[CURVE25519_POINT_SIZE]; u8 result[CURVE25519_POINT_SIZE]; + bool valid; }; static const struct curve25519_test_vector curve25519_test_vectors[] = { { .private = { 0x77, 0x07, 0x6d, 0x0a, 0x73, 0x18, 0xa5, 0x7d, 0x3c, 0x16, 0xc1, 0x72, 0x51, 0xb2, 0x66, 0x45, 0xdf, 0x4c, 0x2f, 0x87, 0xeb, 0xc0, 0x99, 0x2a, 0xb1, 0x77, 0xfb, 0xa5, 0x1d, 0xb9, 0x2c, 0x2a }, .public = { 0xde, 0x9e, 0xdb, 0x7d, 0x7b, 0x7d, 0xc1, 0xb4, 0xd3, 0x5b, 0x61, 0xc2, 0xec, 0xe4, 0x35, 0x37, 0x3f, 0x83, 0x43, 0xc8, 0x5b, 0x78, 0x67, 0x4d, 0xad, 0xfc, 0x7e, 0x14, 0x6f, 0x88, 0x2b, 0x4f }, - .result = { 0x4a, 0x5d, 0x9d, 0x5b, 0xa4, 0xce, 0x2d, 0xe1, 0x72, 0x8e, 0x3b, 0xf4, 0x80, 0x35, 0x0f, 0x25, 0xe0, 0x7e, 0x21, 0xc9, 0x47, 0xd1, 0x9e, 0x33, 0x76, 0xf0, 0x9b, 0x3c, 0x1e, 0x16, 0x17, 0x42 } + .result = { 0x4a, 0x5d, 0x9d, 0x5b, 0xa4, 0xce, 0x2d, 0xe1, 0x72, 0x8e, 0x3b, 0xf4, 0x80, 0x35, 0x0f, 0x25, 0xe0, 0x7e, 0x21, 0xc9, 0x47, 0xd1, 0x9e, 0x33, 0x76, 0xf0, 0x9b, 0x3c, 0x1e, 0x16, 0x17, 0x42 }, + .valid = true }, { .private = { 0x5d, 0xab, 0x08, 0x7e, 0x62, 0x4a, 0x8a, 0x4b, 0x79, 0xe1, 0x7f, 0x8b, 0x83, 0x80, 0x0e, 0xe6, 0x6f, 0x3b, 0xb1, 0x29, 0x26, 0x18, 0xb6, 0xfd, 0x1c, 0x2f, 0x8b, 0x27, 0xff, 0x88, 0xe0, 0xeb }, .public = { 0x85, 0x20, 0xf0, 0x09, 0x89, 0x30, 0xa7, 0x54, 0x74, 0x8b, 0x7d, 0xdc, 0xb4, 0x3e, 0xf7, 0x5a, 0x0d, 0xbf, 0x3a, 0x0d, 0x26, 0x38, 0x1a, 0xf4, 0xeb, 0xa4, 0xa9, 0x8e, 0xaa, 0x9b, 0x4e, 0x6a }, - .result = { 0x4a, 0x5d, 0x9d, 0x5b, 0xa4, 0xce, 0x2d, 0xe1, 0x72, 0x8e, 0x3b, 0xf4, 0x80, 0x35, 0x0f, 0x25, 0xe0, 0x7e, 0x21, 0xc9, 0x47, 0xd1, 0x9e, 0x33, 0x76, 0xf0, 0x9b, 0x3c, 0x1e, 0x16, 0x17, 0x42 } + .result = { 0x4a, 0x5d, 0x9d, 0x5b, 0xa4, 0xce, 0x2d, 0xe1, 0x72, 0x8e, 0x3b, 0xf4, 0x80, 0x35, 0x0f, 0x25, 0xe0, 0x7e, 0x21, 0xc9, 0x47, 0xd1, 0x9e, 0x33, 0x76, 0xf0, 0x9b, 0x3c, 0x1e, 0x16, 0x17, 0x42 }, + .valid = true }, { .private = { 1 }, .public = { 0x25, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, - .result = { 0x3c, 0x77, 0x77, 0xca, 0xf9, 0x97, 0xb2, 0x64, 0x41, 0x60, 0x77, 0x66, 0x5b, 0x4e, 0x22, 0x9d, 0xb, 0x95, 0x48, 0xdc, 0xc, 0xd8, 0x19, 0x98, 0xdd, 0xcd, 0xc5, 0xc8, 0x53, 0x3c, 0x79, 0x7f } + .result = { 0x3c, 0x77, 0x77, 0xca, 0xf9, 0x97, 0xb2, 0x64, 0x41, 0x60, 0x77, 0x66, 0x5b, 0x4e, 0x22, 0x9d, 0xb, 0x95, 0x48, 0xdc, 0xc, 0xd8, 0x19, 0x98, 0xdd, 0xcd, 0xc5, 0xc8, 0x53, 0x3c, 0x79, 0x7f }, + .valid = true }, { .private = { 1 }, .public = { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }, - .result = { 0xb3, 0x2d, 0x13, 0x62, 0xc2, 0x48, 0xd6, 0x2f, 0xe6, 0x26, 0x19, 0xcf, 0xf0, 0x4d, 0xd4, 0x3d, 0xb7, 0x3f, 0xfc, 0x1b, 0x63, 0x8, 0xed, 0xe3, 0xb, 0x78, 0xd8, 0x73, 0x80, 0xf1, 0xe8, 0x34 } + .result = { 0xb3, 0x2d, 0x13, 0x62, 0xc2, 0x48, 0xd6, 0x2f, 0xe6, 0x26, 0x19, 0xcf, 0xf0, 0x4d, 0xd4, 0x3d, 0xb7, 0x3f, 0xfc, 0x1b, 0x63, 0x8, 0xed, 0xe3, 0xb, 0x78, 0xd8, 0x73, 0x80, 0xf1, 0xe8, 0x34 }, + .valid = true }, { .private = { 0xa5, 0x46, 0xe3, 0x6b, 0xf0, 0x52, 0x7c, 0x9d, 0x3b, 0x16, 0x15, 0x4b, 0x82, 0x46, 0x5e, 0xdd, 0x62, 0x14, 0x4c, 0x0a, 0xc1, 0xfc, 0x5a, 0x18, 0x50, 0x6a, 0x22, 0x44, 0xba, 0x44, 0x9a, 0xc4 }, .public = { 0xe6, 0xdb, 0x68, 0x67, 0x58, 0x30, 0x30, 0xdb, 0x35, 0x94, 0xc1, 0xa4, 0x24, 0xb1, 0x5f, 0x7c, 0x72, 0x66, 0x24, 0xec, 0x26, 0xb3, 0x35, 0x3b, 0x10, 0xa9, 0x03, 0xa6, 0xd0, 0xab, 0x1c, 0x4c }, - .result = { 0xc3, 0xda, 0x55, 0x37, 0x9d, 0xe9, 0xc6, 0x90, 0x8e, 0x94, 0xea, 0x4d, 0xf2, 0x8d, 0x08, 0x4f, 0x32, 0xec, 0xcf, 0x03, 0x49, 0x1c, 0x71, 0xf7, 0x54, 0xb4, 0x07, 0x55, 0x77, 0xa2, 0x85, 0x52 } + .result = { 0xc3, 0xda, 0x55, 0x37, 0x9d, 0xe9, 0xc6, 0x90, 0x8e, 0x94, 0xea, 0x4d, 0xf2, 0x8d, 0x08, 0x4f, 0x32, 0xec, 0xcf, 0x03, 0x49, 0x1c, 0x71, 0xf7, 0x54, 0xb4, 0x07, 0x55, 0x77, 0xa2, 0x85, 0x52 }, + .valid = true }, { .private = { 1, 2, 3, 4 }, .public = { 0 }, - .result = { 0 } + .result = { 0 }, + .valid = false }, { .private = { 2, 4, 6, 8 }, .public = { 0xe0, 0xeb, 0x7a, 0x7c, 0x3b, 0x41, 0xb8, 0xae, 0x16, 0x56, 0xe3, 0xfa, 0xf1, 0x9f, 0xc4, 0x6a, 0xda, 0x09, 0x8d, 0xeb, 0x9c, 0x32, 0xb1, 0xfd, 0x86, 0x62, 0x05, 0x16, 0x5f, 0x49, 0xb8 }, - .result = { 0 } + .result = { 0 }, + .valid = false } }; bool curve25519_selftest(void) { - bool success = true; + bool success = true, ret; size_t i = 0; u8 out[CURVE25519_POINT_SIZE]; for (i = 0; i < ARRAY_SIZE(curve25519_test_vectors); ++i) { memset(out, 0, CURVE25519_POINT_SIZE); - curve25519(out, curve25519_test_vectors[i].private, curve25519_test_vectors[i].public); - if (memcmp(out, curve25519_test_vectors[i].result, CURVE25519_POINT_SIZE)) { + ret = curve25519(out, curve25519_test_vectors[i].private, curve25519_test_vectors[i].public); + if (ret != curve25519_test_vectors[i].valid || memcmp(out, curve25519_test_vectors[i].result, CURVE25519_POINT_SIZE)) { pr_info("curve25519 self-test %zu: FAIL\n", i + 1); success = false; break; |