summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/config.go5
-rw-r--r--src/device.go11
-rw-r--r--src/peer.go4
-rw-r--r--src/receive.go137
-rw-r--r--src/send.go2
5 files changed, 115 insertions, 44 deletions
diff --git a/src/config.go b/src/config.go
index 8281581..4edaa2e 100644
--- a/src/config.go
+++ b/src/config.go
@@ -61,8 +61,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.String())
}
- send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
- send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
+ send(fmt.Sprintf("tx_bytes=%d", peer.txBytes))
+ send(fmt.Sprintf("rx_bytes=%d", peer.rxBytes))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.routingTable.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
@@ -73,7 +73,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
// send lines
for _, line := range lines {
- device.log.Debug.Println("Response:", line)
_, err := socket.WriteString(line + "\n")
if err != nil {
return err
diff --git a/src/device.go b/src/device.go
index 882d587..0564068 100644
--- a/src/device.go
+++ b/src/device.go
@@ -31,10 +31,16 @@ type Device struct {
signal struct {
stop chan struct{}
}
- peers map[NoisePublicKey]*Peer
- mac MACStateDevice
+ congestionState int32 // used as an atomic bool
+ peers map[NoisePublicKey]*Peer
+ mac MACStateDevice
}
+const (
+ CongestionStateUnderLoad = iota
+ CongestionStateOkay
+)
+
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
device.mutex.Lock()
defer device.mutex.Unlock()
@@ -93,6 +99,7 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
go device.RoutineDecryption()
go device.RoutineHandshake()
}
+ go device.RoutineBusyMonitor()
go device.RoutineReadFromTUN(tun)
go device.RoutineReceiveIncomming()
go device.RoutineWriteToTUN(tun)
diff --git a/src/peer.go b/src/peer.go
index e3c8060..fadc43f 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -17,8 +17,8 @@ type Peer struct {
keyPairs KeyPairs
handshake Handshake
device *Device
- tx_bytes uint64
- rx_bytes uint64
+ txBytes uint64
+ rxBytes uint64
time struct {
lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake
diff --git a/src/receive.go b/src/receive.go
index 7b16dc5..c788dcf 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -72,12 +72,48 @@ func addToHandshakeQueue(
}
}
-func (device *Device) RoutineReceiveIncomming() {
+/* Routine determining the busy state of the interface
+ *
+ * TODO: prehaps nicer to do this in response to events
+ * TODO: more well reasoned definition of "busy"
+ */
+func (device *Device) RoutineBusyMonitor() {
+ samples := 0
+ interval := time.Second
+ for timer := time.NewTimer(interval); ; {
+
+ select {
+ case <-device.signal.stop:
+ return
+ case <-timer.C:
+ }
+
+ // compute busy heuristic
+
+ if len(device.queue.handshake) > QueueHandshakeBusySize {
+ samples += 1
+ } else if samples > 0 {
+ samples -= 1
+ }
+ samples %= 30
+ busy := samples > 5
+
+ // update busy state
+
+ if busy {
+ atomic.StoreInt32(&device.congestionState, CongestionStateUnderLoad)
+ } else {
+ atomic.StoreInt32(&device.congestionState, CongestionStateOkay)
+ }
+
+ timer.Reset(interval)
+ }
+}
- debugLog := device.log.Debug
- debugLog.Println("Routine, receive incomming, started")
+func (device *Device) RoutineReceiveIncomming() {
- errorLog := device.log.Error
+ logDebug := device.log.Debug
+ logDebug.Println("Routine, receive incomming, started")
var buffer []byte
@@ -122,33 +158,6 @@ func (device *Device) RoutineReceiveIncomming() {
case MessageInitiationType, MessageResponseType:
- // verify mac1
-
- if !device.mac.CheckMAC1(packet) {
- debugLog.Println("Received packet with invalid mac1")
- return
- }
-
- // check if busy, TODO: refine definition of "busy"
-
- busy := len(device.queue.handshake) > QueueHandshakeBusySize
- if busy && !device.mac.CheckMAC2(packet, raddr) {
- sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" always follows "type"
- reply, err := device.CreateMessageCookieReply(packet, sender, raddr)
- if err != nil {
- errorLog.Println("Failed to create cookie reply:", err)
- return
- }
- writer := bytes.NewBuffer(packet[:0])
- binary.Write(writer, binary.LittleEndian, reply)
- packet = writer.Bytes()
- _, err = device.net.conn.WriteToUDP(packet, raddr)
- if err != nil {
- debugLog.Println("Failed to send cookie reply:", err)
- }
- return
- }
-
// add to handshake queue
addToHandshakeQueue(
@@ -173,7 +182,7 @@ func (device *Device) RoutineReceiveIncomming() {
reader := bytes.NewReader(packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
- debugLog.Println("Failed to decode cookie reply")
+ logDebug.Println("Failed to decode cookie reply")
return
}
device.ConsumeMessageCookieReply(&reply)
@@ -218,7 +227,7 @@ func (device *Device) RoutineReceiveIncomming() {
default:
// unknown message type
- debugLog.Println("Got unknown message from:", raddr)
+ logDebug.Println("Got unknown message from:", raddr)
}
}()
}
@@ -285,6 +294,38 @@ func (device *Device) RoutineHandshake() {
func() {
+ // verify mac1
+
+ if !device.mac.CheckMAC1(elem.packet) {
+ logDebug.Println("Received packet with invalid mac1")
+ return
+ }
+
+ // verify mac2
+
+ busy := atomic.LoadInt32(&device.congestionState) == CongestionStateUnderLoad
+
+ if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
+ sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
+ reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
+ if err != nil {
+ logError.Println("Failed to create cookie reply:", err)
+ return
+ }
+ writer := bytes.NewBuffer(elem.packet[:0])
+ binary.Write(writer, binary.LittleEndian, reply)
+ elem.packet = writer.Bytes()
+ _, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
+ if err != nil {
+ logDebug.Println("Failed to send cookie reply:", err)
+ }
+ return
+ }
+
+ // ratelimit
+
+ // handle messages
+
switch elem.msgType {
case MessageInitiationType:
@@ -321,12 +362,12 @@ func (device *Device) RoutineHandshake() {
logError.Println("Failed to create response message:", err)
return
}
+
outElem := device.NewOutboundElement()
writer := bytes.NewBuffer(outElem.data[:0])
binary.Write(writer, binary.LittleEndian, response)
elem.packet = writer.Bytes()
peer.mac.AddMacs(elem.packet)
- device.log.Debug.Println(elem.packet)
addToOutboundQueue(peer.queue.outbound, outElem)
case MessageResponseType:
@@ -388,7 +429,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
}
elem.mutex.Lock()
- // process IP packet
+ // process packet
func() {
if elem.IsDropped() {
@@ -407,30 +448,54 @@ func (peer *Peer) RoutineSequentialReceiver() {
return
}
- // strip padding
+ // verify source and strip padding
switch elem.packet[0] >> 4 {
case IPv4version:
+
+ // strip padding
+
if len(elem.packet) < IPv4headerSize {
return
}
+
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
elem.packet = elem.packet[:length]
+ // verify IPv4 source
+
+ dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ if device.routingTable.LookupIPv4(dst) != peer {
+ return
+ }
+
case IPv6version:
+
+ // strip padding
+
if len(elem.packet) < IPv6headerSize {
return
}
+
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += IPv6headerSize
elem.packet = elem.packet[:length]
+ // verify IPv6 source
+
+ dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ if device.routingTable.LookupIPv6(dst) != peer {
+ return
+ }
+
default:
device.log.Debug.Println("Receieved packet with unknown IP version")
return
}
+
+ atomic.AddUint64(&peer.rxBytes, uint64(len(elem.packet)))
addToInboundQueue(device.queue.inbound, elem)
}()
}
diff --git a/src/send.go b/src/send.go
index d1de44a..a02f5cb 100644
--- a/src/send.go
+++ b/src/send.go
@@ -329,7 +329,7 @@ func (peer *Peer) RoutineSequentialSender() {
if err != nil {
return
}
- atomic.AddUint64(&peer.tx_bytes, uint64(len(work.packet)))
+ atomic.AddUint64(&peer.txBytes, uint64(len(work.packet)))
// shift keep-alive timer