diff options
author | Samuel Holland <samuel@sholland.org> | 2018-09-05 20:17:14 -0500 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-12-08 02:39:41 +0100 |
commit | daba6506567ffc5f88c3a2a3cf8c009b2a9a7a6d (patch) | |
tree | d95ad1ae84d02fc3e18a211aa1e1ef8150d8fa35 /app/src/main/java/com/wireguard | |
parent | 1d44e27caee1064096e2505b93b1e3164a5039c5 (diff) |
Remodel the Model
- The configuration and crypto model is now entirely independent
of Android classes other than Nullable and TextUtils.
- Model classes are immutable and use builders that enforce the
appropriate optional/required attributes.
- The Android config proxies (for Parcelable and databinding) are
moved to the Android side of the codebase, and are designed to be
safe for two-way databinding. This allows proper observability in
TunnelDetailFragment.
- Various robustness fixes and documentation updates to helper classes.
Diffstat (limited to 'app/src/main/java/com/wireguard')
36 files changed, 1987 insertions, 1243 deletions
diff --git a/app/src/main/java/com/wireguard/android/backend/GoBackend.java b/app/src/main/java/com/wireguard/android/backend/GoBackend.java index 295df9d0..97cf0f8e 100644 --- a/app/src/main/java/com/wireguard/android/backend/GoBackend.java +++ b/app/src/main/java/com/wireguard/android/backend/GoBackend.java @@ -22,13 +22,10 @@ import com.wireguard.android.util.ExceptionLoggers; import com.wireguard.android.util.SharedLibraryLoader; import com.wireguard.config.Config; import com.wireguard.config.InetNetwork; -import com.wireguard.config.Interface; import com.wireguard.config.Peer; -import com.wireguard.crypto.KeyEncoding; import java.net.InetAddress; import java.util.Collections; -import java.util.Formatter; import java.util.Objects; import java.util.Set; import java.util.concurrent.TimeUnit; @@ -146,29 +143,7 @@ public final class GoBackend implements Backend { } // Build config - final Interface iface = config.getInterface(); - final String goConfig; - try (final Formatter fmt = new Formatter(new StringBuilder())) { - fmt.format("replace_peers=true\n"); - if (iface.getPrivateKey() != null) - fmt.format("private_key=%s\n", KeyEncoding.keyToHex(KeyEncoding.keyFromBase64(iface.getPrivateKey()))); - if (iface.getListenPort() != 0) - fmt.format("listen_port=%d\n", config.getInterface().getListenPort()); - for (final Peer peer : config.getPeers()) { - if (peer.getPublicKey() != null) - fmt.format("public_key=%s\n", KeyEncoding.keyToHex(KeyEncoding.keyFromBase64(peer.getPublicKey()))); - if (peer.getPreSharedKey() != null) - fmt.format("preshared_key=%s\n", KeyEncoding.keyToHex(KeyEncoding.keyFromBase64(peer.getPreSharedKey()))); - if (peer.getEndpoint() != null) - fmt.format("endpoint=%s\n", peer.getResolvedEndpointString()); - if (peer.getPersistentKeepalive() != 0) - fmt.format("persistent_keepalive_interval=%d\n", peer.getPersistentKeepalive()); - for (final InetNetwork addr : peer.getAllowedIPs()) { - fmt.format("allowed_ip=%s\n", addr.toString()); - } - } - goConfig = fmt.toString(); - } + final String goConfig = config.toWgUserspaceString(); // Create the vpn tunnel with android API final VpnService.Builder builder = service.getBuilder(); @@ -184,18 +159,15 @@ public final class GoBackend implements Backend { for (final InetNetwork addr : config.getInterface().getAddresses()) builder.addAddress(addr.getAddress(), addr.getMask()); - for (final InetAddress addr : config.getInterface().getDnses()) + for (final InetAddress addr : config.getInterface().getDnsServers()) builder.addDnsServer(addr.getHostAddress()); for (final Peer peer : config.getPeers()) { - for (final InetNetwork addr : peer.getAllowedIPs()) + for (final InetNetwork addr : peer.getAllowedIps()) builder.addRoute(addr.getAddress(), addr.getMask()); } - int mtu = config.getInterface().getMtu(); - if (mtu == 0) - mtu = 1280; - builder.setMtu(mtu); + builder.setMtu(config.getInterface().getMtu().orElse(1280)); builder.setBlocking(true); try (final ParcelFileDescriptor tun = builder.establish()) { diff --git a/app/src/main/java/com/wireguard/android/backend/WgQuickBackend.java b/app/src/main/java/com/wireguard/android/backend/WgQuickBackend.java index bfc363a4..68799057 100644 --- a/app/src/main/java/com/wireguard/android/backend/WgQuickBackend.java +++ b/app/src/main/java/com/wireguard/android/backend/WgQuickBackend.java @@ -114,7 +114,7 @@ public final class WgQuickBackend implements Backend { final File tempFile = new File(localTemporaryDir, tunnel.getName() + ".conf"); try (final FileOutputStream stream = new FileOutputStream(tempFile, false)) { - stream.write(config.toString().getBytes(StandardCharsets.UTF_8)); + stream.write(config.toWgQuickString().getBytes(StandardCharsets.UTF_8)); } String command = String.format("wg-quick %s '%s'", state.toString().toLowerCase(), tempFile.getAbsolutePath()); diff --git a/app/src/main/java/com/wireguard/android/configStore/FileConfigStore.java b/app/src/main/java/com/wireguard/android/configStore/FileConfigStore.java index 0e66dab8..654cb48f 100644 --- a/app/src/main/java/com/wireguard/android/configStore/FileConfigStore.java +++ b/app/src/main/java/com/wireguard/android/configStore/FileConfigStore.java @@ -9,6 +9,7 @@ import android.content.Context; import android.util.Log; import com.wireguard.config.Config; +import com.wireguard.config.ParseException; import java.io.File; import java.io.FileInputStream; @@ -41,7 +42,7 @@ public final class FileConfigStore implements ConfigStore { if (!file.createNewFile()) throw new IOException("Configuration file " + file.getName() + " already exists"); try (final FileOutputStream stream = new FileOutputStream(file, false)) { - stream.write(config.toString().getBytes(StandardCharsets.UTF_8)); + stream.write(config.toWgQuickString().getBytes(StandardCharsets.UTF_8)); } return config; } @@ -67,9 +68,9 @@ public final class FileConfigStore implements ConfigStore { } @Override - public Config load(final String name) throws IOException { + public Config load(final String name) throws IOException, ParseException { try (final FileInputStream stream = new FileInputStream(fileFor(name))) { - return Config.from(stream); + return Config.parse(stream); } } @@ -94,7 +95,7 @@ public final class FileConfigStore implements ConfigStore { if (!file.isFile()) throw new FileNotFoundException("Configuration file " + file.getName() + " not found"); try (final FileOutputStream stream = new FileOutputStream(file, false)) { - stream.write(config.toString().getBytes(StandardCharsets.UTF_8)); + stream.write(config.toWgQuickString().getBytes(StandardCharsets.UTF_8)); } return config; } diff --git a/app/src/main/java/com/wireguard/android/databinding/BindingAdapters.java b/app/src/main/java/com/wireguard/android/databinding/BindingAdapters.java index 629f99e5..fe01bf10 100644 --- a/app/src/main/java/com/wireguard/android/databinding/BindingAdapters.java +++ b/app/src/main/java/com/wireguard/android/databinding/BindingAdapters.java @@ -6,21 +6,30 @@ package com.wireguard.android.databinding; import android.databinding.BindingAdapter; +import android.databinding.DataBindingUtil; import android.databinding.ObservableList; +import android.databinding.ViewDataBinding; import android.databinding.adapters.ListenerUtil; +import android.support.annotation.Nullable; import android.support.v7.widget.LinearLayoutManager; import android.support.v7.widget.RecyclerView; import android.text.InputFilter; +import android.view.LayoutInflater; import android.widget.LinearLayout; import android.widget.TextView; +import com.wireguard.android.BR; import com.wireguard.android.R; import com.wireguard.android.databinding.ObservableKeyedRecyclerViewAdapter.RowConfigurationHandler; import com.wireguard.android.util.ObservableKeyedList; import com.wireguard.android.widget.ToggleSwitch; import com.wireguard.android.widget.ToggleSwitch.OnBeforeCheckedChangeListener; +import com.wireguard.config.Attribute; +import com.wireguard.config.InetNetwork; import com.wireguard.util.Keyed; +import java9.util.Optional; + /** * Static methods for use by generated code in the Android data binding library. */ @@ -42,9 +51,10 @@ public final class BindingAdapters { } @BindingAdapter({"items", "layout"}) - public static <E> void setItems(final LinearLayout view, - final ObservableList<E> oldList, final int oldLayoutId, - final ObservableList<E> newList, final int newLayoutId) { + public static <E> + void setItems(final LinearLayout view, + @Nullable final ObservableList<E> oldList, final int oldLayoutId, + @Nullable final ObservableList<E> newList, final int newLayoutId) { if (oldList == newList && oldLayoutId == newLayoutId) return; ItemChangeListener<E> listener = ListenerUtil.getListener(view, R.id.item_change_listener); @@ -66,11 +76,34 @@ public final class BindingAdapters { listener.setList(newList); } + @BindingAdapter({"items", "layout"}) + public static <E> + void setItems(final LinearLayout view, + @Nullable final Iterable<E> oldList, final int oldLayoutId, + @Nullable final Iterable<E> newList, final int newLayoutId) { + if (oldList == newList && oldLayoutId == newLayoutId) + return; + view.removeAllViews(); + if (newList == null) + return; + final LayoutInflater layoutInflater = LayoutInflater.from(view.getContext()); + for (final E item : newList) { + final ViewDataBinding binding = + DataBindingUtil.inflate(layoutInflater, newLayoutId, view, false); + binding.setVariable(BR.collection, newList); + binding.setVariable(BR.item, item); + binding.executePendingBindings(); + view.addView(binding.getRoot()); + } + } + @BindingAdapter(requireAll = false, value = {"items", "layout", "configurationHandler"}) public static <K, E extends Keyed<? extends K>> void setItems(final RecyclerView view, - final ObservableKeyedList<K, E> oldList, final int oldLayoutId, final RowConfigurationHandler oldRowConfigurationHandler, - final ObservableKeyedList<K, E> newList, final int newLayoutId, final RowConfigurationHandler newRowConfigurationHandler) { + @Nullable final ObservableKeyedList<K, E> oldList, final int oldLayoutId, + final RowConfigurationHandler oldRowConfigurationHandler, + @Nullable final ObservableKeyedList<K, E> newList, final int newLayoutId, + final RowConfigurationHandler newRowConfigurationHandler) { if (view.getLayoutManager() == null) view.setLayoutManager(new LinearLayoutManager(view.getContext(), RecyclerView.VERTICAL, false)); @@ -103,4 +136,13 @@ public final class BindingAdapters { view.setOnBeforeCheckedChangeListener(listener); } + @BindingAdapter("android:text") + public static void setText(final TextView view, final Optional<?> text) { + view.setText(text.map(Object::toString).orElse("")); + } + + @BindingAdapter("android:text") + public static void setText(final TextView view, @Nullable final Iterable<InetNetwork> networks) { + view.setText(networks != null ? Attribute.join(networks) : ""); + } } diff --git a/app/src/main/java/com/wireguard/android/databinding/ObservableKeyedRecyclerViewAdapter.java b/app/src/main/java/com/wireguard/android/databinding/ObservableKeyedRecyclerViewAdapter.java index 26e7687f..5bfa64f1 100644 --- a/app/src/main/java/com/wireguard/android/databinding/ObservableKeyedRecyclerViewAdapter.java +++ b/app/src/main/java/com/wireguard/android/databinding/ObservableKeyedRecyclerViewAdapter.java @@ -73,7 +73,7 @@ public class ObservableKeyedRecyclerViewAdapter<K, E extends Keyed<? extends K>> holder.binding.executePendingBindings(); if (rowConfigurationHandler != null) { - E item = getItem(position); + final E item = getItem(position); if (item != null) { rowConfigurationHandler.onConfigureRow(holder.binding, item, position); } diff --git a/app/src/main/java/com/wireguard/android/fragment/AppListDialogFragment.java b/app/src/main/java/com/wireguard/android/fragment/AppListDialogFragment.java index 8bf5a22d..20633c3e 100644 --- a/app/src/main/java/com/wireguard/android/fragment/AppListDialogFragment.java +++ b/app/src/main/java/com/wireguard/android/fragment/AppListDialogFragment.java @@ -27,19 +27,21 @@ import com.wireguard.android.util.ObservableKeyedArrayList; import com.wireguard.android.util.ObservableKeyedList; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; +import java9.util.Comparators; + public class AppListDialogFragment extends DialogFragment { private static final String KEY_EXCLUDED_APPS = "excludedApps"; private final ObservableKeyedList<String, ApplicationData> appData = new ObservableKeyedArrayList<>(); - private List<String> currentlyExcludedApps; + @Nullable private List<String> currentlyExcludedApps; - public static <T extends Fragment & AppExclusionListener> AppListDialogFragment newInstance(final String[] excludedApps, final T target) { + public static <T extends Fragment & AppExclusionListener> + AppListDialogFragment newInstance(final ArrayList<String> excludedApps, final T target) { final Bundle extras = new Bundle(); - extras.putStringArray(KEY_EXCLUDED_APPS, excludedApps); + extras.putStringArrayList(KEY_EXCLUDED_APPS, excludedApps); final AppListDialogFragment fragment = new AppListDialogFragment(); fragment.setTargetFragment(target, 0); fragment.setArguments(extras); @@ -64,7 +66,7 @@ public class AppListDialogFragment extends DialogFragment { appData.add(new ApplicationData(resolveInfo.loadIcon(pm), resolveInfo.loadLabel(pm).toString(), packageName, currentlyExcludedApps.contains(packageName))); } - Collections.sort(appData, (lhs, rhs) -> lhs.getName().toLowerCase().compareTo(rhs.getName().toLowerCase())); + Collections.sort(appData, Comparators.comparing(ApplicationData::getName, String.CASE_INSENSITIVE_ORDER)); return appData; }).whenComplete(((data, throwable) -> { if (data != null) { @@ -82,12 +84,11 @@ public class AppListDialogFragment extends DialogFragment { @Override public void onCreate(@Nullable final Bundle savedInstanceState) { super.onCreate(savedInstanceState); - - currentlyExcludedApps = Arrays.asList(getArguments().getStringArray(KEY_EXCLUDED_APPS)); + currentlyExcludedApps = getArguments().getStringArrayList(KEY_EXCLUDED_APPS); } @Override - public Dialog onCreateDialog(final Bundle savedInstanceState) { + public Dialog onCreateDialog(@Nullable final Bundle savedInstanceState) { final AlertDialog.Builder alertDialogBuilder = new AlertDialog.Builder(getActivity()); alertDialogBuilder.setTitle(R.string.excluded_applications); diff --git a/app/src/main/java/com/wireguard/android/fragment/ConfigNamingDialogFragment.java b/app/src/main/java/com/wireguard/android/fragment/ConfigNamingDialogFragment.java index 83799818..0931868e 100644 --- a/app/src/main/java/com/wireguard/android/fragment/ConfigNamingDialogFragment.java +++ b/app/src/main/java/com/wireguard/android/fragment/ConfigNamingDialogFragment.java @@ -19,8 +19,11 @@ import com.wireguard.android.Application; import com.wireguard.android.R; import com.wireguard.android.databinding.ConfigNamingDialogFragmentBinding; import com.wireguard.config.Config; +import com.wireguard.config.ParseException; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.Objects; public class ConfigNamingDialogFragment extends DialogFragment { @@ -63,8 +66,8 @@ public class ConfigNamingDialogFragment extends DialogFragment { super.onCreate(savedInstanceState); try { - config = Config.from(getArguments().getString(KEY_CONFIG_TEXT)); - } catch (final IOException exception) { + config = Config.parse(new ByteArrayInputStream(getArguments().getString(KEY_CONFIG_TEXT).getBytes(StandardCharsets.UTF_8))); + } catch (final IOException | ParseException exception) { throw new RuntimeException("Invalid config passed to " + getClass().getSimpleName(), exception); } } diff --git a/app/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.java b/app/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.java index fcc601f3..b4e7202f 100644 --- a/app/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.java +++ b/app/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.java @@ -16,7 +16,6 @@ import android.view.ViewGroup; import com.wireguard.android.R; import com.wireguard.android.databinding.TunnelDetailFragmentBinding; import com.wireguard.android.model.Tunnel; -import com.wireguard.config.Config; /** * Fragment that shows details about a specific tunnel. @@ -25,12 +24,6 @@ import com.wireguard.config.Config; public class TunnelDetailFragment extends BaseFragment { @Nullable private TunnelDetailFragmentBinding binding; - private void onConfigLoaded(final String name, final Config config) { - if (binding != null) { - binding.setConfig(new Config.Observable(config, name)); - } - } - @Override public void onCreate(@Nullable final Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -65,7 +58,7 @@ public class TunnelDetailFragment extends BaseFragment { if (newTunnel == null) binding.setConfig(null); else - newTunnel.getConfigAsync().thenAccept(a -> onConfigLoaded(newTunnel.getName(), a)); + newTunnel.getConfigAsync().thenAccept(binding::setConfig); } @Override diff --git a/app/src/main/java/com/wireguard/android/fragment/TunnelEditorFragment.java b/app/src/main/java/com/wireguard/android/fragment/TunnelEditorFragment.java index 8f319e1e..f1250e64 100644 --- a/app/src/main/java/com/wireguard/android/fragment/TunnelEditorFragment.java +++ b/app/src/main/java/com/wireguard/android/fragment/TunnelEditorFragment.java @@ -7,7 +7,6 @@ package com.wireguard.android.fragment; import android.app.Activity; import android.content.Context; -import android.databinding.Observable; import android.databinding.ObservableList; import android.os.Bundle; import android.support.annotation.Nullable; @@ -24,19 +23,16 @@ import android.view.inputmethod.InputMethodManager; import android.widget.Toast; import com.wireguard.android.Application; -import com.wireguard.android.BR; import com.wireguard.android.R; import com.wireguard.android.databinding.TunnelEditorFragmentBinding; import com.wireguard.android.fragment.AppListDialogFragment.AppExclusionListener; import com.wireguard.android.model.Tunnel; import com.wireguard.android.model.TunnelManager; import com.wireguard.android.util.ExceptionLoggers; -import com.wireguard.config.Attribute; +import com.wireguard.android.viewmodel.ConfigProxy; import com.wireguard.config.Config; -import com.wireguard.config.Peer; import java.util.ArrayList; -import java.util.Collection; import java.util.List; import java.util.Objects; @@ -48,64 +44,13 @@ public class TunnelEditorFragment extends BaseFragment implements AppExclusionLi private static final String KEY_LOCAL_CONFIG = "local_config"; private static final String KEY_ORIGINAL_NAME = "original_name"; private static final String TAG = "WireGuard/" + TunnelEditorFragment.class.getSimpleName(); - private final Collection<Object> breakObjectOrientedLayeringHandlerReceivers = new ArrayList<>(); - @Nullable private TunnelEditorFragmentBinding binding; - private final Observable.OnPropertyChangedCallback breakObjectOrientedLayeringHandler = new Observable.OnPropertyChangedCallback() { - @Override - public void onPropertyChanged(final Observable sender, final int propertyId) { - if (binding == null) - return; - final Config.Observable config = binding.getConfig(); - if (config == null) - return; - if (propertyId == BR.config) { - config.addOnPropertyChangedCallback(breakObjectOrientedLayeringHandler); - breakObjectOrientedLayeringHandlerReceivers.add(config); - config.getInterfaceSection().addOnPropertyChangedCallback(breakObjectOrientedLayeringHandler); - breakObjectOrientedLayeringHandlerReceivers.add(config.getInterfaceSection()); - config.getPeers().addOnListChangedCallback(breakObjectListOrientedLayeringHandler); - breakObjectOrientedLayeringHandlerReceivers.add(config.getPeers()); - } else if (propertyId == BR.dnses || propertyId == BR.peers) - ; - else - return; - final int numSiblings = config.getPeers().size() - 1; - for (final Peer.Observable peer : config.getPeers()) { - peer.setInterfaceDNSRoutes(config.getInterfaceSection().getDnses()); - peer.setNumSiblings(numSiblings); - } - } - }; - private final ObservableList.OnListChangedCallback<? extends ObservableList<Peer.Observable>> breakObjectListOrientedLayeringHandler = new ObservableList.OnListChangedCallback<ObservableList<Peer.Observable>>() { - @Override - public void onChanged(final ObservableList<Peer.Observable> sender) { - } - - @Override - public void onItemRangeChanged(final ObservableList<Peer.Observable> sender, final int positionStart, final int itemCount) { - } - - @Override - public void onItemRangeInserted(final ObservableList<Peer.Observable> sender, final int positionStart, final int itemCount) { - if (binding != null) - breakObjectOrientedLayeringHandler.onPropertyChanged(binding.getConfig(), BR.peers); - } - @Override - public void onItemRangeMoved(final ObservableList<Peer.Observable> sender, final int fromPosition, final int toPosition, final int itemCount) { - } - - @Override - public void onItemRangeRemoved(final ObservableList<Peer.Observable> sender, final int positionStart, final int itemCount) { - if (binding != null) - breakObjectOrientedLayeringHandler.onPropertyChanged(binding.getConfig(), BR.peers); - } - }; + @Nullable private TunnelEditorFragmentBinding binding; @Nullable private Tunnel tunnel; - private void onConfigLoaded(final String name, final Config config) { + private void onConfigLoaded(final Config config) { if (binding != null) { - binding.setConfig(new Config.Observable(config, name)); + binding.setConfig(new ConfigProxy(config)); } } @@ -143,29 +88,23 @@ public class TunnelEditorFragment extends BaseFragment implements AppExclusionLi @Nullable final Bundle savedInstanceState) { super.onCreateView(inflater, container, savedInstanceState); binding = TunnelEditorFragmentBinding.inflate(inflater, container, false); - binding.addOnPropertyChangedCallback(breakObjectOrientedLayeringHandler); - breakObjectOrientedLayeringHandlerReceivers.add(binding); binding.executePendingBindings(); return binding.getRoot(); } - @SuppressWarnings("unchecked") @Override public void onDestroyView() { binding = null; - for (final Object o : breakObjectOrientedLayeringHandlerReceivers) { - if (o instanceof Observable) - ((Observable) o).removeOnPropertyChangedCallback(breakObjectOrientedLayeringHandler); - else if (o instanceof ObservableList) - ((ObservableList) o).removeOnListChangedCallback(breakObjectListOrientedLayeringHandler); - } super.onDestroyView(); } @Override public void onExcludedAppsSelected(final List<String> excludedApps) { Objects.requireNonNull(binding, "Tried to set excluded apps while no view was loaded"); - binding.getConfig().getInterfaceSection().setExcludedApplications(Attribute.iterableToString(excludedApps)); + final ObservableList<String> excludedApplications = + binding.getConfig().getInterface().getExcludedApplications(); + excludedApplications.clear(); + excludedApplications.addAll(excludedApps); } private void onFinished() { @@ -195,25 +134,27 @@ public class TunnelEditorFragment extends BaseFragment implements AppExclusionLi public boolean onOptionsItemSelected(final MenuItem item) { switch (item.getItemId()) { case R.id.menu_action_save: - final Config newConfig = new Config(); + if (binding == null) + return false; + final Config newConfig; try { - binding.getConfig().commitData(newConfig); + newConfig = binding.getConfig().resolve(); } catch (final Exception e) { final String error = ExceptionLoggers.unwrapMessage(e); - final String tunnelName = tunnel == null ? binding.getConfig().getName() : tunnel.getName(); + final String tunnelName = tunnel == null ? binding.getName() : tunnel.getName(); final String message = getString(R.string.config_save_error, tunnelName, error); Log.e(TAG, message, e); Snackbar.make(binding.mainContainer, error, Snackbar.LENGTH_LONG).show(); return false; } if (tunnel == null) { - Log.d(TAG, "Attempting to create new tunnel " + binding.getConfig().getName()); + Log.d(TAG, "Attempting to create new tunnel " + binding.getName()); final TunnelManager manager = Application.getTunnelManager(); - manager.create(binding.getConfig().getName(), newConfig) + manager.create(binding.getName(), newConfig) .whenComplete(this::onTunnelCreated); - } else if (!tunnel.getName().equals(binding.getConfig().getName())) { - Log.d(TAG, "Attempting to rename tunnel to " + binding.getConfig().getName()); - tunnel.setName(binding.getConfig().getName()) + } else if (!tunnel.getName().equals(binding.getName())) { + Log.d(TAG, "Attempting to rename tunnel to " + binding.getName()); + tunnel.setName(binding.getName()) .whenComplete((a, b) -> onTunnelRenamed(tunnel, newConfig, b)); } else { Log.d(TAG, "Attempting to save config of " + tunnel.getName()); @@ -229,7 +170,7 @@ public class TunnelEditorFragment extends BaseFragment implements AppExclusionLi public void onRequestSetExcludedApplications(@SuppressWarnings("unused") final View view) { final FragmentManager fragmentManager = getFragmentManager(); if (fragmentManager != null && binding != null) { - final String[] excludedApps = Attribute.stringToList(binding.getConfig().getInterfaceSection().getExcludedApplications()); + final ArrayList<String> excludedApps = new ArrayList<>(binding.getConfig().getInterface().getExcludedApplications()); final AppListDialogFragment fragment = AppListDialogFragment.newInstance(excludedApps, this); fragment.show(fragmentManager, null); } @@ -237,19 +178,25 @@ public class TunnelEditorFragment extends BaseFragment implements AppExclusionLi @Override public void onSaveInstanceState(final Bundle outState) { - outState.putParcelable(KEY_LOCAL_CONFIG, binding.getConfig()); + if (binding != null) + outState.putParcelable(KEY_LOCAL_CONFIG, binding.getConfig()); outState.putString(KEY_ORIGINAL_NAME, tunnel == null ? null : tunnel.getName()); super.onSaveInstanceState(outState); } @Override - public void onSelectedTunnelChanged(@Nullable final Tunnel oldTunnel, @Nullable final Tunnel newTunnel) { + public void onSelectedTunnelChanged(@Nullable final Tunnel oldTunnel, + @Nullable final Tunnel newTunnel) { tunnel = newTunnel; if (binding == null) return; - binding.setConfig(new Config.Observable(null, null)); - if (tunnel != null) - tunnel.getConfigAsync().thenAccept(a -> onConfigLoaded(tunnel.getName(), a)); + binding.setConfig(new ConfigProxy()); + if (tunnel != null) { + binding.setName(tunnel.getName()); + tunnel.getConfigAsync().thenAccept(this::onConfigLoaded); + } else { + binding.setName(""); + } } private void onTunnelCreated(final Tunnel newTunnel, @Nullable final Throwable throwable) { @@ -301,7 +248,7 @@ public class TunnelEditorFragment extends BaseFragment implements AppExclusionLi onSelectedTunnelChanged(null, getSelectedTunnel()); } else { tunnel = getSelectedTunnel(); - final Config.Observable config = savedInstanceState.getParcelable(KEY_LOCAL_CONFIG); + final ConfigProxy config = savedInstanceState.getParcelable(KEY_LOCAL_CONFIG); final String originalName = savedInstanceState.getString(KEY_ORIGINAL_NAME); if (tunnel != null && !tunnel.getName().equals(originalName)) onSelectedTunnelChanged(null, tunnel); diff --git a/app/src/main/java/com/wireguard/android/fragment/TunnelListFragment.java b/app/src/main/java/com/wireguard/android/fragment/TunnelListFragment.java index 7509e40c..783c0a29 100644 --- a/app/src/main/java/com/wireguard/android/fragment/TunnelListFragment.java +++ b/app/src/main/java/com/wireguard/android/fragment/TunnelListFragment.java @@ -41,9 +41,10 @@ import com.wireguard.android.util.ExceptionLoggers; import com.wireguard.android.widget.MultiselectableRelativeLayout; import com.wireguard.android.widget.fab.FloatingActionsMenuRecyclerViewScrollListener; import com.wireguard.config.Config; +import com.wireguard.config.ParseException; -import java.io.BufferedReader; -import java.io.InputStreamReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; @@ -79,13 +80,13 @@ public class TunnelListFragment extends BaseFragment { private void importTunnel(@NonNull final String configText) { try { // Ensure the config text is parseable before proceeding… - Config.from(configText); + Config.parse(new ByteArrayInputStream(configText.getBytes(StandardCharsets.UTF_8))); // Config text is valid, now create the tunnel… final FragmentManager fragmentManager = getFragmentManager(); if (fragmentManager != null) ConfigNamingDialogFragment.newInstance(configText).show(fragmentManager, null); - } catch (final Exception exception) { + } catch (final IllegalArgumentException | IOException | ParseException exception) { onTunnelImportFinished(Collections.emptyList(), Collections.singletonList(exception)); } } @@ -122,7 +123,6 @@ public class TunnelListFragment extends BaseFragment { if (isZip) { try (ZipInputStream zip = new ZipInputStream(contentResolver.openInputStream(uri))) { - BufferedReader reader = new BufferedReader(new InputStreamReader(zip, StandardCharsets.UTF_8)); ZipEntry entry; while ((entry = zip.getNextEntry()) != null) { if (entry.isDirectory()) @@ -140,7 +140,7 @@ public class TunnelListFragment extends BaseFragment { continue; Config config = null; try { - config = Config.from(reader); + config = Config.parse(zip); } catch (Exception e) { throwables.add(e); } @@ -150,7 +150,7 @@ public class TunnelListFragment extends BaseFragment { } } else { futureTunnels.add(Application.getTunnelManager().create(name, - Config.from(contentResolver.openInputStream(uri))).toCompletableFuture()); + Config.parse(contentResolver.openInputStream(uri))).toCompletableFuture()); } if (futureTunnels.isEmpty()) { diff --git a/app/src/main/java/com/wireguard/android/model/ApplicationData.java b/app/src/main/java/com/wireguard/android/model/ApplicationData.java index efe1ef87..f7c335de 100644 --- a/app/src/main/java/com/wireguard/android/model/ApplicationData.java +++ b/app/src/main/java/com/wireguard/android/model/ApplicationData.java @@ -13,7 +13,6 @@ import com.wireguard.android.BR; import com.wireguard.util.Keyed; public class ApplicationData extends BaseObservable implements Keyed<String> { - private final Drawable icon; private final String name; private final String packageName; diff --git a/app/src/main/java/com/wireguard/android/model/Tunnel.java b/app/src/main/java/com/wireguard/android/model/Tunnel.java index 6d37e009..9092b288 100644 --- a/app/src/main/java/com/wireguard/android/model/Tunnel.java +++ b/app/src/main/java/com/wireguard/android/model/Tunnel.java @@ -49,7 +49,8 @@ public class Tunnel extends BaseObservable implements Keyed<String> { return manager.delete(this); } - @Bindable @Nullable + @Bindable + @Nullable public Config getConfig() { if (config == null) manager.getTunnelConfig(this).whenComplete(ExceptionLoggers.E); @@ -81,7 +82,8 @@ public class Tunnel extends BaseObservable implements Keyed<String> { return TunnelManager.getTunnelState(this); } - @Bindable @Nullable + @Bindable + @Nullable public Statistics getStatistics() { // FIXME: Check age of statistics. if (statistics == null) diff --git a/app/src/main/java/com/wireguard/android/model/TunnelManager.java b/app/src/main/java/com/wireguard/android/model/TunnelManager.java index 3fd7bfc0..83df3595 100644 --- a/app/src/main/java/com/wireguard/android/model/TunnelManager.java +++ b/app/src/main/java/com/wireguard/android/model/TunnelManager.java @@ -44,6 +44,7 @@ public final class TunnelManager extends BaseObservable { private static final String KEY_LAST_USED_TUNNEL = "last_used_tunnel"; private static final String KEY_RESTORE_ON_BOOT = "restore_on_boot"; private static final String KEY_RUNNING_TUNNELS = "enabled_configs"; + private final CompletableFuture<ObservableSortedKeyedList<String, Tunnel>> completableTunnels = new CompletableFuture<>(); private final ConfigStore configStore; private final Context context = Application.get(); @@ -111,7 +112,8 @@ public final class TunnelManager extends BaseObservable { }); } - @Bindable @Nullable + @Bindable + @Nullable public Tunnel getLastUsedTunnel() { return lastUsedTunnel; } diff --git a/app/src/main/java/com/wireguard/android/preference/VersionPreference.java b/app/src/main/java/com/wireguard/android/preference/VersionPreference.java index 228facc7..1f3f5aa8 100644 --- a/app/src/main/java/com/wireguard/android/preference/VersionPreference.java +++ b/app/src/main/java/com/wireguard/android/preference/VersionPreference.java @@ -34,7 +34,8 @@ public class VersionPreference extends Preference { }); } - @Override @Nullable + @Nullable + @Override public CharSequence getSummary() { return versionSummary; } diff --git a/app/src/main/java/com/wireguard/android/util/ExceptionLoggers.java b/app/src/main/java/com/wireguard/android/util/ExceptionLoggers.java index a32e77a4..199b1fbd 100644 --- a/app/src/main/java/com/wireguard/android/util/ExceptionLoggers.java +++ b/app/src/main/java/com/wireguard/android/util/ExceptionLoggers.java @@ -5,9 +5,15 @@ package com.wireguard.android.util; +import android.content.res.Resources; import android.support.annotation.Nullable; import android.util.Log; +import com.wireguard.android.Application; +import com.wireguard.android.R; +import com.wireguard.config.ParseException; +import com.wireguard.crypto.Key; + import java9.util.concurrent.CompletionException; import java9.util.function.BiConsumer; @@ -34,12 +40,35 @@ public enum ExceptionLoggers implements BiConsumer<Object, Throwable> { return throwable; } - public static String unwrapMessage(Throwable throwable) { - throwable = unwrap(throwable); - final String message = throwable.getMessage(); - if (message != null) - return message; - return throwable.getClass().getSimpleName(); + public static String unwrapMessage(final Throwable throwable) { + final Throwable innerThrowable = unwrap(throwable); + final Resources resources = Application.get().getResources(); + String message; + if (innerThrowable instanceof ParseException) { + final ParseException parseException = (ParseException) innerThrowable; + message = resources.getString(R.string.parse_error, parseException.getText(), parseException.getContext()); + if (parseException.getMessage() != null) + message += ": " + parseException.getMessage(); + } else if (innerThrowable instanceof Key.KeyFormatException) { + final Key.KeyFormatException keyFormatException = (Key.KeyFormatException) innerThrowable; + switch (keyFormatException.getFormat()) { + case BASE64: + message = resources.getString(R.string.key_length_base64_exception_message); + break; + case BINARY: + message = resources.getString(R.string.key_length_exception_message); + break; + case HEX: + message = resources.getString(R.string.key_length_hex_exception_message); + break; + default: + // Will never happen, as getFormat is not nullable. + message = null; + } + } else { + message = throwable.getMessage(); + } + return message != null ? message : innerThrowable.getClass().getSimpleName(); } @Override diff --git a/app/src/main/java/com/wireguard/android/util/FragmentUtils.java b/app/src/main/java/com/wireguard/android/util/FragmentUtils.java index d5838a95..b7fdd095 100644 --- a/app/src/main/java/com/wireguard/android/util/FragmentUtils.java +++ b/app/src/main/java/com/wireguard/android/util/FragmentUtils.java @@ -11,7 +11,6 @@ import android.view.ContextThemeWrapper; import com.wireguard.android.activity.SettingsActivity; public final class FragmentUtils { - private FragmentUtils() { // Prevent instantiation } diff --git a/app/src/main/java/com/wireguard/android/util/ObservableKeyedArrayList.java b/app/src/main/java/com/wireguard/android/util/ObservableKeyedArrayList.java index 2ba87535..7af829fb 100644 --- a/app/src/main/java/com/wireguard/android/util/ObservableKeyedArrayList.java +++ b/app/src/main/java/com/wireguard/android/util/ObservableKeyedArrayList.java @@ -64,13 +64,15 @@ public class ObservableKeyedArrayList<K, E extends Keyed<? extends K>> return indexOfKey(key) >= 0; } - @Override @Nullable + @Nullable + @Override public E get(final K key) { final int index = indexOfKey(key); return index >= 0 ? get(index) : null; } - @Override @Nullable + @Nullable + @Override public E getLast(final K key) { final int index = lastIndexOfKey(key); return index >= 0 ? get(index) : null; diff --git a/app/src/main/java/com/wireguard/android/util/ObservableSortedKeyedArrayList.java b/app/src/main/java/com/wireguard/android/util/ObservableSortedKeyedArrayList.java index 7ef94106..d287d33d 100644 --- a/app/src/main/java/com/wireguard/android/util/ObservableSortedKeyedArrayList.java +++ b/app/src/main/java/com/wireguard/android/util/ObservableSortedKeyedArrayList.java @@ -28,8 +28,7 @@ import java.util.Spliterator; public class ObservableSortedKeyedArrayList<K, E extends Keyed<? extends K>> extends ObservableKeyedArrayList<K, E> implements ObservableSortedKeyedList<K, E> { - @Nullable - private final Comparator<? super K> comparator; + @Nullable private final Comparator<? super K> comparator; private final transient KeyList<K, E> keyList = new KeyList<>(this); @SuppressWarnings("WeakerAccess") diff --git a/app/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.java b/app/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.java new file mode 100644 index 00000000..abe8cbcf --- /dev/null +++ b/app/src/main/java/com/wireguard/android/viewmodel/ConfigProxy.java @@ -0,0 +1,93 @@ +/* + * Copyright © 2017-2018 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.viewmodel; + +import android.databinding.ObservableArrayList; +import android.databinding.ObservableList; +import android.os.Parcel; +import android.os.Parcelable; + +import com.wireguard.config.Config; +import com.wireguard.config.ParseException; +import com.wireguard.config.Peer; + +import java.util.ArrayList; +import java.util.Collection; + +public class ConfigProxy implements Parcelable { + public static final Parcelable.Creator<ConfigProxy> CREATOR = new ConfigProxyCreator(); + + private final InterfaceProxy interfaze; + private final ObservableList<PeerProxy> peers = new ObservableArrayList<>(); + + private ConfigProxy(final Parcel in) { + interfaze = in.readParcelable(InterfaceProxy.class.getClassLoader()); + in.readTypedList(peers, PeerProxy.CREATOR); + for (final PeerProxy proxy : peers) + proxy.bind(this); + } + + public ConfigProxy(final Config other) { + interfaze = new InterfaceProxy(other.getInterface()); + for (final Peer peer : other.getPeers()) { + final PeerProxy proxy = new PeerProxy(peer); + peers.add(proxy); + proxy.bind(this); + } + } + + public ConfigProxy() { + interfaze = new InterfaceProxy(); + } + + public PeerProxy addPeer() { + final PeerProxy proxy = new PeerProxy(); + peers.add(proxy); + proxy.bind(this); + return proxy; + } + + @Override + public int describeContents() { + return 0; + } + + public InterfaceProxy getInterface() { + return interfaze; + } + + public ObservableList<PeerProxy> getPeers() { + return peers; + } + + public Config resolve() throws ParseException { + final Collection<Peer> resolvedPeers = new ArrayList<>(); + for (final PeerProxy proxy : peers) + resolvedPeers.add(proxy.resolve()); + return new Config.Builder() + .setInterface(interfaze.resolve()) + .addPeers(resolvedPeers) + .build(); + } + + @Override + public void writeToParcel(final Parcel dest, final int flags) { + dest.writeParcelable(interfaze, flags); + dest.writeTypedList(peers); + } + + private static class ConfigProxyCreator implements Parcelable.Creator<ConfigProxy> { + @Override + public ConfigProxy createFromParcel(final Parcel in) { + return new ConfigProxy(in); + } + + @Override + public ConfigProxy[] newArray(final int size) { + return new ConfigProxy[size]; + } + } +} diff --git a/app/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.java b/app/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.java new file mode 100644 index 00000000..63d82042 --- /dev/null +++ b/app/src/main/java/com/wireguard/android/viewmodel/InterfaceProxy.java @@ -0,0 +1,189 @@ +/* + * Copyright © 2017-2018 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.viewmodel; + +import android.databinding.BaseObservable; +import android.databinding.Bindable; +import android.databinding.ObservableArrayList; +import android.databinding.ObservableList; +import android.os.Parcel; +import android.os.Parcelable; + +import com.wireguard.android.BR; +import com.wireguard.config.Attribute; +import com.wireguard.config.Interface; +import com.wireguard.config.ParseException; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyPair; + +import java.net.InetAddress; +import java.util.List; + +import java9.util.stream.Collectors; +import java9.util.stream.StreamSupport; + +public class InterfaceProxy extends BaseObservable implements Parcelable { + public static final Parcelable.Creator<InterfaceProxy> CREATOR = new InterfaceProxyCreator(); + + private final ObservableList<String> excludedApplications = new ObservableArrayList<>(); + private String addresses; + private String dnsServers; + private String listenPort; + private String mtu; + private String privateKey; + private String publicKey; + + private InterfaceProxy(final Parcel in) { + addresses = in.readString(); + dnsServers = in.readString(); + in.readStringList(excludedApplications); + listenPort = in.readString(); + mtu = in.readString(); + privateKey = in.readString(); + publicKey = in.readString(); + } + + public InterfaceProxy(final Interface other) { + addresses = Attribute.join(other.getAddresses()); + final List<String> dnsServerStrings = StreamSupport.stream(other.getDnsServers()) + .map(InetAddress::getHostAddress) + .collect(Collectors.toUnmodifiableList()); + dnsServers = Attribute.join(dnsServerStrings); + excludedApplications.addAll(other.getExcludedApplications()); + listenPort = other.getListenPort().map(String::valueOf).orElse(""); + mtu = other.getMtu().map(String::valueOf).orElse(""); + final KeyPair keyPair = other.getKeyPair(); + privateKey = keyPair.getPrivateKey().toBase64(); + publicKey = keyPair.getPublicKey().toBase64(); + } + + public InterfaceProxy() { + addresses = ""; + dnsServers = ""; + listenPort = ""; + mtu = ""; + privateKey = ""; + publicKey = ""; + } + + @Override + public int describeContents() { + return 0; + } + + public void generateKeyPair() { + final KeyPair keyPair = new KeyPair(); + privateKey = keyPair.getPrivateKey().toBase64(); + publicKey = keyPair.getPublicKey().toBase64(); + notifyPropertyChanged(BR.privateKey); + notifyPropertyChanged(BR.publicKey); + } + + @Bindable + public String getAddresses() { + return addresses; + } + + @Bindable + public String getDnsServers() { + return dnsServers; + } + + public ObservableList<String> getExcludedApplications() { + return excludedApplications; + } + + @Bindable + public String getListenPort() { + return listenPort; + } + + @Bindable + public String getMtu() { + return mtu; + } + + @Bindable + public String getPrivateKey() { + return privateKey; + } + + @Bindable + public String getPublicKey() { + return publicKey; + } + + public Interface resolve() throws ParseException { + final Interface.Builder builder = new Interface.Builder(); + if (!addresses.isEmpty()) + builder.parseAddresses(addresses); + if (!dnsServers.isEmpty()) + builder.parseDnsServers(dnsServers); + if (!excludedApplications.isEmpty()) + builder.excludeApplications(excludedApplications); + if (!listenPort.isEmpty()) + builder.parseListenPort(listenPort); + if (!mtu.isEmpty()) + builder.parseMtu(mtu); + if (!privateKey.isEmpty()) + builder.parsePrivateKey(privateKey); + return builder.build(); + } + + public void setAddresses(final String addresses) { + this.addresses = addresses; + notifyPropertyChanged(BR.addresses); + } + + public void setDnsServers(final String dnsServers) { + this.dnsServers = dnsServers; + notifyPropertyChanged(BR.dnsServers); + } + + public void setListenPort(final String listenPort) { + this.listenPort = listenPort; + notifyPropertyChanged(BR.listenPort); + } + + public void setMtu(final String mtu) { + this.mtu = mtu; + notifyPropertyChanged(BR.mtu); + } + + public void setPrivateKey(final String privateKey) { + this.privateKey = privateKey; + try { + publicKey = new KeyPair(Key.fromBase64(privateKey)).getPublicKey().toBase64(); + } catch (final Key.KeyFormatException ignored) { + publicKey = ""; + } + notifyPropertyChanged(BR.privateKey); + notifyPropertyChanged(BR.publicKey); + } + + @Override + public void writeToParcel(final Parcel dest, final int flags) { + dest.writeString(addresses); + dest.writeString(dnsServers); + dest.writeStringList(excludedApplications); + dest.writeString(listenPort); + dest.writeString(mtu); + dest.writeString(privateKey); + dest.writeString(publicKey); + } + + private static class InterfaceProxyCreator implements Parcelable.Creator<InterfaceProxy> { + @Override + public InterfaceProxy createFromParcel(final Parcel in) { + return new InterfaceProxy(in); + } + + @Override + public InterfaceProxy[] newArray(final int size) { + return new InterfaceProxy[size]; + } + } +} diff --git a/app/src/main/java/com/wireguard/android/viewmodel/PeerProxy.java b/app/src/main/java/com/wireguard/android/viewmodel/PeerProxy.java new file mode 100644 index 00000000..822a4278 --- /dev/null +++ b/app/src/main/java/com/wireguard/android/viewmodel/PeerProxy.java @@ -0,0 +1,379 @@ +/* + * Copyright © 2017-2018 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.viewmodel; + +import android.databinding.BaseObservable; +import android.databinding.Bindable; +import android.databinding.Observable; +import android.databinding.ObservableList; +import android.os.Parcel; +import android.os.Parcelable; +import android.support.annotation.Nullable; + +import com.wireguard.android.BR; +import com.wireguard.config.Attribute; +import com.wireguard.config.InetEndpoint; +import com.wireguard.config.ParseException; +import com.wireguard.config.Peer; +import com.wireguard.crypto.Key; + +import java.lang.ref.WeakReference; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import java9.util.Lists; +import java9.util.Sets; +import java9.util.stream.Collectors; +import java9.util.stream.Stream; + +public class PeerProxy extends BaseObservable implements Parcelable { + public static final Parcelable.Creator<PeerProxy> CREATOR = new PeerProxyCreator(); + private static final Set<String> IPV4_PUBLIC_NETWORKS = new LinkedHashSet<>(Lists.of( + "0.0.0.0/5", "8.0.0.0/7", "11.0.0.0/8", "12.0.0.0/6", "16.0.0.0/4", "32.0.0.0/3", + "64.0.0.0/2", "128.0.0.0/3", "160.0.0.0/5", "168.0.0.0/6", "172.0.0.0/12", + "172.32.0.0/11", "172.64.0.0/10", "172.128.0.0/9", "173.0.0.0/8", "174.0.0.0/7", + "176.0.0.0/4", "192.0.0.0/9", "192.128.0.0/11", "192.160.0.0/13", "192.169.0.0/16", + "192.170.0.0/15", "192.172.0.0/14", "192.176.0.0/12", "192.192.0.0/10", + "193.0.0.0/8", "194.0.0.0/7", "196.0.0.0/6", "200.0.0.0/5", "208.0.0.0/4" + )); + private static final Set<String> IPV4_WILDCARD = Sets.of("0.0.0.0/0"); + + private final List<String> dnsRoutes = new ArrayList<>(); + private String allowedIps; + private AllowedIpsState allowedIpsState = AllowedIpsState.INVALID; + private String endpoint; + @Nullable private InterfaceDnsListener interfaceDnsListener; + @Nullable private ConfigProxy owner; + @Nullable private PeerListListener peerListListener; + private String persistentKeepalive; + private String preSharedKey; + private String publicKey; + private int totalPeers; + + private PeerProxy(final Parcel in) { + allowedIps = in.readString(); + endpoint = in.readString(); + persistentKeepalive = in.readString(); + preSharedKey = in.readString(); + publicKey = in.readString(); + } + + public PeerProxy(final Peer other) { + allowedIps = Attribute.join(other.getAllowedIps()); + endpoint = other.getEndpoint().map(InetEndpoint::toString).orElse(""); + persistentKeepalive = other.getPersistentKeepalive().map(String::valueOf).orElse(""); + preSharedKey = other.getPreSharedKey().map(Key::toBase64).orElse(""); + publicKey = other.getPublicKey().toBase64(); + } + + public PeerProxy() { + allowedIps = ""; + endpoint = ""; + persistentKeepalive = ""; + preSharedKey = ""; + publicKey = ""; + } + + public void bind(final ConfigProxy owner) { + final InterfaceProxy interfaze = owner.getInterface(); + final ObservableList<PeerProxy> peers = owner.getPeers(); + if (interfaceDnsListener == null) + interfaceDnsListener = new InterfaceDnsListener(this); + interfaze.addOnPropertyChangedCallback(interfaceDnsListener); + setInterfaceDns(interfaze.getDnsServers()); + if (peerListListener == null) + peerListListener = new PeerListListener(this); + peers.addOnListChangedCallback(peerListListener); + setTotalPeers(peers.size()); + this.owner = owner; + } + + private void calculateAllowedIpsState() { + final AllowedIpsState newState; + if (totalPeers == 1) { + // String comparison works because we only care if allowedIps is a superset of one of + // the above sets of (valid) *networks*. We are not checking for a superset based on + // the individual addresses in each set. + final Collection<String> networkStrings = getAllowedIpsSet(); + // If allowedIps contains both the wildcard and the public networks, then private + // networks aren't excluded! + if (networkStrings.containsAll(IPV4_WILDCARD)) + newState = AllowedIpsState.CONTAINS_IPV4_WILDCARD; + else if (networkStrings.containsAll(IPV4_PUBLIC_NETWORKS)) + newState = AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS; + else + newState = AllowedIpsState.OTHER; + } else { + newState = AllowedIpsState.INVALID; + } + if (newState != allowedIpsState) { + allowedIpsState = newState; + notifyPropertyChanged(BR.ableToExcludePrivateIps); + notifyPropertyChanged(BR.excludingPrivateIps); + } + } + + @Override + public int describeContents() { + return 0; + } + + @Bindable + public String getAllowedIps() { + return allowedIps; + } + + private Set<String> getAllowedIpsSet() { + return new LinkedHashSet<>(Lists.of(Attribute.split(allowedIps))); + } + + @Bindable + public String getEndpoint() { + return endpoint; + } + + @Bindable + public String getPersistentKeepalive() { + return persistentKeepalive; + } + + @Bindable + public String getPreSharedKey() { + return preSharedKey; + } + + @Bindable + public String getPublicKey() { + return publicKey; + } + + @Bindable + public boolean isAbleToExcludePrivateIps() { + return allowedIpsState == AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS + || allowedIpsState == AllowedIpsState.CONTAINS_IPV4_WILDCARD; + } + + @Bindable + public boolean isExcludingPrivateIps() { + return allowedIpsState == AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS; + } + + public Peer resolve() throws ParseException { + final Peer.Builder builder = new Peer.Builder(); + if (!allowedIps.isEmpty()) + builder.parseAllowedIPs(allowedIps); + if (!endpoint.isEmpty()) + builder.parseEndpoint(endpoint); + if (!persistentKeepalive.isEmpty()) + builder.parsePersistentKeepalive(persistentKeepalive); + if (!preSharedKey.isEmpty()) + builder.parsePreSharedKey(preSharedKey); + if (!publicKey.isEmpty()) + builder.parsePublicKey(publicKey); + return builder.build(); + } + + public void setAllowedIps(final String allowedIps) { + this.allowedIps = allowedIps; + notifyPropertyChanged(BR.allowedIps); + calculateAllowedIpsState(); + } + + public void setEndpoint(final String endpoint) { + this.endpoint = endpoint; + notifyPropertyChanged(BR.endpoint); + } + + public void setExcludingPrivateIps(final boolean excludingPrivateIps) { + if (!isAbleToExcludePrivateIps() || isExcludingPrivateIps() == excludingPrivateIps) + return; + final Set<String> oldNetworks = excludingPrivateIps ? IPV4_WILDCARD : IPV4_PUBLIC_NETWORKS; + final Set<String> newNetworks = excludingPrivateIps ? IPV4_PUBLIC_NETWORKS : IPV4_WILDCARD; + final Collection<String> input = getAllowedIpsSet(); + final int outputSize = input.size() - oldNetworks.size() + newNetworks.size(); + final Collection<String> output = new LinkedHashSet<>(outputSize); + boolean replaced = false; + // Replace the first instance of the wildcard with the public network list, or vice versa. + for (final String network : input) { + if (oldNetworks.contains(network)) { + if (!replaced) { + for (final String replacement : newNetworks) + if (!output.contains(replacement)) + output.add(replacement); + replaced = true; + } + } else if (!output.contains(network)) { + output.add(network); + } + } + // DNS servers only need to handled specially when we're excluding private IPs. + if (excludingPrivateIps) + output.addAll(dnsRoutes); + else + output.removeAll(dnsRoutes); + allowedIps = Attribute.join(output); + allowedIpsState = excludingPrivateIps ? + AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS : AllowedIpsState.CONTAINS_IPV4_WILDCARD; + notifyPropertyChanged(BR.allowedIps); + notifyPropertyChanged(BR.excludingPrivateIps); + } + + private void setInterfaceDns(final CharSequence dnsServers) { + final List<String> newDnsRoutes = Stream.of(Attribute.split(dnsServers)) + .map(server -> server + "/32") + .collect(Collectors.toUnmodifiableList()); + if (allowedIpsState == AllowedIpsState.CONTAINS_IPV4_PUBLIC_NETWORKS) { + final Collection<String> input = getAllowedIpsSet(); + final Collection<String> output = new LinkedHashSet<>(input.size() + 1); + // Yes, this is quadratic in the number of DNS servers, but most users have 1 or 2. + for (final String network : input) + if (!dnsRoutes.contains(network) || newDnsRoutes.contains(network)) + output.add(network); + // Since output is a Set, this does the Right Thing™ (it does not duplicate networks). + output.addAll(newDnsRoutes); + // None of the public networks are /32s, so this cannot change the AllowedIPs state. + allowedIps = Attribute.join(output); + notifyPropertyChanged(BR.allowedIps); + } + dnsRoutes.clear(); + dnsRoutes.addAll(newDnsRoutes); + } + + public void setPersistentKeepalive(final String persistentKeepalive) { + this.persistentKeepalive = persistentKeepalive; + notifyPropertyChanged(BR.persistentKeepalive); + } + + public void setPreSharedKey(final String preSharedKey) { + this.preSharedKey = preSharedKey; + notifyPropertyChanged(BR.preSharedKey); + } + + public void setPublicKey(final String publicKey) { + this.publicKey = publicKey; + notifyPropertyChanged(BR.publicKey); + } + + private void setTotalPeers(final int totalPeers) { + if (this.totalPeers == totalPeers) + return; + this.totalPeers = totalPeers; + calculateAllowedIpsState(); + } + + public void unbind() { + if (owner == null) + return; + final InterfaceProxy interfaze = owner.getInterface(); + final ObservableList<PeerProxy> peers = owner.getPeers(); + if (interfaceDnsListener != null) + interfaze.removeOnPropertyChangedCallback(interfaceDnsListener); + if (peerListListener != null) + peers.removeOnListChangedCallback(peerListListener); + peers.remove(this); + setInterfaceDns(""); + setTotalPeers(0); + owner = null; + } + + @Override + public void writeToParcel(final Parcel dest, final int flags) { + dest.writeString(allowedIps); + dest.writeString(endpoint); + dest.writeString(persistentKeepalive); + dest.writeString(preSharedKey); + dest.writeString(publicKey); + } + + private enum AllowedIpsState { + CONTAINS_IPV4_PUBLIC_NETWORKS, + CONTAINS_IPV4_WILDCARD, + INVALID, + OTHER + } + + private static final class InterfaceDnsListener extends Observable.OnPropertyChangedCallback { + private final WeakReference<PeerProxy> weakPeerProxy; + + private InterfaceDnsListener(final PeerProxy peerProxy) { + weakPeerProxy = new WeakReference<>(peerProxy); + } + + @Override + public void onPropertyChanged(final Observable sender, final int propertyId) { + @Nullable final PeerProxy peerProxy = weakPeerProxy.get(); + if (peerProxy == null) { + sender.removeOnPropertyChangedCallback(this); + return; + } + // This shouldn't be possible, but try to avoid a ClassCastException anyway. + if (!(sender instanceof InterfaceProxy)) + return; + if (!(propertyId == BR._all || propertyId == BR.dnsServers)) + return; + peerProxy.setInterfaceDns(((InterfaceProxy) sender).getDnsServers()); + } + } + + private static final class PeerListListener + extends ObservableList.OnListChangedCallback<ObservableList<PeerProxy>> { + private final WeakReference<PeerProxy> weakPeerProxy; + + private PeerListListener(final PeerProxy peerProxy) { + weakPeerProxy = new WeakReference<>(peerProxy); + } + + @Override + public void onChanged(final ObservableList<PeerProxy> sender) { + @Nullable final PeerProxy peerProxy = weakPeerProxy.get(); + if (peerProxy == null) { + sender.removeOnListChangedCallback(this); + return; + } + peerProxy.setTotalPeers(sender.size()); + } + + @Override + public void onItemRangeChanged(final ObservableList<PeerProxy> sender, + final int positionStart, final int itemCount) { + // Do nothing. + } + + @Override + public void onItemRangeInserted(final ObservableList<PeerProxy> sender, + final int positionStart, final int itemCount) { + onChanged(sender); + } + + @Override + public void onItemRangeMoved(final ObservableList<PeerProxy> sender, + final int fromPosition, final int toPosition, + final int itemCount) { + // Do nothing. + } + + @Override + public void onItemRangeRemoved(final ObservableList<PeerProxy> sender, + final int positionStart, final int itemCount) { + onChanged(sender); + } + } + + private static class PeerProxyCreator implements Parcelable.Creator<PeerProxy> { + @Override + public PeerProxy createFromParcel(final Parcel in) { + return new PeerProxy(in); + } + + @Override + public PeerProxy[] newArray(final int size) { + return new PeerProxy[size]; + } + } +} diff --git a/app/src/main/java/com/wireguard/android/widget/KeyInputFilter.java b/app/src/main/java/com/wireguard/android/widget/KeyInputFilter.java index b6cdada7..6332b856 100644 --- a/app/src/main/java/com/wireguard/android/widget/KeyInputFilter.java +++ b/app/src/main/java/com/wireguard/android/widget/KeyInputFilter.java @@ -10,7 +10,7 @@ import android.text.InputFilter; import android.text.SpannableStringBuilder; import android.text.Spanned; -import com.wireguard.crypto.KeyEncoding; +import com.wireguard.crypto.Key; /** * InputFilter for entering WireGuard private/public keys encoded with base64. @@ -25,7 +25,8 @@ public class KeyInputFilter implements InputFilter { return new KeyInputFilter(); } - @Override @Nullable + @Nullable + @Override public CharSequence filter(final CharSequence source, final int sStart, final int sEnd, final Spanned dest, @@ -38,9 +39,9 @@ public class KeyInputFilter implements InputFilter { final int dIndex = dStart + (sIndex - sStart); // Restrict characters to the base64 character set. // Ensure adding this character does not push the length over the limit. - if (((dIndex + 1 < KeyEncoding.KEY_LENGTH_BASE64 && isAllowed(c)) || - (dIndex + 1 == KeyEncoding.KEY_LENGTH_BASE64 && c == '=')) && - dLength + (sIndex - sStart) < KeyEncoding.KEY_LENGTH_BASE64) { + if (((dIndex + 1 < Key.Format.BASE64.getLength() && isAllowed(c)) || + (dIndex + 1 == Key.Format.BASE64.getLength() && c == '=')) && + dLength + (sIndex - sStart) < Key.Format.BASE64.getLength()) { ++rIndex; } else { if (replacement == null) diff --git a/app/src/main/java/com/wireguard/android/widget/NameInputFilter.java b/app/src/main/java/com/wireguard/android/widget/NameInputFilter.java index db5336d0..2352630e 100644 --- a/app/src/main/java/com/wireguard/android/widget/NameInputFilter.java +++ b/app/src/main/java/com/wireguard/android/widget/NameInputFilter.java @@ -25,7 +25,8 @@ public class NameInputFilter implements InputFilter { return new NameInputFilter(); } - @Override @Nullable + @Nullable + @Override public CharSequence filter(final CharSequence source, final int sStart, final int sEnd, final Spanned dest, diff --git a/app/src/main/java/com/wireguard/android/widget/fab/FloatingActionsMenu.java b/app/src/main/java/com/wireguard/android/widget/fab/FloatingActionsMenu.java index ed838914..7f5b67e6 100644 --- a/app/src/main/java/com/wireguard/android/widget/fab/FloatingActionsMenu.java +++ b/app/src/main/java/com/wireguard/android/widget/fab/FloatingActionsMenu.java @@ -539,7 +539,7 @@ public class FloatingActionsMenu extends ViewGroup { return new SavedState[size]; } }; - public boolean mExpanded; + private boolean mExpanded; public SavedState(final Parcelable parcel) { super(parcel); diff --git a/app/src/main/java/com/wireguard/config/Attribute.java b/app/src/main/java/com/wireguard/config/Attribute.java index d4bdb6c8..d61cc744 100644 --- a/app/src/main/java/com/wireguard/config/Attribute.java +++ b/app/src/main/java/com/wireguard/config/Attribute.java @@ -1,94 +1,49 @@ /* - * Copyright © 2017-2018 WireGuard LLC. All Rights Reserved. + * Copyright © 2018 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.config; -import android.annotation.SuppressLint; -import android.support.annotation.Nullable; import android.text.TextUtils; -import java.util.HashMap; -import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; -/** - * The set of valid attributes for an interface or peer in a WireGuard configuration file. - */ - -public enum Attribute { - ADDRESS("Address"), - ALLOWED_IPS("AllowedIPs"), - DNS("DNS"), - EXCLUDED_APPLICATIONS("ExcludedApplications"), - ENDPOINT("Endpoint"), - LISTEN_PORT("ListenPort"), - MTU("MTU"), - PERSISTENT_KEEPALIVE("PersistentKeepalive"), - PRESHARED_KEY("PresharedKey"), - PRIVATE_KEY("PrivateKey"), - PUBLIC_KEY("PublicKey"); - - private static final String[] EMPTY_LIST = new String[0]; - private static final Map<String, Attribute> KEY_MAP; - private static final Pattern LIST_SEPARATOR_PATTERN = Pattern.compile("\\s*,\\s*"); - private static final Pattern SEPARATOR_PATTERN = Pattern.compile("\\s|="); - - static { - KEY_MAP = new HashMap<>(Attribute.values().length); - for (final Attribute key : Attribute.values()) { - KEY_MAP.put(key.token.toLowerCase(), key); - } - } +import java9.util.Optional; - private final Pattern pattern; - private final String token; +public final class Attribute { + private static final Pattern LINE_PATTERN = Pattern.compile("(\\w+)\\s*=\\s*([^\\s#][^#]*)"); + private static final Pattern LIST_SEPARATOR = Pattern.compile("\\s*,\\s*"); - Attribute(final String token) { - pattern = Pattern.compile(token + "\\s*=\\s*(\\S.*)"); - this.token = token; - } - - public static <T> String iterableToString(final Iterable<T> iterable) { - return TextUtils.join(", ", iterable); - } - - @Nullable - public static Attribute match(final CharSequence line) { - return KEY_MAP.get(SEPARATOR_PATTERN.split(line)[0].toLowerCase()); - } + private final String key; + private final String value; - public static String[] stringToList(@Nullable final String string) { - if (TextUtils.isEmpty(string)) - return EMPTY_LIST; - return LIST_SEPARATOR_PATTERN.split(string.trim()); + private Attribute(final String key, final String value) { + this.key = key; + this.value = value; } - @SuppressLint("DefaultLocale") - public String composeWith(@Nullable final Object value) { - return String.format("%s = %s%n", token, value); + public static String join(final Iterable<?> values) { + return TextUtils.join(", ", values); } - @SuppressLint("DefaultLocale") - public String composeWith(final int value) { - return String.format("%s = %d%n", token, value); + public static Optional<Attribute> parse(final CharSequence line) { + final Matcher matcher = LINE_PATTERN.matcher(line); + if (!matcher.matches()) + return Optional.empty(); + return Optional.of(new Attribute(matcher.group(1), matcher.group(2))); } - public <T> String composeWith(final Iterable<T> value) { - return String.format("%s = %s%n", token, iterableToString(value)); + public static String[] split(final CharSequence value) { + return LIST_SEPARATOR.split(value); } - @Nullable - public String parse(final CharSequence line) { - final Matcher matcher = pattern.matcher(line); - return matcher.matches() ? matcher.group(1) : null; + public String getKey() { + return key; } - @Nullable - public String[] parseList(final CharSequence line) { - final Matcher matcher = pattern.matcher(line); - return matcher.matches() ? stringToList(matcher.group(1)) : null; + public String getValue() { + return value; } } diff --git a/app/src/main/java/com/wireguard/config/Config.java b/app/src/main/java/com/wireguard/config/Config.java index 61e31838..7645583d 100644 --- a/app/src/main/java/com/wireguard/config/Config.java +++ b/app/src/main/java/com/wireguard/config/Config.java @@ -5,170 +5,193 @@ package com.wireguard.config; -import android.content.Context; -import android.databinding.BaseObservable; -import android.databinding.Bindable; -import android.databinding.ObservableArrayList; -import android.databinding.ObservableList; -import android.os.Parcel; -import android.os.Parcelable; import android.support.annotation.Nullable; -import com.android.databinding.library.baseAdapters.BR; -import com.wireguard.android.Application; -import com.wireguard.android.R; - import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.io.StringReader; -import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Objects; +import java.util.Set; /** - * Represents a wg-quick configuration file, its name, and its connection state. + * Represents the contents of a wg-quick configuration file, made up of one or more "Interface" + * sections (combined together), and zero or more "Peer" sections (treated individually). + * <p> + * Instances of this class are immutable. */ - -public class Config { - private final Interface interfaceSection = new Interface(); - private List<Peer> peers = new ArrayList<>(); - - public static Config from(final String string) throws IOException { - return from(new BufferedReader(new StringReader(string))); +public final class Config { + private final Interface interfaze; + private final List<Peer> peers; + + private Config(final Builder builder) { + interfaze = Objects.requireNonNull(builder.interfaze, "An [Interface] section is required"); + // Defensively copy to ensure immutability even if the Builder is reused. + peers = Collections.unmodifiableList(new ArrayList<>(builder.peers)); } - public static Config from(final InputStream stream) throws IOException { - return from(new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8))); - } - - public static Config from(final BufferedReader reader) throws IOException { - final Config config = new Config(); - final Context context = Application.get(); - Peer currentPeer = null; - String line; - boolean inInterfaceSection = false; - while ((line = reader.readLine()) != null) { - final int commentIndex = line.indexOf('#'); - if (commentIndex != -1) - line = line.substring(0, commentIndex); - line = line.trim(); - if (line.isEmpty()) - continue; - if ("[Interface]".toLowerCase().equals(line.toLowerCase())) { - currentPeer = null; - inInterfaceSection = true; - } else if ("[Peer]".toLowerCase().equals(line.toLowerCase())) { - currentPeer = new Peer(); - config.peers.add(currentPeer); - inInterfaceSection = false; - } else if (inInterfaceSection) { - config.interfaceSection.parse(line); - } else if (currentPeer != null) { - currentPeer.parse(line); - } else { - throw new IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_config_line, line)); + /** + * Parses an series of "Interface" and "Peer" sections into a {@code Config}. Throws + * {@link ParseException} if the input is not well-formed or contains unparseable sections. + * + * @param stream a stream of UTF-8 text that is interpreted as a WireGuard configuration file + * @return a {@code Config} instance representing the supplied configuration + */ + public static Config parse(final InputStream stream) throws IOException, ParseException { + final Builder builder = new Builder(); + try (final BufferedReader reader = new BufferedReader(new InputStreamReader(stream))) { + final Collection<String> interfaceLines = new ArrayList<>(); + final Collection<String> peerLines = new ArrayList<>(); + boolean inInterfaceSection = false; + boolean inPeerSection = false; + @Nullable String line; + while ((line = reader.readLine()) != null) { + final int commentIndex = line.indexOf('#'); + if (commentIndex != -1) + line = line.substring(0, commentIndex); + line = line.trim(); + if (line.isEmpty()) + continue; + if (line.startsWith("[")) { + // Consume all [Peer] lines read so far. + if (inPeerSection) { + builder.parsePeer(peerLines); + peerLines.clear(); + } + if ("[Interface]".equalsIgnoreCase(line)) { + inInterfaceSection = true; + inPeerSection = false; + } else if ("[Peer]".equalsIgnoreCase(line)) { + inInterfaceSection = false; + inPeerSection = true; + } else { + throw new ParseException("top level", line, "Unknown section name"); + } + } else if (inInterfaceSection) { + interfaceLines.add(line); + } else if (inPeerSection) { + peerLines.add(line); + } else { + throw new ParseException("top level", line, "Expected [Interface] or [Peer]"); + } } + if (inPeerSection) + builder.parsePeer(peerLines); + else if (!inInterfaceSection) + throw new ParseException("top level", "", "Empty configuration"); + // Combine all [Interface] sections in the file. + builder.parseInterface(interfaceLines); } - if (!inInterfaceSection && currentPeer == null) { - throw new IllegalArgumentException(context.getString(R.string.tunnel_error_no_config_information)); - } - return config; + return builder.build(); } + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Config)) + return false; + final Config other = (Config) obj; + return interfaze.equals(other.interfaze) && peers.equals(other.peers); + } + + /** + * Returns the interface section of the configuration. + * + * @return the interface configuration + */ public Interface getInterface() { - return interfaceSection; + return interfaze; } + /** + * Returns a list of the configuration's peer sections. + * + * @return a list of {@link Peer}s + */ public List<Peer> getPeers() { return peers; } @Override + public int hashCode() { + return 31 * interfaze.hashCode() + peers.hashCode(); + } + + /** + * Converts the {@code Config} into a string suitable for debugging purposes. The {@code Config} + * is identified by its interface's public key and the number of peers it has. + * + * @return a concise single-line identifier for the {@code Config} + */ + @Override public String toString() { - final StringBuilder sb = new StringBuilder().append(interfaceSection); + return "(Config " + interfaze + " (" + peers.size() + " peers))"; + } + + /** + * Converts the {@code Config} into a string suitable for use as a {@code wg-quick} + * configuration file. + * + * @return the {@code Config} represented as one [Interface] and zero or more [Peer] sections + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + sb.append("[Interface]\n").append(interfaze.toWgQuickString()); for (final Peer peer : peers) - sb.append('\n').append(peer); + sb.append("\n[Peer]\n").append(peer.toWgQuickString()); return sb.toString(); } - public static class Observable extends BaseObservable implements Parcelable { - public static final Creator<Observable> CREATOR = new Creator<Observable>() { - @Override - public Observable createFromParcel(final Parcel in) { - return new Observable(in); - } - - @Override - public Observable[] newArray(final int size) { - return new Observable[size]; - } - }; - private final Interface.Observable observableInterface; - private final ObservableList<Peer.Observable> observablePeers; - @Nullable private String name; - - public Observable(@Nullable final Config parent, @Nullable final String name) { - this.name = name; - - observableInterface = new Interface.Observable(parent == null ? null : parent.interfaceSection); - observablePeers = new ObservableArrayList<>(); - if (parent != null) { - for (final Peer peer : parent.getPeers()) - observablePeers.add(new Peer.Observable(peer)); - } - } - - private Observable(final Parcel in) { - name = in.readString(); - observableInterface = in.readParcelable(Interface.Observable.class.getClassLoader()); - observablePeers = new ObservableArrayList<>(); - in.readTypedList(observablePeers, Peer.Observable.CREATOR); - } + /** + * Serializes the {@code Config} for use with the WireGuard cross-platform userspace API. + * + * @return the {@code Config} represented as a series of "key=value" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + sb.append(interfaze.toWgUserspaceString()); + sb.append("replace_peers=true\n"); + for (final Peer peer : peers) + sb.append(peer.toWgUserspaceString()); + return sb.toString(); + } - public void commitData(final Config parent) { - observableInterface.commitData(parent.interfaceSection); - final List<Peer> newPeers = new ArrayList<>(observablePeers.size()); - for (final Peer.Observable observablePeer : observablePeers) { - final Peer peer = new Peer(); - observablePeer.commitData(peer); - newPeers.add(peer); - } - parent.peers = newPeers; - notifyChange(); - } + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // Defaults to an empty set. + private final Set<Peer> peers = new LinkedHashSet<>(); + // No default; must be provided before building. + @Nullable private Interface interfaze; - @Override - public int describeContents() { - return 0; + public Builder addPeer(final Peer peer) { + peers.add(peer); + return this; } - @Bindable - public Interface.Observable getInterfaceSection() { - return observableInterface; + public Builder addPeers(final Collection<Peer> peers) { + this.peers.addAll(peers); + return this; } - @Bindable - public String getName() { - return name == null ? "" : name; + public Config build() { + return new Config(this); } - @Bindable - public ObservableList<Peer.Observable> getPeers() { - return observablePeers; + public Builder parseInterface(final Iterable<? extends CharSequence> lines) throws ParseException { + return setInterface(Interface.parse(lines)); } - public void setName(final String name) { - this.name = name; - notifyPropertyChanged(BR.name); + public Builder parsePeer(final Iterable<? extends CharSequence> lines) throws ParseException { + return addPeer(Peer.parse(lines)); } - @Override - public void writeToParcel(final Parcel dest, final int flags) { - dest.writeString(name); - dest.writeParcelable(observableInterface, flags); - dest.writeTypedList(observablePeers); + public Builder setInterface(final Interface interfaze) { + this.interfaze = interfaze; + return this; } } } diff --git a/app/src/main/java/com/wireguard/config/InetAddresses.java b/app/src/main/java/com/wireguard/config/InetAddresses.java index c50c5a0e..989598da 100644 --- a/app/src/main/java/com/wireguard/config/InetAddresses.java +++ b/app/src/main/java/com/wireguard/config/InetAddresses.java @@ -5,21 +5,22 @@ package com.wireguard.config; -import android.support.annotation.Nullable; - -import com.wireguard.android.Application; -import com.wireguard.android.R; - import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.net.Inet4Address; +import java.net.Inet6Address; import java.net.InetAddress; +/** + * Utility methods for creating instances of {@link InetAddress}. + */ public final class InetAddresses { private static final Method PARSER_METHOD; static { try { // This method is only present on Android. + // noinspection JavaReflectionMemberAccess PARSER_METHOD = InetAddress.class.getMethod("parseNumericAddress", String.class); } catch (final NoSuchMethodException e) { throw new RuntimeException(e); @@ -30,13 +31,23 @@ public final class InetAddresses { // Prevent instantiation. } - public static InetAddress parse(@Nullable final String address) { - if (address == null || address.isEmpty()) - throw new IllegalArgumentException(Application.get().getString(R.string.tunnel_error_empty_inetaddress)); + /** + * Parses a numeric IPv4 or IPv6 address without performing any DNS lookups. + * + * @param address a string representing the IP address + * @return an instance of {@link Inet4Address} or {@link Inet6Address}, as appropriate + */ + public static InetAddress parse(final String address) { + if (address.isEmpty()) + throw new IllegalArgumentException("Empty address"); try { return (InetAddress) PARSER_METHOD.invoke(null, address); } catch (final IllegalAccessException | InvocationTargetException e) { - throw new RuntimeException(e.getCause() == null ? e : e.getCause()); + final Throwable cause = e.getCause(); + // Re-throw parsing exceptions with the original type, as callers might try to catch + // them. On the other hand, callers cannot be expected to handle reflection failures. + throw cause instanceof IllegalArgumentException ? + (IllegalArgumentException) cause : new RuntimeException(e); } } } diff --git a/app/src/main/java/com/wireguard/config/InetEndpoint.java b/app/src/main/java/com/wireguard/config/InetEndpoint.java index 3efe4203..06d0ca80 100644 --- a/app/src/main/java/com/wireguard/config/InetEndpoint.java +++ b/app/src/main/java/com/wireguard/config/InetEndpoint.java @@ -5,36 +5,68 @@ package com.wireguard.config; -import android.annotation.SuppressLint; +import android.support.annotation.Nullable; -import com.wireguard.android.Application; -import com.wireguard.android.R; +import org.threeten.bp.Duration; +import org.threeten.bp.Instant; import java.net.Inet4Address; -import java.net.Inet6Address; import java.net.InetAddress; import java.net.URI; import java.net.URISyntaxException; import java.net.UnknownHostException; +import java.util.regex.Pattern; -import javax.annotation.Nullable; +import java9.util.Optional; + + +/** + * An external endpoint (host and port) used to connect to a WireGuard {@link Peer}. + * <p> + * Instances of this class are externally immutable. + */ +public final class InetEndpoint { + private static final Pattern BARE_IPV6 = Pattern.compile("^[^\\[]*:"); + private static final Pattern FORBIDDEN_CHARACTERS = Pattern.compile("[/?#]"); -public class InetEndpoint { private final String host; + private final boolean isResolved; + private final Object lock = new Object(); private final int port; - @Nullable private InetAddress resolvedHost; + private Instant lastResolution = Instant.EPOCH; + @Nullable private InetEndpoint resolved; - public InetEndpoint(@Nullable final String endpoint) { - if (endpoint.indexOf('/') != -1 || endpoint.indexOf('?') != -1 || endpoint.indexOf('#') != -1) - throw new IllegalArgumentException(Application.get().getString(R.string.tunnel_error_forbidden_endpoint_chars)); + private InetEndpoint(final String host, final boolean isResolved, final int port) { + this.host = host; + this.isResolved = isResolved; + this.port = port; + } + + public static InetEndpoint parse(final String endpoint) { + if (FORBIDDEN_CHARACTERS.matcher(endpoint).find()) + throw new IllegalArgumentException("Forbidden characters in Endpoint"); final URI uri; try { uri = new URI("wg://" + endpoint); } catch (final URISyntaxException e) { throw new IllegalArgumentException(e); } - host = uri.getHost(); - port = uri.getPort(); + try { + InetAddresses.parse(uri.getHost()); + // Parsing ths host as a numeric address worked, so we don't need to do DNS lookups. + return new InetEndpoint(uri.getHost(), true, uri.getPort()); + } catch (final IllegalArgumentException ignored) { + // Failed to parse the host as a numeric address, so it must be a DNS hostname/FQDN. + return new InetEndpoint(uri.getHost(), false, uri.getPort()); + } + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof InetEndpoint)) + return false; + final InetEndpoint other = (InetEndpoint) obj; + return host.equals(other.host) && port == other.port; } public String getHost() { @@ -45,28 +77,47 @@ public class InetEndpoint { return port; } - @SuppressLint("DefaultLocale") - public String getResolvedEndpoint() throws UnknownHostException { - if (resolvedHost == null) { - final InetAddress[] candidates = InetAddress.getAllByName(host); - if (candidates.length == 0) - throw new UnknownHostException(host); - for (final InetAddress addr : candidates) { - if (addr instanceof Inet4Address) { - resolvedHost = addr; - break; + /** + * Generate an {@code InetEndpoint} instance with the same port and the host resolved using DNS + * to a numeric address. If the host is already numeric, the existing instance may be returned. + * Because this function may perform network I/O, it must not be called from the main thread. + * + * @return the resolved endpoint, or {@link Optional#empty()} + */ + public Optional<InetEndpoint> getResolved() { + if (isResolved) + return Optional.of(this); + synchronized (lock) { + //TODO(zx2c4): Implement a real timeout mechanism using DNS TTL + if (Duration.between(lastResolution, Instant.now()).toMinutes() > 1) { + try { + // Prefer v4 endpoints over v6 to work around DNS64 and IPv6 NAT issues. + final InetAddress[] candidates = InetAddress.getAllByName(host); + InetAddress address = candidates[0]; + for (final InetAddress candidate : candidates) { + if (candidate instanceof Inet4Address) { + address = candidate; + break; + } + } + resolved = new InetEndpoint(address.getHostAddress(), true, port); + lastResolution = Instant.now(); + } catch (final UnknownHostException e) { + resolved = null; } } - if (resolvedHost == null) - resolvedHost = candidates[0]; + return Optional.ofNullable(resolved); } - return String.format(resolvedHost instanceof Inet6Address ? - "[%s]:%d" : "%s:%d", resolvedHost.getHostAddress(), port); } - @SuppressLint("DefaultLocale") - public String getEndpoint() { - return String.format(host.contains(":") && !host.contains("[") ? - "[%s]:%d" : "%s:%d", host, port); + @Override + public int hashCode() { + return host.hashCode() ^ port; + } + + @Override + public String toString() { + final boolean isBareIpv6 = isResolved && BARE_IPV6.matcher(host).matches(); + return (isBareIpv6 ? '[' + host + ']' : host) + ':' + port; } } diff --git a/app/src/main/java/com/wireguard/config/InetNetwork.java b/app/src/main/java/com/wireguard/config/InetNetwork.java index 836a1335..9e5e8c64 100644 --- a/app/src/main/java/com/wireguard/config/InetNetwork.java +++ b/app/src/main/java/com/wireguard/config/InetNetwork.java @@ -7,26 +7,36 @@ package com.wireguard.config; import java.net.Inet4Address; import java.net.InetAddress; -import java.util.Objects; -public class InetNetwork { +/** + * An Internet network, denoted by its address and netmask + * <p> + * Instances of this class are immutable. + */ +public final class InetNetwork { private final InetAddress address; private final int mask; - public InetNetwork(final String input) { - final int slash = input.lastIndexOf('/'); + private InetNetwork(final InetAddress address, final int mask) { + this.address = address; + this.mask = mask; + } + + public static InetNetwork parse(final String network) { + final int slash = network.lastIndexOf('/'); final int rawMask; final String rawAddress; if (slash >= 0) { - rawMask = Integer.parseInt(input.substring(slash + 1), 10); - rawAddress = input.substring(0, slash); + rawMask = Integer.parseInt(network.substring(slash + 1), 10); + rawAddress = network.substring(0, slash); } else { rawMask = -1; - rawAddress = input; + rawAddress = network; } - address = InetAddresses.parse(rawAddress); + final InetAddress address = InetAddresses.parse(rawAddress); final int maxMask = (address instanceof Inet4Address) ? 32 : 128; - mask = rawMask >= 0 && rawMask <= maxMask ? rawMask : maxMask; + final int mask = rawMask >= 0 && rawMask <= maxMask ? rawMask : maxMask; + return new InetNetwork(address, mask); } @Override @@ -34,7 +44,7 @@ public class InetNetwork { if (!(obj instanceof InetNetwork)) return false; final InetNetwork other = (InetNetwork) obj; - return Objects.equals(address, other.address) && mask == other.mask; + return address.equals(other.address) && mask == other.mask; } public InetAddress getAddress() { diff --git a/app/src/main/java/com/wireguard/config/Interface.java b/app/src/main/java/com/wireguard/config/Interface.java index aa1d986b..dc1a291d 100644 --- a/app/src/main/java/com/wireguard/config/Interface.java +++ b/app/src/main/java/com/wireguard/config/Interface.java @@ -5,395 +5,345 @@ package com.wireguard.config; -import android.content.Context; -import android.databinding.BaseObservable; -import android.databinding.Bindable; -import android.os.Parcel; -import android.os.Parcelable; import android.support.annotation.Nullable; -import com.wireguard.android.Application; -import com.wireguard.android.BR; -import com.wireguard.android.R; -import com.wireguard.crypto.Keypair; +import com.wireguard.crypto.Key; +import com.wireguard.crypto.KeyPair; import java.net.InetAddress; -import java.util.ArrayList; -import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Objects; +import java.util.Set; + +import java9.util.Lists; +import java9.util.Optional; +import java9.util.stream.Collectors; +import java9.util.stream.Stream; +import java9.util.stream.StreamSupport; /** - * Represents the configuration for a WireGuard interface (an [Interface] block). + * Represents the configuration for a WireGuard interface (an [Interface] block). Interfaces must + * have a private key (used to initialize a {@code KeyPair}), and may optionally have several other + * attributes. + * <p> + * Instances of this class are immutable. */ - -public class Interface { - private final List<InetNetwork> addressList; - private final Context context = Application.get(); - private final List<InetAddress> dnsList; - private final List<String> excludedApplications; - @Nullable private Keypair keypair; - private int listenPort; - private int mtu; - - public Interface() { - addressList = new ArrayList<>(); - dnsList = new ArrayList<>(); - excludedApplications = new ArrayList<>(); +public final class Interface { + private static final int MAX_UDP_PORT = 65535; + private static final int MIN_UDP_PORT = 0; + + private final Set<InetNetwork> addresses; + private final Set<InetAddress> dnsServers; + private final Set<String> excludedApplications; + private final KeyPair keyPair; + private final Optional<Integer> listenPort; + private final Optional<Integer> mtu; + + private Interface(final Builder builder) { + // Defensively copy to ensure immutability even if the Builder is reused. + addresses = Collections.unmodifiableSet(new LinkedHashSet<>(builder.addresses)); + dnsServers = Collections.unmodifiableSet(new LinkedHashSet<>(builder.dnsServers)); + excludedApplications = Collections.unmodifiableSet(new LinkedHashSet<>(builder.excludedApplications)); + keyPair = Objects.requireNonNull(builder.keyPair, "Interfaces must have a private key"); + listenPort = builder.listenPort; + mtu = builder.mtu; } - private void addAddresses(@Nullable final String[] addresses) { - if (addresses != null && addresses.length > 0) { - for (final String addr : addresses) { - if (addr.isEmpty()) - throw new IllegalArgumentException(context.getString(R.string.tunnel_error_empty_interface_address)); - addressList.add(new InetNetwork(addr)); + /** + * Parses an series of "KEY = VALUE" lines into an {@code Interface}. Throws + * {@link ParseException} if the input is not well-formed or contains unknown attributes. + * + * @param lines An iterable sequence of lines, containing at least a private key attribute + * @return An {@code Interface} with all of the attributes from {@code lines} set + */ + public static Interface parse(final Iterable<? extends CharSequence> lines) throws ParseException { + final Builder builder = new Builder(); + for (final CharSequence line : lines) { + final Attribute attribute = Attribute.parse(line) + .orElseThrow(() -> new ParseException("[Interface]", line, "Syntax error")); + switch (attribute.getKey().toLowerCase()) { + case "address": + builder.parseAddresses(attribute.getValue()); + break; + case "dns": + builder.parseDnsServers(attribute.getValue()); + break; + case "excludedapplications": + builder.parseExcludedApplications(attribute.getValue()); + break; + case "listenport": + builder.parseListenPort(attribute.getValue()); + break; + case "mtu": + builder.parseMtu(attribute.getValue()); + break; + case "privatekey": + builder.parsePrivateKey(attribute.getValue()); + break; + default: + throw new ParseException("[Interface]", attribute.getKey(), "Unknown attribute"); } } + return builder.build(); } - private void addDnses(@Nullable final String[] dnses) { - if (dnses != null && dnses.length > 0) { - for (final String dns : dnses) { - dnsList.add(InetAddresses.parse(dns)); - } - } - } - - private void addExcludedApplications(@Nullable final String[] applications) { - if (applications != null && applications.length > 0) { - excludedApplications.addAll(Arrays.asList(applications)); - } - } - - @Nullable - private String getAddressString() { - if (addressList.isEmpty()) - return null; - return Attribute.iterableToString(addressList); - } - - public InetNetwork[] getAddresses() { - return addressList.toArray(new InetNetwork[addressList.size()]); - } - - @Nullable - private String getDnsString() { - if (dnsList.isEmpty()) - return null; - return Attribute.iterableToString(getDnsStrings()); + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Interface)) + return false; + final Interface other = (Interface) obj; + return addresses.equals(other.addresses) + && dnsServers.equals(other.dnsServers) + && excludedApplications.equals(other.excludedApplications) + && keyPair.equals(other.keyPair) + && listenPort.equals(other.listenPort) + && mtu.equals(other.mtu); } - private List<String> getDnsStrings() { - final List<String> strings = new ArrayList<>(); - for (final InetAddress addr : dnsList) - strings.add(addr.getHostAddress()); - return strings; + /** + * Returns the set of IP addresses assigned to the interface. + * + * @return a set of {@link InetNetwork}s + */ + public Set<InetNetwork> getAddresses() { + // The collection is already immutable. + return addresses; } - public InetAddress[] getDnses() { - return dnsList.toArray(new InetAddress[dnsList.size()]); + /** + * Returns the set of DNS servers associated with the interface. + * + * @return a set of {@link InetAddress}es + */ + public Set<InetAddress> getDnsServers() { + // The collection is already immutable. + return dnsServers; } - public String[] getExcludedApplications() { - return excludedApplications.toArray(new String[excludedApplications.size()]); + /** + * Returns the set of applications excluded from using the interface. + * + * @return a set of package names + */ + public Set<String> getExcludedApplications() { + // The collection is already immutable. + return excludedApplications; } - @Nullable - private String getExcludedApplicationsString() { - if (excludedApplications.isEmpty()) - return null; - return Attribute.iterableToString(excludedApplications); + /** + * Returns the public/private key pair used by the interface. + * + * @return a key pair + */ + public KeyPair getKeyPair() { + return keyPair; } - public int getListenPort() { + /** + * Returns the UDP port number that the WireGuard interface will listen on. + * + * @return a UDP port number, or {@code Optional.empty()} if none is configured + */ + public Optional<Integer> getListenPort() { return listenPort; } - @Nullable - private String getListenPortString() { - if (listenPort == 0) - return null; - return Integer.valueOf(listenPort).toString(); - } - - public int getMtu() { + /** + * Returns the MTU used for the WireGuard interface. + * + * @return the MTU, or {@code Optional.empty()} if none is configured + */ + public Optional<Integer> getMtu() { return mtu; } - @Nullable - private String getMtuString() { - if (mtu == 0) - return null; - return Integer.toString(mtu); - } - - @Nullable - public String getPrivateKey() { - if (keypair == null) - return null; - return keypair.getPrivateKey(); - } - - @Nullable - public String getPublicKey() { - if (keypair == null) - return null; - return keypair.getPublicKey(); - } - - public void parse(final String line) { - final Attribute key = Attribute.match(line); - if (key == null) - throw new IllegalArgumentException(String.format(context.getString(R.string.tunnel_error_interface_parse_failed), line)); - switch (key) { - case ADDRESS: - addAddresses(key.parseList(line)); - break; - case DNS: - addDnses(key.parseList(line)); - break; - case EXCLUDED_APPLICATIONS: - addExcludedApplications(key.parseList(line)); - break; - case LISTEN_PORT: - setListenPortString(key.parse(line)); - break; - case MTU: - setMtuString(key.parse(line)); - break; - case PRIVATE_KEY: - setPrivateKey(key.parse(line)); - break; - default: - throw new IllegalArgumentException(line); - } - } - - private void setAddressString(@Nullable final String addressString) { - addressList.clear(); - addAddresses(Attribute.stringToList(addressString)); - } - - private void setDnsString(@Nullable final String dnsString) { - dnsList.clear(); - addDnses(Attribute.stringToList(dnsString)); - } - - private void setExcludedApplicationsString(@Nullable final String applicationsString) { - excludedApplications.clear(); - addExcludedApplications(Attribute.stringToList(applicationsString)); - } - - private void setListenPort(final int listenPort) { - this.listenPort = listenPort; - } - - private void setListenPortString(@Nullable final String port) { - if (port != null && !port.isEmpty()) - setListenPort(Integer.parseInt(port, 10)); - else - setListenPort(0); - } - - private void setMtu(final int mtu) { - this.mtu = mtu; - } - - private void setMtuString(@Nullable final String mtu) { - if (mtu != null && !mtu.isEmpty()) - setMtu(Integer.parseInt(mtu, 10)); - else - setMtu(0); - } - - private void setPrivateKey(@Nullable String privateKey) { - if (privateKey != null && privateKey.isEmpty()) - privateKey = null; - keypair = privateKey == null ? null : new Keypair(privateKey); + @Override + public int hashCode() { + int hash = 1; + hash = 31 * hash + addresses.hashCode(); + hash = 31 * hash + dnsServers.hashCode(); + hash = 31 * hash + excludedApplications.hashCode(); + hash = 31 * hash + keyPair.hashCode(); + hash = 31 * hash + listenPort.hashCode(); + hash = 31 * hash + mtu.hashCode(); + return hash; } + /** + * Converts the {@code Interface} into a string suitable for debugging purposes. The {@code + * Interface} is identified by its public key and (if set) the port used for its UDP socket. + * + * @return A concise single-line identifier for the {@code Interface} + */ @Override public String toString() { - final StringBuilder sb = new StringBuilder().append("[Interface]\n"); - if (!addressList.isEmpty()) - sb.append(Attribute.ADDRESS.composeWith(addressList)); - if (!dnsList.isEmpty()) - sb.append(Attribute.DNS.composeWith(getDnsStrings())); - if (!excludedApplications.isEmpty()) - sb.append(Attribute.EXCLUDED_APPLICATIONS.composeWith(excludedApplications)); - if (listenPort != 0) - sb.append(Attribute.LISTEN_PORT.composeWith(listenPort)); - if (mtu != 0) - sb.append(Attribute.MTU.composeWith(mtu)); - if (keypair != null) - sb.append(Attribute.PRIVATE_KEY.composeWith(keypair.getPrivateKey())); + final StringBuilder sb = new StringBuilder("(Interface "); + sb.append(keyPair.getPublicKey().toBase64()); + listenPort.ifPresent(lp -> sb.append(" @").append(lp)); + sb.append(')'); return sb.toString(); } - public static class Observable extends BaseObservable implements Parcelable { - public static final Creator<Observable> CREATOR = new Creator<Observable>() { - @Override - public Observable createFromParcel(final Parcel in) { - return new Observable(in); - } - - @Override - public Observable[] newArray(final int size) { - return new Observable[size]; - } - }; - @Nullable private String addresses; - @Nullable private String dnses; - @Nullable private String excludedApplications; - @Nullable private String listenPort; - @Nullable private String mtu; - @Nullable private String privateKey; - @Nullable private String publicKey; - - public Observable(@Nullable final Interface parent) { - if (parent != null) - loadData(parent); - } - - private Observable(final Parcel in) { - addresses = in.readString(); - dnses = in.readString(); - publicKey = in.readString(); - privateKey = in.readString(); - listenPort = in.readString(); - mtu = in.readString(); - excludedApplications = in.readString(); - } - - public void commitData(final Interface parent) { - parent.setAddressString(addresses); - parent.setDnsString(dnses); - parent.setExcludedApplicationsString(excludedApplications); - parent.setPrivateKey(privateKey); - parent.setListenPortString(listenPort); - parent.setMtuString(mtu); - loadData(parent); - notifyChange(); - } - - @Override - public int describeContents() { - return 0; + /** + * Converts the {@code Interface} into a string suitable for inclusion in a {@code wg-quick} + * configuration file. + * + * @return The {@code Interface} represented as a series of "Key = Value" lines + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + if (!addresses.isEmpty()) + sb.append("Address = ").append(Attribute.join(addresses)).append('\n'); + if (!dnsServers.isEmpty()) { + final List<String> dnsServerStrings = StreamSupport.stream(dnsServers) + .map(InetAddress::getHostAddress) + .collect(Collectors.toUnmodifiableList()); + sb.append("DNS = ").append(Attribute.join(dnsServerStrings)).append('\n'); } + if (!excludedApplications.isEmpty()) + sb.append("ExcludedApplications = ").append(Attribute.join(excludedApplications)).append('\n'); + listenPort.ifPresent(lp -> sb.append("ListenPort = ").append(lp).append('\n')); + mtu.ifPresent(m -> sb.append("MTU = ").append(m).append('\n')); + sb.append("PrivateKey = ").append(keyPair.getPrivateKey().toBase64()).append('\n'); + return sb.toString(); + } - public void generateKeypair() { - final Keypair keypair = new Keypair(); - privateKey = keypair.getPrivateKey(); - publicKey = keypair.getPublicKey(); - notifyPropertyChanged(BR.privateKey); - notifyPropertyChanged(BR.publicKey); - } + /** + * Serializes the {@code Interface} for use with the WireGuard cross-platform userspace API. + * Note that not all attributes are included in this representation. + * + * @return the {@code Interface} represented as a series of "KEY=VALUE" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + sb.append("private_key=").append(keyPair.getPrivateKey().toHex()).append('\n'); + listenPort.ifPresent(lp -> sb.append("listen_port=").append(lp).append('\n')); + return sb.toString(); + } - @Nullable - @Bindable - public String getAddresses() { - return addresses; + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // Defaults to an empty set. + private final Set<InetNetwork> addresses = new LinkedHashSet<>(); + // Defaults to an empty set. + private final Set<InetAddress> dnsServers = new LinkedHashSet<>(); + // Defaults to an empty set. + private final Set<String> excludedApplications = new LinkedHashSet<>(); + // No default; must be provided before building. + @Nullable private KeyPair keyPair; + // Defaults to not present. + private Optional<Integer> listenPort = Optional.empty(); + // Defaults to not present. + private Optional<Integer> mtu = Optional.empty(); + + public Builder addAddress(final InetNetwork address) { + addresses.add(address); + return this; } - @Nullable - @Bindable - public String getDnses() { - return dnses; + public Builder addAddresses(final Collection<InetNetwork> addresses) { + this.addresses.addAll(addresses); + return this; } - @Nullable - @Bindable - public String getExcludedApplications() { - return excludedApplications; + public Builder addDnsServer(final InetAddress dnsServer) { + dnsServers.add(dnsServer); + return this; } - @Bindable - public int getExcludedApplicationsCount() { - return Attribute.stringToList(excludedApplications).length; + public Builder addDnsServers(final Collection<? extends InetAddress> dnsServers) { + this.dnsServers.addAll(dnsServers); + return this; } - @Nullable - @Bindable - public String getListenPort() { - return listenPort; + public Interface build() { + return new Interface(this); } - @Nullable - @Bindable - public String getMtu() { - return mtu; + public Builder excludeApplication(final String application) { + excludedApplications.add(application); + return this; } - @Nullable - @Bindable - public String getPrivateKey() { - return privateKey; + public Builder excludeApplications(final Collection<String> applications) { + excludedApplications.addAll(applications); + return this; } - @Nullable - @Bindable - public String getPublicKey() { - return publicKey; + public Builder parseAddresses(final CharSequence addresses) throws ParseException { + try { + final List<InetNetwork> parsed = Stream.of(Attribute.split(addresses)) + .map(InetNetwork::parse) + .collect(Collectors.toUnmodifiableList()); + return addAddresses(parsed); + } catch (final IllegalArgumentException e) { + throw new ParseException("Address", addresses, e); + } } - private void loadData(final Interface parent) { - addresses = parent.getAddressString(); - dnses = parent.getDnsString(); - excludedApplications = parent.getExcludedApplicationsString(); - publicKey = parent.getPublicKey(); - privateKey = parent.getPrivateKey(); - listenPort = parent.getListenPortString(); - mtu = parent.getMtuString(); + public Builder parseDnsServers(final CharSequence dnsServers) throws ParseException { + try { + final List<InetAddress> parsed = Stream.of(Attribute.split(dnsServers)) + .map(InetAddresses::parse) + .collect(Collectors.toUnmodifiableList()); + return addDnsServers(parsed); + } catch (final IllegalArgumentException e) { + throw new ParseException("DNS", dnsServers, e); + } } - public void setAddresses(final String addresses) { - this.addresses = addresses; - notifyPropertyChanged(BR.addresses); + public Builder parseExcludedApplications(final CharSequence apps) throws ParseException { + try { + return excludeApplications(Lists.of(Attribute.split(apps))); + } catch (final IllegalArgumentException e) { + throw new ParseException("ExcludedApplications", apps, e); + } } - public void setDnses(final String dnses) { - this.dnses = dnses; - notifyPropertyChanged(BR.dnses); + public Builder parseListenPort(final String listenPort) throws ParseException { + try { + return setListenPort(Integer.parseInt(listenPort)); + } catch (final IllegalArgumentException e) { + throw new ParseException("ListenPort", listenPort, e); + } } - public void setExcludedApplications(final String excludedApplications) { - this.excludedApplications = excludedApplications; - notifyPropertyChanged(BR.excludedApplications); - notifyPropertyChanged(BR.excludedApplicationsCount); + public Builder parseMtu(final String mtu) throws ParseException { + try { + return setMtu(Integer.parseInt(mtu)); + } catch (final IllegalArgumentException e) { + throw new ParseException("MTU", mtu, e); + } } - public void setListenPort(final String listenPort) { - this.listenPort = listenPort; - notifyPropertyChanged(BR.listenPort); + public Builder parsePrivateKey(final String privateKey) throws ParseException { + try { + return setKeyPair(new KeyPair(Key.fromBase64(privateKey))); + } catch (final Key.KeyFormatException e) { + throw new ParseException("PrivateKey", "(omitted)", e); + } } - public void setMtu(final String mtu) { - this.mtu = mtu; - notifyPropertyChanged(BR.mtu); + public Builder setKeyPair(final KeyPair keyPair) { + this.keyPair = keyPair; + return this; } - public void setPrivateKey(final String privateKey) { - this.privateKey = privateKey; - - try { - publicKey = new Keypair(privateKey).getPublicKey(); - } catch (final IllegalArgumentException ignored) { - publicKey = ""; - } - - notifyPropertyChanged(BR.privateKey); - notifyPropertyChanged(BR.publicKey); + public Builder setListenPort(final int listenPort) { + if (listenPort < MIN_UDP_PORT || listenPort > MAX_UDP_PORT) + throw new IllegalArgumentException("ListenPort must be a valid UDP port number"); + this.listenPort = listenPort == 0 ? Optional.empty() : Optional.of(listenPort); + return this; } - @Override - public void writeToParcel(final Parcel dest, final int flags) { - dest.writeString(addresses); - dest.writeString(dnses); - dest.writeString(publicKey); - dest.writeString(privateKey); - dest.writeString(listenPort); - dest.writeString(mtu); - dest.writeString(excludedApplications); + public Builder setMtu(final int mtu) { + if (mtu < 0) + throw new IllegalArgumentException("MTU must not be negative"); + this.mtu = mtu == 0 ? Optional.empty() : Optional.of(mtu); + return this; } } } diff --git a/app/src/main/java/com/wireguard/config/ParseException.java b/app/src/main/java/com/wireguard/config/ParseException.java new file mode 100644 index 00000000..1fccb534 --- /dev/null +++ b/app/src/main/java/com/wireguard/config/ParseException.java @@ -0,0 +1,41 @@ +/* + * Copyright © 2018 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.config; + +/** + * An exception representing a failure to parse an element of a WireGuard configuration. The context + * for this failure can be retrieved with {@link #getContext}, and the text that failed to parse can + * be retrieved with {@link #getText}. + */ +public class ParseException extends Exception { + private final String context; + private final CharSequence text; + + public ParseException(final String context, final CharSequence text, final String message) { + super(message); + this.context = context; + this.text = text; + } + + public ParseException(final String context, final CharSequence text, final Throwable cause) { + super(cause.getMessage(), cause); + this.context = context; + this.text = text; + } + + public ParseException(final String context, final CharSequence text) { + this.context = context; + this.text = text; + } + + public String getContext() { + return context; + } + + public CharSequence getText() { + return text; + } +} diff --git a/app/src/main/java/com/wireguard/config/Peer.java b/app/src/main/java/com/wireguard/config/Peer.java index 5cf0283c..50135fb0 100644 --- a/app/src/main/java/com/wireguard/config/Peer.java +++ b/app/src/main/java/com/wireguard/config/Peer.java @@ -5,363 +5,291 @@ package com.wireguard.config; -import android.annotation.SuppressLint; -import android.content.Context; -import android.databinding.BaseObservable; -import android.databinding.Bindable; -import android.os.Parcel; -import android.os.Parcelable; import android.support.annotation.Nullable; -import com.android.databinding.library.baseAdapters.BR; -import com.wireguard.android.Application; -import com.wireguard.android.R; -import com.wireguard.crypto.KeyEncoding; - -import java.net.Inet6Address; -import java.net.InetSocketAddress; -import java.net.URI; -import java.net.URISyntaxException; -import java.net.UnknownHostException; -import java.util.ArrayList; -import java.util.Arrays; +import com.wireguard.crypto.Key; + import java.util.Collection; -import java.util.HashSet; +import java.util.Collections; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Objects; +import java.util.Set; -import java9.lang.Iterables; +import java9.util.Optional; +import java9.util.stream.Collectors; +import java9.util.stream.Stream; /** - * Represents the configuration for a WireGuard peer (a [Peer] block). + * Represents the configuration for a WireGuard peer (a [Peer] block). Peers must have a public key, + * and may optionally have several other attributes. + * <p> + * Instances of this class are immutable. */ - -public class Peer { - private final List<InetNetwork> allowedIPsList; - private final Context context = Application.get(); - @Nullable private InetEndpoint endpoint; - private int persistentKeepalive; - @Nullable private String preSharedKey; - @Nullable private String publicKey; - - public Peer() { - allowedIPsList = new ArrayList<>(); +public final class Peer { + private final Set<InetNetwork> allowedIps; + private final Optional<InetEndpoint> endpoint; + private final Optional<Integer> persistentKeepalive; + private final Optional<Key> preSharedKey; + private final Key publicKey; + + private Peer(final Builder builder) { + // Defensively copy to ensure immutability even if the Builder is reused. + allowedIps = Collections.unmodifiableSet(new LinkedHashSet<>(builder.allowedIps)); + endpoint = builder.endpoint; + persistentKeepalive = builder.persistentKeepalive; + preSharedKey = builder.preSharedKey; + publicKey = Objects.requireNonNull(builder.publicKey, "Peers must have a public key"); } - private void addAllowedIPs(@Nullable final String[] allowedIPs) { - if (allowedIPs != null && allowedIPs.length > 0) { - for (final String allowedIP : allowedIPs) { - allowedIPsList.add(new InetNetwork(allowedIP)); + /** + * Parses an series of "KEY = VALUE" lines into a {@code Peer}. Throws {@link ParseException} if + * the input is not well-formed or contains unknown attributes. + * + * @param lines an iterable sequence of lines, containing at least a public key attribute + * @return a {@code Peer} with all of its attributes set from {@code lines} + */ + public static Peer parse(final Iterable<? extends CharSequence> lines) throws ParseException { + final Builder builder = new Builder(); + for (final CharSequence line : lines) { + final Attribute attribute = Attribute.parse(line) + .orElseThrow(() -> new ParseException("[Peer]", line, "Syntax error")); + switch (attribute.getKey().toLowerCase()) { + case "allowedips": + builder.parseAllowedIPs(attribute.getValue()); + break; + case "endpoint": + builder.parseEndpoint(attribute.getValue()); + break; + case "persistentkeepalive": + builder.parsePersistentKeepalive(attribute.getValue()); + break; + case "presharedkey": + builder.parsePreSharedKey(attribute.getValue()); + break; + case "publickey": + builder.parsePublicKey(attribute.getValue()); + break; + default: + throw new ParseException("[Peer]", line, "Unknown attribute"); } } + return builder.build(); } - public InetNetwork[] getAllowedIPs() { - return allowedIPsList.toArray(new InetNetwork[allowedIPsList.size()]); + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof Peer)) + return false; + final Peer other = (Peer) obj; + return allowedIps.equals(other.allowedIps) + && endpoint.equals(other.endpoint) + && persistentKeepalive.equals(other.persistentKeepalive) + && preSharedKey.equals(other.preSharedKey) + && publicKey.equals(other.publicKey); } - @Nullable - private String getAllowedIPsString() { - if (allowedIPsList.isEmpty()) - return null; - return Attribute.iterableToString(allowedIPsList); + /** + * Returns the peer's set of allowed IPs. + * + * @return the set of allowed IPs + */ + public Set<InetNetwork> getAllowedIps() { + // The collection is already immutable. + return allowedIps; } - @Nullable - public InetEndpoint getEndpoint() { + /** + * Returns the peer's endpoint. + * + * @return the endpoint, or {@code Optional.empty()} if none is configured + */ + public Optional<InetEndpoint> getEndpoint() { return endpoint; } - @Nullable - private String getEndpointString() { - if (endpoint == null) - return null; - return endpoint.getEndpoint(); - } - - public int getPersistentKeepalive() { + /** + * Returns the peer's persistent keepalive. + * + * @return the persistent keepalive, or {@code Optional.empty()} if none is configured + */ + public Optional<Integer> getPersistentKeepalive() { return persistentKeepalive; } - @Nullable - private String getPersistentKeepaliveString() { - if (persistentKeepalive == 0) - return null; - return Integer.valueOf(persistentKeepalive).toString(); - } - - @Nullable - public String getPreSharedKey() { + /** + * Returns the peer's pre-shared key. + * + * @return the pre-shared key, or {@code Optional.empty()} if none is configured + */ + public Optional<Key> getPreSharedKey() { return preSharedKey; } - @Nullable - public String getPublicKey() { + /** + * Returns the peer's public key. + * + * @return the public key + */ + public Key getPublicKey() { return publicKey; } - public String getResolvedEndpointString() throws UnknownHostException { - if (endpoint == null) - throw new UnknownHostException("{empty}"); - return endpoint.getResolvedEndpoint(); - } - - public void parse(final String line) { - final Attribute key = Attribute.match(line); - if (key == null) - throw new IllegalArgumentException(context.getString(R.string.tunnel_error_interface_parse_failed, line)); - switch (key) { - case ALLOWED_IPS: - addAllowedIPs(key.parseList(line)); - break; - case ENDPOINT: - setEndpointString(key.parse(line)); - break; - case PERSISTENT_KEEPALIVE: - setPersistentKeepaliveString(key.parse(line)); - break; - case PRESHARED_KEY: - setPreSharedKey(key.parse(line)); - break; - case PUBLIC_KEY: - setPublicKey(key.parse(line)); - break; - default: - throw new IllegalArgumentException(line); - } - } - - private void setAllowedIPsString(@Nullable final String allowedIPsString) { - allowedIPsList.clear(); - addAllowedIPs(Attribute.stringToList(allowedIPsString)); - } - - private void setEndpoint(@Nullable final InetEndpoint endpoint) { - this.endpoint = endpoint; - } - - private void setEndpointString(@Nullable final String endpoint) { - if (endpoint != null && !endpoint.isEmpty()) - setEndpoint(new InetEndpoint(endpoint)); - else - setEndpoint(null); - } - - private void setPersistentKeepalive(final int persistentKeepalive) { - this.persistentKeepalive = persistentKeepalive; - } - - private void setPersistentKeepaliveString(@Nullable final String persistentKeepalive) { - if (persistentKeepalive != null && !persistentKeepalive.isEmpty()) - setPersistentKeepalive(Integer.parseInt(persistentKeepalive, 10)); - else - setPersistentKeepalive(0); - } - - private void setPreSharedKey(@Nullable String preSharedKey) { - if (preSharedKey != null && preSharedKey.isEmpty()) - preSharedKey = null; - if (preSharedKey != null) - KeyEncoding.keyFromBase64(preSharedKey); - this.preSharedKey = preSharedKey; - } - - private void setPublicKey(@Nullable String publicKey) { - if (publicKey != null && publicKey.isEmpty()) - publicKey = null; - if (publicKey != null) - KeyEncoding.keyFromBase64(publicKey); - this.publicKey = publicKey; + @Override + public int hashCode() { + int hash = 1; + hash = 31 * hash + allowedIps.hashCode(); + hash = 31 * hash + endpoint.hashCode(); + hash = 31 * hash + persistentKeepalive.hashCode(); + hash = 31 * hash + preSharedKey.hashCode(); + hash = 31 * hash + publicKey.hashCode(); + return hash; } + /** + * Converts the {@code Peer} into a string suitable for debugging purposes. The {@code Peer} is + * identified by its public key and (if known) its endpoint. + * + * @return a concise single-line identifier for the {@code Peer} + */ @Override public String toString() { - final StringBuilder sb = new StringBuilder().append("[Peer]\n"); - if (!allowedIPsList.isEmpty()) - sb.append(Attribute.ALLOWED_IPS.composeWith(allowedIPsList)); - if (endpoint != null) - sb.append(Attribute.ENDPOINT.composeWith(getEndpointString())); - if (persistentKeepalive != 0) - sb.append(Attribute.PERSISTENT_KEEPALIVE.composeWith(persistentKeepalive)); - if (preSharedKey != null) - sb.append(Attribute.PRESHARED_KEY.composeWith(preSharedKey)); - if (publicKey != null) - sb.append(Attribute.PUBLIC_KEY.composeWith(publicKey)); + final StringBuilder sb = new StringBuilder("(Peer "); + sb.append(publicKey.toBase64()); + endpoint.ifPresent(ep -> sb.append(" @").append(ep)); + sb.append(')'); return sb.toString(); } - public static class Observable extends BaseObservable implements Parcelable { - public static final Creator<Observable> CREATOR = new Creator<Observable>() { - @Override - public Observable createFromParcel(final Parcel in) { - return new Observable(in); - } - - @Override - public Observable[] newArray(final int size) { - return new Observable[size]; - } - }; - private static final List<String> DEFAULT_ROUTE_MOD_RFC1918_V4 = Arrays.asList("0.0.0.0/5", "8.0.0.0/7", "11.0.0.0/8", "12.0.0.0/6", "16.0.0.0/4", "32.0.0.0/3", "64.0.0.0/2", "128.0.0.0/3", "160.0.0.0/5", "168.0.0.0/6", "172.0.0.0/12", "172.32.0.0/11", "172.64.0.0/10", "172.128.0.0/9", "173.0.0.0/8", "174.0.0.0/7", "176.0.0.0/4", "192.0.0.0/9", "192.128.0.0/11", "192.160.0.0/13", "192.169.0.0/16", "192.170.0.0/15", "192.172.0.0/14", "192.176.0.0/12", "192.192.0.0/10", "193.0.0.0/8", "194.0.0.0/7", "196.0.0.0/6", "200.0.0.0/5", "208.0.0.0/4"); - private static final String DEFAULT_ROUTE_V4 = "0.0.0.0/0"; - private final List<String> interfaceDNSRoutes = new ArrayList<>(); - @Nullable private String allowedIPs; - @Nullable private String endpoint; - private int numSiblings; - @Nullable private String persistentKeepalive; - @Nullable private String preSharedKey; - @Nullable private String publicKey; - - public Observable(final Peer parent) { - loadData(parent); - } - - private Observable(final Parcel in) { - allowedIPs = in.readString(); - endpoint = in.readString(); - persistentKeepalive = in.readString(); - preSharedKey = in.readString(); - publicKey = in.readString(); - numSiblings = in.readInt(); - in.readStringList(interfaceDNSRoutes); - } - - public static Observable newInstance() { - return new Observable(new Peer()); - } - - public void commitData(final Peer parent) { - parent.setAllowedIPsString(allowedIPs); - parent.setEndpointString(endpoint); - parent.setPersistentKeepaliveString(persistentKeepalive); - parent.setPreSharedKey(preSharedKey); - parent.setPublicKey(publicKey); - if (parent.getPublicKey() == null) - throw new IllegalArgumentException(Application.get().getString(R.string.tunnel_error_empty_peer_public_key)); - loadData(parent); - notifyChange(); - } - - @Override - public int describeContents() { - return 0; - } - - @Bindable @Nullable - public String getAllowedIPs() { - return allowedIPs; - } - - @Bindable - public boolean getCanToggleExcludePrivateIPs() { - final Collection<String> ips = Arrays.asList(Attribute.stringToList(allowedIPs)); - return numSiblings == 0 && (ips.contains(DEFAULT_ROUTE_V4) || ips.containsAll(DEFAULT_ROUTE_MOD_RFC1918_V4)); - } + /** + * Converts the {@code Peer} into a string suitable for inclusion in a {@code wg-quick} + * configuration file. + * + * @return the {@code Peer} represented as a series of "Key = Value" lines + */ + public String toWgQuickString() { + final StringBuilder sb = new StringBuilder(); + if (!allowedIps.isEmpty()) + sb.append("AllowedIPs = ").append(Attribute.join(allowedIps)).append('\n'); + endpoint.ifPresent(ep -> sb.append("Endpoint = ").append(ep).append('\n')); + persistentKeepalive.ifPresent(pk -> sb.append("PersistentKeepalive = ").append(pk).append('\n')); + preSharedKey.ifPresent(psk -> sb.append("PreSharedKey = ").append(psk.toBase64()).append('\n')); + sb.append("PublicKey = ").append(publicKey.toBase64()).append('\n'); + return sb.toString(); + } - @Bindable @Nullable - public String getEndpoint() { - return endpoint; - } + /** + * Serializes the {@code Peer} for use with the WireGuard cross-platform userspace API. Note + * that not all attributes are included in this representation. + * + * @return the {@code Peer} represented as a series of "key=value" lines + */ + public String toWgUserspaceString() { + final StringBuilder sb = new StringBuilder(); + // The order here is important: public_key signifies the beginning of a new peer. + sb.append("public_key=").append(publicKey.toHex()).append('\n'); + for (final InetNetwork allowedIp : allowedIps) + sb.append("allowed_ip=").append(allowedIp).append('\n'); + endpoint.flatMap(InetEndpoint::getResolved).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); + persistentKeepalive.ifPresent(pk -> sb.append("persistent_keepalive_interval=").append(pk).append('\n')); + preSharedKey.ifPresent(psk -> sb.append("preshared_key=").append(psk.toHex()).append('\n')); + return sb.toString(); + } - @Bindable - public boolean getIsExcludePrivateIPsOn() { - return numSiblings == 0 && Arrays.asList(Attribute.stringToList(allowedIPs)).containsAll(DEFAULT_ROUTE_MOD_RFC1918_V4); + @SuppressWarnings("UnusedReturnValue") + public static final class Builder { + // See wg(8) + private static final int MAX_PERSISTENT_KEEPALIVE = 65535; + + // Defaults to an empty set. + private final Set<InetNetwork> allowedIps = new LinkedHashSet<>(); + // Defaults to not present. + private Optional<InetEndpoint> endpoint = Optional.empty(); + // Defaults to not present. + private Optional<Integer> persistentKeepalive = Optional.empty(); + // Defaults to not present. + private Optional<Key> preSharedKey = Optional.empty(); + // No default; must be provided before building. + @Nullable private Key publicKey; + + public Builder addAllowedIp(final InetNetwork allowedIp) { + allowedIps.add(allowedIp); + return this; } - @Bindable @Nullable - public String getPersistentKeepalive() { - return persistentKeepalive; + public Builder addAllowedIps(final Collection<InetNetwork> allowedIps) { + this.allowedIps.addAll(allowedIps); + return this; } - @Bindable @Nullable - public String getPreSharedKey() { - return preSharedKey; + public Peer build() { + return new Peer(this); } - @Bindable @Nullable - public String getPublicKey() { - return publicKey; + public Builder parseAllowedIPs(final CharSequence allowedIps) throws ParseException { + try { + final List<InetNetwork> parsed = Stream.of(Attribute.split(allowedIps)) + .map(InetNetwork::parse) + .collect(Collectors.toUnmodifiableList()); + return addAllowedIps(parsed); + } catch (final IllegalArgumentException e) { + throw new ParseException("AllowedIPs", allowedIps, e); + } } - private void loadData(final Peer parent) { - allowedIPs = parent.getAllowedIPsString(); - endpoint = parent.getEndpointString(); - persistentKeepalive = parent.getPersistentKeepaliveString(); - preSharedKey = parent.getPreSharedKey(); - publicKey = parent.getPublicKey(); + public Builder parseEndpoint(final String endpoint) throws ParseException { + try { + return setEndpoint(InetEndpoint.parse(endpoint)); + } catch (final IllegalArgumentException e) { + throw new ParseException("Endpoint", endpoint, e); + } } - public void setAllowedIPs(final String allowedIPs) { - this.allowedIPs = allowedIPs; - notifyPropertyChanged(BR.allowedIPs); - notifyPropertyChanged(BR.canToggleExcludePrivateIPs); - notifyPropertyChanged(BR.isExcludePrivateIPsOn); + public Builder parsePersistentKeepalive(final String persistentKeepalive) throws ParseException { + try { + return setPersistentKeepalive(Integer.parseInt(persistentKeepalive)); + } catch (final IllegalArgumentException e) { + throw new ParseException("PersistentKeepalive", persistentKeepalive, e); + } } - public void setEndpoint(final String endpoint) { - this.endpoint = endpoint; - notifyPropertyChanged(BR.endpoint); + public Builder parsePreSharedKey(final String preSharedKey) throws ParseException { + try { + return setPreSharedKey(Key.fromBase64(preSharedKey)); + } catch (final Key.KeyFormatException e) { + throw new ParseException("PresharedKey", preSharedKey, e); + } } - public void setInterfaceDNSRoutes(@Nullable final String dnsServers) { - final Collection<String> ips = new HashSet<>(Arrays.asList(Attribute.stringToList(allowedIPs))); - final boolean modifyAllowedIPs = ips.containsAll(DEFAULT_ROUTE_MOD_RFC1918_V4); - - ips.removeAll(interfaceDNSRoutes); - interfaceDNSRoutes.clear(); - for (final String dnsServer : Attribute.stringToList(dnsServers)) { - if (!dnsServer.contains(":")) - interfaceDNSRoutes.add(dnsServer + "/32"); + public Builder parsePublicKey(final String publicKey) throws ParseException { + try { + return setPublicKey(Key.fromBase64(publicKey)); + } catch (final Key.KeyFormatException e) { + throw new ParseException("PublicKey", publicKey, e); } - ips.addAll(interfaceDNSRoutes); - if (modifyAllowedIPs) - setAllowedIPs(Attribute.iterableToString(ips)); } - public void setNumSiblings(final int num) { - numSiblings = num; - notifyPropertyChanged(BR.canToggleExcludePrivateIPs); - notifyPropertyChanged(BR.isExcludePrivateIPsOn); + public Builder setEndpoint(final InetEndpoint endpoint) { + this.endpoint = Optional.of(endpoint); + return this; } - public void setPersistentKeepalive(final String persistentKeepalive) { - this.persistentKeepalive = persistentKeepalive; - notifyPropertyChanged(BR.persistentKeepalive); + public Builder setPersistentKeepalive(final int persistentKeepalive) { + if (persistentKeepalive < 0 || persistentKeepalive > MAX_PERSISTENT_KEEPALIVE) + throw new IllegalArgumentException("Invalid value for PersistentKeepalive"); + this.persistentKeepalive = persistentKeepalive == 0 ? + Optional.empty() : Optional.of(persistentKeepalive); + return this; } - public void setPreSharedKey(final String preSharedKey) { - this.preSharedKey = preSharedKey; - notifyPropertyChanged(BR.preSharedKey); + public Builder setPreSharedKey(final Key preSharedKey) { + this.preSharedKey = Optional.of(preSharedKey); + return this; } - public void setPublicKey(final String publicKey) { + public Builder setPublicKey(final Key publicKey) { this.publicKey = publicKey; - notifyPropertyChanged(BR.publicKey); - } - - public void toggleExcludePrivateIPs() { - final Collection<String> ips = new HashSet<>(Arrays.asList(Attribute.stringToList(allowedIPs))); - final boolean hasDefaultRoute = ips.contains(DEFAULT_ROUTE_V4); - final boolean hasDefaultRouteModRFC1918 = ips.containsAll(DEFAULT_ROUTE_MOD_RFC1918_V4); - if ((!hasDefaultRoute && !hasDefaultRouteModRFC1918) || numSiblings > 0) - return; - Iterables.removeIf(ips, ip -> !ip.contains(":")); - if (hasDefaultRoute) { - ips.addAll(DEFAULT_ROUTE_MOD_RFC1918_V4); - ips.addAll(interfaceDNSRoutes); - } else if (hasDefaultRouteModRFC1918) - ips.add(DEFAULT_ROUTE_V4); - setAllowedIPs(Attribute.iterableToString(ips)); - } - - @Override - public void writeToParcel(final Parcel dest, final int flags) { - dest.writeString(allowedIPs); - dest.writeString(endpoint); - dest.writeString(persistentKeepalive); - dest.writeString(preSharedKey); - dest.writeString(publicKey); - dest.writeInt(numSiblings); - dest.writeStringList(interfaceDNSRoutes); + return this; } } } diff --git a/app/src/main/java/com/wireguard/crypto/Key.java b/app/src/main/java/com/wireguard/crypto/Key.java new file mode 100644 index 00000000..85146794 --- /dev/null +++ b/app/src/main/java/com/wireguard/crypto/Key.java @@ -0,0 +1,255 @@ +/* + * Copyright © 2017-2018 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +import java.util.Arrays; + +/** + * Represents a WireGuard public or private key. This class uses specialized constant-time base64 + * and hexadecimal codec implementations that resist side-channel attacks. + * <p> + * Instances of this class are immutable. + */ +@SuppressWarnings("MagicNumber") +public final class Key { + private final byte[] key; + + /** + * Constructs an object encapsulating the supplied key. + * + * @param key an array of bytes containing a binary key. Callers of this constructor are + * responsible for ensuring that the array is of the correct length. + */ + private Key(final byte[] key) { + // Defensively copy to ensure immutability. + this.key = Arrays.copyOf(key, key.length); + } + + /** + * Decodes a single 4-character base64 chunk to an integer in constant time. + * + * @param src an array of at least 4 characters in base64 format + * @param srcOffset the offset of the beginning of the chunk in {@code src} + * @return the decoded 3-byte integer, or some arbitrary integer value if the input was not + * valid base64 + */ + private static int decodeBase64(final char[] src, final int srcOffset) { + int val = 0; + for (int i = 0; i < 4; ++i) { + final char c = src[i + srcOffset]; + val |= (-1 + + ((((('A' - 1) - c) & (c - ('Z' + 1))) >>> 8) & (c - 64)) + + ((((('a' - 1) - c) & (c - ('z' + 1))) >>> 8) & (c - 70)) + + ((((('0' - 1) - c) & (c - ('9' + 1))) >>> 8) & (c + 5)) + + ((((('+' - 1) - c) & (c - ('+' + 1))) >>> 8) & 63) + + ((((('/' - 1) - c) & (c - ('/' + 1))) >>> 8) & 64) + ) << (18 - 6 * i); + } + return val; + } + + /** + * Encodes a single 4-character base64 chunk from 3 consecutive bytes in constant time. + * + * @param src an array of at least 3 bytes + * @param srcOffset the offset of the beginning of the chunk in {@code src} + * @param dest an array of at least 4 characters + * @param destOffset the offset of the beginning of the chunk in {@code dest} + */ + private static void encodeBase64(final byte[] src, final int srcOffset, + final char[] dest, final int destOffset) { + final byte[] input = { + (byte) ((src[srcOffset] >>> 2) & 63), + (byte) ((src[srcOffset] << 4 | ((src[1 + srcOffset] & 0xff) >>> 4)) & 63), + (byte) ((src[1 + srcOffset] << 2 | ((src[2 + srcOffset] & 0xff) >>> 6)) & 63), + (byte) ((src[2 + srcOffset]) & 63), + }; + for (int i = 0; i < 4; ++i) { + dest[i + destOffset] = (char) (input[i] + 'A' + + (((25 - input[i]) >>> 8) & 6) + - (((51 - input[i]) >>> 8) & 75) + - (((61 - input[i]) >>> 8) & 15) + + (((62 - input[i]) >>> 8) & 3)); + } + } + + /** + * Decodes a WireGuard public or private key from its base64 string representation. This + * function throws a {@link KeyFormatException} if the source string is not well-formed. + * + * @param str the base64 string representation of a WireGuard key + * @return the decoded key encapsulated in an immutable container + */ + public static Key fromBase64(final String str) { + final char[] input = str.toCharArray(); + if (input.length != Format.BASE64.length || input[Format.BASE64.length - 1] != '=') + throw new KeyFormatException(Format.BASE64); + final byte[] key = new byte[Format.BINARY.length]; + int i; + int ret = 0; + for (i = 0; i < key.length / 3; ++i) { + final int val = decodeBase64(input, i * 4); + ret |= val >>> 31; + key[i * 3] = (byte) ((val >>> 16) & 0xff); + key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); + key[i * 3 + 2] = (byte) (val & 0xff); + } + final char[] endSegment = { + input[i * 4], + input[i * 4 + 1], + input[i * 4 + 2], + 'A', + }; + final int val = decodeBase64(endSegment, 0); + ret |= (val >>> 31) | (val & 0xff); + key[i * 3] = (byte) ((val >>> 16) & 0xff); + key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); + + if (ret != 0) + throw new KeyFormatException(Format.BASE64); + return new Key(key); + } + + /** + * Wraps a WireGuard public or private key in an immutable container. This function throws a + * {@link KeyFormatException} if the source data is not the correct length. + * + * @param bytes an array of bytes containing a WireGuard key in binary format + * @return the key encapsulated in an immutable container + */ + public static Key fromBytes(final byte[] bytes) { + if (bytes.length != Format.BINARY.length) + throw new KeyFormatException(Format.BINARY); + return new Key(bytes); + } + + /** + * Decodes a WireGuard public or private key from its hexadecimal string representation. This + * function throws a {@link KeyFormatException} if the source string is not well-formed. + * + * @param str the hexadecimal string representation of a WireGuard key + * @return the decoded key encapsulated in an immutable container + */ + public static Key fromHex(final String str) { + final char[] input = str.toCharArray(); + if (input.length != Format.HEX.length) + throw new KeyFormatException(Format.HEX); + final byte[] key = new byte[Format.BINARY.length]; + int ret = 0; + for (int i = 0; i < key.length; ++i) { + int c; + int cNum; + int cNum0; + int cAlpha; + int cAlpha0; + int cVal; + final int cAcc; + + c = input[i * 2]; + cNum = c ^ 48; + cNum0 = ((cNum - 10) >>> 8) & 0xff; + cAlpha = (c & ~32) - 55; + cAlpha0 = (((cAlpha - 10) ^ (cAlpha - 16)) >>> 8) & 0xff; + ret |= ((cNum0 | cAlpha0) - 1) >>> 8; + cVal = (cNum0 & cNum) | (cAlpha0 & cAlpha); + cAcc = cVal * 16; + + c = input[i * 2 + 1]; + cNum = c ^ 48; + cNum0 = ((cNum - 10) >>> 8) & 0xff; + cAlpha = (c & ~32) - 55; + cAlpha0 = (((cAlpha - 10) ^ (cAlpha - 16)) >>> 8) & 0xff; + ret |= ((cNum0 | cAlpha0) - 1) >>> 8; + cVal = (cNum0 & cNum) | (cAlpha0 & cAlpha); + key[i] = (byte) (cAcc | cVal); + } + if (ret != 0) + throw new KeyFormatException(Format.HEX); + return new Key(key); + } + + /** + * Returns the key as an array of bytes. + * + * @return an array of bytes containing the raw binary key + */ + public byte[] getBytes() { + // Defensively copy to ensure immutability. + return Arrays.copyOf(key, key.length); + } + + /** + * Encodes the key to base64. + * + * @return a string containing the encoded key + */ + public String toBase64() { + final char[] output = new char[Format.BASE64.length]; + int i; + for (i = 0; i < key.length / 3; ++i) + encodeBase64(key, i * 3, output, i * 4); + final byte[] endSegment = { + key[i * 3], + key[i * 3 + 1], + 0, + }; + encodeBase64(endSegment, 0, output, i * 4); + output[Format.BASE64.length - 1] = '='; + return new String(output); + } + + /** + * Encodes the key to hexadecimal ASCII characters. + * + * @return a string containing the encoded key + */ + public String toHex() { + final char[] output = new char[Format.HEX.length]; + for (int i = 0; i < key.length; ++i) { + output[i * 2] = (char) (87 + (key[i] >> 4 & 0xf) + + ((((key[i] >> 4 & 0xf) - 10) >> 8) & ~38)); + output[i * 2 + 1] = (char) (87 + (key[i] & 0xf) + + ((((key[i] & 0xf) - 10) >> 8) & ~38)); + } + return new String(output); + } + + /** + * The supported formats for encoding a WireGuard key. + */ + public enum Format { + BASE64(44), + BINARY(32), + HEX(64); + + private final int length; + + Format(final int length) { + this.length = length; + } + + public int getLength() { + return length; + } + } + + /** + * An exception thrown when attempting to parse an invalid key (too short, too long, or byte + * data inappropriate for the format). The format being parsed can be accessed with the + * {@link #getFormat} method. + */ + public static final class KeyFormatException extends RuntimeException { + private final Format format; + + private KeyFormatException(final Format format) { + this.format = format; + } + + public Format getFormat() { + return format; + } + } +} diff --git a/app/src/main/java/com/wireguard/crypto/KeyEncoding.java b/app/src/main/java/com/wireguard/crypto/KeyEncoding.java deleted file mode 100644 index d29c2d44..00000000 --- a/app/src/main/java/com/wireguard/crypto/KeyEncoding.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright © 2017-2018 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package com.wireguard.crypto; - -import com.wireguard.android.Application; -import com.wireguard.android.R; - -/** - * This is a specialized constant-time base64 and hex implementation that resists side-channel attacks. - */ - -@SuppressWarnings("MagicNumber") -public final class KeyEncoding { - public static final int KEY_LENGTH = 32; - public static final int KEY_LENGTH_BASE64 = 44; - public static final int KEY_LENGTH_HEX = 64; - private static final String KEY_LENGTH_BASE64_EXCEPTION_MESSAGE = - Application.get().getString(R.string.key_length_base64_exception_message); - private static final String KEY_LENGTH_EXCEPTION_MESSAGE = - Application.get().getString(R.string.key_length_exception_message); - private static final String KEY_LENGTH_HEX_EXCEPTION_MESSAGE = - Application.get().getString(R.string.key_length_hex_exception_message); - - private KeyEncoding() { - // Prevent instantiation. - } - - private static int decodeBase64(final char[] src, final int srcOffset) { - int val = 0; - for (int i = 0; i < 4; ++i) { - final char c = src[i + srcOffset]; - val |= (-1 - + ((((('A' - 1) - c) & (c - ('Z' + 1))) >>> 8) & (c - 64)) - + ((((('a' - 1) - c) & (c - ('z' + 1))) >>> 8) & (c - 70)) - + ((((('0' - 1) - c) & (c - ('9' + 1))) >>> 8) & (c + 5)) - + ((((('+' - 1) - c) & (c - ('+' + 1))) >>> 8) & 63) - + ((((('/' - 1) - c) & (c - ('/' + 1))) >>> 8) & 64) - ) << (18 - 6 * i); - } - return val; - } - - private static void encodeBase64(final byte[] src, final int srcOffset, - final char[] dest, final int destOffset) { - final byte[] input = { - (byte) ((src[srcOffset] >>> 2) & 63), - (byte) ((src[srcOffset] << 4 | ((src[1 + srcOffset] & 0xff) >>> 4)) & 63), - (byte) ((src[1 + srcOffset] << 2 | ((src[2 + srcOffset] & 0xff) >>> 6)) & 63), - (byte) ((src[2 + srcOffset]) & 63), - }; - for (int i = 0; i < 4; ++i) { - dest[i + destOffset] = (char) (input[i] + 'A' - + (((25 - input[i]) >>> 8) & 6) - - (((51 - input[i]) >>> 8) & 75) - - (((61 - input[i]) >>> 8) & 15) - + (((62 - input[i]) >>> 8) & 3)); - } - } - - public static byte[] keyFromBase64(final String str) { - final char[] input = str.toCharArray(); - final byte[] key = new byte[KEY_LENGTH]; - if (input.length != KEY_LENGTH_BASE64 || input[KEY_LENGTH_BASE64 - 1] != '=') - throw new IllegalArgumentException(KEY_LENGTH_BASE64_EXCEPTION_MESSAGE); - int i; - int ret = 0; - for (i = 0; i < KEY_LENGTH / 3; ++i) { - final int val = decodeBase64(input, i * 4); - ret |= val >>> 31; - key[i * 3] = (byte) ((val >>> 16) & 0xff); - key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); - key[i * 3 + 2] = (byte) (val & 0xff); - } - final char[] endSegment = { - input[i * 4], - input[i * 4 + 1], - input[i * 4 + 2], - 'A', - }; - final int val = decodeBase64(endSegment, 0); - ret |= (val >>> 31) | (val & 0xff); - key[i * 3] = (byte) ((val >>> 16) & 0xff); - key[i * 3 + 1] = (byte) ((val >>> 8) & 0xff); - - if (ret != 0) - throw new IllegalArgumentException(KEY_LENGTH_BASE64_EXCEPTION_MESSAGE); - return key; - } - - public static byte[] keyFromHex(final String str) { - final char[] input = str.toCharArray(); - final byte[] key = new byte[KEY_LENGTH]; - if (input.length != KEY_LENGTH_HEX) - throw new IllegalArgumentException(KEY_LENGTH_HEX_EXCEPTION_MESSAGE); - int ret = 0; - - for (int i = 0; i < KEY_LENGTH_HEX; i += 2) { - int c; - int cNum; - int cNum0; - int cAlpha; - int cAlpha0; - int cVal; - final int cAcc; - - c = input[i]; - cNum = c ^ 48; - cNum0 = ((cNum - 10) >>> 8) & 0xff; - cAlpha = (c & ~32) - 55; - cAlpha0 = (((cAlpha - 10) ^ (cAlpha - 16)) >>> 8) & 0xff; - ret |= ((cNum0 | cAlpha0) - 1) >>> 8; - cVal = (cNum0 & cNum) | (cAlpha0 & cAlpha); - cAcc = cVal * 16; - - c = input[i + 1]; - cNum = c ^ 48; - cNum0 = ((cNum - 10) >>> 8) & 0xff; - cAlpha = (c & ~32) - 55; - cAlpha0 = (((cAlpha - 10) ^ (cAlpha - 16)) >>> 8) & 0xff; - ret |= ((cNum0 | cAlpha0) - 1) >>> 8; - cVal = (cNum0 & cNum) | (cAlpha0 & cAlpha); - key[i / 2] = (byte) (cAcc | cVal); - } - if (ret != 0) - throw new IllegalArgumentException(KEY_LENGTH_HEX_EXCEPTION_MESSAGE); - return key; - } - - public static String keyToBase64(final byte[] key) { - final char[] output = new char[KEY_LENGTH_BASE64]; - if (key.length != KEY_LENGTH) - throw new IllegalArgumentException(KEY_LENGTH_EXCEPTION_MESSAGE); - int i; - for (i = 0; i < KEY_LENGTH / 3; ++i) - encodeBase64(key, i * 3, output, i * 4); - final byte[] endSegment = { - key[i * 3], - key[i * 3 + 1], - 0, - }; - encodeBase64(endSegment, 0, output, i * 4); - output[KEY_LENGTH_BASE64 - 1] = '='; - return new String(output); - } - - public static String keyToHex(final byte[] key) { - final char[] output = new char[KEY_LENGTH_HEX]; - if (key.length != KEY_LENGTH) - throw new IllegalArgumentException(KEY_LENGTH_EXCEPTION_MESSAGE); - for (int i = 0; i < KEY_LENGTH; ++i) { - output[i * 2] = (char) (87 + (key[i] >> 4 & 0xf) - + ((((key[i] >> 4 & 0xf) - 10) >> 8) & ~38)); - output[i * 2 + 1] = (char) (87 + (key[i] & 0xf) - + ((((key[i] & 0xf) - 10) >> 8) & ~38)); - } - return new String(output); - } -} diff --git a/app/src/main/java/com/wireguard/crypto/KeyPair.java b/app/src/main/java/com/wireguard/crypto/KeyPair.java new file mode 100644 index 00000000..2b2bf564 --- /dev/null +++ b/app/src/main/java/com/wireguard/crypto/KeyPair.java @@ -0,0 +1,81 @@ +/* + * Copyright © 2017-2018 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.crypto; + +import java.security.SecureRandom; + +/** + * Represents a Curve25519 key pair as used by WireGuard. + * <p> + * Instances of this class are immutable. + */ +public class KeyPair { + private final Key privateKey; + private final Key publicKey; + + /** + * Creates a key pair using a newly-generated private key. + */ + public KeyPair() { + this(generatePrivateKey()); + } + + /** + * Creates a key pair using an existing private key. + * + * @param privateKey a private key, used to derive the public key + */ + public KeyPair(final Key privateKey) { + this.privateKey = privateKey; + publicKey = generatePublicKey(privateKey); + } + + /** + * Generates a private key using the system's {@link SecureRandom} number generator. + * + * @return a well-formed random private key + */ + @SuppressWarnings("MagicNumber") + private static Key generatePrivateKey() { + final SecureRandom secureRandom = new SecureRandom(); + final byte[] privateKey = new byte[Key.Format.BINARY.getLength()]; + secureRandom.nextBytes(privateKey); + privateKey[0] &= 248; + privateKey[31] &= 127; + privateKey[31] |= 64; + return Key.fromBytes(privateKey); + } + + /** + * Generates a public key from an existing private key. + * + * @param privateKey a private key + * @return a well-formed public key that corresponds to the supplied private key + */ + private static Key generatePublicKey(final Key privateKey) { + final byte[] publicKey = new byte[Key.Format.BINARY.getLength()]; + Curve25519.eval(publicKey, 0, privateKey.getBytes(), null); + return Key.fromBytes(publicKey); + } + + /** + * Returns the private key from the key pair. + * + * @return the private key + */ + public Key getPrivateKey() { + return privateKey; + } + + /** + * Returns the public key from the key pair. + * + * @return the public key + */ + public Key getPublicKey() { + return publicKey; + } +} diff --git a/app/src/main/java/com/wireguard/crypto/Keypair.java b/app/src/main/java/com/wireguard/crypto/Keypair.java deleted file mode 100644 index 0ee27542..00000000 --- a/app/src/main/java/com/wireguard/crypto/Keypair.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright © 2017-2018 WireGuard LLC. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package com.wireguard.crypto; - -import java.security.SecureRandom; - -/** - * Represents a Curve25519 keypair as used by WireGuard. - */ - -public class Keypair { - private final byte[] privateKey; - private final byte[] publicKey; - - public Keypair() { - this(generatePrivateKey()); - } - - private Keypair(final byte[] privateKey) { - this.privateKey = privateKey; - publicKey = generatePublicKey(privateKey); - } - - public Keypair(final String privateKey) { - this(KeyEncoding.keyFromBase64(privateKey)); - } - - @SuppressWarnings("MagicNumber") - private static byte[] generatePrivateKey() { - final SecureRandom secureRandom = new SecureRandom(); - final byte[] privateKey = new byte[KeyEncoding.KEY_LENGTH]; - secureRandom.nextBytes(privateKey); - privateKey[0] &= 248; - privateKey[31] &= 127; - privateKey[31] |= 64; - return privateKey; - } - - private static byte[] generatePublicKey(final byte[] privateKey) { - final byte[] publicKey = new byte[KeyEncoding.KEY_LENGTH]; - Curve25519.eval(publicKey, 0, privateKey, null); - return publicKey; - } - - public String getPrivateKey() { - return KeyEncoding.keyToBase64(privateKey); - } - - public String getPublicKey() { - return KeyEncoding.keyToBase64(publicKey); - } -} |