diff options
Diffstat (limited to 'ui')
-rw-r--r-- | ui/src/main/java/com/wireguard/android/updater/Updater.kt | 177 |
1 files changed, 98 insertions, 79 deletions
diff --git a/ui/src/main/java/com/wireguard/android/updater/Updater.kt b/ui/src/main/java/com/wireguard/android/updater/Updater.kt index d7050090..d0bd18a1 100644 --- a/ui/src/main/java/com/wireguard/android/updater/Updater.kt +++ b/ui/src/main/java/com/wireguard/android/updater/Updater.kt @@ -18,7 +18,9 @@ import androidx.core.content.IntentCompat import com.wireguard.android.Application import com.wireguard.android.BuildConfig import com.wireguard.android.util.UserKnobs +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job import kotlinx.coroutines.delay import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.asStateFlow @@ -41,12 +43,16 @@ import kotlin.time.Duration.Companion.seconds object Updater { private const val TAG = "WireGuard/Updater" - private const val LATEST_VERSION_URL = "https://download.wireguard.com/android-client/latest.sig" + private const val LATEST_VERSION_URL = + "https://download.wireguard.com/android-client/latest.sig" private const val APK_PATH_URL = "https://download.wireguard.com/android-client/%s" - private const val APK_NAME_PREFIX = BuildConfig.APPLICATION_ID + "-" + private val APK_NAME_PREFIX = BuildConfig.APPLICATION_ID.removeSuffix(".debug") + "-" private const val APK_NAME_SUFFIX = ".apk" - private const val RELEASE_PUBLIC_KEY_BASE64 = "RWTAzwGRYr3EC9px0Ia3fbttz8WcVN6wrOwWp2delz4el6SI8XmkKSMp" - private val CURRENT_VERSION = BuildConfig.VERSION_NAME.removeSuffix("-debug") + private const val RELEASE_PUBLIC_KEY_BASE64 = + "RWTAzwGRYr3EC9px0Ia3fbttz8WcVN6wrOwWp2delz4el6SI8XmkKSMp" + private val CURRENT_VERSION = Version(BuildConfig.VERSION_NAME.removeSuffix("-debug")) + + private val updaterScope = CoroutineScope(Job() + Dispatchers.IO) sealed class Progress { object Complete : Progress() @@ -89,7 +95,7 @@ object Updater { class Failure(val error: Throwable) : Progress() { fun retry() { - Application.getCoroutineScope().launch { + updaterScope.launch { downloadAndUpdateWrapErrors() } } @@ -104,29 +110,61 @@ object Updater { mutableState.emit(progress) } - private fun versionIsNewer(lhs: String, rhs: String): Boolean { - val lhsParts = lhs.split(".") - val rhsParts = rhs.split(".") - if (lhsParts.isEmpty() || rhsParts.isEmpty()) - throw InvalidParameterException("Version is empty") - - for (i in 0 until max(lhsParts.size, rhsParts.size)) { - val lhsPart = if (i < lhsParts.size) lhsParts[i].toULong() else 0UL - val rhsPart = if (i < rhsParts.size) rhsParts[i].toULong() else 0UL - if (lhsPart == rhsPart) - continue - return lhsPart > rhsPart + private class Sha256Digest(hex: String) { + val bytes: ByteArray + + init { + if (hex.length != 64) + throw InvalidParameterException("SHA256 hashes must be 32 bytes long") + bytes = hex.chunked(2).map { it.toInt(16).toByte() }.toByteArray() + } + } + + @OptIn(ExperimentalUnsignedTypes::class) + private class Version(version: String) : Comparable<Version> { + val parts: ULongArray + + init { + val strParts = version.split(".") + if (strParts.isEmpty()) + throw InvalidParameterException("Version has no parts") + parts = ULongArray(strParts.size) + for (i in parts.indices) { + parts[i] = strParts[i].toULong() + } + } + + override fun toString(): String { + return parts.joinToString(".") + } + + override fun compareTo(other: Version): Int { + for (i in 0 until max(parts.size, other.parts.size)) { + val lhsPart = if (i < parts.size) parts[i] else 0UL + val rhsPart = if (i < other.parts.size) other.parts[i] else 0UL + if (lhsPart > rhsPart) + return 1 + else if (lhsPart < rhsPart) + return -1 + } + return 0 } - return false } - private fun versionOfFile(name: String): String? { + private class Update(val fileName: String, val version: Version, val hash: Sha256Digest) + + private fun versionOfFile(name: String): Version? { if (!name.startsWith(APK_NAME_PREFIX) || !name.endsWith(APK_NAME_SUFFIX)) return null - return name.substring(APK_NAME_PREFIX.length, name.length - APK_NAME_SUFFIX.length) + return try { + Version(name.substring(APK_NAME_PREFIX.length, name.length - APK_NAME_SUFFIX.length)) + } catch (_: Throwable) { + null + } } - private fun verifySignedFileList(signifyDigest: String): Map<String, Sha256Digest> { + private fun verifySignedFileList(signifyDigest: String): List<Update> { + val updates = ArrayList<Update>(1) val publicKeyBytes = Base64.decode(RELEASE_PUBLIC_KEY_BASE64, Base64.DEFAULT) if (publicKeyBytes == null || publicKeyBytes.size != 32 + 10 || publicKeyBytes[0] != 'E'.code.toByte() || publicKeyBytes[1] != 'd'.code.toByte()) throw InvalidKeyException("Invalid public key") @@ -149,32 +187,23 @@ object Updater { ) ) throw SecurityException("Invalid signature") - val hashes: MutableMap<String, Sha256Digest> = HashMap() for (line in lines[2].split("\n").dropLastWhile { it.isEmpty() }) { val components = line.split(" ", limit = 2) if (components.size != 2) throw InvalidParameterException("Invalid file list format: too few components") - hashes[components[1]] = Sha256Digest(components[0]) - } - return hashes - } - - private class Sha256Digest(hex: String) { - val bytes: ByteArray - - init { - if (hex.length != 64) - throw InvalidParameterException("SHA256 hashes must be 32 bytes long") - bytes = hex.chunked(2).map { it.toInt(16).toByte() }.toByteArray() + /* If version is null, it's not a file we understand, but still a legitimate entry, so don't throw. */ + val version = versionOfFile(components[1]) ?: continue + updates.add(Update(components[1], version, Sha256Digest(components[0]))) } + return updates } - private fun checkForUpdates(): Pair<String, Sha256Digest> { + private fun checkForUpdates(): Update? { val connection = URL(LATEST_VERSION_URL).openConnection() as HttpURLConnection connection.setRequestProperty("User-Agent", Application.USER_AGENT) connection.connect() if (connection.responseCode != HttpURLConnection.HTTP_OK) - throw IOException("File list could not be fetched: ${connection.responseCode}") + throw IOException(connection.responseMessage) var fileListBytes = ByteArray(1024 * 512 /* 512 KiB */) connection.inputStream.use { val len = it.read(fileListBytes) @@ -182,26 +211,7 @@ object Updater { throw IOException("File list is empty") fileListBytes = fileListBytes.sliceArray(0 until len) } - val fileList = verifySignedFileList(fileListBytes.decodeToString()) - if (fileList.isEmpty()) - throw InvalidParameterException("File list is empty") - var newestFile: String? = null - var newestVersion: String? = null - var newestFileHash: Sha256Digest? = null - for (file in fileList) { - val fileVersion = versionOfFile(file.key) - try { - if (fileVersion != null && (newestVersion == null || versionIsNewer(fileVersion, newestVersion))) { - newestVersion = fileVersion - newestFile = file.key - newestFileHash = file.value - } - } catch (_: Throwable) { - } - } - if (newestFile == null || newestFileHash == null) - throw InvalidParameterException("File list is empty") - return Pair(newestFile, newestFileHash) + return verifySignedFileList(fileListBytes.decodeToString()).maxByOrNull { it.version } } private suspend fun downloadAndUpdate() = withContext(Dispatchers.IO) { @@ -224,14 +234,14 @@ object Updater { emitProgress(Progress.Rechecking) val update = checkForUpdates() - val updateVersion = versionOfFile(checkForUpdates().first) ?: throw Exception("No versions returned") - if (!versionIsNewer(updateVersion, CURRENT_VERSION)) { + if (update == null || update.version <= CURRENT_VERSION) { emitProgress(Progress.Complete) return@withContext } emitProgress(Progress.Downloading(0UL, 0UL), true) - val connection = URL(APK_PATH_URL.format(update.first)).openConnection() as HttpURLConnection + val connection = + URL(APK_PATH_URL.format(update.fileName)).openConnection() as HttpURLConnection connection.setRequestProperty("User-Agent", Application.USER_AGENT) connection.connect() if (connection.responseCode != HttpURLConnection.HTTP_OK) @@ -246,7 +256,8 @@ object Updater { emitProgress(Progress.Downloading(downloadedByteLen, totalByteLen), true) val installer = context.packageManager.packageInstaller - val params = PackageInstaller.SessionParams(PackageInstaller.SessionParams.MODE_FULL_INSTALL) + val params = + PackageInstaller.SessionParams(PackageInstaller.SessionParams.MODE_FULL_INSTALL) if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) params.setRequireUserAction(PackageInstaller.SessionParams.USER_ACTION_NOT_REQUIRED) params.setAppPackageName(context.packageName) /* Enforces updates; disallows new apps. */ @@ -275,7 +286,7 @@ object Updater { } emitProgress(Progress.Installing) - if (!digest.digest().contentEquals(update.second.bytes)) + if (!digest.digest().contentEquals(update.hash.bytes)) throw SecurityException("Update has invalid hash") sessionFailure = false } finally { @@ -305,10 +316,17 @@ object Updater { return when (val status = - intent.getIntExtra(PackageInstaller.EXTRA_STATUS, PackageInstaller.STATUS_FAILURE_INVALID)) { + intent.getIntExtra( + PackageInstaller.EXTRA_STATUS, + PackageInstaller.STATUS_FAILURE_INVALID + )) { PackageInstaller.STATUS_PENDING_USER_ACTION -> { val id = intent.getIntExtra(PackageInstaller.EXTRA_SESSION_ID, 0) - val userIntervention = IntentCompat.getParcelableExtra(intent, Intent.EXTRA_INTENT, Intent::class.java)!! + val userIntervention = IntentCompat.getParcelableExtra( + intent, + Intent.EXTRA_INTENT, + Intent::class.java + )!! Application.getCoroutineScope().launch { emitProgress(Progress.NeedsUserIntervention(userIntervention, id)) } @@ -328,7 +346,8 @@ object Updater { } catch (_: SecurityException) { } val message = - intent.getStringExtra(PackageInstaller.EXTRA_STATUS_MESSAGE) ?: "Installation error $status" + intent.getStringExtra(PackageInstaller.EXTRA_STATUS_MESSAGE) + ?: "Installation error $status" Application.getCoroutineScope().launch { val e = Exception(message) Log.e(TAG, "Update failure", e) @@ -344,21 +363,22 @@ object Updater { if (installerIsGooglePlay()) return - Application.getCoroutineScope().launch(Dispatchers.IO) { - if (UserKnobs.updaterNewerVersionSeen.firstOrNull()?.let { versionIsNewer(it, CURRENT_VERSION) } == true) + updaterScope.launch { + if (UserKnobs.updaterNewerVersionSeen.firstOrNull() + ?.let { Version(it) > CURRENT_VERSION } == true + ) return@launch var waitTime = 15 while (true) { try { - val updateVersion = versionOfFile(checkForUpdates().first) ?: throw IllegalStateException("No versions returned") - if (versionIsNewer(updateVersion, CURRENT_VERSION)) { - Log.i(TAG, "Update available: $updateVersion") - UserKnobs.setUpdaterNewerVersionSeen(updateVersion) + val update = checkForUpdates() ?: continue + if (update.version > CURRENT_VERSION) { + Log.i(TAG, "Update available: ${update.version}") + UserKnobs.setUpdaterNewerVersionSeen(update.version.toString()) return@launch } - } catch (e: Throwable) { - Log.e(TAG, "Failed to check for updates", e) + } catch (_: Throwable) { } delay(waitTime.minutes) waitTime = 45 @@ -366,18 +386,17 @@ object Updater { } UserKnobs.updaterNewerVersionSeen.onEach { ver -> - if (ver != null && versionIsNewer( - ver, - CURRENT_VERSION - ) && UserKnobs.updaterNewerVersionConsented.firstOrNull() - ?.let { versionIsNewer(it, CURRENT_VERSION) } != true + if (ver != null && Version(ver) > CURRENT_VERSION && UserKnobs.updaterNewerVersionConsented.firstOrNull() + ?.let { Version(it) > CURRENT_VERSION } != true ) emitProgress(Progress.Available(ver)) }.launchIn(Application.getCoroutineScope()) UserKnobs.updaterNewerVersionConsented.onEach { ver -> - if (ver != null && versionIsNewer(ver, CURRENT_VERSION)) - downloadAndUpdateWrapErrors() + if (ver != null && Version(ver) > CURRENT_VERSION) + updaterScope.launch { + downloadAndUpdateWrapErrors() + } }.launchIn(Application.getCoroutineScope()) } |