diff options
-rw-r--r-- | src/conn.go | 11 | ||||
-rw-r--r-- | src/device.go | 23 | ||||
-rw-r--r-- | src/peer.go | 62 | ||||
-rw-r--r-- | src/receive.go | 2 | ||||
-rw-r--r-- | src/send.go | 11 | ||||
-rw-r--r-- | src/timers.go | 18 | ||||
-rw-r--r-- | src/tun.go | 6 | ||||
-rw-r--r-- | src/uapi.go | 16 |
8 files changed, 102 insertions, 47 deletions
diff --git a/src/conn.go b/src/conn.go index ddb7ed1..1d033ff 100644 --- a/src/conn.go +++ b/src/conn.go @@ -64,13 +64,9 @@ func unsafeCloseBind(device *Device) error { return err } -func updateBind(device *Device) error { - device.mutex.Lock() - defer device.mutex.Unlock() - - netc := &device.net - netc.mutex.Lock() - defer netc.mutex.Unlock() +/* Must hold device and net lock + */ +func unsafeUpdateBind(device *Device) error { // close existing sockets @@ -89,6 +85,7 @@ func updateBind(device *Device) error { // bind to new port var err error + netc := &device.net netc.bind, netc.port, err = CreateBind(netc.port) if err != nil { netc.bind = nil diff --git a/src/device.go b/src/device.go index f4a087c..5f8e91b 100644 --- a/src/device.go +++ b/src/device.go @@ -1,6 +1,7 @@ package main import ( + "github.com/sasha-s/go-deadlock" "runtime" "sync" "sync/atomic" @@ -21,12 +22,12 @@ type Device struct { messageBuffers sync.Pool } net struct { - mutex sync.RWMutex + mutex deadlock.RWMutex bind Bind // bind interface port uint16 // listening port fwmark uint32 // mark value (0 = disabled) } - mutex sync.RWMutex + mutex deadlock.RWMutex privateKey NoisePrivateKey publicKey NoisePublicKey routingTable RoutingTable @@ -49,8 +50,15 @@ func (device *Device) Up() { device.mutex.Lock() defer device.mutex.Unlock() - device.isUp.Set(true) - updateBind(device) + device.net.mutex.Lock() + defer device.net.mutex.Unlock() + + if device.isUp.Swap(true) { + return + } + + unsafeUpdateBind(device) + for _, peer := range device.peers { peer.Start() } @@ -60,8 +68,12 @@ func (device *Device) Down() { device.mutex.Lock() defer device.mutex.Unlock() - device.isUp.Set(false) + if !device.isUp.Swap(false) { + return + } + closeBind(device) + for _, peer := range device.peers { peer.Stop() } @@ -75,7 +87,6 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) { if !ok { return } - peer.mutex.Lock() peer.Stop() device.routingTable.RemovePeer(peer) delete(device.peers, key) diff --git a/src/peer.go b/src/peer.go index 7c6ad47..3d82989 100644 --- a/src/peer.go +++ b/src/peer.go @@ -8,6 +8,10 @@ import ( "time" ) +const ( + PeerRoutineNumber = 4 +) + type Peer struct { id uint mutex sync.RWMutex @@ -34,7 +38,6 @@ type Peer struct { flushNonceQueue Signal // size 1, empty queued packets messageSend Signal // size 1, message was send to peer messageReceived Signal // size 1, authenticated message recv - stop Signal // size 0, stop all goroutines in peer } timer struct { // state related to WireGuard timers @@ -54,6 +57,12 @@ type Peer struct { outbound chan *QueueOutboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work } + routines struct { + mutex sync.Mutex // held when stopping / starting routines + starting sync.WaitGroup // routines pending start + stopping sync.WaitGroup // routines pending stop + stop Signal // size 0, stop all goroutines in peer + } mac CookieGenerator } @@ -121,6 +130,10 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.signal.handshakeCompleted = NewSignal() peer.signal.flushNonceQueue = NewSignal() + peer.routines.mutex.Lock() + peer.routines.stop = NewSignal() + peer.routines.mutex.Unlock() + return peer, nil } @@ -156,32 +169,43 @@ func (peer *Peer) String() string { ) } -/* Starts all routines for a given peer - * - * Requires that the caller holds the exclusive peer lock! - */ -func unsafePeerStart(peer *Peer) { - peer.signal.stop.Broadcast() - peer.signal.stop = NewSignal() +func (peer *Peer) Start() { + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Lock() + + // stop & wait for ungoing routines (if any) + + peer.routines.stop.Broadcast() + peer.routines.starting.Wait() + peer.routines.stopping.Wait() - var wait sync.WaitGroup + // reset signal and start (new) routines - wait.Add(1) + peer.routines.stop = NewSignal() + peer.routines.starting.Add(PeerRoutineNumber) + peer.routines.stopping.Add(PeerRoutineNumber) go peer.RoutineNonce() - go peer.RoutineTimerHandler(&wait) + go peer.RoutineTimerHandler() go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() - wait.Wait() -} - -func (peer *Peer) Start() { - peer.mutex.Lock() - unsafePeerStart(peer) - peer.mutex.Unlock() + peer.routines.starting.Wait() } func (peer *Peer) Stop() { - peer.signal.stop.Broadcast() + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Lock() + + // stop & wait for ungoing routines (if any) + + peer.routines.stop.Broadcast() + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + + // reset signal (to handle repeated stopping) + + peer.routines.stop = NewSignal() } diff --git a/src/receive.go b/src/receive.go index dbd2813..e6e8481 100644 --- a/src/receive.go +++ b/src/receive.go @@ -497,7 +497,7 @@ func (peer *Peer) RoutineSequentialReceiver() { select { - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id) return diff --git a/src/send.go b/src/send.go index 9537f5e..fa13c91 100644 --- a/src/send.go +++ b/src/send.go @@ -192,7 +192,7 @@ func (peer *Peer) RoutineNonce() { for { NextPacket: select { - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): return case elem := <-peer.queue.nonce: @@ -217,7 +217,7 @@ func (peer *Peer) RoutineNonce() { logDebug.Println("Clearing queue for", peer.String()) peer.FlushNonceQueue() goto NextPacket - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): return } } @@ -309,15 +309,20 @@ func (device *Device) RoutineEncryption() { * The routine terminates then the outbound queue is closed. */ func (peer *Peer) RoutineSequentialSender() { + + defer peer.routines.stopping.Done() + device := peer.device logDebug := device.log.Debug logDebug.Println("Routine, sequential sender, started for", peer.String()) + peer.routines.starting.Done() + for { select { - case <-peer.signal.stop.Wait(): + case <-peer.routines.stop.Wait(): logDebug.Println( "Routine, sequential sender, stopped for", peer.String()) return diff --git a/src/timers.go b/src/timers.go index f2fed30..f1ed9c5 100644 --- a/src/timers.go +++ b/src/timers.go @@ -4,7 +4,6 @@ import ( "bytes"
"encoding/binary"
"math/rand"
- "sync"
"sync/atomic"
"time"
)
@@ -182,7 +181,10 @@ func (peer *Peer) sendNewHandshake() error { return err
}
-func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) {
+func (peer *Peer) RoutineTimerHandler() {
+
+ defer peer.routines.stopping.Done()
+
device := peer.device
logInfo := device.log.Info
@@ -203,15 +205,20 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { peer.timer.keepalivePersistent.Reset(duration)
}
- // signal that timers are reset
+ // signal synchronised setup complete
- ready.Done()
+ peer.routines.starting.Done()
// handle timer events
for {
select {
+ /* stopping */
+
+ case <-peer.routines.stop.Wait():
+ return
+
/* timers */
// keep-alive
@@ -312,9 +319,6 @@ func (peer *Peer) RoutineTimerHandler(ready *sync.WaitGroup) { /* signals */
- case <-peer.signal.stop.Wait():
- return
-
case <-peer.signal.handshakeBegin.Wait():
peer.signal.handshakeBegin.Disable()
@@ -45,14 +45,14 @@ func (device *Device) RoutineTUNEventReader() { } } - if event&TUNEventUp != 0 { + if event&TUNEventUp != 0 && !device.isUp.Get() { logInfo.Println("Interface set up") device.Up() } - if event&TUNEventDown != 0 { + if event&TUNEventDown != 0 && device.isUp.Get() { logInfo.Println("Interface set down") - device.Up() + device.Down() } } } diff --git a/src/uapi.go b/src/uapi.go index a67bff1..f66528c 100644 --- a/src/uapi.go +++ b/src/uapi.go @@ -133,13 +133,27 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { device.SetPrivateKey(sk) case "listen_port": + + // parse port number + port, err := strconv.ParseUint(value, 10, 16) if err != nil { logError.Println("Failed to parse listen_port:", err) return &IPCError{Code: ipcErrorInvalid} } + + // update port and rebind + + device.mutex.Lock() + device.net.mutex.Lock() + device.net.port = uint16(port) - if err := updateBind(device); err != nil { + err = unsafeUpdateBind(device) + + device.net.mutex.Unlock() + device.mutex.Unlock() + + if err != nil { logError.Println("Failed to set listen_port:", err) return &IPCError{Code: ipcErrorPortInUse} } |