diff options
Diffstat (limited to 'device/receive.go')
-rw-r--r-- | device/receive.go | 320 |
1 files changed, 178 insertions, 142 deletions
diff --git a/device/receive.go b/device/receive.go index 03fcf00..aee7864 100644 --- a/device/receive.go +++ b/device/receive.go @@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { +func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { recvName := recv.PrettyName() defer func() { device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) @@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { // receive datagrams until conn is closed - buffer := device.GetMessageBuffer() - var ( + buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) + buffs = make([][]byte, maxBatchSize) err error - size int - endpoint conn.Endpoint + sizes = make([]int, maxBatchSize) + count int + endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int + elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize) ) - for { - size, endpoint, err = recv(buffer[:]) + for i := range buffsArrs { + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] + } + + defer func() { + for i := 0; i < maxBatchSize; i++ { + if buffsArrs[i] != nil { + device.PutMessageBuffer(buffsArrs[i]) + } + } + }() + for { + count, err = recv(buffs, sizes, endpoints) if err != nil { - device.PutMessageBuffer(buffer) if errors.Is(err, net.ErrClosed) { return } @@ -103,101 +116,122 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { if deathSpiral < 10 { deathSpiral++ time.Sleep(time.Second / 3) - buffer = device.GetMessageBuffer() continue } return } deathSpiral = 0 - if size < MinMessageSize { - continue - } + // handle each packet in the batch + for i, size := range sizes[:count] { + if size < MinMessageSize { + continue + } - // check size of packet + // check size of packet - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + packet := buffsArrs[i][:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - var okay bool + switch msgType { - switch msgType { + // check if transport - // check if transport + case MessageTransportType: - case MessageTransportType: + // check size - // check size + if len(packet) < MessageTransportSize { + continue + } - if len(packet) < MessageTransportSize { - continue - } + // lookup key pair - // lookup key pair + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indexTable.Lookup(receiver) - keypair := value.keypair - if keypair == nil { - continue - } + // check keypair expiry - // check keypair expiry + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } - if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = buffsArrs[i] + elem.keypair = keypair + elem.endpoint = endpoints[i] + elem.counter = 0 + elem.Mutex = sync.Mutex{} + elem.Lock() + + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetInboundElementsSlice() + elemsByPeer[peer] = elemsForPeer + } + *elemsForPeer = append(*elemsForPeer, elem) + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] continue - } - - // create work element - peer := value.peer - elem := device.GetInboundElement() - elem.packet = packet - elem.buffer = buffer - elem.keypair = keypair - elem.endpoint = endpoint - elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() - // add to decryption queues - if peer.isRunning.Load() { - peer.queue.inbound.c <- elem - device.queue.decryption.c <- elem - buffer = device.GetMessageBuffer() - } else { - device.PutInboundElement(elem) - } - continue + // otherwise it is a fixed size & handshake related packet - // otherwise it is a fixed size & handshake related packet - - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize + case MessageInitiationType: + if len(packet) != MessageInitiationSize { + continue + } - case MessageResponseType: - okay = len(packet) == MessageResponseSize + case MessageResponseType: + if len(packet) != MessageResponseSize { + continue + } - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize + case MessageCookieReplyType: + if len(packet) != MessageCookieReplySize { + continue + } - default: - device.log.Verbosef("Received message with unknown type") - } + default: + device.log.Verbosef("Received message with unknown type") + continue + } - if okay { select { case device.queue.handshake.c <- QueueHandshakeElement{ msgType: msgType, - buffer: buffer, + buffer: buffsArrs[i], packet: packet, - endpoint: endpoint, + endpoint: endpoints[i], }: - buffer = device.GetMessageBuffer() + buffsArrs[i] = device.GetMessageBuffer() + buffs[i] = buffsArrs[i][:] default: } } + for peer, elems := range elemsByPeer { + if peer.isRunning.Load() { + peer.queue.inbound.c <- elems + for _, elem := range *elems { + device.queue.decryption.c <- elem + } + } else { + for _, elem := range *elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsSlice(elems) + } + delete(elemsByPeer, peer) + } } } @@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) { } } -func (peer *Peer) RoutineSequentialReceiver() { +func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device := peer.device defer func() { device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) @@ -401,89 +435,91 @@ func (peer *Peer) RoutineSequentialReceiver() { }() device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - for elem := range peer.queue.inbound.c { - if elem == nil { + buffs := make([][]byte, 0, maxBatchSize) + + for elems := range peer.queue.inbound.c { + if elems == nil { return } - var err error - elem.Lock() - if elem.packet == nil { - // decryption failed - goto skip - } + for _, elem := range *elems { + elem.Lock() + if elem.packet == nil { + // decryption failed + continue + } - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - goto skip - } + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { + continue + } - peer.SetEndpointFromPacket(elem.endpoint) - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - peer.SendStagedPackets() - } + peer.SetEndpointFromPacket(elem.endpoint) + if peer.ReceivedWithKeypair(elem.keypair) { + peer.timersHandshakeComplete() + peer.SendStagedPackets() + } + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize)) + if len(elem.packet) == 0 { + device.log.Verbosef("%v - Receiving keepalive packet", peer) + continue + } + peer.timersDataReceived() - if len(elem.packet) == 0 { - device.log.Verbosef("%v - Receiving keepalive packet", peer) - goto skip - } - peer.timersDataReceived() + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) + continue + } - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - goto skip - } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - goto skip - } - elem.packet = elem.packet[:length] - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.Lookup(src) != peer { - device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) - goto skip - } + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) + continue + } - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { - goto skip - } - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - goto skip - } - elem.packet = elem.packet[:length] - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.Lookup(src) != peer { - device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) - goto skip + default: + device.log.Verbosef("Packet with invalid IP version from %v", peer) + continue } - default: - device.log.Verbosef("Packet with invalid IP version from %v", peer) - goto skip + buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } - - _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) - if err != nil && !device.isClosed() { - device.log.Errorf("Failed to write packet to TUN device: %v", err) - } - if len(peer.queue.inbound.c) == 0 { - err = device.tun.device.Flush() - if err != nil { - peer.device.log.Errorf("Unable to flush packets: %v", err) + if len(buffs) > 0 { + _, err := device.tun.device.Write(buffs, MessageTransportOffsetContent) + if err != nil && !device.isClosed() { + device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - skip: - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) + for _, elem := range *elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + buffs = buffs[:0] + device.PutInboundElementsSlice(elems) } } |