diff options
-rw-r--r-- | src/conn.go | 2 | ||||
-rw-r--r-- | src/device.go | 44 | ||||
-rw-r--r-- | src/peer.go | 45 | ||||
-rw-r--r-- | src/timer.go | 6 | ||||
-rw-r--r-- | src/timers.go | 150 | ||||
-rw-r--r-- | src/tun.go | 16 | ||||
-rw-r--r-- | src/uapi.go | 2 |
7 files changed, 163 insertions, 102 deletions
diff --git a/src/conn.go b/src/conn.go index 6d292d3..ddb7ed1 100644 --- a/src/conn.go +++ b/src/conn.go @@ -82,7 +82,7 @@ func updateBind(device *Device) error { // open new sockets - if device.tun.isUp.Get() { + if device.isUp.Get() { device.log.Debug.Println("UDP bind updating") diff --git a/src/device.go b/src/device.go index a3461ad..f4a087c 100644 --- a/src/device.go +++ b/src/device.go @@ -8,13 +8,13 @@ import ( ) type Device struct { - closed AtomicBool // device is closed? (acting as guard) + isUp AtomicBool // device is up (TUN interface up)? + isClosed AtomicBool // device is closed? (acting as guard) log *Logger // collection of loggers for levels idCounter uint // for assigning debug ids to peers fwMark uint32 tun struct { device TUNDevice - isUp AtomicBool mtu int32 } pool struct { @@ -45,6 +45,28 @@ type Device struct { mac CookieChecker } +func (device *Device) Up() { + device.mutex.Lock() + defer device.mutex.Unlock() + + device.isUp.Set(true) + updateBind(device) + for _, peer := range device.peers { + peer.Start() + } +} + +func (device *Device) Down() { + device.mutex.Lock() + defer device.mutex.Unlock() + + device.isUp.Set(false) + closeBind(device) + for _, peer := range device.peers { + peer.Stop() + } +} + /* Warning: * The caller must hold the device mutex (write lock) */ @@ -54,9 +76,9 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) { return } peer.mutex.Lock() + peer.Stop() device.routingTable.RemovePeer(peer) delete(device.peers, key) - peer.Close() } func (device *Device) IsUnderLoad() bool { @@ -98,7 +120,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { device.publicKey = publicKey device.mac.Init(publicKey) - // do DH precomputations + // do DH pre-computations rmKey := device.privateKey.IsZero() @@ -132,10 +154,12 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { device.mutex.Lock() defer device.mutex.Unlock() + device.isUp.Set(false) + device.isClosed.Set(false) + device.log = logger device.peers = make(map[NoisePublicKey]*Peer) device.tun.device = tun - device.tun.isUp.Set(false) device.indices.Init() device.ratelimiter.Init() @@ -196,17 +220,13 @@ func (device *Device) RemovePeer(key NoisePublicKey) { func (device *Device) RemoveAllPeers() { device.mutex.Lock() defer device.mutex.Unlock() - - for key, peer := range device.peers { - peer.mutex.Lock() - delete(device.peers, key) - peer.Close() - peer.mutex.Unlock() + for key := range device.peers { + removePeerUnsafe(device, key) } } func (device *Device) Close() { - if device.closed.Swap(true) { + if device.isClosed.Swap(true) { return } device.log.Info.Println("Closing device") diff --git a/src/peer.go b/src/peer.go index f582556..7c6ad47 100644 --- a/src/peer.go +++ b/src/peer.go @@ -34,15 +34,15 @@ type Peer struct { flushNonceQueue Signal // size 1, empty queued packets messageSend Signal // size 1, message was send to peer messageReceived Signal // size 1, authenticated message recv - stop Signal // size 0, stop all goroutines + stop Signal // size 0, stop all goroutines in peer } timer struct { // state related to WireGuard timers keepalivePersistent Timer // set for persistent keepalives keepalivePassive Timer // set upon recieving messages - newHandshake Timer // begin a new handshake (stale) zeroAllKeys Timer // zero all key material + handshakeNew Timer // begin a new handshake (stale) handshakeDeadline Timer // complete handshake timeout handshakeTimeout Timer // current handshake message timeout @@ -69,8 +69,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.timer.keepalivePersistent = NewTimer() peer.timer.keepalivePassive = NewTimer() - peer.timer.newHandshake = NewTimer() peer.timer.zeroAllKeys = NewTimer() + peer.timer.handshakeNew = NewTimer() peer.timer.handshakeDeadline = NewTimer() peer.timer.handshakeTimeout = NewTimer() @@ -116,32 +116,29 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { // prepare signaling & routines - peer.signal.stop = NewSignal() peer.signal.newKeyPair = NewSignal() peer.signal.handshakeBegin = NewSignal() peer.signal.handshakeCompleted = NewSignal() peer.signal.flushNonceQueue = NewSignal() - go peer.RoutineNonce() - go peer.RoutineTimerHandler() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() - return peer, nil } func (peer *Peer) SendBuffer(buffer []byte) error { peer.device.net.mutex.RLock() defer peer.device.net.mutex.RUnlock() + peer.mutex.RLock() defer peer.mutex.RUnlock() + if peer.endpoint == nil { return errors.New("No known endpoint for peer") } + return peer.device.net.bind.Send(buffer, peer.endpoint) } -/* Returns a short string identification for logging +/* Returns a short string identifier for logging */ func (peer *Peer) String() string { if peer.endpoint == nil { @@ -159,6 +156,32 @@ func (peer *Peer) String() string { ) } -func (peer *Peer) Close() { +/* Starts all routines for a given peer + * + * Requires that the caller holds the exclusive peer lock! + */ +func unsafePeerStart(peer *Peer) { + peer.signal.stop.Broadcast() + peer.signal.stop = NewSignal() + + var wait sync.WaitGroup + + wait.Add(1) + + go peer.RoutineNonce() + go peer.RoutineTimerHandler(&wait) + go peer.RoutineSequentialSender() + go peer.RoutineSequentialReceiver() + + wait.Wait() +} + +func (peer *Peer) Start() { + peer.mutex.Lock() + unsafePeerStart(peer) + peer.mutex.Unlock() +} + +func (peer *Peer) Stop() { peer.signal.stop.Broadcast() } diff --git a/src/timer.go b/src/timer.go index 3def253..f00ca49 100644 --- a/src/timer.go +++ b/src/timer.go @@ -43,12 +43,6 @@ func (t *Timer) Reset(dur time.Duration) { t.Start(dur) } -func (t *Timer) Push(dur time.Duration) { - if t.pending.Get() { - t.Reset(dur) - } -} - func (t *Timer) Wait() <-chan time.Time { return t.timer.C } diff --git a/src/timers.go b/src/timers.go index ee47393..f2fed30 100644 --- a/src/timers.go +++ b/src/timers.go @@ -4,10 +4,17 @@ import ( "bytes"
"encoding/binary"
"math/rand"
+ "sync"
"sync/atomic"
"time"
)
+/* NOTE:
+ * Notion of validity
+ *
+ *
+ */
+
/* Called when a new authenticated message has been send
*
*/
@@ -44,25 +51,25 @@ func (peer *Peer) KeepKeyFreshReceiving() { send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
// do a last minute attempt at initiating a new handshake
- peer.signal.handshakeBegin.Send()
peer.timer.sendLastMinuteHandshake = true
+ peer.signal.handshakeBegin.Send()
}
}
/* Queues a keep-alive if no packets are queued for peer
*/
func (peer *Peer) SendKeepAlive() bool {
+ if len(peer.queue.nonce) != 0 {
+ return false
+ }
elem := peer.device.NewOutboundElement()
elem.packet = nil
- if len(peer.queue.nonce) == 0 {
- select {
- case peer.queue.nonce <- elem:
- return true
- default:
- return false
- }
+ select {
+ case peer.queue.nonce <- elem:
+ return true
+ default:
+ return false
}
- return true
}
/* Event:
@@ -70,9 +77,7 @@ func (peer *Peer) SendKeepAlive() bool { */
func (peer *Peer) TimerDataSent() {
peer.timer.keepalivePassive.Stop()
- if peer.timer.newHandshake.Pending() {
- peer.timer.newHandshake.Reset(NewHandshakeTime)
- }
+ peer.timer.handshakeNew.Start(NewHandshakeTime)
}
/* Event:
@@ -91,7 +96,7 @@ func (peer *Peer) TimerDataReceived() { * Any (authenticated) packet received
*/
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
- peer.timer.newHandshake.Stop()
+ peer.timer.handshakeNew.Stop()
}
/* Event:
@@ -115,10 +120,6 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() { * - First transport message under the "next" key
*/
func (peer *Peer) TimerHandshakeComplete() {
- atomic.StoreInt64(
- &peer.stats.lastHandshakeNano,
- time.Now().UnixNano(),
- )
peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
}
@@ -139,13 +140,75 @@ func (peer *Peer) TimerEphemeralKeyCreated() { peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
}
-func (peer *Peer) RoutineTimerHandler() {
+/* Sends a new handshake initiation message to the peer (endpoint)
+ */
+func (peer *Peer) sendNewHandshake() error {
+
+ // temporarily disable the handshake complete signal
+
+ peer.signal.handshakeCompleted.Disable()
+
+ // create initiation message
+
+ msg, err := peer.device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ // marshal handshake message
+
+ var buff [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
+
+ // send to endpoint
+
+ peer.TimerAnyAuthenticatedPacketTraversal()
+
+ err = peer.SendBuffer(packet)
+ if err == nil {
+ peer.signal.handshakeCompleted.Enable()
+ }
+
+ // set timeout
+
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
+
+ peer.timer.keepalivePassive.Stop()
+ peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
+
+ return err
+}
+
+func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
device := peer.device
logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
+ // reset all timers
+
+ peer.timer.keepalivePassive.Stop()
+ peer.timer.handshakeDeadline.Stop()
+ peer.timer.handshakeTimeout.Stop()
+ peer.timer.handshakeNew.Stop()
+ peer.timer.zeroAllKeys.Stop()
+
+ interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
+ if interval > 0 {
+ duration := time.Duration(interval) * time.Second
+ peer.timer.keepalivePersistent.Reset(duration)
+ }
+
+ // signal that timers are reset
+
+ ready.Done()
+
+ // handle timer events
+
for {
select {
@@ -158,6 +221,7 @@ func (peer *Peer) RoutineTimerHandler() { interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
logDebug.Println("Sending keep-alive to", peer.String())
+ peer.timer.keepalivePassive.Stop()
peer.SendKeepAlive()
}
@@ -168,8 +232,8 @@ func (peer *Peer) RoutineTimerHandler() { peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive {
- peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
peer.timer.needAnotherKeepalive = false
+ peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
}
// clear key material timer
@@ -213,7 +277,7 @@ func (peer *Peer) RoutineTimerHandler() { // handshake timers
- case <-peer.timer.newHandshake.Wait():
+ case <-peer.timer.handshakeNew.Wait():
logInfo.Println("Retrying handshake with", peer.String())
peer.signal.handshakeBegin.Send()
@@ -268,48 +332,16 @@ func (peer *Peer) RoutineTimerHandler() { logInfo.Println(
"Handshake completed for:", peer.String())
+ atomic.StoreInt64(
+ &peer.stats.lastHandshakeNano,
+ time.Now().UnixNano(),
+ )
+
peer.timer.handshakeTimeout.Stop()
peer.timer.handshakeDeadline.Stop()
peer.signal.handshakeBegin.Enable()
- }
- }
-}
-
-/* Sends a new handshake initiation message to the peer (endpoint)
- */
-func (peer *Peer) sendNewHandshake() error {
-
- // temporarily disable the handshake complete signal
-
- peer.signal.handshakeCompleted.Disable()
-
- // create initiation message
- msg, err := peer.device.CreateMessageInitiation(peer)
- if err != nil {
- return err
- }
-
- // marshal handshake message
-
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
-
- // send to endpoint
-
- err = peer.SendBuffer(packet)
- if err == nil {
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.signal.handshakeCompleted.Enable()
+ peer.timer.sendLastMinuteHandshake = false
+ }
}
-
- // set timeout
-
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
- peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
-
- return err
}
@@ -46,21 +46,13 @@ func (device *Device) RoutineTUNEventReader() { } if event&TUNEventUp != 0 { - if !device.tun.isUp.Get() { - // begin listening for incomming datagrams - logInfo.Println("Interface set up") - device.tun.isUp.Set(true) - updateBind(device) - } + logInfo.Println("Interface set up") + device.Up() } if event&TUNEventDown != 0 { - if device.tun.isUp.Get() { - // stop listening for incomming datagrams - logInfo.Println("Interface set down") - device.tun.isUp.Set(false) - closeBind(device) - } + logInfo.Println("Interface set down") + device.Up() } } } diff --git a/src/uapi.go b/src/uapi.go index 155f483..a67bff1 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -296,7 +296,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { logError.Println("Failed to get tun device status:", err) return &IPCError{Code: ipcErrorIO} } - if device.tun.isUp.Get() && !dummy { + if device.isUp.Get() && !dummy { peer.SendKeepAlive() } } |