summaryrefslogtreecommitdiffhomepage
path: root/tunnel/tools/libwg-go/conntrack.go
diff options
context:
space:
mode:
Diffstat (limited to 'tunnel/tools/libwg-go/conntrack.go')
-rw-r--r--tunnel/tools/libwg-go/conntrack.go330
1 files changed, 317 insertions, 13 deletions
diff --git a/tunnel/tools/libwg-go/conntrack.go b/tunnel/tools/libwg-go/conntrack.go
index 17289a6f..9b34b5ce 100644
--- a/tunnel/tools/libwg-go/conntrack.go
+++ b/tunnel/tools/libwg-go/conntrack.go
@@ -1,12 +1,83 @@
package main
import (
+ "container/heap"
+ "fmt"
"sync"
+ "time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/device"
)
+const (
+ TCP_TRANS_TIMEOUT_MS = 4 * 60 * 1000 // 4 minutes
+ TCP_EST_TIMEOUT_MS = 1 * 24 * 60 * 60 * 1000 // 1 day
+)
+
+// RFC 7857 section 2. TCP Session Tracking
+type State int8
+
+const (
+ Closed State = iota
+ Trans
+ Init
+ Established
+ ClientFinRcv
+ ServerFinRcv
+ ClientServerFinRcv
+)
+
+func (st State) String() string {
+ stateNames := []string{
+ "Closed",
+ "Trans",
+ "Init",
+ "Established",
+ "ClientFinRcv",
+ "ServerFinRcv",
+ "ClientServerFinRcv",
+ }
+
+ if int(st) < len(stateNames) {
+ return stateNames[st]
+ } else {
+ return "Unknown"
+ }
+}
+
+type Event int8
+
+const (
+ None Event = iota
+ ClientSyn
+ ClientRst
+ ClientFin
+ ServerSyn
+ ServerRst
+ ServerFin
+ DataPkt
+)
+
+func (ev Event) String() string {
+ eventNames := []string{
+ "None",
+ "ClientSyn",
+ "ClientRst",
+ "ClientFin",
+ "ServerSyn",
+ "ServerRst",
+ "ServerFin",
+ "DataPkt",
+ }
+
+ if int(ev) < len(eventNames) {
+ return eventNames[ev]
+ } else {
+ return "Unknown"
+ }
+}
+
type connection struct {
src netip.AddrPort
dst netip.AddrPort
@@ -19,30 +90,263 @@ func Connection(src, dst netip.AddrPort) connection {
}
}
+func (c connection) String() string {
+ return fmt.Sprintf("%v -> %v", c.src, c.dst)
+}
+
+type mapping struct {
+ orig connection
+ nat connection
+ state State
+ timeout time.Time // Priority in queue
+ index int // Index in the queue
+}
+
+func (m *mapping) String() string {
+ return fmt.Sprintf("state:%v orig:%v, nat:%v", m.state, m.orig, m.nat)
+}
+
+func (m *mapping) clientSyn(ct *Conntrack) {
+ if m.state == Closed {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = Init
+ }
+}
+
+func (m *mapping) serverSyn(ct *Conntrack) {
+ if m.state == Init {
+ ct.updateTimer(m, TCP_EST_TIMEOUT_MS)
+ m.state = Established
+ }
+}
+
+func (m *mapping) clientRst(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = Trans
+ }
+}
+
+func (m *mapping) serverRst(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = Trans
+ }
+}
+
+func (m *mapping) clientFin(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = ClientFinRcv
+ } else if m.state == ServerFinRcv {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = ClientServerFinRcv
+ }
+}
+
+func (m *mapping) serverFin(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = ServerFinRcv
+ } else if m.state == ClientFinRcv {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = ClientServerFinRcv
+ }
+}
+
+func (m *mapping) dataPkt(ct *Conntrack) {
+ if m.state == Trans {
+ ct.updateTimer(m, TCP_EST_TIMEOUT_MS)
+ m.state = Established
+ }
+}
+
+func (m *mapping) transitionTimeout(ct *Conntrack) {
+ if m.state == Trans || m.state == ClientServerFinRcv {
+ // TODO remove connection
+ m.state = Closed
+ }
+}
+
+func (m *mapping) establishedTimeout(ct *Conntrack) {
+ if m.state == Established {
+ ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS)
+ m.state = Trans
+ }
+}
+
+type PriorityQueue []*mapping
+
+func (pq PriorityQueue) Len() int { return len(pq) }
+
+func (pq PriorityQueue) Less(i, j int) bool {
+ // We want Pop to give us the lowest, not highest, priority so
+ // we use less than here.
+ return pq[i].timeout.Before(pq[j].timeout)
+}
+
+func (pq PriorityQueue) Swap(i, j int) {
+ pq[i], pq[j] = pq[j], pq[i]
+ pq[i].index = i
+ pq[j].index = j
+}
+
+func (pq *PriorityQueue) Push(x interface{}) {
+ n := len(*pq)
+ item := x.(*mapping)
+ item.index = n
+ *pq = append(*pq, item)
+}
+
+func (pq *PriorityQueue) Pop() interface{} {
+ old := *pq
+ n := len(old)
+ item := old[n-1]
+ old[n-1] = nil // avoid memory leak
+ item.index = -1 // for safety
+ *pq = old[0 : n-1]
+ return item
+}
+
+// update modifies the priority and value of an Item in the queue.
+func (pq *PriorityQueue) update(c *mapping, timeout time.Time) {
+ c.timeout = timeout
+ heap.Fix(pq, c.index)
+}
+
type Conntrack struct {
- connections map[connection]connection
- connectionsMutex sync.RWMutex
+ mappings map[connection]*mapping
+ mappingsMutex sync.RWMutex
+ timeouts PriorityQueue
+ timer *time.Timer
l *device.Logger
}
func NewConntrack(logger *device.Logger) *Conntrack {
+ timeouts := make(PriorityQueue, 0)
+ heap.Init(&timeouts)
return &Conntrack{
- connections: make(map[connection]connection),
- connectionsMutex: sync.RWMutex{},
+ mappings: make(map[connection]*mapping),
+ mappingsMutex: sync.RWMutex{},
+ timeouts: timeouts,
+ timer: nil,
l: logger,
}
}
-func (ct *Conntrack) addConnection(new, orig connection) {
- ct.connectionsMutex.Lock()
- ct.connections[new] = orig
- ct.connectionsMutex.Unlock()
+// mappingsMutex must be held
+func (ct *Conntrack) updateTimer(m *mapping, d time.Duration) {
+ if m.index == -1 {
+ m.timeout = time.Now().Add(d)
+ heap.Push(&ct.timeouts, m)
+ } else {
+ ct.timeouts.update(m, time.Now().Add(d))
+ }
+ firstDuration := ct.timeouts[0].timeout.Sub(time.Now())
+ if m.index == 0 && ct.timer != nil {
+ ct.l.Verbosef("updateTimer stop")
+ // FIXME deadlock if called from AfterFunc
+ if !ct.timer.Stop() {
+ ct.l.Verbosef("stop returned false")
+ //<-ct.timer.C
+ // TODO handle timeoute
+ }
+ ct.l.Verbosef("updateTimer reset")
+ ct.timer.Reset(firstDuration)
+ ct.l.Verbosef("updateTimer reset done")
+ } else if ct.timer == nil {
+ ct.timer = time.NewTimer(firstDuration)
+
+ go func(ch <-chan time.Time) {
+ for {
+ t := <-ch
+ ct.l.Verbosef("Timer %v", t)
+ ct.mappingsMutex.Lock()
+ ct.l.Verbosef("Timer locked")
+ now := time.Now()
+ for ct.timeouts.Len() > 0 {
+ ct.l.Verbosef("Timer loop %d", ct.timeouts.Len())
+ m := ct.timeouts[0]
+ if m.timeout.Before(now) {
+ ct.l.Verbosef("Timer Before")
+ heap.Pop(&ct.timeouts)
+ if m.state == Established {
+ ct.l.Verbosef("Timer established")
+ // FIXME use channel
+ m.establishedTimeout(ct)
+ } else {
+ ct.l.Verbosef("Timer transition")
+ m.transitionTimeout(ct)
+ }
+ if m.state == Closed {
+ ct.l.Verbosef("Timer closed")
+ delete(ct.mappings, m.orig)
+ delete(ct.mappings, m.nat)
+ }
+ } else {
+ break
+ }
+ }
+ ct.l.Verbosef("Timer loop done")
+ if ct.timeouts.Len() > 0 {
+ ct.timer.Reset(ct.timeouts[0].timeout.Sub(time.Now()))
+ }
+ ct.mappingsMutex.Unlock()
+ ct.l.Verbosef("Timer unlocked")
+ }
+ }(ct.timer.C)
+ } else {
+ ct.l.Verbosef("Timer update not needed")
+ }
+}
+
+func (ct *Conntrack) addConnection(orig, nat connection) {
+ ct.l.Verbosef("addConnection")
+ ct.mappingsMutex.Lock()
+ m := &mapping{
+ orig: orig,
+ nat: nat,
+ state: Init,
+ timeout: time.Now().Add(TCP_TRANS_TIMEOUT_MS),
+ index: -1,
+ }
+ ct.mappings[orig] = m
+ ct.mappings[nat] = m
+ heap.Push(&ct.timeouts, m)
+ ct.mappingsMutex.Unlock()
}
-func (ct *Conntrack) lookupConnection(new connection) (connection, bool) {
- ct.connectionsMutex.RLock()
- c, ok := ct.connections[new]
- ct.connectionsMutex.RUnlock()
- return c, ok
+func (ct *Conntrack) lookupConnection(c connection) (*mapping, bool) {
+ ct.l.Verbosef("lookupConnection")
+ ct.mappingsMutex.RLock()
+ m, ok := ct.mappings[c]
+ ct.mappingsMutex.RUnlock()
+ return m, ok
}
+func (ct *Conntrack) event(c connection, ev Event) {
+ ct.l.Verbosef("event")
+ ct.mappingsMutex.Lock()
+ m, ok := ct.mappings[c]
+ if !ok {
+ ct.mappingsMutex.Unlock()
+ // FIXME
+ return
+ }
+ if ev == ClientSyn {
+ m.clientSyn(ct)
+ } else if ev == ServerSyn {
+ m.serverSyn(ct)
+ } else if ev == ClientRst {
+ m.clientRst(ct)
+ } else if ev == ServerRst {
+ m.serverRst(ct)
+ } else if ev == ClientFin {
+ m.clientFin(ct)
+ } else if ev == ServerFin {
+ m.serverFin(ct)
+ } else if ev == DataPkt {
+ m.dataPkt(ct)
+ }
+ ct.mappingsMutex.Unlock()
+}