diff options
Diffstat (limited to 'device/peer.go')
-rw-r--r-- | device/peer.go | 57 |
1 files changed, 13 insertions, 44 deletions
diff --git a/device/peer.go b/device/peer.go index b385519..76f9a96 100644 --- a/device/peer.go +++ b/device/peer.go @@ -25,6 +25,7 @@ type Peer struct { endpoint conn.Endpoint persistentKeepaliveInterval uint32 // accessed atomically firstTrieEntry *trieEntry + stopping sync.WaitGroup // routines pending stop // These fields are accessed with atomic operations, which must be // 64-bit aligned even on 32-bit platforms. Go guarantees that an @@ -53,14 +54,8 @@ type Peer struct { queue struct { sync.RWMutex staged chan *QueueOutboundElement // staged packets before a handshake is available - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - } - - routines struct { - sync.Mutex // held when stopping routines - stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go routines in peer + outbound chan *QueueOutboundElement // sequential ordering of udp transmission + inbound chan *QueueInboundElement // sequential ordering of tun writing } cookieGenerator CookieGenerator @@ -72,7 +67,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { } // lock resources - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -80,13 +74,11 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { defer device.peers.Unlock() // check if over limit - if len(device.peers.keyMap) >= MaxPeers { return nil, errors.New("too many peers") } // create peer - peer := new(Peer) peer.Lock() defer peer.Unlock() @@ -95,14 +87,12 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { peer.device = device // map public key - _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("adding existing peer") } // pre-compute DH - handshake := &peer.handshake handshake.mutex.Lock() handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) @@ -110,16 +100,13 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { handshake.mutex.Unlock() // reset endpoint - peer.endpoint = nil // add - device.peers.keyMap[pk] = peer device.peers.empty.Set(false) // start peer - if peer.device.isUp.Get() { peer.Start() } @@ -164,17 +151,14 @@ func (peer *Peer) String() string { } func (peer *Peer) Start() { - // should never start a peer on a closed device - if peer.device.isClosed.Get() { return } // prevent simultaneous start/stop operations - - peer.routines.Lock() - defer peer.routines.Unlock() + peer.queue.Lock() + defer peer.queue.Unlock() if peer.isRunning.Get() { return @@ -184,23 +168,19 @@ func (peer *Peer) Start() { device.log.Verbosef("%v - Starting...", peer) // reset routine state - - peer.routines.stopping.Wait() - peer.routines.stop = make(chan struct{}) - peer.routines.stopping.Add(1) + peer.stopping.Wait() + peer.stopping.Add(2) // prepare queues - peer.queue.Lock() - peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) - peer.queue.Unlock() + if peer.queue.staged == nil { + peer.queue.staged = make(chan *QueueOutboundElement, QueueStagedSize) + } peer.timersInit() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) - // wait for routines to start - go peer.RoutineSequentialSender() go peer.RoutineSequentialReceiver() @@ -254,31 +234,20 @@ func (peer *Peer) ExpireCurrentKeypairs() { } func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations + peer.queue.Lock() + defer peer.queue.Unlock() if !peer.isRunning.Swap(false) { return } - peer.routines.Lock() - defer peer.routines.Unlock() - peer.device.log.Verbosef("%v - Stopping...", peer) peer.timersStop() - // stop & wait for ongoing peer routines - - close(peer.routines.stop) - peer.routines.stopping.Wait() - - // close queues - - peer.queue.Lock() close(peer.queue.inbound) close(peer.queue.outbound) - peer.queue.Unlock() + peer.stopping.Wait() peer.ZeroAndFlushAll() } |