summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMikael Magnusson <mikma@users.sourceforge.net>2022-02-19 10:11:16 +0100
committerMikael Magnusson <mikma@users.sourceforge.net>2022-03-18 22:07:58 +0100
commit99421e9c97978d2c5fcdcfe96bcc122bd2fa3045 (patch)
treee2ddc50015b02cfba37e3ae59eeaeb781d64a2fd
parent90d21ec6d360af0689f90a13d37a71f50f5b5d21 (diff)
WIP not working nat fsmhttp-proxy-fsm
-rw-r--r--tunnel/tools/libwg-go/conntrack.go330
-rw-r--r--tunnel/tools/libwg-go/http-proxy.go6
-rw-r--r--tunnel/tools/libwg-go/nat-tun.go277
3 files changed, 504 insertions, 109 deletions
diff --git a/tunnel/tools/libwg-go/conntrack.go b/tunnel/tools/libwg-go/conntrack.go
index 17289a6f..9b34b5ce 100644
--- a/tunnel/tools/libwg-go/conntrack.go
+++ b/tunnel/tools/libwg-go/conntrack.go
@@ -1,12 +1,83 @@
package main
import (
+ "container/heap"
+ "fmt"
"sync"
+ "time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/device"
)
+const (
+ TCP_TRANS_TIMEOUT_MS = 4 * 60 * 1000 // 4 minutes
+ TCP_EST_TIMEOUT_MS = 1 * 24 * 60 * 60 * 1000 // 1 day
+)
+
+// RFC 7857 section 2. TCP Session Tracking
+type State int8
+
+const (
+ Closed State = iota
+ Trans
+ Init
+ Established
+ ClientFinRcv
+ ServerFinRcv
+ ClientServerFinRcv
+)
+
+func (st State) String() string {
+ stateNames := []string{
+ "Closed",
+ "Trans",
+ "Init",
+ "Established",
+ "ClientFinRcv",
+ "ServerFinRcv",
+ "ClientServerFinRcv",
+ }
+
+ if int(st) < len(stateNames) {
+ return stateNames[st]
+ } else {
+ return "Unknown"
+ }
+}
+
+type Event int8
+
+const (
+ None Event = iota
+ ClientSyn
+ ClientRst
+ ClientFin
+ ServerSyn
+ ServerRst
+ ServerFin
+ DataPkt
+)
+
+func (ev Event) String() string {
+ eventNames := []string{
+ "None",
+ "ClientSyn",
+ "ClientRst",
+ "ClientFin",
+ "ServerSyn",
+ "ServerRst",
+ "ServerFin",
+ "DataPkt",
+ }
+
+ if int(ev) < len(eventNames) {
+ return eventNames[ev]
+ } else {
+ return "Unknown"
+ }
+}
+
type connection struct {
src netip.AddrPort
dst netip.AddrPort
@@ -19,30 +90,263 @@ func Connection(src, dst netip.AddrPort) connection {
}
}
+func (c connection) String() string {
+ return fmt.Sprintf("%v -> %v", c.src, c.dst)
+}
+
+type mapping struct {
+ orig connection
+ nat connection
+ state State
+ timeout time.Time // Priority in queue
+ index int // Index in the queue
+}
+
+func (m *mapping) String() string {
+ return fmt.Sprintf("state:%v orig:%v, nat:%v", m.state, m.orig, m.nat)
+}
+
+func (m *mapping) clientSyn(ct *Conntrack) {
+ if m.state == Closed {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = Init
+ }
+}
+
+func (m *mapping) serverSyn(ct *Conntrack) {
+ if m.state == Init {
+ ct.updateTimer(m, TCP_EST_TIMEOUT_MS)
+ m.state = Established
+ }
+}
+
+func (m *mapping) clientRst(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = Trans
+ }
+}
+
+func (m *mapping) serverRst(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = Trans
+ }
+}
+
+func (m *mapping) clientFin(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = ClientFinRcv
+ } else if m.state == ServerFinRcv {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = ClientServerFinRcv
+ }
+}
+
+func (m *mapping) serverFin(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = ServerFinRcv
+ } else if m.state == ClientFinRcv {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = ClientServerFinRcv
+ }
+}
+
+func (m *mapping) dataPkt(ct *Conntrack) {
+ if m.state == Trans {
+ ct.updateTimer(m, TCP_EST_TIMEOUT_MS)
+ m.state = Established
+ }
+}
+
+func (m *mapping) transitionTimeout(ct *Conntrack) {
+ if m.state == Trans || m.state == ClientServerFinRcv {
+ // TODO remove connection
+ m.state = Closed
+ }
+}
+
+func (m *mapping) establishedTimeout(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = Trans
+ }
+}
+
+type PriorityQueue []*mapping
+
+func (pq PriorityQueue) Len() int { return len(pq) }
+
+func (pq PriorityQueue) Less(i, j int) bool {
+ // We want Pop to give us the lowest, not highest, priority so
+ // we use less than here.
+ return pq[i].timeout.Before(pq[j].timeout)
+}
+
+func (pq PriorityQueue) Swap(i, j int) {
+ pq[i], pq[j] = pq[j], pq[i]
+ pq[i].index = i
+ pq[j].index = j
+}
+
+func (pq *PriorityQueue) Push(x interface{}) {
+ n := len(*pq)
+ item := x.(*mapping)
+ item.index = n
+ *pq = append(*pq, item)
+}
+
+func (pq *PriorityQueue) Pop() interface{} {
+ old := *pq
+ n := len(old)
+ item := old[n-1]
+ old[n-1] = nil // avoid memory leak
+ item.index = -1 // for safety
+ *pq = old[0 : n-1]
+ return item
+}
+
+// update modifies the priority and value of an Item in the queue.
+func (pq *PriorityQueue) update(c *mapping, timeout time.Time) {
+ c.timeout = timeout
+ heap.Fix(pq, c.index)
+}
+
type Conntrack struct {
- connections map[connection]connection
- connectionsMutex sync.RWMutex
+ mappings map[connection]*mapping
+ mappingsMutex sync.RWMutex
+ timeouts PriorityQueue
+ timer *time.Timer
l *device.Logger
}
func NewConntrack(logger *device.Logger) *Conntrack {
+ timeouts := make(PriorityQueue, 0)
+ heap.Init(&timeouts)
return &Conntrack{
- connections: make(map[connection]connection),
- connectionsMutex: sync.RWMutex{},
+ mappings: make(map[connection]*mapping),
+ mappingsMutex: sync.RWMutex{},
+ timeouts: timeouts,
+ timer: nil,
l: logger,
}
}
-func (ct *Conntrack) addConnection(new, orig connection) {
- ct.connectionsMutex.Lock()
- ct.connections[new] = orig
- ct.connectionsMutex.Unlock()
+// mappingsMutex must be held
+func (ct *Conntrack) updateTimer(m *mapping, d time.Duration) {
+ if m.index == -1 {
+ m.timeout = time.Now().Add(d)
+ heap.Push(&ct.timeouts, m)
+ } else {
+ ct.timeouts.update(m, time.Now().Add(d))
+ }
+ firstDuration := ct.timeouts[0].timeout.Sub(time.Now())
+ if m.index == 0 && ct.timer != nil {
+ ct.l.Verbosef("updateTimer stop")
+ // FIXME deadlock if called from AfterFunc
+ if !ct.timer.Stop() {
+ ct.l.Verbosef("stop returned false")
+ //<-ct.timer.C
+ // TODO handle timeoute
+ }
+ ct.l.Verbosef("updateTimer reset")
+ ct.timer.Reset(firstDuration)
+ ct.l.Verbosef("updateTimer reset done")
+ } else if ct.timer == nil {
+ ct.timer = time.NewTimer(firstDuration)
+
+ go func(ch <-chan time.Time) {
+ for {
+ t := <-ch
+ ct.l.Verbosef("Timer %v", t)
+ ct.mappingsMutex.Lock()
+ ct.l.Verbosef("Timer locked")
+ now := time.Now()
+ for ct.timeouts.Len() > 0 {
+ ct.l.Verbosef("Timer loop %d", ct.timeouts.Len())
+ m := ct.timeouts[0]
+ if m.timeout.Before(now) {
+ ct.l.Verbosef("Timer Before")
+ heap.Pop(&ct.timeouts)
+ if m.state == Established {
+ ct.l.Verbosef("Timer established")
+ // FIXME use channel
+ m.establishedTimeout(ct)
+ } else {
+ ct.l.Verbosef("Timer transition")
+ m.transitionTimeout(ct)
+ }
+ if m.state == Closed {
+ ct.l.Verbosef("Timer closed")
+ delete(ct.mappings, m.orig)
+ delete(ct.mappings, m.nat)
+ }
+ } else {
+ break
+ }
+ }
+ ct.l.Verbosef("Timer loop done")
+ if ct.timeouts.Len() > 0 {
+ ct.timer.Reset(ct.timeouts[0].timeout.Sub(time.Now()))
+ }
+ ct.mappingsMutex.Unlock()
+ ct.l.Verbosef("Timer unlocked")
+ }
+ }(ct.timer.C)
+ } else {
+ ct.l.Verbosef("Timer update not needed")
+ }
+}
+
+func (ct *Conntrack) addConnection(orig, nat connection) {
+ ct.l.Verbosef("addConnection")
+ ct.mappingsMutex.Lock()
+ m := &mapping{
+ orig: orig,
+ nat: nat,
+ state: Init,
+ timeout: time.Now().Add(TCP_TRANS_TIMEOUT_MS),
+ index: -1,
+ }
+ ct.mappings[orig] = m
+ ct.mappings[nat] = m
+ heap.Push(&ct.timeouts, m)
+ ct.mappingsMutex.Unlock()
}
-func (ct *Conntrack) lookupConnection(new connection) (connection, bool) {
- ct.connectionsMutex.RLock()
- c, ok := ct.connections[new]
- ct.connectionsMutex.RUnlock()
- return c, ok
+func (ct *Conntrack) lookupConnection(c connection) (*mapping, bool) {
+ ct.l.Verbosef("lookupConnection")
+ ct.mappingsMutex.RLock()
+ m, ok := ct.mappings[c]
+ ct.mappingsMutex.RUnlock()
+ return m, ok
}
+func (ct *Conntrack) event(c connection, ev Event) {
+ ct.l.Verbosef("event")
+ ct.mappingsMutex.Lock()
+ m, ok := ct.mappings[c]
+ if !ok {
+ ct.mappingsMutex.Unlock()
+ // FIXME
+ return
+ }
+ if ev == ClientSyn {
+ m.clientSyn(ct)
+ } else if ev == ServerSyn {
+ m.serverSyn(ct)
+ } else if ev == ClientRst {
+ m.clientRst(ct)
+ } else if ev == ServerRst {
+ m.serverRst(ct)
+ } else if ev == ClientFin {
+ m.clientFin(ct)
+ } else if ev == ServerFin {
+ m.serverFin(ct)
+ } else if ev == DataPkt {
+ m.dataPkt(ct)
+ }
+ ct.mappingsMutex.Unlock()
+}
diff --git a/tunnel/tools/libwg-go/http-proxy.go b/tunnel/tools/libwg-go/http-proxy.go
index 7e2bd32f..844ffae5 100644
--- a/tunnel/tools/libwg-go/http-proxy.go
+++ b/tunnel/tools/libwg-go/http-proxy.go
@@ -604,11 +604,11 @@ func (h *HttpHandler) addConnToProxyMap(c net.Conn) bool {
}
newConn := connection{src: local, dst: remote}
- oldConn, ok := h.p.conntrack.lookupConnection(newConn)
+ m, ok := h.p.conntrack.lookupConnection(newConn)
if ok {
- local = oldConn.src
- remote = oldConn.dst
+ local = m.orig.src
+ remote = m.orig.dst
h.logger.Verbosef("Before NAT: %v -> %v", local, remote)
} else if remote.Addr().IsLoopback() {
h.logger.Verbosef("Loopback request")
diff --git a/tunnel/tools/libwg-go/nat-tun.go b/tunnel/tools/libwg-go/nat-tun.go
index 75181ba0..6b37a9d8 100644
--- a/tunnel/tools/libwg-go/nat-tun.go
+++ b/tunnel/tools/libwg-go/nat-tun.go
@@ -6,6 +6,7 @@
package main
// TODO debug IPv6 NAT
+// TODO implement TCP session tracking according to RFC 7857,
import (
"encoding/binary"
@@ -57,17 +58,19 @@ func (tun *natTun) addTranslation(ipv6 bool, dstPort int, proxyPort int) {
if ipv6 {
ipVersion = IPV6_VERSION
- //srcStr = "fe80::1"
- //proxyStr = "fe80::2"
- srcStr = "fdba:4b51:1606:a61d::1"
- proxyStr = "fdba:4b51:1606:a61d::2"
- // loop = false
- // proxyStr = "2001:470:de6f:2fff::f780"
- // proxyPort = 8888
+ //srcStr = "fdba:4b51:1606:a61d::1"
+ //proxyStr = "fdba:4b51:1606:a61d::2"
+ loop = false
+ proxyStr = "2001:470:de6f:2fff::f780"
+ proxyPort = 8888
} else {
ipVersion = IPV4_VERSION
- srcStr = "169.254.0.1"
- proxyStr = "169.254.0.2"
+ //srcStr = "169.254.0.1"
+ //proxyStr = "169.254.0.2"
+ loop = false
+ proxyStr = "10.49.32.1"
+ proxyPort = 8888
+
}
if srcStr != "" {
@@ -132,8 +135,13 @@ const (
TCP_HEADER_LEN = 40
TCP_HEADER_SRC_PORT = 0
TCP_HEADER_DST_PORT = 2
+ TCP_HEADER_FLAGS = 12
TCP_HEADER_CHECKSUM = 16
+ TCP_FLAGS_FIN = 0x01
+ TCP_FLAGS_SYN = 0x02
+ TCP_FLAGS_RST = 0x04
+
UDP_HEADER_SRC_PORT = 0
UDP_HEADER_DST_PORT = 2
UDP_HEADER_CHECKSUM = 6
@@ -429,30 +437,110 @@ func getIPVersion(header []byte, len int) (int, error) {
}
}
+func testFlag(flags int, flag int) bool {
+ return (flags & flag) != 0
+}
+
func (tun *natTun) Read(buf []byte, offset int) (int, error) {
len, err := tun.tun.Read(buf, offset)
- if err == nil && len > 0 {
- header := buf[offset:]
- version, err := getIPVersion(header, len)
+ if err != nil || len <= 0 {
+ return len, err
+ }
+
+ header := buf[offset:]
+ version, err := getIPVersion(header, len)
+
+ if err != nil {
+ return len, nil
+ }
- if err != nil {
- // Ignore bad packet
- } else if version == IPV4_VERSION || version == IPV6_VERSION {
- origDstPort := getDstPort(header, version)
+ if version == IPV4_VERSION || version == IPV6_VERSION {
+ transport, transportPayload := getTransport(header, version)
+ if (transport != PROTO_TCP) {
+ return len, nil
+ }
+
+ flags := getUint16(transportPayload, TCP_HEADER_FLAGS) & 0x1f
+ srcAddr := getSrcAddr(header, version)
+ srcPort := getSrcPort(header, version)
+ dstAddr := getDstAddr(header, version)
+ dstPort := getDstPort(header, version)
+
+ src := netip.AddrPortFrom(srcAddr, uint16(srcPort))
+ dst := netip.AddrPortFrom(dstAddr, uint16(dstPort))
+ conn := Connection(src, dst)
+
+ m, ok := tun.conntrack.lookupConnection(conn)
+ if ok {
+ tun.l.Verbosef("Found mapping: [%v] %v", conn, m)
+ var event Event
+ var newConn connection
+ if m.orig == conn {
+ // Client
+ newConn = m.nat
+ tun.l.Verbosef("Client: %v", newConn)
+ if testFlag(flags, TCP_FLAGS_SYN) {
+ event = ClientSyn
+ } else if testFlag(flags, TCP_FLAGS_FIN) {
+ event = ClientFin
+ } else if testFlag(flags, TCP_FLAGS_RST) {
+ event = ClientRst
+ }
+ } else {
+ // Server
+ newConn = m.orig
+ tun.l.Verbosef("Server: %v", newConn)
+ if testFlag(flags, TCP_FLAGS_SYN) {
+ event = ServerSyn
+ } else if testFlag(flags, TCP_FLAGS_FIN) {
+ event = ServerFin
+ } else if testFlag(flags, TCP_FLAGS_RST) {
+ event = ServerRst
+ } else {
+ event = DataPkt
+ }
+ }
+
+ if event != None {
+ tun.l.Verbosef("Event: %v", event)
+ tun.conntrack.event(conn, event)
+ }
+
+ if m.state == Closed {
+ // TODO log
+ return 0, fmt.Errorf("Mapping closed")
+ }
+
+ updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return newConn.dst.Addr() })
+ updateSrcPort(header, version, func(int) int { return int(newConn.dst.Port()) })
+ updateDstAddr(header, version, func(netip.Addr) netip.Addr { return newConn.src.Addr() })
+ updateDstPort(header, version, func(int) int { return int(newConn.src.Port()) })
+ // Write back to tunnel
+ writeBuf := buf[:offset+len]
+ written, err := tun.Write(writeBuf, offset)
+ tun.l.Verbosef("Written: %v %v", written, err)
+ if err != nil {
+ return 0, err
+ } else if written != len {
+ return 0, fmt.Errorf("NAT rev buffer partly written %v != %v", written, len)
+ } else {
+ return 0, nil
+ }
+ } else {
+ if !testFlag(flags, TCP_FLAGS_SYN) {
+ return len, nil
+ }
for _, entry := range(tun.translations) {
if version != entry.ipVersion {
continue
}
- if origDstPort != entry.dstPort {
+ if dstPort != entry.dstPort {
continue
}
-
- srcPort := getSrcPort(header, version)
+
newDstPort := 0
- var origSrcAddr netip.Addr
- var origDstAddr netip.Addr
var newSrcAddr netip.Addr
var newDstAddr netip.Addr
@@ -463,34 +551,34 @@ func (tun *natTun) Read(buf []byte, offset int) (int, error) {
if entry.srcAddr.IsValid() {
updateSrcAddr(header, version, func(addr netip.Addr) netip.Addr {
- origSrcAddr = addr
newSrcAddr = entry.srcAddr
return newSrcAddr
})
} else {
- origSrcAddr = getSrcAddr(header, version)
- newSrcAddr = origSrcAddr
+ newSrcAddr = srcAddr
}
updateDstAddr(header, version, func(addr netip.Addr) netip.Addr {
- origDstAddr = addr
newDstAddr = entry.proxyAddr
return newDstAddr
})
- orig := Connection(netip.AddrPortFrom(origSrcAddr, uint16(srcPort)),
- netip.AddrPortFrom(origDstAddr, uint16(origDstPort)))
- new := Connection(netip.AddrPortFrom(newSrcAddr, uint16(srcPort)),
- netip.AddrPortFrom(newDstAddr, uint16(newDstPort)))
- tun.conntrack.addConnection(new, orig)
+ orig := Connection(netip.AddrPortFrom(srcAddr, uint16(srcPort)),
+ netip.AddrPortFrom(dstAddr, uint16(dstPort)))
+ nat := Connection(netip.AddrPortFrom(newDstAddr, uint16(newDstPort)),
+ netip.AddrPortFrom(newSrcAddr, uint16(srcPort)))
+ tun.l.Verbosef("New conn: %v -> %v", orig, nat)
+
+ tun.conntrack.addConnection(orig, nat)
if !entry.loop {
return len, nil
}
-
+
// Write back to tunnel
writeBuf := buf[:offset+len]
written, err := tun.Write(writeBuf, offset)
+ tun.l.Verbosef("Written: %v %v", written, err)
if err != nil {
return 0, err
} else if written != len {
@@ -499,7 +587,27 @@ func (tun *natTun) Read(buf []byte, offset int) (int, error) {
return 0, nil
}
}
+ }
+ }
+
+ return len, err
+}
+func (tun *natTun) Write(buf []byte, offset int) (int, error) {
+ len := len(buf)
+
+ if len <= 0 {
+ return len, nil
+ }
+
+ header := buf[offset:]
+ version, err := getIPVersion(header, len)
+
+ if err == nil && (version == IPV4_VERSION || version == IPV6_VERSION) {
+ transport, transportPayload := getTransport(header, version)
+ if (transport == PROTO_TCP) {
+ flags := getUint16(transportPayload, TCP_HEADER_FLAGS) & 0x1f
+ // TODO remove unused code
srcAddr := getSrcAddr(header, version)
srcPort := getSrcPort(header, version)
dstAddr := getDstAddr(header, version)
@@ -507,74 +615,57 @@ func (tun *natTun) Read(buf []byte, offset int) (int, error) {
src := netip.AddrPortFrom(srcAddr, uint16(srcPort))
dst := netip.AddrPortFrom(dstAddr, uint16(dstPort))
- new := Connection(dst, src)
+ conn := Connection(src, dst)
- orig, ok := tun.conntrack.lookupConnection(new)
+ m, ok := tun.conntrack.lookupConnection(conn)
if ok {
- updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return orig.dst.Addr() })
- updateSrcPort(header, version, func(int) int { return int(orig.dst.Port()) })
- updateDstAddr(header, version, func(netip.Addr) netip.Addr { return orig.src.Addr() })
- updateDstPort(header, version, func(int) int { return int(orig.src.Port()) })
- // Write back to tunnel
- writeBuf := buf[:offset+len]
- written, err := tun.Write(writeBuf, offset)
- if err != nil {
- return 0, err
- } else if written != len {
- return 0, fmt.Errorf("NAT rev buffer partly written %v != %v", written, len)
+ tun.l.Verbosef("Write: Found mapping: [%v] %v", conn, m)
+ var event Event
+ var newConn connection
+
+ if m.orig == conn {
+ // Client
+ newConn = m.nat
+ tun.l.Verbosef("Client: %v", newConn)
+ if testFlag(flags, TCP_FLAGS_SYN) {
+ event = ClientSyn
+ } else if testFlag(flags, TCP_FLAGS_FIN) {
+ event = ClientFin
+ } else if testFlag(flags, TCP_FLAGS_RST) {
+ event = ClientRst
+ }
} else {
- return 0, nil
+ // Server
+ newConn = m.orig
+ tun.l.Verbosef("Server: %v", newConn)
+ if testFlag(flags, TCP_FLAGS_SYN) {
+ event = ServerSyn
+ } else if testFlag(flags, TCP_FLAGS_FIN) {
+ event = ServerFin
+ } else if testFlag(flags, TCP_FLAGS_RST) {
+ event = ServerRst
+ } else {
+ event = DataPkt
+ }
+ }
+
+ if event != None {
+ tun.l.Verbosef("Event: %v", event)
+ tun.conntrack.event(conn, event)
}
- }
- // protocol := int(header[9])
-
- // if protocol == PROTO_TCP && len >= (IPV4_HEADER_LEN + TCP_HEADER_LEN) {
- // tcpHeader := header[IPV4_HEADER_LEN:]
- // updateTcpDstPort(tcpHeader, func(port int) int { return port + 1 })
- // }
- //} else if version == IPV6_VERSION {
- // updateSrcAddr(header, IPV6_HEADER_SRC_ADDR, func(netip.Addr) { return extAddr })
- // nextHeader := int(header[6])
-
- // if nextHeader == PROTO_TCP && len >= (IPV6_HEADER_LEN + TCP_HEADER_LEN) {
- // tcpHeader := header[IPV6_HEADER_LEN:]
- // updateTcpDstPort(tcpHeader, func(port int) int { return port + 1 })
- // }
- }
- }
- return len, err
-}
-
-func (tun *natTun) Write(buf []byte, offset int) (int, error) {
- len := len(buf)
+ if m.state == Closed {
+ // TODO log
+ return 0, fmt.Errorf("Mapping closed")
+ }
- if len > 0 {
- header := buf[offset:]
- version, err := getIPVersion(header, len)
-
- if err != nil {
- // Ignore bad packet
- } else if version == IPV4_VERSION || version == IPV6_VERSION {
- srcAddr := getSrcAddr(header, version)
- srcPort := getSrcPort(header, version)
- dstAddr := getDstAddr(header, version)
- dstPort := getDstPort(header, version)
-
- src := netip.AddrPortFrom(srcAddr, uint16(srcPort))
- dst := netip.AddrPortFrom(dstAddr, uint16(dstPort))
- new := Connection(dst, src)
-
- orig, ok := tun.conntrack.lookupConnection(new)
-
- if ok {
- updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return orig.dst.Addr() })
- updateSrcPort(header, version, func(int) int { return int(orig.dst.Port()) })
- updateDstAddr(header, version, func(netip.Addr) netip.Addr { return orig.src.Addr() })
- updateDstPort(header, version, func(int) int { return int(orig.src.Port()) })
- }
- }
- }
+ updateSrcAddr(header, version, func(netip.Addr) netip.Addr { return newConn.dst.Addr() })
+ updateSrcPort(header, version, func(int) int { return int(newConn.dst.Port()) })
+ updateDstAddr(header, version, func(netip.Addr) netip.Addr { return newConn.src.Addr() })
+ updateDstPort(header, version, func(int) int { return int(newConn.src.Port()) })
+ }
+ }
+ }
return tun.tun.Write(buf, offset)
}