summaryrefslogtreecommitdiffhomepage
path: root/tunnel/tools/libwg-go/api-android.go
diff options
context:
space:
mode:
Diffstat (limited to 'tunnel/tools/libwg-go/api-android.go')
-rw-r--r--tunnel/tools/libwg-go/api-android.go163
1 files changed, 160 insertions, 3 deletions
diff --git a/tunnel/tools/libwg-go/api-android.go b/tunnel/tools/libwg-go/api-android.go
index 01608923..7881d74c 100644
--- a/tunnel/tools/libwg-go/api-android.go
+++ b/tunnel/tools/libwg-go/api-android.go
@@ -17,14 +17,31 @@ import (
"os/signal"
"runtime"
"runtime/debug"
+ "sort"
"strings"
"unsafe"
"golang.org/x/sys/unix"
+
+ "golang.zx2c4.com/go118/netip"
+
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/tun/netstack"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+
+ "golang.zx2c4.com/wireguard/android/gen"
)
type AndroidLogger struct {
@@ -48,12 +65,31 @@ func (l AndroidLogger) Printf(format string, args ...interface{}) {
type TunnelHandle struct {
device *device.Device
uapi net.Listener
+ wgId tcpip.NICID
+ tunId tcpip.NICID
+ tunFd int
logger *device.Logger
+ tnet *netstack.Net
}
var tunnelHandles map[int32]TunnelHandle
+var stk *stack.Stack
+var nextID tcpip.NICID
+
+func getNextNicID() (nic tcpip.NICID) {
+ nic = nextID
+ nextID++
+ return
+}
func init() {
+ nextID = tcpip.NICID(1)
+ opts := stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
+ HandleLocal: true,
+ }
+ stk = stack.New(opts)
tunnelHandles = make(map[int32]TunnelHandle)
signals := make(chan os.Signal)
signal.Notify(signals, unix.SIGUSR2)
@@ -73,21 +109,136 @@ func init() {
}()
}
+func closedFunc(_ tcpip.Error) {
+ //log.Printf("closedFunc: %v", err)
+}
+
+func newTun(fd int, mtu int) (stack.LinkEndpoint, error) {
+ tunEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: uint32(mtu), ClosedFunc: closedFunc})
+ if err != nil {
+ return nil, err
+ }
+ return tunEP, err
+}
+
+func addAddress(nicID tcpip.NICID, addr netip.Addr) error {
+ var protoNumber tcpip.NetworkProtocolNumber
+ if addr.Is4() {
+ protoNumber = ipv4.ProtocolNumber
+ } else {
+ protoNumber = ipv6.ProtocolNumber
+ }
+ protoAddr := tcpip.ProtocolAddress{
+ Protocol: protoNumber,
+ AddressWithPrefix: tcpip.Address(addr.AsSlice()).WithPrefix(),
+ }
+ err := stk.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{})
+ if err != nil {
+ return fmt.Errorf("AddProtocolAddress(%d, %v): %v", nicID, protoAddr, err)
+ }
+ return nil
+}
+
+func createNetTUNWithStack(wgId tcpip.NICID, tunId tcpip.NICID, tunFd int, routeAddresses, tunAddresses, wgAddresses []netip.Addr, mtu int, logger *device.Logger) (tun.Device, error) {
+ dev, netEP, err := netstack.NewNetTUN(stk, mtu)
+ if err != nil {
+ return nil, err //fmt.Errorf("NewNetTUN: %v", tcpipErr)
+ }
+
+ tcpipErr := stk.CreateNICWithOptions(wgId, netEP, stack.NICOptions{Name: "wg0"})
+ if tcpipErr != nil {
+ return nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
+ }
+
+ tunEP, err := newTun(tunFd, mtu)
+ if err != nil {
+ return nil, err
+ }
+
+ tcpipErr = stk.CreateNICWithOptions(tunId, tunEP, stack.NICOptions{Name: "tun0"})
+ if tcpipErr != nil {
+ return nil, fmt.Errorf("CreateNICWithOptions: %v", tcpipErr)
+ }
+
+ for _, addr := range tunAddresses {
+ err := addAddress(tunId, addr)
+ if err != nil {
+ return nil, err
+ }
+ }
+ for _, addr := range wgAddresses {
+ err := addAddress(wgId, addr)
+ if err != nil {
+ return nil, err
+ }
+ }
+ for _, addr := range routeAddresses {
+ dest := tcpip.Address(addr.AsSlice()).WithPrefix().Subnet()
+ stk.AddRoute(tcpip.Route{Destination: dest, NIC: tunId})
+ }
+ stk.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: wgId})
+ stk.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: wgId})
+
+ return dev, nil
+}
+
//export wgTurnOn
func wgTurnOn(interfaceName string, tunFd int32, settings string) int32 {
+ return turnOn(interfaceName, tunFd, settings, nil)
+}
+
+func turnOn(interfaceName string, tunFd int32, settings string, addresses []*gen.InetAddress) int32 {
tag := cstring("WireGuard/GoBackend/" + interfaceName)
logger := &device.Logger{
Verbosef: AndroidLogger{level: C.ANDROID_LOG_DEBUG, tag: tag}.Printf,
Errorf: AndroidLogger{level: C.ANDROID_LOG_ERROR, tag: tag}.Printf,
}
- tun, name, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
+ logger.Verbosef("Got addresses '%v'", addresses)
+
+ tunId := getNextNicID()
+ wgId := getNextNicID()
+
+ routeAddresses := make([]netip.Addr, len(addresses))
+ pos := 0
+ for _, addr := range addresses {
+ ip, ok := netip.AddrFromSlice(addr.Address)
+ if ok {
+ routeAddresses[pos] = ip
+ pos++
+ }
+ }
+
+ tun, err := createNetTUNWithStack(wgId, tunId, int(tunFd),
+ // routeAddresses
+ routeAddresses,
+ // tunAddresses
+ []netip.Addr{
+ netip.MustParseAddr("169.254.0.1"),
+ netip.MustParseAddr("fe80::1"),
+ },
+ // wgAddresses
+ []netip.Addr{
+ netip.MustParseAddr("169.254.1.2"),
+ netip.MustParseAddr("fe80::2"),
+ },
+ 1420, logger)
if err != nil {
unix.Close(int(tunFd))
- logger.Errorf("CreateUnmonitoredTUNFromFD: %v", err)
+ logger.Errorf("CcreateNetTUNWithStack: %v", err)
return -1
}
+ // Sort route table for longest prefix match
+ routes := stk.GetRouteTable()
+ sort.Slice(routes, func(i, j int) bool {
+ return routes[i].Destination.Prefix() > routes[j].Destination.Prefix()
+ })
+ stk.SetRouteTable(routes)
+ stk.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true)
+ stk.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true)
+
+ name := "FIXME"
logger.Verbosef("Attaching to interface %v", name)
device := device.NewDevice(tun, conn.NewStdNetBind(), logger)
@@ -143,7 +294,7 @@ func wgTurnOn(interfaceName string, tunFd int32, settings string) int32 {
device.Close()
return -1
}
- tunnelHandles[i] = TunnelHandle{device: device, uapi: uapi}
+ tunnelHandles[i] = TunnelHandle{device: device, uapi: uapi, wgId: wgId, tunId: tunId, tunFd: int(tunFd), logger: logger}
return i
}
@@ -154,6 +305,12 @@ func wgTurnOff(tunnelHandle int32) {
return
}
delete(tunnelHandles, tunnelHandle)
+ handle.logger.Verbosef("Remove wg")
+ stk.RemoveNIC(handle.wgId)
+ handle.logger.Verbosef("Remove tun")
+ stk.RemoveNIC(handle.tunId)
+ handle.logger.Verbosef("Remove done.")
+ unix.Close(handle.tunFd)
if handle.uapi != nil {
handle.uapi.Close()
}