diff options
author | Mikael Magnusson <mikma@users.sourceforge.net> | 2022-12-29 00:23:17 +0100 |
---|---|---|
committer | Mikael Magnusson <mikma@users.sourceforge.net> | 2023-02-09 21:22:14 +0100 |
commit | 2eccfc91f3bda1158d056e3cb5e5868afa787731 (patch) | |
tree | 7aca22c01b7592b0b2233cc535565a196000fefd /tunnel | |
parent | da0f9feb33bbc97140963f575e4538d846f3048b (diff) |
tunnel: auto-detect IPv6/IPv4 preference
Detect IP address change.
Request non-VPN network.
Update endpoint when needed.
Unregister network on wgTurnOff and use IPv4 if network is not known.
Diffstat (limited to 'tunnel')
-rw-r--r-- | tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java | 53 | ||||
-rw-r--r-- | tunnel/src/main/java/com/wireguard/config/Config.java | 12 | ||||
-rw-r--r-- | tunnel/src/main/java/com/wireguard/config/InetEndpoint.java | 37 | ||||
-rw-r--r-- | tunnel/src/main/java/com/wireguard/config/Peer.java | 13 | ||||
-rw-r--r-- | tunnel/src/main/java/com/wireguard/util/Resolver.java | 134 | ||||
-rw-r--r-- | tunnel/src/main/proto/libwg.proto | 10 | ||||
-rw-r--r-- | tunnel/tools/libwg-go/api-android.go | 5 | ||||
-rw-r--r-- | tunnel/tools/libwg-go/service.go | 27 |
8 files changed, 271 insertions, 20 deletions
diff --git a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java index 5d544f56..e8148c04 100644 --- a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java +++ b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java @@ -9,7 +9,11 @@ import android.content.Context; import android.content.Intent; import android.content.pm.PackageManager; import android.net.ConnectivityManager; +import android.net.LinkProperties; import android.net.LocalSocketAddress; +import android.net.Network; +import android.net.NetworkCapabilities; +import android.net.NetworkRequest; import android.net.ProxyInfo; import android.net.Uri; import android.os.Build; @@ -25,6 +29,8 @@ import com.google.protobuf.Empty; import com.wireguard.android.backend.BackendException.Reason; import com.wireguard.android.backend.Tunnel.State; import com.wireguard.android.backend.gen.GetConnectionOwnerUidResponse; +import com.wireguard.android.backend.gen.IpcSetRequest; +import com.wireguard.android.backend.gen.IpcSetResponse; import com.wireguard.android.backend.gen.LibwgGrpc; import com.wireguard.android.backend.gen.ReverseRequest; import com.wireguard.android.backend.gen.ReverseResponse; @@ -32,6 +38,7 @@ import com.wireguard.android.backend.gen.StartHttpProxyRequest; import com.wireguard.android.backend.gen.StartHttpProxyResponse; import com.wireguard.android.backend.gen.StopHttpProxyRequest; import com.wireguard.android.backend.gen.StopHttpProxyResponse; +import com.wireguard.android.backend.gen.TunnelHandle; import com.wireguard.android.backend.gen.VersionRequest; import com.wireguard.android.backend.gen.VersionResponse; import com.wireguard.android.util.SharedLibraryLoader; @@ -43,6 +50,7 @@ import com.wireguard.config.Peer; import com.wireguard.crypto.Key; import com.wireguard.crypto.KeyFormatException; import com.wireguard.util.NonNullForAll; +import com.wireguard.util.Resolver; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; @@ -94,6 +102,8 @@ public final class GoBackend implements Backend { private int currentTunnelHandle = -1; private ManagedChannel channel; private ConnectivityManager connectivityManager; + private ConnectivityManager.NetworkCallback myNetworkCallback = new MyNetworkCallback(); + @Nullable private Network activeNetwork; /** * Public constructor for GoBackend. @@ -448,13 +458,19 @@ public final class GoBackend implements Backend { } + activeNetwork = connectivityManager.getActiveNetwork(); + if (!connectivityManager.getNetworkCapabilities(activeNetwork).hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)) { + Log.w(TAG, "VPN network is active, null activeNetwork"); + activeNetwork = null; + } + final Resolver resolver = new Resolver(activeNetwork, connectivityManager.getLinkProperties(activeNetwork)); dnsRetry: for (int i = 0; i < DNS_RESOLUTION_RETRIES; ++i) { // Pre-resolve IPs so they're cached when building the userspace string for (final Peer peer : config.getPeers()) { final InetEndpoint ep = peer.getEndpoint().orElse(null); if (ep == null) continue; - if (ep.getResolved().orElse(null) == null) { + if (ep.getResolved(resolver, true).orElse(null) == null) { if (i < DNS_RESOLUTION_RETRIES - 1) { Log.w(TAG, "DNS host \"" + ep.getHost() + "\" failed to resolve; trying again"); Thread.sleep(1000); @@ -467,7 +483,7 @@ public final class GoBackend implements Backend { } // Build config - final String goConfig = config.toWgUserspaceString(); + final String goConfig = config.toWgUserspaceString(resolver); // Create the vpn tunnel with android API final VpnService.Builder builder = service.getBuilder(); @@ -541,6 +557,9 @@ public final class GoBackend implements Backend { service.protect(wgGetSocketV4(currentTunnelHandle)); service.protect(wgGetSocketV6(currentTunnelHandle)); + + NetworkRequest req = new NetworkRequest.Builder().addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN).build(); + connectivityManager.requestNetwork(req, myNetworkCallback); } else { if (currentTunnelHandle == -1) { Log.w(TAG, "Tunnel already down"); @@ -551,6 +570,8 @@ public final class GoBackend implements Backend { currentTunnelHandle = -1; currentConfig = null; stopHttpProxy(); + connectivityManager.unregisterNetworkCallback(myNetworkCallback); + activeNetwork = null; wgTurnOff(handleToClose); } @@ -616,8 +637,11 @@ public final class GoBackend implements Backend { owner.stopHttpProxy(); final Tunnel tunnel = owner.currentTunnel; if (tunnel != null) { - if (owner.currentTunnelHandle != -1) + if (owner.currentTunnelHandle != -1) { + owner.connectivityManager.unregisterNetworkCallback(owner.myNetworkCallback); + owner.activeNetwork = null; wgTurnOff(owner.currentTunnelHandle); + } owner.currentTunnel = null; owner.currentTunnelHandle = -1; owner.currentConfig = null; @@ -643,4 +667,27 @@ public final class GoBackend implements Backend { this.owner = owner; } } + + private class MyNetworkCallback extends ConnectivityManager.NetworkCallback { + @Override + public void onAvailable(Network network) { + activeNetwork = network; + Log.w(TAG, "onAvailable: " + activeNetwork); + } + + @Override + public void onLinkPropertiesChanged(Network network, LinkProperties linkProperties) { + Log.w(TAG, "onLinkPropertiesChanged: " + network + " is default:" + (network.equals(activeNetwork))); + if (network.equals(activeNetwork) && currentConfig != null && currentTunnelHandle > -1) { + final Resolver resolver = new Resolver(network, linkProperties); + final String goConfig = currentConfig.toWgUserspaceStringWithChangedEndpoints(resolver); + Log.w(TAG, "is default network, config:" + goConfig); + + LibwgGrpc.LibwgBlockingStub stub = LibwgGrpc.newBlockingStub(channel); + TunnelHandle tunnel = TunnelHandle.newBuilder().setHandle(currentTunnelHandle).build(); + IpcSetRequest request = IpcSetRequest.newBuilder().setTunnel(tunnel).setConfig(goConfig).build(); + IpcSetResponse resp = stub.ipcSet(request); + } + } + } } diff --git a/tunnel/src/main/java/com/wireguard/config/Config.java b/tunnel/src/main/java/com/wireguard/config/Config.java index 807ebec8..ea05e9c8 100644 --- a/tunnel/src/main/java/com/wireguard/config/Config.java +++ b/tunnel/src/main/java/com/wireguard/config/Config.java @@ -9,6 +9,7 @@ import com.wireguard.config.BadConfigException.Location; import com.wireguard.config.BadConfigException.Reason; import com.wireguard.config.BadConfigException.Section; import com.wireguard.util.NonNullForAll; +import com.wireguard.util.Resolver; import java.io.BufferedReader; import java.io.IOException; @@ -173,12 +174,19 @@ public final class Config { * * @return the {@code Config} represented as a series of "key=value" lines */ - public String toWgUserspaceString() { + public String toWgUserspaceString(Resolver resolver) { final StringBuilder sb = new StringBuilder(); sb.append(interfaze.toWgUserspaceString()); sb.append("replace_peers=true\n"); for (final Peer peer : peers) - sb.append(peer.toWgUserspaceString()); + sb.append(peer.toWgUserspaceString(resolver)); + return sb.toString(); + } + + public String toWgUserspaceStringWithChangedEndpoints(Resolver resolver) { + final StringBuilder sb = new StringBuilder(); + for (final Peer peer : peers) + sb.append(peer.toWgUserspaceStringWithChangedEndpoint(resolver)); return sb.toString(); } diff --git a/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java b/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java index 66855f11..abd888c8 100644 --- a/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java +++ b/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java @@ -6,6 +6,7 @@ package com.wireguard.config; import com.wireguard.util.NonNullForAll; +import com.wireguard.util.Resolver; import java.net.Inet4Address; import java.net.InetAddress; @@ -87,24 +88,34 @@ public final class InetEndpoint { * * @return the resolved endpoint, or {@link Optional#empty()} */ - public Optional<InetEndpoint> getResolved() { - if (isResolved) + public Optional<InetEndpoint> getResolved(Resolver resolver) { + return getResolved(resolver, false, false); + } + + public Optional<InetEndpoint> getResolved(Resolver resolver, Boolean force) { + return getResolved(resolver, force, false); + } + + public Optional<InetEndpoint> getResolvedIfChanged(Resolver resolver) { + return getResolved(resolver, true, true); + } + + public Optional<InetEndpoint> getResolved(Resolver resolver, Boolean force, Boolean ifChanged) { + if (!force && isResolved) return Optional.of(this); synchronized (lock) { //TODO(zx2c4): Implement a real timeout mechanism using DNS TTL - if (Duration.between(lastResolution, Instant.now()).toMinutes() > 1) { + if (force || Duration.between(lastResolution, Instant.now()).toMinutes() > 1) { try { - // Prefer v4 endpoints over v6 to work around DNS64 and IPv6 NAT issues. - final InetAddress[] candidates = InetAddress.getAllByName(host); - InetAddress address = candidates[0]; - for (final InetAddress candidate : candidates) { - if (candidate instanceof Inet4Address) { - address = candidate; - break; - } - } - resolved = new InetEndpoint(address.getHostAddress(), true, port); + InetAddress address = resolver.resolve(host); + InetEndpoint resolvedNow = new InetEndpoint(address.getHostAddress(), true, port); lastResolution = Instant.now(); + + if (ifChanged && resolvedNow.equals(resolved)) { + return Optional.empty(); + } + + resolved = resolvedNow; } catch (final UnknownHostException e) { resolved = null; } diff --git a/tunnel/src/main/java/com/wireguard/config/Peer.java b/tunnel/src/main/java/com/wireguard/config/Peer.java index 9b87b397..3cf5dc15 100644 --- a/tunnel/src/main/java/com/wireguard/config/Peer.java +++ b/tunnel/src/main/java/com/wireguard/config/Peer.java @@ -11,6 +11,7 @@ import com.wireguard.config.BadConfigException.Section; import com.wireguard.crypto.Key; import com.wireguard.crypto.KeyFormatException; import com.wireguard.util.NonNullForAll; +import com.wireguard.util.Resolver; import java.util.Collection; import java.util.Collections; @@ -190,18 +191,26 @@ public final class Peer { * * @return the {@code Peer} represented as a series of "key=value" lines */ - public String toWgUserspaceString() { + public String toWgUserspaceString(Resolver resolver) { final StringBuilder sb = new StringBuilder(); // The order here is important: public_key signifies the beginning of a new peer. sb.append("public_key=").append(publicKey.toHex()).append('\n'); for (final InetNetwork allowedIp : allowedIps) sb.append("allowed_ip=").append(allowedIp).append('\n'); - endpoint.flatMap(InetEndpoint::getResolved).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); + endpoint.flatMap(ep -> ep.getResolved(resolver)).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); persistentKeepalive.ifPresent(pk -> sb.append("persistent_keepalive_interval=").append(pk).append('\n')); preSharedKey.ifPresent(psk -> sb.append("preshared_key=").append(psk.toHex()).append('\n')); return sb.toString(); } + public String toWgUserspaceStringWithChangedEndpoint(Resolver resolver) { + final StringBuilder sb = new StringBuilder(); + // The order here is important: public_key signifies the beginning of a new peer. + sb.append("public_key=").append(publicKey.toHex()).append('\n'); + endpoint.flatMap(ep -> ep.getResolved(resolver, true)).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n')); + return sb.toString(); + } + @SuppressWarnings("UnusedReturnValue") public static final class Builder { // See wg(8) diff --git a/tunnel/src/main/java/com/wireguard/util/Resolver.java b/tunnel/src/main/java/com/wireguard/util/Resolver.java new file mode 100644 index 00000000..f401b584 --- /dev/null +++ b/tunnel/src/main/java/com/wireguard/util/Resolver.java @@ -0,0 +1,134 @@ +/* + * Copyright © 2023 WireGuard LLC. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.wireguard.util; + +import java.io.IOException; +import java.net.DatagramSocket; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketException; +import java.net.UnknownHostException; + +import android.net.IpPrefix; +import android.net.LinkProperties; +import android.net.Network; +import android.util.Log; + +import androidx.annotation.Nullable; + +@NonNullForAll +public class Resolver { + private static final String TAG = "WireGuard/Resolver"; + @Nullable private final Network network; + @Nullable private final LinkProperties linkProps; + @Nullable private IpPrefix nat64Prefix; + + public Resolver(Network network, LinkProperties linkProps) { + this.network = network; + this.linkProps = linkProps; + if (linkProps != null) { + this.nat64Prefix = linkProps.getNat64Prefix(); + } + } + + static boolean isULA(Inet6Address addr) { + byte[] raw = addr.getAddress(); + return ((raw[0] & 0xfe) == 0xfc); + } + + boolean isWithinNAT64Prefix(Inet6Address address) { + if (nat64Prefix == null) + return false; + + int prefixLength = nat64Prefix.getPrefixLength(); + byte[] rawAddr = address.getAddress(); + byte[] rawPrefix = nat64Prefix.getRawAddress(); + + for (int i=0; i < prefixLength/8; i++) { + if (rawAddr[i] != rawPrefix[i]) + return false; + } + + return true; + } + + boolean isPreferredIPv6(Inet6Address local, Inet6Address remote) { + if (linkProps == null) { + // Prefer IPv4 if there are not link properties that can + // be tested. + return false; + } + + // * Prefer IPv4 if local or remote address is ULA + // * Prefer IPv4 if remote IPv6 is within NAT64 prefix. + // * Otherwise prefer IPv6 + boolean isLocalULA = isULA(local); + boolean isRemoteULA = isULA(remote); + + if (isLocalULA || isRemoteULA) { + return false; + } + + if (isWithinNAT64Prefix(remote)) { + return false; + } + + return true; + } + + public InetAddress resolve(String host) throws UnknownHostException { + final InetAddress[] candidates = network != null ? network.getAllByName(host) : InetAddress.getAllByName(host); + InetAddress address = candidates[0]; + for (final InetAddress candidate : candidates) { + DatagramSocket sock; + + try { + sock = new DatagramSocket(); + if (network != null) { + network.bindSocket(sock); + } + } catch (SocketException e) { + // Return first candidate as fallback + Log.w(TAG, "DatagramSocket failed, fallback to: \"" + address); + return address; + } catch (IOException e) { + // Return first candidate as fallback + Log.w(TAG, "BindSocket failed, fallback to: \"" + address); + return address; + } + + sock.connect(candidate, 51820); + + if (sock.getLocalAddress().isAnyLocalAddress()) { + // Connect didn't find a local address. + Log.w(TAG, "No local address"); + continue; + } + + Log.w(TAG, "Local address: " + sock.getLocalAddress()); + + if (candidate instanceof Inet4Address) { + // Accept IPv4 as preferred address. + address = candidate; + break; + } + + Inet6Address local = (Inet6Address)sock.getLocalAddress(); + InetSocketAddress remoteSockAddr = (InetSocketAddress)sock.getRemoteSocketAddress(); + Inet6Address remote = (Inet6Address)remoteSockAddr.getAddress(); + sock.close(); + + if (isPreferredIPv6(local, remote)) { + address = candidate; + break; + } + } + Log.w(TAG, "Resolved \"" + host + "\" to: " + address); + return address; + } +} diff --git a/tunnel/src/main/proto/libwg.proto b/tunnel/src/main/proto/libwg.proto index 977dacdd..e633ea46 100644 --- a/tunnel/src/main/proto/libwg.proto +++ b/tunnel/src/main/proto/libwg.proto @@ -14,6 +14,7 @@ service Libwg { rpc StartHttpProxy(StartHttpProxyRequest) returns (StartHttpProxyResponse); rpc StopHttpProxy(StopHttpProxyRequest) returns (StopHttpProxyResponse); rpc Reverse(stream ReverseRequest) returns (stream ReverseResponse); + rpc IpcSet(IpcSetRequest) returns (IpcSetResponse); } message TunnelHandle { int32 handle = 1; } @@ -91,3 +92,12 @@ message GetConnectionOwnerUidResponse { int32 uid = 1; string package = 2; // context.getPackageManager().getNameForUid() } + +message IpcSetRequest { + TunnelHandle tunnel = 1; + string config = 2; +} + +message IpcSetResponse { + Error error = 1; +} diff --git a/tunnel/tools/libwg-go/api-android.go b/tunnel/tools/libwg-go/api-android.go index fd0142e1..0ab80be9 100644 --- a/tunnel/tools/libwg-go/api-android.go +++ b/tunnel/tools/libwg-go/api-android.go @@ -53,6 +53,11 @@ type TunnelHandle struct { var tunnelHandles map[int32]TunnelHandle +func GetTunnel(handle int32) (tunnelHandle TunnelHandle, ok bool) { + tunnelHandle, ok = tunnelHandles[handle] + return +} + func init() { tunnelHandles = make(map[int32]TunnelHandle) signals := make(chan os.Signal) diff --git a/tunnel/tools/libwg-go/service.go b/tunnel/tools/libwg-go/service.go index 37fb4b40..1f2e629c 100644 --- a/tunnel/tools/libwg-go/service.go +++ b/tunnel/tools/libwg-go/service.go @@ -217,3 +217,30 @@ func (e *LibwgServiceImpl) Reverse(stream gen.Libwg_ReverseServer) error { e.logger.Verbosef("Reverse returns") return nil } + +func (e *LibwgServiceImpl) IpcSet(ctx context.Context, req *gen.IpcSetRequest) (*gen.IpcSetResponse, error) { + tunnel, ok := GetTunnel(req.GetTunnel().GetHandle()) + if !ok { + r := &gen.IpcSetResponse{ + Error: &gen.Error{ + Message: fmt.Sprintf("Invalid tunnel"), + }, + } + return r, nil + } + + err := tunnel.device.IpcSet(req.GetConfig()) + if err != nil { + r := &gen.IpcSetResponse{ + Error: &gen.Error{ + Message: fmt.Sprintf("IpcSet failed: %v", err), + }, + } + return r, nil + } + + r := &gen.IpcSetResponse{ + } + + return r, nil +} |