From e664a05d4b77cd398b80b34f0e4aea068596f80b Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Wed, 30 May 2018 23:16:41 +0200 Subject: KeyEncoding: more constant time Signed-off-by: Jason A. Donenfeld --- .../java/com/wireguard/crypto/KeyEncoding.java | 49 ++++++++++++---------- 1 file changed, 28 insertions(+), 21 deletions(-) (limited to 'app/src/main/java/com/wireguard/crypto/KeyEncoding.java') diff --git a/app/src/main/java/com/wireguard/crypto/KeyEncoding.java b/app/src/main/java/com/wireguard/crypto/KeyEncoding.java index ec86da37..1f32cc52 100644 --- a/app/src/main/java/com/wireguard/crypto/KeyEncoding.java +++ b/app/src/main/java/com/wireguard/crypto/KeyEncoding.java @@ -62,11 +62,10 @@ public final class KeyEncoding { final byte[] key = new byte[KEY_LENGTH]; if (input.length != KEY_LENGTH_BASE64 || input[KEY_LENGTH_BASE64 - 1] != '=') throw new IllegalArgumentException(KEY_LENGTH_BASE64_EXCEPTION_MESSAGE); - int i; + int i, ret = 0; for (i = 0; i < KEY_LENGTH / 3; ++i) { final int val = decodeBase64(input, i * 4); - if (val < 0) - throw new IllegalArgumentException(KEY_LENGTH_BASE64_EXCEPTION_MESSAGE); + ret |= val >>> 31; key[i * 3] = (byte) ((val >>> 16) & 0xff); key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); key[i * 3 + 2] = (byte) (val & 0xff); @@ -78,10 +77,12 @@ public final class KeyEncoding { 'A', }; final int val = decodeBase64(endSegment, 0); - if (val < 0 || (val & 0xff) != 0) - throw new IllegalArgumentException(KEY_LENGTH_BASE64_EXCEPTION_MESSAGE); + ret |= (val >>> 31) | (val & 0xff); key[i * 3] = (byte) ((val >>> 16) & 0xff); key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); + + if (ret != 0) + throw new IllegalArgumentException(KEY_LENGTH_BASE64_EXCEPTION_MESSAGE); return key; } @@ -90,25 +91,31 @@ public final class KeyEncoding { final byte[] key = new byte[KEY_LENGTH]; if (input.length != KEY_LENGTH_HEX) throw new IllegalArgumentException(KEY_LENGTH_HEX_EXCEPTION_MESSAGE); + int ret = 0; + + for (int i = 0; i < KEY_LENGTH_HEX; i += 2) { + int c, c_num, c_num0, c_alpha, c_alpha0, c_val, c_acc; - int c_acc = 0; - int state = 0; + c = input[i]; + c_num = c ^ 48; + c_num0 = ((c_num - 10) >>> 8) & 0xff; + c_alpha = (c & ~32) - 55; + c_alpha0 = (((c_alpha - 10) ^ (c_alpha - 16)) >>> 8) & 0xff; + ret |= ((c_num0 | c_alpha0) - 1) >>> 8; + c_val = (c_num0 & c_num) | (c_alpha0 & c_alpha); + c_acc = c_val * 16; - for (int i = 0; i < KEY_LENGTH_HEX; ++i) { - final int c = input[i]; - final int c_num = c ^ 48; - final int c_num0 = (c_num - 10) >> 8; - final int c_alpha = (c & ~32) - 55; - final int c_alpha0 = ((c_alpha - 10) ^ (c_alpha - 16)) >> 8; - if ((c_num0 | c_alpha0) == 0) - throw new IllegalArgumentException(KEY_LENGTH_HEX_EXCEPTION_MESSAGE); - final int c_val = (c_num0 & c_num) | (c_alpha0 & c_alpha); - if (state == 0) - c_acc = c_val * 16; - else - key[i / 2] = (byte) (c_acc | c_val); - state = ~state; + c = input[i + 1]; + c_num = c ^ 48; + c_num0 = ((c_num - 10) >>> 8) & 0xff; + c_alpha = (c & ~32) - 55; + c_alpha0 = (((c_alpha - 10) ^ (c_alpha - 16)) >>> 8) & 0xff; + ret |= ((c_num0 | c_alpha0) - 1) >>> 8; + c_val = (c_num0 & c_num) | (c_alpha0 & c_alpha); + key[i / 2] = (byte) (c_acc | c_val); } + if (ret != 0) + throw new IllegalArgumentException(KEY_LENGTH_HEX_EXCEPTION_MESSAGE); return key; } -- cgit v1.2.3