diff options
author | Harsh Shandilya <me@msfjarvis.dev> | 2020-03-09 19:00:14 +0530 |
---|---|---|
committer | Harsh Shandilya <me@msfjarvis.dev> | 2020-03-09 19:24:26 +0530 |
commit | adc613d8011af7c508050badb1272e8326554c39 (patch) | |
tree | 4eadedc0767e1f4f482b7c22ec905329acab62a6 /tunnel/src/main/java | |
parent | fd573f6c1c37bcfcd09237dfcd55e08b1e0eaff9 (diff) |
Migrate tunnel related classes to tunnel/ Gradle module
Signed-off-by: Harsh Shandilya <me@msfjarvis.dev>
Diffstat (limited to 'tunnel/src/main/java')
24 files changed, 3485 insertions, 0 deletions
diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Backend.java b/tunnel/src/main/java/com/wireguard/android/backend/Backend.java new file mode 100644 index 00000000..ed3a5ebd --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Backend.java @@ -0,0 +1,63 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import com.wireguard.config.Config; + +import java.util.Collection; +import java.util.Set; + +import androidx.annotation.Nullable; + +/** + * Interface for implementations of the WireGuard secure network tunnel. + */ + +public interface Backend { + /** + * Enumerate names of currently-running tunnels. + * + * @return The set of running tunnel names. + */ + Set<String> getRunningTunnelNames(); + + /** + * Get the state of a tunnel. + * + * @param tunnel The tunnel to examine the state of. + * @return The state of the tunnel. + */ + Tunnel.State getState(Tunnel tunnel) throws Exception; + + /** + * Get statistics about traffic and errors on this tunnel. If the tunnel is not running, the + * statistics object will be filled with zero values. + * + * @param tunnel The tunnel to retrieve statistics for. + * @return The statistics for the tunnel. + */ + Statistics getStatistics(Tunnel tunnel) throws Exception; + + /** + * Determine version of underlying backend. + * + * @return The version of the backend. + * @throws Exception + */ + String getVersion() throws Exception; + + /** + * Set the state of a tunnel, updating it's configuration. If the tunnel is already up, config + * may update the running configuration; config may be null when setting the tunnel down. + * + * @param tunnel The tunnel to control the state of. + * @param state The new state for this tunnel. Must be {@code UP}, {@code DOWN}, or + * {@code TOGGLE}. + * @param config The configuration for this tunnel, may be null if state is {@code DOWN}. + * @return The updated state of the tunnel. + */ + Tunnel.State setState(Tunnel tunnel, Tunnel.State state, @Nullable Config config) throws Exception; +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java b/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java new file mode 100644 index 00000000..e1e8eaa9 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java @@ -0,0 +1,30 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +public final class BackendException extends Exception { + public enum Reason { + UNKNOWN_KERNEL_MODULE_NAME, + WG_QUICK_CONFIG_ERROR_CODE, + TUNNEL_MISSING_CONFIG, + VPN_NOT_AUTHORIZED, + UNABLE_TO_START_VPN, + TUN_CREATION_ERROR, + GO_ACTIVATION_ERROR_CODE + } + private final Reason reason; + private final Object[] format; + public BackendException(final Reason reason, final Object ...format) { + this.reason = reason; + this.format = format; + } + public Reason getReason() { + return reason; + } + public Object[] getFormat() { + return format; + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java new file mode 100644 index 00000000..6ad5afa4 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java @@ -0,0 +1,292 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import android.content.Context; +import android.content.Intent; +import android.os.Build; +import android.os.ParcelFileDescriptor; +import androidx.annotation.Nullable; +import androidx.collection.ArraySet; +import android.util.Log; + +import com.wireguard.android.backend.BackendException.Reason; +import com.wireguard.android.backend.Tunnel.State; +import com.wireguard.android.util.SharedLibraryLoader; +import com.wireguard.config.Config; +import com.wireguard.config.InetNetwork; +import com.wireguard.config.Peer; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyFormatException; + +import java.net.InetAddress; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import java9.util.concurrent.CompletableFuture; + +public final class GoBackend implements Backend { + private static final String TAG = "WireGuard/" + GoBackend.class.getSimpleName(); + private static CompletableFuture<VpnService> vpnService = new CompletableFuture<>(); + public interface AlwaysOnCallback { + void alwaysOnTriggered(); + } + @Nullable private static AlwaysOnCallback alwaysOnCallback; + public static void setAlwaysOnCallback(AlwaysOnCallback cb) { + alwaysOnCallback = cb; + } + + private final Context context; + @Nullable private Tunnel currentTunnel; + @Nullable private Config currentConfig; + private int currentTunnelHandle = -1; + + public GoBackend(final Context context) { + SharedLibraryLoader.loadSharedLibrary(context, "wg-go"); + this.context = context; + } + + private static native String wgGetConfig(int handle); + + private static native int wgGetSocketV4(int handle); + + private static native int wgGetSocketV6(int handle); + + private static native void wgTurnOff(int handle); + + private static native int wgTurnOn(String ifName, int tunFd, String settings); + + private static native String wgVersion(); + + @Override + public Set<String> getRunningTunnelNames() { + if (currentTunnel != null) { + final Set<String> runningTunnels = new ArraySet<>(); + runningTunnels.add(currentTunnel.getName()); + return runningTunnels; + } + return Collections.emptySet(); + } + + @Override + public State getState(final Tunnel tunnel) { + return currentTunnel == tunnel ? State.UP : State.DOWN; + } + + @Override + public Statistics getStatistics(final Tunnel tunnel) { + final Statistics stats = new Statistics(); + if (tunnel != currentTunnel) { + return stats; + } + final String config = wgGetConfig(currentTunnelHandle); + Key key = null; + long rx = 0, tx = 0; + for (final String line : config.split("\\n")) { + if (line.startsWith("public_key=")) { + if (key != null) + stats.add(key, rx, tx); + rx = 0; + tx = 0; + try { + key = Key.fromHex(line.substring(11)); + } catch (final KeyFormatException ignored) { + key = null; + } + } else if (line.startsWith("rx_bytes=")) { + if (key == null) + continue; + try { + rx = Long.parseLong(line.substring(9)); + } catch (final NumberFormatException ignored) { + rx = 0; + } + } else if (line.startsWith("tx_bytes=")) { + if (key == null) + continue; + try { + tx = Long.parseLong(line.substring(9)); + } catch (final NumberFormatException ignored) { + tx = 0; + } + } + } + if (key != null) + stats.add(key, rx, tx); + return stats; + } + + @Override + public String getVersion() { + return wgVersion(); + } + + @Override + public State setState(final Tunnel tunnel, State state, @Nullable final Config config) throws Exception { + final State originalState = getState(tunnel); + + if (state == State.TOGGLE) + state = originalState == State.UP ? State.DOWN : State.UP; + if (state == originalState && tunnel == currentTunnel && config == currentConfig) + return originalState; + if (state == State.UP) { + final Config originalConfig = currentConfig; + final Tunnel originalTunnel = currentTunnel; + if (currentTunnel != null) + setStateInternal(currentTunnel, null, State.DOWN); + try { + setStateInternal(tunnel, config, state); + } catch(final Exception e) { + if (originalTunnel != null) + setStateInternal(originalTunnel, originalConfig, State.UP); + throw e; + } + } else if (state == State.DOWN && tunnel == currentTunnel) { + setStateInternal(tunnel, null, State.DOWN); + } + return getState(tunnel); + } + + private void setStateInternal(final Tunnel tunnel, @Nullable final Config config, final State state) + throws Exception { + Log.i(TAG, "Bringing tunnel " + tunnel.getName() + " " + state); + + if (state == State.UP) { + if (config == null) + throw new BackendException(Reason.TUNNEL_MISSING_CONFIG); + + if (VpnService.prepare(context) != null) + throw new BackendException(Reason.VPN_NOT_AUTHORIZED); + + final VpnService service; + if (!vpnService.isDone()) + startVpnService(); + + try { + service = vpnService.get(2, TimeUnit.SECONDS); + } catch (final TimeoutException e) { + final Exception be = new BackendException(Reason.UNABLE_TO_START_VPN); + be.initCause(e); + throw be; + } + service.setOwner(this); + + if (currentTunnelHandle != -1) { + Log.w(TAG, "Tunnel already up"); + return; + } + + // Build config + final String goConfig = config.toWgUserspaceString(); + + // Create the vpn tunnel with android API + final VpnService.Builder builder = service.getBuilder(); + builder.setSession(tunnel.getName()); + + for (final String excludedApplication : config.getInterface().getExcludedApplications()) + builder.addDisallowedApplication(excludedApplication); + + for (final InetNetwork addr : config.getInterface().getAddresses()) + builder.addAddress(addr.getAddress(), addr.getMask()); + + for (final InetAddress addr : config.getInterface().getDnsServers()) + builder.addDnsServer(addr.getHostAddress()); + + for (final Peer peer : config.getPeers()) { + for (final InetNetwork addr : peer.getAllowedIps()) + builder.addRoute(addr.getAddress(), addr.getMask()); + } + + builder.setMtu(config.getInterface().getMtu().orElse(1280)); + + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) + builder.setMetered(false); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) + service.setUnderlyingNetworks(null); + + builder.setBlocking(true); + try (final ParcelFileDescriptor tun = builder.establish()) { + if (tun == null) + throw new BackendException(Reason.TUN_CREATION_ERROR); + Log.d(TAG, "Go backend v" + wgVersion()); + currentTunnelHandle = wgTurnOn(tunnel.getName(), tun.detachFd(), goConfig); + } + if (currentTunnelHandle < 0) + throw new BackendException(Reason.GO_ACTIVATION_ERROR_CODE, currentTunnelHandle); + + currentTunnel = tunnel; + currentConfig = config; + + service.protect(wgGetSocketV4(currentTunnelHandle)); + service.protect(wgGetSocketV6(currentTunnelHandle)); + } else { + if (currentTunnelHandle == -1) { + Log.w(TAG, "Tunnel already down"); + return; + } + + wgTurnOff(currentTunnelHandle); + currentTunnel = null; + currentTunnelHandle = -1; + currentConfig = null; + } + + tunnel.onStateChange(state); + } + + private void startVpnService() { + Log.d(TAG, "Requesting to start VpnService"); + context.startService(new Intent(context, VpnService.class)); + } + + public static class VpnService extends android.net.VpnService { + @Nullable private GoBackend owner; + + public void setOwner(final GoBackend owner) { + this.owner = owner; + } + + public Builder getBuilder() { + return new Builder(); + } + + @Override + public void onCreate() { + vpnService.complete(this); + super.onCreate(); + } + + @Override + public void onDestroy() { + if (owner != null) { + final Tunnel tunnel = owner.currentTunnel; + if (tunnel != null) { + if (owner.currentTunnelHandle != -1) + wgTurnOff(owner.currentTunnelHandle); + owner.currentTunnel = null; + owner.currentTunnelHandle = -1; + owner.currentConfig = null; + tunnel.onStateChange(State.DOWN); + } + } + vpnService = vpnService.newIncompleteFuture(); + super.onDestroy(); + } + + @Override + public int onStartCommand(@Nullable final Intent intent, final int flags, final int startId) { + vpnService.complete(this); + if (intent == null || intent.getComponent() == null || !intent.getComponent().getPackageName().equals(getPackageName())) { + Log.d(TAG, "Service started by Always-on VPN feature"); + if (alwaysOnCallback != null) + alwaysOnCallback.alwaysOnTriggered(); + } + return super.onStartCommand(intent, flags, startId); + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java b/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java new file mode 100644 index 00000000..2ca87d23 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java @@ -0,0 +1,62 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import android.os.SystemClock; +import android.util.Pair; + +import com.wireguard.crypto.Key; + +import java.util.HashMap; +import java.util.Map; + +public class Statistics { + private long lastTouched = SystemClock.elapsedRealtime(); + private final Map<Key, Pair<Long, Long>> peerBytes = new HashMap<>(); + + Statistics() { } + + void add(final Key key, final long rx, final long tx) { + peerBytes.put(key, Pair.create(rx, tx)); + lastTouched = SystemClock.elapsedRealtime(); + } + + public boolean isStale() { + return SystemClock.elapsedRealtime() - lastTouched > 900; + } + + public Key[] peers() { + return peerBytes.keySet().toArray(new Key[0]); + } + + public long peerRx(final Key peer) { + if (!peerBytes.containsKey(peer)) + return 0; + return peerBytes.get(peer).first; + } + + public long peerTx(final Key peer) { + if (!peerBytes.containsKey(peer)) + return 0; + return peerBytes.get(peer).second; + } + + public long totalRx() { + long rx = 0; + for (final Pair<Long, Long> val : peerBytes.values()) { + rx += val.first; + } + return rx; + } + + public long totalTx() { + long tx = 0; + for (final Pair<Long, Long> val : peerBytes.values()) { + tx += val.second; + } + return tx; + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java b/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java new file mode 100644 index 00000000..af2f59f7 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java @@ -0,0 +1,46 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import java.util.regex.Pattern; + +/** + * Represents a WireGuard tunnel. + */ + +public interface Tunnel { + enum State { + DOWN, + TOGGLE, + UP; + + public static State of(final boolean running) { + return running ? UP : DOWN; + } + } + + int NAME_MAX_LENGTH = 15; + Pattern NAME_PATTERN = Pattern.compile("[a-zA-Z0-9_=+.-]{1,15}"); + + static boolean isNameInvalid(final CharSequence name) { + return !NAME_PATTERN.matcher(name).matches(); + } + + /** + * Get the name of the tunnel, which should always pass the !isNameInvalid test. + * + * @return The name of the tunnel. + */ + String getName(); + + /** + * React to a change in state of the tunnel. Should only be directly called by Backend. + * + * @param newState The new state of the tunnel. + * @return The new state of the tunnel. + */ + void onStateChange(State newState); +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java new file mode 100644 index 00000000..9695aab7 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java @@ -0,0 +1,158 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import androidx.annotation.Nullable; + +import android.content.Context; +import android.util.Log; + +import com.wireguard.android.backend.BackendException.Reason; +import com.wireguard.android.backend.Tunnel.State; +import com.wireguard.android.util.RootShell; +import com.wireguard.android.util.ToolsInstaller; +import com.wireguard.config.Config; +import com.wireguard.crypto.Key; + +import java.io.File; +import java.io.FileOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.HashMap; + +import java9.util.stream.Collectors; +import java9.util.stream.Stream; + +/** + * WireGuard backend that uses {@code wg-quick} to implement tunnel configuration. + */ + +public final class WgQuickBackend implements Backend { + private static final String TAG = "WireGuard/" + WgQuickBackend.class.getSimpleName(); + + private final RootShell rootShell; + private final ToolsInstaller toolsInstaller; + private final File localTemporaryDir; + private final Map<Tunnel, Config> runningConfigs = new HashMap<>(); + + public WgQuickBackend(final Context context, final RootShell rootShell, final ToolsInstaller toolsInstaller) { + localTemporaryDir = new File(context.getCacheDir(), "tmp"); + this.rootShell = rootShell; + this.toolsInstaller = toolsInstaller; + } + + @Override + public Set<String> getRunningTunnelNames() { + final List<String> output = new ArrayList<>(); + // Don't throw an exception here or nothing will show up in the UI. + try { + toolsInstaller.ensureToolsAvailable(); + if (rootShell.run(output, "wg show interfaces") != 0 || output.isEmpty()) + return Collections.emptySet(); + } catch (final Exception e) { + Log.w(TAG, "Unable to enumerate running tunnels", e); + return Collections.emptySet(); + } + // wg puts all interface names on the same line. Split them into separate elements. + return Stream.of(output.get(0).split(" ")).collect(Collectors.toUnmodifiableSet()); + } + + @Override + public State getState(final Tunnel tunnel) { + return getRunningTunnelNames().contains(tunnel.getName()) ? State.UP : State.DOWN; + } + + @Override + public Statistics getStatistics(final Tunnel tunnel) { + final Statistics stats = new Statistics(); + final Collection<String> output = new ArrayList<>(); + try { + if (rootShell.run(output, String.format("wg show '%s' transfer", tunnel.getName())) != 0) + return stats; + } catch (final Exception ignored) { + return stats; + } + for (final String line : output) { + final String[] parts = line.split("\\t"); + if (parts.length != 3) + continue; + try { + stats.add(Key.fromBase64(parts[0]), Long.parseLong(parts[1]), Long.parseLong(parts[2])); + } catch (final Exception ignored) { + } + } + return stats; + } + + @Override + public String getVersion() throws Exception { + final List<String> output = new ArrayList<>(); + if (rootShell.run(output, "cat /sys/module/wireguard/version") != 0 || output.isEmpty()) + throw new BackendException(Reason.UNKNOWN_KERNEL_MODULE_NAME); + return output.get(0); + } + + @Override + public State setState(final Tunnel tunnel, State state, @Nullable final Config config) throws Exception { + final State originalState = getState(tunnel); + final Config originalConfig = runningConfigs.get(tunnel); + + if (state == State.TOGGLE) + state = originalState == State.UP ? State.DOWN : State.UP; + if ((state == State.UP && originalState == State.UP && originalConfig != null && originalConfig == config) || + (state == State.DOWN && originalState == State.DOWN)) + return originalState; + if (state == State.UP) { + toolsInstaller.ensureToolsAvailable(); + if (originalState == State.UP) + setStateInternal(tunnel, originalConfig == null ? config : originalConfig, State.DOWN); + try { + setStateInternal(tunnel, config, State.UP); + } catch(final Exception e) { + if (originalState == State.UP && originalConfig != null) + setStateInternal(tunnel, originalConfig, State.UP); + throw e; + } + } else if (state == State.DOWN) { + setStateInternal(tunnel, originalConfig == null ? config : originalConfig, State.DOWN); + } + return state; + } + + private void setStateInternal(final Tunnel tunnel, @Nullable final Config config, final State state) throws Exception { + Log.i(TAG, "Bringing tunnel " + tunnel.getName() + " " + state); + + Objects.requireNonNull(config, "Trying to set state up with a null config"); + + final File tempFile = new File(localTemporaryDir, tunnel.getName() + ".conf"); + try (final FileOutputStream stream = new FileOutputStream(tempFile, false)) { + stream.write(config.toWgQuickString().getBytes(StandardCharsets.UTF_8)); + } + String command = String.format("wg-quick %s '%s'", + state.toString().toLowerCase(Locale.ENGLISH), tempFile.getAbsolutePath()); + if (state == State.UP) + command = "cat /sys/module/wireguard/version && " + command; + final int result = rootShell.run(null, command); + // noinspection ResultOfMethodCallIgnored + tempFile.delete(); + if (result != 0) + throw new BackendException(Reason.WG_QUICK_CONFIG_ERROR_CODE, result); + + if (state == State.UP) + runningConfigs.put(tunnel, config); + else + runningConfigs.remove(tunnel); + + tunnel.onStateChange(state); + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/util/AsyncWorker.java b/tunnel/src/main/java/com/wireguard/android/util/AsyncWorker.java new file mode 100644 index 00000000..1d041851 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/util/AsyncWorker.java @@ -0,0 +1,63 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.os.Handler; + +import java.util.concurrent.Executor; + +import java9.util.concurrent.CompletableFuture; +import java9.util.concurrent.CompletionStage; + +/** + * Helper class for running asynchronous tasks and ensuring they are completed on the main thread. + */ + +public class AsyncWorker { + private final Executor executor; + private final Handler handler; + + public AsyncWorker(final Executor executor, final Handler handler) { + this.executor = executor; + this.handler = handler; + } + + public CompletionStage<Void> runAsync(final AsyncRunnable<?> runnable) { + final CompletableFuture<Void> future = new CompletableFuture<>(); + executor.execute(() -> { + try { + runnable.run(); + handler.post(() -> future.complete(null)); + } catch (final Throwable t) { + handler.post(() -> future.completeExceptionally(t)); + } + }); + return future; + } + + public <T> CompletionStage<T> supplyAsync(final AsyncSupplier<T, ?> supplier) { + final CompletableFuture<T> future = new CompletableFuture<>(); + executor.execute(() -> { + try { + final T result = supplier.get(); + handler.post(() -> future.complete(result)); + } catch (final Throwable t) { + handler.post(() -> future.completeExceptionally(t)); + } + }); + return future; + } + + @FunctionalInterface + public interface AsyncRunnable<E extends Throwable> { + void run() throws E; + } + + @FunctionalInterface + public interface AsyncSupplier<T, E extends Throwable> { + T get() throws E; + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/util/RootShell.java b/tunnel/src/main/java/com/wireguard/android/util/RootShell.java new file mode 100644 index 00000000..1fc2c9f2 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/util/RootShell.java @@ -0,0 +1,211 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.content.Context; +import androidx.annotation.Nullable; +import android.util.Log; + +import com.wireguard.android.util.RootShell.RootShellException.Reason; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.UUID; + +/** + * Helper class for running commands as root. + */ + +public class RootShell { + private static final String SU = "su"; + private static final String TAG = "WireGuard/" + RootShell.class.getSimpleName(); + + private final File localBinaryDir; + private final File localTemporaryDir; + private final Object lock = new Object(); + private final String preamble; + @Nullable private Process process; + @Nullable private BufferedReader stderr; + @Nullable private OutputStreamWriter stdin; + @Nullable private BufferedReader stdout; + + public RootShell(final Context context) { + localBinaryDir = new File(context.getCodeCacheDir(), "bin"); + localTemporaryDir = new File(context.getCacheDir(), "tmp"); + preamble = String.format("export CALLING_PACKAGE=%s PATH=\"%s:$PATH\" TMPDIR='%s'; id -u\n", + context.getPackageName(), localBinaryDir, localTemporaryDir); + } + + private static boolean isExecutableInPath(final String name) { + final String path = System.getenv("PATH"); + if (path == null) + return false; + for (final String dir : path.split(":")) + if (new File(dir, name).canExecute()) + return true; + return false; + } + + private boolean isRunning() { + synchronized (lock) { + try { + // Throws an exception if the process hasn't finished yet. + if (process != null) + process.exitValue(); + return false; + } catch (final IllegalThreadStateException ignored) { + // The existing process is still running. + return true; + } + } + } + + /** + * Run a command in a root shell. + * + * @param output Lines read from stdout are appended to this list. Pass null if the + * output from the shell is not important. + * @param command Command to run as root. + * @return The exit value of the command. + */ + public int run(@Nullable final Collection<String> output, final String command) + throws IOException, RootShellException { + synchronized (lock) { + /* Start inside synchronized block to prevent a concurrent call to stop(). */ + start(); + final String marker = UUID.randomUUID().toString(); + final String script = "echo " + marker + "; echo " + marker + " >&2; (" + command + + "); ret=$?; echo " + marker + " $ret; echo " + marker + " $ret >&2\n"; + Log.v(TAG, "executing: " + command); + stdin.write(script); + stdin.flush(); + String line; + int errnoStdout = Integer.MIN_VALUE; + int errnoStderr = Integer.MAX_VALUE; + int markersSeen = 0; + while ((line = stdout.readLine()) != null) { + if (line.startsWith(marker)) { + ++markersSeen; + if (line.length() > marker.length() + 1) { + errnoStdout = Integer.valueOf(line.substring(marker.length() + 1)); + break; + } + } else if (markersSeen > 0) { + if (output != null) + output.add(line); + Log.v(TAG, "stdout: " + line); + } + } + while ((line = stderr.readLine()) != null) { + if (line.startsWith(marker)) { + ++markersSeen; + if (line.length() > marker.length() + 1) { + errnoStderr = Integer.valueOf(line.substring(marker.length() + 1)); + break; + } + } else if (markersSeen > 2) { + Log.v(TAG, "stderr: " + line); + } + } + if (markersSeen != 4) + throw new RootShellException(Reason.SHELL_MARKER_COUNT_ERROR, markersSeen); + if (errnoStdout != errnoStderr) + throw new RootShellException(Reason.SHELL_EXIT_STATUS_READ_ERROR); + Log.v(TAG, "exit: " + errnoStdout); + return errnoStdout; + } + } + + public void start() throws IOException, RootShellException { + if (!isExecutableInPath(SU)) + throw new RootShellException(Reason.NO_ROOT_ACCESS); + synchronized (lock) { + if (isRunning()) + return; + if (!localBinaryDir.isDirectory() && !localBinaryDir.mkdirs()) + throw new RootShellException(Reason.CREATE_BIN_DIR_ERROR); + if (!localTemporaryDir.isDirectory() && !localTemporaryDir.mkdirs()) + throw new RootShellException(Reason.CREATE_TEMP_DIR_ERROR); + try { + final ProcessBuilder builder = new ProcessBuilder().command(SU); + builder.environment().put("LC_ALL", "C"); + try { + process = builder.start(); + } catch (final IOException e) { + // A failure at this stage means the device isn't rooted. + final RootShellException rse = new RootShellException(Reason.NO_ROOT_ACCESS); + rse.initCause(e); + throw rse; + } + stdin = new OutputStreamWriter(process.getOutputStream(), StandardCharsets.UTF_8); + stdout = new BufferedReader(new InputStreamReader(process.getInputStream(), + StandardCharsets.UTF_8)); + stderr = new BufferedReader(new InputStreamReader(process.getErrorStream(), + StandardCharsets.UTF_8)); + stdin.write(preamble); + stdin.flush(); + // Check that the shell started successfully. + final String uid = stdout.readLine(); + if (!"0".equals(uid)) { + Log.w(TAG, "Root check did not return correct UID: " + uid); + throw new RootShellException(Reason.NO_ROOT_ACCESS); + } + if (!isRunning()) { + String line; + while ((line = stderr.readLine()) != null) { + Log.w(TAG, "Root check returned an error: " + line); + if (line.contains("Permission denied")) + throw new RootShellException(Reason.NO_ROOT_ACCESS); + } + throw new RootShellException(Reason.SHELL_START_ERROR, process.exitValue()); + } + } catch (final IOException | RootShellException e) { + stop(); + throw e; + } + } + } + + public void stop() { + synchronized (lock) { + if (process != null) { + process.destroy(); + process = null; + } + } + } + + public static class RootShellException extends Exception { + public enum Reason { + NO_ROOT_ACCESS, + SHELL_MARKER_COUNT_ERROR, + SHELL_EXIT_STATUS_READ_ERROR, + SHELL_START_ERROR, + CREATE_BIN_DIR_ERROR, + CREATE_TEMP_DIR_ERROR + } + private final Reason reason; + private final Object[] format; + public RootShellException(final Reason reason, final Object ...format) { + this.reason = reason; + this.format = format; + } + public boolean isIORelated() { + return reason != Reason.NO_ROOT_ACCESS; + } + public Reason getReason() { + return reason; + } + public Object[] getFormat() { + return format; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/util/SharedLibraryLoader.java b/tunnel/src/main/java/com/wireguard/android/util/SharedLibraryLoader.java new file mode 100644 index 00000000..93e44b64 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/util/SharedLibraryLoader.java @@ -0,0 +1,94 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.content.Context; +import android.os.Build; +import android.util.Log; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; + +public final class SharedLibraryLoader { + private static final String TAG = "WireGuard/" + SharedLibraryLoader.class.getSimpleName(); + + private SharedLibraryLoader() { + } + + public static boolean extractLibrary(final Context context, final String libName, final File destination) throws IOException { + final Collection<String> apks = new HashSet<>(); + if (context.getApplicationInfo().sourceDir != null) + apks.add(context.getApplicationInfo().sourceDir); + if (context.getApplicationInfo().splitSourceDirs != null) + apks.addAll(Arrays.asList(context.getApplicationInfo().splitSourceDirs)); + + for (final String abi : Build.SUPPORTED_ABIS) { + for (final String apk : apks) { + final ZipFile zipFile; + try { + zipFile = new ZipFile(new File(apk), ZipFile.OPEN_READ); + } catch (final IOException e) { + throw new RuntimeException(e); + } + + final String mappedLibName = System.mapLibraryName(libName); + final byte[] buffer = new byte[1024 * 32]; + final String libZipPath = "lib" + File.separatorChar + abi + File.separatorChar + mappedLibName; + final ZipEntry zipEntry = zipFile.getEntry(libZipPath); + if (zipEntry == null) + continue; + Log.d(TAG, "Extracting apk:/" + libZipPath + " to " + destination.getAbsolutePath()); + try (final FileOutputStream out = new FileOutputStream(destination); + final InputStream in = zipFile.getInputStream(zipEntry)) { + int len; + while ((len = in.read(buffer)) != -1) { + out.write(buffer, 0, len); + } + out.getFD().sync(); + } + zipFile.close(); + return true; + } + } + return false; + } + + public static void loadSharedLibrary(final Context context, final String libName) { + Throwable noAbiException; + try { + System.loadLibrary(libName); + return; + } catch (final UnsatisfiedLinkError e) { + Log.d(TAG, "Failed to load library normally, so attempting to extract from apk", e); + noAbiException = e; + } + File f = null; + try { + f = File.createTempFile("lib", ".so", context.getCodeCacheDir()); + if (extractLibrary(context, libName, f)) { + System.load(f.getAbsolutePath()); + return; + } + } catch (final Exception e) { + Log.d(TAG, "Failed to load library apk:/" + libName, e); + noAbiException = e; + } finally { + if (f != null) + // noinspection ResultOfMethodCallIgnored + f.delete(); + } + if (noAbiException instanceof RuntimeException) + throw (RuntimeException) noAbiException; + throw new RuntimeException(noAbiException); + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/util/ToolsInstaller.java b/tunnel/src/main/java/com/wireguard/android/util/ToolsInstaller.java new file mode 100644 index 00000000..ac18cabf --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/util/ToolsInstaller.java @@ -0,0 +1,196 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.content.Context; +import androidx.annotation.Nullable; +import android.system.OsConstants; +import android.util.Log; + +import com.wireguard.android.util.RootShell.RootShellException; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +/** + * Helper to install WireGuard tools to the system partition. + */ + +public final class ToolsInstaller { + public static final int ERROR = 0x0; + public static final int MAGISK = 0x4; + public static final int NO = 0x2; + public static final int SYSTEM = 0x8; + public static final int YES = 0x1; + private static final String[] EXECUTABLES = {"wg", "wg-quick"}; + private static final File[] INSTALL_DIRS = { + new File("/system/xbin"), + new File("/system/bin"), + }; + @Nullable private static final File INSTALL_DIR = getInstallDir(); + private static final String TAG = "WireGuard/" + ToolsInstaller.class.getSimpleName(); + + private final Context context; + private final RootShell rootShell; + private final File localBinaryDir; + private final Object lock = new Object(); + @Nullable private Boolean areToolsAvailable; + @Nullable private Boolean installAsMagiskModule; + + public ToolsInstaller(final Context context, final RootShell rootShell) { + localBinaryDir = new File(context.getCodeCacheDir(), "bin"); + this.context = context; + this.rootShell = rootShell; + } + + @Nullable + private static File getInstallDir() { + final String path = System.getenv("PATH"); + if (path == null) + return INSTALL_DIRS[0]; + final List<String> paths = Arrays.asList(path.split(":")); + for (final File dir : INSTALL_DIRS) { + if (paths.contains(dir.getPath()) && dir.isDirectory()) + return dir; + } + return null; + } + + public int areInstalled() throws RootShellException { + if (INSTALL_DIR == null) + return ERROR; + final StringBuilder script = new StringBuilder(); + for (final String name : EXECUTABLES) { + script.append(String.format("cmp -s '%s' '%s' && ", + new File(localBinaryDir, name).getAbsolutePath(), + new File(INSTALL_DIR, name).getAbsolutePath())); + } + script.append("exit ").append(OsConstants.EALREADY).append(';'); + try { + final int ret = rootShell.run(null, script.toString()); + if (ret == OsConstants.EALREADY) + return willInstallAsMagiskModule() ? YES | MAGISK : YES | SYSTEM; + else + return willInstallAsMagiskModule() ? NO | MAGISK : NO | SYSTEM; + } catch (final IOException ignored) { + return ERROR; + } catch (final RootShellException e) { + if (e.isIORelated()) + return ERROR; + throw e; + } + } + + public void ensureToolsAvailable() throws FileNotFoundException { + synchronized (lock) { + if (areToolsAvailable == null) { + try { + Log.d(TAG, extract() ? "Tools are now extracted into our private binary dir" : + "Tools were already extracted into our private binary dir"); + areToolsAvailable = true; + } catch (final IOException e) { + Log.e(TAG, "The wg and wg-quick tools are not available", e); + areToolsAvailable = false; + } + } + if (!areToolsAvailable) + throw new FileNotFoundException("Required tools unavailable"); + } + } + + public int install() throws RootShellException, IOException { + if (!context.getPackageName().startsWith("com.wireguard.")) + throw new SecurityException("The tools may only be installed system-wide from the main WireGuard app."); + return willInstallAsMagiskModule() ? installMagisk() : installSystem(); + } + + private int installMagisk() throws RootShellException, IOException { + extract(); + final StringBuilder script = new StringBuilder("set -ex; "); + + script.append("trap 'rm -rf /sbin/.magisk/img/wireguard' INT TERM EXIT; "); + script.append(String.format("rm -rf /sbin/.magisk/img/wireguard/; mkdir -p /sbin/.magisk/img/wireguard%s; ", INSTALL_DIR)); + script.append("printf 'name=WireGuard Command Line Tools\nversion=1.0\nversionCode=1\nauthor=zx2c4\ndescription=Command line tools for WireGuard\nminMagisk=1500\n' > /sbin/.magisk/img/wireguard/module.prop; "); + script.append("touch /sbin/.magisk/img/wireguard/auto_mount; "); + for (final String name : EXECUTABLES) { + final File destination = new File("/sbin/.magisk/img/wireguard" + INSTALL_DIR, name); + script.append(String.format("cp '%s' '%s'; chmod 755 '%s'; chcon 'u:object_r:system_file:s0' '%s' || true; ", + new File(localBinaryDir, name), destination, destination, destination)); + } + script.append("trap - INT TERM EXIT;"); + + try { + return rootShell.run(null, script.toString()) == 0 ? YES | MAGISK : ERROR; + } catch (final IOException ignored) { + return ERROR; + } catch (final RootShellException e) { + if (e.isIORelated()) + return ERROR; + throw e; + } + } + + private int installSystem() throws RootShellException, IOException { + if (INSTALL_DIR == null) + return OsConstants.ENOENT; + extract(); + final StringBuilder script = new StringBuilder("set -ex; "); + script.append("trap 'mount -o ro,remount /system' EXIT; mount -o rw,remount /system; "); + for (final String name : EXECUTABLES) { + final File destination = new File(INSTALL_DIR, name); + script.append(String.format("cp '%s' '%s'; chmod 755 '%s'; restorecon '%s' || true; ", + new File(localBinaryDir, name), destination, destination, destination)); + } + try { + return rootShell.run(null, script.toString()) == 0 ? YES | SYSTEM : ERROR; + } catch (final IOException ignored) { + return ERROR; + } catch (final RootShellException e) { + if (e.isIORelated()) + return ERROR; + throw e; + } + } + + public boolean extract() throws IOException { + localBinaryDir.mkdirs(); + final File files[] = new File[EXECUTABLES.length]; + final File tempFiles[] = new File[EXECUTABLES.length]; + boolean allExist = true; + for (int i = 0; i < files.length; ++i) { + files[i] = new File(localBinaryDir, EXECUTABLES[i]); + tempFiles[i] = new File(localBinaryDir, EXECUTABLES[i] + ".tmp"); + allExist &= files[i].exists(); + } + if (allExist) + return false; + for (int i = 0; i < files.length; ++i) { + if (!SharedLibraryLoader.extractLibrary(context, EXECUTABLES[i], tempFiles[i])) + throw new FileNotFoundException("Unable to find " + EXECUTABLES[i]); + if (!tempFiles[i].setExecutable(true, false)) + throw new IOException("Unable to mark " + tempFiles[i].getAbsolutePath() + " as executable"); + if (!tempFiles[i].renameTo(files[i])) + throw new IOException("Unable to rename " + tempFiles[i].getAbsolutePath() + " to " + files[i].getAbsolutePath()); + } + return true; + } + + private boolean willInstallAsMagiskModule() { + synchronized (lock) { + if (installAsMagiskModule == null) { + try { + installAsMagiskModule = rootShell.run(null, "[ -d /sbin/.magisk/mirror -a -d /sbin/.magisk/img -a ! -f /cache/.disable_magisk ]") == OsConstants.EXIT_SUCCESS; + } catch (final Exception ignored) { + installAsMagiskModule = false; + } + } + return installAsMagiskModule; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/Attribute.java b/tunnel/src/main/java/com/wireguard/config/Attribute.java new file mode 100644 index 00000000..1e9e25f0 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/Attribute.java @@ -0,0 +1,58 @@ +/* + * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import java.util.Iterator; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import java9.util.Optional; + +public final class Attribute { + private static final Pattern LINE_PATTERN = Pattern.compile("(\\w+)\\s*=\\s*([^\\s#][^#]*)"); + private static final Pattern LIST_SEPARATOR = Pattern.compile("\\s*,\\s*"); + + private final String key; + private final String value; + + private Attribute(final String key, final String value) { + this.key = key; + this.value = value; + } + + public static String join(final Iterable<?> values) { + final Iterator<?> it = values.iterator(); + if (!it.hasNext()) { + return ""; + } + final StringBuilder sb = new StringBuilder(); + sb.append(it.next()); + while (it.hasNext()) { + sb.append(", "); + sb.append(it.next()); + } + return sb.toString(); + } + + public static Optional<Attribute> parse(final CharSequence line) { + final Matcher matcher = LINE_PATTERN.matcher(line); + if (!matcher.matches()) + return Optional.empty(); + return Optional.of(new Attribute(matcher.group(1), matcher.group(2))); + } + + public static String[] split(final CharSequence value) { + return LIST_SEPARATOR.split(value); + } + + public String getKey() { + return key; + } + + public String getValue() { + return value; + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/BadConfigException.java b/tunnel/src/main/java/com/wireguard/config/BadConfigException.java new file mode 100644 index 00000000..6d41b065 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/BadConfigException.java @@ -0,0 +1,118 @@ +/* + * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import androidx.annotation.Nullable; + +import com.wireguard.crypto.KeyFormatException; + +public class BadConfigException extends Exception { + private final Location location; + private final Reason reason; + private final Section section; + @Nullable private final CharSequence text; + + private BadConfigException(final Section section, final Location location, + final Reason reason, @Nullable final CharSequence text, + @Nullable final Throwable cause) { + super(cause); + this.section = section; + this.location = location; + this.reason = reason; + this.text = text; + } + + public BadConfigException(final Section section, final Location location, + final Reason reason, @Nullable final CharSequence text) { + this(section, location, reason, text, null); + } + + public BadConfigException(final Section section, final Location location, + final KeyFormatException cause) { + this(section, location, Reason.INVALID_KEY, null, cause); + } + + public BadConfigException(final Section section, final Location location, + @Nullable final CharSequence text, + final NumberFormatException cause) { + this(section, location, Reason.INVALID_NUMBER, text, cause); + } + + public BadConfigException(final Section section, final Location location, + final ParseException cause) { + this(section, location, Reason.INVALID_VALUE, cause.getText(), cause); + } + + public Location getLocation() { + return location; + } + + public Reason getReason() { + return reason; + } + + public Section getSection() { + return section; + } + + @Nullable + public CharSequence getText() { + return text; + } + + public enum Location { + TOP_LEVEL(""), + ADDRESS("Address"), + ALLOWED_IPS("AllowedIPs"), + DNS("DNS"), + ENDPOINT("Endpoint"), + EXCLUDED_APPLICATIONS("ExcludedApplications"), + LISTEN_PORT("ListenPort"), + MTU("MTU"), + PERSISTENT_KEEPALIVE("PersistentKeepalive"), + PRE_SHARED_KEY("PresharedKey"), + PRIVATE_KEY("PrivateKey"), + PUBLIC_KEY("PublicKey"); + + private final String name; + + Location(final String name) { + this.name = name; + } + + public String getName() { + return name; + } + } + + public enum Reason { + INVALID_KEY, + INVALID_NUMBER, + INVALID_VALUE, + MISSING_ATTRIBUTE, + MISSING_SECTION, + MISSING_VALUE, + SYNTAX_ERROR, + UNKNOWN_ATTRIBUTE, + UNKNOWN_SECTION + } + + public enum Section { + CONFIG("Config"), + INTERFACE("Interface"), + PEER("Peer"); + + private final String name; + + Section(final String name) { + this.name = name; + } + + public String getName() { + return name; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/Config.java b/tunnel/src/main/java/com/wireguard/config/Config.java new file mode 100644 index 00000000..62651b08 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/Config.java @@ -0,0 +1,221 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import androidx.annotation.Nullable; + +import com.wireguard.config.BadConfigException.Location; +import com.wireguard.config.BadConfigException.Reason; +import com.wireguard.config.BadConfigException.Section; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +/** + * Represents the contents of a wg-quick configuration file, made up of one or more "Interface" + * sections (combined together), and zero or more "Peer" sections (treated individually). + * <p> + * Instances of this class are immutable. + */ +public final class Config { + private final Interface interfaze; + private final List<Peer> peers; + + private Config(final Builder builder) { + interfaze = Objects.requireNonNull(builder.interfaze, "An [Interface] section is required"); + // Defensively copy to ensure immutability even if the Builder is reused. + peers = Collections.unmodifiableList(new ArrayList<>(builder.peers)); + } + + /** + * Parses an series of "Interface" and "Peer" sections into a {@code Config}. Throws + * {@link BadConfigException} if the input is not well-formed or contains data that cannot + * be parsed. + * + * @param stream a stream of UTF-8 text that is interpreted as a WireGuard configuration + * @return a {@code Config} instance representing the supplied configuration + */ + public static Config parse(final InputStream stream) + throws IOException, BadConfigException { + return parse(new BufferedReader(new InputStreamReader(stream))); + } + + /** + * Parses an series of "Interface" and "Peer" sections into a {@code Config}. Throws + * {@link BadConfigException} if the input is not well-formed or contains data that cannot + * be parsed. + * + * @param reader a BufferedReader of UTF-8 text that is interpreted as a WireGuard configuration + * @return a {@code Config} instance representing the supplied configuration + */ + public static Config parse(final BufferedReader reader) + throws IOException, BadConfigException { + final Builder builder = new Builder(); + final Collection<String> interfaceLines = new ArrayList<>(); + final Collection<String> peerLines = new ArrayList<>(); + boolean inInterfaceSection = false; + boolean inPeerSection = false; + @Nullable String line; + while ((line = reader.readLine()) != null) { + final int commentIndex = line.indexOf('#'); + if (commentIndex != -1) + line = line.substring(0, commentIndex); + line = line.trim(); + if (line.isEmpty()) + continue; + if (line.startsWith("[")) { + // Consume all [Peer] lines read so far. + if (inPeerSection) { + builder.parsePeer(peerLines); + peerLines.clear(); + } + if ("[Interface]".equalsIgnoreCase(line)) { + inInterfaceSection = true; + inPeerSection = false; + } else if ("[Peer]".equalsIgnoreCase(line)) { + inInterfaceSection = false; + inPeerSection = true; + } else { + throw new BadConfigException(Section.CONFIG, Location.TOP_LEVEL, + Reason.UNKNOWN_SECTION, line); + } + } else if (inInterfaceSection) { + interfaceLines.add(line); + } else if (inPeerSection) { + peerLines.add(line); + } else { + throw new BadConfigException(Section.CONFIG, Location.TOP_LEVEL, + Reason.UNKNOWN_SECTION, line); + } + } + if (inPeerSection) + builder.parsePeer(peerLines); + else if (!inInterfaceSection) + throw new BadConfigException(Section.CONFIG, Location.TOP_LEVEL, + Reason.MISSING_SECTION, null); + // Combine all [Interface] sections in the file. + builder.parseInterface(interfaceLines); + return builder.build(); + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Config)) + return false; + final Config other = (Config) obj; + return interfaze.equals(other.interfaze) && peers.equals(other.peers); + } + + /** + * Returns the interface section of the configuration. + * + * @return the interface configuration + */ + public Interface getInterface() { + return interfaze; + } + + /** + * Returns a list of the configuration's peer sections. + * + * @return a list of {@link Peer}s + */ + public List<Peer> getPeers() { + return peers; + } + + @Override + public int hashCode() { + return 31 * interfaze.hashCode() + peers.hashCode(); + } + + /** + * Converts the {@code Config} into a string suitable for debugging purposes. The {@code Config} + * is identified by its interface's public key and the number of peers it has. + * + * @return a concise single-line identifier for the {@code Config} + */ + @Override + public String toString() { + return "(Config " + interfaze + " (" + peers.size() + " peers))"; + } + + /** + * Converts the {@code Config} into a string suitable for use as a {@code wg-quick} + * configuration file. + * + * @return the {@code Config} represented as one [Interface] and zero or more [Peer] sections + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + sb.append("[Interface]\n").append(interfaze.toWgQuickString()); + for (final Peer peer : peers) + sb.append("\n[Peer]\n").append(peer.toWgQuickString()); + return sb.toString(); + } + + /** + * Serializes the {@code Config} for use with the WireGuard cross-platform userspace API. + * + * @return the {@code Config} represented as a series of "key=value" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + sb.append(interfaze.toWgUserspaceString()); + sb.append("replace_peers=true\n"); + for (final Peer peer : peers) + sb.append(peer.toWgUserspaceString()); + return sb.toString(); + } + + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // Defaults to an empty set. + private final Set<Peer> peers = new LinkedHashSet<>(); + // No default; must be provided before building. + @Nullable private Interface interfaze; + + public Builder addPeer(final Peer peer) { + peers.add(peer); + return this; + } + + public Builder addPeers(final Collection<Peer> peers) { + this.peers.addAll(peers); + return this; + } + + public Config build() { + if (interfaze == null) + throw new IllegalArgumentException("An [Interface] section is required"); + return new Config(this); + } + + public Builder parseInterface(final Iterable<? extends CharSequence> lines) + throws BadConfigException { + return setInterface(Interface.parse(lines)); + } + + public Builder parsePeer(final Iterable<? extends CharSequence> lines) + throws BadConfigException { + return addPeer(Peer.parse(lines)); + } + + public Builder setInterface(final Interface interfaze) { + this.interfaze = interfaze; + return this; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/InetAddresses.java b/tunnel/src/main/java/com/wireguard/config/InetAddresses.java new file mode 100644 index 00000000..5303e27f --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/InetAddresses.java @@ -0,0 +1,71 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import java.lang.reflect.Method; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.regex.Pattern; + +import javax.annotation.Nullable; + +/** + * Utility methods for creating instances of {@link InetAddress}. + */ +public final class InetAddresses { + @Nullable private static final Method PARSER_METHOD; + private static final Pattern WONT_TOUCH_RESOLVER = Pattern.compile("^(((([0-9A-Fa-f]{1,4}:){7}([0-9A-Fa-f]{1,4}|:))|(([0-9A-Fa-f]{1,4}:){6}(:[0-9A-Fa-f]{1,4}|((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3})|:))|(([0-9A-Fa-f]{1,4}:){5}(((:[0-9A-Fa-f]{1,4}){1,2})|:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3})|:))|(([0-9A-Fa-f]{1,4}:){4}(((:[0-9A-Fa-f]{1,4}){1,3})|((:[0-9A-Fa-f]{1,4})?:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){3}(((:[0-9A-Fa-f]{1,4}){1,4})|((:[0-9A-Fa-f]{1,4}){0,2}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){2}(((:[0-9A-Fa-f]{1,4}){1,5})|((:[0-9A-Fa-f]{1,4}){0,3}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(([0-9A-Fa-f]{1,4}:){1}(((:[0-9A-Fa-f]{1,4}){1,6})|((:[0-9A-Fa-f]{1,4}){0,4}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:))|(:(((:[0-9A-Fa-f]{1,4}){1,7})|((:[0-9A-Fa-f]{1,4}){0,5}:((25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)(\\.(25[0-5]|2[0-4]\\d|1\\d\\d|[1-9]?\\d)){3}))|:)))(%.+)?)|((?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?))$"); + + static { + Method m = null; + try { + if (android.os.Build.VERSION.SDK_INT < android.os.Build.VERSION_CODES.Q) + // noinspection JavaReflectionMemberAccess + m = InetAddress.class.getMethod("parseNumericAddress", String.class); + } catch (final Exception ignored) { + } + PARSER_METHOD = m; + } + + private InetAddresses() { } + + /** + * Parses a numeric IPv4 or IPv6 address without performing any DNS lookups. + * + * @param address a string representing the IP address + * @return an instance of {@link Inet4Address} or {@link Inet6Address}, as appropriate + */ + public static InetAddress parse(final String address) throws ParseException { + if (address.isEmpty()) + throw new ParseException(InetAddress.class, address, "Empty address"); + try { + if (android.os.Build.VERSION.SDK_INT >= android.os.Build.VERSION_CODES.Q) + return android.net.InetAddresses.parseNumericAddress(address); + else if (PARSER_METHOD != null) + return (InetAddress) PARSER_METHOD.invoke(null, address); + else + throw new NoSuchMethodException("parseNumericAddress"); + } catch (final IllegalArgumentException e) { + throw new ParseException(InetAddress.class, address, e); + } catch (final Exception e) { + final Throwable cause = e.getCause(); + // Re-throw parsing exceptions with the original type, as callers might try to catch + // them. On the other hand, callers cannot be expected to handle reflection failures. + if (cause instanceof IllegalArgumentException) + throw new ParseException(InetAddress.class, address, cause); + try { + if (WONT_TOUCH_RESOLVER.matcher(address).matches()) + return InetAddress.getByName(address); + else + throw new ParseException(InetAddress.class, address, "Not an IP address"); + } catch (final UnknownHostException f) { + throw new ParseException(InetAddress.class, address, f); + } + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java b/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java new file mode 100644 index 00000000..a442258e --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java @@ -0,0 +1,125 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import androidx.annotation.Nullable; + +import org.threeten.bp.Duration; +import org.threeten.bp.Instant; + +import java.net.Inet4Address; +import java.net.InetAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; +import java.util.regex.Pattern; + +import java9.util.Optional; + + +/** + * An external endpoint (host and port) used to connect to a WireGuard {@link Peer}. + * <p> + * Instances of this class are externally immutable. + */ +public final class InetEndpoint { + private static final Pattern BARE_IPV6 = Pattern.compile("^[^\\[\\]]*:[^\\[\\]]*"); + private static final Pattern FORBIDDEN_CHARACTERS = Pattern.compile("[/?#]"); + + private final String host; + private final boolean isResolved; + private final Object lock = new Object(); + private final int port; + private Instant lastResolution = Instant.EPOCH; + @Nullable private InetEndpoint resolved; + + private InetEndpoint(final String host, final boolean isResolved, final int port) { + this.host = host; + this.isResolved = isResolved; + this.port = port; + } + + public static InetEndpoint parse(final String endpoint) throws ParseException { + if (FORBIDDEN_CHARACTERS.matcher(endpoint).find()) + throw new ParseException(InetEndpoint.class, endpoint, "Forbidden characters"); + final URI uri; + try { + uri = new URI("wg://" + endpoint); + } catch (final URISyntaxException e) { + throw new IllegalArgumentException(e); + } + if (uri.getPort() < 0 || uri.getPort() > 65535) + throw new ParseException(InetEndpoint.class, endpoint, "Missing/invalid port number"); + try { + InetAddresses.parse(uri.getHost()); + // Parsing ths host as a numeric address worked, so we don't need to do DNS lookups. + return new InetEndpoint(uri.getHost(), true, uri.getPort()); + } catch (final ParseException ignored) { + // Failed to parse the host as a numeric address, so it must be a DNS hostname/FQDN. + return new InetEndpoint(uri.getHost(), false, uri.getPort()); + } + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof InetEndpoint)) + return false; + final InetEndpoint other = (InetEndpoint) obj; + return host.equals(other.host) && port == other.port; + } + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + /** + * Generate an {@code InetEndpoint} instance with the same port and the host resolved using DNS + * to a numeric address. If the host is already numeric, the existing instance may be returned. + * Because this function may perform network I/O, it must not be called from the main thread. + * + * @return the resolved endpoint, or {@link Optional#empty()} + */ + public Optional<InetEndpoint> getResolved() { + if (isResolved) + return Optional.of(this); + synchronized (lock) { + //TODO(zx2c4): Implement a real timeout mechanism using DNS TTL + if (Duration.between(lastResolution, Instant.now()).toMinutes() > 1) { + try { + // Prefer v4 endpoints over v6 to work around DNS64 and IPv6 NAT issues. + final InetAddress[] candidates = InetAddress.getAllByName(host); + InetAddress address = candidates[0]; + for (final InetAddress candidate : candidates) { + if (candidate instanceof Inet4Address) { + address = candidate; + break; + } + } + resolved = new InetEndpoint(address.getHostAddress(), true, port); + lastResolution = Instant.now(); + } catch (final UnknownHostException e) { + resolved = null; + } + } + return Optional.ofNullable(resolved); + } + } + + @Override + public int hashCode() { + return host.hashCode() ^ port; + } + + @Override + public String toString() { + final boolean isBareIpv6 = isResolved && BARE_IPV6.matcher(host).matches(); + return (isBareIpv6 ? '[' + host + ']' : host) + ':' + port; + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/InetNetwork.java b/tunnel/src/main/java/com/wireguard/config/InetNetwork.java new file mode 100644 index 00000000..f89322fd --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/InetNetwork.java @@ -0,0 +1,76 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import java.net.Inet4Address; +import java.net.InetAddress; + +/** + * An Internet network, denoted by its address and netmask + * <p> + * Instances of this class are immutable. + */ +public final class InetNetwork { + private final InetAddress address; + private final int mask; + + private InetNetwork(final InetAddress address, final int mask) { + this.address = address; + this.mask = mask; + } + + public static InetNetwork parse(final String network) throws ParseException { + final int slash = network.lastIndexOf('/'); + final String maskString; + final int rawMask; + final String rawAddress; + if (slash >= 0) { + maskString = network.substring(slash + 1); + try { + rawMask = Integer.parseInt(maskString, 10); + } catch (final NumberFormatException ignored) { + throw new ParseException(Integer.class, maskString); + } + rawAddress = network.substring(0, slash); + } else { + maskString = ""; + rawMask = -1; + rawAddress = network; + } + final InetAddress address = InetAddresses.parse(rawAddress); + final int maxMask = (address instanceof Inet4Address) ? 32 : 128; + if (rawMask > maxMask) + throw new ParseException(InetNetwork.class, maskString, "Invalid network mask"); + final int mask = rawMask >= 0 && rawMask <= maxMask ? rawMask : maxMask; + return new InetNetwork(address, mask); + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof InetNetwork)) + return false; + final InetNetwork other = (InetNetwork) obj; + return address.equals(other.address) && mask == other.mask; + } + + public InetAddress getAddress() { + return address; + } + + public int getMask() { + return mask; + } + + @Override + public int hashCode() { + return address.hashCode() ^ mask; + } + + @Override + public String toString() { + return address.getHostAddress() + '/' + mask; + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/Interface.java b/tunnel/src/main/java/com/wireguard/config/Interface.java new file mode 100644 index 00000000..54944424 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/Interface.java @@ -0,0 +1,355 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import androidx.annotation.Nullable; + +import com.wireguard.config.BadConfigException.Location; +import com.wireguard.config.BadConfigException.Reason; +import com.wireguard.config.BadConfigException.Section; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyFormatException; +import com.wireguard.crypto.KeyPair; + +import java.net.InetAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Set; + +import java9.util.Lists; +import java9.util.Optional; +import java9.util.stream.Collectors; +import java9.util.stream.StreamSupport; + +/** + * Represents the configuration for a WireGuard interface (an [Interface] block). Interfaces must + * have a private key (used to initialize a {@code KeyPair}), and may optionally have several other + * attributes. + * <p> + * Instances of this class are immutable. + */ +public final class Interface { + private static final int MAX_UDP_PORT = 65535; + private static final int MIN_UDP_PORT = 0; + + private final Set<InetNetwork> addresses; + private final Set<InetAddress> dnsServers; + private final Set<String> excludedApplications; + private final KeyPair keyPair; + private final Optional<Integer> listenPort; + private final Optional<Integer> mtu; + + private Interface(final Builder builder) { + // Defensively copy to ensure immutability even if the Builder is reused. + addresses = Collections.unmodifiableSet(new LinkedHashSet<>(builder.addresses)); + dnsServers = Collections.unmodifiableSet(new LinkedHashSet<>(builder.dnsServers)); + excludedApplications = Collections.unmodifiableSet(new LinkedHashSet<>(builder.excludedApplications)); + keyPair = Objects.requireNonNull(builder.keyPair, "Interfaces must have a private key"); + listenPort = builder.listenPort; + mtu = builder.mtu; + } + + /** + * Parses an series of "KEY = VALUE" lines into an {@code Interface}. Throws + * {@link ParseException} if the input is not well-formed or contains unknown attributes. + * + * @param lines An iterable sequence of lines, containing at least a private key attribute + * @return An {@code Interface} with all of the attributes from {@code lines} set + */ + public static Interface parse(final Iterable<? extends CharSequence> lines) + throws BadConfigException { + final Builder builder = new Builder(); + for (final CharSequence line : lines) { + final Attribute attribute = Attribute.parse(line).orElseThrow(() -> + new BadConfigException(Section.INTERFACE, Location.TOP_LEVEL, + Reason.SYNTAX_ERROR, line)); + switch (attribute.getKey().toLowerCase(Locale.ENGLISH)) { + case "address": + builder.parseAddresses(attribute.getValue()); + break; + case "dns": + builder.parseDnsServers(attribute.getValue()); + break; + case "excludedapplications": + builder.parseExcludedApplications(attribute.getValue()); + break; + case "listenport": + builder.parseListenPort(attribute.getValue()); + break; + case "mtu": + builder.parseMtu(attribute.getValue()); + break; + case "privatekey": + builder.parsePrivateKey(attribute.getValue()); + break; + default: + throw new BadConfigException(Section.INTERFACE, Location.TOP_LEVEL, + Reason.UNKNOWN_ATTRIBUTE, attribute.getKey()); + } + } + return builder.build(); + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Interface)) + return false; + final Interface other = (Interface) obj; + return addresses.equals(other.addresses) + && dnsServers.equals(other.dnsServers) + && excludedApplications.equals(other.excludedApplications) + && keyPair.equals(other.keyPair) + && listenPort.equals(other.listenPort) + && mtu.equals(other.mtu); + } + + /** + * Returns the set of IP addresses assigned to the interface. + * + * @return a set of {@link InetNetwork}s + */ + public Set<InetNetwork> getAddresses() { + // The collection is already immutable. + return addresses; + } + + /** + * Returns the set of DNS servers associated with the interface. + * + * @return a set of {@link InetAddress}es + */ + public Set<InetAddress> getDnsServers() { + // The collection is already immutable. + return dnsServers; + } + + /** + * Returns the set of applications excluded from using the interface. + * + * @return a set of package names + */ + public Set<String> getExcludedApplications() { + // The collection is already immutable. + return excludedApplications; + } + + /** + * Returns the public/private key pair used by the interface. + * + * @return a key pair + */ + public KeyPair getKeyPair() { + return keyPair; + } + + /** + * Returns the UDP port number that the WireGuard interface will listen on. + * + * @return a UDP port number, or {@code Optional.empty()} if none is configured + */ + public Optional<Integer> getListenPort() { + return listenPort; + } + + /** + * Returns the MTU used for the WireGuard interface. + * + * @return the MTU, or {@code Optional.empty()} if none is configured + */ + public Optional<Integer> getMtu() { + return mtu; + } + + @Override + public int hashCode() { + int hash = 1; + hash = 31 * hash + addresses.hashCode(); + hash = 31 * hash + dnsServers.hashCode(); + hash = 31 * hash + excludedApplications.hashCode(); + hash = 31 * hash + keyPair.hashCode(); + hash = 31 * hash + listenPort.hashCode(); + hash = 31 * hash + mtu.hashCode(); + return hash; + } + + /** + * Converts the {@code Interface} into a string suitable for debugging purposes. The {@code + * Interface} is identified by its public key and (if set) the port used for its UDP socket. + * + * @return A concise single-line identifier for the {@code Interface} + */ + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("(Interface "); + sb.append(keyPair.getPublicKey().toBase64()); + listenPort.ifPresent(lp -> sb.append(" @").append(lp)); + sb.append(')'); + return sb.toString(); + } + + /** + * Converts the {@code Interface} into a string suitable for inclusion in a {@code wg-quick} + * configuration file. + * + * @return The {@code Interface} represented as a series of "Key = Value" lines + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + if (!addresses.isEmpty()) + sb.append("Address = ").append(Attribute.join(addresses)).append('\n'); + if (!dnsServers.isEmpty()) { + final List<String> dnsServerStrings = StreamSupport.stream(dnsServers) + .map(InetAddress::getHostAddress) + .collect(Collectors.toUnmodifiableList()); + sb.append("DNS = ").append(Attribute.join(dnsServerStrings)).append('\n'); + } + if (!excludedApplications.isEmpty()) + sb.append("ExcludedApplications = ").append(Attribute.join(excludedApplications)).append('\n'); + listenPort.ifPresent(lp -> sb.append("ListenPort = ").append(lp).append('\n')); + mtu.ifPresent(m -> sb.append("MTU = ").append(m).append('\n')); + sb.append("PrivateKey = ").append(keyPair.getPrivateKey().toBase64()).append('\n'); + return sb.toString(); + } + + /** + * Serializes the {@code Interface} for use with the WireGuard cross-platform userspace API. + * Note that not all attributes are included in this representation. + * + * @return the {@code Interface} represented as a series of "KEY=VALUE" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + sb.append("private_key=").append(keyPair.getPrivateKey().toHex()).append('\n'); + listenPort.ifPresent(lp -> sb.append("listen_port=").append(lp).append('\n')); + return sb.toString(); + } + + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // Defaults to an empty set. + private final Set<InetNetwork> addresses = new LinkedHashSet<>(); + // Defaults to an empty set. + private final Set<InetAddress> dnsServers = new LinkedHashSet<>(); + // Defaults to an empty set. + private final Set<String> excludedApplications = new LinkedHashSet<>(); + // No default; must be provided before building. + @Nullable private KeyPair keyPair; + // Defaults to not present. + private Optional<Integer> listenPort = Optional.empty(); + // Defaults to not present. + private Optional<Integer> mtu = Optional.empty(); + + public Builder addAddress(final InetNetwork address) { + addresses.add(address); + return this; + } + + public Builder addAddresses(final Collection<InetNetwork> addresses) { + this.addresses.addAll(addresses); + return this; + } + + public Builder addDnsServer(final InetAddress dnsServer) { + dnsServers.add(dnsServer); + return this; + } + + public Builder addDnsServers(final Collection<? extends InetAddress> dnsServers) { + this.dnsServers.addAll(dnsServers); + return this; + } + + public Interface build() throws BadConfigException { + if (keyPair == null) + throw new BadConfigException(Section.INTERFACE, Location.PRIVATE_KEY, + Reason.MISSING_ATTRIBUTE, null); + return new Interface(this); + } + + public Builder excludeApplication(final String application) { + excludedApplications.add(application); + return this; + } + + public Builder excludeApplications(final Collection<String> applications) { + excludedApplications.addAll(applications); + return this; + } + + public Builder parseAddresses(final CharSequence addresses) throws BadConfigException { + try { + for (final String address : Attribute.split(addresses)) + addAddress(InetNetwork.parse(address)); + return this; + } catch (final ParseException e) { + throw new BadConfigException(Section.INTERFACE, Location.ADDRESS, e); + } + } + + public Builder parseDnsServers(final CharSequence dnsServers) throws BadConfigException { + try { + for (final String dnsServer : Attribute.split(dnsServers)) + addDnsServer(InetAddresses.parse(dnsServer)); + return this; + } catch (final ParseException e) { + throw new BadConfigException(Section.INTERFACE, Location.DNS, e); + } + } + + public Builder parseExcludedApplications(final CharSequence apps) { + return excludeApplications(Lists.of(Attribute.split(apps))); + } + + public Builder parseListenPort(final String listenPort) throws BadConfigException { + try { + return setListenPort(Integer.parseInt(listenPort)); + } catch (final NumberFormatException e) { + throw new BadConfigException(Section.INTERFACE, Location.LISTEN_PORT, listenPort, e); + } + } + + public Builder parseMtu(final String mtu) throws BadConfigException { + try { + return setMtu(Integer.parseInt(mtu)); + } catch (final NumberFormatException e) { + throw new BadConfigException(Section.INTERFACE, Location.MTU, mtu, e); + } + } + + public Builder parsePrivateKey(final String privateKey) throws BadConfigException { + try { + return setKeyPair(new KeyPair(Key.fromBase64(privateKey))); + } catch (final KeyFormatException e) { + throw new BadConfigException(Section.INTERFACE, Location.PRIVATE_KEY, e); + } + } + + public Builder setKeyPair(final KeyPair keyPair) { + this.keyPair = keyPair; + return this; + } + + public Builder setListenPort(final int listenPort) throws BadConfigException { + if (listenPort < MIN_UDP_PORT || listenPort > MAX_UDP_PORT) + throw new BadConfigException(Section.INTERFACE, Location.LISTEN_PORT, + Reason.INVALID_VALUE, String.valueOf(listenPort)); + this.listenPort = listenPort == 0 ? Optional.empty() : Optional.of(listenPort); + return this; + } + + public Builder setMtu(final int mtu) throws BadConfigException { + if (mtu < 0) + throw new BadConfigException(Section.INTERFACE, Location.LISTEN_PORT, + Reason.INVALID_VALUE, String.valueOf(mtu)); + this.mtu = mtu == 0 ? Optional.empty() : Optional.of(mtu); + return this; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/ParseException.java b/tunnel/src/main/java/com/wireguard/config/ParseException.java new file mode 100644 index 00000000..c79d1fa1 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/ParseException.java @@ -0,0 +1,44 @@ +/* + * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import androidx.annotation.Nullable; + +/** + */ +public class ParseException extends Exception { + private final Class<?> parsingClass; + private final CharSequence text; + + public ParseException(final Class<?> parsingClass, final CharSequence text, + @Nullable final String message, @Nullable final Throwable cause) { + super(message, cause); + this.parsingClass = parsingClass; + this.text = text; + } + + public ParseException(final Class<?> parsingClass, final CharSequence text, + @Nullable final String message) { + this(parsingClass, text, message, null); + } + + public ParseException(final Class<?> parsingClass, final CharSequence text, + @Nullable final Throwable cause) { + this(parsingClass, text, null, cause); + } + + public ParseException(final Class<?> parsingClass, final CharSequence text) { + this(parsingClass, text, null, null); + } + + public Class<?> getParsingClass() { + return parsingClass; + } + + public CharSequence getText() { + return text; + } +} diff --git a/tunnel/src/main/java/com/wireguard/config/Peer.java b/tunnel/src/main/java/com/wireguard/config/Peer.java new file mode 100644 index 00000000..37fcfa69 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/config/Peer.java @@ -0,0 +1,306 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +import androidx.annotation.Nullable; + +import com.wireguard.config.BadConfigException.Location; +import com.wireguard.config.BadConfigException.Reason; +import com.wireguard.config.BadConfigException.Section; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyFormatException; + +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Locale; +import java.util.Objects; +import java.util.Set; + +import java9.util.Optional; + +/** + * Represents the configuration for a WireGuard peer (a [Peer] block). Peers must have a public key, + * and may optionally have several other attributes. + * <p> + * Instances of this class are immutable. + */ +public final class Peer { + private final Set<InetNetwork> allowedIps; + private final Optional<InetEndpoint> endpoint; + private final Optional<Integer> persistentKeepalive; + private final Optional<Key> preSharedKey; + private final Key publicKey; + + private Peer(final Builder builder) { + // Defensively copy to ensure immutability even if the Builder is reused. + allowedIps = Collections.unmodifiableSet(new LinkedHashSet<>(builder.allowedIps)); + endpoint = builder.endpoint; + persistentKeepalive = builder.persistentKeepalive; + preSharedKey = builder.preSharedKey; + publicKey = Objects.requireNonNull(builder.publicKey, "Peers must have a public key"); + } + + /** + * Parses an series of "KEY = VALUE" lines into a {@code Peer}. Throws {@link ParseException} if + * the input is not well-formed or contains unknown attributes. + * + * @param lines an iterable sequence of lines, containing at least a public key attribute + * @return a {@code Peer} with all of its attributes set from {@code lines} + */ + public static Peer parse(final Iterable<? extends CharSequence> lines) + throws BadConfigException { + final Builder builder = new Builder(); + for (final CharSequence line : lines) { + final Attribute attribute = Attribute.parse(line).orElseThrow(() -> + new BadConfigException(Section.PEER, Location.TOP_LEVEL, + Reason.SYNTAX_ERROR, line)); + switch (attribute.getKey().toLowerCase(Locale.ENGLISH)) { + case "allowedips": + builder.parseAllowedIPs(attribute.getValue()); + break; + case "endpoint": + builder.parseEndpoint(attribute.getValue()); + break; + case "persistentkeepalive": + builder.parsePersistentKeepalive(attribute.getValue()); + break; + case "presharedkey": + builder.parsePreSharedKey(attribute.getValue()); + break; + case "publickey": + builder.parsePublicKey(attribute.getValue()); + break; + default: + throw new BadConfigException(Section.PEER, Location.TOP_LEVEL, + Reason.UNKNOWN_ATTRIBUTE, attribute.getKey()); + } + } + return builder.build(); + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Peer)) + return false; + final Peer other = (Peer) obj; + return allowedIps.equals(other.allowedIps) + && endpoint.equals(other.endpoint) + && persistentKeepalive.equals(other.persistentKeepalive) + && preSharedKey.equals(other.preSharedKey) + && publicKey.equals(other.publicKey); + } + + /** + * Returns the peer's set of allowed IPs. + * + * @return the set of allowed IPs + */ + public Set<InetNetwork> getAllowedIps() { + // The collection is already immutable. + return allowedIps; + } + + /** + * Returns the peer's endpoint. + * + * @return the endpoint, or {@code Optional.empty()} if none is configured + */ + public Optional<InetEndpoint> getEndpoint() { + return endpoint; + } + + /** + * Returns the peer's persistent keepalive. + * + * @return the persistent keepalive, or {@code Optional.empty()} if none is configured + */ + public Optional<Integer> getPersistentKeepalive() { + return persistentKeepalive; + } + + /** + * Returns the peer's pre-shared key. + * + * @return the pre-shared key, or {@code Optional.empty()} if none is configured + */ + public Optional<Key> getPreSharedKey() { + return preSharedKey; + } + + /** + * Returns the peer's public key. + * + * @return the public key + */ + public Key getPublicKey() { + return publicKey; + } + + @Override + public int hashCode() { + int hash = 1; + hash = 31 * hash + allowedIps.hashCode(); + hash = 31 * hash + endpoint.hashCode(); + hash = 31 * hash + persistentKeepalive.hashCode(); + hash = 31 * hash + preSharedKey.hashCode(); + hash = 31 * hash + publicKey.hashCode(); + return hash; + } + + /** + * Converts the {@code Peer} into a string suitable for debugging purposes. The {@code Peer} is + * identified by its public key and (if known) its endpoint. + * + * @return a concise single-line identifier for the {@code Peer} + */ + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("(Peer "); + sb.append(publicKey.toBase64()); + endpoint.ifPresent(ep -> sb.append(" @").append(ep)); + sb.append(')'); + return sb.toString(); + } + + /** + * Converts the {@code Peer} into a string suitable for inclusion in a {@code wg-quick} + * configuration file. + * + * @return the {@code Peer} represented as a series of "Key = Value" lines + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + if (!allowedIps.isEmpty()) + sb.append("AllowedIPs = ").append(Attribute.join(allowedIps)).append('\n'); + endpoint.ifPresent(ep -> sb.append("Endpoint = ").append(ep).append('\n')); + persistentKeepalive.ifPresent(pk -> sb.append("PersistentKeepalive = ").append(pk).append('\n')); + preSharedKey.ifPresent(psk -> sb.append("PreSharedKey = ").append(psk.toBase64()).append('\n')); + sb.append("PublicKey = ").append(publicKey.toBase64()).append('\n'); + return sb.toString(); + } + + /** + * Serializes the {@code Peer} for use with the WireGuard cross-platform userspace API. Note + * that not all attributes are included in this representation. + * + * @return the {@code Peer} represented as a series of "key=value" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + // The order here is important: public_key signifies the beginning of a new peer. + sb.append("public_key=").append(publicKey.toHex()).append('\n'); + for (final InetNetwork allowedIp : allowedIps) + sb.append("allowed_ip=").append(allowedIp).append('\n'); + endpoint.flatMap(InetEndpoint::getResolved).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); + persistentKeepalive.ifPresent(pk -> sb.append("persistent_keepalive_interval=").append(pk).append('\n')); + preSharedKey.ifPresent(psk -> sb.append("preshared_key=").append(psk.toHex()).append('\n')); + return sb.toString(); + } + + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // See wg(8) + private static final int MAX_PERSISTENT_KEEPALIVE = 65535; + + // Defaults to an empty set. + private final Set<InetNetwork> allowedIps = new LinkedHashSet<>(); + // Defaults to not present. + private Optional<InetEndpoint> endpoint = Optional.empty(); + // Defaults to not present. + private Optional<Integer> persistentKeepalive = Optional.empty(); + // Defaults to not present. + private Optional<Key> preSharedKey = Optional.empty(); + // No default; must be provided before building. + @Nullable private Key publicKey; + + public Builder addAllowedIp(final InetNetwork allowedIp) { + allowedIps.add(allowedIp); + return this; + } + + public Builder addAllowedIps(final Collection<InetNetwork> allowedIps) { + this.allowedIps.addAll(allowedIps); + return this; + } + + public Peer build() throws BadConfigException { + if (publicKey == null) + throw new BadConfigException(Section.PEER, Location.PUBLIC_KEY, + Reason.MISSING_ATTRIBUTE, null); + return new Peer(this); + } + + public Builder parseAllowedIPs(final CharSequence allowedIps) throws BadConfigException { + try { + for (final String allowedIp : Attribute.split(allowedIps)) + addAllowedIp(InetNetwork.parse(allowedIp)); + return this; + } catch (final ParseException e) { + throw new BadConfigException(Section.PEER, Location.ALLOWED_IPS, e); + } + } + + public Builder parseEndpoint(final String endpoint) throws BadConfigException { + try { + return setEndpoint(InetEndpoint.parse(endpoint)); + } catch (final ParseException e) { + throw new BadConfigException(Section.PEER, Location.ENDPOINT, e); + } + } + + public Builder parsePersistentKeepalive(final String persistentKeepalive) + throws BadConfigException { + try { + return setPersistentKeepalive(Integer.parseInt(persistentKeepalive)); + } catch (final NumberFormatException e) { + throw new BadConfigException(Section.PEER, Location.PERSISTENT_KEEPALIVE, + persistentKeepalive, e); + } + } + + public Builder parsePreSharedKey(final String preSharedKey) throws BadConfigException { + try { + return setPreSharedKey(Key.fromBase64(preSharedKey)); + } catch (final KeyFormatException e) { + throw new BadConfigException(Section.PEER, Location.PRE_SHARED_KEY, e); + } + } + + public Builder parsePublicKey(final String publicKey) throws BadConfigException { + try { + return setPublicKey(Key.fromBase64(publicKey)); + } catch (final KeyFormatException e) { + throw new BadConfigException(Section.PEER, Location.PUBLIC_KEY, e); + } + } + + public Builder setEndpoint(final InetEndpoint endpoint) { + this.endpoint = Optional.of(endpoint); + return this; + } + + public Builder setPersistentKeepalive(final int persistentKeepalive) + throws BadConfigException { + if (persistentKeepalive < 0 || persistentKeepalive > MAX_PERSISTENT_KEEPALIVE) + throw new BadConfigException(Section.PEER, Location.PERSISTENT_KEEPALIVE, + Reason.INVALID_VALUE, String.valueOf(persistentKeepalive)); + this.persistentKeepalive = persistentKeepalive == 0 ? + Optional.empty() : Optional.of(persistentKeepalive); + return this; + } + + public Builder setPreSharedKey(final Key preSharedKey) { + this.preSharedKey = Optional.of(preSharedKey); + return this; + } + + public Builder setPublicKey(final Key publicKey) { + this.publicKey = publicKey; + return this; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/crypto/Curve25519.java b/tunnel/src/main/java/com/wireguard/crypto/Curve25519.java new file mode 100644 index 00000000..5622fc5f --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/crypto/Curve25519.java @@ -0,0 +1,497 @@ +/* + * Copyright © 2016 Southern Storm Software, Pty Ltd. + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +import androidx.annotation.Nullable; + +import java.util.Arrays; + +/** + * Implementation of the Curve25519 elliptic curve algorithm. + * <p> + * This implementation was imported to WireGuard from noise-java: + * https://github.com/rweather/noise-java + * <p> + * This implementation is based on that from arduinolibs: + * https://github.com/rweather/arduinolibs + * <p> + * Differences in this version are due to using 26-bit limbs for the + * representation instead of the 8/16/32-bit limbs in the original. + * <p> + * References: http://cr.yp.to/ecdh.html, RFC 7748 + */ +@SuppressWarnings({"MagicNumber", "NonConstantFieldWithUpperCaseName", "SuspiciousNameCombination"}) +public final class Curve25519 { + // Numbers modulo 2^255 - 19 are broken up into ten 26-bit words. + private static final int NUM_LIMBS_255BIT = 10; + private static final int NUM_LIMBS_510BIT = 20; + + private final int[] A; + private final int[] AA; + private final int[] B; + private final int[] BB; + private final int[] C; + private final int[] CB; + private final int[] D; + private final int[] DA; + private final int[] E; + private final long[] t1; + private final int[] t2; + private final int[] x_1; + private final int[] x_2; + private final int[] x_3; + private final int[] z_2; + private final int[] z_3; + + /** + * Constructs the temporary state holder for Curve25519 evaluation. + */ + private Curve25519() { + // Allocate memory for all of the temporary variables we will need. + x_1 = new int[NUM_LIMBS_255BIT]; + x_2 = new int[NUM_LIMBS_255BIT]; + x_3 = new int[NUM_LIMBS_255BIT]; + z_2 = new int[NUM_LIMBS_255BIT]; + z_3 = new int[NUM_LIMBS_255BIT]; + A = new int[NUM_LIMBS_255BIT]; + B = new int[NUM_LIMBS_255BIT]; + C = new int[NUM_LIMBS_255BIT]; + D = new int[NUM_LIMBS_255BIT]; + E = new int[NUM_LIMBS_255BIT]; + AA = new int[NUM_LIMBS_255BIT]; + BB = new int[NUM_LIMBS_255BIT]; + DA = new int[NUM_LIMBS_255BIT]; + CB = new int[NUM_LIMBS_255BIT]; + t1 = new long[NUM_LIMBS_510BIT]; + t2 = new int[NUM_LIMBS_510BIT]; + } + + /** + * Conditional swap of two values. + * + * @param select Set to 1 to swap, 0 to leave as-is. + * @param x The first value. + * @param y The second value. + */ + private static void cswap(int select, final int[] x, final int[] y) { + select = -select; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) { + final int dummy = select & (x[index] ^ y[index]); + x[index] ^= dummy; + y[index] ^= dummy; + } + } + + /** + * Evaluates the Curve25519 curve. + * + * @param result Buffer to place the result of the evaluation into. + * @param offset Offset into the result buffer. + * @param privateKey The private key to use in the evaluation. + * @param publicKey The public key to use in the evaluation, or null + * if the base point of the curve should be used. + */ + public static void eval(final byte[] result, final int offset, + final byte[] privateKey, @Nullable final byte[] publicKey) { + final Curve25519 state = new Curve25519(); + try { + // Unpack the public key value. If null, use 9 as the base point. + Arrays.fill(state.x_1, 0); + if (publicKey != null) { + // Convert the input value from little-endian into 26-bit limbs. + for (int index = 0; index < 32; ++index) { + final int bit = (index * 8) % 26; + final int word = (index * 8) / 26; + final int value = publicKey[index] & 0xFF; + if (bit <= (26 - 8)) { + state.x_1[word] |= value << bit; + } else { + state.x_1[word] |= value << bit; + state.x_1[word] &= 0x03FFFFFF; + state.x_1[word + 1] |= value >> (26 - bit); + } + } + + // Just in case, we reduce the number modulo 2^255 - 19 to + // make sure that it is in range of the field before we start. + // This eliminates values between 2^255 - 19 and 2^256 - 1. + state.reduceQuick(state.x_1); + state.reduceQuick(state.x_1); + } else { + state.x_1[0] = 9; + } + + // Initialize the other temporary variables. + Arrays.fill(state.x_2, 0); // x_2 = 1 + state.x_2[0] = 1; + Arrays.fill(state.z_2, 0); // z_2 = 0 + System.arraycopy(state.x_1, 0, state.x_3, 0, state.x_1.length); // x_3 = x_1 + Arrays.fill(state.z_3, 0); // z_3 = 1 + state.z_3[0] = 1; + + // Evaluate the curve for every bit of the private key. + state.evalCurve(privateKey); + + // Compute x_2 * (z_2 ^ (p - 2)) where p = 2^255 - 19. + state.recip(state.z_3, state.z_2); + state.mul(state.x_2, state.x_2, state.z_3); + + // Convert x_2 into little-endian in the result buffer. + for (int index = 0; index < 32; ++index) { + final int bit = (index * 8) % 26; + final int word = (index * 8) / 26; + if (bit <= (26 - 8)) + result[offset + index] = (byte) (state.x_2[word] >> bit); + else + result[offset + index] = (byte) ((state.x_2[word] >> bit) | (state.x_2[word + 1] << (26 - bit))); + } + } finally { + // Clean up all temporary state before we exit. + state.destroy(); + } + } + + /** + * Subtracts two numbers modulo 2^255 - 19. + * + * @param result The result. + * @param x The first number to subtract. + * @param y The second number to subtract. + */ + private static void sub(final int[] result, final int[] x, final int[] y) { + int index; + int borrow; + + // Subtract y from x to generate the intermediate result. + borrow = 0; + for (index = 0; index < NUM_LIMBS_255BIT; ++index) { + borrow = x[index] - y[index] - ((borrow >> 26) & 0x01); + result[index] = borrow & 0x03FFFFFF; + } + + // If we had a borrow, then the result has gone negative and we + // have to add 2^255 - 19 to the result to make it positive again. + // The top bits of "borrow" will be all 1's if there is a borrow + // or it will be all 0's if there was no borrow. Easiest is to + // conditionally subtract 19 and then mask off the high bits. + borrow = result[0] - ((-((borrow >> 26) & 0x01)) & 19); + result[0] = borrow & 0x03FFFFFF; + for (index = 1; index < NUM_LIMBS_255BIT; ++index) { + borrow = result[index] - ((borrow >> 26) & 0x01); + result[index] = borrow & 0x03FFFFFF; + } + result[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; + } + + /** + * Adds two numbers modulo 2^255 - 19. + * + * @param result The result. + * @param x The first number to add. + * @param y The second number to add. + */ + private void add(final int[] result, final int[] x, final int[] y) { + int carry = x[0] + y[0]; + result[0] = carry & 0x03FFFFFF; + for (int index = 1; index < NUM_LIMBS_255BIT; ++index) { + carry = (carry >> 26) + x[index] + y[index]; + result[index] = carry & 0x03FFFFFF; + } + reduceQuick(result); + } + + /** + * Destroy all sensitive data in this object. + */ + private void destroy() { + // Destroy all temporary variables. + Arrays.fill(x_1, 0); + Arrays.fill(x_2, 0); + Arrays.fill(x_3, 0); + Arrays.fill(z_2, 0); + Arrays.fill(z_3, 0); + Arrays.fill(A, 0); + Arrays.fill(B, 0); + Arrays.fill(C, 0); + Arrays.fill(D, 0); + Arrays.fill(E, 0); + Arrays.fill(AA, 0); + Arrays.fill(BB, 0); + Arrays.fill(DA, 0); + Arrays.fill(CB, 0); + Arrays.fill(t1, 0L); + Arrays.fill(t2, 0); + } + + /** + * Evaluates the curve for every bit in a secret key. + * + * @param s The 32-byte secret key. + */ + private void evalCurve(final byte[] s) { + int sposn = 31; + int sbit = 6; + int svalue = s[sposn] | 0x40; + int swap = 0; + + // Iterate over all 255 bits of "s" from the highest to the lowest. + // We ignore the high bit of the 256-bit representation of "s". + while (true) { + // Conditional swaps on entry to this bit but only if we + // didn't swap on the previous bit. + final int select = (svalue >> sbit) & 0x01; + swap ^= select; + cswap(swap, x_2, x_3); + cswap(swap, z_2, z_3); + swap = select; + + // Evaluate the curve. + add(A, x_2, z_2); // A = x_2 + z_2 + square(AA, A); // AA = A^2 + sub(B, x_2, z_2); // B = x_2 - z_2 + square(BB, B); // BB = B^2 + sub(E, AA, BB); // E = AA - BB + add(C, x_3, z_3); // C = x_3 + z_3 + sub(D, x_3, z_3); // D = x_3 - z_3 + mul(DA, D, A); // DA = D * A + mul(CB, C, B); // CB = C * B + add(x_3, DA, CB); // x_3 = (DA + CB)^2 + square(x_3, x_3); + sub(z_3, DA, CB); // z_3 = x_1 * (DA - CB)^2 + square(z_3, z_3); + mul(z_3, z_3, x_1); + mul(x_2, AA, BB); // x_2 = AA * BB + mulA24(z_2, E); // z_2 = E * (AA + a24 * E) + add(z_2, z_2, AA); + mul(z_2, z_2, E); + + // Move onto the next lower bit of "s". + if (sbit > 0) { + --sbit; + } else if (sposn == 0) { + break; + } else if (sposn == 1) { + --sposn; + svalue = s[sposn] & 0xF8; + sbit = 7; + } else { + --sposn; + svalue = s[sposn]; + sbit = 7; + } + } + + // Final conditional swaps. + cswap(swap, x_2, x_3); + cswap(swap, z_2, z_3); + } + + /** + * Multiplies two numbers modulo 2^255 - 19. + * + * @param result The result. + * @param x The first number to multiply. + * @param y The second number to multiply. + */ + private void mul(final int[] result, final int[] x, final int[] y) { + // Multiply the two numbers to create the intermediate result. + long v = x[0]; + for (int i = 0; i < NUM_LIMBS_255BIT; ++i) { + t1[i] = v * y[i]; + } + for (int i = 1; i < NUM_LIMBS_255BIT; ++i) { + v = x[i]; + for (int j = 0; j < (NUM_LIMBS_255BIT - 1); ++j) { + t1[i + j] += v * y[j]; + } + t1[i + NUM_LIMBS_255BIT - 1] = v * y[NUM_LIMBS_255BIT - 1]; + } + + // Propagate carries and convert back into 26-bit words. + v = t1[0]; + t2[0] = ((int) v) & 0x03FFFFFF; + for (int i = 1; i < NUM_LIMBS_510BIT; ++i) { + v = (v >> 26) + t1[i]; + t2[i] = ((int) v) & 0x03FFFFFF; + } + + // Reduce the result modulo 2^255 - 19. + reduce(result, t2, NUM_LIMBS_255BIT); + } + + /** + * Multiplies a number by the a24 constant, modulo 2^255 - 19. + * + * @param result The result. + * @param x The number to multiply by a24. + */ + private void mulA24(final int[] result, final int[] x) { + final long a24 = 121665; + long carry = 0; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) { + carry += a24 * x[index]; + t2[index] = ((int) carry) & 0x03FFFFFF; + carry >>= 26; + } + t2[NUM_LIMBS_255BIT] = ((int) carry) & 0x03FFFFFF; + reduce(result, t2, 1); + } + + /** + * Raise x to the power of (2^250 - 1). + * + * @param result The result. Must not overlap with x. + * @param x The argument. + */ + private void pow250(final int[] result, final int[] x) { + // The big-endian hexadecimal expansion of (2^250 - 1) is: + // 03FFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF + // + // The naive implementation needs to do 2 multiplications per 1 bit and + // 1 multiplication per 0 bit. We can improve upon this by creating a + // pattern 0000000001 ... 0000000001. If we square and multiply the + // pattern by itself we can turn the pattern into the partial results + // 0000000011 ... 0000000011, 0000000111 ... 0000000111, etc. + // This averages out to about 1.1 multiplications per 1 bit instead of 2. + + // Build a pattern of 250 bits in length of repeated copies of 0000000001. + square(A, x); + for (int j = 0; j < 9; ++j) + square(A, A); + mul(result, A, x); + for (int i = 0; i < 23; ++i) { + for (int j = 0; j < 10; ++j) + square(A, A); + mul(result, result, A); + } + + // Multiply bit-shifted versions of the 0000000001 pattern into + // the result to "fill in" the gaps in the pattern. + square(A, result); + mul(result, result, A); + for (int j = 0; j < 8; ++j) { + square(A, A); + mul(result, result, A); + } + } + + /** + * Computes the reciprocal of a number modulo 2^255 - 19. + * + * @param result The result. Must not overlap with x. + * @param x The argument. + */ + private void recip(final int[] result, final int[] x) { + // The reciprocal is the same as x ^ (p - 2) where p = 2^255 - 19. + // The big-endian hexadecimal expansion of (p - 2) is: + // 7FFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFEB + // Start with the 250 upper bits of the expansion of (p - 2). + pow250(result, x); + + // Deal with the 5 lowest bits of (p - 2), 01011, from highest to lowest. + square(result, result); + square(result, result); + mul(result, result, x); + square(result, result); + square(result, result); + mul(result, result, x); + square(result, result); + mul(result, result, x); + } + + /** + * Reduce a number modulo 2^255 - 19. + * + * @param result The result. + * @param x The value to be reduced. This array will be + * modified during the reduction. + * @param size The number of limbs in the high order half of x. + */ + private void reduce(final int[] result, final int[] x, final int size) { + // Calculate (x mod 2^255) + ((x / 2^255) * 19) which will + // either produce the answer we want or it will produce a + // value of the form "answer + j * (2^255 - 19)". There are + // 5 left-over bits in the top-most limb of the bottom half. + int carry = 0; + int limb = x[NUM_LIMBS_255BIT - 1] >> 21; + x[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; + for (int index = 0; index < size; ++index) { + limb += x[NUM_LIMBS_255BIT + index] << 5; + carry += (limb & 0x03FFFFFF) * 19 + x[index]; + x[index] = carry & 0x03FFFFFF; + limb >>= 26; + carry >>= 26; + } + if (size < NUM_LIMBS_255BIT) { + // The high order half of the number is short; e.g. for mulA24(). + // Propagate the carry through the rest of the low order part. + for (int index = size; index < NUM_LIMBS_255BIT; ++index) { + carry += x[index]; + x[index] = carry & 0x03FFFFFF; + carry >>= 26; + } + } + + // The "j" value may still be too large due to the final carry-out. + // We must repeat the reduction. If we already have the answer, + // then this won't do any harm but we must still do the calculation + // to preserve the overall timing. The "j" value will be between + // 0 and 19, which means that the carry we care about is in the + // top 5 bits of the highest limb of the bottom half. + carry = (x[NUM_LIMBS_255BIT - 1] >> 21) * 19; + x[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) { + carry += x[index]; + result[index] = carry & 0x03FFFFFF; + carry >>= 26; + } + + // At this point "x" will either be the answer or it will be the + // answer plus (2^255 - 19). Perform a trial subtraction to + // complete the reduction process. + reduceQuick(result); + } + + /** + * Reduces a number modulo 2^255 - 19 where it is known that the + * number can be reduced with only 1 trial subtraction. + * + * @param x The number to reduce, and the result. + */ + private void reduceQuick(final int[] x) { + // Perform a trial subtraction of (2^255 - 19) from "x" which is + // equivalent to adding 19 and subtracting 2^255. We add 19 here; + // the subtraction of 2^255 occurs in the next step. + int carry = 19; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) { + carry += x[index]; + t2[index] = carry & 0x03FFFFFF; + carry >>= 26; + } + + // If there was a borrow, then the original "x" is the correct answer. + // If there was no borrow, then "t2" is the correct answer. Select the + // correct answer but do it in a way that instruction timing will not + // reveal which value was selected. Borrow will occur if bit 21 of + // "t2" is zero. Turn the bit into a selection mask. + final int mask = -((t2[NUM_LIMBS_255BIT - 1] >> 21) & 0x01); + final int nmask = ~mask; + t2[NUM_LIMBS_255BIT - 1] &= 0x001FFFFF; + for (int index = 0; index < NUM_LIMBS_255BIT; ++index) + x[index] = (x[index] & nmask) | (t2[index] & mask); + } + + /** + * Squares a number modulo 2^255 - 19. + * + * @param result The result. + * @param x The number to square. + */ + private void square(final int[] result, final int[] x) { + mul(result, x, x); + } +} diff --git a/tunnel/src/main/java/com/wireguard/crypto/Key.java b/tunnel/src/main/java/com/wireguard/crypto/Key.java new file mode 100644 index 00000000..6648a5f3 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/crypto/Key.java @@ -0,0 +1,288 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +import com.wireguard.crypto.KeyFormatException.Type; + +import java.security.MessageDigest; +import java.security.SecureRandom; +import java.util.Arrays; + +/** + * Represents a WireGuard public or private key. This class uses specialized constant-time base64 + * and hexadecimal codec implementations that resist side-channel attacks. + * <p> + * Instances of this class are immutable. + */ +@SuppressWarnings("MagicNumber") +public final class Key { + private final byte[] key; + + /** + * Constructs an object encapsulating the supplied key. + * + * @param key an array of bytes containing a binary key. Callers of this constructor are + * responsible for ensuring that the array is of the correct length. + */ + private Key(final byte[] key) { + // Defensively copy to ensure immutability. + this.key = Arrays.copyOf(key, key.length); + } + + /** + * Decodes a single 4-character base64 chunk to an integer in constant time. + * + * @param src an array of at least 4 characters in base64 format + * @param srcOffset the offset of the beginning of the chunk in {@code src} + * @return the decoded 3-byte integer, or some arbitrary integer value if the input was not + * valid base64 + */ + private static int decodeBase64(final char[] src, final int srcOffset) { + int val = 0; + for (int i = 0; i < 4; ++i) { + final char c = src[i + srcOffset]; + val |= (-1 + + ((((('A' - 1) - c) & (c - ('Z' + 1))) >>> 8) & (c - 64)) + + ((((('a' - 1) - c) & (c - ('z' + 1))) >>> 8) & (c - 70)) + + ((((('0' - 1) - c) & (c - ('9' + 1))) >>> 8) & (c + 5)) + + ((((('+' - 1) - c) & (c - ('+' + 1))) >>> 8) & 63) + + ((((('/' - 1) - c) & (c - ('/' + 1))) >>> 8) & 64) + ) << (18 - 6 * i); + } + return val; + } + + /** + * Encodes a single 4-character base64 chunk from 3 consecutive bytes in constant time. + * + * @param src an array of at least 3 bytes + * @param srcOffset the offset of the beginning of the chunk in {@code src} + * @param dest an array of at least 4 characters + * @param destOffset the offset of the beginning of the chunk in {@code dest} + */ + private static void encodeBase64(final byte[] src, final int srcOffset, + final char[] dest, final int destOffset) { + final byte[] input = { + (byte) ((src[srcOffset] >>> 2) & 63), + (byte) ((src[srcOffset] << 4 | ((src[1 + srcOffset] & 0xff) >>> 4)) & 63), + (byte) ((src[1 + srcOffset] << 2 | ((src[2 + srcOffset] & 0xff) >>> 6)) & 63), + (byte) ((src[2 + srcOffset]) & 63), + }; + for (int i = 0; i < 4; ++i) { + dest[i + destOffset] = (char) (input[i] + 'A' + + (((25 - input[i]) >>> 8) & 6) + - (((51 - input[i]) >>> 8) & 75) + - (((61 - input[i]) >>> 8) & 15) + + (((62 - input[i]) >>> 8) & 3)); + } + } + + /** + * Decodes a WireGuard public or private key from its base64 string representation. This + * function throws a {@link KeyFormatException} if the source string is not well-formed. + * + * @param str the base64 string representation of a WireGuard key + * @return the decoded key encapsulated in an immutable container + */ + public static Key fromBase64(final String str) throws KeyFormatException { + final char[] input = str.toCharArray(); + if (input.length != Format.BASE64.length || input[Format.BASE64.length - 1] != '=') + throw new KeyFormatException(Format.BASE64, Type.LENGTH); + final byte[] key = new byte[Format.BINARY.length]; + int i; + int ret = 0; + for (i = 0; i < key.length / 3; ++i) { + final int val = decodeBase64(input, i * 4); + ret |= val >>> 31; + key[i * 3] = (byte) ((val >>> 16) & 0xff); + key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); + key[i * 3 + 2] = (byte) (val & 0xff); + } + final char[] endSegment = { + input[i * 4], + input[i * 4 + 1], + input[i * 4 + 2], + 'A', + }; + final int val = decodeBase64(endSegment, 0); + ret |= (val >>> 31) | (val & 0xff); + key[i * 3] = (byte) ((val >>> 16) & 0xff); + key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); + + if (ret != 0) + throw new KeyFormatException(Format.BASE64, Type.CONTENTS); + return new Key(key); + } + + /** + * Wraps a WireGuard public or private key in an immutable container. This function throws a + * {@link KeyFormatException} if the source data is not the correct length. + * + * @param bytes an array of bytes containing a WireGuard key in binary format + * @return the key encapsulated in an immutable container + */ + public static Key fromBytes(final byte[] bytes) throws KeyFormatException { + if (bytes.length != Format.BINARY.length) + throw new KeyFormatException(Format.BINARY, Type.LENGTH); + return new Key(bytes); + } + + /** + * Decodes a WireGuard public or private key from its hexadecimal string representation. This + * function throws a {@link KeyFormatException} if the source string is not well-formed. + * + * @param str the hexadecimal string representation of a WireGuard key + * @return the decoded key encapsulated in an immutable container + */ + public static Key fromHex(final String str) throws KeyFormatException { + final char[] input = str.toCharArray(); + if (input.length != Format.HEX.length) + throw new KeyFormatException(Format.HEX, Type.LENGTH); + final byte[] key = new byte[Format.BINARY.length]; + int ret = 0; + for (int i = 0; i < key.length; ++i) { + int c; + int cNum; + int cNum0; + int cAlpha; + int cAlpha0; + int cVal; + final int cAcc; + + c = input[i * 2]; + cNum = c ^ 48; + cNum0 = ((cNum - 10) >>> 8) & 0xff; + cAlpha = (c & ~32) - 55; + cAlpha0 = (((cAlpha - 10) ^ (cAlpha - 16)) >>> 8) & 0xff; + ret |= ((cNum0 | cAlpha0) - 1) >>> 8; + cVal = (cNum0 & cNum) | (cAlpha0 & cAlpha); + cAcc = cVal * 16; + + c = input[i * 2 + 1]; + cNum = c ^ 48; + cNum0 = ((cNum - 10) >>> 8) & 0xff; + cAlpha = (c & ~32) - 55; + cAlpha0 = (((cAlpha - 10) ^ (cAlpha - 16)) >>> 8) & 0xff; + ret |= ((cNum0 | cAlpha0) - 1) >>> 8; + cVal = (cNum0 & cNum) | (cAlpha0 & cAlpha); + key[i] = (byte) (cAcc | cVal); + } + if (ret != 0) + throw new KeyFormatException(Format.HEX, Type.CONTENTS); + return new Key(key); + } + + /** + * Generates a private key using the system's {@link SecureRandom} number generator. + * + * @return a well-formed random private key + */ + static Key generatePrivateKey() { + final SecureRandom secureRandom = new SecureRandom(); + final byte[] privateKey = new byte[Format.BINARY.getLength()]; + secureRandom.nextBytes(privateKey); + privateKey[0] &= 248; + privateKey[31] &= 127; + privateKey[31] |= 64; + return new Key(privateKey); + } + + /** + * Generates a public key from an existing private key. + * + * @param privateKey a private key + * @return a well-formed public key that corresponds to the supplied private key + */ + static Key generatePublicKey(final Key privateKey) { + final byte[] publicKey = new byte[Format.BINARY.getLength()]; + Curve25519.eval(publicKey, 0, privateKey.getBytes(), null); + return new Key(publicKey); + } + + /** + * Returns the key as an array of bytes. + * + * @return an array of bytes containing the raw binary key + */ + public byte[] getBytes() { + // Defensively copy to ensure immutability. + return Arrays.copyOf(key, key.length); + } + + /** + * Encodes the key to base64. + * + * @return a string containing the encoded key + */ + public String toBase64() { + final char[] output = new char[Format.BASE64.length]; + int i; + for (i = 0; i < key.length / 3; ++i) + encodeBase64(key, i * 3, output, i * 4); + final byte[] endSegment = { + key[i * 3], + key[i * 3 + 1], + 0, + }; + encodeBase64(endSegment, 0, output, i * 4); + output[Format.BASE64.length - 1] = '='; + return new String(output); + } + + /** + * Encodes the key to hexadecimal ASCII characters. + * + * @return a string containing the encoded key + */ + public String toHex() { + final char[] output = new char[Format.HEX.length]; + for (int i = 0; i < key.length; ++i) { + output[i * 2] = (char) (87 + (key[i] >> 4 & 0xf) + + ((((key[i] >> 4 & 0xf) - 10) >> 8) & ~38)); + output[i * 2 + 1] = (char) (87 + (key[i] & 0xf) + + ((((key[i] & 0xf) - 10) >> 8) & ~38)); + } + return new String(output); + } + + @Override + public int hashCode() { + int ret = 0; + for (int i = 0; i < key.length / 4; ++i) + ret ^= (key[i * 4 + 0] >> 0) + (key[i * 4 + 1] >> 8) + (key[i * 4 + 2] >> 16) + (key[i * 4 + 3] >> 24); + return ret; + } + + @Override + public boolean equals(final Object obj) { + if (obj == this) + return true; + if (obj == null || obj.getClass() != getClass()) + return false; + final Key other = (Key) obj; + return MessageDigest.isEqual(key, other.key); + } + + /** + * The supported formats for encoding a WireGuard key. + */ + public enum Format { + BASE64(44), + BINARY(32), + HEX(64); + + private final int length; + + Format(final int length) { + this.length = length; + } + + public int getLength() { + return length; + } + } + +} diff --git a/tunnel/src/main/java/com/wireguard/crypto/KeyFormatException.java b/tunnel/src/main/java/com/wireguard/crypto/KeyFormatException.java new file mode 100644 index 00000000..5818b4d4 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/crypto/KeyFormatException.java @@ -0,0 +1,34 @@ +/* + * Copyright © 2018-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +/** + * An exception thrown when attempting to parse an invalid key (too short, too long, or byte + * data inappropriate for the format). The format being parsed can be accessed with the + * {@link #getFormat} method. + */ +public final class KeyFormatException extends Exception { + private final Key.Format format; + private final Type type; + + KeyFormatException(final Key.Format format, final Type type) { + this.format = format; + this.type = type; + } + + public Key.Format getFormat() { + return format; + } + + public Type getType() { + return type; + } + + public enum Type { + CONTENTS, + LENGTH + } +} diff --git a/tunnel/src/main/java/com/wireguard/crypto/KeyPair.java b/tunnel/src/main/java/com/wireguard/crypto/KeyPair.java new file mode 100644 index 00000000..f8238e91 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/crypto/KeyPair.java @@ -0,0 +1,51 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +/** + * Represents a Curve25519 key pair as used by WireGuard. + * <p> + * Instances of this class are immutable. + */ +public class KeyPair { + private final Key privateKey; + private final Key publicKey; + + /** + * Creates a key pair using a newly-generated private key. + */ + public KeyPair() { + this(Key.generatePrivateKey()); + } + + /** + * Creates a key pair using an existing private key. + * + * @param privateKey a private key, used to derive the public key + */ + public KeyPair(final Key privateKey) { + this.privateKey = privateKey; + publicKey = Key.generatePublicKey(privateKey); + } + + /** + * Returns the private key from the key pair. + * + * @return the private key + */ + public Key getPrivateKey() { + return privateKey; + } + + /** + * Returns the public key from the key pair. + * + * @return the public key + */ + public Key getPublicKey() { + return publicKey; + } +} diff --git a/tunnel/src/main/java/com/wireguard/util/NonNullForAll.java b/tunnel/src/main/java/com/wireguard/util/NonNullForAll.java new file mode 100644 index 00000000..f179fa49 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/util/NonNullForAll.java @@ -0,0 +1,26 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.util; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +import javax.annotation.Nonnull; +import javax.annotation.meta.TypeQualifierDefault; + +/** + * This annotation can be applied to a package, class or method to indicate that all + * class fields and method parameters and return values in that element are nonnull + * by default unless overridden. + */ +@Documented +@Nonnull +@TypeQualifierDefault({ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER}) +@Retention(RetentionPolicy.RUNTIME) +public @interface NonNullForAll { +} |