diff options
author | Harsh Shandilya <me@msfjarvis.dev> | 2020-03-09 19:00:14 +0530 |
---|---|---|
committer | Harsh Shandilya <me@msfjarvis.dev> | 2020-03-09 19:24:26 +0530 |
commit | adc613d8011af7c508050badb1272e8326554c39 (patch) | |
tree | 4eadedc0767e1f4f482b7c22ec905329acab62a6 /tunnel/src/main/java/com/wireguard/android | |
parent | fd573f6c1c37bcfcd09237dfcd55e08b1e0eaff9 (diff) |
Migrate tunnel related classes to tunnel/ Gradle module
Signed-off-by: Harsh Shandilya <me@msfjarvis.dev>
Diffstat (limited to 'tunnel/src/main/java/com/wireguard/android')
10 files changed, 1215 insertions, 0 deletions
diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Backend.java b/tunnel/src/main/java/com/wireguard/android/backend/Backend.java new file mode 100644 index 00000000..ed3a5ebd --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Backend.java @@ -0,0 +1,63 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import com.wireguard.config.Config; + +import java.util.Collection; +import java.util.Set; + +import androidx.annotation.Nullable; + +/** + * Interface for implementations of the WireGuard secure network tunnel. + */ + +public interface Backend { + /** + * Enumerate names of currently-running tunnels. + * + * @return The set of running tunnel names. + */ + Set<String> getRunningTunnelNames(); + + /** + * Get the state of a tunnel. + * + * @param tunnel The tunnel to examine the state of. + * @return The state of the tunnel. + */ + Tunnel.State getState(Tunnel tunnel) throws Exception; + + /** + * Get statistics about traffic and errors on this tunnel. If the tunnel is not running, the + * statistics object will be filled with zero values. + * + * @param tunnel The tunnel to retrieve statistics for. + * @return The statistics for the tunnel. + */ + Statistics getStatistics(Tunnel tunnel) throws Exception; + + /** + * Determine version of underlying backend. + * + * @return The version of the backend. + * @throws Exception + */ + String getVersion() throws Exception; + + /** + * Set the state of a tunnel, updating it's configuration. If the tunnel is already up, config + * may update the running configuration; config may be null when setting the tunnel down. + * + * @param tunnel The tunnel to control the state of. + * @param state The new state for this tunnel. Must be {@code UP}, {@code DOWN}, or + * {@code TOGGLE}. + * @param config The configuration for this tunnel, may be null if state is {@code DOWN}. + * @return The updated state of the tunnel. + */ + Tunnel.State setState(Tunnel tunnel, Tunnel.State state, @Nullable Config config) throws Exception; +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java b/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java new file mode 100644 index 00000000..e1e8eaa9 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/BackendException.java @@ -0,0 +1,30 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +public final class BackendException extends Exception { + public enum Reason { + UNKNOWN_KERNEL_MODULE_NAME, + WG_QUICK_CONFIG_ERROR_CODE, + TUNNEL_MISSING_CONFIG, + VPN_NOT_AUTHORIZED, + UNABLE_TO_START_VPN, + TUN_CREATION_ERROR, + GO_ACTIVATION_ERROR_CODE + } + private final Reason reason; + private final Object[] format; + public BackendException(final Reason reason, final Object ...format) { + this.reason = reason; + this.format = format; + } + public Reason getReason() { + return reason; + } + public Object[] getFormat() { + return format; + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java new file mode 100644 index 00000000..6ad5afa4 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java @@ -0,0 +1,292 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import android.content.Context; +import android.content.Intent; +import android.os.Build; +import android.os.ParcelFileDescriptor; +import androidx.annotation.Nullable; +import androidx.collection.ArraySet; +import android.util.Log; + +import com.wireguard.android.backend.BackendException.Reason; +import com.wireguard.android.backend.Tunnel.State; +import com.wireguard.android.util.SharedLibraryLoader; +import com.wireguard.config.Config; +import com.wireguard.config.InetNetwork; +import com.wireguard.config.Peer; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyFormatException; + +import java.net.InetAddress; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import java9.util.concurrent.CompletableFuture; + +public final class GoBackend implements Backend { + private static final String TAG = "WireGuard/" + GoBackend.class.getSimpleName(); + private static CompletableFuture<VpnService> vpnService = new CompletableFuture<>(); + public interface AlwaysOnCallback { + void alwaysOnTriggered(); + } + @Nullable private static AlwaysOnCallback alwaysOnCallback; + public static void setAlwaysOnCallback(AlwaysOnCallback cb) { + alwaysOnCallback = cb; + } + + private final Context context; + @Nullable private Tunnel currentTunnel; + @Nullable private Config currentConfig; + private int currentTunnelHandle = -1; + + public GoBackend(final Context context) { + SharedLibraryLoader.loadSharedLibrary(context, "wg-go"); + this.context = context; + } + + private static native String wgGetConfig(int handle); + + private static native int wgGetSocketV4(int handle); + + private static native int wgGetSocketV6(int handle); + + private static native void wgTurnOff(int handle); + + private static native int wgTurnOn(String ifName, int tunFd, String settings); + + private static native String wgVersion(); + + @Override + public Set<String> getRunningTunnelNames() { + if (currentTunnel != null) { + final Set<String> runningTunnels = new ArraySet<>(); + runningTunnels.add(currentTunnel.getName()); + return runningTunnels; + } + return Collections.emptySet(); + } + + @Override + public State getState(final Tunnel tunnel) { + return currentTunnel == tunnel ? State.UP : State.DOWN; + } + + @Override + public Statistics getStatistics(final Tunnel tunnel) { + final Statistics stats = new Statistics(); + if (tunnel != currentTunnel) { + return stats; + } + final String config = wgGetConfig(currentTunnelHandle); + Key key = null; + long rx = 0, tx = 0; + for (final String line : config.split("\\n")) { + if (line.startsWith("public_key=")) { + if (key != null) + stats.add(key, rx, tx); + rx = 0; + tx = 0; + try { + key = Key.fromHex(line.substring(11)); + } catch (final KeyFormatException ignored) { + key = null; + } + } else if (line.startsWith("rx_bytes=")) { + if (key == null) + continue; + try { + rx = Long.parseLong(line.substring(9)); + } catch (final NumberFormatException ignored) { + rx = 0; + } + } else if (line.startsWith("tx_bytes=")) { + if (key == null) + continue; + try { + tx = Long.parseLong(line.substring(9)); + } catch (final NumberFormatException ignored) { + tx = 0; + } + } + } + if (key != null) + stats.add(key, rx, tx); + return stats; + } + + @Override + public String getVersion() { + return wgVersion(); + } + + @Override + public State setState(final Tunnel tunnel, State state, @Nullable final Config config) throws Exception { + final State originalState = getState(tunnel); + + if (state == State.TOGGLE) + state = originalState == State.UP ? State.DOWN : State.UP; + if (state == originalState && tunnel == currentTunnel && config == currentConfig) + return originalState; + if (state == State.UP) { + final Config originalConfig = currentConfig; + final Tunnel originalTunnel = currentTunnel; + if (currentTunnel != null) + setStateInternal(currentTunnel, null, State.DOWN); + try { + setStateInternal(tunnel, config, state); + } catch(final Exception e) { + if (originalTunnel != null) + setStateInternal(originalTunnel, originalConfig, State.UP); + throw e; + } + } else if (state == State.DOWN && tunnel == currentTunnel) { + setStateInternal(tunnel, null, State.DOWN); + } + return getState(tunnel); + } + + private void setStateInternal(final Tunnel tunnel, @Nullable final Config config, final State state) + throws Exception { + Log.i(TAG, "Bringing tunnel " + tunnel.getName() + " " + state); + + if (state == State.UP) { + if (config == null) + throw new BackendException(Reason.TUNNEL_MISSING_CONFIG); + + if (VpnService.prepare(context) != null) + throw new BackendException(Reason.VPN_NOT_AUTHORIZED); + + final VpnService service; + if (!vpnService.isDone()) + startVpnService(); + + try { + service = vpnService.get(2, TimeUnit.SECONDS); + } catch (final TimeoutException e) { + final Exception be = new BackendException(Reason.UNABLE_TO_START_VPN); + be.initCause(e); + throw be; + } + service.setOwner(this); + + if (currentTunnelHandle != -1) { + Log.w(TAG, "Tunnel already up"); + return; + } + + // Build config + final String goConfig = config.toWgUserspaceString(); + + // Create the vpn tunnel with android API + final VpnService.Builder builder = service.getBuilder(); + builder.setSession(tunnel.getName()); + + for (final String excludedApplication : config.getInterface().getExcludedApplications()) + builder.addDisallowedApplication(excludedApplication); + + for (final InetNetwork addr : config.getInterface().getAddresses()) + builder.addAddress(addr.getAddress(), addr.getMask()); + + for (final InetAddress addr : config.getInterface().getDnsServers()) + builder.addDnsServer(addr.getHostAddress()); + + for (final Peer peer : config.getPeers()) { + for (final InetNetwork addr : peer.getAllowedIps()) + builder.addRoute(addr.getAddress(), addr.getMask()); + } + + builder.setMtu(config.getInterface().getMtu().orElse(1280)); + + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) + builder.setMetered(false); + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) + service.setUnderlyingNetworks(null); + + builder.setBlocking(true); + try (final ParcelFileDescriptor tun = builder.establish()) { + if (tun == null) + throw new BackendException(Reason.TUN_CREATION_ERROR); + Log.d(TAG, "Go backend v" + wgVersion()); + currentTunnelHandle = wgTurnOn(tunnel.getName(), tun.detachFd(), goConfig); + } + if (currentTunnelHandle < 0) + throw new BackendException(Reason.GO_ACTIVATION_ERROR_CODE, currentTunnelHandle); + + currentTunnel = tunnel; + currentConfig = config; + + service.protect(wgGetSocketV4(currentTunnelHandle)); + service.protect(wgGetSocketV6(currentTunnelHandle)); + } else { + if (currentTunnelHandle == -1) { + Log.w(TAG, "Tunnel already down"); + return; + } + + wgTurnOff(currentTunnelHandle); + currentTunnel = null; + currentTunnelHandle = -1; + currentConfig = null; + } + + tunnel.onStateChange(state); + } + + private void startVpnService() { + Log.d(TAG, "Requesting to start VpnService"); + context.startService(new Intent(context, VpnService.class)); + } + + public static class VpnService extends android.net.VpnService { + @Nullable private GoBackend owner; + + public void setOwner(final GoBackend owner) { + this.owner = owner; + } + + public Builder getBuilder() { + return new Builder(); + } + + @Override + public void onCreate() { + vpnService.complete(this); + super.onCreate(); + } + + @Override + public void onDestroy() { + if (owner != null) { + final Tunnel tunnel = owner.currentTunnel; + if (tunnel != null) { + if (owner.currentTunnelHandle != -1) + wgTurnOff(owner.currentTunnelHandle); + owner.currentTunnel = null; + owner.currentTunnelHandle = -1; + owner.currentConfig = null; + tunnel.onStateChange(State.DOWN); + } + } + vpnService = vpnService.newIncompleteFuture(); + super.onDestroy(); + } + + @Override + public int onStartCommand(@Nullable final Intent intent, final int flags, final int startId) { + vpnService.complete(this); + if (intent == null || intent.getComponent() == null || !intent.getComponent().getPackageName().equals(getPackageName())) { + Log.d(TAG, "Service started by Always-on VPN feature"); + if (alwaysOnCallback != null) + alwaysOnCallback.alwaysOnTriggered(); + } + return super.onStartCommand(intent, flags, startId); + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java b/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java new file mode 100644 index 00000000..2ca87d23 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Statistics.java @@ -0,0 +1,62 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import android.os.SystemClock; +import android.util.Pair; + +import com.wireguard.crypto.Key; + +import java.util.HashMap; +import java.util.Map; + +public class Statistics { + private long lastTouched = SystemClock.elapsedRealtime(); + private final Map<Key, Pair<Long, Long>> peerBytes = new HashMap<>(); + + Statistics() { } + + void add(final Key key, final long rx, final long tx) { + peerBytes.put(key, Pair.create(rx, tx)); + lastTouched = SystemClock.elapsedRealtime(); + } + + public boolean isStale() { + return SystemClock.elapsedRealtime() - lastTouched > 900; + } + + public Key[] peers() { + return peerBytes.keySet().toArray(new Key[0]); + } + + public long peerRx(final Key peer) { + if (!peerBytes.containsKey(peer)) + return 0; + return peerBytes.get(peer).first; + } + + public long peerTx(final Key peer) { + if (!peerBytes.containsKey(peer)) + return 0; + return peerBytes.get(peer).second; + } + + public long totalRx() { + long rx = 0; + for (final Pair<Long, Long> val : peerBytes.values()) { + rx += val.first; + } + return rx; + } + + public long totalTx() { + long tx = 0; + for (final Pair<Long, Long> val : peerBytes.values()) { + tx += val.second; + } + return tx; + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java b/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java new file mode 100644 index 00000000..af2f59f7 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/Tunnel.java @@ -0,0 +1,46 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import java.util.regex.Pattern; + +/** + * Represents a WireGuard tunnel. + */ + +public interface Tunnel { + enum State { + DOWN, + TOGGLE, + UP; + + public static State of(final boolean running) { + return running ? UP : DOWN; + } + } + + int NAME_MAX_LENGTH = 15; + Pattern NAME_PATTERN = Pattern.compile("[a-zA-Z0-9_=+.-]{1,15}"); + + static boolean isNameInvalid(final CharSequence name) { + return !NAME_PATTERN.matcher(name).matches(); + } + + /** + * Get the name of the tunnel, which should always pass the !isNameInvalid test. + * + * @return The name of the tunnel. + */ + String getName(); + + /** + * React to a change in state of the tunnel. Should only be directly called by Backend. + * + * @param newState The new state of the tunnel. + * @return The new state of the tunnel. + */ + void onStateChange(State newState); +} diff --git a/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java new file mode 100644 index 00000000..9695aab7 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/backend/WgQuickBackend.java @@ -0,0 +1,158 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.backend; + +import androidx.annotation.Nullable; + +import android.content.Context; +import android.util.Log; + +import com.wireguard.android.backend.BackendException.Reason; +import com.wireguard.android.backend.Tunnel.State; +import com.wireguard.android.util.RootShell; +import com.wireguard.android.util.ToolsInstaller; +import com.wireguard.config.Config; +import com.wireguard.crypto.Key; + +import java.io.File; +import java.io.FileOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.HashMap; + +import java9.util.stream.Collectors; +import java9.util.stream.Stream; + +/** + * WireGuard backend that uses {@code wg-quick} to implement tunnel configuration. + */ + +public final class WgQuickBackend implements Backend { + private static final String TAG = "WireGuard/" + WgQuickBackend.class.getSimpleName(); + + private final RootShell rootShell; + private final ToolsInstaller toolsInstaller; + private final File localTemporaryDir; + private final Map<Tunnel, Config> runningConfigs = new HashMap<>(); + + public WgQuickBackend(final Context context, final RootShell rootShell, final ToolsInstaller toolsInstaller) { + localTemporaryDir = new File(context.getCacheDir(), "tmp"); + this.rootShell = rootShell; + this.toolsInstaller = toolsInstaller; + } + + @Override + public Set<String> getRunningTunnelNames() { + final List<String> output = new ArrayList<>(); + // Don't throw an exception here or nothing will show up in the UI. + try { + toolsInstaller.ensureToolsAvailable(); + if (rootShell.run(output, "wg show interfaces") != 0 || output.isEmpty()) + return Collections.emptySet(); + } catch (final Exception e) { + Log.w(TAG, "Unable to enumerate running tunnels", e); + return Collections.emptySet(); + } + // wg puts all interface names on the same line. Split them into separate elements. + return Stream.of(output.get(0).split(" ")).collect(Collectors.toUnmodifiableSet()); + } + + @Override + public State getState(final Tunnel tunnel) { + return getRunningTunnelNames().contains(tunnel.getName()) ? State.UP : State.DOWN; + } + + @Override + public Statistics getStatistics(final Tunnel tunnel) { + final Statistics stats = new Statistics(); + final Collection<String> output = new ArrayList<>(); + try { + if (rootShell.run(output, String.format("wg show '%s' transfer", tunnel.getName())) != 0) + return stats; + } catch (final Exception ignored) { + return stats; + } + for (final String line : output) { + final String[] parts = line.split("\\t"); + if (parts.length != 3) + continue; + try { + stats.add(Key.fromBase64(parts[0]), Long.parseLong(parts[1]), Long.parseLong(parts[2])); + } catch (final Exception ignored) { + } + } + return stats; + } + + @Override + public String getVersion() throws Exception { + final List<String> output = new ArrayList<>(); + if (rootShell.run(output, "cat /sys/module/wireguard/version") != 0 || output.isEmpty()) + throw new BackendException(Reason.UNKNOWN_KERNEL_MODULE_NAME); + return output.get(0); + } + + @Override + public State setState(final Tunnel tunnel, State state, @Nullable final Config config) throws Exception { + final State originalState = getState(tunnel); + final Config originalConfig = runningConfigs.get(tunnel); + + if (state == State.TOGGLE) + state = originalState == State.UP ? State.DOWN : State.UP; + if ((state == State.UP && originalState == State.UP && originalConfig != null && originalConfig == config) || + (state == State.DOWN && originalState == State.DOWN)) + return originalState; + if (state == State.UP) { + toolsInstaller.ensureToolsAvailable(); + if (originalState == State.UP) + setStateInternal(tunnel, originalConfig == null ? config : originalConfig, State.DOWN); + try { + setStateInternal(tunnel, config, State.UP); + } catch(final Exception e) { + if (originalState == State.UP && originalConfig != null) + setStateInternal(tunnel, originalConfig, State.UP); + throw e; + } + } else if (state == State.DOWN) { + setStateInternal(tunnel, originalConfig == null ? config : originalConfig, State.DOWN); + } + return state; + } + + private void setStateInternal(final Tunnel tunnel, @Nullable final Config config, final State state) throws Exception { + Log.i(TAG, "Bringing tunnel " + tunnel.getName() + " " + state); + + Objects.requireNonNull(config, "Trying to set state up with a null config"); + + final File tempFile = new File(localTemporaryDir, tunnel.getName() + ".conf"); + try (final FileOutputStream stream = new FileOutputStream(tempFile, false)) { + stream.write(config.toWgQuickString().getBytes(StandardCharsets.UTF_8)); + } + String command = String.format("wg-quick %s '%s'", + state.toString().toLowerCase(Locale.ENGLISH), tempFile.getAbsolutePath()); + if (state == State.UP) + command = "cat /sys/module/wireguard/version && " + command; + final int result = rootShell.run(null, command); + // noinspection ResultOfMethodCallIgnored + tempFile.delete(); + if (result != 0) + throw new BackendException(Reason.WG_QUICK_CONFIG_ERROR_CODE, result); + + if (state == State.UP) + runningConfigs.put(tunnel, config); + else + runningConfigs.remove(tunnel); + + tunnel.onStateChange(state); + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/util/AsyncWorker.java b/tunnel/src/main/java/com/wireguard/android/util/AsyncWorker.java new file mode 100644 index 00000000..1d041851 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/util/AsyncWorker.java @@ -0,0 +1,63 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.os.Handler; + +import java.util.concurrent.Executor; + +import java9.util.concurrent.CompletableFuture; +import java9.util.concurrent.CompletionStage; + +/** + * Helper class for running asynchronous tasks and ensuring they are completed on the main thread. + */ + +public class AsyncWorker { + private final Executor executor; + private final Handler handler; + + public AsyncWorker(final Executor executor, final Handler handler) { + this.executor = executor; + this.handler = handler; + } + + public CompletionStage<Void> runAsync(final AsyncRunnable<?> runnable) { + final CompletableFuture<Void> future = new CompletableFuture<>(); + executor.execute(() -> { + try { + runnable.run(); + handler.post(() -> future.complete(null)); + } catch (final Throwable t) { + handler.post(() -> future.completeExceptionally(t)); + } + }); + return future; + } + + public <T> CompletionStage<T> supplyAsync(final AsyncSupplier<T, ?> supplier) { + final CompletableFuture<T> future = new CompletableFuture<>(); + executor.execute(() -> { + try { + final T result = supplier.get(); + handler.post(() -> future.complete(result)); + } catch (final Throwable t) { + handler.post(() -> future.completeExceptionally(t)); + } + }); + return future; + } + + @FunctionalInterface + public interface AsyncRunnable<E extends Throwable> { + void run() throws E; + } + + @FunctionalInterface + public interface AsyncSupplier<T, E extends Throwable> { + T get() throws E; + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/util/RootShell.java b/tunnel/src/main/java/com/wireguard/android/util/RootShell.java new file mode 100644 index 00000000..1fc2c9f2 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/util/RootShell.java @@ -0,0 +1,211 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.content.Context; +import androidx.annotation.Nullable; +import android.util.Log; + +import com.wireguard.android.util.RootShell.RootShellException.Reason; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.UUID; + +/** + * Helper class for running commands as root. + */ + +public class RootShell { + private static final String SU = "su"; + private static final String TAG = "WireGuard/" + RootShell.class.getSimpleName(); + + private final File localBinaryDir; + private final File localTemporaryDir; + private final Object lock = new Object(); + private final String preamble; + @Nullable private Process process; + @Nullable private BufferedReader stderr; + @Nullable private OutputStreamWriter stdin; + @Nullable private BufferedReader stdout; + + public RootShell(final Context context) { + localBinaryDir = new File(context.getCodeCacheDir(), "bin"); + localTemporaryDir = new File(context.getCacheDir(), "tmp"); + preamble = String.format("export CALLING_PACKAGE=%s PATH=\"%s:$PATH\" TMPDIR='%s'; id -u\n", + context.getPackageName(), localBinaryDir, localTemporaryDir); + } + + private static boolean isExecutableInPath(final String name) { + final String path = System.getenv("PATH"); + if (path == null) + return false; + for (final String dir : path.split(":")) + if (new File(dir, name).canExecute()) + return true; + return false; + } + + private boolean isRunning() { + synchronized (lock) { + try { + // Throws an exception if the process hasn't finished yet. + if (process != null) + process.exitValue(); + return false; + } catch (final IllegalThreadStateException ignored) { + // The existing process is still running. + return true; + } + } + } + + /** + * Run a command in a root shell. + * + * @param output Lines read from stdout are appended to this list. Pass null if the + * output from the shell is not important. + * @param command Command to run as root. + * @return The exit value of the command. + */ + public int run(@Nullable final Collection<String> output, final String command) + throws IOException, RootShellException { + synchronized (lock) { + /* Start inside synchronized block to prevent a concurrent call to stop(). */ + start(); + final String marker = UUID.randomUUID().toString(); + final String script = "echo " + marker + "; echo " + marker + " >&2; (" + command + + "); ret=$?; echo " + marker + " $ret; echo " + marker + " $ret >&2\n"; + Log.v(TAG, "executing: " + command); + stdin.write(script); + stdin.flush(); + String line; + int errnoStdout = Integer.MIN_VALUE; + int errnoStderr = Integer.MAX_VALUE; + int markersSeen = 0; + while ((line = stdout.readLine()) != null) { + if (line.startsWith(marker)) { + ++markersSeen; + if (line.length() > marker.length() + 1) { + errnoStdout = Integer.valueOf(line.substring(marker.length() + 1)); + break; + } + } else if (markersSeen > 0) { + if (output != null) + output.add(line); + Log.v(TAG, "stdout: " + line); + } + } + while ((line = stderr.readLine()) != null) { + if (line.startsWith(marker)) { + ++markersSeen; + if (line.length() > marker.length() + 1) { + errnoStderr = Integer.valueOf(line.substring(marker.length() + 1)); + break; + } + } else if (markersSeen > 2) { + Log.v(TAG, "stderr: " + line); + } + } + if (markersSeen != 4) + throw new RootShellException(Reason.SHELL_MARKER_COUNT_ERROR, markersSeen); + if (errnoStdout != errnoStderr) + throw new RootShellException(Reason.SHELL_EXIT_STATUS_READ_ERROR); + Log.v(TAG, "exit: " + errnoStdout); + return errnoStdout; + } + } + + public void start() throws IOException, RootShellException { + if (!isExecutableInPath(SU)) + throw new RootShellException(Reason.NO_ROOT_ACCESS); + synchronized (lock) { + if (isRunning()) + return; + if (!localBinaryDir.isDirectory() && !localBinaryDir.mkdirs()) + throw new RootShellException(Reason.CREATE_BIN_DIR_ERROR); + if (!localTemporaryDir.isDirectory() && !localTemporaryDir.mkdirs()) + throw new RootShellException(Reason.CREATE_TEMP_DIR_ERROR); + try { + final ProcessBuilder builder = new ProcessBuilder().command(SU); + builder.environment().put("LC_ALL", "C"); + try { + process = builder.start(); + } catch (final IOException e) { + // A failure at this stage means the device isn't rooted. + final RootShellException rse = new RootShellException(Reason.NO_ROOT_ACCESS); + rse.initCause(e); + throw rse; + } + stdin = new OutputStreamWriter(process.getOutputStream(), StandardCharsets.UTF_8); + stdout = new BufferedReader(new InputStreamReader(process.getInputStream(), + StandardCharsets.UTF_8)); + stderr = new BufferedReader(new InputStreamReader(process.getErrorStream(), + StandardCharsets.UTF_8)); + stdin.write(preamble); + stdin.flush(); + // Check that the shell started successfully. + final String uid = stdout.readLine(); + if (!"0".equals(uid)) { + Log.w(TAG, "Root check did not return correct UID: " + uid); + throw new RootShellException(Reason.NO_ROOT_ACCESS); + } + if (!isRunning()) { + String line; + while ((line = stderr.readLine()) != null) { + Log.w(TAG, "Root check returned an error: " + line); + if (line.contains("Permission denied")) + throw new RootShellException(Reason.NO_ROOT_ACCESS); + } + throw new RootShellException(Reason.SHELL_START_ERROR, process.exitValue()); + } + } catch (final IOException | RootShellException e) { + stop(); + throw e; + } + } + } + + public void stop() { + synchronized (lock) { + if (process != null) { + process.destroy(); + process = null; + } + } + } + + public static class RootShellException extends Exception { + public enum Reason { + NO_ROOT_ACCESS, + SHELL_MARKER_COUNT_ERROR, + SHELL_EXIT_STATUS_READ_ERROR, + SHELL_START_ERROR, + CREATE_BIN_DIR_ERROR, + CREATE_TEMP_DIR_ERROR + } + private final Reason reason; + private final Object[] format; + public RootShellException(final Reason reason, final Object ...format) { + this.reason = reason; + this.format = format; + } + public boolean isIORelated() { + return reason != Reason.NO_ROOT_ACCESS; + } + public Reason getReason() { + return reason; + } + public Object[] getFormat() { + return format; + } + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/util/SharedLibraryLoader.java b/tunnel/src/main/java/com/wireguard/android/util/SharedLibraryLoader.java new file mode 100644 index 00000000..93e44b64 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/util/SharedLibraryLoader.java @@ -0,0 +1,94 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.content.Context; +import android.os.Build; +import android.util.Log; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.zip.ZipEntry; +import java.util.zip.ZipFile; + +public final class SharedLibraryLoader { + private static final String TAG = "WireGuard/" + SharedLibraryLoader.class.getSimpleName(); + + private SharedLibraryLoader() { + } + + public static boolean extractLibrary(final Context context, final String libName, final File destination) throws IOException { + final Collection<String> apks = new HashSet<>(); + if (context.getApplicationInfo().sourceDir != null) + apks.add(context.getApplicationInfo().sourceDir); + if (context.getApplicationInfo().splitSourceDirs != null) + apks.addAll(Arrays.asList(context.getApplicationInfo().splitSourceDirs)); + + for (final String abi : Build.SUPPORTED_ABIS) { + for (final String apk : apks) { + final ZipFile zipFile; + try { + zipFile = new ZipFile(new File(apk), ZipFile.OPEN_READ); + } catch (final IOException e) { + throw new RuntimeException(e); + } + + final String mappedLibName = System.mapLibraryName(libName); + final byte[] buffer = new byte[1024 * 32]; + final String libZipPath = "lib" + File.separatorChar + abi + File.separatorChar + mappedLibName; + final ZipEntry zipEntry = zipFile.getEntry(libZipPath); + if (zipEntry == null) + continue; + Log.d(TAG, "Extracting apk:/" + libZipPath + " to " + destination.getAbsolutePath()); + try (final FileOutputStream out = new FileOutputStream(destination); + final InputStream in = zipFile.getInputStream(zipEntry)) { + int len; + while ((len = in.read(buffer)) != -1) { + out.write(buffer, 0, len); + } + out.getFD().sync(); + } + zipFile.close(); + return true; + } + } + return false; + } + + public static void loadSharedLibrary(final Context context, final String libName) { + Throwable noAbiException; + try { + System.loadLibrary(libName); + return; + } catch (final UnsatisfiedLinkError e) { + Log.d(TAG, "Failed to load library normally, so attempting to extract from apk", e); + noAbiException = e; + } + File f = null; + try { + f = File.createTempFile("lib", ".so", context.getCodeCacheDir()); + if (extractLibrary(context, libName, f)) { + System.load(f.getAbsolutePath()); + return; + } + } catch (final Exception e) { + Log.d(TAG, "Failed to load library apk:/" + libName, e); + noAbiException = e; + } finally { + if (f != null) + // noinspection ResultOfMethodCallIgnored + f.delete(); + } + if (noAbiException instanceof RuntimeException) + throw (RuntimeException) noAbiException; + throw new RuntimeException(noAbiException); + } +} diff --git a/tunnel/src/main/java/com/wireguard/android/util/ToolsInstaller.java b/tunnel/src/main/java/com/wireguard/android/util/ToolsInstaller.java new file mode 100644 index 00000000..ac18cabf --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/android/util/ToolsInstaller.java @@ -0,0 +1,196 @@ +/* + * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util; + +import android.content.Context; +import androidx.annotation.Nullable; +import android.system.OsConstants; +import android.util.Log; + +import com.wireguard.android.util.RootShell.RootShellException; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +/** + * Helper to install WireGuard tools to the system partition. + */ + +public final class ToolsInstaller { + public static final int ERROR = 0x0; + public static final int MAGISK = 0x4; + public static final int NO = 0x2; + public static final int SYSTEM = 0x8; + public static final int YES = 0x1; + private static final String[] EXECUTABLES = {"wg", "wg-quick"}; + private static final File[] INSTALL_DIRS = { + new File("/system/xbin"), + new File("/system/bin"), + }; + @Nullable private static final File INSTALL_DIR = getInstallDir(); + private static final String TAG = "WireGuard/" + ToolsInstaller.class.getSimpleName(); + + private final Context context; + private final RootShell rootShell; + private final File localBinaryDir; + private final Object lock = new Object(); + @Nullable private Boolean areToolsAvailable; + @Nullable private Boolean installAsMagiskModule; + + public ToolsInstaller(final Context context, final RootShell rootShell) { + localBinaryDir = new File(context.getCodeCacheDir(), "bin"); + this.context = context; + this.rootShell = rootShell; + } + + @Nullable + private static File getInstallDir() { + final String path = System.getenv("PATH"); + if (path == null) + return INSTALL_DIRS[0]; + final List<String> paths = Arrays.asList(path.split(":")); + for (final File dir : INSTALL_DIRS) { + if (paths.contains(dir.getPath()) && dir.isDirectory()) + return dir; + } + return null; + } + + public int areInstalled() throws RootShellException { + if (INSTALL_DIR == null) + return ERROR; + final StringBuilder script = new StringBuilder(); + for (final String name : EXECUTABLES) { + script.append(String.format("cmp -s '%s' '%s' && ", + new File(localBinaryDir, name).getAbsolutePath(), + new File(INSTALL_DIR, name).getAbsolutePath())); + } + script.append("exit ").append(OsConstants.EALREADY).append(';'); + try { + final int ret = rootShell.run(null, script.toString()); + if (ret == OsConstants.EALREADY) + return willInstallAsMagiskModule() ? YES | MAGISK : YES | SYSTEM; + else + return willInstallAsMagiskModule() ? NO | MAGISK : NO | SYSTEM; + } catch (final IOException ignored) { + return ERROR; + } catch (final RootShellException e) { + if (e.isIORelated()) + return ERROR; + throw e; + } + } + + public void ensureToolsAvailable() throws FileNotFoundException { + synchronized (lock) { + if (areToolsAvailable == null) { + try { + Log.d(TAG, extract() ? "Tools are now extracted into our private binary dir" : + "Tools were already extracted into our private binary dir"); + areToolsAvailable = true; + } catch (final IOException e) { + Log.e(TAG, "The wg and wg-quick tools are not available", e); + areToolsAvailable = false; + } + } + if (!areToolsAvailable) + throw new FileNotFoundException("Required tools unavailable"); + } + } + + public int install() throws RootShellException, IOException { + if (!context.getPackageName().startsWith("com.wireguard.")) + throw new SecurityException("The tools may only be installed system-wide from the main WireGuard app."); + return willInstallAsMagiskModule() ? installMagisk() : installSystem(); + } + + private int installMagisk() throws RootShellException, IOException { + extract(); + final StringBuilder script = new StringBuilder("set -ex; "); + + script.append("trap 'rm -rf /sbin/.magisk/img/wireguard' INT TERM EXIT; "); + script.append(String.format("rm -rf /sbin/.magisk/img/wireguard/; mkdir -p /sbin/.magisk/img/wireguard%s; ", INSTALL_DIR)); + script.append("printf 'name=WireGuard Command Line Tools\nversion=1.0\nversionCode=1\nauthor=zx2c4\ndescription=Command line tools for WireGuard\nminMagisk=1500\n' > /sbin/.magisk/img/wireguard/module.prop; "); + script.append("touch /sbin/.magisk/img/wireguard/auto_mount; "); + for (final String name : EXECUTABLES) { + final File destination = new File("/sbin/.magisk/img/wireguard" + INSTALL_DIR, name); + script.append(String.format("cp '%s' '%s'; chmod 755 '%s'; chcon 'u:object_r:system_file:s0' '%s' || true; ", + new File(localBinaryDir, name), destination, destination, destination)); + } + script.append("trap - INT TERM EXIT;"); + + try { + return rootShell.run(null, script.toString()) == 0 ? YES | MAGISK : ERROR; + } catch (final IOException ignored) { + return ERROR; + } catch (final RootShellException e) { + if (e.isIORelated()) + return ERROR; + throw e; + } + } + + private int installSystem() throws RootShellException, IOException { + if (INSTALL_DIR == null) + return OsConstants.ENOENT; + extract(); + final StringBuilder script = new StringBuilder("set -ex; "); + script.append("trap 'mount -o ro,remount /system' EXIT; mount -o rw,remount /system; "); + for (final String name : EXECUTABLES) { + final File destination = new File(INSTALL_DIR, name); + script.append(String.format("cp '%s' '%s'; chmod 755 '%s'; restorecon '%s' || true; ", + new File(localBinaryDir, name), destination, destination, destination)); + } + try { + return rootShell.run(null, script.toString()) == 0 ? YES | SYSTEM : ERROR; + } catch (final IOException ignored) { + return ERROR; + } catch (final RootShellException e) { + if (e.isIORelated()) + return ERROR; + throw e; + } + } + + public boolean extract() throws IOException { + localBinaryDir.mkdirs(); + final File files[] = new File[EXECUTABLES.length]; + final File tempFiles[] = new File[EXECUTABLES.length]; + boolean allExist = true; + for (int i = 0; i < files.length; ++i) { + files[i] = new File(localBinaryDir, EXECUTABLES[i]); + tempFiles[i] = new File(localBinaryDir, EXECUTABLES[i] + ".tmp"); + allExist &= files[i].exists(); + } + if (allExist) + return false; + for (int i = 0; i < files.length; ++i) { + if (!SharedLibraryLoader.extractLibrary(context, EXECUTABLES[i], tempFiles[i])) + throw new FileNotFoundException("Unable to find " + EXECUTABLES[i]); + if (!tempFiles[i].setExecutable(true, false)) + throw new IOException("Unable to mark " + tempFiles[i].getAbsolutePath() + " as executable"); + if (!tempFiles[i].renameTo(files[i])) + throw new IOException("Unable to rename " + tempFiles[i].getAbsolutePath() + " to " + files[i].getAbsolutePath()); + } + return true; + } + + private boolean willInstallAsMagiskModule() { + synchronized (lock) { + if (installAsMagiskModule == null) { + try { + installAsMagiskModule = rootShell.run(null, "[ -d /sbin/.magisk/mirror -a -d /sbin/.magisk/img -a ! -f /cache/.disable_magisk ]") == OsConstants.EXIT_SUCCESS; + } catch (final Exception ignored) { + installAsMagiskModule = false; + } + } + return installAsMagiskModule; + } + } +} |