diff options
Diffstat (limited to 'src/receive.go')
-rw-r--r-- | src/receive.go | 150 |
1 files changed, 71 insertions, 79 deletions
diff --git a/src/receive.go b/src/receive.go index 5afbf7f..50789a1 100644 --- a/src/receive.go +++ b/src/receive.go @@ -31,17 +31,39 @@ type QueueInboundElement struct { func (elem *QueueInboundElement) Drop() { atomic.StoreUint32(&elem.state, ElementStateDropped) - elem.mutex.Unlock() +} + +func (elem *QueueInboundElement) IsDropped() bool { + return atomic.LoadUint32(&elem.state) == ElementStateDropped +} + +func addToInboundQueue( + queue chan *QueueInboundElement, + element *QueueInboundElement, +) { + for { + select { + case queue <- element: + return + default: + select { + case old := <-queue: + old.Drop() + default: + } + } + } } func (device *Device) RoutineReceiveIncomming() { - var packet []byte debugLog := device.log.Debug debugLog.Println("Routine, receive incomming, started") errorLog := device.log.Error + var buffer []byte // unsliced buffer + for { // check if stopped @@ -54,28 +76,28 @@ func (device *Device) RoutineReceiveIncomming() { // read next datagram - if packet == nil { - packet = make([]byte, 1<<16) + if buffer == nil { + buffer = make([]byte, MaxMessageSize) } device.net.mutex.RLock() conn := device.net.conn device.net.mutex.RUnlock() + if conn == nil { + time.Sleep(time.Second) + continue + } conn.SetReadDeadline(time.Now().Add(time.Second)) - size, raddr, err := conn.ReadFromUDP(packet) - if err != nil { - continue - } - if size < MinMessageSize { + size, raddr, err := conn.ReadFromUDP(buffer) + if err != nil || size < MinMessageSize { continue } // handle packet - packet = packet[:size] - debugLog.Println("GOT:", packet) + packet := buffer[:size] msgType := binary.LittleEndian.Uint32(packet[:4]) func() { @@ -112,6 +134,7 @@ func (device *Device) RoutineReceiveIncomming() { // add to handshake queue + buffer = nil device.queue.handshake <- QueueHandshakeElement{ msgType: msgType, packet: packet, @@ -137,8 +160,6 @@ func (device *Device) RoutineReceiveIncomming() { case MessageTransportType: - debugLog.Println("DEBUG: Got transport") - // lookup key pair if len(packet) < MessageTransportSize { @@ -169,42 +190,15 @@ func (device *Device) RoutineReceiveIncomming() { work.state = ElementStateOkay work.mutex.Lock() - // add to parallel decryption queue - - func() { - for { - select { - case device.queue.decryption <- work: - return - default: - select { - case elem := <-device.queue.decryption: - elem.Drop() - default: - } - } - } - }() - - // add to sequential inbound queue - - func() { - for { - select { - case peer.queue.inbound <- work: - break - default: - select { - case elem := <-peer.queue.inbound: - elem.Drop() - default: - } - } - } - }() + // add to decryption queues + + addToInboundQueue(device.queue.decryption, work) + addToInboundQueue(peer.queue.inbound, work) + buffer = nil default: // unknown message type + debugLog.Println("Got unknown message from:", raddr) } }() } @@ -214,6 +208,9 @@ func (device *Device) RoutineDecryption() { var elem *QueueInboundElement var nonce [chacha20poly1305.NonceSize]byte + logDebug := device.log.Debug + logDebug.Println("Routine, decryption, started for device") + for { select { case elem = <-device.queue.decryption: @@ -223,31 +220,25 @@ func (device *Device) RoutineDecryption() { // check if dropped - state := atomic.LoadUint32(&elem.state) - if state != ElementStateOkay { + if elem.IsDropped() { + elem.mutex.Unlock() continue } // split message into fields - counter := binary.LittleEndian.Uint64( - elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent], - ) + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] // decrypt with key-pair var err error - binary.LittleEndian.PutUint64(nonce[4:], counter) - elem.packet, err = elem.keyPair.recv.Open(elem.packet[:0], nonce[:], content, nil) + copy(nonce[4:], counter) + elem.counter = binary.LittleEndian.Uint64(counter) + elem.packet, err = elem.keyPair.receive.Open(elem.packet[:0], nonce[:], content, nil) if err != nil { elem.Drop() - continue } - - // release to consumer - - elem.counter = counter elem.mutex.Unlock() } } @@ -261,6 +252,7 @@ func (device *Device) RoutineHandshake() { logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug + logDebug.Println("Routine, handshake routine, started for device") var elem QueueHandshakeElement @@ -332,13 +324,15 @@ func (device *Device) RoutineHandshake() { } sendSignal(peer.signal.handshakeCompleted) logDebug.Println("Recieved valid response message for peer", peer.id) - peer.NewKeyPair() + kp := peer.NewKeyPair() + if kp == nil { + logDebug.Println("Failed to derieve key-pair") + } peer.SendKeepAlive() default: device.log.Error.Println("Invalid message type in handshake queue") } - }() } } @@ -348,7 +342,6 @@ func (peer *Peer) RoutineSequentialReceiver() { device := peer.device logDebug := device.log.Debug - logDebug.Println("Routine, sequential receiver, started for peer", peer.id) for { @@ -359,20 +352,15 @@ func (peer *Peer) RoutineSequentialReceiver() { return case elem = <-peer.queue.inbound: } - elem.mutex.Lock() - // check if dropped - - logDebug.Println("MESSSAGE:", elem) - - state := atomic.LoadUint32(&elem.state) - if state != ElementStateOkay { + elem.mutex.Lock() + if elem.IsDropped() { continue } // check for replay - // strip padding + // update timers // check for keep-alive @@ -380,26 +368,30 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } + // strip padding + // insert into inbound TUN queue device.queue.inbound <- elem.packet - } + // update key material + } } func (device *Device) RoutineWriteToTUN(tun TUNDevice) { - for { - var packet []byte + logError := device.log.Error + logDebug := device.log.Debug + logDebug.Println("Routine, sequential tun writer, started") + for { select { case <-device.signal.stop: - case packet = <-device.queue.inbound: - } - - size, err := tun.Write(packet) - device.log.Debug.Println("DEBUG:", size, err) - if err != nil { - + return + case packet := <-device.queue.inbound: + _, err := tun.Write(packet) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) + } } } } |