diff options
Diffstat (limited to 'src/receive.go')
-rw-r--r-- | src/receive.go | 257 |
1 files changed, 137 insertions, 120 deletions
diff --git a/src/receive.go b/src/receive.go index 52c2718..27fdb8a 100644 --- a/src/receive.go +++ b/src/receive.go @@ -13,19 +13,20 @@ import ( ) type QueueHandshakeElement struct { - msgType uint32 - packet []byte - buffer *[MaxMessageSize]byte - source *net.UDPAddr + msgType uint32 + packet []byte + endpoint Endpoint + buffer *[MaxMessageSize]byte } type QueueInboundElement struct { - dropped int32 - mutex sync.Mutex - buffer *[MaxMessageSize]byte - packet []byte - counter uint64 - keyPair *KeyPair + dropped int32 + mutex sync.Mutex + buffer *[MaxMessageSize]byte + packet []byte + counter uint64 + keyPair *KeyPair + endpoint Endpoint } func (elem *QueueInboundElement) Drop() { @@ -92,130 +93,122 @@ func (device *Device) addToHandshakeQueue( } } -func (device *Device) RoutineReceiveIncomming() { +func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) { logDebug := device.log.Debug - logDebug.Println("Routine, receive incomming, started") + logDebug.Println("Routine, receive incomming, IP version:", IP) for { - // wait for new conn + // receive datagrams until conn is closed - logDebug.Println("Waiting for udp socket") + buffer := device.GetMessageBuffer() - select { - case <-device.signal.stop: - return + var ( + err error + size int + endpoint Endpoint + ) - case <-device.signal.newUDPConn: + for { - // fetch connection + // read next datagram - device.net.mutex.RLock() - conn := device.net.conn - device.net.mutex.RUnlock() - if conn == nil { + switch IP { + case ipv4.Version: + size, endpoint, err = bind.ReceiveIPv4(buffer[:]) + case ipv6.Version: + size, endpoint, err = bind.ReceiveIPv6(buffer[:]) + default: + return + } + + if err != nil { + break + } + + if size < MinMessageSize { continue } - logDebug.Println("Listening for inbound packets") + // check size of packet - // receive datagrams until conn is closed + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - buffer := device.GetMessageBuffer() + var okay bool - for { + switch msgType { - // read next datagram + // check if transport - size, raddr, err := conn.ReadFromUDP(buffer[:]) + case MessageTransportType: - if err != nil { - break - } + // check size - if size < MinMessageSize { + if len(packet) < MessageTransportType { continue } - // check size of packet - - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) - - var okay bool + // lookup key pair - switch msgType { - - // check if transport - - case MessageTransportType: - - // check size - - if len(packet) < MessageTransportType { - continue - } - - // lookup key pair - - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indices.Lookup(receiver) - keyPair := value.keyPair - if keyPair == nil { - continue - } + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indices.Lookup(receiver) + keyPair := value.keyPair + if keyPair == nil { + continue + } - // check key-pair expiry + // check key-pair expiry - if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } + if keyPair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } - // create work element + // create work element - peer := value.peer - elem := &QueueInboundElement{ - packet: packet, - buffer: buffer, - keyPair: keyPair, - dropped: AtomicFalse, - } - elem.mutex.Lock() + peer := value.peer + elem := &QueueInboundElement{ + packet: packet, + buffer: buffer, + keyPair: keyPair, + dropped: AtomicFalse, + endpoint: endpoint, + } + elem.mutex.Lock() - // add to decryption queues + // add to decryption queues - device.addToDecryptionQueue(device.queue.decryption, elem) - device.addToInboundQueue(peer.queue.inbound, elem) - buffer = device.GetMessageBuffer() - continue + device.addToDecryptionQueue(device.queue.decryption, elem) + device.addToInboundQueue(peer.queue.inbound, elem) + buffer = device.GetMessageBuffer() + continue - // otherwise it is a handshake related packet + // otherwise it is a fixed size & handshake related packet - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize - case MessageResponseType: - okay = len(packet) == MessageResponseSize + case MessageResponseType: + okay = len(packet) == MessageResponseSize - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize - } + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize + } - if okay { - device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - source: raddr, - }, - ) - buffer = device.GetMessageBuffer() - } + if okay { + device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, + }, + ) + buffer = device.GetMessageBuffer() } } } @@ -293,8 +286,6 @@ func (device *Device) RoutineHandshake() { // unmarshal packet - logDebug.Println("Process cookie reply from:", elem.source.String()) - var reply MessageCookieReply reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &reply) @@ -321,15 +312,25 @@ func (device *Device) RoutineHandshake() { return } + // endpoints destination address is the source of the datagram + + srcBytes := elem.endpoint.DstToBytes() + if device.IsUnderLoad() { - if !device.mac.CheckMAC2(elem.packet, elem.source) { + + // verify MAC2 field + + if !device.mac.CheckMAC2(elem.packet, srcBytes) { // construct cookie reply - logDebug.Println("Sending cookie reply to:", elem.source.String()) + logDebug.Println( + "Sending cookie reply to:", + elem.endpoint.DstToString(), + ) - sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" - reply, err := device.mac.CreateReply(elem.packet, sender, elem.source) + sender := binary.LittleEndian.Uint32(elem.packet[4:8]) + reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes) if err != nil { logError.Println("Failed to create cookie reply:", err) return @@ -339,17 +340,16 @@ func (device *Device) RoutineHandshake() { writer := bytes.NewBuffer(temp[:0]) binary.Write(writer, binary.LittleEndian, reply) - _, err = device.net.conn.WriteToUDP( - writer.Bytes(), - elem.source, - ) + device.net.bind.Send(writer.Bytes(), elem.endpoint) if err != nil { logDebug.Println("Failed to send cookie reply:", err) } continue } - if !device.ratelimiter.Allow(elem.source.IP) { + // check ratelimiter + + if !device.ratelimiter.Allow(elem.endpoint.DstIP()) { continue } } @@ -380,8 +380,7 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid initiation message from", - elem.source.IP.String(), - elem.source.Port, + elem.endpoint.DstToString(), ) continue } @@ -392,10 +391,9 @@ func (device *Device) RoutineHandshake() { peer.TimerAnyAuthenticatedPacketReceived() // update endpoint - // TODO: Discover destination address also, only update on change peer.mutex.Lock() - peer.endpoint = elem.source + peer.endpoint = elem.endpoint peer.mutex.Unlock() // create response @@ -418,9 +416,11 @@ func (device *Device) RoutineHandshake() { // send response - _, err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet) if err == nil { peer.TimerAnyAuthenticatedPacketTraversal() + } else { + logError.Println("Failed to send response to:", peer.String(), err) } case MessageResponseType: @@ -441,12 +441,17 @@ func (device *Device) RoutineHandshake() { if peer == nil { logInfo.Println( "Recieved invalid response message from", - elem.source.IP.String(), - elem.source.Port, + elem.endpoint.DstToString(), ) continue } + // update endpoint + + peer.mutex.Lock() + peer.endpoint = elem.endpoint + peer.mutex.Unlock() + logDebug.Println("Received handshake initation from", peer) peer.TimerEphemeralKeyCreated() @@ -515,6 +520,12 @@ func (peer *Peer) RoutineSequentialReceiver() { } kp.mutex.Unlock() + // update endpoint + + peer.mutex.Lock() + peer.endpoint = elem.endpoint + peer.mutex.Unlock() + // check for keep-alive if len(elem.packet) == 0 { @@ -546,7 +557,10 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] if device.routingTable.LookupIPv4(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) + logInfo.Println( + "IPv4 packet with unallowed source address from", + peer.String(), + ) continue } @@ -571,7 +585,10 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.routingTable.LookupIPv6(src) != peer { - logInfo.Println("Packet with unallowed source IP from", peer.String()) + logInfo.Println( + "IPv6 packet with unallowed source address from", + peer.String(), + ) continue } @@ -580,7 +597,7 @@ func (peer *Peer) RoutineSequentialReceiver() { continue } - // write to tun + // write to tun device atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) _, err := device.tun.device.Write(elem.packet) |