diff options
Diffstat (limited to 'receive.go')
-rw-r--r-- | receive.go | 89 |
1 files changed, 57 insertions, 32 deletions
@@ -31,7 +31,7 @@ type QueueInboundElement struct { buffer *[MaxMessageSize]byte packet []byte counter uint64 - keyPair *KeyPair + keyPair *Keypair endpoint Endpoint } @@ -99,6 +99,21 @@ func (device *Device) addToHandshakeQueue( } } +/* Called when a new authenticated message has been received + * + * NOTE: Not thread safe, but called by sequential receiver! + */ +func (peer *Peer) keepKeyFreshReceiving() { + if peer.timers.sentLastMinuteHandshake { + return + } + kp := peer.keyPairs.Current() + if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { + peer.timers.sentLastMinuteHandshake = true + peer.SendHandshakeInitiation(false) + } +} + /* Receives incoming datagrams for the device * * Every time the bind is updated a new routine is started for @@ -245,7 +260,7 @@ func (device *Device) RoutineDecryption() { for { select { - case <-device.signal.stop.Wait(): + case <-device.signals.stop: return case elem, ok := <-device.queue.decryption: @@ -317,7 +332,7 @@ func (device *Device) RoutineHandshake() { for { select { case elem, ok = <-device.queue.handshake: - case <-device.signal.stop.Wait(): + case <-device.signals.stop: return } @@ -441,8 +456,8 @@ func (device *Device) RoutineHandshake() { // update timers - peer.event.anyAuthenticatedPacketTraversal.Fire() - peer.event.anyAuthenticatedPacketReceived.Fire() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() // update endpoint @@ -460,10 +475,11 @@ func (device *Device) RoutineHandshake() { continue } - peer.TimerEphemeralKeyCreated() - peer.NewKeyPair() + if peer.NewKeypair() == nil { + continue + } - logDebug.Println(peer, ": Creating handshake response") + logDebug.Println(peer, ": Sending handshake response") writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, response) @@ -472,9 +488,10 @@ func (device *Device) RoutineHandshake() { // send response + peer.timers.lastSentHandshake = time.Now() err = peer.SendBuffer(packet) if err == nil { - peer.event.anyAuthenticatedPacketTraversal.Fire() + peer.timersAnyAuthenticatedPacketTraversal() } else { logError.Println(peer, ": Failed to send handshake response", err) } @@ -510,18 +527,23 @@ func (device *Device) RoutineHandshake() { logDebug.Println(peer, ": Received handshake response") - peer.TimerEphemeralKeyCreated() - // update timers - peer.event.anyAuthenticatedPacketTraversal.Fire() - peer.event.anyAuthenticatedPacketReceived.Fire() - peer.event.handshakeCompleted.Fire() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() // derive key-pair - peer.NewKeyPair() - peer.SendKeepAlive() + if peer.NewKeypair() == nil { + continue + } + + peer.timersHandshakeComplete() + peer.SendKeepalive() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } } } } @@ -569,38 +591,41 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } - peer.event.anyAuthenticatedPacketTraversal.Fire() - peer.event.anyAuthenticatedPacketReceived.Fire() - peer.KeepKeyFreshReceiving() + // update endpoint + + peer.mutex.Lock() + peer.endpoint = elem.endpoint + peer.mutex.Unlock() // check if using new key-pair kp := &peer.keyPairs - kp.mutex.Lock() + kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true if kp.next == elem.keyPair { - peer.event.handshakeCompleted.Fire() - if kp.previous != nil { - device.DeleteKeyPair(kp.previous) - } + old := kp.previous kp.previous = kp.current + device.DeleteKeypair(old) kp.current = kp.next kp.next = nil + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } } kp.mutex.Unlock() - // update endpoint - - peer.mutex.Lock() - peer.endpoint = elem.endpoint - peer.mutex.Unlock() + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() - // check for keep-alive + // check for keepalive if len(elem.packet) == 0 { - logDebug.Println(peer, ": Received keep-alive") + logDebug.Println(peer, ": Receiving keepalive packet") continue } - peer.event.dataReceived.Fire() + peer.timersDataReceived() // verify source and strip padding |