summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/device.go15
-rw-r--r--src/main.go11
-rw-r--r--src/misc.go37
-rw-r--r--src/ratelimiter.go4
-rw-r--r--src/receive.go22
-rw-r--r--src/send.go23
-rw-r--r--src/signal.go10
-rw-r--r--src/timers.go2
-rw-r--r--src/trie.go12
9 files changed, 68 insertions, 68 deletions
diff --git a/src/device.go b/src/device.go
index a1ce802..a3461ad 100644
--- a/src/device.go
+++ b/src/device.go
@@ -37,7 +37,7 @@ type Device struct {
handshake chan QueueHandshakeElement
}
signal struct {
- stop chan struct{}
+ stop Signal
}
underLoadUntil atomic.Value
ratelimiter Ratelimiter
@@ -129,7 +129,6 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device)
-
device.mutex.Lock()
defer device.mutex.Unlock()
@@ -160,7 +159,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
// prepare signals
- device.signal.stop = make(chan struct{})
+ device.signal.stop = NewSignal()
// prepare net
@@ -174,9 +173,11 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
go device.RoutineDecryption()
go device.RoutineHandshake()
}
+
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
+
return device
}
@@ -210,11 +211,11 @@ func (device *Device) Close() {
}
device.log.Info.Println("Closing device")
device.RemoveAllPeers()
- close(device.signal.stop)
- closeBind(device)
+ device.signal.stop.Broadcast()
device.tun.device.Close()
+ closeBind(device)
}
-func (device *Device) WaitChannel() chan struct{} {
- return device.signal.stop
+func (device *Device) Wait() chan struct{} {
+ return device.signal.stop.Wait()
}
diff --git a/src/main.go b/src/main.go
index e43176c..8bca78c 100644
--- a/src/main.go
+++ b/src/main.go
@@ -8,6 +8,10 @@ import (
"strconv"
)
+import _ "net/http/pprof"
+import "net/http"
+import "log"
+
const (
ExitSetupSuccess = 0
ExitSetupFailed = 1
@@ -25,6 +29,10 @@ func printUsage() {
func main() {
+ go func() {
+ log.Println(http.ListenAndServe("localhost:6060", nil))
+ }()
+
// parse arguments
var foreground bool
@@ -160,7 +168,6 @@ func main() {
errs := make(chan error)
term := make(chan os.Signal)
- wait := device.WaitChannel()
uapi, err := UAPIListen(interfaceName, fileUAPI)
@@ -183,9 +190,9 @@ func main() {
signal.Notify(term, os.Interrupt)
select {
- case <-wait:
case <-term:
case <-errs:
+ case <-device.Wait():
}
// clean up
diff --git a/src/misc.go b/src/misc.go
index b43e97e..80e33f6 100644
--- a/src/misc.go
+++ b/src/misc.go
@@ -2,12 +2,10 @@ package main
import (
"sync/atomic"
- "time"
)
-/* We use int32 as atomic bools
- * (since booleans are not natively supported by sync/atomic)
- */
+/* Atomic Boolean */
+
const (
AtomicFalse = int32(iota)
AtomicTrue
@@ -37,6 +35,8 @@ func (a *AtomicBool) Set(val bool) {
atomic.StoreInt32(&a.flag, flag)
}
+/* Integer manipulation */
+
func toInt32(n uint32) int32 {
mask := uint32(1 << 31)
return int32(-(n & mask) + (n & ^mask))
@@ -55,32 +55,3 @@ func minUint64(a uint64, b uint64) uint64 {
}
return a
}
-
-func signalSend(c chan struct{}) {
- select {
- case c <- struct{}{}:
- default:
- }
-}
-
-func signalClear(c chan struct{}) {
- select {
- case <-c:
- default:
- }
-}
-
-func timerStop(timer *time.Timer) {
- if !timer.Stop() {
- select {
- case <-timer.C:
- default:
- }
- }
-}
-
-func NewStoppedTimer() *time.Timer {
- timer := time.NewTimer(time.Hour)
- timerStop(timer)
- return timer
-}
diff --git a/src/ratelimiter.go b/src/ratelimiter.go
index 4f8227e..6e5f005 100644
--- a/src/ratelimiter.go
+++ b/src/ratelimiter.go
@@ -66,11 +66,11 @@ func (rate *Ratelimiter) GarbageCollectEntries() {
rate.mutex.Unlock()
}
-func (rate *Ratelimiter) RoutineGarbageCollector(stop chan struct{}) {
+func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) {
timer := time.NewTimer(time.Second)
for {
select {
- case <-stop:
+ case <-stop.Wait():
return
case <-timer.C:
rate.GarbageCollectEntries()
diff --git a/src/receive.go b/src/receive.go
index fd1993e..f650cc9 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -93,6 +93,11 @@ func (device *Device) addToHandshakeQueue(
}
}
+/* Receives incoming datagrams for the device
+ *
+ * Every time the bind is updated a new routine is started for
+ * IPv4 and IPv6 (separately)
+ */
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
logDebug := device.log.Debug
@@ -182,6 +187,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
device.addToDecryptionQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem)
buffer = device.GetMessageBuffer()
+
continue
// otherwise it is a fixed size & handshake related packet
@@ -220,7 +226,7 @@ func (device *Device) RoutineDecryption() {
for {
select {
- case <-device.signal.stop:
+ case <-device.signal.stop.Wait():
logDebug.Println("Routine, decryption worker, stopped")
return
@@ -256,7 +262,7 @@ func (device *Device) RoutineDecryption() {
}
}
-/* Handles incomming packets related to handshake
+/* Handles incoming packets related to handshake
*/
func (device *Device) RoutineHandshake() {
@@ -271,7 +277,7 @@ func (device *Device) RoutineHandshake() {
for {
select {
case elem = <-device.queue.handshake:
- case <-device.signal.stop:
+ case <-device.signal.stop.Wait():
return
}
@@ -356,7 +362,7 @@ func (device *Device) RoutineHandshake() {
continue
}
- // handle handshake initation/response content
+ // handle handshake initiation/response content
switch elem.msgType {
case MessageInitiationType:
@@ -376,7 +382,7 @@ func (device *Device) RoutineHandshake() {
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
logInfo.Println(
- "Recieved invalid initiation message from",
+ "Received invalid initiation message from",
elem.endpoint.DstToString(),
)
continue
@@ -449,7 +455,7 @@ func (device *Device) RoutineHandshake() {
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
- logDebug.Println("Received handshake initation from", peer)
+ logDebug.Println("Received handshake initiation from", peer)
peer.TimerEphemeralKeyCreated()
@@ -556,7 +562,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println(
- "IPv4 packet with unallowed source address from",
+ "IPv4 packet with disallowed source address from",
peer.String(),
)
continue
@@ -584,7 +590,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println(
- "IPv6 packet with unallowed source address from",
+ "IPv6 packet with disallowed source address from",
peer.String(),
)
continue
diff --git a/src/send.go b/src/send.go
index 35a4a6e..2919f2e 100644
--- a/src/send.go
+++ b/src/send.go
@@ -11,7 +11,7 @@ import (
"time"
)
-/* Handles outbound flow
+/* Outbound flow
*
* 1. TUN queue
* 2. Routing (sequential)
@@ -19,17 +19,22 @@ import (
* 4. Encryption (parallel)
* 5. Transmission (sequential)
*
- * The order of packets (per peer) is maintained.
- * The functions in this file occure (roughly) in the order packets are processed.
- */
-
-/* The sequential consumers will attempt to take the lock,
+ * The functions in this file occur (roughly) in the order in
+ * which the packets are processed.
+ *
+ * Locking, Producers and Consumers
+ *
+ * The order of packets (per peer) must be maintained,
+ * but encryption of packets happen out-of-order:
+ *
+ * The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work (encryption) on the packet.
*
* If the element is inserted into the "encryption queue",
- * the content is preceeded by enough "junk" to contain the transport header
+ * the content is preceded by enough "junk" to contain the transport header
* (to allow the construction of transport messages in-place)
*/
+
type QueueOutboundElement struct {
dropped int32
mutex sync.Mutex
@@ -155,7 +160,7 @@ func (device *Device) RoutineReadFromTUN() {
peer = device.routingTable.LookupIPv6(dst)
default:
- logDebug.Println("Receieved packet with unknown IP version")
+ logDebug.Println("Received packet with unknown IP version")
}
if peer == nil {
@@ -249,7 +254,7 @@ func (device *Device) RoutineEncryption() {
// fetch next element
select {
- case <-device.signal.stop:
+ case <-device.signal.stop.Wait():
logDebug.Println("Routine, encryption worker, stopped")
return
diff --git a/src/signal.go b/src/signal.go
index 96b21bb..2cefad4 100644
--- a/src/signal.go
+++ b/src/signal.go
@@ -20,6 +20,8 @@ func (s *Signal) Enable() {
s.enabled.Set(true)
}
+/* Unblock exactly one listener
+ */
func (s *Signal) Send() {
if s.enabled.Get() {
select {
@@ -29,6 +31,8 @@ func (s *Signal) Send() {
}
}
+/* Clear the signal if already fired
+ */
func (s Signal) Clear() {
select {
case <-s.C:
@@ -36,10 +40,14 @@ func (s Signal) Clear() {
}
}
+/* Unblocks all listeners (forever)
+ */
func (s Signal) Broadcast() {
- close(s.C) // unblocks all selectors
+ close(s.C)
}
+/* Wait for the signal
+ */
func (s Signal) Wait() chan struct{} {
return s.C
}
diff --git a/src/timers.go b/src/timers.go
index 64aeca8..ee47393 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -27,7 +27,7 @@ func (peer *Peer) KeepKeyFreshSending() {
/* Called when a new authenticated message has been received
*
- * NOTE: Not thread safe (called by sequential receiver)
+ * NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) KeepKeyFreshReceiving() {
if peer.timer.sendLastMinuteHandshake {
diff --git a/src/trie.go b/src/trie.go
index 38fcd4a..405ffc3 100644
--- a/src/trie.go
+++ b/src/trie.go
@@ -11,10 +11,8 @@ import (
* same way as those created by the "net" functions.
* Here the IPs are slices of either 4 or 16 byte (not always 16)
*
- * Syncronization done seperatly
+ * Synchronization done separately
* See: routing.go
- *
- * TODO: Better commenting
*/
type Trie struct {
@@ -30,7 +28,11 @@ type Trie struct {
}
/* Finds length of matching prefix
- * TODO: Make faster
+ *
+ * TODO: Only use during insertion (xor + prefix mask for lookup)
+ * Check out
+ * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits)
+ * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match
*
* Assumption:
* len(ip1) == len(ip2)
@@ -88,7 +90,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie {
return node
}
- // walk recursivly
+ // walk recursively
node.child[0] = node.child[0].RemovePeer(p)
node.child[1] = node.child[1].RemovePeer(p)