diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-05-16 22:20:15 +0200 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2018-05-16 22:20:15 +0200 |
commit | 846d721dfd0cde953f2e9304d6ef50110de050eb (patch) | |
tree | 8de15914ab39d0aad1b50d03530b82fece54c740 | |
parent | 23eca94508d7cef0c1adbbc37c81050899ca1d60 (diff) |
Finer-grained start-stop synchronization
-rw-r--r-- | conn.go | 6 | ||||
-rw-r--r-- | device.go | 12 | ||||
-rw-r--r-- | peer.go | 9 | ||||
-rw-r--r-- | receive.go | 4 | ||||
-rw-r--r-- | send.go | 3 | ||||
-rw-r--r-- | tun.go | 4 |
6 files changed, 33 insertions, 5 deletions
@@ -12,6 +12,10 @@ import ( "net" ) +const ( + ConnRoutineNumber = 2 +) + /* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic */ type Bind interface { @@ -153,6 +157,8 @@ func (device *Device) BindUpdate() error { // start receiving routines + device.state.starting.Add(ConnRoutineNumber) + device.state.stopping.Add(ConnRoutineNumber) go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) @@ -15,6 +15,7 @@ import ( const ( DeviceRoutineNumberPerCPU = 3 + DeviceRoutineNumberAdditional = 2 ) type Device struct { @@ -25,6 +26,7 @@ type Device struct { // synchronized resources (locks acquired in order) state struct { + starting sync.WaitGroup stopping sync.WaitGroup mutex sync.Mutex changing AtomicBool @@ -297,7 +299,10 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { // start workers cpus := runtime.NumCPU() - device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus) + device.state.starting.Wait() + device.state.stopping.Wait() + device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus + DeviceRoutineNumberAdditional) + device.state.starting.Add(DeviceRoutineNumberPerCPU * cpus + DeviceRoutineNumberAdditional) for i := 0; i < cpus; i += 1 { go device.RoutineEncryption() go device.RoutineDecryption() @@ -307,6 +312,8 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() + device.state.starting.Wait() + return device } @@ -363,6 +370,9 @@ func (device *Device) Close() { if device.isClosed.Swap(true) { return } + + device.state.starting.Wait() + device.log.Info.Println("Device closing") device.state.changing.Set(true) device.state.mutex.Lock() @@ -231,20 +231,21 @@ func (peer *Peer) Stop() { // prevent simultaneous start/stop operations - peer.routines.mutex.Lock() - defer peer.routines.mutex.Unlock() - if !peer.isRunning.Swap(false) { return } + peer.routines.starting.Wait() + + peer.routines.mutex.Lock() + defer peer.routines.mutex.Unlock() + peer.device.log.Debug.Println(peer, ": Stopping...") peer.timersStop() // stop & wait for ongoing peer routines - peer.routines.starting.Wait() close(peer.routines.stop) peer.routines.stopping.Wait() @@ -124,9 +124,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { logDebug := device.log.Debug defer func() { logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.state.stopping.Done() }() logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting") + device.state.starting.Done() // receive datagrams until conn is closed @@ -257,6 +259,7 @@ func (device *Device) RoutineDecryption() { device.state.stopping.Done() }() logDebug.Println("Routine: decryption worker - started") + device.state.starting.Done() for { select { @@ -324,6 +327,7 @@ func (device *Device) RoutineHandshake() { }() logDebug.Println("Routine: handshake worker - started") + device.state.starting.Done() var elem QueueHandshakeElement var ok bool @@ -247,9 +247,11 @@ func (device *Device) RoutineReadFromTUN() { defer func() { logDebug.Println("Routine: TUN reader - stopped") + device.state.stopping.Done() }() logDebug.Println("Routine: TUN reader - started") + device.state.starting.Done() for { @@ -424,6 +426,7 @@ func (device *Device) RoutineEncryption() { }() logDebug.Println("Routine: encryption worker - started") + device.state.starting.Done() for { @@ -35,6 +35,8 @@ func (device *Device) RoutineTUNEventReader() { logInfo := device.log.Info logError := device.log.Error + device.state.starting.Done() + for event := range device.tun.device.Events() { if event&TUNEventMTUUpdate != 0 { mtu, err := device.tun.device.MTU() @@ -63,4 +65,6 @@ func (device *Device) RoutineTUNEventReader() { device.Down() } } + + device.state.stopping.Done() } |