diff options
5 files changed, 175 insertions, 267 deletions
diff --git a/ui/src/main/java/com/wireguard/android/activity/TvMainActivity.kt b/ui/src/main/java/com/wireguard/android/activity/TvMainActivity.kt index 0b03d474..4b110a30 100644 --- a/ui/src/main/java/com/wireguard/android/activity/TvMainActivity.kt +++ b/ui/src/main/java/com/wireguard/android/activity/TvMainActivity.kt @@ -5,36 +5,23 @@ package com.wireguard.android.activity -import android.net.Uri import android.os.Bundle -import android.provider.OpenableColumns -import android.util.Log +import android.widget.Toast import androidx.activity.result.contract.ActivityResultContracts import androidx.lifecycle.lifecycleScope import com.google.android.material.button.MaterialButton -import com.google.android.material.snackbar.Snackbar -import com.wireguard.android.Application import com.wireguard.android.R import com.wireguard.android.model.ObservableTunnel -import com.wireguard.android.util.ErrorMessages -import com.wireguard.config.Config -import kotlinx.coroutines.Deferred -import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.async +import com.wireguard.android.util.TunnelImporter import kotlinx.coroutines.launch -import kotlinx.coroutines.withContext -import java.io.BufferedReader -import java.io.InputStreamReader -import java.nio.charset.StandardCharsets -import java.util.ArrayList -import java.util.Locale -import java.util.zip.ZipEntry -import java.util.zip.ZipInputStream class TvMainActivity : BaseActivity() { private val tunnelFileImportResultLauncher = registerForActivityResult(ActivityResultContracts.GetContent()) { data -> - importTunnel(data) + lifecycleScope.launch { + TunnelImporter.importTunnel(contentResolver, data) { + Toast.makeText(this@TvMainActivity, it, Toast.LENGTH_LONG).show() + } + } } override fun onSelectedTunnelChanged(oldTunnel: ObservableTunnel?, newTunnel: ObservableTunnel?) { @@ -48,114 +35,6 @@ class TvMainActivity : BaseActivity() { } } - private fun onTunnelImportFinished(tunnels: List<ObservableTunnel>, throwables: Collection<Throwable>) { - var message = "" - for (throwable in throwables) { - val error = ErrorMessages[throwable] - message = getString(R.string.import_error, error) - Log.e(TAG, message, throwable) - } - if (tunnels.size == 1 && throwables.isEmpty()) - message = getString(R.string.import_success, tunnels[0].name) - else if (tunnels.isEmpty() && throwables.size == 1) - else if (throwables.isEmpty()) - message = resources.getQuantityString(R.plurals.import_total_success, - tunnels.size, tunnels.size) - else if (!throwables.isEmpty()) - message = resources.getQuantityString(R.plurals.import_partial_success, - tunnels.size + throwables.size, - tunnels.size, tunnels.size + throwables.size) - Snackbar.make(findViewById(android.R.id.content), message, Snackbar.LENGTH_LONG).show() - } - - private fun importTunnel(uri: Uri?) { - lifecycleScope.launch { - withContext(Dispatchers.IO) { - if (uri == null) { - return@withContext - } - val futureTunnels = ArrayList<Deferred<ObservableTunnel>>() - val throwables = ArrayList<Throwable>() - try { - val columns = arrayOf(OpenableColumns.DISPLAY_NAME) - var name = "" - contentResolver.query(uri, columns, null, null, null)?.use { cursor -> - if (cursor.moveToFirst() && !cursor.isNull(0)) { - name = cursor.getString(0) - } - } - if (name.isEmpty()) { - name = Uri.decode(uri.lastPathSegment) - } - var idx = name.lastIndexOf('/') - if (idx >= 0) { - require(idx < name.length - 1) { resources.getString(R.string.illegal_filename_error, name) } - name = name.substring(idx + 1) - } - val isZip = name.toLowerCase(Locale.ROOT).endsWith(".zip") - if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { - name = name.substring(0, name.length - ".conf".length) - } else { - require(isZip) { resources.getString(R.string.bad_extension_error) } - } - - if (isZip) { - ZipInputStream(contentResolver.openInputStream(uri)).use { zip -> - val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8)) - var entry: ZipEntry? - while (true) { - entry = zip.nextEntry ?: break - name = entry.name - idx = name.lastIndexOf('/') - if (idx >= 0) { - if (idx >= name.length - 1) { - continue - } - name = name.substring(name.lastIndexOf('/') + 1) - } - if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { - name = name.substring(0, name.length - ".conf".length) - } else { - continue - } - try { - Config.parse(reader) - } catch (e: Throwable) { - throwables.add(e) - null - }?.let { - val nameCopy = name - futureTunnels.add(async(SupervisorJob()) { Application.getTunnelManager().create(nameCopy, it) }) - } - } - } - } else { - futureTunnels.add(async(SupervisorJob()) { Application.getTunnelManager().create(name, Config.parse(contentResolver.openInputStream(uri)!!)) }) - } - - if (futureTunnels.isEmpty()) { - if (throwables.size == 1) { - throw throwables[0] - } else { - require(throwables.isNotEmpty()) { resources.getString(R.string.no_configs_error) } - } - } - val tunnels = futureTunnels.mapNotNull { - try { - it.await() - } catch (e: Throwable) { - throwables.add(e) - null - } - } - withContext(Dispatchers.Main.immediate) { onTunnelImportFinished(tunnels, throwables) } - } catch (e: Throwable) { - withContext(Dispatchers.Main.immediate) { onTunnelImportFinished(emptyList(), listOf(e)) } - } - } - } - } - companion object { const val TAG = "WireGuard/TvMainActivity" } diff --git a/ui/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.kt index 9b643e5f..ce39fd8f 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/TunnelDetailFragment.kt @@ -17,7 +17,6 @@ import com.wireguard.android.backend.Tunnel import com.wireguard.android.databinding.TunnelDetailFragmentBinding import com.wireguard.android.databinding.TunnelDetailPeerBinding import com.wireguard.android.model.ObservableTunnel -import com.wireguard.android.util.formatBytes import kotlinx.coroutines.delay import kotlinx.coroutines.launch @@ -29,6 +28,16 @@ class TunnelDetailFragment : BaseFragment() { private var lastState = Tunnel.State.TOGGLE private var timerActive = true + private fun formatBytes(bytes: Long): String { + return when { + bytes < 1024 -> getString(R.string.transfer_bytes, bytes) + bytes < 1024 * 1024 -> getString(R.string.transfer_kibibytes, bytes / 1024.0) + bytes < 1024 * 1024 * 1024 -> getString(R.string.transfer_mibibytes, bytes / (1024.0 * 1024.0)) + bytes < 1024 * 1024 * 1024 * 1024L -> getString(R.string.transfer_gibibytes, bytes / (1024.0 * 1024.0 * 1024.0)) + else -> getString(R.string.transfer_tibibytes, bytes / (1024.0 * 1024.0 * 1024.0) / 1024.0) + } + } + override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) setHasOptionsMenu(true) @@ -108,7 +117,7 @@ class TunnelDetailFragment : BaseFragment() { peer.transferText.visibility = View.GONE continue } - peer.transferText.text = getString(R.string.transfer_rx_tx, context?.formatBytes(rx), context?.formatBytes(tx)) + peer.transferText.text = getString(R.string.transfer_rx_tx, formatBytes(rx), formatBytes(tx)) peer.transferLabel.visibility = View.VISIBLE peer.transferText.visibility = View.VISIBLE } diff --git a/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt b/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt index eb3d6f78..66b0e8ba 100644 --- a/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt +++ b/ui/src/main/java/com/wireguard/android/fragment/TunnelListFragment.kt @@ -32,6 +32,7 @@ import com.wireguard.android.databinding.TunnelListItemBinding import com.wireguard.android.fragment.ConfigNamingDialogFragment.Companion.newInstance import com.wireguard.android.model.ObservableTunnel import com.wireguard.android.util.ErrorMessages +import com.wireguard.android.util.TunnelImporter import com.wireguard.android.widget.MultiselectableRelativeLayout import com.wireguard.config.Config import kotlinx.coroutines.Deferred @@ -59,114 +60,15 @@ class TunnelListFragment : BaseFragment() { private var actionMode: ActionMode? = null private var binding: TunnelListFragmentBinding? = null private val tunnelFileImportResultLauncher = registerForActivityResult(ActivityResultContracts.GetContent()) { data -> - importTunnel(data) - } - - private val qrImportResultLauncher = registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { result -> - val qrCode = IntentIntegrator.parseActivityResult(result.resultCode, result.data) - qrCode?.contents?.let { importTunnel(it) } - } - - private fun importTunnel(configText: String) { - try { - // Ensure the config text is parseable before proceeding… - Config.parse(ByteArrayInputStream(configText.toByteArray(StandardCharsets.UTF_8))) - - // Config text is valid, now create the tunnel… - newInstance(configText).show(parentFragmentManager, null) - } catch (e: Throwable) { - onTunnelImportFinished(emptyList(), listOf<Throwable>(e)) + lifecycleScope.launch { + val contentResolver = activity?.contentResolver ?: return@launch + TunnelImporter.importTunnel(contentResolver, data) { showSnackbar(it) } } } - private fun importTunnel(uri: Uri?) { - lifecycleScope.launch { - withContext(Dispatchers.IO) { - val activity = activity - if (activity == null || uri == null) { - return@withContext - } - val contentResolver = activity.contentResolver - val futureTunnels = ArrayList<Deferred<ObservableTunnel>>() - val throwables = ArrayList<Throwable>() - try { - val columns = arrayOf(OpenableColumns.DISPLAY_NAME) - var name = "" - contentResolver.query(uri, columns, null, null, null)?.use { cursor -> - if (cursor.moveToFirst() && !cursor.isNull(0)) { - name = cursor.getString(0) - } - } - if (name.isEmpty()) { - name = Uri.decode(uri.lastPathSegment) - } - var idx = name.lastIndexOf('/') - if (idx >= 0) { - require(idx < name.length - 1) { resources.getString(R.string.illegal_filename_error, name) } - name = name.substring(idx + 1) - } - val isZip = name.toLowerCase(Locale.ROOT).endsWith(".zip") - if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { - name = name.substring(0, name.length - ".conf".length) - } else { - require(isZip) { resources.getString(R.string.bad_extension_error) } - } - - if (isZip) { - ZipInputStream(contentResolver.openInputStream(uri)).use { zip -> - val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8)) - var entry: ZipEntry? - while (true) { - entry = zip.nextEntry ?: break - name = entry.name - idx = name.lastIndexOf('/') - if (idx >= 0) { - if (idx >= name.length - 1) { - continue - } - name = name.substring(name.lastIndexOf('/') + 1) - } - if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { - name = name.substring(0, name.length - ".conf".length) - } else { - continue - } - try { - Config.parse(reader) - } catch (e: Throwable) { - throwables.add(e) - null - }?.let { - val nameCopy = name - futureTunnels.add(async(SupervisorJob()) { Application.getTunnelManager().create(nameCopy, it) }) - } - } - } - } else { - futureTunnels.add(async(SupervisorJob()) { Application.getTunnelManager().create(name, Config.parse(contentResolver.openInputStream(uri)!!)) }) - } - - if (futureTunnels.isEmpty()) { - if (throwables.size == 1) { - throw throwables[0] - } else { - require(throwables.isNotEmpty()) { resources.getString(R.string.no_configs_error) } - } - } - val tunnels = futureTunnels.mapNotNull { - try { - it.await() - } catch (e: Throwable) { - throwables.add(e) - null - } - } - withContext(Dispatchers.Main.immediate) { onTunnelImportFinished(tunnels, throwables) } - } catch (e: Throwable) { - withContext(Dispatchers.Main.immediate) { onTunnelImportFinished(emptyList(), listOf(e)) } - } - } - } + private val qrImportResultLauncher = registerForActivityResult(ActivityResultContracts.StartActivityForResult()) { result -> + val qrCode = IntentIntegrator.parseActivityResult(result.resultCode, result.data)?.contents ?: return@registerForActivityResult + lifecycleScope.launch { TunnelImporter.importTunnel(parentFragmentManager, qrCode) { showSnackbar(it) } } } override fun onViewCreated(view: View, savedInstanceState: Bundle?) { @@ -241,26 +143,6 @@ class TunnelListFragment : BaseFragment() { showSnackbar(message) } - private fun onTunnelImportFinished(tunnels: List<ObservableTunnel>, throwables: Collection<Throwable>) { - var message = "" - for (throwable in throwables) { - val error = ErrorMessages[throwable] - message = getString(R.string.import_error, error) - Log.e(TAG, message, throwable) - } - if (tunnels.size == 1 && throwables.isEmpty()) - message = getString(R.string.import_success, tunnels[0].name) - else if (tunnels.isEmpty() && throwables.size == 1) - else if (throwables.isEmpty()) - message = resources.getQuantityString(R.plurals.import_total_success, - tunnels.size, tunnels.size) - else if (!throwables.isEmpty()) - message = resources.getQuantityString(R.plurals.import_partial_success, - tunnels.size + throwables.size, - tunnels.size, tunnels.size + throwables.size) - showSnackbar(message) - } - override fun onViewStateRestored(savedInstanceState: Bundle?) { super.onViewStateRestored(savedInstanceState) binding ?: return @@ -423,8 +305,6 @@ class TunnelListFragment : BaseFragment() { } companion object { - const val REQUEST_IMPORT = 1 - private const val REQUEST_TARGET_FRAGMENT = 2 private const val CHECKED_ITEMS = "CHECKED_ITEMS" private const val TAG = "WireGuard/TunnelListFragment" } diff --git a/ui/src/main/java/com/wireguard/android/util/Extensions.kt b/ui/src/main/java/com/wireguard/android/util/Extensions.kt index b419feef..62bda7e9 100644 --- a/ui/src/main/java/com/wireguard/android/util/Extensions.kt +++ b/ui/src/main/java/com/wireguard/android/util/Extensions.kt @@ -21,16 +21,6 @@ fun Context.resolveAttribute(@AttrRes attrRes: Int): Int { return typedValue.data } -fun Context.formatBytes(bytes: Long): String { - return when { - bytes < 1024 -> getString(R.string.transfer_bytes, bytes) - bytes < 1024 * 1024 -> getString(R.string.transfer_kibibytes, bytes / 1024.0) - bytes < 1024 * 1024 * 1024 -> getString(R.string.transfer_mibibytes, bytes / (1024.0 * 1024.0)) - bytes < 1024 * 1024 * 1024 * 1024L -> getString(R.string.transfer_gibibytes, bytes / (1024.0 * 1024.0 * 1024.0)) - else -> getString(R.string.transfer_tibibytes, bytes / (1024.0 * 1024.0 * 1024.0) / 1024.0) - } -} - val Any.applicationScope: CoroutineScope get() = Application.getCoroutineScope() diff --git a/ui/src/main/java/com/wireguard/android/util/TunnelImporter.kt b/ui/src/main/java/com/wireguard/android/util/TunnelImporter.kt new file mode 100644 index 00000000..a197bd7b --- /dev/null +++ b/ui/src/main/java/com/wireguard/android/util/TunnelImporter.kt @@ -0,0 +1,150 @@ +/* + * Copyright © 2020 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.android.util + +import android.content.ContentResolver +import android.net.Uri +import android.provider.OpenableColumns +import android.util.Log +import androidx.fragment.app.FragmentManager +import com.wireguard.android.Application +import com.wireguard.android.R +import com.wireguard.android.fragment.ConfigNamingDialogFragment +import com.wireguard.android.model.ObservableTunnel +import com.wireguard.config.Config +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.async +import kotlinx.coroutines.withContext +import java.io.BufferedReader +import java.io.ByteArrayInputStream +import java.io.InputStreamReader +import java.nio.charset.StandardCharsets +import java.util.ArrayList +import java.util.Locale +import java.util.zip.ZipEntry +import java.util.zip.ZipInputStream + +object TunnelImporter { + suspend fun importTunnel(contentResolver: ContentResolver, uri: Uri, messageCallback: (CharSequence) -> Unit) = withContext(Dispatchers.IO) { + val context = Application.get().applicationContext + val futureTunnels = ArrayList<Deferred<ObservableTunnel>>() + val throwables = ArrayList<Throwable>() + try { + val columns = arrayOf(OpenableColumns.DISPLAY_NAME) + var name = "" + contentResolver.query(uri, columns, null, null, null)?.use { cursor -> + if (cursor.moveToFirst() && !cursor.isNull(0)) { + name = cursor.getString(0) + } + } + if (name.isEmpty()) { + name = Uri.decode(uri.lastPathSegment) + } + var idx = name.lastIndexOf('/') + if (idx >= 0) { + require(idx < name.length - 1) { context.getString(R.string.illegal_filename_error, name) } + name = name.substring(idx + 1) + } + val isZip = name.toLowerCase(Locale.ROOT).endsWith(".zip") + if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { + name = name.substring(0, name.length - ".conf".length) + } else { + require(isZip) { context.getString(R.string.bad_extension_error) } + } + + if (isZip) { + ZipInputStream(contentResolver.openInputStream(uri)).use { zip -> + val reader = BufferedReader(InputStreamReader(zip, StandardCharsets.UTF_8)) + var entry: ZipEntry? + while (true) { + entry = zip.nextEntry ?: break + name = entry.name + idx = name.lastIndexOf('/') + if (idx >= 0) { + if (idx >= name.length - 1) { + continue + } + name = name.substring(name.lastIndexOf('/') + 1) + } + if (name.toLowerCase(Locale.ROOT).endsWith(".conf")) { + name = name.substring(0, name.length - ".conf".length) + } else { + continue + } + try { + Config.parse(reader) + } catch (e: Throwable) { + throwables.add(e) + null + }?.let { + val nameCopy = name + futureTunnels.add(async(SupervisorJob()) { Application.getTunnelManager().create(nameCopy, it) }) + } + } + } + } else { + futureTunnels.add(async(SupervisorJob()) { Application.getTunnelManager().create(name, Config.parse(contentResolver.openInputStream(uri)!!)) }) + } + + if (futureTunnels.isEmpty()) { + if (throwables.size == 1) { + throw throwables[0] + } else { + require(throwables.isNotEmpty()) { context.getString(R.string.no_configs_error) } + } + } + val tunnels = futureTunnels.mapNotNull { + try { + it.await() + } catch (e: Throwable) { + throwables.add(e) + null + } + } + withContext(Dispatchers.Main.immediate) { onTunnelImportFinished(tunnels, throwables, messageCallback) } + } catch (e: Throwable) { + withContext(Dispatchers.Main.immediate) { onTunnelImportFinished(emptyList(), listOf(e), messageCallback) } + } + } + + fun importTunnel(parentFragmentManager: FragmentManager, configText: String, messageCallback: (CharSequence) -> Unit) { + try { + // Ensure the config text is parseable before proceeding… + Config.parse(ByteArrayInputStream(configText.toByteArray(StandardCharsets.UTF_8))) + + // Config text is valid, now create the tunnel… + ConfigNamingDialogFragment.newInstance(configText).show(parentFragmentManager, null) + } catch (e: Throwable) { + onTunnelImportFinished(emptyList(), listOf<Throwable>(e), messageCallback) + } + } + + private fun onTunnelImportFinished(tunnels: List<ObservableTunnel>, throwables: Collection<Throwable>, messageCallback: (CharSequence) -> Unit) { + val context = Application.get().applicationContext + var message = "" + for (throwable in throwables) { + val error = ErrorMessages[throwable] + message = context.getString(R.string.import_error, error) + Log.e(TAG, message, throwable) + } + if (tunnels.size == 1 && throwables.isEmpty()) + message = context.getString(R.string.import_success, tunnels[0].name) + else if (tunnels.isEmpty() && throwables.size == 1) + else if (throwables.isEmpty()) + message = context.resources.getQuantityString(R.plurals.import_total_success, + tunnels.size, tunnels.size) + else if (!throwables.isEmpty()) + message = context.resources.getQuantityString(R.plurals.import_partial_success, + tunnels.size + throwables.size, + tunnels.size, tunnels.size + throwables.size) + + messageCallback(message) + } + + private const val TAG = "WireGuard/TunnelImporter" +}
\ No newline at end of file |