summaryrefslogtreecommitdiffhomepage
path: root/tunnel
diff options
context:
space:
mode:
Diffstat (limited to 'tunnel')
-rw-r--r--tunnel/tools/libwg-go/Makefile2
-rw-r--r--tunnel/tools/libwg-go/api-android.go4
-rw-r--r--tunnel/tools/libwg-go/nat-tun.go500
3 files changed, 504 insertions, 2 deletions
diff --git a/tunnel/tools/libwg-go/Makefile b/tunnel/tools/libwg-go/Makefile
index 189807d2..b01ffcad 100644
--- a/tunnel/tools/libwg-go/Makefile
+++ b/tunnel/tools/libwg-go/Makefile
@@ -71,7 +71,7 @@ gen/%_grpc.pb.go: $(PROTODIR)/%.proto $(BUILDDIR)/go-$(GO_VERSION)/.prepared $(P
$(PROTOC) -I $(PROTODIR) -I $(PROTO_INCLUDEDIR) --go-grpc_out=./gen --go-grpc_opt=paths=source_relative $<
$(DESTDIR)/libwg-go.so: export PATH := $(BUILDDIR)/go-$(GO_VERSION)/bin/:$(PATH)
-$(DESTDIR)/libwg-go.so: $(BUILDDIR)/go-$(GO_VERSION)/.prepared go.mod api-android.go http-proxy.go service.go gen/libwg.pb.go gen/libwg_grpc.pb.go jni.c
+$(DESTDIR)/libwg-go.so: $(BUILDDIR)/go-$(GO_VERSION)/.prepared go.mod api-android.go http-proxy.go nat-tun.go service.go gen/libwg.pb.go gen/libwg_grpc.pb.go jni.c
go build -tags linux -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/$(ANDROID_PACKAGE_NAME)/cache/wireguard" -v -trimpath -o "$@" -buildmode c-shared
.DELETE_ON_ERROR:
diff --git a/tunnel/tools/libwg-go/api-android.go b/tunnel/tools/libwg-go/api-android.go
index 457359b7..2ce9c49c 100644
--- a/tunnel/tools/libwg-go/api-android.go
+++ b/tunnel/tools/libwg-go/api-android.go
@@ -80,7 +80,9 @@ func wgTurnOn(interfaceName string, tunFd int32, settings string) int32 {
Errorf: AndroidLogger{level: C.ANDROID_LOG_ERROR, tag: tag}.Printf,
}
- tun, name, err := CreateUnmonitoredTUNFromFD(int(tunFd))
+ nativeTun, name, err := CreateUnmonitoredTUNFromFD(int(tunFd))
+
+ tun, err := NewNatTun(nativeTun)
if err != nil {
unix.Close(int(tunFd))
logger.Errorf("CreateUnmonitoredTUNFromFD: %v", err)
diff --git a/tunnel/tools/libwg-go/nat-tun.go b/tunnel/tools/libwg-go/nat-tun.go
new file mode 100644
index 00000000..e9a4537d
--- /dev/null
+++ b/tunnel/tools/libwg-go/nat-tun.go
@@ -0,0 +1,500 @@
+/* SPDX-Identifier-License: MIT
+ *
+ * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved.
+ */
+
+package main
+
+import (
+ "encoding/binary"
+ "fmt"
+ "os"
+
+ "golang.zx2c4.com/go118/netip"
+
+ "golang.zx2c4.com/wireguard/tun"
+)
+
+type connection struct {
+ src netip.AddrPort
+ dst netip.AddrPort
+}
+
+func Connection(src, dst netip.AddrPort) connection {
+ return connection{
+ src: src,
+ dst: dst,
+ }
+}
+
+type natTun struct {
+ tun tun.Device
+ intAddr netip.Addr
+ extAddr netip.Addr
+ srcAddr netip.Addr
+ proxyAddr netip.Addr
+ proxyPort int
+ connections map[connection]connection
+}
+
+func NewNatTun(t tun.Device) (dev tun.Device, err error) {
+ dev = nil
+
+ extAddr, err := netip.ParseAddr("10.49.40.151")
+ if err != nil {
+ return
+ }
+
+ intAddr, err := netip.ParseAddr("10.49.40.101")
+ if err != nil {
+ return
+ }
+
+ proxyAddr, err := netip.ParseAddr("10.49.124.115")
+ if err != nil {
+ return
+ }
+
+ dev = &natTun{
+ tun: t,
+ intAddr: intAddr,
+ extAddr: extAddr,
+ proxyAddr: proxyAddr,
+ proxyPort: 8888,
+ connections: make(map[connection]connection),
+ }
+
+ err = nil
+ return
+}
+
+func (tun *natTun) addConnection(new, orig connection) {
+ // TODO use mutex
+ tun.connections[new] = orig
+}
+
+func (tun *natTun) lookupConnection(new connection) (connection, bool) {
+ c, ok := tun.connections[new]
+ return c, ok
+}
+
+func (tun *natTun) Name() (string, error) {
+ return tun.tun.Name()
+}
+
+func (tun *natTun) File() *os.File {
+ return nil
+}
+
+func (tun *natTun) Events() chan tun.Event {
+ return tun.tun.Events()
+}
+
+const (
+ IPV4_VERSION = 4
+ IPV4_HEADER_LEN = 20 // TODO support option headers, refer to https://cs.opensource.google/go/x/net/+/fe4d6282:ipv4/header.go
+ IPV4_HEADER_PROTOCOL = 9
+ IPV4_HEADER_CHECKSUM = 10
+ IPV4_HEADER_SRC_ADDR = 12
+ IPV4_HEADER_DST_ADDR = 16
+
+ IPV6_VERSION = 6
+ IPV6_HEADER_LEN = 40 // TODO support additional headers
+ IPV6_HEADER_NEXT_HEADER = 6
+ IPV6_HEADER_SRC_ADDR = 8
+ IPV6_HEADER_DST_ADDR = 24
+
+ TCP_HEADER_LEN = 40
+ TCP_HEADER_SRC_PORT = 0
+ TCP_HEADER_DST_PORT = 2
+ TCP_HEADER_CHECKSUM = 16
+
+ UDP_HEADER_SRC_PORT = 0
+ UDP_HEADER_DST_PORT = 2
+ UDP_HEADER_CHECKSUM = 6
+
+ PROTO_TCP = 6
+ PROTO_UDP = 17
+)
+
+func getUint16(header []byte, offset int) int {
+ return int(binary.BigEndian.Uint16(header[offset:offset+2]))
+}
+
+func putUint16(header []byte, offset int, value int) {
+ binary.BigEndian.PutUint16(header[offset:offset+2], uint16(value))
+}
+
+func getIPv4SrcAddr(header []byte) netip.Addr {
+ src, _ := netip.AddrFromSlice(header[IPV4_HEADER_SRC_ADDR:IPV4_HEADER_SRC_ADDR+4])
+ return src
+}
+
+func putIPv4SrcAddr(header []byte, addr netip.Addr) {
+ copy(header[IPV4_HEADER_SRC_ADDR:IPV4_HEADER_SRC_ADDR+4], addr.AsSlice())
+}
+
+func getIPv4DstAddr(header []byte) netip.Addr {
+ dst, _ := netip.AddrFromSlice(header[IPV4_HEADER_DST_ADDR:IPV4_HEADER_DST_ADDR+4])
+ return dst
+}
+
+func putIPv4DstAddr(header []byte, addr netip.Addr) {
+ copy(header[IPV4_HEADER_DST_ADDR:IPV4_HEADER_DST_ADDR+4], addr.AsSlice())
+}
+
+func getSrcAddr(header []byte, version int) netip.Addr {
+ if version == IPV4_VERSION {
+ return getIPv4SrcAddr(header)
+ } else {
+ // FIXME
+ return netip.IPv6Unspecified()
+ }
+}
+
+func getDstAddr(header []byte, version int) netip.Addr {
+ if version == IPV4_VERSION {
+ return getIPv4DstAddr(header)
+ } else {
+ // FIXME
+ return netip.IPv6Unspecified()
+ }
+}
+
+func updateIPv4Checksum(header []byte, updateFunc func (checksum int) int) {
+ checksum := getUint16(header, IPV4_HEADER_CHECKSUM)
+ putUint16(header, IPV4_HEADER_CHECKSUM, updateFunc(checksum))
+}
+
+func getTransport(header []byte, version int) (int, []byte) {
+ if version == IPV4_VERSION {
+ protocol := int(header[IPV4_HEADER_PROTOCOL])
+ return protocol, header[IPV4_HEADER_LEN:] // TODO take option headers into account
+ } else if version == IPV6_VERSION {
+ nextHeader := int(header[IPV6_HEADER_NEXT_HEADER])
+ return nextHeader, header[IPV6_HEADER_LEN:] // TODO take additional headers into account
+ } else {
+ return 0, nil
+ }
+}
+
+func updateTransportChecksum(header []byte, version int, updateFunc func (checksum int) int) {
+ transport, transportPayload := getTransport(header, version)
+
+ if transport == PROTO_UDP {
+ // UDP checksum should be calculated with by swinging the 17th carry bit around. Use no checksum.
+ //checksum := getUint16(transportPayload, UDP_HEADER_CHECKSUM)
+ //putUint16(transportPayload, UDP_HEADER_CHECKSUM, updateFunc(checksum))
+ putUint16(transportPayload, UDP_HEADER_CHECKSUM, 0)
+ } else if transport == PROTO_TCP {
+ checksum := getUint16(transportPayload, TCP_HEADER_CHECKSUM)
+ putUint16(transportPayload, TCP_HEADER_CHECKSUM, updateFunc(checksum))
+ }
+}
+
+func onesComplementUint16(value int) int {
+ return (0x10000 + ^(value & 0xffff)) & 0xffff
+}
+
+func onesSumUint16(value int) int {
+ for value > 0xffff {
+ value = (value & 0xffff) + (value >> 16)
+ }
+ return value
+}
+
+func calculateChecksum(hc int, m, mPrim int) int {
+ return onesComplementUint16(onesSumUint16(onesComplementUint16(hc) + onesComplementUint16(m) + mPrim))
+}
+
+func updateAddr(header []byte, version int, getFunc func (header []byte) netip.Addr, putFunc func (header []byte, addr netip.Addr), updateFunc func (netip.Addr) netip.Addr) {
+ if version == IPV4_VERSION {
+ addr := getFunc(header)
+ newAddr := updateFunc(addr)
+
+ if newAddr != addr {
+ putFunc(header, newAddr)
+ // TODO reduce code duplication see updateSrcAddr
+ updateIPv4Checksum(header, func (checksum int) int {
+ addr4 := addr.As4()
+ newAddr4 := newAddr.As4()
+ checksum = calculateChecksum(checksum, int(binary.BigEndian.Uint16(addr4[0:2])), int(binary.BigEndian.Uint16(newAddr4[0:2])))
+ checksum = calculateChecksum(checksum, int(binary.BigEndian.Uint16(addr4[2:4])), int(binary.BigEndian.Uint16(newAddr4[2:4])))
+ return checksum
+ })
+ updateTransportChecksum(header, version, func (checksum int) int {
+ if version == IPV4_VERSION && checksum == 0 {
+ return 0
+ } else {
+ addr4 := addr.As4()
+ newAddr4 := newAddr.As4()
+ checksum = calculateChecksum(checksum, int(binary.BigEndian.Uint16(addr4[0:2])), int(binary.BigEndian.Uint16(newAddr4[0:2])))
+ checksum = calculateChecksum(checksum, int(binary.BigEndian.Uint16(addr4[2:4])), int(binary.BigEndian.Uint16(newAddr4[2:4])))
+ return checksum
+ }
+ })
+ }
+ // } else if version == IPV6_VERSION {
+ // addr := getIPv6SrcAddr(header)
+ // putIPv6SrcAddr(header, updateFunc(addr))
+ }
+}
+
+func updateSrcAddr(header []byte, version int, updateFunc func (netip.Addr) netip.Addr ) {
+ if version == IPV4_VERSION {
+ updateAddr(header, version, getIPv4SrcAddr, putIPv4SrcAddr, updateFunc)
+ }
+}
+
+func updateDstAddr(header []byte, version int, updateFunc func (netip.Addr) netip.Addr ) {
+ if version == IPV4_VERSION {
+ updateAddr(header, version, getIPv4DstAddr, putIPv4DstAddr, updateFunc)
+ }
+
+ // } else if version == IPV6_VERSION {
+ // addr := getIPv6SrcAddr(header)
+ // putIPv6SrcAddr(header, updateFunc(addr))
+}
+
+func updateDstPort(header []byte, version int, updateFunc func (int) int ) {
+ transport, transportPayload := getTransport(header, version)
+
+ if transport == PROTO_TCP {
+ port := getUint16(transportPayload, TCP_HEADER_DST_PORT)
+ newPort := updateFunc(port)
+
+ if newPort != port {
+ putUint16(transportPayload, TCP_HEADER_DST_PORT, newPort)
+
+ updateTransportChecksum(header, version, func (checksum int) int {
+ if version == IPV4_VERSION && checksum == 0 {
+ return 0
+ } else {
+ checksum = calculateChecksum(checksum, int(port), int(newPort))
+ return checksum
+ }
+ })
+ }
+ // } else if version == IPV6_VERSION {
+ }
+}
+
+func updateSrcPort(header []byte, version int, updateFunc func (int) int ) {
+ transport, transportPayload := getTransport(header, version)
+
+ if transport == PROTO_TCP {
+ port := getUint16(transportPayload, TCP_HEADER_SRC_PORT)
+ newPort := updateFunc(port)
+
+ if newPort != port {
+ putUint16(transportPayload, TCP_HEADER_SRC_PORT, newPort)
+
+ updateTransportChecksum(header, version, func (checksum int) int {
+ if version == IPV4_VERSION && checksum == 0 {
+ return 0
+ } else {
+ checksum = calculateChecksum(checksum, int(port), int(newPort))
+ return checksum
+ }
+ })
+ }
+ // } else if version == IPV6_VERSION {
+ }
+}
+
+func getSrcPort(header []byte, version int) int {
+ transport, transportPayload := getTransport(header, version)
+
+ if transport == PROTO_TCP {
+ return getUint16(transportPayload, TCP_HEADER_SRC_PORT)
+ } else if transport == PROTO_UDP {
+ return getUint16(transportPayload, UDP_HEADER_SRC_PORT)
+ } else {
+ return -1
+ }
+}
+
+func getDstPort(header []byte, version int) int {
+ transport, transportPayload := getTransport(header, version)
+
+ if transport == PROTO_TCP {
+ return getUint16(transportPayload, TCP_HEADER_DST_PORT)
+ } else if transport == PROTO_UDP {
+ return getUint16(transportPayload, UDP_HEADER_DST_PORT)
+ } else {
+ return -1
+ }
+}
+
+func getTcpSrcPort(tcpHeader []byte) int {
+ return getUint16(tcpHeader, TCP_HEADER_SRC_PORT)
+}
+
+func putTcpSrcPort(tcpHeader []byte, port int) {
+ putUint16(tcpHeader, TCP_HEADER_SRC_PORT, port)
+}
+
+func getTcpDstPort(tcpHeader []byte) int {
+ return getUint16(tcpHeader, TCP_HEADER_DST_PORT)
+}
+
+func putTcpDstPort(tcpHeader []byte, port int) {
+ putUint16(tcpHeader, TCP_HEADER_DST_PORT, port)
+}
+
+func updateTcpSrcPort(tcpHeader []byte, updateFunc func (int) int ) {
+ port := getTcpSrcPort(tcpHeader)
+ putTcpSrcPort(tcpHeader, updateFunc(port))
+}
+
+func updateTcpDstPort(tcpHeader []byte, updateFunc func (int) int ) {
+ port := getTcpDstPort(tcpHeader)
+ putTcpDstPort(tcpHeader, updateFunc(port))
+}
+
+func getIPVersion(header []byte, len int) (int, error) {
+ version := int(header[0]) >> 4
+
+ if version == IPV4_VERSION {
+ if len >= IPV4_HEADER_LEN {
+ return version, nil
+ } else {
+ return -1, fmt.Errorf("IPv4 header length: %d < %d", len, IPV4_HEADER_LEN)
+ }
+ } else if version == IPV6_VERSION {
+ if len >= IPV6_HEADER_LEN {
+ return version, nil
+ } else {
+ return -1, fmt.Errorf("IPv6 header length: %d < %d", len, IPV6_HEADER_LEN)
+ }
+ } else {
+ return -1, fmt.Errorf("Unknown IP version: %d", version)
+ }
+}
+
+func (tun *natTun) Read(buf []byte, offset int) (int, error) {
+ len, err := tun.tun.Read(buf, offset)
+ if err == nil && len > 0 {
+ header := buf[offset:]
+ version, err := getIPVersion(header, len)
+
+ if err != nil {
+ // Ignore bad packet
+ } else if version == IPV4_VERSION {
+ isProxy := false
+
+ srcPort := getSrcPort(header, version)
+ origDstPort := 0
+ newDstPort := 0
+ var origSrcAddr netip.Addr
+ var newSrcAddr netip.Addr
+ var origDstAddr netip.Addr
+ var newDstAddr netip.Addr
+
+ updateDstPort(header, version, func(port int) int {
+ // Transparent proxy HTTP and HTTPS
+ if port == 80 || port == 443 {
+ isProxy = true
+ origDstPort = port
+ newDstPort = tun.proxyPort
+ return newDstPort
+ } else {
+ return port
+ }
+ })
+ if tun.srcAddr.IsValid() {
+ updateSrcAddr(header, version, func(addr netip.Addr) netip.Addr {
+ if isProxy {
+ origSrcAddr = addr
+ newSrcAddr = tun.srcAddr
+ return newSrcAddr
+ } else {
+ return addr
+ }
+ })
+ }
+ updateDstAddr(header, version, func(addr netip.Addr) netip.Addr {
+ if isProxy {
+ origDstAddr = addr
+ newDstAddr = tun.proxyAddr
+ return newDstAddr
+ } else {
+ return addr
+ }
+ })
+
+ if isProxy {
+ if !newSrcAddr.IsValid() {
+ origSrcAddr = getSrcAddr(header, version)
+ newSrcAddr = origSrcAddr
+ }
+ orig := Connection(netip.AddrPortFrom(origSrcAddr, uint16(srcPort)),
+ netip.AddrPortFrom(origDstAddr, uint16(origDstPort)))
+ new := Connection(netip.AddrPortFrom(newSrcAddr, uint16(srcPort)),
+ netip.AddrPortFrom(newDstAddr, uint16(newDstPort)))
+ tun.addConnection(new, orig)
+ }
+ // protocol := int(header[9])
+
+ // if protocol == PROTO_TCP && len >= (IPV4_HEADER_LEN + TCP_HEADER_LEN) {
+ // tcpHeader := header[IPV4_HEADER_LEN:]
+ // updateTcpDstPort(tcpHeader, func(port int) int { return port + 1 })
+ // }
+ //} else if version == IPV6_VERSION {
+ // updateSrcAddr(header, IPV6_HEADER_SRC_ADDR, func(netip.Addr) { return extAddr })
+ // nextHeader := int(header[6])
+
+ // if nextHeader == PROTO_TCP && len >= (IPV6_HEADER_LEN + TCP_HEADER_LEN) {
+ // tcpHeader := header[IPV6_HEADER_LEN:]
+ // updateTcpDstPort(tcpHeader, func(port int) int { return port + 1 })
+ // }
+ }
+ }
+
+ return len, err
+}
+
+func (tun *natTun) Write(buf []byte, offset int) (int, error) {
+ len := len(buf)
+
+ if len > 0 {
+ header := buf[offset:]
+ version, err := getIPVersion(header, len)
+
+ if err != nil {
+ // Ignore bad packet
+ } else if version == IPV4_VERSION || version == IPV6_VERSION {
+ srcAddr := getSrcAddr(header, version)
+ srcPort := getSrcPort(header, version)
+ dstAddr := getDstAddr(header, version)
+ dstPort := getDstPort(header, version)
+
+ src := netip.AddrPortFrom(srcAddr, uint16(srcPort))
+ dst := netip.AddrPortFrom(dstAddr, uint16(dstPort))
+ new := Connection(dst, src)
+
+ orig, ok := tun.lookupConnection(new)
+
+ if ok {
+ updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return orig.dst.Addr() })
+ updateSrcPort(header, version, func(int) int { return int(orig.dst.Port()) })
+ updateDstAddr(header, version, func(netip.Addr) netip.Addr { return orig.src.Addr() })
+ updateDstPort(header, version, func(int) int { return int(orig.src.Port()) })
+ }
+ }
+ }
+
+ return tun.tun.Write(buf, offset)
+}
+
+func (tun *natTun) Flush() error {
+ return tun.tun.Flush()
+}
+
+func (tun *natTun) Close() error {
+ return tun.tun.Close()
+}
+
+func (tun *natTun) MTU() (int, error) {
+ return tun.tun.MTU()
+}