diff options
Diffstat (limited to 'device/timers.go')
-rw-r--r-- | device/timers.go | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/device/timers.go b/device/timers.go new file mode 100644 index 0000000..5f28fcc --- /dev/null +++ b/device/timers.go @@ -0,0 +1,227 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * + * This is based heavily on timers.c from the kernel implementation. + */ + +package device + +import ( + "math/rand" + "sync" + "sync/atomic" + "time" +) + +/* This Timer structure and related functions should roughly copy the interface of + * the Linux kernel's struct timer_list. + */ + +type Timer struct { + *time.Timer + modifyingLock sync.RWMutex + runningLock sync.Mutex + isPending bool +} + +func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { + timer := &Timer{} + timer.Timer = time.AfterFunc(time.Hour, func() { + timer.runningLock.Lock() + + timer.modifyingLock.Lock() + if !timer.isPending { + timer.modifyingLock.Unlock() + timer.runningLock.Unlock() + return + } + timer.isPending = false + timer.modifyingLock.Unlock() + + expirationFunction(peer) + timer.runningLock.Unlock() + }) + timer.Stop() + return timer +} + +func (timer *Timer) Mod(d time.Duration) { + timer.modifyingLock.Lock() + timer.isPending = true + timer.Reset(d) + timer.modifyingLock.Unlock() +} + +func (timer *Timer) Del() { + timer.modifyingLock.Lock() + timer.isPending = false + timer.Stop() + timer.modifyingLock.Unlock() +} + +func (timer *Timer) DelSync() { + timer.Del() + timer.runningLock.Lock() + timer.Del() + timer.runningLock.Unlock() +} + +func (timer *Timer) IsPending() bool { + timer.modifyingLock.RLock() + defer timer.modifyingLock.RUnlock() + return timer.isPending +} + +func (peer *Peer) timersActive() bool { + return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 +} + +func expiredRetransmitHandshake(peer *Peer) { + if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { + peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) + + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } + + /* We drop all packets without a keypair and don't try again, + * if we try unsuccessfully for too long to make a handshake. + */ + peer.FlushNonceQueue() + + /* We set a timer for destroying any residue that might be left + * of a partial exchange. + */ + if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() { + peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) + } + } else { + atomic.AddUint32(&peer.timers.handshakeAttempts, 1) + peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) + + /* We clear the endpoint address src address, in case this is the cause of trouble. */ + peer.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.Unlock() + + peer.SendHandshakeInitiation(true) + } +} + +func expiredSendKeepalive(peer *Peer) { + peer.SendKeepalive() + if peer.timers.needAnotherKeepalive.Get() { + peer.timers.needAnotherKeepalive.Set(false) + if peer.timersActive() { + peer.timers.sendKeepalive.Mod(KeepaliveTimeout) + } + } +} + +func expiredNewHandshake(peer *Peer) { + peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) + /* We clear the endpoint address src address, in case this is the cause of trouble. */ + peer.Lock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + peer.Unlock() + peer.SendHandshakeInitiation(false) + +} + +func expiredZeroKeyMaterial(peer *Peer) { + peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) + peer.ZeroAndFlushAll() +} + +func expiredPersistentKeepalive(peer *Peer) { + if peer.persistentKeepaliveInterval > 0 { + peer.SendKeepalive() + } +} + +/* Should be called after an authenticated data packet is sent. */ +func (peer *Peer) timersDataSent() { + if peer.timersActive() && !peer.timers.newHandshake.IsPending() { + peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout) + } +} + +/* Should be called after an authenticated data packet is received. */ +func (peer *Peer) timersDataReceived() { + if peer.timersActive() { + if !peer.timers.sendKeepalive.IsPending() { + peer.timers.sendKeepalive.Mod(KeepaliveTimeout) + } else { + peer.timers.needAnotherKeepalive.Set(true) + } + } +} + +/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */ +func (peer *Peer) timersAnyAuthenticatedPacketSent() { + if peer.timersActive() { + peer.timers.sendKeepalive.Del() + } +} + +/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */ +func (peer *Peer) timersAnyAuthenticatedPacketReceived() { + if peer.timersActive() { + peer.timers.newHandshake.Del() + } +} + +/* Should be called after a handshake initiation message is sent. */ +func (peer *Peer) timersHandshakeInitiated() { + if peer.timersActive() { + peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + } +} + +/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */ +func (peer *Peer) timersHandshakeComplete() { + if peer.timersActive() { + peer.timers.retransmitHandshake.Del() + } + atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.sentLastMinuteHandshake.Set(false) + atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) +} + +/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ +func (peer *Peer) timersSessionDerived() { + if peer.timersActive() { + peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) + } +} + +/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ +func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { + if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { + peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) + } +} + +func (peer *Peer) timersInit() { + peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake) + peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive) + peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) + peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) + peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) + atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.sentLastMinuteHandshake.Set(false) + peer.timers.needAnotherKeepalive.Set(false) +} + +func (peer *Peer) timersStop() { + peer.timers.retransmitHandshake.DelSync() + peer.timers.sendKeepalive.DelSync() + peer.timers.newHandshake.DelSync() + peer.timers.zeroKeyMaterial.DelSync() + peer.timers.persistentKeepalive.DelSync() +} |