diff options
-rw-r--r-- | src/device.go | 15 | ||||
-rw-r--r-- | src/main.go | 11 | ||||
-rw-r--r-- | src/misc.go | 37 | ||||
-rw-r--r-- | src/ratelimiter.go | 4 | ||||
-rw-r--r-- | src/receive.go | 22 | ||||
-rw-r--r-- | src/send.go | 23 | ||||
-rw-r--r-- | src/signal.go | 10 | ||||
-rw-r--r-- | src/timers.go | 2 | ||||
-rw-r--r-- | src/trie.go | 12 |
9 files changed, 68 insertions, 68 deletions
diff --git a/src/device.go b/src/device.go index a1ce802..a3461ad 100644 --- a/src/device.go +++ b/src/device.go @@ -37,7 +37,7 @@ type Device struct { handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} + stop Signal } underLoadUntil atomic.Value ratelimiter Ratelimiter @@ -129,7 +129,6 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { func NewDevice(tun TUNDevice, logger *Logger) *Device { device := new(Device) - device.mutex.Lock() defer device.mutex.Unlock() @@ -160,7 +159,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { // prepare signals - device.signal.stop = make(chan struct{}) + device.signal.stop = NewSignal() // prepare net @@ -174,9 +173,11 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device { go device.RoutineDecryption() go device.RoutineHandshake() } + go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) + return device } @@ -210,11 +211,11 @@ func (device *Device) Close() { } device.log.Info.Println("Closing device") device.RemoveAllPeers() - close(device.signal.stop) - closeBind(device) + device.signal.stop.Broadcast() device.tun.device.Close() + closeBind(device) } -func (device *Device) WaitChannel() chan struct{} { - return device.signal.stop +func (device *Device) Wait() chan struct{} { + return device.signal.stop.Wait() } diff --git a/src/main.go b/src/main.go index e43176c..8bca78c 100644 --- a/src/main.go +++ b/src/main.go @@ -8,6 +8,10 @@ import ( "strconv" ) +import _ "net/http/pprof" +import "net/http" +import "log" + const ( ExitSetupSuccess = 0 ExitSetupFailed = 1 @@ -25,6 +29,10 @@ func printUsage() { func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() + // parse arguments var foreground bool @@ -160,7 +168,6 @@ func main() { errs := make(chan error) term := make(chan os.Signal) - wait := device.WaitChannel() uapi, err := UAPIListen(interfaceName, fileUAPI) @@ -183,9 +190,9 @@ func main() { signal.Notify(term, os.Interrupt) select { - case <-wait: case <-term: case <-errs: + case <-device.Wait(): } // clean up diff --git a/src/misc.go b/src/misc.go index b43e97e..80e33f6 100644 --- a/src/misc.go +++ b/src/misc.go @@ -2,12 +2,10 @@ package main import ( "sync/atomic" - "time" ) -/* We use int32 as atomic bools - * (since booleans are not natively supported by sync/atomic) - */ +/* Atomic Boolean */ + const ( AtomicFalse = int32(iota) AtomicTrue @@ -37,6 +35,8 @@ func (a *AtomicBool) Set(val bool) { atomic.StoreInt32(&a.flag, flag) } +/* Integer manipulation */ + func toInt32(n uint32) int32 { mask := uint32(1 << 31) return int32(-(n & mask) + (n & ^mask)) @@ -55,32 +55,3 @@ func minUint64(a uint64, b uint64) uint64 { } return a } - -func signalSend(c chan struct{}) { - select { - case c <- struct{}{}: - default: - } -} - -func signalClear(c chan struct{}) { - select { - case <-c: - default: - } -} - -func timerStop(timer *time.Timer) { - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } -} - -func NewStoppedTimer() *time.Timer { - timer := time.NewTimer(time.Hour) - timerStop(timer) - return timer -} diff --git a/src/ratelimiter.go b/src/ratelimiter.go index 4f8227e..6e5f005 100644 --- a/src/ratelimiter.go +++ b/src/ratelimiter.go @@ -66,11 +66,11 @@ func (rate *Ratelimiter) GarbageCollectEntries() { rate.mutex.Unlock() } -func (rate *Ratelimiter) RoutineGarbageCollector(stop chan struct{}) { +func (rate *Ratelimiter) RoutineGarbageCollector(stop Signal) { timer := time.NewTimer(time.Second) for { select { - case <-stop: + case <-stop.Wait(): return case <-timer.C: rate.GarbageCollectEntries() diff --git a/src/receive.go b/src/receive.go index fd1993e..f650cc9 100644 --- a/src/receive.go +++ b/src/receive.go @@ -93,6 +93,11 @@ func (device *Device) addToHandshakeQueue( } } +/* Receives incoming datagrams for the device + * + * Every time the bind is updated a new routine is started for + * IPv4 and IPv6 (separately) + */ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { logDebug := device.log.Debug @@ -182,6 +187,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { device.addToDecryptionQueue(device.queue.decryption, elem) device.addToInboundQueue(peer.queue.inbound, elem) buffer = device.GetMessageBuffer() + continue // otherwise it is a fixed size & handshake related packet @@ -220,7 +226,7 @@ func (device *Device) RoutineDecryption() { for { select { - case <-device.signal.stop: + case <-device.signal.stop.Wait(): logDebug.Println("Routine, decryption worker, stopped") return @@ -256,7 +262,7 @@ func (device *Device) RoutineDecryption() { } } -/* Handles incomming packets related to handshake +/* Handles incoming packets related to handshake */ func (device *Device) RoutineHandshake() { @@ -271,7 +277,7 @@ func (device *Device) RoutineHandshake() { for { select { case elem = <-device.queue.handshake: - case <-device.signal.stop: + case <-device.signal.stop.Wait(): return } @@ -356,7 +362,7 @@ func (device *Device) RoutineHandshake() { continue } - // handle handshake initation/response content + // handle handshake initiation/response content switch elem.msgType { case MessageInitiationType: @@ -376,7 +382,7 @@ func (device *Device) RoutineHandshake() { peer := device.ConsumeMessageInitiation(&msg) if peer == nil { logInfo.Println( - "Recieved invalid initiation message from", + "Received invalid initiation message from", elem.endpoint.DstToString(), ) continue @@ -449,7 +455,7 @@ func (device *Device) RoutineHandshake() { peer.endpoint = elem.endpoint peer.mutex.Unlock() - logDebug.Println("Received handshake initation from", peer) + logDebug.Println("Received handshake initiation from", peer) peer.TimerEphemeralKeyCreated() @@ -556,7 +562,7 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] if device.routingTable.LookupIPv4(src) != peer { logInfo.Println( - "IPv4 packet with unallowed source address from", + "IPv4 packet with disallowed source address from", peer.String(), ) continue @@ -584,7 +590,7 @@ func (peer *Peer) RoutineSequentialReceiver() { src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] if device.routingTable.LookupIPv6(src) != peer { logInfo.Println( - "IPv6 packet with unallowed source address from", + "IPv6 packet with disallowed source address from", peer.String(), ) continue diff --git a/src/send.go b/src/send.go index 35a4a6e..2919f2e 100644 --- a/src/send.go +++ b/src/send.go @@ -11,7 +11,7 @@ import ( "time" ) -/* Handles outbound flow +/* Outbound flow * * 1. TUN queue * 2. Routing (sequential) @@ -19,17 +19,22 @@ import ( * 4. Encryption (parallel) * 5. Transmission (sequential) * - * The order of packets (per peer) is maintained. - * The functions in this file occure (roughly) in the order packets are processed. - */ - -/* The sequential consumers will attempt to take the lock, + * The functions in this file occur (roughly) in the order in + * which the packets are processed. + * + * Locking, Producers and Consumers + * + * The order of packets (per peer) must be maintained, + * but encryption of packets happen out-of-order: + * + * The sequential consumers will attempt to take the lock, * workers release lock when they have completed work (encryption) on the packet. * * If the element is inserted into the "encryption queue", - * the content is preceeded by enough "junk" to contain the transport header + * the content is preceded by enough "junk" to contain the transport header * (to allow the construction of transport messages in-place) */ + type QueueOutboundElement struct { dropped int32 mutex sync.Mutex @@ -155,7 +160,7 @@ func (device *Device) RoutineReadFromTUN() { peer = device.routingTable.LookupIPv6(dst) default: - logDebug.Println("Receieved packet with unknown IP version") + logDebug.Println("Received packet with unknown IP version") } if peer == nil { @@ -249,7 +254,7 @@ func (device *Device) RoutineEncryption() { // fetch next element select { - case <-device.signal.stop: + case <-device.signal.stop.Wait(): logDebug.Println("Routine, encryption worker, stopped") return diff --git a/src/signal.go b/src/signal.go index 96b21bb..2cefad4 100644 --- a/src/signal.go +++ b/src/signal.go @@ -20,6 +20,8 @@ func (s *Signal) Enable() { s.enabled.Set(true) } +/* Unblock exactly one listener + */ func (s *Signal) Send() { if s.enabled.Get() { select { @@ -29,6 +31,8 @@ func (s *Signal) Send() { } } +/* Clear the signal if already fired + */ func (s Signal) Clear() { select { case <-s.C: @@ -36,10 +40,14 @@ func (s Signal) Clear() { } } +/* Unblocks all listeners (forever) + */ func (s Signal) Broadcast() { - close(s.C) // unblocks all selectors + close(s.C) } +/* Wait for the signal + */ func (s Signal) Wait() chan struct{} { return s.C } diff --git a/src/timers.go b/src/timers.go index 64aeca8..ee47393 100644 --- a/src/timers.go +++ b/src/timers.go @@ -27,7 +27,7 @@ func (peer *Peer) KeepKeyFreshSending() { /* Called when a new authenticated message has been received
*
- * NOTE: Not thread safe (called by sequential receiver)
+ * NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) KeepKeyFreshReceiving() {
if peer.timer.sendLastMinuteHandshake {
diff --git a/src/trie.go b/src/trie.go index 38fcd4a..405ffc3 100644 --- a/src/trie.go +++ b/src/trie.go @@ -11,10 +11,8 @@ import ( * same way as those created by the "net" functions. * Here the IPs are slices of either 4 or 16 byte (not always 16) * - * Syncronization done seperatly + * Synchronization done separately * See: routing.go - * - * TODO: Better commenting */ type Trie struct { @@ -30,7 +28,11 @@ type Trie struct { } /* Finds length of matching prefix - * TODO: Make faster + * + * TODO: Only use during insertion (xor + prefix mask for lookup) + * Check out + * prefix_matches(struct allowedips_node *node, const u8 *key, u8 bits) + * https://git.zx2c4.com/WireGuard/commit/?h=jd/precomputed-prefix-match * * Assumption: * len(ip1) == len(ip2) @@ -88,7 +90,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie { return node } - // walk recursivly + // walk recursively node.child[0] = node.child[0].RemovePeer(p) node.child[1] = node.child[1].RemovePeer(p) |