summaryrefslogtreecommitdiffhomepage
path: root/ui/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'ui/src/main/java')
-rw-r--r--ui/src/main/java/com/wireguard/android/updater/Updater.kt177
1 files changed, 98 insertions, 79 deletions
diff --git a/ui/src/main/java/com/wireguard/android/updater/Updater.kt b/ui/src/main/java/com/wireguard/android/updater/Updater.kt
index d7050090..d0bd18a1 100644
--- a/ui/src/main/java/com/wireguard/android/updater/Updater.kt
+++ b/ui/src/main/java/com/wireguard/android/updater/Updater.kt
@@ -18,7 +18,9 @@ import androidx.core.content.IntentCompat
import com.wireguard.android.Application
import com.wireguard.android.BuildConfig
import com.wireguard.android.util.UserKnobs
+import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.Job
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
@@ -41,12 +43,16 @@ import kotlin.time.Duration.Companion.seconds
object Updater {
private const val TAG = "WireGuard/Updater"
- private const val LATEST_VERSION_URL = "https://download.wireguard.com/android-client/latest.sig"
+ private const val LATEST_VERSION_URL =
+ "https://download.wireguard.com/android-client/latest.sig"
private const val APK_PATH_URL = "https://download.wireguard.com/android-client/%s"
- private const val APK_NAME_PREFIX = BuildConfig.APPLICATION_ID + "-"
+ private val APK_NAME_PREFIX = BuildConfig.APPLICATION_ID.removeSuffix(".debug") + "-"
private const val APK_NAME_SUFFIX = ".apk"
- private const val RELEASE_PUBLIC_KEY_BASE64 = "RWTAzwGRYr3EC9px0Ia3fbttz8WcVN6wrOwWp2delz4el6SI8XmkKSMp"
- private val CURRENT_VERSION = BuildConfig.VERSION_NAME.removeSuffix("-debug")
+ private const val RELEASE_PUBLIC_KEY_BASE64 =
+ "RWTAzwGRYr3EC9px0Ia3fbttz8WcVN6wrOwWp2delz4el6SI8XmkKSMp"
+ private val CURRENT_VERSION = Version(BuildConfig.VERSION_NAME.removeSuffix("-debug"))
+
+ private val updaterScope = CoroutineScope(Job() + Dispatchers.IO)
sealed class Progress {
object Complete : Progress()
@@ -89,7 +95,7 @@ object Updater {
class Failure(val error: Throwable) : Progress() {
fun retry() {
- Application.getCoroutineScope().launch {
+ updaterScope.launch {
downloadAndUpdateWrapErrors()
}
}
@@ -104,29 +110,61 @@ object Updater {
mutableState.emit(progress)
}
- private fun versionIsNewer(lhs: String, rhs: String): Boolean {
- val lhsParts = lhs.split(".")
- val rhsParts = rhs.split(".")
- if (lhsParts.isEmpty() || rhsParts.isEmpty())
- throw InvalidParameterException("Version is empty")
-
- for (i in 0 until max(lhsParts.size, rhsParts.size)) {
- val lhsPart = if (i < lhsParts.size) lhsParts[i].toULong() else 0UL
- val rhsPart = if (i < rhsParts.size) rhsParts[i].toULong() else 0UL
- if (lhsPart == rhsPart)
- continue
- return lhsPart > rhsPart
+ private class Sha256Digest(hex: String) {
+ val bytes: ByteArray
+
+ init {
+ if (hex.length != 64)
+ throw InvalidParameterException("SHA256 hashes must be 32 bytes long")
+ bytes = hex.chunked(2).map { it.toInt(16).toByte() }.toByteArray()
+ }
+ }
+
+ @OptIn(ExperimentalUnsignedTypes::class)
+ private class Version(version: String) : Comparable<Version> {
+ val parts: ULongArray
+
+ init {
+ val strParts = version.split(".")
+ if (strParts.isEmpty())
+ throw InvalidParameterException("Version has no parts")
+ parts = ULongArray(strParts.size)
+ for (i in parts.indices) {
+ parts[i] = strParts[i].toULong()
+ }
+ }
+
+ override fun toString(): String {
+ return parts.joinToString(".")
+ }
+
+ override fun compareTo(other: Version): Int {
+ for (i in 0 until max(parts.size, other.parts.size)) {
+ val lhsPart = if (i < parts.size) parts[i] else 0UL
+ val rhsPart = if (i < other.parts.size) other.parts[i] else 0UL
+ if (lhsPart > rhsPart)
+ return 1
+ else if (lhsPart < rhsPart)
+ return -1
+ }
+ return 0
}
- return false
}
- private fun versionOfFile(name: String): String? {
+ private class Update(val fileName: String, val version: Version, val hash: Sha256Digest)
+
+ private fun versionOfFile(name: String): Version? {
if (!name.startsWith(APK_NAME_PREFIX) || !name.endsWith(APK_NAME_SUFFIX))
return null
- return name.substring(APK_NAME_PREFIX.length, name.length - APK_NAME_SUFFIX.length)
+ return try {
+ Version(name.substring(APK_NAME_PREFIX.length, name.length - APK_NAME_SUFFIX.length))
+ } catch (_: Throwable) {
+ null
+ }
}
- private fun verifySignedFileList(signifyDigest: String): Map<String, Sha256Digest> {
+ private fun verifySignedFileList(signifyDigest: String): List<Update> {
+ val updates = ArrayList<Update>(1)
val publicKeyBytes = Base64.decode(RELEASE_PUBLIC_KEY_BASE64, Base64.DEFAULT)
if (publicKeyBytes == null || publicKeyBytes.size != 32 + 10 || publicKeyBytes[0] != 'E'.code.toByte() || publicKeyBytes[1] != 'd'.code.toByte())
throw InvalidKeyException("Invalid public key")
@@ -149,32 +187,23 @@ object Updater {
)
)
throw SecurityException("Invalid signature")
- val hashes: MutableMap<String, Sha256Digest> = HashMap()
for (line in lines[2].split("\n").dropLastWhile { it.isEmpty() }) {
val components = line.split(" ", limit = 2)
if (components.size != 2)
throw InvalidParameterException("Invalid file list format: too few components")
- hashes[components[1]] = Sha256Digest(components[0])
- }
- return hashes
- }
-
- private class Sha256Digest(hex: String) {
- val bytes: ByteArray
-
- init {
- if (hex.length != 64)
- throw InvalidParameterException("SHA256 hashes must be 32 bytes long")
- bytes = hex.chunked(2).map { it.toInt(16).toByte() }.toByteArray()
+ /* If version is null, it's not a file we understand, but still a legitimate entry, so don't throw. */
+ val version = versionOfFile(components[1]) ?: continue
+ updates.add(Update(components[1], version, Sha256Digest(components[0])))
}
+ return updates
}
- private fun checkForUpdates(): Pair<String, Sha256Digest> {
+ private fun checkForUpdates(): Update? {
val connection = URL(LATEST_VERSION_URL).openConnection() as HttpURLConnection
connection.setRequestProperty("User-Agent", Application.USER_AGENT)
connection.connect()
if (connection.responseCode != HttpURLConnection.HTTP_OK)
- throw IOException("File list could not be fetched: ${connection.responseCode}")
+ throw IOException(connection.responseMessage)
var fileListBytes = ByteArray(1024 * 512 /* 512 KiB */)
connection.inputStream.use {
val len = it.read(fileListBytes)
@@ -182,26 +211,7 @@ object Updater {
throw IOException("File list is empty")
fileListBytes = fileListBytes.sliceArray(0 until len)
}
- val fileList = verifySignedFileList(fileListBytes.decodeToString())
- if (fileList.isEmpty())
- throw InvalidParameterException("File list is empty")
- var newestFile: String? = null
- var newestVersion: String? = null
- var newestFileHash: Sha256Digest? = null
- for (file in fileList) {
- val fileVersion = versionOfFile(file.key)
- try {
- if (fileVersion != null && (newestVersion == null || versionIsNewer(fileVersion, newestVersion))) {
- newestVersion = fileVersion
- newestFile = file.key
- newestFileHash = file.value
- }
- } catch (_: Throwable) {
- }
- }
- if (newestFile == null || newestFileHash == null)
- throw InvalidParameterException("File list is empty")
- return Pair(newestFile, newestFileHash)
+ return verifySignedFileList(fileListBytes.decodeToString()).maxByOrNull { it.version }
}
private suspend fun downloadAndUpdate() = withContext(Dispatchers.IO) {
@@ -224,14 +234,14 @@ object Updater {
emitProgress(Progress.Rechecking)
val update = checkForUpdates()
- val updateVersion = versionOfFile(checkForUpdates().first) ?: throw Exception("No versions returned")
- if (!versionIsNewer(updateVersion, CURRENT_VERSION)) {
+ if (update == null || update.version <= CURRENT_VERSION) {
emitProgress(Progress.Complete)
return@withContext
}
emitProgress(Progress.Downloading(0UL, 0UL), true)
- val connection = URL(APK_PATH_URL.format(update.first)).openConnection() as HttpURLConnection
+ val connection =
+ URL(APK_PATH_URL.format(update.fileName)).openConnection() as HttpURLConnection
connection.setRequestProperty("User-Agent", Application.USER_AGENT)
connection.connect()
if (connection.responseCode != HttpURLConnection.HTTP_OK)
@@ -246,7 +256,8 @@ object Updater {
emitProgress(Progress.Downloading(downloadedByteLen, totalByteLen), true)
val installer = context.packageManager.packageInstaller
- val params = PackageInstaller.SessionParams(PackageInstaller.SessionParams.MODE_FULL_INSTALL)
+ val params =
+ PackageInstaller.SessionParams(PackageInstaller.SessionParams.MODE_FULL_INSTALL)
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S)
params.setRequireUserAction(PackageInstaller.SessionParams.USER_ACTION_NOT_REQUIRED)
params.setAppPackageName(context.packageName) /* Enforces updates; disallows new apps. */
@@ -275,7 +286,7 @@ object Updater {
}
emitProgress(Progress.Installing)
- if (!digest.digest().contentEquals(update.second.bytes))
+ if (!digest.digest().contentEquals(update.hash.bytes))
throw SecurityException("Update has invalid hash")
sessionFailure = false
} finally {
@@ -305,10 +316,17 @@ object Updater {
return
when (val status =
- intent.getIntExtra(PackageInstaller.EXTRA_STATUS, PackageInstaller.STATUS_FAILURE_INVALID)) {
+ intent.getIntExtra(
+ PackageInstaller.EXTRA_STATUS,
+ PackageInstaller.STATUS_FAILURE_INVALID
+ )) {
PackageInstaller.STATUS_PENDING_USER_ACTION -> {
val id = intent.getIntExtra(PackageInstaller.EXTRA_SESSION_ID, 0)
- val userIntervention = IntentCompat.getParcelableExtra(intent, Intent.EXTRA_INTENT, Intent::class.java)!!
+ val userIntervention = IntentCompat.getParcelableExtra(
+ intent,
+ Intent.EXTRA_INTENT,
+ Intent::class.java
+ )!!
Application.getCoroutineScope().launch {
emitProgress(Progress.NeedsUserIntervention(userIntervention, id))
}
@@ -328,7 +346,8 @@ object Updater {
} catch (_: SecurityException) {
}
val message =
- intent.getStringExtra(PackageInstaller.EXTRA_STATUS_MESSAGE) ?: "Installation error $status"
+ intent.getStringExtra(PackageInstaller.EXTRA_STATUS_MESSAGE)
+ ?: "Installation error $status"
Application.getCoroutineScope().launch {
val e = Exception(message)
Log.e(TAG, "Update failure", e)
@@ -344,21 +363,22 @@ object Updater {
if (installerIsGooglePlay())
return
- Application.getCoroutineScope().launch(Dispatchers.IO) {
- if (UserKnobs.updaterNewerVersionSeen.firstOrNull()?.let { versionIsNewer(it, CURRENT_VERSION) } == true)
+ updaterScope.launch {
+ if (UserKnobs.updaterNewerVersionSeen.firstOrNull()
+ ?.let { Version(it) > CURRENT_VERSION } == true
+ )
return@launch
var waitTime = 15
while (true) {
try {
- val updateVersion = versionOfFile(checkForUpdates().first) ?: throw IllegalStateException("No versions returned")
- if (versionIsNewer(updateVersion, CURRENT_VERSION)) {
- Log.i(TAG, "Update available: $updateVersion")
- UserKnobs.setUpdaterNewerVersionSeen(updateVersion)
+ val update = checkForUpdates() ?: continue
+ if (update.version > CURRENT_VERSION) {
+ Log.i(TAG, "Update available: ${update.version}")
+ UserKnobs.setUpdaterNewerVersionSeen(update.version.toString())
return@launch
}
- } catch (e: Throwable) {
- Log.e(TAG, "Failed to check for updates", e)
+ } catch (_: Throwable) {
}
delay(waitTime.minutes)
waitTime = 45
@@ -366,18 +386,17 @@ object Updater {
}
UserKnobs.updaterNewerVersionSeen.onEach { ver ->
- if (ver != null && versionIsNewer(
- ver,
- CURRENT_VERSION
- ) && UserKnobs.updaterNewerVersionConsented.firstOrNull()
- ?.let { versionIsNewer(it, CURRENT_VERSION) } != true
+ if (ver != null && Version(ver) > CURRENT_VERSION && UserKnobs.updaterNewerVersionConsented.firstOrNull()
+ ?.let { Version(it) > CURRENT_VERSION } != true
)
emitProgress(Progress.Available(ver))
}.launchIn(Application.getCoroutineScope())
UserKnobs.updaterNewerVersionConsented.onEach { ver ->
- if (ver != null && versionIsNewer(ver, CURRENT_VERSION))
- downloadAndUpdateWrapErrors()
+ if (ver != null && Version(ver) > CURRENT_VERSION)
+ updaterScope.launch {
+ downloadAndUpdateWrapErrors()
+ }
}.launchIn(Application.getCoroutineScope())
}