diff options
Diffstat (limited to 'device/receive.go')
-rw-r--r-- | device/receive.go | 641 |
1 files changed, 641 insertions, 0 deletions
diff --git a/device/receive.go b/device/receive.go new file mode 100644 index 0000000..5c837c1 --- /dev/null +++ b/device/receive.go @@ -0,0 +1,641 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "bytes" + "encoding/binary" + "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "net" + "strconv" + "sync" + "sync/atomic" + "time" +) + +type QueueHandshakeElement struct { + msgType uint32 + packet []byte + endpoint Endpoint + buffer *[MaxMessageSize]byte +} + +type QueueInboundElement struct { + dropped int32 + sync.Mutex + buffer *[MaxMessageSize]byte + packet []byte + counter uint64 + keypair *Keypair + endpoint Endpoint +} + +func (elem *QueueInboundElement) Drop() { + atomic.StoreInt32(&elem.dropped, AtomicTrue) +} + +func (elem *QueueInboundElement) IsDropped() bool { + return atomic.LoadInt32(&elem.dropped) == AtomicTrue +} + +func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { + select { + case inboundQueue <- element: + select { + case decryptionQueue <- element: + return true + default: + element.Drop() + element.Unlock() + return false + } + default: + device.PutInboundElement(element) + return false + } +} + +func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { + select { + case queue <- element: + return true + default: + return false + } +} + +/* Called when a new authenticated message has been received + * + * NOTE: Not thread safe, but called by sequential receiver! + */ +func (peer *Peer) keepKeyFreshReceiving() { + if peer.timers.sentLastMinuteHandshake.Get() { + return + } + keypair := peer.keypairs.Current() + if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { + peer.timers.sentLastMinuteHandshake.Set(true) + peer.SendHandshakeInitiation(false) + } +} + +/* Receives incoming datagrams for the device + * + * Every time the bind is updated a new routine is started for + * IPv4 and IPv6 (separately) + */ +func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { + + logDebug := device.log.Debug + defer func() { + logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.net.stopping.Done() + }() + + logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started") + device.net.starting.Done() + + // receive datagrams until conn is closed + + buffer := device.GetMessageBuffer() + + var ( + err error + size int + endpoint Endpoint + ) + + for { + + // read next datagram + + switch IP { + case ipv4.Version: + size, endpoint, err = bind.ReceiveIPv4(buffer[:]) + case ipv6.Version: + size, endpoint, err = bind.ReceiveIPv6(buffer[:]) + default: + panic("invalid IP version") + } + + if err != nil { + device.PutMessageBuffer(buffer) + return + } + + if size < MinMessageSize { + continue + } + + // check size of packet + + packet := buffer[:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) + + var okay bool + + switch msgType { + + // check if transport + + case MessageTransportType: + + // check size + + if len(packet) < MessageTransportSize { + continue + } + + // lookup key pair + + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } + + // check keypair expiry + + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } + + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = buffer + elem.keypair = keypair + elem.dropped = AtomicFalse + elem.endpoint = endpoint + elem.counter = 0 + elem.Mutex = sync.Mutex{} + elem.Lock() + + // add to decryption queues + + if peer.isRunning.Get() { + if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { + buffer = device.GetMessageBuffer() + } + } + + continue + + // otherwise it is a fixed size & handshake related packet + + case MessageInitiationType: + okay = len(packet) == MessageInitiationSize + + case MessageResponseType: + okay = len(packet) == MessageResponseSize + + case MessageCookieReplyType: + okay = len(packet) == MessageCookieReplySize + + default: + logDebug.Println("Received message with unknown type") + } + + if okay { + if (device.addToHandshakeQueue( + device.queue.handshake, + QueueHandshakeElement{ + msgType: msgType, + buffer: buffer, + packet: packet, + endpoint: endpoint, + }, + )) { + buffer = device.GetMessageBuffer() + } + } + } +} + +func (device *Device) RoutineDecryption() { + + var nonce [chacha20poly1305.NonceSize]byte + + logDebug := device.log.Debug + defer func() { + logDebug.Println("Routine: decryption worker - stopped") + device.state.stopping.Done() + }() + logDebug.Println("Routine: decryption worker - started") + device.state.starting.Done() + + for { + select { + case <-device.signals.stop: + return + + case elem, ok := <-device.queue.decryption: + + if !ok { + return + } + + // check if dropped + + if elem.IsDropped() { + continue + } + + // split message into fields + + counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] + content := elem.packet[MessageTransportOffsetContent:] + + // expand nonce + + nonce[0x4] = counter[0x0] + nonce[0x5] = counter[0x1] + nonce[0x6] = counter[0x2] + nonce[0x7] = counter[0x3] + + nonce[0x8] = counter[0x4] + nonce[0x9] = counter[0x5] + nonce[0xa] = counter[0x6] + nonce[0xb] = counter[0x7] + + // decrypt and release to consumer + + var err error + elem.counter = binary.LittleEndian.Uint64(counter) + elem.packet, err = elem.keypair.receive.Open( + content[:0], + nonce[:], + content, + nil, + ) + if err != nil { + elem.Drop() + device.PutMessageBuffer(elem.buffer) + } + elem.Unlock() + } + } +} + +/* Handles incoming packets related to handshake + */ +func (device *Device) RoutineHandshake() { + + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + + var elem QueueHandshakeElement + var ok bool + + defer func() { + logDebug.Println("Routine: handshake worker - stopped") + device.state.stopping.Done() + if elem.buffer != nil { + device.PutMessageBuffer(elem.buffer) + } + }() + + logDebug.Println("Routine: handshake worker - started") + device.state.starting.Done() + + for { + if elem.buffer != nil { + device.PutMessageBuffer(elem.buffer) + elem.buffer = nil + } + + select { + case elem, ok = <-device.queue.handshake: + case <-device.signals.stop: + return + } + + if !ok { + return + } + + // handle cookie fields and ratelimiting + + switch elem.msgType { + + case MessageCookieReplyType: + + // unmarshal packet + + var reply MessageCookieReply + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &reply) + if err != nil { + logDebug.Println("Failed to decode cookie reply") + return + } + + // lookup peer from index + + entry := device.indexTable.Lookup(reply.Receiver) + + if entry.peer == nil { + continue + } + + // consume reply + + if peer := entry.peer; peer.isRunning.Get() { + logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString()) + if !peer.cookieGenerator.ConsumeReply(&reply) { + logDebug.Println("Could not decrypt invalid cookie response") + } + } + + continue + + case MessageInitiationType, MessageResponseType: + + // check mac fields and maybe ratelimit + + if !device.cookieChecker.CheckMAC1(elem.packet) { + logDebug.Println("Received packet with invalid mac1") + continue + } + + // endpoints destination address is the source of the datagram + + if device.IsUnderLoad() { + + // verify MAC2 field + + if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { + device.SendHandshakeCookie(&elem) + continue + } + + // check ratelimiter + + if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { + continue + } + } + + default: + logError.Println("Invalid packet ended up in the handshake queue") + continue + } + + // handle handshake initiation/response content + + switch elem.msgType { + case MessageInitiationType: + + // unmarshal + + var msg MessageInitiation + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode initiation message") + continue + } + + // consume initiation + + peer := device.ConsumeMessageInitiation(&msg) + if peer == nil { + logInfo.Println( + "Received invalid initiation message from", + elem.endpoint.DstToString(), + ) + continue + } + + // update timers + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + logDebug.Println(peer, "- Received handshake initiation") + + peer.SendHandshakeResponse() + + case MessageResponseType: + + // unmarshal + + var msg MessageResponse + reader := bytes.NewReader(elem.packet) + err := binary.Read(reader, binary.LittleEndian, &msg) + if err != nil { + logError.Println("Failed to decode response message") + continue + } + + // consume response + + peer := device.ConsumeMessageResponse(&msg) + if peer == nil { + logInfo.Println( + "Received invalid response message from", + elem.endpoint.DstToString(), + ) + continue + } + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + logDebug.Println(peer, "- Received handshake response") + + // update timers + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // derive keypair + + err = peer.BeginSymmetricSession() + + if err != nil { + logError.Println(peer, "- Failed to derive keypair:", err) + continue + } + + peer.timersSessionDerived() + peer.timersHandshakeComplete() + peer.SendKeepalive() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } + } + } +} + +func (peer *Peer) RoutineSequentialReceiver() { + + device := peer.device + logInfo := device.log.Info + logError := device.log.Error + logDebug := device.log.Debug + + var elem *QueueInboundElement + var ok bool + + defer func() { + logDebug.Println(peer, "- Routine: sequential receiver - stopped") + peer.routines.stopping.Done() + if elem != nil { + if !elem.IsDropped() { + device.PutMessageBuffer(elem.buffer) + } + device.PutInboundElement(elem) + } + }() + + logDebug.Println(peer, "- Routine: sequential receiver - started") + + peer.routines.starting.Done() + + for { + if elem != nil { + if !elem.IsDropped() { + device.PutMessageBuffer(elem.buffer) + } + device.PutInboundElement(elem) + elem = nil + } + + select { + + case <-peer.routines.stop: + return + + case elem, ok = <-peer.queue.inbound: + + if !ok { + return + } + + // wait for decryption + + elem.Lock() + + if elem.IsDropped() { + continue + } + + // check for replay + + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } + + // update endpoint + peer.SetEndpointFromPacket(elem.endpoint) + + // check if using new keypair + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + select { + case peer.signals.newKeypairArrived <- struct{}{}: + default: + } + } + + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + + // check for keepalive + + if len(elem.packet) == 0 { + logDebug.Println(peer, "- Receiving keepalive packet") + continue + } + peer.timersDataReceived() + + // verify source and strip padding + + switch elem.packet[0] >> 4 { + case ipv4.Version: + + // strip padding + + if len(elem.packet) < ipv4.HeaderLen { + continue + } + + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv4 source + + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.LookupIPv4(src) != peer { + logInfo.Println( + "IPv4 packet with disallowed source address from", + peer, + ) + continue + } + + case ipv6.Version: + + // strip padding + + if len(elem.packet) < ipv6.HeaderLen { + continue + } + + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + + elem.packet = elem.packet[:length] + + // verify IPv6 source + + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.LookupIPv6(src) != peer { + logInfo.Println( + peer, + "sent packet with disallowed IPv6 source", + ) + continue + } + + default: + logInfo.Println("Packet with invalid IP version from", peer) + continue + } + + // write to tun device + + offset := MessageTransportOffsetContent + atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) + if err != nil { + logError.Println("Failed to write packet to TUN device:", err) + } + } + } +} |