diff options
Diffstat (limited to 'tunnel')
-rw-r--r-- | tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java | 52 | ||||
-rw-r--r-- | tunnel/src/main/java/com/wireguard/config/Config.java | 12 | ||||
-rw-r--r-- | tunnel/src/main/java/com/wireguard/config/InetEndpoint.java | 25 | ||||
-rw-r--r-- | tunnel/src/main/java/com/wireguard/config/Peer.java | 13 | ||||
-rw-r--r-- | tunnel/src/main/java/com/wireguard/util/Resolver.java | 136 | ||||
-rw-r--r-- | tunnel/tools/libwg-go/api-android.go | 13 | ||||
-rw-r--r-- | tunnel/tools/libwg-go/jni.c | 13 |
7 files changed, 244 insertions, 20 deletions
diff --git a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java index 242f81d8..19f328f2 100644 --- a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java +++ b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java @@ -7,6 +7,11 @@ package com.wireguard.android.backend; import android.content.Context; import android.content.Intent; +import android.net.ConnectivityManager; +import android.net.LinkProperties; +import android.net.Network; +import android.net.NetworkCapabilities; +import android.net.NetworkRequest; import android.net.ProxyInfo; import android.os.Build; import android.os.ParcelFileDescriptor; @@ -23,6 +28,7 @@ import com.wireguard.config.Peer; import com.wireguard.crypto.Key; import com.wireguard.crypto.KeyFormatException; import com.wireguard.util.NonNullForAll; +import com.wireguard.util.Resolver; import java.net.InetAddress; import java.net.URL; @@ -52,6 +58,9 @@ public final class GoBackend implements Backend { @Nullable private Config currentConfig; @Nullable private Tunnel currentTunnel; private int currentTunnelHandle = -1; + private ConnectivityManager connectivityManager; + private ConnectivityManager.NetworkCallback myNetworkCallback = new MyNetworkCallback(); + @Nullable private Network activeNetwork; /** * Public constructor for GoBackend. @@ -61,6 +70,7 @@ public final class GoBackend implements Backend { public GoBackend(final Context context) { SharedLibraryLoader.loadSharedLibrary(context, "wg-go"); this.context = context; + connectivityManager = context.getSystemService(ConnectivityManager.class); } /** @@ -79,6 +89,8 @@ public final class GoBackend implements Backend { private static native int wgGetSocketV6(int handle); + private static native int wgSetConfig(int handle, String settings); + private static native void wgTurnOff(int handle); private static native int wgTurnOn(String ifName, int tunFd, String settings); @@ -258,13 +270,19 @@ public final class GoBackend implements Backend { } + activeNetwork = connectivityManager.getActiveNetwork(); + if (!connectivityManager.getNetworkCapabilities(activeNetwork).hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)) { + Log.w(TAG, "VPN network is active, null activeNetwork"); + activeNetwork = null; + } + final Resolver resolver = new Resolver(activeNetwork, connectivityManager.getLinkProperties(activeNetwork)); dnsRetry: for (int i = 0; i < DNS_RESOLUTION_RETRIES; ++i) { // Pre-resolve IPs so they're cached when building the userspace string for (final Peer peer : config.getPeers()) { final InetEndpoint ep = peer.getEndpoint().orElse(null); if (ep == null) continue; - if (ep.getResolved().orElse(null) == null) { + if (ep.getResolved(resolver, true).orElse(null) == null) { if (i < DNS_RESOLUTION_RETRIES - 1) { Log.w(TAG, "DNS host \"" + ep.getHost() + "\" failed to resolve; trying again"); Thread.sleep(1000); @@ -277,7 +295,7 @@ public final class GoBackend implements Backend { } // Build config - final String goConfig = config.toWgUserspaceString(); + final String goConfig = config.toWgUserspaceString(resolver); // Create the vpn tunnel with android API final VpnService.Builder builder = service.getBuilder(); @@ -339,6 +357,9 @@ public final class GoBackend implements Backend { service.protect(wgGetSocketV4(currentTunnelHandle)); service.protect(wgGetSocketV6(currentTunnelHandle)); + + NetworkRequest req = new NetworkRequest.Builder().addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN).build(); + connectivityManager.requestNetwork(req, myNetworkCallback); } else { if (currentTunnelHandle == -1) { Log.w(TAG, "Tunnel already down"); @@ -348,6 +369,8 @@ public final class GoBackend implements Backend { currentTunnel = null; currentTunnelHandle = -1; currentConfig = null; + connectivityManager.unregisterNetworkCallback(myNetworkCallback); + activeNetwork = null; wgTurnOff(handleToClose); try { vpnService.get(0, TimeUnit.NANOSECONDS).stopSelf(); @@ -415,8 +438,11 @@ public final class GoBackend implements Backend { if (owner != null) { final Tunnel tunnel = owner.currentTunnel; if (tunnel != null) { - if (owner.currentTunnelHandle != -1) + if (owner.currentTunnelHandle != -1) { + owner.connectivityManager.unregisterNetworkCallback(owner.myNetworkCallback); + owner.activeNetwork = null; wgTurnOff(owner.currentTunnelHandle); + } owner.currentTunnel = null; owner.currentTunnelHandle = -1; owner.currentConfig = null; @@ -442,4 +468,24 @@ public final class GoBackend implements Backend { this.owner = owner; } } + + private class MyNetworkCallback extends ConnectivityManager.NetworkCallback { + @Override + public void onAvailable(Network network) { + activeNetwork = network; + Log.w(TAG, "onAvailable: " + activeNetwork); + } + + @Override + public void onLinkPropertiesChanged(Network network, LinkProperties linkProperties) { + Log.w(TAG, "onLinkPropertiesChanged: " + network + " is default:" + (network.equals(activeNetwork))); + if (network.equals(activeNetwork) && currentConfig != null && currentTunnelHandle > -1) { + final Resolver resolver = new Resolver(network, linkProperties); + final String goConfig = currentConfig.toWgEndpointsUserspaceString(resolver); + Log.w(TAG, "is default network, config:" + goConfig); + + wgSetConfig(currentTunnelHandle, goConfig); + } + } + } } diff --git a/tunnel/src/main/java/com/wireguard/config/Config.java b/tunnel/src/main/java/com/wireguard/config/Config.java index ee9cebce..12ddb242 100644 --- a/tunnel/src/main/java/com/wireguard/config/Config.java +++ b/tunnel/src/main/java/com/wireguard/config/Config.java @@ -9,6 +9,7 @@ import com.wireguard.config.BadConfigException.Location; import com.wireguard.config.BadConfigException.Reason; import com.wireguard.config.BadConfigException.Section; import com.wireguard.util.NonNullForAll; +import com.wireguard.util.Resolver; import java.io.BufferedReader; import java.io.IOException; @@ -173,12 +174,19 @@ public final class Config { * * @return the {@code Config} represented as a series of "key=value" lines */ - public String toWgUserspaceString() { + public String toWgUserspaceString(Resolver resolver) { final StringBuilder sb = new StringBuilder(); sb.append(interfaze.toWgUserspaceString()); sb.append("replace_peers=true\n"); for (final Peer peer : peers) - sb.append(peer.toWgUserspaceString()); + sb.append(peer.toWgUserspaceString(resolver)); + return sb.toString(); + } + + public String toWgEndpointsUserspaceString(Resolver resolver) { + final StringBuilder sb = new StringBuilder(); + for (final Peer peer : peers) + sb.append(peer.toWgEndpointsUserspaceString(resolver)); return sb.toString(); } diff --git a/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java b/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java index d1db432b..e1e9d653 100644 --- a/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java +++ b/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java @@ -6,6 +6,7 @@ package com.wireguard.config; import com.wireguard.util.NonNullForAll; +import com.wireguard.util.Resolver; import java.net.Inet4Address; import java.net.InetAddress; @@ -87,24 +88,22 @@ public final class InetEndpoint { * * @return the resolved endpoint, or {@link Optional#empty()} */ - public Optional<InetEndpoint> getResolved() { - if (isResolved) + public Optional<InetEndpoint> getResolved(Resolver resolver) { + return getResolved(resolver, false); + } + + public Optional<InetEndpoint> getResolved(Resolver resolver, Boolean force) { + if (!force && isResolved) return Optional.of(this); synchronized (lock) { //TODO(zx2c4): Implement a real timeout mechanism using DNS TTL - if (Duration.between(lastResolution, Instant.now()).toMinutes() > 1) { + if (force || Duration.between(lastResolution, Instant.now()).toMinutes() > 1) { try { - // Prefer v4 endpoints over v6 to work around DNS64 and IPv6 NAT issues. - final InetAddress[] candidates = InetAddress.getAllByName(host); - InetAddress address = candidates[0]; - for (final InetAddress candidate : candidates) { - if (candidate instanceof Inet4Address) { - address = candidate; - break; - } - } - resolved = new InetEndpoint(address.getHostAddress(), true, port); + InetAddress address = resolver.resolve(host); + InetEndpoint resolvedNow = new InetEndpoint(address.getHostAddress(), true, port); lastResolution = Instant.now(); + + resolved = resolvedNow; } catch (final UnknownHostException e) { resolved = null; } diff --git a/tunnel/src/main/java/com/wireguard/config/Peer.java b/tunnel/src/main/java/com/wireguard/config/Peer.java index 8a0fd763..1d20a961 100644 --- a/tunnel/src/main/java/com/wireguard/config/Peer.java +++ b/tunnel/src/main/java/com/wireguard/config/Peer.java @@ -11,6 +11,7 @@ import com.wireguard.config.BadConfigException.Section; import com.wireguard.crypto.Key; import com.wireguard.crypto.KeyFormatException; import com.wireguard.util.NonNullForAll; +import com.wireguard.util.Resolver; import java.util.Collection; import java.util.Collections; @@ -190,18 +191,26 @@ public final class Peer { * * @return the {@code Peer} represented as a series of "key=value" lines */ - public String toWgUserspaceString() { + public String toWgUserspaceString(Resolver resolver) { final StringBuilder sb = new StringBuilder(); // The order here is important: public_key signifies the beginning of a new peer. sb.append("public_key=").append(publicKey.toHex()).append('\n'); for (final InetNetwork allowedIp : allowedIps) sb.append("allowed_ip=").append(allowedIp).append('\n'); - endpoint.flatMap(InetEndpoint::getResolved).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); + endpoint.flatMap(ep -> ep.getResolved(resolver)).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); persistentKeepalive.ifPresent(pk -> sb.append("persistent_keepalive_interval=").append(pk).append('\n')); preSharedKey.ifPresent(psk -> sb.append("preshared_key=").append(psk.toHex()).append('\n')); return sb.toString(); } + public String toWgEndpointsUserspaceString(Resolver resolver) { + final StringBuilder sb = new StringBuilder(); + // The order here is important: public_key signifies the beginning of a new peer. + sb.append("public_key=").append(publicKey.toHex()).append('\n'); + endpoint.flatMap(ep -> ep.getResolved(resolver, true)).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); + return sb.toString(); + } + @SuppressWarnings("UnusedReturnValue") public static final class Builder { // See wg(8) diff --git a/tunnel/src/main/java/com/wireguard/util/Resolver.java b/tunnel/src/main/java/com/wireguard/util/Resolver.java new file mode 100644 index 00000000..04e6b9a8 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/util/Resolver.java @@ -0,0 +1,136 @@ +/* + * Copyright © 2023 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.util; + +import java.io.IOException; +import java.net.DatagramSocket; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketException; +import java.net.UnknownHostException; + +import android.net.IpPrefix; +import android.net.LinkProperties; +import android.net.Network; +import android.util.Log; + +import androidx.annotation.Nullable; + +@NonNullForAll +public class Resolver { + private static final String TAG = "WireGuard/Resolver"; + @Nullable private final Network network; + @Nullable private final LinkProperties linkProps; + @Nullable private IpPrefix nat64Prefix; + + public Resolver(Network network, LinkProperties linkProps) { + this.network = network; + this.linkProps = linkProps; + if (linkProps != null) { + this.nat64Prefix = linkProps.getNat64Prefix(); + } + } + + static boolean isULA(Inet6Address addr) { + byte[] raw = addr.getAddress(); + return ((raw[0] & 0xfe) == 0xfc); + } + + boolean isWithinNAT64Prefix(Inet6Address address) { + if (nat64Prefix == null) + return false; + + int prefixLength = nat64Prefix.getPrefixLength(); + byte[] rawAddr = address.getAddress(); + byte[] rawPrefix = nat64Prefix.getRawAddress(); + + for (int i=0; i < prefixLength/8; i++) { + if (rawAddr[i] != rawPrefix[i]) + return false; + } + + return true; + } + + boolean isPreferredIPv6(Inet6Address local, Inet6Address remote) { + if (linkProps == null) { + // Prefer IPv4 if there are not link properties that can + // be tested. + return false; + } + + // * Prefer IPv4 if local or remote address is ULA + // * Prefer IPv4 if remote IPv6 is within NAT64 prefix. + // * Otherwise prefer IPv6 + boolean isLocalULA = isULA(local); + boolean isRemoteULA = isULA(remote); + + if (isLocalULA || isRemoteULA) { + return false; + } + + if (isWithinNAT64Prefix(remote)) { + return false; + } + + return true; + } + + public InetAddress resolve(String host) throws UnknownHostException { + final InetAddress[] candidates = network != null ? network.getAllByName(host) : InetAddress.getAllByName(host); + InetAddress address = candidates[0]; + for (final InetAddress candidate : candidates) { + DatagramSocket sock; + + try { + sock = new DatagramSocket(); + if (network != null) { + network.bindSocket(sock); + } + } catch (SocketException e) { + // Return first candidate as fallback + Log.w(TAG, "DatagramSocket failed, fallback to: \"" + address); + return address; + } catch (IOException e) { + // Return first candidate as fallback + Log.w(TAG, "BindSocket failed, fallback to: \"" + address); + return address; + } + + sock.connect(candidate, 51820); + + if (sock.getLocalAddress().isAnyLocalAddress()) { + // Connect didn't find a local address. + Log.w(TAG, "No local address"); + sock.close(); + continue; + } + + Log.w(TAG, "Local address: " + sock.getLocalAddress()); + + if (candidate instanceof Inet4Address) { + // Accept IPv4 as preferred address. + address = candidate; + sock.close(); + break; + } + + Inet6Address local = (Inet6Address)sock.getLocalAddress(); + InetSocketAddress remoteSockAddr = (InetSocketAddress)sock.getRemoteSocketAddress(); + Inet6Address remote = (Inet6Address)remoteSockAddr.getAddress(); + sock.close(); + + if (isPreferredIPv6(local, remote)) { + address = candidate; + break; + } + } + Log.w(TAG, "Resolved \"" + host + "\" to: " + address); + return address; + } +} diff --git a/tunnel/tools/libwg-go/api-android.go b/tunnel/tools/libwg-go/api-android.go index d47c5d76..d4e36480 100644 --- a/tunnel/tools/libwg-go/api-android.go +++ b/tunnel/tools/libwg-go/api-android.go @@ -224,4 +224,17 @@ func wgVersion() *C.char { return C.CString("unknown") } +//export wgSetConfig +func wgSetConfig(tunnelHandle int32, settings string) int32 { + handle, ok := tunnelHandles[tunnelHandle] + if !ok { + return -1 + } + err := handle.device.IpcSet(settings) + if err != nil { + return -1 + } + return 0 +} + func main() {} diff --git a/tunnel/tools/libwg-go/jni.c b/tunnel/tools/libwg-go/jni.c index 7ad94d35..e47b0be5 100644 --- a/tunnel/tools/libwg-go/jni.c +++ b/tunnel/tools/libwg-go/jni.c @@ -14,6 +14,7 @@ extern int wgGetSocketV4(int handle); extern int wgGetSocketV6(int handle); extern char *wgGetConfig(int handle); extern char *wgVersion(); +extern int wgSetConfig(int handle, struct go_string settings); JNIEXPORT jint JNICALL Java_com_wireguard_android_backend_GoBackend_wgTurnOn(JNIEnv *env, jclass c, jstring ifname, jint tun_fd, jstring settings) { @@ -69,3 +70,15 @@ JNIEXPORT jstring JNICALL Java_com_wireguard_android_backend_GoBackend_wgVersion free(version); return ret; } + +JNIEXPORT jint JNICALL Java_com_wireguard_android_backend_GoBackend_wgSetConfig(JNIEnv *env, jclass c, jint handle, jstring settings) +{ + const char *settings_str = (*env)->GetStringUTFChars(env, settings, 0); + size_t settings_len = (*env)->GetStringUTFLength(env, settings); + int ret = wgSetConfig(handle, (struct go_string){ + .str = settings_str, + .n = settings_len + }); + (*env)->ReleaseStringUTFChars(env, settings, settings_str); + return ret; +} |