diff options
Diffstat (limited to 'src/receive.go')
-rw-r--r-- | src/receive.go | 507 |
1 files changed, 250 insertions, 257 deletions
diff --git a/src/receive.go b/src/receive.go index fb5c51f..5f46925 100644 --- a/src/receive.go +++ b/src/receive.go @@ -111,113 +111,84 @@ func (device *Device) RoutineBusyMonitor() { func (device *Device) RoutineReceiveIncomming() { - logInfo := device.log.Info logDebug := device.log.Debug logDebug.Println("Routine, receive incomming, started") - var buffer *[MaxMessageSize]byte - for { - // check if stopped + // wait for new conn + + var conn *net.UDPConn select { + case <-device.signal.newUDPConn: + device.net.mutex.RLock() + conn = device.net.conn + device.net.mutex.RUnlock() + case <-device.signal.stop: return - default: } - // read next datagram - - if buffer == nil { - buffer = device.GetMessageBuffer() - } - - // TODO: Take writelock to sleep - device.net.mutex.RLock() - conn := device.net.conn - device.net.mutex.RUnlock() if conn == nil { - time.Sleep(time.Second) continue } - // TODO: Wait for new conn or message - conn.SetReadDeadline(time.Now().Add(time.Second)) + // receive datagrams until closed - size, raddr, err := conn.ReadFromUDP(buffer[:]) - if err != nil || size < MinMessageSize { - continue - } + buffer := device.GetMessageBuffer() - // handle packet + for { - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + // read next datagram - func() { - switch msgType { - - case MessageInitiationType, MessageResponseType: - - // TODO: Check size early + size, raddr, err := conn.ReadFromUDP(buffer[:]) // TODO: This is broken - // add to handshake queue + if err != nil { + break + } - device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - source: raddr, - }, - ) - buffer = nil + if size < MinMessageSize { + continue + } - case MessageCookieReplyType: + // check size of packet - // TODO: Queue all the things + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - // verify and update peer cookie state + var okay bool - if len(packet) != MessageCookieReplySize { - return - } + switch msgType { - var reply MessageCookieReply - reader := bytes.NewReader(packet) - err := binary.Read(reader, binary.LittleEndian, &reply) - if err != nil { - logDebug.Println("Failed to decode cookie reply") - return - } - device.ConsumeMessageCookieReply(&reply) + // check if transport case MessageTransportType: - // lookup key pair + // check size - if len(packet) < MessageTransportSize { - return + if len(packet) < MessageTransportType { + continue } + // lookup key pair + receiver := binary.LittleEndian.Uint32( packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], ) value := device.indices.Lookup(receiver) keyPair := value.keyPair if keyPair == nil { - return + continue } // check key-pair expiry if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { - return + continue } - // add to peer queue + // create work element peer := value.peer elem := &QueueInboundElement{ @@ -233,11 +204,33 @@ func (device *Device) RoutineReceiveIncomming() { device.addToInboundQueue(device.queue.decryption, elem) device.addToInboundQueue(peer.queue.inbound, elem) buffer = nil + continue - default: - logInfo.Println("Got unknown message from:", raddr) + // otherwise it is a handshake related packet + + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize + + case MessageResponseType: + okay = len(packet) == MessageResponseSize + + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize } - }() + + if okay { + device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + source: raddr, + }, + ) + buffer = device.GetMessageBuffer() + } + } } } @@ -306,154 +299,165 @@ func (device *Device) RoutineHandshake() { return } - func() { + // handle cookie fields and ratelimiting - // verify mac1 + switch elem.msgType { - if !device.mac.CheckMAC1(elem.packet) { - logDebug.Println("Received packet with invalid mac1") + case MessageCookieReplyType: + + // verify and update peer cookie state + + var reply MessageCookieReply + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &reply) + if err != nil { + logDebug.Println("Failed to decode cookie reply") return } + device.ConsumeMessageCookieReply(&reply) + continue - // verify mac2 + case MessageInitiationType, MessageResponseType: - busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue + // check mac fields and ratelimit - if busy && !device.mac.CheckMAC2(elem.packet, elem.source) { - sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" - reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source) - if err != nil { - logError.Println("Failed to create cookie reply:", err) - return - } - // TODO: Use temp - writer := bytes.NewBuffer(elem.packet[:0]) - binary.Write(writer, binary.LittleEndian, reply) - elem.packet = writer.Bytes() - _, err = device.net.conn.WriteToUDP(elem.packet, elem.source) - if err != nil { - logDebug.Println("Failed to send cookie reply:", err) - } + if !device.mac.CheckMAC1(elem.packet) { + logDebug.Println("Received packet with invalid mac1") return } - // ratelimit - - // TODO: Only ratelimit when busy + busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue - if !device.ratelimiter.Allow(elem.source.IP) { - return + if busy { + if !device.mac.CheckMAC2(elem.packet, elem.source) { + sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" + reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source) + if err != nil { + logError.Println("Failed to create cookie reply:", err) + return + } + writer := bytes.NewBuffer(temp[:0]) + binary.Write(writer, binary.LittleEndian, reply) + _, err = device.net.conn.WriteToUDP( + writer.Bytes(), + elem.source, + ) + if err != nil { + logDebug.Println("Failed to send cookie reply:", err) + } + continue + } + if !device.ratelimiter.Allow(elem.source.IP) { + continue + } } - // handle messages + default: + logError.Println("Invalid packet ended up in the handshake queue") + continue + } - switch elem.msgType { - case MessageInitiationType: + // handle handshake initation/response content - // unmarshal + switch elem.msgType { + case MessageInitiationType: - if len(elem.packet) != MessageInitiationSize { - return - } + // unmarshal - var msg MessageInitiation - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode initiation message") - return - } + var msg MessageInitiation + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode initiation message") + continue + } - // consume initiation + // consume initiation - peer := device.ConsumeMessageInitiation(&msg) - if peer == nil { - logInfo.Println( - "Recieved invalid initiation message from", - elem.source.IP.String(), - elem.source.Port, - ) - return - } + peer := device.ConsumeMessageInitiation(&msg) + if peer == nil { + logInfo.Println( + "Recieved invalid initiation message from", + elem.source.IP.String(), + elem.source.Port, + ) + continue + } - // update timers + // update timers - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() - // update endpoint - // TODO: Add a race condition \s + // update endpoint + // TODO: Discover destination address also, only update on change - peer.mutex.Lock() - peer.endpoint = elem.source - peer.mutex.Unlock() + peer.mutex.Lock() + peer.endpoint = elem.source + peer.mutex.Unlock() - // create response + // create response - response, err := device.CreateMessageResponse(peer) - if err != nil { - logError.Println("Failed to create response message:", err) - return - } + response, err := device.CreateMessageResponse(peer) + if err != nil { + logError.Println("Failed to create response message:", err) + continue + } - peer.TimerEphemeralKeyCreated() - peer.NewKeyPair() + peer.TimerEphemeralKeyCreated() + peer.NewKeyPair() - logDebug.Println("Creating response message for", peer.String()) + logDebug.Println("Creating response message for", peer.String()) - writer := bytes.NewBuffer(temp[:0]) - binary.Write(writer, binary.LittleEndian, response) - packet := writer.Bytes() - peer.mac.AddMacs(packet) + writer := bytes.NewBuffer(temp[:0]) + binary.Write(writer, binary.LittleEndian, response) + packet := writer.Bytes() + peer.mac.AddMacs(packet) - // send response + // send response - peer.SendBuffer(packet) + _, err = peer.SendBuffer(packet) + if err == nil { peer.TimerAnyAuthenticatedPacketTraversal() + } - case MessageResponseType: + case MessageResponseType: - // unmarshal + // unmarshal - if len(elem.packet) != MessageResponseSize { - return - } - - var msg MessageResponse - reader := bytes.NewReader(elem.packet) - err := binary.Read(reader, binary.LittleEndian, &msg) - if err != nil { - logError.Println("Failed to decode response message") - return - } + var msg MessageResponse + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode response message") + continue + } - // consume response + // consume response - peer := device.ConsumeMessageResponse(&msg) - if peer == nil { - logInfo.Println( - "Recieved invalid response message from", - elem.source.IP.String(), - elem.source.Port, - ) - return - } + peer := device.ConsumeMessageResponse(&msg) + if peer == nil { + logInfo.Println( + "Recieved invalid response message from", + elem.source.IP.String(), + elem.source.Port, + ) + continue + } - // update timers + peer.TimerEphemeralKeyCreated() - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.TimerHandshakeComplete() + // update timers - // derive key-pair + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + peer.TimerHandshakeComplete() - peer.NewKeyPair() - peer.SendKeepAlive() + // derive key-pair - default: - logError.Println("Invalid message type in handshake queue") - } - }() + peer.NewKeyPair() + peer.SendKeepAlive() + } } } @@ -463,6 +467,7 @@ func (peer *Peer) RoutineSequentialReceiver() { device := peer.device logInfo := device.log.Info + logError := device.log.Error logDebug := device.log.Debug logDebug.Println("Routine, sequential receiver, started for peer", peer.id) @@ -478,116 +483,104 @@ func (peer *Peer) RoutineSequentialReceiver() { // process packet - func() { - if elem.IsDropped() { - return - } - - // check for replay - - if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { - return - } + if elem.IsDropped() { + continue + } - peer.TimerAnyAuthenticatedPacketTraversal() - peer.TimerAnyAuthenticatedPacketReceived() - peer.KeepKeyFreshReceiving() + // check for replay - // check if using new key-pair + if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) { + continue + } - kp := &peer.keyPairs - kp.mutex.Lock() - if kp.next == elem.keyPair { - peer.TimerHandshakeComplete() - kp.previous = kp.current - kp.current = kp.next - kp.next = nil - } - kp.mutex.Unlock() + peer.TimerAnyAuthenticatedPacketTraversal() + peer.TimerAnyAuthenticatedPacketReceived() + peer.KeepKeyFreshReceiving() - // check for keep-alive + // check if using new key-pair - if len(elem.packet) == 0 { - logDebug.Println("Received keep-alive from", peer.String()) - return - } - peer.TimerDataReceived() + kp := &peer.keyPairs + kp.mutex.Lock() + if kp.next == elem.keyPair { + peer.TimerHandshakeComplete() + kp.previous = kp.current + kp.current = kp.next + kp.next = nil + } + kp.mutex.Unlock() - // verify source and strip padding + // check for keep-alive - switch elem.packet[0] >> 4 { - case ipv4.Version: + if len(elem.packet) == 0 { + logDebug.Println("Received keep-alive from", peer.String()) + continue + } + peer.TimerDataReceived() - // strip padding + // verify source and strip padding - if len(elem.packet) < ipv4.HeaderLen { - return - } + switch elem.packet[0] >> 4 { + case ipv4.Version: - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - // TODO: check length of packet & NOT TOO SMALL either - elem.packet = elem.packet[:length] + // strip padding - // verify IPv4 source + if len(elem.packet) < ipv4.HeaderLen { + continue + } - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.routingTable.LookupIPv4(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) - return - } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } - case ipv6.Version: + elem.packet = elem.packet[:length] - // strip padding + // verify IPv4 source - if len(elem.packet) < ipv6.HeaderLen { - return - } + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.routingTable.LookupIPv4(src) != peer { + logInfo.Println("Packet with unallowed source IP from", peer.String()) + continue + } - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - // TODO: check length of packet - elem.packet = elem.packet[:length] + case ipv6.Version: - // verify IPv6 source + // strip padding - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.routingTable.LookupIPv6(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) - return - } + if len(elem.packet) < ipv6.HeaderLen { + continue + } - default: - logInfo.Println("Packet with invalid IP version from", peer.String()) - return + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue } - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) - device.addToInboundQueue(device.queue.inbound, elem) + elem.packet = elem.packet[:length] - // TODO: move TUN write into per peer routine - }() - } -} + // verify IPv6 source -func (device *Device) RoutineWriteToTUN() { + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.routingTable.LookupIPv6(src) != peer { + logInfo.Println("Packet with unallowed source IP from", peer.String()) + continue + } - logError := device.log.Error - logDebug := device.log.Debug - logDebug.Println("Routine, sequential tun writer, started") + default: + logInfo.Println("Packet with invalid IP version from", peer.String()) + continue + } - for { - select { - case <-device.signal.stop: - return - case elem := <-device.queue.inbound: - _, err := device.tun.Write(elem.packet) - device.PutMessageBuffer(elem.buffer) - if err != nil { - logError.Println("Failed to write packet to TUN device:", err) - } + // write to tun + + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.Write(elem.packet) + device.PutMessageBuffer(elem.buffer) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) } } } |