summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-05-07 22:27:03 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-05-10 16:08:03 +0200
commit233f079a9479279d2aab68f4accb139ee87ad664 (patch)
tree338dfb681ffafbb53b81d353aa5612866ff935f5
parent375dcbd4aefc8054700dcb072a5e74a9ed7e9d39 (diff)
Rewrite timers and related state machines
-rw-r--r--constants.go27
-rw-r--r--device.go10
-rw-r--r--event.go43
-rw-r--r--index.go2
-rw-r--r--keypair.go14
-rw-r--r--main.go15
-rw-r--r--noise-protocol.go33
-rw-r--r--noise_test.go4
-rw-r--r--peer.go78
-rw-r--r--receive.go89
-rw-r--r--send.go134
-rw-r--r--signal.go71
-rw-r--r--timers.go476
-rw-r--r--uapi.go11
14 files changed, 429 insertions, 578 deletions
diff --git a/constants.go b/constants.go
index 04b75d7..01af1bb 100644
--- a/constants.go
+++ b/constants.go
@@ -12,21 +12,18 @@ import (
/* Specification constants */
const (
- RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
- RejectAfterMessages = (1 << 64) - (1 << 4) - 1
- RekeyAfterTime = time.Second * 120
- RekeyAttemptTime = time.Second * 90
- RekeyTimeout = time.Second * 5
- RejectAfterTime = time.Second * 180
- KeepaliveTimeout = time.Second * 10
- CookieRefreshTime = time.Second * 120
- HandshakeInitationRate = time.Second / 20
- PaddingMultiple = 16
-)
-
-const (
- RekeyAfterTimeReceiving = RejectAfterTime - KeepaliveTimeout - RekeyTimeout
- NewHandshakeTime = KeepaliveTimeout + RekeyTimeout // upon failure to acknowledge transport message
+ RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
+ RejectAfterMessages = (1 << 64) - (1 << 4) - 1
+ RekeyAfterTime = time.Second * 120
+ RekeyAttemptTime = time.Second * 90
+ RekeyTimeout = time.Second * 5
+ MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
+ RekeyTimeoutJitterMaxMs = 334
+ RejectAfterTime = time.Second * 180
+ KeepaliveTimeout = time.Second * 10
+ CookieRefreshTime = time.Second * 120
+ HandshakeInitationRate = time.Second / 20
+ PaddingMultiple = 16
)
/* Implementation specific constants */
diff --git a/device.go b/device.go
index c714b21..e127b5b 100644
--- a/device.go
+++ b/device.go
@@ -74,8 +74,8 @@ type Device struct {
handshake chan QueueHandshakeElement
}
- signal struct {
- stop Signal
+ signals struct {
+ stop chan struct{}
}
tun struct {
@@ -302,7 +302,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// prepare signals
- device.signal.stop = NewSignal()
+ device.signals.stop = make(chan struct{}, 1)
// prepare net
@@ -400,7 +400,7 @@ func (device *Device) Close() {
device.isUp.Set(false)
- device.signal.stop.Broadcast()
+ close(device.signals.stop)
device.state.stopping.Wait()
device.FlushPacketQueues()
@@ -413,5 +413,5 @@ func (device *Device) Close() {
}
func (device *Device) Wait() chan struct{} {
- return device.signal.stop.Wait()
+ return device.signals.stop
}
diff --git a/event.go b/event.go
deleted file mode 100644
index 6235ba4..0000000
--- a/event.go
+++ /dev/null
@@ -1,43 +0,0 @@
-package main
-
-import (
- "sync/atomic"
- "time"
-)
-
-type Event struct {
- guard int32
- next time.Time
- interval time.Duration
- C chan struct{}
-}
-
-func newEvent(interval time.Duration) *Event {
- return &Event{
- guard: 0,
- next: time.Now(),
- interval: interval,
- C: make(chan struct{}, 1),
- }
-}
-
-func (e *Event) Clear() {
- select {
- case <-e.C:
- default:
- }
-}
-
-func (e *Event) Fire() {
- if e == nil || atomic.SwapInt32(&e.guard, 1) != 0 {
- return
- }
- if now := time.Now(); now.After(e.next) {
- select {
- case e.C <- struct{}{}:
- default:
- }
- e.next = now.Add(e.interval)
- }
- atomic.StoreInt32(&e.guard, 0)
-}
diff --git a/index.go b/index.go
index c309f23..4a78d55 100644
--- a/index.go
+++ b/index.go
@@ -18,7 +18,7 @@ import (
type IndexTableEntry struct {
peer *Peer
handshake *Handshake
- keyPair *KeyPair
+ keyPair *Keypair
}
type IndexTable struct {
diff --git a/keypair.go b/keypair.go
index eaf30b2..07a183d 100644
--- a/keypair.go
+++ b/keypair.go
@@ -18,7 +18,7 @@ import (
* we plan to resolve this issue; whenever Go allows us to do so.
*/
-type KeyPair struct {
+type Keypair struct {
sendNonce uint64
send cipher.AEAD
receive cipher.AEAD
@@ -29,20 +29,20 @@ type KeyPair struct {
remoteIndex uint32
}
-type KeyPairs struct {
+type Keypairs struct {
mutex sync.RWMutex
- current *KeyPair
- previous *KeyPair
- next *KeyPair // not yet "confirmed by transport"
+ current *Keypair
+ previous *Keypair
+ next *Keypair // not yet "confirmed by transport"
}
-func (kp *KeyPairs) Current() *KeyPair {
+func (kp *Keypairs) Current() *Keypair {
kp.mutex.RLock()
defer kp.mutex.RUnlock()
return kp.current
}
-func (device *Device) DeleteKeyPair(key *KeyPair) {
+func (device *Device) DeleteKeypair(key *Keypair) {
if key != nil {
device.indices.Delete(key.localIndex)
}
diff --git a/main.go b/main.go
index ecfbc50..5001bc4 100644
--- a/main.go
+++ b/main.go
@@ -30,6 +30,8 @@ func printUsage() {
}
func warning() {
+ shouldQuit := false
+
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W This is alpha software. It will very likely not G")
@@ -37,6 +39,8 @@ func warning() {
fmt.Fprintln(os.Stderr, "W horribly wrong. You have been warned. Proceed G")
fmt.Fprintln(os.Stderr, "W at your own risk. G")
if runtime.GOOS == "linux" {
+ shouldQuit = os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
+
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W Furthermore, you are running this software on a G")
fmt.Fprintln(os.Stderr, "W Linux kernel, which is probably unnecessary and G")
@@ -46,9 +50,20 @@ func warning() {
fmt.Fprintln(os.Stderr, "W program. For more information on installing the G")
fmt.Fprintln(os.Stderr, "W kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
+ if shouldQuit {
+ fmt.Fprintln(os.Stderr, "W G")
+ fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
+ fmt.Fprintln(os.Stderr, "W the sage advice here, please first export this G")
+ fmt.Fprintln(os.Stderr, "W environment variable: G")
+ fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
+ }
}
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
+
+ if shouldQuit {
+ os.Exit(1)
+ }
}
func main() {
diff --git a/noise-protocol.go b/noise-protocol.go
index 35e95ef..3abbe4b 100644
--- a/noise-protocol.go
+++ b/noise-protocol.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0
*
- * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package main
@@ -488,7 +488,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
/* Derives a new key-pair from the current handshake state
*
*/
-func (peer *Peer) NewKeyPair() *KeyPair {
+func (peer *Peer) NewKeypair() *Keypair {
device := peer.device
handshake := &peer.handshake
handshake.mutex.Lock()
@@ -528,7 +528,7 @@ func (peer *Peer) NewKeyPair() *KeyPair {
// create AEAD instances
- keyPair := new(KeyPair)
+ keyPair := new(Keypair)
keyPair.send, _ = chacha20poly1305.New(sendKey[:])
keyPair.receive, _ = chacha20poly1305.New(recvKey[:])
@@ -559,24 +559,27 @@ func (peer *Peer) NewKeyPair() *KeyPair {
kp := &peer.keyPairs
kp.mutex.Lock()
- if isInitiator {
- if kp.previous != nil {
- device.DeleteKeyPair(kp.previous)
- kp.previous = nil
- }
+ peer.timersSessionDerived()
+
+ previous := kp.previous
+ next := kp.next
+ current := kp.current
- if kp.next != nil {
- kp.previous = kp.next
- kp.next = keyPair
+ if isInitiator {
+ if next != nil {
+ kp.next = nil
+ kp.previous = next
+ device.DeleteKeypair(current)
} else {
- kp.previous = kp.current
- kp.current = keyPair
- peer.event.newKeyPair.Fire()
+ kp.previous = current
}
-
+ device.DeleteKeypair(previous)
+ kp.current = keyPair
} else {
kp.next = keyPair
+ device.DeleteKeypair(next)
kp.previous = nil
+ device.DeleteKeypair(previous)
}
kp.mutex.Unlock()
diff --git a/noise_test.go b/noise_test.go
index 958a4ef..37bfb94 100644
--- a/noise_test.go
+++ b/noise_test.go
@@ -102,8 +102,8 @@ func TestNoiseHandshake(t *testing.T) {
t.Log("deriving keys")
- key1 := peer1.NewKeyPair()
- key2 := peer2.NewKeyPair()
+ key1 := peer1.NewKeypair()
+ key2 := peer2.NewKeypair()
if key1 == nil {
t.Fatal("failed to dervice key-pair for peer 1")
diff --git a/peer.go b/peer.go
index 739c8fb..242729e 100644
--- a/peer.go
+++ b/peer.go
@@ -14,14 +14,13 @@ import (
)
const (
- PeerRoutineNumber = 4
- EventInterval = 10 * time.Millisecond
+ PeerRoutineNumber = 3
)
type Peer struct {
isRunning AtomicBool
mutex sync.RWMutex
- keyPairs KeyPairs
+ keyPairs Keypairs
handshake Handshake
device *Device
endpoint Endpoint
@@ -34,34 +33,28 @@ type Peer struct {
lastHandshakeNano int64 // nano seconds since epoch
}
- time struct {
- mutex sync.RWMutex
- lastSend time.Time // last send message
- lastHandshake time.Time // last completed handshake
- nextKeepalive time.Time
+ timers struct {
+ retransmitHandshake *Timer
+ sendKeepalive *Timer
+ newHandshake *Timer
+ zeroKeyMaterial *Timer
+ persistentKeepalive *Timer
+ handshakeAttempts uint
+ needAnotherKeepalive bool
+ sentLastMinuteHandshake bool
+ lastSentHandshake time.Time
}
- event struct {
- dataSent *Event
- dataReceived *Event
- anyAuthenticatedPacketReceived *Event
- anyAuthenticatedPacketTraversal *Event
- handshakeCompleted *Event
- handshakePushDeadline *Event
- handshakeBegin *Event
- ephemeralKeyCreated *Event
- newKeyPair *Event
- flushNonceQueue *Event
- }
-
- timer struct {
- sendLastMinuteHandshake AtomicBool
+ signals struct {
+ newKeypairArrived chan struct{}
+ flushNonceQueue chan struct{}
}
queue struct {
- nonce chan *QueueOutboundElement // nonce / pre-handshake queue
- outbound chan *QueueOutboundElement // sequential ordering of work
- inbound chan *QueueInboundElement // sequential ordering of work
+ nonce chan *QueueOutboundElement // nonce / pre-handshake queue
+ outbound chan *QueueOutboundElement // sequential ordering of work
+ inbound chan *QueueInboundElement // sequential ordering of work
+ packetInNonceQueueIsAwaitingKey bool
}
routines struct {
@@ -188,6 +181,8 @@ func (peer *Peer) Start() {
peer.routines.starting.Wait()
peer.routines.stopping.Wait()
peer.routines.stop = make(chan struct{})
+ peer.routines.starting.Add(PeerRoutineNumber)
+ peer.routines.stopping.Add(PeerRoutineNumber)
// prepare queues
@@ -195,28 +190,13 @@ func (peer *Peer) Start() {
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
- // events
-
- peer.event.dataSent = newEvent(EventInterval)
- peer.event.dataReceived = newEvent(EventInterval)
- peer.event.anyAuthenticatedPacketReceived = newEvent(EventInterval)
- peer.event.anyAuthenticatedPacketTraversal = newEvent(EventInterval)
- peer.event.handshakeCompleted = newEvent(EventInterval)
- peer.event.handshakePushDeadline = newEvent(EventInterval)
- peer.event.handshakeBegin = newEvent(EventInterval)
- peer.event.ephemeralKeyCreated = newEvent(EventInterval)
- peer.event.newKeyPair = newEvent(EventInterval)
- peer.event.flushNonceQueue = newEvent(EventInterval)
-
- peer.isRunning.Set(true)
+ peer.timersInit()
+ peer.signals.newKeypairArrived = make(chan struct{}, 1)
+ peer.signals.flushNonceQueue = make(chan struct{}, 1)
// wait for routines to start
- peer.routines.starting.Add(PeerRoutineNumber)
- peer.routines.stopping.Add(PeerRoutineNumber)
-
go peer.RoutineNonce()
- go peer.RoutineTimerHandler()
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
@@ -238,6 +218,8 @@ func (peer *Peer) Stop() {
device := peer.device
device.log.Debug.Println(peer, ": Stopping...")
+ peer.timersStop()
+
// stop & wait for ongoing peer routines
peer.routines.starting.Wait()
@@ -255,9 +237,9 @@ func (peer *Peer) Stop() {
kp := &peer.keyPairs
kp.mutex.Lock()
- device.DeleteKeyPair(kp.previous)
- device.DeleteKeyPair(kp.current)
- device.DeleteKeyPair(kp.next)
+ device.DeleteKeypair(kp.previous)
+ device.DeleteKeypair(kp.current)
+ device.DeleteKeypair(kp.next)
kp.previous = nil
kp.current = nil
@@ -271,4 +253,6 @@ func (peer *Peer) Stop() {
device.indices.Delete(hs.localIndex)
hs.Clear()
hs.mutex.Unlock()
+
+ peer.FlushNonceQueue()
}
diff --git a/receive.go b/receive.go
index 1cf77b2..0f22a3f 100644
--- a/receive.go
+++ b/receive.go
@@ -31,7 +31,7 @@ type QueueInboundElement struct {
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
- keyPair *KeyPair
+ keyPair *Keypair
endpoint Endpoint
}
@@ -99,6 +99,21 @@ func (device *Device) addToHandshakeQueue(
}
}
+/* Called when a new authenticated message has been received
+ *
+ * NOTE: Not thread safe, but called by sequential receiver!
+ */
+func (peer *Peer) keepKeyFreshReceiving() {
+ if peer.timers.sentLastMinuteHandshake {
+ return
+ }
+ kp := peer.keyPairs.Current()
+ if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
+ peer.timers.sentLastMinuteHandshake = true
+ peer.SendHandshakeInitiation(false)
+ }
+}
+
/* Receives incoming datagrams for the device
*
* Every time the bind is updated a new routine is started for
@@ -245,7 +260,7 @@ func (device *Device) RoutineDecryption() {
for {
select {
- case <-device.signal.stop.Wait():
+ case <-device.signals.stop:
return
case elem, ok := <-device.queue.decryption:
@@ -317,7 +332,7 @@ func (device *Device) RoutineHandshake() {
for {
select {
case elem, ok = <-device.queue.handshake:
- case <-device.signal.stop.Wait():
+ case <-device.signals.stop:
return
}
@@ -441,8 +456,8 @@ func (device *Device) RoutineHandshake() {
// update timers
- peer.event.anyAuthenticatedPacketTraversal.Fire()
- peer.event.anyAuthenticatedPacketReceived.Fire()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
// update endpoint
@@ -460,10 +475,11 @@ func (device *Device) RoutineHandshake() {
continue
}
- peer.TimerEphemeralKeyCreated()
- peer.NewKeyPair()
+ if peer.NewKeypair() == nil {
+ continue
+ }
- logDebug.Println(peer, ": Creating handshake response")
+ logDebug.Println(peer, ": Sending handshake response")
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, response)
@@ -472,9 +488,10 @@ func (device *Device) RoutineHandshake() {
// send response
+ peer.timers.lastSentHandshake = time.Now()
err = peer.SendBuffer(packet)
if err == nil {
- peer.event.anyAuthenticatedPacketTraversal.Fire()
+ peer.timersAnyAuthenticatedPacketTraversal()
} else {
logError.Println(peer, ": Failed to send handshake response", err)
}
@@ -510,18 +527,23 @@ func (device *Device) RoutineHandshake() {
logDebug.Println(peer, ": Received handshake response")
- peer.TimerEphemeralKeyCreated()
-
// update timers
- peer.event.anyAuthenticatedPacketTraversal.Fire()
- peer.event.anyAuthenticatedPacketReceived.Fire()
- peer.event.handshakeCompleted.Fire()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
// derive key-pair
- peer.NewKeyPair()
- peer.SendKeepAlive()
+ if peer.NewKeypair() == nil {
+ continue
+ }
+
+ peer.timersHandshakeComplete()
+ peer.SendKeepalive()
+ select {
+ case peer.signals.newKeypairArrived <- struct{}{}:
+ default:
+ }
}
}
}
@@ -569,38 +591,41 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
- peer.event.anyAuthenticatedPacketTraversal.Fire()
- peer.event.anyAuthenticatedPacketReceived.Fire()
- peer.KeepKeyFreshReceiving()
+ // update endpoint
+
+ peer.mutex.Lock()
+ peer.endpoint = elem.endpoint
+ peer.mutex.Unlock()
// check if using new key-pair
kp := &peer.keyPairs
- kp.mutex.Lock()
+ kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
if kp.next == elem.keyPair {
- peer.event.handshakeCompleted.Fire()
- if kp.previous != nil {
- device.DeleteKeyPair(kp.previous)
- }
+ old := kp.previous
kp.previous = kp.current
+ device.DeleteKeypair(old)
kp.current = kp.next
kp.next = nil
+ peer.timersHandshakeComplete()
+ select {
+ case peer.signals.newKeypairArrived <- struct{}{}:
+ default:
+ }
}
kp.mutex.Unlock()
- // update endpoint
-
- peer.mutex.Lock()
- peer.endpoint = elem.endpoint
- peer.mutex.Unlock()
+ peer.keepKeyFreshReceiving()
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersAnyAuthenticatedPacketReceived()
- // check for keep-alive
+ // check for keepalive
if len(elem.packet) == 0 {
- logDebug.Println(peer, ": Received keep-alive")
+ logDebug.Println(peer, ": Receiving keepalive packet")
continue
}
- peer.event.dataReceived.Fire()
+ peer.timersDataReceived()
// verify source and strip padding
diff --git a/send.go b/send.go
index ddebb99..1b35e27 100644
--- a/send.go
+++ b/send.go
@@ -6,6 +6,7 @@
package main
import (
+ "bytes"
"encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
@@ -46,21 +47,10 @@ type QueueOutboundElement struct {
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
- keyPair *KeyPair // key-pair for encryption
+ keyPair *Keypair // key-pair for encryption
peer *Peer // related peer
}
-func (peer *Peer) flushNonceQueue() {
- elems := len(peer.queue.nonce)
- for i := 0; i < elems; i++ {
- select {
- case <-peer.queue.nonce:
- default:
- return
- }
- }
-}
-
func (device *Device) NewOutboundElement() *QueueOutboundElement {
return &QueueOutboundElement{
dropped: AtomicFalse,
@@ -114,6 +104,73 @@ func addToEncryptionQueue(
}
}
+/* Queues a keepalive if no packets are queued for peer
+ */
+func (peer *Peer) SendKeepalive() bool {
+ if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey {
+ return false
+ }
+ elem := peer.device.NewOutboundElement()
+ elem.packet = nil
+ select {
+ case peer.queue.nonce <- elem:
+ peer.device.log.Debug.Println(peer, ": Sending keepalive packet")
+ return true
+ default:
+ return false
+ }
+}
+
+/* Sends a new handshake initiation message to the peer (endpoint)
+ */
+func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
+ if !isRetry {
+ peer.timers.handshakeAttempts = 0
+ }
+
+ if time.Now().Sub(peer.timers.lastSentHandshake) < RekeyTimeout {
+ return nil
+ }
+ peer.timers.lastSentHandshake = time.Now() //TODO: locking for this variable?
+
+ // create initiation message
+
+ msg, err := peer.device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ peer.device.log.Debug.Println(peer, ": Sending handshake initiation")
+
+ // marshal handshake message
+
+ var buff [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
+
+ // send to endpoint
+
+ peer.timersAnyAuthenticatedPacketTraversal()
+ peer.timersHandshakeInitiated()
+ return peer.SendBuffer(packet)
+}
+
+/* Called when a new authenticated message has been send
+ *
+ */
+func (peer *Peer) keepKeyFreshSending() {
+ kp := peer.keyPairs.Current()
+ if kp == nil {
+ return
+ }
+ nonce := atomic.LoadUint64(&kp.sendNonce)
+ if nonce > RekeyAfterMessages || (kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime) {
+ peer.SendHandshakeInitiation(false)
+ }
+}
+
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
@@ -180,13 +237,22 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue
if peer.isRunning.Get() {
- peer.event.handshakePushDeadline.Fire()
+ if peer.queue.packetInNonceQueueIsAwaitingKey {
+ peer.SendHandshakeInitiation(false)
+ }
addToOutboundQueue(peer.queue.nonce, elem)
elem = device.NewOutboundElement()
}
}
}
+func (peer *Peer) FlushNonceQueue() {
+ select {
+ case peer.signals.flushNonceQueue <- struct{}{}:
+ default:
+ }
+}
+
/* Queues packets when there is no handshake.
* Then assigns nonces to packets sequentially
* and creates "work" structs for workers
@@ -194,13 +260,14 @@ func (device *Device) RoutineReadFromTUN() {
* Obs. A single instance per peer
*/
func (peer *Peer) RoutineNonce() {
- var keyPair *KeyPair
+ var keyPair *Keypair
device := peer.device
logDebug := device.log.Debug
defer func() {
logDebug.Println(peer, ": Routine: nonce worker - stopped")
+ peer.queue.packetInNonceQueueIsAwaitingKey = false
peer.routines.stopping.Done()
}()
@@ -209,8 +276,7 @@ func (peer *Peer) RoutineNonce() {
for {
NextPacket:
-
- peer.event.flushNonceQueue.Clear()
+ peer.queue.packetInNonceQueueIsAwaitingKey = false
select {
case <-peer.routines.stop:
@@ -225,34 +291,48 @@ func (peer *Peer) RoutineNonce() {
// wait for key pair
for {
-
- peer.event.newKeyPair.Clear()
-
keyPair = peer.keyPairs.Current()
if keyPair != nil && keyPair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keyPair.created) < RejectAfterTime {
break
}
}
+ peer.queue.packetInNonceQueueIsAwaitingKey = true
- peer.event.handshakeBegin.Fire()
+ select {
+ case <-peer.signals.newKeypairArrived:
+ default:
+ }
+
+ peer.SendHandshakeInitiation(false)
logDebug.Println(peer, ": Awaiting key-pair")
select {
- case <-peer.event.newKeyPair.C:
+ case <-peer.signals.newKeypairArrived:
logDebug.Println(peer, ": Obtained awaited key-pair")
- case <-peer.event.flushNonceQueue.C:
- goto NextPacket
+ case <-peer.signals.flushNonceQueue:
+ for {
+ select {
+ case <-peer.queue.nonce:
+ default:
+ goto NextPacket
+ }
+ }
case <-peer.routines.stop:
return
}
}
+ peer.queue.packetInNonceQueueIsAwaitingKey = false
// populate work element
elem.peer = peer
elem.nonce = atomic.AddUint64(&keyPair.sendNonce, 1) - 1
+ // double check in case of race condition added by future code
+ if elem.nonce >= RejectAfterMessages {
+ goto NextPacket
+ }
elem.keyPair = keyPair
elem.dropped = AtomicFalse
elem.mutex.Lock()
@@ -288,7 +368,7 @@ func (device *Device) RoutineEncryption() {
// fetch next element
select {
- case <-device.signal.stop.Wait():
+ case <-device.signals.stop:
return
case elem, ok := <-device.queue.encryption:
@@ -389,11 +469,11 @@ func (peer *Peer) RoutineSequentialSender() {
// update timers
- peer.event.anyAuthenticatedPacketTraversal.Fire()
+ peer.timersAnyAuthenticatedPacketTraversal()
if len(elem.packet) != MessageKeepaliveSize {
- peer.event.dataSent.Fire()
+ peer.timersDataSent()
}
- peer.KeepKeyFreshSending()
+ peer.keepKeyFreshSending()
}
}
}
diff --git a/signal.go b/signal.go
deleted file mode 100644
index 606da52..0000000
--- a/signal.go
+++ /dev/null
@@ -1,71 +0,0 @@
-/* SPDX-License-Identifier: GPL-2.0
- *
- * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
- */
-
-package main
-
-func signalSend(s chan<- struct{}) {
- select {
- case s <- struct{}{}:
- default:
- }
-}
-
-type Signal struct {
- enabled AtomicBool
- C chan struct{}
-}
-
-func NewSignal() (s Signal) {
- s.C = make(chan struct{}, 1)
- s.Enable()
- return
-}
-
-func (s *Signal) Close() {
- close(s.C)
-}
-
-func (s *Signal) Disable() {
- s.enabled.Set(false)
- s.Clear()
-}
-
-func (s *Signal) Enable() {
- s.enabled.Set(true)
-}
-
-/* Unblock exactly one listener
- */
-func (s *Signal) Send() {
- if s.enabled.Get() {
- select {
- case s.C <- struct{}{}:
- default:
- }
- }
-}
-
-/* Clear the signal if already fired
- */
-func (s Signal) Clear() {
- select {
- case <-s.C:
- default:
- }
-}
-
-/* Unblocks all listeners (forever)
- */
-func (s Signal) Broadcast() {
- if s.enabled.Get() {
- close(s.C)
- }
-}
-
-/* Wait for the signal
- */
-func (s Signal) Wait() chan struct{} {
- return s.C
-}
diff --git a/timers.go b/timers.go
index 38c9b46..5c72efd 100644
--- a/timers.go
+++ b/timers.go
@@ -1,355 +1,221 @@
/* SPDX-License-Identifier: GPL-2.0
*
- * Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ * Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
+ *
+ * This is based heavily on timers.c from the kernel implementation.
*/
package main
import (
- "bytes"
- "encoding/binary"
"math/rand"
"sync/atomic"
"time"
)
-/* NOTE:
- * Notion of validity
+/* This Timer structure and related functions should roughly copy the interface of
+ * the Linux kernel's struct timer_list.
*/
-/* Called when a new authenticated message has been send
- *
- */
-func (peer *Peer) KeepKeyFreshSending() {
- kp := peer.keyPairs.Current()
- if kp == nil {
- return
- }
- nonce := atomic.LoadUint64(&kp.sendNonce)
- if nonce > RekeyAfterMessages {
- peer.event.handshakeBegin.Fire()
- }
- if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
- peer.event.handshakeBegin.Fire()
- }
+type Timer struct {
+ timer *time.Timer
+ isPending bool
}
-/* Called when a new authenticated message has been received
- *
- * NOTE: Not thread safe, but called by sequential receiver!
- */
-func (peer *Peer) KeepKeyFreshReceiving() {
- if peer.timer.sendLastMinuteHandshake.Get() {
- return
- }
- kp := peer.keyPairs.Current()
- if kp == nil {
- return
- }
- if !kp.isInitiator {
- return
- }
- nonce := atomic.LoadUint64(&kp.sendNonce)
- send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
- if send {
- // do a last minute attempt at initiating a new handshake
- peer.timer.sendLastMinuteHandshake.Set(true)
- peer.event.handshakeBegin.Fire()
- }
+func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
+ timer := &Timer{}
+ timer.timer = time.AfterFunc(time.Hour, func() {
+ timer.isPending = false
+ expirationFunction(peer)
+ })
+ timer.timer.Stop()
+ return timer
}
-/* Queues a keep-alive if no packets are queued for peer
- */
-func (peer *Peer) SendKeepAlive() bool {
- if len(peer.queue.nonce) != 0 {
- return false
- }
- elem := peer.device.NewOutboundElement()
- elem.packet = nil
- select {
- case peer.queue.nonce <- elem:
- return true
- default:
- return false
- }
+func (timer *Timer) Mod(d time.Duration) {
+ timer.isPending = true
+ timer.timer.Reset(d)
}
-/* Called after successfully completing a handshake.
- * i.e. after:
- *
- * - Valid handshake response
- * - First transport message under the "next" key
- */
-// peer.device.log.Info.Println(peer, ": New handshake completed")
-
-/* Event:
- * An ephemeral key is generated
- *
- * i.e. after:
- *
- * CreateMessageInitiation
- * CreateMessageResponse
- *
- * Action:
- * Schedule the deletion of all key material
- * upon failure to complete a handshake
- */
-func (peer *Peer) TimerEphemeralKeyCreated() {
- peer.event.ephemeralKeyCreated.Fire()
- // peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
+func (timer *Timer) Del() {
+ timer.isPending = false
+ timer.timer.Stop()
}
-/* Sends a new handshake initiation message to the peer (endpoint)
- */
-func (peer *Peer) sendNewHandshake() error {
-
- // create initiation message
-
- msg, err := peer.device.CreateMessageInitiation(peer)
- if err != nil {
- return err
- }
+func (peer *Peer) timersActive() bool {
+ return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
+}
- // marshal handshake message
+func expiredRetransmitHandshake(peer *Peer) {
+ if peer.timers.handshakeAttempts > MaxTimerHandshakes {
+ peer.device.log.Debug.Printf("%s: Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
+ if peer.timersActive() {
+ peer.timers.sendKeepalive.Del()
+ }
- // send to endpoint
+ /* We drop all packets without a keypair and don't try again,
+ * if we try unsuccessfully for too long to make a handshake.
+ */
+ peer.FlushNonceQueue()
- peer.event.anyAuthenticatedPacketTraversal.Fire()
+ /* We set a timer for destroying any residue that might be left
+ * of a partial exchange.
+ */
+ if peer.timersActive() && !peer.timers.zeroKeyMaterial.isPending {
+ peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
+ }
+ } else {
+ peer.timers.handshakeAttempts++
+ peer.device.log.Debug.Printf("%s: Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts+1)
+
+ /* We clear the endpoint address src address, in case this is the cause of trouble. */
+ peer.mutex.Lock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
+ peer.mutex.Unlock()
- return peer.SendBuffer(packet)
+ peer.SendHandshakeInitiation(true)
+ }
}
-func newTimer() *time.Timer {
- timer := time.NewTimer(time.Hour)
- timer.Stop()
- return timer
+func expiredSendKeepalive(peer *Peer) {
+ peer.SendKeepalive()
+ if peer.timers.needAnotherKeepalive {
+ peer.timers.needAnotherKeepalive = false
+ if peer.timersActive() {
+ peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
+ }
+ }
}
-func (peer *Peer) RoutineTimerHandler() {
-
- device := peer.device
-
- logInfo := device.log.Info
- logDebug := device.log.Debug
-
- defer func() {
- logDebug.Println(peer, ": Routine: timer handler - stopped")
- peer.routines.stopping.Done()
- }()
-
- logDebug.Println(peer, ": Routine: timer handler - started")
-
- // reset all timers
-
- enableHandshake := true
- pendingHandshakeNew := false
- pendingKeepalivePassive := false
- needAnotherKeepalive := false
-
- timerKeepalivePassive := newTimer()
- timerHandshakeDeadline := newTimer()
- timerHandshakeTimeout := newTimer()
- timerHandshakeNew := newTimer()
- timerZeroAllKeys := newTimer()
- timerKeepalivePersistent := newTimer()
-
- interval := peer.persistentKeepaliveInterval
- if interval > 0 {
- duration := time.Duration(interval) * time.Second
- timerKeepalivePersistent.Reset(duration)
+func expiredNewHandshake(peer *Peer) {
+ peer.device.log.Debug.Printf("%s: Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
+ /* We clear the endpoint address src address, in case this is the cause of trouble. */
+ peer.mutex.Lock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
}
+ peer.mutex.Unlock()
+ peer.SendHandshakeInitiation(false)
- // signal synchronised setup complete
-
- peer.routines.starting.Done()
-
- // handle timer events
-
- for {
- select {
-
- /* stopping */
-
- case <-peer.routines.stop:
- return
-
- /* events */
-
- case <-peer.event.dataSent.C:
- timerKeepalivePassive.Stop()
- if !pendingHandshakeNew {
- timerHandshakeNew.Reset(NewHandshakeTime)
- }
-
- case <-peer.event.dataReceived.C:
- if pendingKeepalivePassive {
- needAnotherKeepalive = true
- } else {
- timerKeepalivePassive.Reset(KeepaliveTimeout)
- }
-
- case <-peer.event.anyAuthenticatedPacketTraversal.C:
- interval := peer.persistentKeepaliveInterval
- if interval > 0 {
- duration := time.Duration(interval) * time.Second
- timerKeepalivePersistent.Reset(duration)
- }
-
- case <-peer.event.handshakeBegin.C:
-
- if !enableHandshake {
- continue
- }
-
- logDebug.Println(peer, ": Event, Handshake Begin")
-
- err := peer.sendNewHandshake()
-
- // set timeout
-
- jitter := time.Millisecond * time.Duration(rand.Int31n(334))
- timerKeepalivePassive.Stop()
- timerHandshakeTimeout.Reset(RekeyTimeout + jitter)
-
- if err != nil {
- logInfo.Println(peer, ": Failed to send handshake initiation", err)
- } else {
- logDebug.Println(peer, ": Send handshake initiation (initial)")
- }
-
- timerHandshakeDeadline.Reset(RekeyAttemptTime)
-
- // disable further handshakes
-
- peer.event.handshakeBegin.Clear()
- enableHandshake = false
-
- case <-peer.event.handshakeCompleted.C:
-
- logInfo.Println(peer, ": Handshake completed")
-
- atomic.StoreInt64(
- &peer.stats.lastHandshakeNano,
- time.Now().UnixNano(),
- )
-
- timerHandshakeTimeout.Stop()
- timerHandshakeDeadline.Stop()
- peer.timer.sendLastMinuteHandshake.Set(false)
-
- // allow further handshakes
-
- peer.event.handshakeBegin.Clear()
- enableHandshake = true
-
- /* timers */
-
- case <-timerKeepalivePersistent.C:
-
- interval := peer.persistentKeepaliveInterval
- if interval > 0 {
- logDebug.Println(peer, ": Send keep-alive (persistent)")
- timerKeepalivePassive.Stop()
- peer.SendKeepAlive()
- }
-
- case <-timerKeepalivePassive.C:
-
- logDebug.Println(peer, ": Send keep-alive (passive)")
-
- peer.SendKeepAlive()
-
- if needAnotherKeepalive {
- timerKeepalivePassive.Reset(KeepaliveTimeout)
- needAnotherKeepalive = false
- }
-
- case <-timerZeroAllKeys.C:
-
- logDebug.Println(peer, ": Clear all key-material (timer event)")
-
- hs := &peer.handshake
- hs.mutex.Lock()
-
- kp := &peer.keyPairs
- kp.mutex.Lock()
-
- // remove key-pairs
-
- if kp.previous != nil {
- device.DeleteKeyPair(kp.previous)
- kp.previous = nil
- }
- if kp.current != nil {
- device.DeleteKeyPair(kp.current)
- kp.current = nil
- }
- if kp.next != nil {
- device.DeleteKeyPair(kp.next)
- kp.next = nil
- }
- kp.mutex.Unlock()
-
- // zero out handshake
-
- device.indices.Delete(hs.localIndex)
- hs.Clear()
- hs.mutex.Unlock()
-
- case <-timerHandshakeTimeout.C:
-
- // allow new handshake to be send
+}
- enableHandshake = true
+func expiredZeroKeyMaterial(peer *Peer) {
+ peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
- // clear source (in case this is causing problems)
+ hs := &peer.handshake
+ hs.mutex.Lock()
- peer.mutex.Lock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
- peer.mutex.Unlock()
+ kp := &peer.keyPairs
+ kp.mutex.Lock()
- // send new handshake
+ if kp.previous != nil {
+ peer.device.DeleteKeypair(kp.previous)
+ kp.previous = nil
+ }
+ if kp.current != nil {
+ peer.device.DeleteKeypair(kp.current)
+ kp.current = nil
+ }
+ if kp.next != nil {
+ peer.device.DeleteKeypair(kp.next)
+ kp.next = nil
+ }
+ kp.mutex.Unlock()
- err := peer.sendNewHandshake()
+ peer.device.indices.Delete(hs.localIndex)
+ hs.Clear()
+ hs.mutex.Unlock()
+}
- // set timeout
+func expiredPersistentKeepalive(peer *Peer) {
+ if peer.persistentKeepaliveInterval > 0 {
+ if peer.timersActive() {
+ peer.timers.sendKeepalive.Del()
+ }
+ peer.SendKeepalive()
+ }
+}
- jitter := time.Millisecond * time.Duration(rand.Int31n(334))
- timerKeepalivePassive.Stop()
- timerHandshakeTimeout.Reset(RekeyTimeout + jitter)
+/* Should be called after an authenticated data packet is sent. */
+func (peer *Peer) timersDataSent() {
+ if peer.timersActive() {
+ peer.timers.sendKeepalive.Del()
+ }
- if err != nil {
- logInfo.Println(peer, ": Failed to send handshake initiation", err)
- } else {
- logDebug.Println(peer, ": Send handshake initiation (subsequent)")
- }
+ if peer.timersActive() && !peer.timers.newHandshake.isPending {
+ peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout)
+ }
+}
- // disable further handshakes
+/* Should be called after an authenticated data packet is received. */
+func (peer *Peer) timersDataReceived() {
+ if peer.timersActive() {
+ if !peer.timers.sendKeepalive.isPending {
+ peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
+ } else {
+ peer.timers.needAnotherKeepalive = true
+ }
+ }
+}
- peer.event.handshakeBegin.Clear()
- enableHandshake = false
+/* Should be called after any type of authenticated packet is received -- keepalive or data. */
+func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
+ if peer.timersActive() {
+ peer.timers.newHandshake.Del()
+ }
+}
- case <-timerHandshakeDeadline.C:
+/* Should be called after a handshake initiation message is sent. */
+func (peer *Peer) timersHandshakeInitiated() {
+ if peer.timersActive() {
+ peer.timers.sendKeepalive.Del()
+ peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
+ }
+}
- // clear all queued packets and stop keep-alive
+/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
+func (peer *Peer) timersHandshakeComplete() {
+ if peer.timersActive() {
+ peer.timers.retransmitHandshake.Del()
+ }
+ peer.timers.handshakeAttempts = 0
+ peer.timers.sentLastMinuteHandshake = false
+ atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
+}
- logInfo.Println(peer, ": Handshake negotiation timed-out")
+/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
+func (peer *Peer) timersSessionDerived() {
+ if peer.timersActive() {
+ peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
+ }
+}
- peer.flushNonceQueue()
- peer.event.flushNonceQueue.Fire()
+/* Should be called before a packet with authentication -- data, keepalive, either handshake -- is sent, or after one is received. */
+func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
+ if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
+ peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
+ }
+}
- // renable further handshakes
+func (peer *Peer) timersInit() {
+ peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
+ peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
+ peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
+ peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
+ peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
+ peer.timers.handshakeAttempts = 0
+ peer.timers.sentLastMinuteHandshake = false
+ peer.timers.needAnotherKeepalive = false
+ peer.timers.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
+}
- peer.event.handshakeBegin.Clear()
- enableHandshake = true
- }
- }
+func (peer *Peer) timersStop() {
+ peer.timers.retransmitHandshake.Del()
+ peer.timers.sendKeepalive.Del()
+ peer.timers.newHandshake.Del()
+ peer.timers.zeroKeyMaterial.Del()
+ peer.timers.persistentKeepalive.Del()
}
diff --git a/uapi.go b/uapi.go
index 54d9bae..4b2038b 100644
--- a/uapi.go
+++ b/uapi.go
@@ -256,8 +256,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logDebug.Println("UAPI: Created new peer:", peer)
}
- peer.event.handshakePushDeadline.Fire()
-
case "remove":
// remove currently selected peer from device
@@ -288,8 +286,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
- peer.event.handshakePushDeadline.Fire()
-
case "endpoint":
// set endpoint destination
@@ -304,7 +300,6 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return err
}
peer.endpoint = endpoint
- peer.event.handshakePushDeadline.Fire()
return nil
}()
@@ -315,7 +310,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "persistent_keepalive_interval":
- // update keep-alive interval
+ // update persistent keepalive interval
logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer)
@@ -328,7 +323,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
old := peer.persistentKeepaliveInterval
peer.persistentKeepaliveInterval = uint16(secs)
- // send immediate keep-alive
+ // send immediate keepalive if we're turning it on and before it wasn't on
if old == 0 && secs != 0 {
if err != nil {
@@ -336,7 +331,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorIO}
}
if device.isUp.Get() && !dummy {
- peer.SendKeepAlive()
+ peer.SendKeepalive()
}
}