diff options
Diffstat (limited to 'tunnel/tools/libwg-go/api-android.go')
-rw-r--r-- | tunnel/tools/libwg-go/api-android.go | 163 |
1 files changed, 160 insertions, 3 deletions
diff --git a/tunnel/tools/libwg-go/api-android.go b/tunnel/tools/libwg-go/api-android.go index 01608923..7881d74c 100644 --- a/tunnel/tools/libwg-go/api-android.go +++ b/tunnel/tools/libwg-go/api-android.go @@ -17,14 +17,31 @@ import ( "os/signal" "runtime" "runtime/debug" + "sort" "strings" "unsafe" "golang.org/x/sys/unix" + + "golang.zx2c4.com/go118/netip" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" + "golang.zx2c4.com/wireguard/tun/netstack" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/fdbased" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + + "golang.zx2c4.com/wireguard/android/gen" ) type AndroidLogger struct { @@ -48,12 +65,31 @@ func (l AndroidLogger) Printf(format string, args ...interface{}) { type TunnelHandle struct { device *device.Device uapi net.Listener + wgId tcpip.NICID + tunId tcpip.NICID + tunFd int logger *device.Logger + tnet *netstack.Net } var tunnelHandles map[int32]TunnelHandle +var stk *stack.Stack +var nextID tcpip.NICID + +func getNextNicID() (nic tcpip.NICID) { + nic = nextID + nextID++ + return +} func init() { + nextID = tcpip.NICID(1) + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + } + stk = stack.New(opts) tunnelHandles = make(map[int32]TunnelHandle) signals := make(chan os.Signal) signal.Notify(signals, unix.SIGUSR2) @@ -73,21 +109,136 @@ func init() { }() } +func closedFunc(_ tcpip.Error) { + //log.Printf("closedFunc: %v", err) +} + +func newTun(fd int, mtu int) (stack.LinkEndpoint, error) { + tunEP, err := fdbased.New(&fdbased.Options{FDs: []int{fd}, MTU: uint32(mtu), ClosedFunc: closedFunc}) + if err != nil { + return nil, err + } + return tunEP, err +} + +func addAddress(nicID tcpip.NICID, addr netip.Addr) error { + var protoNumber tcpip.NetworkProtocolNumber + if addr.Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.Address(addr.AsSlice()).WithPrefix(), + } + err := stk.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}) + if err != nil { + return fmt.Errorf("AddProtocolAddress(%d, %v): %v", nicID, protoAddr, err) + } + return nil +} + +func createNetTUNWithStack(wgId tcpip.NICID, tunId tcpip.NICID, tunFd int, routeAddresses, tunAddresses, wgAddresses []netip.Addr, mtu int, logger *device.Logger) (tun.Device, error) { + dev, netEP, err := netstack.NewNetTUN(stk, mtu) + if err != nil { + return nil, err //fmt.Errorf("NewNetTUN: %v", tcpipErr) + } + + tcpipErr := stk.CreateNICWithOptions(wgId, netEP, stack.NICOptions{Name: "wg0"}) + if tcpipErr != nil { + return nil, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + + tunEP, err := newTun(tunFd, mtu) + if err != nil { + return nil, err + } + + tcpipErr = stk.CreateNICWithOptions(tunId, tunEP, stack.NICOptions{Name: "tun0"}) + if tcpipErr != nil { + return nil, fmt.Errorf("CreateNICWithOptions: %v", tcpipErr) + } + + for _, addr := range tunAddresses { + err := addAddress(tunId, addr) + if err != nil { + return nil, err + } + } + for _, addr := range wgAddresses { + err := addAddress(wgId, addr) + if err != nil { + return nil, err + } + } + for _, addr := range routeAddresses { + dest := tcpip.Address(addr.AsSlice()).WithPrefix().Subnet() + stk.AddRoute(tcpip.Route{Destination: dest, NIC: tunId}) + } + stk.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: wgId}) + stk.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: wgId}) + + return dev, nil +} + //export wgTurnOn func wgTurnOn(interfaceName string, tunFd int32, settings string) int32 { + return turnOn(interfaceName, tunFd, settings, nil) +} + +func turnOn(interfaceName string, tunFd int32, settings string, addresses []*gen.InetAddress) int32 { tag := cstring("WireGuard/GoBackend/" + interfaceName) logger := &device.Logger{ Verbosef: AndroidLogger{level: C.ANDROID_LOG_DEBUG, tag: tag}.Printf, Errorf: AndroidLogger{level: C.ANDROID_LOG_ERROR, tag: tag}.Printf, } - tun, name, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd)) + logger.Verbosef("Got addresses '%v'", addresses) + + tunId := getNextNicID() + wgId := getNextNicID() + + routeAddresses := make([]netip.Addr, len(addresses)) + pos := 0 + for _, addr := range addresses { + ip, ok := netip.AddrFromSlice(addr.Address) + if ok { + routeAddresses[pos] = ip + pos++ + } + } + + tun, err := createNetTUNWithStack(wgId, tunId, int(tunFd), + // routeAddresses + routeAddresses, + // tunAddresses + []netip.Addr{ + netip.MustParseAddr("169.254.0.1"), + netip.MustParseAddr("fe80::1"), + }, + // wgAddresses + []netip.Addr{ + netip.MustParseAddr("169.254.1.2"), + netip.MustParseAddr("fe80::2"), + }, + 1420, logger) if err != nil { unix.Close(int(tunFd)) - logger.Errorf("CreateUnmonitoredTUNFromFD: %v", err) + logger.Errorf("CcreateNetTUNWithStack: %v", err) return -1 } + // Sort route table for longest prefix match + routes := stk.GetRouteTable() + sort.Slice(routes, func(i, j int) bool { + return routes[i].Destination.Prefix() > routes[j].Destination.Prefix() + }) + stk.SetRouteTable(routes) + stk.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true) + stk.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true) + + name := "FIXME" logger.Verbosef("Attaching to interface %v", name) device := device.NewDevice(tun, conn.NewStdNetBind(), logger) @@ -143,7 +294,7 @@ func wgTurnOn(interfaceName string, tunFd int32, settings string) int32 { device.Close() return -1 } - tunnelHandles[i] = TunnelHandle{device: device, uapi: uapi} + tunnelHandles[i] = TunnelHandle{device: device, uapi: uapi, wgId: wgId, tunId: tunId, tunFd: int(tunFd), logger: logger} return i } @@ -154,6 +305,12 @@ func wgTurnOff(tunnelHandle int32) { return } delete(tunnelHandles, tunnelHandle) + handle.logger.Verbosef("Remove wg") + stk.RemoveNIC(handle.wgId) + handle.logger.Verbosef("Remove tun") + stk.RemoveNIC(handle.tunId) + handle.logger.Verbosef("Remove done.") + unix.Close(handle.tunFd) if handle.uapi != nil { handle.uapi.Close() } |