summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/conn.go2
-rw-r--r--src/device.go44
-rw-r--r--src/peer.go45
-rw-r--r--src/timer.go6
-rw-r--r--src/timers.go150
-rw-r--r--src/tun.go16
-rw-r--r--src/uapi.go2
7 files changed, 102 insertions, 163 deletions
diff --git a/src/conn.go b/src/conn.go
index ddb7ed1..6d292d3 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -82,7 +82,7 @@ func updateBind(device *Device) error {
// open new sockets
- if device.isUp.Get() {
+ if device.tun.isUp.Get() {
device.log.Debug.Println("UDP bind updating")
diff --git a/src/device.go b/src/device.go
index f4a087c..a3461ad 100644
--- a/src/device.go
+++ b/src/device.go
@@ -8,13 +8,13 @@ import (
)
type Device struct {
- isUp AtomicBool // device is up (TUN interface up)?
- isClosed AtomicBool // device is closed? (acting as guard)
+ closed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
fwMark uint32
tun struct {
device TUNDevice
+ isUp AtomicBool
mtu int32
}
pool struct {
@@ -45,28 +45,6 @@ type Device struct {
mac CookieChecker
}
-func (device *Device) Up() {
- device.mutex.Lock()
- defer device.mutex.Unlock()
-
- device.isUp.Set(true)
- updateBind(device)
- for _, peer := range device.peers {
- peer.Start()
- }
-}
-
-func (device *Device) Down() {
- device.mutex.Lock()
- defer device.mutex.Unlock()
-
- device.isUp.Set(false)
- closeBind(device)
- for _, peer := range device.peers {
- peer.Stop()
- }
-}
-
/* Warning:
* The caller must hold the device mutex (write lock)
*/
@@ -76,9 +54,9 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) {
return
}
peer.mutex.Lock()
- peer.Stop()
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
+ peer.Close()
}
func (device *Device) IsUnderLoad() bool {
@@ -120,7 +98,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.publicKey = publicKey
device.mac.Init(publicKey)
- // do DH pre-computations
+ // do DH precomputations
rmKey := device.privateKey.IsZero()
@@ -154,12 +132,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
device.mutex.Lock()
defer device.mutex.Unlock()
- device.isUp.Set(false)
- device.isClosed.Set(false)
-
device.log = logger
device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun
+ device.tun.isUp.Set(false)
device.indices.Init()
device.ratelimiter.Init()
@@ -220,13 +196,17 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
func (device *Device) RemoveAllPeers() {
device.mutex.Lock()
defer device.mutex.Unlock()
- for key := range device.peers {
- removePeerUnsafe(device, key)
+
+ for key, peer := range device.peers {
+ peer.mutex.Lock()
+ delete(device.peers, key)
+ peer.Close()
+ peer.mutex.Unlock()
}
}
func (device *Device) Close() {
- if device.isClosed.Swap(true) {
+ if device.closed.Swap(true) {
return
}
device.log.Info.Println("Closing device")
diff --git a/src/peer.go b/src/peer.go
index 7c6ad47..f582556 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -34,15 +34,15 @@ type Peer struct {
flushNonceQueue Signal // size 1, empty queued packets
messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv
- stop Signal // size 0, stop all goroutines in peer
+ stop Signal // size 0, stop all goroutines
}
timer struct {
// state related to WireGuard timers
keepalivePersistent Timer // set for persistent keepalives
keepalivePassive Timer // set upon recieving messages
+ newHandshake Timer // begin a new handshake (stale)
zeroAllKeys Timer // zero all key material
- handshakeNew Timer // begin a new handshake (stale)
handshakeDeadline Timer // complete handshake timeout
handshakeTimeout Timer // current handshake message timeout
@@ -69,8 +69,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewTimer()
+ peer.timer.newHandshake = NewTimer()
peer.timer.zeroAllKeys = NewTimer()
- peer.timer.handshakeNew = NewTimer()
peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer()
@@ -116,29 +116,32 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// prepare signaling & routines
+ peer.signal.stop = NewSignal()
peer.signal.newKeyPair = NewSignal()
peer.signal.handshakeBegin = NewSignal()
peer.signal.handshakeCompleted = NewSignal()
peer.signal.flushNonceQueue = NewSignal()
+ go peer.RoutineNonce()
+ go peer.RoutineTimerHandler()
+ go peer.RoutineSequentialSender()
+ go peer.RoutineSequentialReceiver()
+
return peer, nil
}
func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.mutex.RLock()
defer peer.device.net.mutex.RUnlock()
-
peer.mutex.RLock()
defer peer.mutex.RUnlock()
-
if peer.endpoint == nil {
return errors.New("No known endpoint for peer")
}
-
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
-/* Returns a short string identifier for logging
+/* Returns a short string identification for logging
*/
func (peer *Peer) String() string {
if peer.endpoint == nil {
@@ -156,32 +159,6 @@ func (peer *Peer) String() string {
)
}
-/* Starts all routines for a given peer
- *
- * Requires that the caller holds the exclusive peer lock!
- */
-func unsafePeerStart(peer *Peer) {
- peer.signal.stop.Broadcast()
- peer.signal.stop = NewSignal()
-
- var wait sync.WaitGroup
-
- wait.Add(1)
-
- go peer.RoutineNonce()
- go peer.RoutineTimerHandler(&wait)
- go peer.RoutineSequentialSender()
- go peer.RoutineSequentialReceiver()
-
- wait.Wait()
-}
-
-func (peer *Peer) Start() {
- peer.mutex.Lock()
- unsafePeerStart(peer)
- peer.mutex.Unlock()
-}
-
-func (peer *Peer) Stop() {
+func (peer *Peer) Close() {
peer.signal.stop.Broadcast()
}
diff --git a/src/timer.go b/src/timer.go
index f00ca49..3def253 100644
--- a/src/timer.go
+++ b/src/timer.go
@@ -43,6 +43,12 @@ func (t *Timer) Reset(dur time.Duration) {
t.Start(dur)
}
+func (t *Timer) Push(dur time.Duration) {
+ if t.pending.Get() {
+ t.Reset(dur)
+ }
+}
+
func (t *Timer) Wait() <-chan time.Time {
return t.timer.C
}
diff --git a/src/timers.go b/src/timers.go
index f2fed30..ee47393 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -4,17 +4,10 @@ import (
"bytes"
"encoding/binary"
"math/rand"
- "sync"
"sync/atomic"
"time"
)
-/* NOTE:
- * Notion of validity
- *
- *
- */
-
/* Called when a new authenticated message has been send
*
*/
@@ -51,25 +44,25 @@ func (peer *Peer) KeepKeyFreshReceiving() {
send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTimeReceiving
if send {
// do a last minute attempt at initiating a new handshake
- peer.timer.sendLastMinuteHandshake = true
peer.signal.handshakeBegin.Send()
+ peer.timer.sendLastMinuteHandshake = true
}
}
/* Queues a keep-alive if no packets are queued for peer
*/
func (peer *Peer) SendKeepAlive() bool {
- if len(peer.queue.nonce) != 0 {
- return false
- }
elem := peer.device.NewOutboundElement()
elem.packet = nil
- select {
- case peer.queue.nonce <- elem:
- return true
- default:
- return false
+ if len(peer.queue.nonce) == 0 {
+ select {
+ case peer.queue.nonce <- elem:
+ return true
+ default:
+ return false
+ }
}
+ return true
}
/* Event:
@@ -77,7 +70,9 @@ func (peer *Peer) SendKeepAlive() bool {
*/
func (peer *Peer) TimerDataSent() {
peer.timer.keepalivePassive.Stop()
- peer.timer.handshakeNew.Start(NewHandshakeTime)
+ if peer.timer.newHandshake.Pending() {
+ peer.timer.newHandshake.Reset(NewHandshakeTime)
+ }
}
/* Event:
@@ -96,7 +91,7 @@ func (peer *Peer) TimerDataReceived() {
* Any (authenticated) packet received
*/
func (peer *Peer) TimerAnyAuthenticatedPacketReceived() {
- peer.timer.handshakeNew.Stop()
+ peer.timer.newHandshake.Stop()
}
/* Event:
@@ -120,6 +115,10 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
* - First transport message under the "next" key
*/
func (peer *Peer) TimerHandshakeComplete() {
+ atomic.StoreInt64(
+ &peer.stats.lastHandshakeNano,
+ time.Now().UnixNano(),
+ )
peer.signal.handshakeCompleted.Send()
peer.device.log.Info.Println("Negotiated new handshake for", peer.String())
}
@@ -140,75 +139,13 @@ func (peer *Peer) TimerEphemeralKeyCreated() {
peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
}
-/* Sends a new handshake initiation message to the peer (endpoint)
- */
-func (peer *Peer) sendNewHandshake() error {
-
- // temporarily disable the handshake complete signal
-
- peer.signal.handshakeCompleted.Disable()
-
- // create initiation message
-
- msg, err := peer.device.CreateMessageInitiation(peer)
- if err != nil {
- return err
- }
-
- // marshal handshake message
-
- var buff [MessageInitiationSize]byte
- writer := bytes.NewBuffer(buff[:0])
- binary.Write(writer, binary.LittleEndian, msg)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
-
- // send to endpoint
-
- peer.TimerAnyAuthenticatedPacketTraversal()
-
- err = peer.SendBuffer(packet)
- if err == nil {
- peer.signal.handshakeCompleted.Enable()
- }
-
- // set timeout
-
- jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
-
- peer.timer.keepalivePassive.Stop()
- peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
-
- return err
-}
-
-func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
+func (peer *Peer) RoutineTimerHandler() {
device := peer.device
logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, timer handler, started for peer", peer.String())
- // reset all timers
-
- peer.timer.keepalivePassive.Stop()
- peer.timer.handshakeDeadline.Stop()
- peer.timer.handshakeTimeout.Stop()
- peer.timer.handshakeNew.Stop()
- peer.timer.zeroAllKeys.Stop()
-
- interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
- if interval > 0 {
- duration := time.Duration(interval) * time.Second
- peer.timer.keepalivePersistent.Reset(duration)
- }
-
- // signal that timers are reset
-
- ready.Done()
-
- // handle timer events
-
for {
select {
@@ -221,7 +158,6 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
logDebug.Println("Sending keep-alive to", peer.String())
- peer.timer.keepalivePassive.Stop()
peer.SendKeepAlive()
}
@@ -232,8 +168,8 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive {
- peer.timer.needAnotherKeepalive = false
peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
+ peer.timer.needAnotherKeepalive = false
}
// clear key material timer
@@ -277,7 +213,7 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
// handshake timers
- case <-peer.timer.handshakeNew.Wait():
+ case <-peer.timer.newHandshake.Wait():
logInfo.Println("Retrying handshake with", peer.String())
peer.signal.handshakeBegin.Send()
@@ -332,16 +268,48 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
logInfo.Println(
"Handshake completed for:", peer.String())
- atomic.StoreInt64(
- &peer.stats.lastHandshakeNano,
- time.Now().UnixNano(),
- )
-
peer.timer.handshakeTimeout.Stop()
peer.timer.handshakeDeadline.Stop()
peer.signal.handshakeBegin.Enable()
-
- peer.timer.sendLastMinuteHandshake = false
}
}
}
+
+/* Sends a new handshake initiation message to the peer (endpoint)
+ */
+func (peer *Peer) sendNewHandshake() error {
+
+ // temporarily disable the handshake complete signal
+
+ peer.signal.handshakeCompleted.Disable()
+
+ // create initiation message
+
+ msg, err := peer.device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ // marshal handshake message
+
+ var buff [MessageInitiationSize]byte
+ writer := bytes.NewBuffer(buff[:0])
+ binary.Write(writer, binary.LittleEndian, msg)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
+
+ // send to endpoint
+
+ err = peer.SendBuffer(packet)
+ if err == nil {
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.signal.handshakeCompleted.Enable()
+ }
+
+ // set timeout
+
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
+ peer.timer.handshakeTimeout.Reset(RekeyTimeout + jitter)
+
+ return err
+}
diff --git a/src/tun.go b/src/tun.go
index 024f0f0..54253b4 100644
--- a/src/tun.go
+++ b/src/tun.go
@@ -46,13 +46,21 @@ func (device *Device) RoutineTUNEventReader() {
}
if event&TUNEventUp != 0 {
- logInfo.Println("Interface set up")
- device.Up()
+ if !device.tun.isUp.Get() {
+ // begin listening for incomming datagrams
+ logInfo.Println("Interface set up")
+ device.tun.isUp.Set(true)
+ updateBind(device)
+ }
}
if event&TUNEventDown != 0 {
- logInfo.Println("Interface set down")
- device.Up()
+ if device.tun.isUp.Get() {
+ // stop listening for incomming datagrams
+ logInfo.Println("Interface set down")
+ device.tun.isUp.Set(false)
+ closeBind(device)
+ }
}
}
}
diff --git a/src/uapi.go b/src/uapi.go
index a67bff1..155f483 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -296,7 +296,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to get tun device status:", err)
return &IPCError{Code: ipcErrorIO}
}
- if device.isUp.Get() && !dummy {
+ if device.tun.isUp.Get() && !dummy {
peer.SendKeepAlive()
}
}