/* * Copyright © 2017-2019 WireGuard LLC. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0 */ package com.wireguard.android.model import android.annotation.SuppressLint import android.content.BroadcastReceiver import android.content.Context import android.content.Intent import android.os.Build import android.util.Log import androidx.databinding.BaseObservable import androidx.databinding.Bindable import com.wireguard.android.Application.Companion.get import com.wireguard.android.Application.Companion.getBackend import com.wireguard.android.Application.Companion.getSharedPreferences import com.wireguard.android.Application.Companion.getTunnelManager import com.wireguard.android.BR import com.wireguard.android.R import com.wireguard.android.backend.Statistics import com.wireguard.android.backend.Tunnel import com.wireguard.android.configStore.ConfigStore import com.wireguard.android.databinding.ObservableSortedKeyedArrayList import com.wireguard.config.Config import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.GlobalScope import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.launch import kotlinx.coroutines.withContext /** * Maintains and mediates changes to the set of available WireGuard tunnels, */ class TunnelManager(private val configStore: ConfigStore) : BaseObservable() { private val tunnels = CompletableDeferred>() private val context: Context = get() private val tunnelMap: ObservableSortedKeyedArrayList = ObservableSortedKeyedArrayList(TunnelComparator) private var haveLoaded = false private fun addToList(name: String, config: Config?, state: Tunnel.State): ObservableTunnel { val tunnel = ObservableTunnel(this, name, config, state) tunnelMap.add(tunnel) return tunnel } suspend fun getTunnels(): ObservableSortedKeyedArrayList = tunnels.await() suspend fun create(name: String, config: Config?): ObservableTunnel = withContext(Dispatchers.Main.immediate) { if (Tunnel.isNameInvalid(name)) throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) if (tunnelMap.containsKey(name)) throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) addToList(name, withContext(Dispatchers.IO) { configStore.create(name, config!!) }, Tunnel.State.DOWN) } suspend fun delete(tunnel: ObservableTunnel) = withContext(Dispatchers.Main.immediate) { val originalState = tunnel.state val wasLastUsed = tunnel == lastUsedTunnel // Make sure nothing touches the tunnel. if (wasLastUsed) lastUsedTunnel = null tunnelMap.remove(tunnel) try { if (originalState == Tunnel.State.UP) withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } try { withContext(Dispatchers.IO) { configStore.delete(tunnel.name) } } catch (e: Throwable) { if (originalState == Tunnel.State.UP) withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } throw e } } catch (e: Throwable) { // Failure, put the tunnel back. tunnelMap.add(tunnel) if (wasLastUsed) lastUsedTunnel = tunnel throw e } } @get:Bindable @SuppressLint("ApplySharedPref") var lastUsedTunnel: ObservableTunnel? = null private set(value) { if (value == field) return field = value notifyPropertyChanged(BR.lastUsedTunnel) if (value != null) getSharedPreferences().edit().putString(KEY_LAST_USED_TUNNEL, value.name).commit() else getSharedPreferences().edit().remove(KEY_LAST_USED_TUNNEL).commit() } suspend fun getTunnelConfig(tunnel: ObservableTunnel): Config = withContext(Dispatchers.Main.immediate) { tunnel.onConfigChanged(withContext(Dispatchers.IO) { configStore.load(tunnel.name) })!! } fun onCreate() { GlobalScope.launch(Dispatchers.Main.immediate) { try { onTunnelsLoaded(withContext(Dispatchers.IO) { configStore.enumerate() }, withContext(Dispatchers.IO) { getBackend().runningTunnelNames }) } catch (e: Throwable) { Log.e(TAG, Log.getStackTraceString(e)) } } } private fun onTunnelsLoaded(present: Iterable, running: Collection) { for (name in present) addToList(name, null, if (running.contains(name)) Tunnel.State.UP else Tunnel.State.DOWN) val lastUsedName = getSharedPreferences().getString(KEY_LAST_USED_TUNNEL, null) if (lastUsedName != null) lastUsedTunnel = tunnelMap[lastUsedName] haveLoaded = true restoreState(true) tunnels.complete(tunnelMap) } private fun refreshTunnelStates() { GlobalScope.launch(Dispatchers.Main.immediate) { try { val running = withContext(Dispatchers.IO) { getBackend().runningTunnelNames } for (tunnel in tunnelMap) tunnel.onStateChanged(if (running.contains(tunnel.name)) Tunnel.State.UP else Tunnel.State.DOWN) } catch (e: Throwable) { Log.e(TAG, Log.getStackTraceString(e)) } } } fun restoreState(force: Boolean) { if (!haveLoaded || (!force && !getSharedPreferences().getBoolean(KEY_RESTORE_ON_BOOT, false))) return val previouslyRunning = getSharedPreferences().getStringSet(KEY_RUNNING_TUNNELS, null) ?: return if (previouslyRunning.isEmpty()) return GlobalScope.launch(Dispatchers.Main.immediate) { withContext(Dispatchers.IO) { try { tunnelMap.filter { previouslyRunning.contains(it.name) }.map { async(SupervisorJob()) { setTunnelState(it, Tunnel.State.UP) } }.awaitAll() } catch (e: Throwable) { Log.e(TAG, Log.getStackTraceString(e)) } } } } @SuppressLint("ApplySharedPref") fun saveState() { getSharedPreferences().edit().putStringSet(KEY_RUNNING_TUNNELS, tunnelMap.filter { it.state == Tunnel.State.UP }.map { it.name }.toSet()).commit() } suspend fun setTunnelConfig(tunnel: ObservableTunnel, config: Config): Config = withContext(Dispatchers.Main.immediate) { tunnel.onConfigChanged(withContext(Dispatchers.IO) { getBackend().setState(tunnel, tunnel.state, config) configStore.save(tunnel.name, config) })!! } suspend fun setTunnelName(tunnel: ObservableTunnel, name: String): String = withContext(Dispatchers.Main.immediate) { if (Tunnel.isNameInvalid(name)) throw IllegalArgumentException(context.getString(R.string.tunnel_error_invalid_name)) if (tunnelMap.containsKey(name)) { throw IllegalArgumentException(context.getString(R.string.tunnel_error_already_exists, name)) } val originalState = tunnel.state val wasLastUsed = tunnel == lastUsedTunnel // Make sure nothing touches the tunnel. if (wasLastUsed) lastUsedTunnel = null tunnelMap.remove(tunnel) var throwable: Throwable? = null var newName: String? = null try { if (originalState == Tunnel.State.UP) withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.DOWN, null) } withContext(Dispatchers.IO) { configStore.rename(tunnel.name, name) } newName = tunnel.onNameChanged(name) if (originalState == Tunnel.State.UP) withContext(Dispatchers.IO) { getBackend().setState(tunnel, Tunnel.State.UP, tunnel.config) } } catch (e: Throwable) { throwable = e // On failure, we don't know what state the tunnel might be in. Fix that. getTunnelState(tunnel) } // Add the tunnel back to the manager, under whatever name it thinks it has. tunnelMap.add(tunnel) if (wasLastUsed) lastUsedTunnel = tunnel if (throwable != null) throw throwable newName!! } suspend fun setTunnelState(tunnel: ObservableTunnel, state: Tunnel.State): Tunnel.State = withContext(Dispatchers.Main.immediate) { var newState = tunnel.state var throwable: Throwable? = null try { newState = withContext(Dispatchers.IO) { getBackend().setState(tunnel, state, tunnel.getConfigAsync()) } if (newState == Tunnel.State.UP) lastUsedTunnel = tunnel } catch (e: Throwable) { throwable = e } tunnel.onStateChanged(newState) saveState() if (throwable != null) throw throwable newState } class IntentReceiver : BroadcastReceiver() { override fun onReceive(context: Context, intent: Intent?) { val manager = getTunnelManager() if (intent == null) return val action = intent.action ?: return if ("com.wireguard.android.action.REFRESH_TUNNEL_STATES" == action) { manager.refreshTunnelStates() return } if (Build.VERSION.SDK_INT < Build.VERSION_CODES.M || !getSharedPreferences().getBoolean("allow_remote_control_intents", false)) return val state: Tunnel.State state = when (action) { "com.wireguard.android.action.SET_TUNNEL_UP" -> Tunnel.State.UP "com.wireguard.android.action.SET_TUNNEL_DOWN" -> Tunnel.State.DOWN else -> return } val tunnelName = intent.getStringExtra("tunnel") ?: return GlobalScope.launch(Dispatchers.Main.immediate) { val tunnels = manager.getTunnels() val tunnel = tunnels[tunnelName] ?: return@launch manager.setTunnelState(tunnel, state) } } } suspend fun getTunnelState(tunnel: ObservableTunnel): Tunnel.State = withContext(Dispatchers.Main.immediate) { tunnel.onStateChanged(withContext(Dispatchers.IO) { getBackend().getState(tunnel) }) } suspend fun getTunnelStatistics(tunnel: ObservableTunnel): Statistics = withContext(Dispatchers.Main.immediate) { tunnel.onStatisticsChanged(withContext(Dispatchers.IO) { getBackend().getStatistics(tunnel) })!! } companion object { private const val TAG = "WireGuard/TunnelManager" private const val KEY_LAST_USED_TUNNEL = "last_used_tunnel" private const val KEY_RESTORE_ON_BOOT = "restore_on_boot" private const val KEY_RUNNING_TUNNELS = "enabled_configs" } }