From 99421e9c97978d2c5fcdcfe96bcc122bd2fa3045 Mon Sep 17 00:00:00 2001 From: Mikael Magnusson Date: Sat, 19 Feb 2022 10:11:16 +0100 Subject: WIP not working nat fsm --- tunnel/tools/libwg-go/conntrack.go | 330 ++++++++++++++++++++++++++++++++++-- tunnel/tools/libwg-go/http-proxy.go | 6 +- tunnel/tools/libwg-go/nat-tun.go | 277 ++++++++++++++++++++---------- 3 files changed, 504 insertions(+), 109 deletions(-) diff --git a/tunnel/tools/libwg-go/conntrack.go b/tunnel/tools/libwg-go/conntrack.go index 17289a6f..9b34b5ce 100644 --- a/tunnel/tools/libwg-go/conntrack.go +++ b/tunnel/tools/libwg-go/conntrack.go @@ -1,12 +1,83 @@ package main import ( + "container/heap" + "fmt" "sync" + "time" "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/device" ) +const ( + TCP_TRANS_TIMEOUT_MS = 4 * 60 * 1000 // 4 minutes + TCP_EST_TIMEOUT_MS = 1 * 24 * 60 * 60 * 1000 // 1 day +) + +// RFC 7857 section 2. TCP Session Tracking +type State int8 + +const ( + Closed State = iota + Trans + Init + Established + ClientFinRcv + ServerFinRcv + ClientServerFinRcv +) + +func (st State) String() string { + stateNames := []string{ + "Closed", + "Trans", + "Init", + "Established", + "ClientFinRcv", + "ServerFinRcv", + "ClientServerFinRcv", + } + + if int(st) < len(stateNames) { + return stateNames[st] + } else { + return "Unknown" + } +} + +type Event int8 + +const ( + None Event = iota + ClientSyn + ClientRst + ClientFin + ServerSyn + ServerRst + ServerFin + DataPkt +) + +func (ev Event) String() string { + eventNames := []string{ + "None", + "ClientSyn", + "ClientRst", + "ClientFin", + "ServerSyn", + "ServerRst", + "ServerFin", + "DataPkt", + } + + if int(ev) < len(eventNames) { + return eventNames[ev] + } else { + return "Unknown" + } +} + type connection struct { src netip.AddrPort dst netip.AddrPort @@ -19,30 +90,263 @@ func Connection(src, dst netip.AddrPort) connection { } } +func (c connection) String() string { + return fmt.Sprintf("%v -> %v", c.src, c.dst) +} + +type mapping struct { + orig connection + nat connection + state State + timeout time.Time // Priority in queue + index int // Index in the queue +} + +func (m *mapping) String() string { + return fmt.Sprintf("state:%v orig:%v, nat:%v", m.state, m.orig, m.nat) +} + +func (m *mapping) clientSyn(ct *Conntrack) { + if m.state == Closed { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = Init + } +} + +func (m *mapping) serverSyn(ct *Conntrack) { + if m.state == Init { + ct.updateTimer(m, TCP_EST_TIMEOUT_MS) + m.state = Established + } +} + +func (m *mapping) clientRst(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = Trans + } +} + +func (m *mapping) serverRst(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = Trans + } +} + +func (m *mapping) clientFin(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = ClientFinRcv + } else if m.state == ServerFinRcv { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = ClientServerFinRcv + } +} + +func (m *mapping) serverFin(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = ServerFinRcv + } else if m.state == ClientFinRcv { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = ClientServerFinRcv + } +} + +func (m *mapping) dataPkt(ct *Conntrack) { + if m.state == Trans { + ct.updateTimer(m, TCP_EST_TIMEOUT_MS) + m.state = Established + } +} + +func (m *mapping) transitionTimeout(ct *Conntrack) { + if m.state == Trans || m.state == ClientServerFinRcv { + // TODO remove connection + m.state = Closed + } +} + +func (m *mapping) establishedTimeout(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = Trans + } +} + +type PriorityQueue []*mapping + +func (pq PriorityQueue) Len() int { return len(pq) } + +func (pq PriorityQueue) Less(i, j int) bool { + // We want Pop to give us the lowest, not highest, priority so + // we use less than here. + return pq[i].timeout.Before(pq[j].timeout) +} + +func (pq PriorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j +} + +func (pq *PriorityQueue) Push(x interface{}) { + n := len(*pq) + item := x.(*mapping) + item.index = n + *pq = append(*pq, item) +} + +func (pq *PriorityQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.index = -1 // for safety + *pq = old[0 : n-1] + return item +} + +// update modifies the priority and value of an Item in the queue. +func (pq *PriorityQueue) update(c *mapping, timeout time.Time) { + c.timeout = timeout + heap.Fix(pq, c.index) +} + type Conntrack struct { - connections map[connection]connection - connectionsMutex sync.RWMutex + mappings map[connection]*mapping + mappingsMutex sync.RWMutex + timeouts PriorityQueue + timer *time.Timer l *device.Logger } func NewConntrack(logger *device.Logger) *Conntrack { + timeouts := make(PriorityQueue, 0) + heap.Init(&timeouts) return &Conntrack{ - connections: make(map[connection]connection), - connectionsMutex: sync.RWMutex{}, + mappings: make(map[connection]*mapping), + mappingsMutex: sync.RWMutex{}, + timeouts: timeouts, + timer: nil, l: logger, } } -func (ct *Conntrack) addConnection(new, orig connection) { - ct.connectionsMutex.Lock() - ct.connections[new] = orig - ct.connectionsMutex.Unlock() +// mappingsMutex must be held +func (ct *Conntrack) updateTimer(m *mapping, d time.Duration) { + if m.index == -1 { + m.timeout = time.Now().Add(d) + heap.Push(&ct.timeouts, m) + } else { + ct.timeouts.update(m, time.Now().Add(d)) + } + firstDuration := ct.timeouts[0].timeout.Sub(time.Now()) + if m.index == 0 && ct.timer != nil { + ct.l.Verbosef("updateTimer stop") + // FIXME deadlock if called from AfterFunc + if !ct.timer.Stop() { + ct.l.Verbosef("stop returned false") + //<-ct.timer.C + // TODO handle timeoute + } + ct.l.Verbosef("updateTimer reset") + ct.timer.Reset(firstDuration) + ct.l.Verbosef("updateTimer reset done") + } else if ct.timer == nil { + ct.timer = time.NewTimer(firstDuration) + + go func(ch <-chan time.Time) { + for { + t := <-ch + ct.l.Verbosef("Timer %v", t) + ct.mappingsMutex.Lock() + ct.l.Verbosef("Timer locked") + now := time.Now() + for ct.timeouts.Len() > 0 { + ct.l.Verbosef("Timer loop %d", ct.timeouts.Len()) + m := ct.timeouts[0] + if m.timeout.Before(now) { + ct.l.Verbosef("Timer Before") + heap.Pop(&ct.timeouts) + if m.state == Established { + ct.l.Verbosef("Timer established") + // FIXME use channel + m.establishedTimeout(ct) + } else { + ct.l.Verbosef("Timer transition") + m.transitionTimeout(ct) + } + if m.state == Closed { + ct.l.Verbosef("Timer closed") + delete(ct.mappings, m.orig) + delete(ct.mappings, m.nat) + } + } else { + break + } + } + ct.l.Verbosef("Timer loop done") + if ct.timeouts.Len() > 0 { + ct.timer.Reset(ct.timeouts[0].timeout.Sub(time.Now())) + } + ct.mappingsMutex.Unlock() + ct.l.Verbosef("Timer unlocked") + } + }(ct.timer.C) + } else { + ct.l.Verbosef("Timer update not needed") + } +} + +func (ct *Conntrack) addConnection(orig, nat connection) { + ct.l.Verbosef("addConnection") + ct.mappingsMutex.Lock() + m := &mapping{ + orig: orig, + nat: nat, + state: Init, + timeout: time.Now().Add(TCP_TRANS_TIMEOUT_MS), + index: -1, + } + ct.mappings[orig] = m + ct.mappings[nat] = m + heap.Push(&ct.timeouts, m) + ct.mappingsMutex.Unlock() } -func (ct *Conntrack) lookupConnection(new connection) (connection, bool) { - ct.connectionsMutex.RLock() - c, ok := ct.connections[new] - ct.connectionsMutex.RUnlock() - return c, ok +func (ct *Conntrack) lookupConnection(c connection) (*mapping, bool) { + ct.l.Verbosef("lookupConnection") + ct.mappingsMutex.RLock() + m, ok := ct.mappings[c] + ct.mappingsMutex.RUnlock() + return m, ok } +func (ct *Conntrack) event(c connection, ev Event) { + ct.l.Verbosef("event") + ct.mappingsMutex.Lock() + m, ok := ct.mappings[c] + if !ok { + ct.mappingsMutex.Unlock() + // FIXME + return + } + if ev == ClientSyn { + m.clientSyn(ct) + } else if ev == ServerSyn { + m.serverSyn(ct) + } else if ev == ClientRst { + m.clientRst(ct) + } else if ev == ServerRst { + m.serverRst(ct) + } else if ev == ClientFin { + m.clientFin(ct) + } else if ev == ServerFin { + m.serverFin(ct) + } else if ev == DataPkt { + m.dataPkt(ct) + } + ct.mappingsMutex.Unlock() +} diff --git a/tunnel/tools/libwg-go/http-proxy.go b/tunnel/tools/libwg-go/http-proxy.go index 7e2bd32f..844ffae5 100644 --- a/tunnel/tools/libwg-go/http-proxy.go +++ b/tunnel/tools/libwg-go/http-proxy.go @@ -604,11 +604,11 @@ func (h *HttpHandler) addConnToProxyMap(c net.Conn) bool { } newConn := connection{src: local, dst: remote} - oldConn, ok := h.p.conntrack.lookupConnection(newConn) + m, ok := h.p.conntrack.lookupConnection(newConn) if ok { - local = oldConn.src - remote = oldConn.dst + local = m.orig.src + remote = m.orig.dst h.logger.Verbosef("Before NAT: %v -> %v", local, remote) } else if remote.Addr().IsLoopback() { h.logger.Verbosef("Loopback request") diff --git a/tunnel/tools/libwg-go/nat-tun.go b/tunnel/tools/libwg-go/nat-tun.go index 75181ba0..6b37a9d8 100644 --- a/tunnel/tools/libwg-go/nat-tun.go +++ b/tunnel/tools/libwg-go/nat-tun.go @@ -6,6 +6,7 @@ package main // TODO debug IPv6 NAT +// TODO implement TCP session tracking according to RFC 7857, import ( "encoding/binary" @@ -57,17 +58,19 @@ func (tun *natTun) addTranslation(ipv6 bool, dstPort int, proxyPort int) { if ipv6 { ipVersion = IPV6_VERSION - //srcStr = "fe80::1" - //proxyStr = "fe80::2" - srcStr = "fdba:4b51:1606:a61d::1" - proxyStr = "fdba:4b51:1606:a61d::2" - // loop = false - // proxyStr = "2001:470:de6f:2fff::f780" - // proxyPort = 8888 + //srcStr = "fdba:4b51:1606:a61d::1" + //proxyStr = "fdba:4b51:1606:a61d::2" + loop = false + proxyStr = "2001:470:de6f:2fff::f780" + proxyPort = 8888 } else { ipVersion = IPV4_VERSION - srcStr = "169.254.0.1" - proxyStr = "169.254.0.2" + //srcStr = "169.254.0.1" + //proxyStr = "169.254.0.2" + loop = false + proxyStr = "10.49.32.1" + proxyPort = 8888 + } if srcStr != "" { @@ -132,8 +135,13 @@ const ( TCP_HEADER_LEN = 40 TCP_HEADER_SRC_PORT = 0 TCP_HEADER_DST_PORT = 2 + TCP_HEADER_FLAGS = 12 TCP_HEADER_CHECKSUM = 16 + TCP_FLAGS_FIN = 0x01 + TCP_FLAGS_SYN = 0x02 + TCP_FLAGS_RST = 0x04 + UDP_HEADER_SRC_PORT = 0 UDP_HEADER_DST_PORT = 2 UDP_HEADER_CHECKSUM = 6 @@ -429,30 +437,110 @@ func getIPVersion(header []byte, len int) (int, error) { } } +func testFlag(flags int, flag int) bool { + return (flags & flag) != 0 +} + func (tun *natTun) Read(buf []byte, offset int) (int, error) { len, err := tun.tun.Read(buf, offset) - if err == nil && len > 0 { - header := buf[offset:] - version, err := getIPVersion(header, len) + if err != nil || len <= 0 { + return len, err + } + + header := buf[offset:] + version, err := getIPVersion(header, len) + + if err != nil { + return len, nil + } - if err != nil { - // Ignore bad packet - } else if version == IPV4_VERSION || version == IPV6_VERSION { - origDstPort := getDstPort(header, version) + if version == IPV4_VERSION || version == IPV6_VERSION { + transport, transportPayload := getTransport(header, version) + if (transport != PROTO_TCP) { + return len, nil + } + + flags := getUint16(transportPayload, TCP_HEADER_FLAGS) & 0x1f + srcAddr := getSrcAddr(header, version) + srcPort := getSrcPort(header, version) + dstAddr := getDstAddr(header, version) + dstPort := getDstPort(header, version) + + src := netip.AddrPortFrom(srcAddr, uint16(srcPort)) + dst := netip.AddrPortFrom(dstAddr, uint16(dstPort)) + conn := Connection(src, dst) + + m, ok := tun.conntrack.lookupConnection(conn) + if ok { + tun.l.Verbosef("Found mapping: [%v] %v", conn, m) + var event Event + var newConn connection + if m.orig == conn { + // Client + newConn = m.nat + tun.l.Verbosef("Client: %v", newConn) + if testFlag(flags, TCP_FLAGS_SYN) { + event = ClientSyn + } else if testFlag(flags, TCP_FLAGS_FIN) { + event = ClientFin + } else if testFlag(flags, TCP_FLAGS_RST) { + event = ClientRst + } + } else { + // Server + newConn = m.orig + tun.l.Verbosef("Server: %v", newConn) + if testFlag(flags, TCP_FLAGS_SYN) { + event = ServerSyn + } else if testFlag(flags, TCP_FLAGS_FIN) { + event = ServerFin + } else if testFlag(flags, TCP_FLAGS_RST) { + event = ServerRst + } else { + event = DataPkt + } + } + + if event != None { + tun.l.Verbosef("Event: %v", event) + tun.conntrack.event(conn, event) + } + + if m.state == Closed { + // TODO log + return 0, fmt.Errorf("Mapping closed") + } + + updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return newConn.dst.Addr() }) + updateSrcPort(header, version, func(int) int { return int(newConn.dst.Port()) }) + updateDstAddr(header, version, func(netip.Addr) netip.Addr { return newConn.src.Addr() }) + updateDstPort(header, version, func(int) int { return int(newConn.src.Port()) }) + // Write back to tunnel + writeBuf := buf[:offset+len] + written, err := tun.Write(writeBuf, offset) + tun.l.Verbosef("Written: %v %v", written, err) + if err != nil { + return 0, err + } else if written != len { + return 0, fmt.Errorf("NAT rev buffer partly written %v != %v", written, len) + } else { + return 0, nil + } + } else { + if !testFlag(flags, TCP_FLAGS_SYN) { + return len, nil + } for _, entry := range(tun.translations) { if version != entry.ipVersion { continue } - if origDstPort != entry.dstPort { + if dstPort != entry.dstPort { continue } - - srcPort := getSrcPort(header, version) + newDstPort := 0 - var origSrcAddr netip.Addr - var origDstAddr netip.Addr var newSrcAddr netip.Addr var newDstAddr netip.Addr @@ -463,34 +551,34 @@ func (tun *natTun) Read(buf []byte, offset int) (int, error) { if entry.srcAddr.IsValid() { updateSrcAddr(header, version, func(addr netip.Addr) netip.Addr { - origSrcAddr = addr newSrcAddr = entry.srcAddr return newSrcAddr }) } else { - origSrcAddr = getSrcAddr(header, version) - newSrcAddr = origSrcAddr + newSrcAddr = srcAddr } updateDstAddr(header, version, func(addr netip.Addr) netip.Addr { - origDstAddr = addr newDstAddr = entry.proxyAddr return newDstAddr }) - orig := Connection(netip.AddrPortFrom(origSrcAddr, uint16(srcPort)), - netip.AddrPortFrom(origDstAddr, uint16(origDstPort))) - new := Connection(netip.AddrPortFrom(newSrcAddr, uint16(srcPort)), - netip.AddrPortFrom(newDstAddr, uint16(newDstPort))) - tun.conntrack.addConnection(new, orig) + orig := Connection(netip.AddrPortFrom(srcAddr, uint16(srcPort)), + netip.AddrPortFrom(dstAddr, uint16(dstPort))) + nat := Connection(netip.AddrPortFrom(newDstAddr, uint16(newDstPort)), + netip.AddrPortFrom(newSrcAddr, uint16(srcPort))) + tun.l.Verbosef("New conn: %v -> %v", orig, nat) + + tun.conntrack.addConnection(orig, nat) if !entry.loop { return len, nil } - + // Write back to tunnel writeBuf := buf[:offset+len] written, err := tun.Write(writeBuf, offset) + tun.l.Verbosef("Written: %v %v", written, err) if err != nil { return 0, err } else if written != len { @@ -499,7 +587,27 @@ func (tun *natTun) Read(buf []byte, offset int) (int, error) { return 0, nil } } + } + } + + return len, err +} +func (tun *natTun) Write(buf []byte, offset int) (int, error) { + len := len(buf) + + if len <= 0 { + return len, nil + } + + header := buf[offset:] + version, err := getIPVersion(header, len) + + if err == nil && (version == IPV4_VERSION || version == IPV6_VERSION) { + transport, transportPayload := getTransport(header, version) + if (transport == PROTO_TCP) { + flags := getUint16(transportPayload, TCP_HEADER_FLAGS) & 0x1f + // TODO remove unused code srcAddr := getSrcAddr(header, version) srcPort := getSrcPort(header, version) dstAddr := getDstAddr(header, version) @@ -507,74 +615,57 @@ func (tun *natTun) Read(buf []byte, offset int) (int, error) { src := netip.AddrPortFrom(srcAddr, uint16(srcPort)) dst := netip.AddrPortFrom(dstAddr, uint16(dstPort)) - new := Connection(dst, src) + conn := Connection(src, dst) - orig, ok := tun.conntrack.lookupConnection(new) + m, ok := tun.conntrack.lookupConnection(conn) if ok { - updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return orig.dst.Addr() }) - updateSrcPort(header, version, func(int) int { return int(orig.dst.Port()) }) - updateDstAddr(header, version, func(netip.Addr) netip.Addr { return orig.src.Addr() }) - updateDstPort(header, version, func(int) int { return int(orig.src.Port()) }) - // Write back to tunnel - writeBuf := buf[:offset+len] - written, err := tun.Write(writeBuf, offset) - if err != nil { - return 0, err - } else if written != len { - return 0, fmt.Errorf("NAT rev buffer partly written %v != %v", written, len) + tun.l.Verbosef("Write: Found mapping: [%v] %v", conn, m) + var event Event + var newConn connection + + if m.orig == conn { + // Client + newConn = m.nat + tun.l.Verbosef("Client: %v", newConn) + if testFlag(flags, TCP_FLAGS_SYN) { + event = ClientSyn + } else if testFlag(flags, TCP_FLAGS_FIN) { + event = ClientFin + } else if testFlag(flags, TCP_FLAGS_RST) { + event = ClientRst + } } else { - return 0, nil + // Server + newConn = m.orig + tun.l.Verbosef("Server: %v", newConn) + if testFlag(flags, TCP_FLAGS_SYN) { + event = ServerSyn + } else if testFlag(flags, TCP_FLAGS_FIN) { + event = ServerFin + } else if testFlag(flags, TCP_FLAGS_RST) { + event = ServerRst + } else { + event = DataPkt + } + } + + if event != None { + tun.l.Verbosef("Event: %v", event) + tun.conntrack.event(conn, event) } - } - // protocol := int(header[9]) - - // if protocol == PROTO_TCP && len >= (IPV4_HEADER_LEN + TCP_HEADER_LEN) { - // tcpHeader := header[IPV4_HEADER_LEN:] - // updateTcpDstPort(tcpHeader, func(port int) int { return port + 1 }) - // } - //} else if version == IPV6_VERSION { - // updateSrcAddr(header, IPV6_HEADER_SRC_ADDR, func(netip.Addr) { return extAddr }) - // nextHeader := int(header[6]) - - // if nextHeader == PROTO_TCP && len >= (IPV6_HEADER_LEN + TCP_HEADER_LEN) { - // tcpHeader := header[IPV6_HEADER_LEN:] - // updateTcpDstPort(tcpHeader, func(port int) int { return port + 1 }) - // } - } - } - return len, err -} - -func (tun *natTun) Write(buf []byte, offset int) (int, error) { - len := len(buf) + if m.state == Closed { + // TODO log + return 0, fmt.Errorf("Mapping closed") + } - if len > 0 { - header := buf[offset:] - version, err := getIPVersion(header, len) - - if err != nil { - // Ignore bad packet - } else if version == IPV4_VERSION || version == IPV6_VERSION { - srcAddr := getSrcAddr(header, version) - srcPort := getSrcPort(header, version) - dstAddr := getDstAddr(header, version) - dstPort := getDstPort(header, version) - - src := netip.AddrPortFrom(srcAddr, uint16(srcPort)) - dst := netip.AddrPortFrom(dstAddr, uint16(dstPort)) - new := Connection(dst, src) - - orig, ok := tun.conntrack.lookupConnection(new) - - if ok { - updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return orig.dst.Addr() }) - updateSrcPort(header, version, func(int) int { return int(orig.dst.Port()) }) - updateDstAddr(header, version, func(netip.Addr) netip.Addr { return orig.src.Addr() }) - updateDstPort(header, version, func(int) int { return int(orig.src.Port()) }) - } - } - } + updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return newConn.dst.Addr() }) + updateSrcPort(header, version, func(int) int { return int(newConn.dst.Port()) }) + updateDstAddr(header, version, func(netip.Addr) netip.Addr { return newConn.src.Addr() }) + updateDstPort(header, version, func(int) int { return int(newConn.src.Port()) }) + } + } + } return tun.tun.Write(buf, offset) } -- cgit v1.2.3