diff options
-rw-r--r-- | src/noise_protocol.go | 3 | ||||
-rw-r--r-- | src/peer.go | 59 | ||||
-rw-r--r-- | src/receive.go | 3 | ||||
-rw-r--r-- | src/send.go | 19 | ||||
-rw-r--r-- | src/signal.go | 45 | ||||
-rw-r--r-- | src/timer.go | 65 | ||||
-rw-r--r-- | src/timers.go | 214 | ||||
-rw-r--r-- | src/uapi.go | 4 |
8 files changed, 249 insertions, 163 deletions
diff --git a/src/noise_protocol.go b/src/noise_protocol.go index 9e5fdd8..2f9e1d5 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -532,7 +532,6 @@ func (peer *Peer) NewKeyPair() *KeyPair { kp := &peer.keyPairs kp.mutex.Lock() - // TODO: Adapt kernel behavior noise.c:161 if isInitiator { if kp.previous != nil { device.DeleteKeyPair(kp.previous) @@ -545,7 +544,7 @@ func (peer *Peer) NewKeyPair() *KeyPair { } else { kp.previous = kp.current kp.current = keyPair - signalSend(peer.signal.newKeyPair) // TODO: This more places (after confirming the key) + peer.signal.newKeyPair.Send() } } else { diff --git a/src/peer.go b/src/peer.go index f3eb6c2..f582556 100644 --- a/src/peer.go +++ b/src/peer.go @@ -28,30 +28,26 @@ type Peer struct { nextKeepalive time.Time } signal struct { - newKeyPair chan struct{} // (size 1) : a new key pair was generated - handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake") - handshakeCompleted chan struct{} // (size 1) : handshake completed - handshakeReset chan struct{} // (size 1) : reset handshake negotiation state - flushNonceQueue chan struct{} // (size 1) : empty queued packets - messageSend chan struct{} // (size 1) : a message was send to the peer - messageReceived chan struct{} // (size 1) : an authenticated message was received - stop chan struct{} // (size 0) : close to stop all goroutines for peer + newKeyPair Signal // size 1, new key pair was generated + handshakeCompleted Signal // size 1, handshake completed + handshakeBegin Signal // size 1, begin new handshake begin + flushNonceQueue Signal // size 1, empty queued packets + messageSend Signal // size 1, message was send to peer + messageReceived Signal // size 1, authenticated message recv + stop Signal // size 0, stop all goroutines } timer struct { // state related to WireGuard timers - keepalivePersistent *time.Timer // set for persistent keepalives - keepalivePassive *time.Timer // set upon recieving messages - newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout) - zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3) - handshakeDeadline *time.Timer // Current handshake must be completed + keepalivePersistent Timer // set for persistent keepalives + keepalivePassive Timer // set upon recieving messages + newHandshake Timer // begin a new handshake (stale) + zeroAllKeys Timer // zero all key material + handshakeDeadline Timer // complete handshake timeout + handshakeTimeout Timer // current handshake message timeout - pendingKeepalivePassive bool - pendingNewHandshake bool - pendingZeroAllKeys bool - - needAnotherKeepalive bool sendLastMinuteHandshake bool + needAnotherKeepalive bool } queue struct { nonce chan *QueueOutboundElement // nonce / pre-handshake queue @@ -71,10 +67,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.mac.Init(pk) peer.device = device - peer.timer.keepalivePersistent = NewStoppedTimer() - peer.timer.keepalivePassive = NewStoppedTimer() - peer.timer.newHandshake = NewStoppedTimer() - peer.timer.zeroAllKeys = NewStoppedTimer() + peer.timer.keepalivePersistent = NewTimer() + peer.timer.keepalivePassive = NewTimer() + peer.timer.newHandshake = NewTimer() + peer.timer.zeroAllKeys = NewTimer() + peer.timer.handshakeDeadline = NewTimer() + peer.timer.handshakeTimeout = NewTimer() // assign id for debugging @@ -102,7 +100,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake := &peer.handshake handshake.mutex.Lock() handshake.remoteStatic = pk - handshake.precomputedStaticStatic = device.privateKey.sharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic = + device.privateKey.sharedSecret(handshake.remoteStatic) handshake.mutex.Unlock() // reset endpoint @@ -117,16 +116,14 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // prepare signaling & routines - peer.signal.stop = make(chan struct{}) - peer.signal.newKeyPair = make(chan struct{}, 1) - peer.signal.handshakeBegin = make(chan struct{}, 1) - peer.signal.handshakeReset = make(chan struct{}, 1) - peer.signal.handshakeCompleted = make(chan struct{}, 1) - peer.signal.flushNonceQueue = make(chan struct{}, 1) + peer.signal.stop = NewSignal() + peer.signal.newKeyPair = NewSignal() + peer.signal.handshakeBegin = NewSignal() + peer.signal.handshakeCompleted = NewSignal() + peer.signal.flushNonceQueue = NewSignal() go peer.RoutineNonce() go peer.RoutineTimerHandler() - go peer.RoutineHandshakeInitiator() go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() @@ -163,5 +160,5 @@ func (peer *Peer) String() string { } func (peer *Peer) Close() { - close(peer.signal.stop) + peer.signal.stop.Broadcast() } diff --git a/src/receive.go b/src/receive.go index 0b0efbf..7d493b0 100644 --- a/src/receive.go +++ b/src/receive.go @@ -482,7 +482,8 @@ func (peer *Peer) RoutineSequentialReceiver() { for { select { - case <-peer.signal.stop: + + case <-peer.signal.stop.Wait(): logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) return diff --git a/src/send.go b/src/send.go index 52872f6..35a4a6e 100644 --- a/src/send.go +++ b/src/send.go @@ -164,7 +164,7 @@ func (device *Device) RoutineReadFromTUN() { // insert into nonce/pre-handshake queue - signalSend(peer.signal.handshakeReset) + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) addToOutboundQueue(peer.queue.nonce, elem) elem = device.NewOutboundElement() } @@ -186,7 +186,7 @@ func (peer *Peer) RoutineNonce() { for { NextPacket: select { - case <-peer.signal.stop: + case <-peer.signal.stop.Wait(): return case elem := <-peer.queue.nonce: @@ -201,16 +201,17 @@ func (peer *Peer) RoutineNonce() { } } - signalSend(peer.signal.handshakeBegin) + peer.signal.handshakeBegin.Send() + logDebug.Println("Awaiting key-pair for", peer.String()) select { - case <-peer.signal.newKeyPair: - case <-peer.signal.flushNonceQueue: + case <-peer.signal.newKeyPair.Wait(): + case <-peer.signal.flushNonceQueue.Wait(): logDebug.Println("Clearing queue for", peer.String()) peer.FlushNonceQueue() goto NextPacket - case <-peer.signal.stop: + case <-peer.signal.stop.Wait(): return } } @@ -309,8 +310,10 @@ func (peer *Peer) RoutineSequentialSender() { for { select { - case <-peer.signal.stop: - logDebug.Println("Routine, sequential sender, stopped for", peer.String()) + + case <-peer.signal.stop.Wait(): + logDebug.Println( + "Routine, sequential sender, stopped for", peer.String()) return case elem := <-peer.queue.outbound: diff --git a/src/signal.go b/src/signal.go new file mode 100644 index 0000000..96b21bb --- /dev/null +++ b/src/signal.go @@ -0,0 +1,45 @@ +package main + +type Signal struct { + enabled AtomicBool + C chan struct{} +} + +func NewSignal() (s Signal) { + s.C = make(chan struct{}, 1) + s.Enable() + return +} + +func (s *Signal) Disable() { + s.enabled.Set(false) + s.Clear() +} + +func (s *Signal) Enable() { + s.enabled.Set(true) +} + +func (s *Signal) Send() { + if s.enabled.Get() { + select { + case s.C <- struct{}{}: + default: + } + } +} + +func (s Signal) Clear() { + select { + case <-s.C: + default: + } +} + +func (s Signal) Broadcast() { + close(s.C) // unblocks all selectors +} + +func (s Signal) Wait() chan struct{} { + return s.C +} diff --git a/src/timer.go b/src/timer.go new file mode 100644 index 0000000..3def253 --- /dev/null +++ b/src/timer.go @@ -0,0 +1,65 @@ +package main + +import ( + "time" +) + +type Timer struct { + pending AtomicBool + timer *time.Timer +} + +/* Starts the timer if not already pending + */ +func (t *Timer) Start(dur time.Duration) bool { + set := t.pending.Swap(true) + if !set { + t.timer.Reset(dur) + return true + } + return false +} + +/* Stops the timer + */ +func (t *Timer) Stop() { + set := t.pending.Swap(true) + if set { + t.timer.Stop() + select { + case <-t.timer.C: + default: + } + } + t.pending.Set(false) +} + +func (t *Timer) Pending() bool { + return t.pending.Get() +} + +func (t *Timer) Reset(dur time.Duration) { + t.pending.Set(false) + t.Start(dur) +} + +func (t *Timer) Push(dur time.Duration) { + if t.pending.Get() { + t.Reset(dur) + } +} + +func (t *Timer) Wait() <-chan time.Time { + return t.timer.C +} + +func NewTimer() (t Timer) { + t.pending.Set(false) + t.timer = time.NewTimer(0) + t.timer.Stop() + select { + case <-t.timer.C: + default: + } + return +} diff --git a/src/timers.go b/src/timers.go index 5848b2a..64aeca8 100644 --- a/src/timers.go +++ b/src/timers.go @@ -18,10 +18,10 @@ func (peer *Peer) KeepKeyFreshSending() { }
nonce := atomic.LoadUint64(&kp.sendNonce)
if nonce > RekeyAfterMessages {
- signalSend(peer.signal.handshakeBegin)
+ peer.signal.handshakeBegin.Send()
}
if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
- signalSend(peer.signal.handshakeBegin)
+ peer.signal.handshakeBegin.Send()
}
}
@@ -44,7 +44,7 @@ func (peer *Peer) KeepKeyFreshReceiving() { send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
// do a last minute attempt at initiating a new handshake
- signalSend(peer.signal.handshakeBegin)
+ peer.signal.handshakeBegin.Send()
peer.timer.sendLastMinuteHandshake = true
}
}
@@ -69,34 +69,36 @@ func (peer *Peer) SendKeepAlive() bool { * Sent non-empty (authenticated) transport message
*/
func (peer *Peer) TimerDataSent() {
- timerStop(peer.timer.keepalivePassive)
- if !peer.timer.pendingNewHandshake {
- peer.timer.pendingNewHandshake = true
+ peer.timer.keepalivePassive.Stop()
+ if peer.timer.newHandshake.Pending() {
peer.timer.newHandshake.Reset(NewHandshakeTime)
}
}
/* Event:
* Received non-empty (authenticated) transport message
+ *
+ * Action:
+ * Set a timer to confirm the message using a keep-alive (if not already set)
*/
func (peer *Peer) TimerDataReceived() {
- if peer.timer.pendingKeepalivePassive {
+ if !peer.timer.keepalivePassive.Start(KeepaliveTimeout) {
peer.timer.needAnotherKeepalive = true
- return
}
- peer.timer.pendingKeepalivePassive = false
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
}
/* Event:
* Any (authenticated) packet received
*/
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
- timerStop(peer.timer.newHandshake)
+ peer.timer.newHandshake.Stop()
}
/* Event:
* Any authenticated packet send / received.
+ *
+ * Action:
+ * Push persistent keep-alive into the future
*/
func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
@@ -117,7 +119,7 @@ func (peer *Peer) TimerHandshakeComplete() { &peer.stats.lastHandshakeNano,
time.Now().UnixNano(),
)
- signalSend(peer.signal.handshakeCompleted)
+ peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
}
@@ -129,7 +131,8 @@ func (peer *Peer) TimerHandshakeComplete() { * CreateMessageInitiation
* CreateMessageResponse
*
- * Schedules the deletion of all key material
+ * Action:
+ * Schedule the deletion of all key material
* upon failure to complete a handshake
*/
func (peer *Peer) TimerEphemeralKeyCreated() {
@@ -139,18 +142,18 @@ func (peer *Peer) TimerEphemeralKeyCreated() { func (peer *Peer) RoutineTimerHandler() {
device := peer.device
+ logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
for {
select {
- case <-peer.signal.stop:
- return
+ /* timers */
- // keep-alives
+ // keep-alive
- case <-peer.timer.keepalivePersistent.C:
+ case <-peer.timer.keepalivePersistent.Wait():
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
@@ -158,7 +161,7 @@ func (peer *Peer) RoutineTimerHandler() { peer.SendKeepAlive()
}
- case <-peer.timer.keepalivePassive.C:
+ case <-peer.timer.keepalivePassive.Wait():
logDebug.Println("Sending keep-alive to", peer.String())
@@ -169,17 +172,9 @@ func (peer *Peer) RoutineTimerHandler() { peer.timer.needAnotherKeepalive = false
}
- // unresponsive session
+ // clear key material timer
- case <-peer.timer.newHandshake.C:
-
- logDebug.Println("Retrying handshake with", peer.String(), "due to lack of reply")
-
- signalSend(peer.signal.handshakeBegin)
-
- // clear key material
-
- case <-peer.timer.zeroAllKeys.C:
+ case <-peer.timer.zeroAllKeys.Wait():
logDebug.Println("Clearing all key material for", peer.String())
@@ -215,125 +210,106 @@ func (peer *Peer) RoutineTimerHandler() { setZero(hs.chainKey[:])
setZero(hs.hash[:])
hs.mutex.Unlock()
- }
- }
-}
-/* This is the state machine for handshake initiation
- *
- * Associated with this routine is the signal "handshakeBegin"
- * The routine will read from the "handshakeBegin" channel
- * at most every RekeyTimeout seconds
- */
-func (peer *Peer) RoutineHandshakeInitiator() {
- device := peer.device
+ // handshake timers
- logInfo := device.log.Info
- logError := device.log.Error
- logDebug := device.log.Debug
- logDebug.Println("Routine, handshake initiator, started for", peer.String())
+ case <-peer.timer.newHandshake.Wait():
+ logInfo.Println("Retrying handshake with", peer.String())
+ peer.signal.handshakeBegin.Send()
- var temp [256]byte
+ case <-peer.timer.handshakeTimeout.Wait():
- for {
+ // clear source (in case this is causing problems)
- // wait for signal
+ peer.mutex.Lock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ peer.mutex.Unlock()
- select {
- case <-peer.signal.handshakeBegin:
- case <-peer.signal.stop:
- return
- }
+ // send new handshake
- // set deadline
+ err := peer.sendNewHandshake()
+ if err != nil {
+ logInfo.Println(
+ "Failed to send handshake to peer:", peer.String())
+ }
- BeginHandshakes:
+ case <-peer.timer.handshakeDeadline.Wait():
- signalClear(peer.signal.handshakeReset)
- deadline := time.NewTimer(RekeyAttemptTime)
+ // clear all queued packets and stop keep-alive
- AttemptHandshakes:
+ logInfo.Println(
+ "Handshake negotiation timed out for:", peer.String())
- for attempts := uint(1); ; attempts++ {
+ peer.signal.flushNonceQueue.Send()
+ peer.timer.keepalivePersistent.Stop()
+ peer.signal.handshakeBegin.Enable()
- // check if deadline reached
+ /* signals */
- select {
- case <-deadline.C:
- logInfo.Println("Handshake negotiation timed out for:", peer.String())
- signalSend(peer.signal.flushNonceQueue)
- timerStop(peer.timer.keepalivePersistent)
- break
- case <-peer.signal.stop:
- return
- default:
- }
+ case <-peer.signal.stop.Wait():
+ return
- signalClear(peer.signal.handshakeCompleted)
+ case <-peer.signal.handshakeBegin.Wait():
- // create initiation message
+ peer.signal.handshakeBegin.Disable()
- msg, err := peer.device.CreateMessageInitiation(peer)
+ err := peer.sendNewHandshake()
if err != nil {
- logError.Println("Failed to create handshake initiation message:", err)
- break AttemptHandshakes
+ logInfo.Println(
+ "Failed to send handshake to peer:", peer.String())
}
- // marshal handshake message
-
- writer := bytes.NewBuffer(temp[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
-
- // send to endpoint
-
- err = peer.SendBuffer(packet)
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
- timeout := time.NewTimer(RekeyTimeout + jitter)
- if err == nil {
- peer.TimerAnyAuthenticatedPacketTraversal()
- logDebug.Println(
- "Handshake initiation attempt",
- attempts, "sent to", peer.String(),
- )
- } else {
- logError.Println(
- "Failed to send handshake initiation message to",
- peer.String(), ":", err,
- )
- }
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
- // wait for handshake or timeout
+ case <-peer.signal.handshakeCompleted.Wait():
- select {
+ logInfo.Println(
+ "Handshake completed for:", peer.String())
- case <-peer.signal.stop:
- return
+ peer.timer.handshakeTimeout.Stop()
+ peer.timer.handshakeDeadline.Stop()
+ peer.signal.handshakeBegin.Enable()
+ }
+ }
+}
- case <-peer.signal.handshakeCompleted:
- <-timeout.C
- peer.timer.sendLastMinuteHandshake = false
- break AttemptHandshakes
+/* Sends a new handshake initiation message to the peer (endpoint)
+ */
+func (peer *Peer) sendNewHandshake() error {
- case <-peer.signal.handshakeReset:
- <-timeout.C
- goto BeginHandshakes
+ // temporarily disable the handshake complete signal
- case <-timeout.C:
+ peer.signal.handshakeCompleted.Disable()
- // clear source address of peer
+ // create initiation message
- peer.mutex.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.mutex.Unlock()
- }
- }
+ msg, err := peer.device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ // marshal handshake message
- // clear signal set in the meantime
+ var buff [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
- signalClear(peer.signal.handshakeBegin)
+ // send to endpoint
+
+ err = peer.SendBuffer(packet)
+ if err == nil {
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.signal.handshakeCompleted.Enable()
}
+
+ // set timeout
+
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
+ peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
+
+ return err
}
diff --git a/src/uapi.go b/src/uapi.go index 7ab3c4a..155f483 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -221,7 +221,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return &IPCError{Code: ipcErrorInvalid} } } - signalSend(peer.signal.handshakeReset) + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) dummy = false } @@ -265,7 +265,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return err } peer.endpoint = endpoint - signalSend(peer.signal.handshakeReset) + peer.timer.handshakeDeadline.Reset(RekeyAttemptTime) return nil }() |