diff options
Diffstat (limited to 'tunnel/tools/libwg-go/conntrack.go')
-rw-r--r-- | tunnel/tools/libwg-go/conntrack.go | 330 |
1 files changed, 317 insertions, 13 deletions
diff --git a/tunnel/tools/libwg-go/conntrack.go b/tunnel/tools/libwg-go/conntrack.go index 17289a6f..9b34b5ce 100644 --- a/tunnel/tools/libwg-go/conntrack.go +++ b/tunnel/tools/libwg-go/conntrack.go @@ -1,12 +1,83 @@ package main import ( + "container/heap" + "fmt" "sync" + "time" "golang.zx2c4.com/go118/netip" "golang.zx2c4.com/wireguard/device" ) +const ( + TCP_TRANS_TIMEOUT_MS = 4 * 60 * 1000 // 4 minutes + TCP_EST_TIMEOUT_MS = 1 * 24 * 60 * 60 * 1000 // 1 day +) + +// RFC 7857 section 2. TCP Session Tracking +type State int8 + +const ( + Closed State = iota + Trans + Init + Established + ClientFinRcv + ServerFinRcv + ClientServerFinRcv +) + +func (st State) String() string { + stateNames := []string{ + "Closed", + "Trans", + "Init", + "Established", + "ClientFinRcv", + "ServerFinRcv", + "ClientServerFinRcv", + } + + if int(st) < len(stateNames) { + return stateNames[st] + } else { + return "Unknown" + } +} + +type Event int8 + +const ( + None Event = iota + ClientSyn + ClientRst + ClientFin + ServerSyn + ServerRst + ServerFin + DataPkt +) + +func (ev Event) String() string { + eventNames := []string{ + "None", + "ClientSyn", + "ClientRst", + "ClientFin", + "ServerSyn", + "ServerRst", + "ServerFin", + "DataPkt", + } + + if int(ev) < len(eventNames) { + return eventNames[ev] + } else { + return "Unknown" + } +} + type connection struct { src netip.AddrPort dst netip.AddrPort @@ -19,30 +90,263 @@ func Connection(src, dst netip.AddrPort) connection { } } +func (c connection) String() string { + return fmt.Sprintf("%v -> %v", c.src, c.dst) +} + +type mapping struct { + orig connection + nat connection + state State + timeout time.Time // Priority in queue + index int // Index in the queue +} + +func (m *mapping) String() string { + return fmt.Sprintf("state:%v orig:%v, nat:%v", m.state, m.orig, m.nat) +} + +func (m *mapping) clientSyn(ct *Conntrack) { + if m.state == Closed { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = Init + } +} + +func (m *mapping) serverSyn(ct *Conntrack) { + if m.state == Init { + ct.updateTimer(m, TCP_EST_TIMEOUT_MS) + m.state = Established + } +} + +func (m *mapping) clientRst(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = Trans + } +} + +func (m *mapping) serverRst(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = Trans + } +} + +func (m *mapping) clientFin(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = ClientFinRcv + } else if m.state == ServerFinRcv { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = ClientServerFinRcv + } +} + +func (m *mapping) serverFin(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = ServerFinRcv + } else if m.state == ClientFinRcv { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = ClientServerFinRcv + } +} + +func (m *mapping) dataPkt(ct *Conntrack) { + if m.state == Trans { + ct.updateTimer(m, TCP_EST_TIMEOUT_MS) + m.state = Established + } +} + +func (m *mapping) transitionTimeout(ct *Conntrack) { + if m.state == Trans || m.state == ClientServerFinRcv { + // TODO remove connection + m.state = Closed + } +} + +func (m *mapping) establishedTimeout(ct *Conntrack) { + if m.state == Established { + ct.updateTimer(m, TCP_TRANS_TIMEOUT_MS) + m.state = Trans + } +} + +type PriorityQueue []*mapping + +func (pq PriorityQueue) Len() int { return len(pq) } + +func (pq PriorityQueue) Less(i, j int) bool { + // We want Pop to give us the lowest, not highest, priority so + // we use less than here. + return pq[i].timeout.Before(pq[j].timeout) +} + +func (pq PriorityQueue) Swap(i, j int) { + pq[i], pq[j] = pq[j], pq[i] + pq[i].index = i + pq[j].index = j +} + +func (pq *PriorityQueue) Push(x interface{}) { + n := len(*pq) + item := x.(*mapping) + item.index = n + *pq = append(*pq, item) +} + +func (pq *PriorityQueue) Pop() interface{} { + old := *pq + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.index = -1 // for safety + *pq = old[0 : n-1] + return item +} + +// update modifies the priority and value of an Item in the queue. +func (pq *PriorityQueue) update(c *mapping, timeout time.Time) { + c.timeout = timeout + heap.Fix(pq, c.index) +} + type Conntrack struct { - connections map[connection]connection - connectionsMutex sync.RWMutex + mappings map[connection]*mapping + mappingsMutex sync.RWMutex + timeouts PriorityQueue + timer *time.Timer l *device.Logger } func NewConntrack(logger *device.Logger) *Conntrack { + timeouts := make(PriorityQueue, 0) + heap.Init(&timeouts) return &Conntrack{ - connections: make(map[connection]connection), - connectionsMutex: sync.RWMutex{}, + mappings: make(map[connection]*mapping), + mappingsMutex: sync.RWMutex{}, + timeouts: timeouts, + timer: nil, l: logger, } } -func (ct *Conntrack) addConnection(new, orig connection) { - ct.connectionsMutex.Lock() - ct.connections[new] = orig - ct.connectionsMutex.Unlock() +// mappingsMutex must be held +func (ct *Conntrack) updateTimer(m *mapping, d time.Duration) { + if m.index == -1 { + m.timeout = time.Now().Add(d) + heap.Push(&ct.timeouts, m) + } else { + ct.timeouts.update(m, time.Now().Add(d)) + } + firstDuration := ct.timeouts[0].timeout.Sub(time.Now()) + if m.index == 0 && ct.timer != nil { + ct.l.Verbosef("updateTimer stop") + // FIXME deadlock if called from AfterFunc + if !ct.timer.Stop() { + ct.l.Verbosef("stop returned false") + //<-ct.timer.C + // TODO handle timeoute + } + ct.l.Verbosef("updateTimer reset") + ct.timer.Reset(firstDuration) + ct.l.Verbosef("updateTimer reset done") + } else if ct.timer == nil { + ct.timer = time.NewTimer(firstDuration) + + go func(ch <-chan time.Time) { + for { + t := <-ch + ct.l.Verbosef("Timer %v", t) + ct.mappingsMutex.Lock() + ct.l.Verbosef("Timer locked") + now := time.Now() + for ct.timeouts.Len() > 0 { + ct.l.Verbosef("Timer loop %d", ct.timeouts.Len()) + m := ct.timeouts[0] + if m.timeout.Before(now) { + ct.l.Verbosef("Timer Before") + heap.Pop(&ct.timeouts) + if m.state == Established { + ct.l.Verbosef("Timer established") + // FIXME use channel + m.establishedTimeout(ct) + } else { + ct.l.Verbosef("Timer transition") + m.transitionTimeout(ct) + } + if m.state == Closed { + ct.l.Verbosef("Timer closed") + delete(ct.mappings, m.orig) + delete(ct.mappings, m.nat) + } + } else { + break + } + } + ct.l.Verbosef("Timer loop done") + if ct.timeouts.Len() > 0 { + ct.timer.Reset(ct.timeouts[0].timeout.Sub(time.Now())) + } + ct.mappingsMutex.Unlock() + ct.l.Verbosef("Timer unlocked") + } + }(ct.timer.C) + } else { + ct.l.Verbosef("Timer update not needed") + } +} + +func (ct *Conntrack) addConnection(orig, nat connection) { + ct.l.Verbosef("addConnection") + ct.mappingsMutex.Lock() + m := &mapping{ + orig: orig, + nat: nat, + state: Init, + timeout: time.Now().Add(TCP_TRANS_TIMEOUT_MS), + index: -1, + } + ct.mappings[orig] = m + ct.mappings[nat] = m + heap.Push(&ct.timeouts, m) + ct.mappingsMutex.Unlock() } -func (ct *Conntrack) lookupConnection(new connection) (connection, bool) { - ct.connectionsMutex.RLock() - c, ok := ct.connections[new] - ct.connectionsMutex.RUnlock() - return c, ok +func (ct *Conntrack) lookupConnection(c connection) (*mapping, bool) { + ct.l.Verbosef("lookupConnection") + ct.mappingsMutex.RLock() + m, ok := ct.mappings[c] + ct.mappingsMutex.RUnlock() + return m, ok } +func (ct *Conntrack) event(c connection, ev Event) { + ct.l.Verbosef("event") + ct.mappingsMutex.Lock() + m, ok := ct.mappings[c] + if !ok { + ct.mappingsMutex.Unlock() + // FIXME + return + } + if ev == ClientSyn { + m.clientSyn(ct) + } else if ev == ServerSyn { + m.serverSyn(ct) + } else if ev == ClientRst { + m.clientRst(ct) + } else if ev == ServerRst { + m.serverRst(ct) + } else if ev == ClientFin { + m.clientFin(ct) + } else if ev == ServerFin { + m.serverFin(ct) + } else if ev == DataPkt { + m.dataPkt(ct) + } + ct.mappingsMutex.Unlock() +} |