diff options
-rw-r--r-- | src/config.go | 5 | ||||
-rw-r--r-- | src/device.go | 11 | ||||
-rw-r--r-- | src/peer.go | 4 | ||||
-rw-r--r-- | src/receive.go | 137 | ||||
-rw-r--r-- | src/send.go | 2 |
5 files changed, 115 insertions, 44 deletions
diff --git a/src/config.go b/src/config.go index 8281581..4edaa2e 100644 --- a/src/config.go +++ b/src/config.go @@ -61,8 +61,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { if peer.endpoint != nil { send("endpoint=" + peer.endpoint.String()) } - send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes)) - send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes)) + send(fmt.Sprintf("tx_bytes=%d", peer.txBytes)) + send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes)) send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) for _, ip := range device.routingTable.AllowedIPs(peer) { send("allowed_ip=" + ip.String()) @@ -73,7 +73,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { // send lines for _, line := range lines { - device.log.Debug.Println("Response:", line) _, err := socket.WriteString(line + "\n") if err != nil { return err diff --git a/src/device.go b/src/device.go index 882d587..0564068 100644 --- a/src/device.go +++ b/src/device.go @@ -31,10 +31,16 @@ type Device struct { signal struct { stop chan struct{} } - peers map[NoisePublicKey]*Peer - mac MACStateDevice + congestionState int32 // used as an atomic bool + peers map[NoisePublicKey]*Peer + mac MACStateDevice } +const ( + CongestionStateUnderLoad = iota + CongestionStateOkay +) + func (device *Device) SetPrivateKey(sk NoisePrivateKey) { device.mutex.Lock() defer device.mutex.Unlock() @@ -93,6 +99,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { go device.RoutineDecryption() go device.RoutineHandshake() } + go device.RoutineBusyMonitor() go device.RoutineReadFromTUN(tun) go device.RoutineReceiveIncomming() go device.RoutineWriteToTUN(tun) diff --git a/src/peer.go b/src/peer.go index e3c8060..fadc43f 100644 --- a/src/peer.go +++ b/src/peer.go @@ -17,8 +17,8 @@ type Peer struct { keyPairs KeyPairs handshake Handshake device *Device - tx_bytes uint64 - rx_bytes uint64 + txBytes uint64 + rxBytes uint64 time struct { lastSend time.Time // last send message lastHandshake time.Time // last completed handshake diff --git a/src/receive.go b/src/receive.go index 7b16dc5..c788dcf 100644 --- a/src/receive.go +++ b/src/receive.go @@ -72,12 +72,48 @@ func addToHandshakeQueue( } } -func (device *Device) RoutineReceiveIncomming() { +/* Routine determining the busy state of the interface + * + * TODO: prehaps nicer to do this in response to events + * TODO: more well reasoned definition of "busy" + */ +func (device *Device) RoutineBusyMonitor() { + samples := 0 + interval := time.Second + for timer := time.NewTimer(interval); ; { + + select { + case <-device.signal.stop: + return + case <-timer.C: + } + + // compute busy heuristic + + if len(device.queue.handshake) > QueueHandshakeBusySize { + samples += 1 + } else if samples > 0 { + samples -= 1 + } + samples %= 30 + busy := samples > 5 + + // update busy state + + if busy { + atomic.StoreInt32(&device.congestionState, CongestionStateUnderLoad) + } else { + atomic.StoreInt32(&device.congestionState, CongestionStateOkay) + } + + timer.Reset(interval) + } +} - debugLog := device.log.Debug - debugLog.Println("Routine, receive incomming, started") +func (device *Device) RoutineReceiveIncomming() { - errorLog := device.log.Error + logDebug := device.log.Debug + logDebug.Println("Routine, receive incomming, started") var buffer []byte @@ -122,33 +158,6 @@ func (device *Device) RoutineReceiveIncomming() { case MessageInitiationType, MessageResponseType: - // verify mac1 - - if !device.mac.CheckMAC1(packet) { - debugLog.Println("Received packet with invalid mac1") - return - } - - // check if busy, TODO: refine definition of "busy" - - busy := len(device.queue.handshake) > QueueHandshakeBusySize - if busy && !device.mac.CheckMAC2(packet, raddr) { - sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" always follows "type" - reply, err := device.CreateMessageCookieReply(packet, sender, raddr) - if err != nil { - errorLog.Println("Failed to create cookie reply:", err) - return - } - writer := bytes.NewBuffer(packet[:0]) - binary.Write(writer, binary.LittleEndian, reply) - packet = writer.Bytes() - _, err = device.net.conn.WriteToUDP(packet, raddr) - if err != nil { - debugLog.Println("Failed to send cookie reply:", err) - } - return - } - // add to handshake queue addToHandshakeQueue( @@ -173,7 +182,7 @@ func (device *Device) RoutineReceiveIncomming() { reader := bytes.NewReader(packet) err := binary.Read(reader, binary.LittleEndian, &reply) if err != nil { - debugLog.Println("Failed to decode cookie reply") + logDebug.Println("Failed to decode cookie reply") return } device.ConsumeMessageCookieReply(&reply) @@ -218,7 +227,7 @@ func (device *Device) RoutineReceiveIncomming() { default: // unknown message type - debugLog.Println("Got unknown message from:", raddr) + logDebug.Println("Got unknown message from:", raddr) } }() } @@ -285,6 +294,38 @@ func (device *Device) RoutineHandshake() { func() { + // verify mac1 + + if !device.mac.CheckMAC1(elem.packet) { + logDebug.Println("Received packet with invalid mac1") + return + } + + // verify mac2 + + busy := atomic.LoadInt32(&device.congestionState) == CongestionStateUnderLoad + + if busy && !device.mac.CheckMAC2(elem.packet, elem.source) { + sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type" + reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source) + if err != nil { + logError.Println("Failed to create cookie reply:", err) + return + } + writer := bytes.NewBuffer(elem.packet[:0]) + binary.Write(writer, binary.LittleEndian, reply) + elem.packet = writer.Bytes() + _, err = device.net.conn.WriteToUDP(elem.packet, elem.source) + if err != nil { + logDebug.Println("Failed to send cookie reply:", err) + } + return + } + + // ratelimit + + // handle messages + switch elem.msgType { case MessageInitiationType: @@ -321,12 +362,12 @@ func (device *Device) RoutineHandshake() { logError.Println("Failed to create response message:", err) return } + outElem := device.NewOutboundElement() writer := bytes.NewBuffer(outElem.data[:0]) binary.Write(writer, binary.LittleEndian, response) elem.packet = writer.Bytes() peer.mac.AddMacs(elem.packet) - device.log.Debug.Println(elem.packet) addToOutboundQueue(peer.queue.outbound, outElem) case MessageResponseType: @@ -388,7 +429,7 @@ func (peer *Peer) RoutineSequentialReceiver() { } elem.mutex.Lock() - // process IP packet + // process packet func() { if elem.IsDropped() { @@ -407,30 +448,54 @@ func (peer *Peer) RoutineSequentialReceiver() { return } - // strip padding + // verify source and strip padding switch elem.packet[0] >> 4 { case IPv4version: + + // strip padding + if len(elem.packet) < IPv4headerSize { return } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] length := binary.BigEndian.Uint16(field) elem.packet = elem.packet[:length] + // verify IPv4 source + + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + if device.routingTable.LookupIPv4(dst) != peer { + return + } + case IPv6version: + + // strip padding + if len(elem.packet) < IPv6headerSize { return } + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] length := binary.BigEndian.Uint16(field) length += IPv6headerSize elem.packet = elem.packet[:length] + // verify IPv6 source + + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + if device.routingTable.LookupIPv6(dst) != peer { + return + } + default: device.log.Debug.Println("Receieved packet with unknown IP version") return } + + atomic.AddUint64(&peer.rxBytes, uint64(len(elem.packet))) addToInboundQueue(device.queue.inbound, elem) }() } diff --git a/src/send.go b/src/send.go index d1de44a..a02f5cb 100644 --- a/src/send.go +++ b/src/send.go @@ -329,7 +329,7 @@ func (peer *Peer) RoutineSequentialSender() { if err != nil { return } - atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet))) + atomic.AddUint64(&peer.txBytes, uint64(len(work.packet))) // shift keep-alive timer |