diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-01-29 14:54:11 +0100 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-01-29 16:21:53 +0100 |
commit | 9263014ed3f0a97800c893cb7346cc5109fc9e27 (patch) | |
tree | 2962b08013b744110780769613027cb1baac9baa /device/receive.go | |
parent | f0f27d7fd242587ccb966c6d2e074dafe5ab7349 (diff) |
device: simplify peer queue locking
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device/receive.go')
-rw-r--r-- | device/receive.go | 84 |
1 files changed, 18 insertions, 66 deletions
diff --git a/device/receive.go b/device/receive.go index d513a21..abaf5af 100644 --- a/device/receive.go +++ b/device/receive.go @@ -174,7 +174,6 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { elem.Lock() // add to decryption queues - peer.queue.RLock() if peer.isRunning.Get() { peer.queue.inbound <- elem @@ -433,52 +432,25 @@ func (device *Device) RoutineHandshake() { func (peer *Peer) RoutineSequentialReceiver() { device := peer.device - var elem *QueueInboundElement - defer func() { device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) - peer.routines.stopping.Done() - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) - } + peer.stopping.Done() }() - device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutInboundElement(elem) - elem = nil - } - - var elemOk bool - select { - case <-peer.routines.stop: - return - case elem, elemOk = <-peer.queue.inbound: - if !elemOk { - return - } - } - - // wait for decryption + for elem := range peer.queue.inbound { + var err error elem.Lock() if elem.packet == nil { // decryption failed - continue + goto skip } - // check for replay if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue + goto skip } - // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - - // check if using new keypair if peer.ReceivedWithKeypair(elem.keypair) { peer.timersHandshakeComplete() peer.SendStagedPackets() @@ -489,83 +461,63 @@ func (peer *Peer) RoutineSequentialReceiver() { peer.timersAnyAuthenticatedPacketReceived() atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) - // check for keepalive - if len(elem.packet) == 0 { device.log.Verbosef("%v - Receiving keepalive packet", peer) - continue + goto skip } peer.timersDataReceived() - // verify source and strip padding - switch elem.packet[0] >> 4 { case ipv4.Version: - - // strip padding - if len(elem.packet) < ipv4.HeaderLen { - continue + goto skip } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] length := binary.BigEndian.Uint16(field) if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { - continue + goto skip } - elem.packet = elem.packet[:length] - - // verify IPv4 source - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] if device.allowedips.LookupIPv4(src) != peer { device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) - continue + goto skip } case ipv6.Version: - - // strip padding - if len(elem.packet) < ipv6.HeaderLen { - continue + goto skip } - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] length := binary.BigEndian.Uint16(field) length += ipv6.HeaderLen if int(length) > len(elem.packet) { - continue + goto skip } - elem.packet = elem.packet[:length] - - // verify IPv6 source - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.allowedips.LookupIPv6(src) != peer { device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) - continue + goto skip } default: device.log.Verbosef("Packet with invalid IP version from %v", peer) - continue + goto skip } - // write to tun device - - offset := MessageTransportOffsetContent - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) + _, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent) if err != nil && !device.isClosed.Get() { device.log.Errorf("Failed to write packet to TUN device: %v", err) } if len(peer.queue.inbound) == 0 { - err := device.tun.device.Flush() + err = device.tun.device.Flush() if err != nil { peer.device.log.Errorf("Unable to flush packets: %v", err) } } + skip: + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) } } |