summaryrefslogtreecommitdiffhomepage
path: root/tunnel
diff options
context:
space:
mode:
authorMikael Magnusson <mikma@users.sourceforge.net>2022-12-29 00:23:17 +0100
committerMikael Magnusson <mikma@users.sourceforge.net>2023-03-23 22:52:53 +0100
commitc647112c07e91232dbaed28251293c4c9c2aa083 (patch)
tree29c635f04a67a0253660b2db2f89abdad58da345 /tunnel
parentd8056a134b0021650cef914d0985e556802c095e (diff)
tunnel: auto-detect IPv6/IPv4 preference
Detect IP address change. Request non-VPN network. Update endpoint when needed. Unregister network on wgTurnOff and use IPv4 if network is not known.
Diffstat (limited to 'tunnel')
-rw-r--r--tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java53
-rw-r--r--tunnel/src/main/java/com/wireguard/config/Config.java12
-rw-r--r--tunnel/src/main/java/com/wireguard/config/InetEndpoint.java37
-rw-r--r--tunnel/src/main/java/com/wireguard/config/Peer.java13
-rw-r--r--tunnel/src/main/java/com/wireguard/util/Resolver.java134
-rw-r--r--tunnel/src/main/proto/libwg.proto10
-rw-r--r--tunnel/tools/libwg-go/api-android.go5
-rw-r--r--tunnel/tools/libwg-go/service.go27
8 files changed, 271 insertions, 20 deletions
diff --git a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java
index 968f2ade..f38a72fd 100644
--- a/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java
+++ b/tunnel/src/main/java/com/wireguard/android/backend/GoBackend.java
@@ -9,7 +9,11 @@ import android.content.Context;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.net.ConnectivityManager;
+import android.net.LinkProperties;
import android.net.LocalSocketAddress;
+import android.net.Network;
+import android.net.NetworkCapabilities;
+import android.net.NetworkRequest;
import android.net.ProxyInfo;
import android.net.Uri;
import android.os.Build;
@@ -25,6 +29,8 @@ import com.google.protobuf.Empty;
import com.wireguard.android.backend.BackendException.Reason;
import com.wireguard.android.backend.Tunnel.State;
import com.wireguard.android.backend.gen.GetConnectionOwnerUidResponse;
+import com.wireguard.android.backend.gen.IpcSetRequest;
+import com.wireguard.android.backend.gen.IpcSetResponse;
import com.wireguard.android.backend.gen.LibwgGrpc;
import com.wireguard.android.backend.gen.ReverseRequest;
import com.wireguard.android.backend.gen.ReverseResponse;
@@ -32,6 +38,7 @@ import com.wireguard.android.backend.gen.StartHttpProxyRequest;
import com.wireguard.android.backend.gen.StartHttpProxyResponse;
import com.wireguard.android.backend.gen.StopHttpProxyRequest;
import com.wireguard.android.backend.gen.StopHttpProxyResponse;
+import com.wireguard.android.backend.gen.TunnelHandle;
import com.wireguard.android.backend.gen.VersionRequest;
import com.wireguard.android.backend.gen.VersionResponse;
import com.wireguard.android.util.SharedLibraryLoader;
@@ -43,6 +50,7 @@ import com.wireguard.config.Peer;
import com.wireguard.crypto.Key;
import com.wireguard.crypto.KeyFormatException;
import com.wireguard.util.NonNullForAll;
+import com.wireguard.util.Resolver;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
@@ -94,6 +102,8 @@ public final class GoBackend implements Backend {
private int currentTunnelHandle = -1;
private ManagedChannel channel;
private ConnectivityManager connectivityManager;
+ private ConnectivityManager.NetworkCallback myNetworkCallback = new MyNetworkCallback();
+ @Nullable private Network activeNetwork;
/**
* Public constructor for GoBackend.
@@ -448,13 +458,19 @@ public final class GoBackend implements Backend {
}
+ activeNetwork = connectivityManager.getActiveNetwork();
+ if (!connectivityManager.getNetworkCapabilities(activeNetwork).hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN)) {
+ Log.w(TAG, "VPN network is active, null activeNetwork");
+ activeNetwork = null;
+ }
+ final Resolver resolver = new Resolver(activeNetwork, connectivityManager.getLinkProperties(activeNetwork));
dnsRetry: for (int i = 0; i < DNS_RESOLUTION_RETRIES; ++i) {
// Pre-resolve IPs so they're cached when building the userspace string
for (final Peer peer : config.getPeers()) {
final InetEndpoint ep = peer.getEndpoint().orElse(null);
if (ep == null)
continue;
- if (ep.getResolved().orElse(null) == null) {
+ if (ep.getResolved(resolver, true).orElse(null) == null) {
if (i < DNS_RESOLUTION_RETRIES - 1) {
Log.w(TAG, "DNS host \"" + ep.getHost() + "\" failed to resolve; trying again");
Thread.sleep(1000);
@@ -467,7 +483,7 @@ public final class GoBackend implements Backend {
}
// Build config
- final String goConfig = config.toWgUserspaceString();
+ final String goConfig = config.toWgUserspaceString(resolver);
// Create the vpn tunnel with android API
final VpnService.Builder builder = service.getBuilder();
@@ -541,6 +557,9 @@ public final class GoBackend implements Backend {
service.protect(wgGetSocketV4(currentTunnelHandle));
service.protect(wgGetSocketV6(currentTunnelHandle));
+
+ NetworkRequest req = new NetworkRequest.Builder().addCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN).build();
+ connectivityManager.requestNetwork(req, myNetworkCallback);
} else {
if (currentTunnelHandle == -1) {
Log.w(TAG, "Tunnel already down");
@@ -551,6 +570,8 @@ public final class GoBackend implements Backend {
currentTunnelHandle = -1;
currentConfig = null;
stopHttpProxy();
+ connectivityManager.unregisterNetworkCallback(myNetworkCallback);
+ activeNetwork = null;
wgTurnOff(handleToClose);
}
@@ -616,8 +637,11 @@ public final class GoBackend implements Backend {
owner.stopHttpProxy();
final Tunnel tunnel = owner.currentTunnel;
if (tunnel != null) {
- if (owner.currentTunnelHandle != -1)
+ if (owner.currentTunnelHandle != -1) {
+ owner.connectivityManager.unregisterNetworkCallback(owner.myNetworkCallback);
+ owner.activeNetwork = null;
wgTurnOff(owner.currentTunnelHandle);
+ }
owner.currentTunnel = null;
owner.currentTunnelHandle = -1;
owner.currentConfig = null;
@@ -643,4 +667,27 @@ public final class GoBackend implements Backend {
this.owner = owner;
}
}
+
+ private class MyNetworkCallback extends ConnectivityManager.NetworkCallback {
+ @Override
+ public void onAvailable(Network network) {
+ activeNetwork = network;
+ Log.w(TAG, "onAvailable: " + activeNetwork);
+ }
+
+ @Override
+ public void onLinkPropertiesChanged(Network network, LinkProperties linkProperties) {
+ Log.w(TAG, "onLinkPropertiesChanged: " + network + " is default:" + (network.equals(activeNetwork)));
+ if (network.equals(activeNetwork) && currentConfig != null && currentTunnelHandle > -1) {
+ final Resolver resolver = new Resolver(network, linkProperties);
+ final String goConfig = currentConfig.toWgUserspaceStringWithChangedEndpoints(resolver);
+ Log.w(TAG, "is default network, config:" + goConfig);
+
+ LibwgGrpc.LibwgBlockingStub stub = LibwgGrpc.newBlockingStub(channel);
+ TunnelHandle tunnel = TunnelHandle.newBuilder().setHandle(currentTunnelHandle).build();
+ IpcSetRequest request = IpcSetRequest.newBuilder().setTunnel(tunnel).setConfig(goConfig).build();
+ IpcSetResponse resp = stub.ipcSet(request);
+ }
+ }
+ }
}
diff --git a/tunnel/src/main/java/com/wireguard/config/Config.java b/tunnel/src/main/java/com/wireguard/config/Config.java
index ee9cebce..5cd03fc0 100644
--- a/tunnel/src/main/java/com/wireguard/config/Config.java
+++ b/tunnel/src/main/java/com/wireguard/config/Config.java
@@ -9,6 +9,7 @@ import com.wireguard.config.BadConfigException.Location;
import com.wireguard.config.BadConfigException.Reason;
import com.wireguard.config.BadConfigException.Section;
import com.wireguard.util.NonNullForAll;
+import com.wireguard.util.Resolver;
import java.io.BufferedReader;
import java.io.IOException;
@@ -173,12 +174,19 @@ public final class Config {
*
* @return the {@code Config} represented as a series of "key=value" lines
*/
- public String toWgUserspaceString() {
+ public String toWgUserspaceString(Resolver resolver) {
final StringBuilder sb = new StringBuilder();
sb.append(interfaze.toWgUserspaceString());
sb.append("replace_peers=true\n");
for (final Peer peer : peers)
- sb.append(peer.toWgUserspaceString());
+ sb.append(peer.toWgUserspaceString(resolver));
+ return sb.toString();
+ }
+
+ public String toWgUserspaceStringWithChangedEndpoints(Resolver resolver) {
+ final StringBuilder sb = new StringBuilder();
+ for (final Peer peer : peers)
+ sb.append(peer.toWgUserspaceStringWithChangedEndpoint(resolver));
return sb.toString();
}
diff --git a/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java b/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java
index d1db432b..9d087cd6 100644
--- a/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java
+++ b/tunnel/src/main/java/com/wireguard/config/InetEndpoint.java
@@ -6,6 +6,7 @@
package com.wireguard.config;
import com.wireguard.util.NonNullForAll;
+import com.wireguard.util.Resolver;
import java.net.Inet4Address;
import java.net.InetAddress;
@@ -87,24 +88,34 @@ public final class InetEndpoint {
*
* @return the resolved endpoint, or {@link Optional#empty()}
*/
- public Optional<InetEndpoint> getResolved() {
- if (isResolved)
+ public Optional<InetEndpoint> getResolved(Resolver resolver) {
+ return getResolved(resolver, false, false);
+ }
+
+ public Optional<InetEndpoint> getResolved(Resolver resolver, Boolean force) {
+ return getResolved(resolver, force, false);
+ }
+
+ public Optional<InetEndpoint> getResolvedIfChanged(Resolver resolver) {
+ return getResolved(resolver, true, true);
+ }
+
+ public Optional<InetEndpoint> getResolved(Resolver resolver, Boolean force, Boolean ifChanged) {
+ if (!force && isResolved)
return Optional.of(this);
synchronized (lock) {
//TODO(zx2c4): Implement a real timeout mechanism using DNS TTL
- if (Duration.between(lastResolution, Instant.now()).toMinutes() > 1) {
+ if (force || Duration.between(lastResolution, Instant.now()).toMinutes() > 1) {
try {
- // Prefer v4 endpoints over v6 to work around DNS64 and IPv6 NAT issues.
- final InetAddress[] candidates = InetAddress.getAllByName(host);
- InetAddress address = candidates[0];
- for (final InetAddress candidate : candidates) {
- if (candidate instanceof Inet4Address) {
- address = candidate;
- break;
- }
- }
- resolved = new InetEndpoint(address.getHostAddress(), true, port);
+ InetAddress address = resolver.resolve(host);
+ InetEndpoint resolvedNow = new InetEndpoint(address.getHostAddress(), true, port);
lastResolution = Instant.now();
+
+ if (ifChanged && resolvedNow.equals(resolved)) {
+ return Optional.empty();
+ }
+
+ resolved = resolvedNow;
} catch (final UnknownHostException e) {
resolved = null;
}
diff --git a/tunnel/src/main/java/com/wireguard/config/Peer.java b/tunnel/src/main/java/com/wireguard/config/Peer.java
index 8a0fd763..d7b75f16 100644
--- a/tunnel/src/main/java/com/wireguard/config/Peer.java
+++ b/tunnel/src/main/java/com/wireguard/config/Peer.java
@@ -11,6 +11,7 @@ import com.wireguard.config.BadConfigException.Section;
import com.wireguard.crypto.Key;
import com.wireguard.crypto.KeyFormatException;
import com.wireguard.util.NonNullForAll;
+import com.wireguard.util.Resolver;
import java.util.Collection;
import java.util.Collections;
@@ -190,18 +191,26 @@ public final class Peer {
*
* @return the {@code Peer} represented as a series of "key=value" lines
*/
- public String toWgUserspaceString() {
+ public String toWgUserspaceString(Resolver resolver) {
final StringBuilder sb = new StringBuilder();
// The order here is important: public_key signifies the beginning of a new peer.
sb.append("public_key=").append(publicKey.toHex()).append('\n');
for (final InetNetwork allowedIp : allowedIps)
sb.append("allowed_ip=").append(allowedIp).append('\n');
- endpoint.flatMap(InetEndpoint::getResolved).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n'));
+ endpoint.flatMap(ep -> ep.getResolved(resolver)).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n'));
persistentKeepalive.ifPresent(pk -> sb.append("persistent_keepalive_interval=").append(pk).append('\n'));
preSharedKey.ifPresent(psk -> sb.append("preshared_key=").append(psk.toHex()).append('\n'));
return sb.toString();
}
+ public String toWgUserspaceStringWithChangedEndpoint(Resolver resolver) {
+ final StringBuilder sb = new StringBuilder();
+ // The order here is important: public_key signifies the beginning of a new peer.
+ sb.append("public_key=").append(publicKey.toHex()).append('\n');
+ endpoint.flatMap(ep -> ep.getResolved(resolver, true)).ifPresent(ep -> sb.append("endpoint=").append(ep).append('\n'));
+ return sb.toString();
+ }
+
@SuppressWarnings("UnusedReturnValue")
public static final class Builder {
// See wg(8)
diff --git a/tunnel/src/main/java/com/wireguard/util/Resolver.java b/tunnel/src/main/java/com/wireguard/util/Resolver.java
new file mode 100644
index 00000000..f401b584
--- /dev/null
+++ b/tunnel/src/main/java/com/wireguard/util/Resolver.java
@@ -0,0 +1,134 @@
+/*
+ * Copyright © 2023 WireGuard LLC. All Rights Reserved.
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+package com.wireguard.util;
+
+import java.io.IOException;
+import java.net.DatagramSocket;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.SocketException;
+import java.net.UnknownHostException;
+
+import android.net.IpPrefix;
+import android.net.LinkProperties;
+import android.net.Network;
+import android.util.Log;
+
+import androidx.annotation.Nullable;
+
+@NonNullForAll
+public class Resolver {
+ private static final String TAG = "WireGuard/Resolver";
+ @Nullable private final Network network;
+ @Nullable private final LinkProperties linkProps;
+ @Nullable private IpPrefix nat64Prefix;
+
+ public Resolver(Network network, LinkProperties linkProps) {
+ this.network = network;
+ this.linkProps = linkProps;
+ if (linkProps != null) {
+ this.nat64Prefix = linkProps.getNat64Prefix();
+ }
+ }
+
+ static boolean isULA(Inet6Address addr) {
+ byte[] raw = addr.getAddress();
+ return ((raw[0] & 0xfe) == 0xfc);
+ }
+
+ boolean isWithinNAT64Prefix(Inet6Address address) {
+ if (nat64Prefix == null)
+ return false;
+
+ int prefixLength = nat64Prefix.getPrefixLength();
+ byte[] rawAddr = address.getAddress();
+ byte[] rawPrefix = nat64Prefix.getRawAddress();
+
+ for (int i=0; i < prefixLength/8; i++) {
+ if (rawAddr[i] != rawPrefix[i])
+ return false;
+ }
+
+ return true;
+ }
+
+ boolean isPreferredIPv6(Inet6Address local, Inet6Address remote) {
+ if (linkProps == null) {
+ // Prefer IPv4 if there are not link properties that can
+ // be tested.
+ return false;
+ }
+
+ // * Prefer IPv4 if local or remote address is ULA
+ // * Prefer IPv4 if remote IPv6 is within NAT64 prefix.
+ // * Otherwise prefer IPv6
+ boolean isLocalULA = isULA(local);
+ boolean isRemoteULA = isULA(remote);
+
+ if (isLocalULA || isRemoteULA) {
+ return false;
+ }
+
+ if (isWithinNAT64Prefix(remote)) {
+ return false;
+ }
+
+ return true;
+ }
+
+ public InetAddress resolve(String host) throws UnknownHostException {
+ final InetAddress[] candidates = network != null ? network.getAllByName(host) : InetAddress.getAllByName(host);
+ InetAddress address = candidates[0];
+ for (final InetAddress candidate : candidates) {
+ DatagramSocket sock;
+
+ try {
+ sock = new DatagramSocket();
+ if (network != null) {
+ network.bindSocket(sock);
+ }
+ } catch (SocketException e) {
+ // Return first candidate as fallback
+ Log.w(TAG, "DatagramSocket failed, fallback to: \"" + address);
+ return address;
+ } catch (IOException e) {
+ // Return first candidate as fallback
+ Log.w(TAG, "BindSocket failed, fallback to: \"" + address);
+ return address;
+ }
+
+ sock.connect(candidate, 51820);
+
+ if (sock.getLocalAddress().isAnyLocalAddress()) {
+ // Connect didn't find a local address.
+ Log.w(TAG, "No local address");
+ continue;
+ }
+
+ Log.w(TAG, "Local address: " + sock.getLocalAddress());
+
+ if (candidate instanceof Inet4Address) {
+ // Accept IPv4 as preferred address.
+ address = candidate;
+ break;
+ }
+
+ Inet6Address local = (Inet6Address)sock.getLocalAddress();
+ InetSocketAddress remoteSockAddr = (InetSocketAddress)sock.getRemoteSocketAddress();
+ Inet6Address remote = (Inet6Address)remoteSockAddr.getAddress();
+ sock.close();
+
+ if (isPreferredIPv6(local, remote)) {
+ address = candidate;
+ break;
+ }
+ }
+ Log.w(TAG, "Resolved \"" + host + "\" to: " + address);
+ return address;
+ }
+}
diff --git a/tunnel/src/main/proto/libwg.proto b/tunnel/src/main/proto/libwg.proto
index 977dacdd..e633ea46 100644
--- a/tunnel/src/main/proto/libwg.proto
+++ b/tunnel/src/main/proto/libwg.proto
@@ -14,6 +14,7 @@ service Libwg {
rpc StartHttpProxy(StartHttpProxyRequest) returns (StartHttpProxyResponse);
rpc StopHttpProxy(StopHttpProxyRequest) returns (StopHttpProxyResponse);
rpc Reverse(stream ReverseRequest) returns (stream ReverseResponse);
+ rpc IpcSet(IpcSetRequest) returns (IpcSetResponse);
}
message TunnelHandle { int32 handle = 1; }
@@ -91,3 +92,12 @@ message GetConnectionOwnerUidResponse {
int32 uid = 1;
string package = 2; // context.getPackageManager().getNameForUid()
}
+
+message IpcSetRequest {
+ TunnelHandle tunnel = 1;
+ string config = 2;
+}
+
+message IpcSetResponse {
+ Error error = 1;
+}
diff --git a/tunnel/tools/libwg-go/api-android.go b/tunnel/tools/libwg-go/api-android.go
index fd0142e1..0ab80be9 100644
--- a/tunnel/tools/libwg-go/api-android.go
+++ b/tunnel/tools/libwg-go/api-android.go
@@ -53,6 +53,11 @@ type TunnelHandle struct {
var tunnelHandles map[int32]TunnelHandle
+func GetTunnel(handle int32) (tunnelHandle TunnelHandle, ok bool) {
+ tunnelHandle, ok = tunnelHandles[handle]
+ return
+}
+
func init() {
tunnelHandles = make(map[int32]TunnelHandle)
signals := make(chan os.Signal)
diff --git a/tunnel/tools/libwg-go/service.go b/tunnel/tools/libwg-go/service.go
index 37fb4b40..1f2e629c 100644
--- a/tunnel/tools/libwg-go/service.go
+++ b/tunnel/tools/libwg-go/service.go
@@ -217,3 +217,30 @@ func (e *LibwgServiceImpl) Reverse(stream gen.Libwg_ReverseServer) error {
e.logger.Verbosef("Reverse returns")
return nil
}
+
+func (e *LibwgServiceImpl) IpcSet(ctx context.Context, req *gen.IpcSetRequest) (*gen.IpcSetResponse, error) {
+ tunnel, ok := GetTunnel(req.GetTunnel().GetHandle())
+ if !ok {
+ r := &gen.IpcSetResponse{
+ Error: &gen.Error{
+ Message: fmt.Sprintf("Invalid tunnel"),
+ },
+ }
+ return r, nil
+ }
+
+ err := tunnel.device.IpcSet(req.GetConfig())
+ if err != nil {
+ r := &gen.IpcSetResponse{
+ Error: &gen.Error{
+ Message: fmt.Sprintf("IpcSet failed: %v", err),
+ },
+ }
+ return r, nil
+ }
+
+ r := &gen.IpcSetResponse{
+ }
+
+ return r, nil
+}