diff options
-rw-r--r-- | src/constants.go | 4 | ||||
-rw-r--r-- | src/device.go | 7 | ||||
-rw-r--r-- | src/ip.go | 4 | ||||
-rw-r--r-- | src/main.go | 31 | ||||
-rw-r--r-- | src/peer.go | 19 | ||||
-rw-r--r-- | src/receive.go | 24 | ||||
-rw-r--r-- | src/send.go | 69 | ||||
-rw-r--r-- | src/timers.go | 52 | ||||
-rw-r--r-- | src/trie.go | 19 |
9 files changed, 132 insertions, 97 deletions
diff --git a/src/constants.go b/src/constants.go index 0384741..6b0d414 100644 --- a/src/constants.go +++ b/src/constants.go @@ -29,6 +29,6 @@ const ( QueueInboundSize = 1024 QueueHandshakeSize = 1024 QueueHandshakeBusySize = QueueHandshakeSize / 8 - MinMessageSize = MessageTransportSize // keep-alive - MaxMessageSize = 4096 // TODO: make depend on the MTU? + MinMessageSize = MessageTransportSize // size of keep-alive + MaxMessageSize = (1 << 16) - 1 ) diff --git a/src/device.go b/src/device.go index a26cc7b..b272544 100644 --- a/src/device.go +++ b/src/device.go @@ -98,9 +98,9 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { } go device.RoutineBusyMonitor() + go device.RoutineWriteToTUN(tun) go device.RoutineReadFromTUN(tun) go device.RoutineReceiveIncomming() - go device.RoutineWriteToTUN(tun) go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) return device @@ -141,5 +141,8 @@ func (device *Device) RemoveAllPeers() { func (device *Device) Close() { device.RemoveAllPeers() close(device.signal.stop) - close(device.queue.encryption) +} + +func (device *Device) Wait() { + <-device.signal.stop } @@ -5,17 +5,13 @@ import ( ) const ( - IPv4version = 4 IPv4offsetTotalLength = 2 IPv4offsetSrc = 12 IPv4offsetDst = IPv4offsetSrc + net.IPv4len - IPv4headerSize = 20 ) const ( - IPv6version = 6 IPv6offsetPayloadLength = 4 IPv6offsetSrc = 8 IPv6offsetDst = IPv6offsetSrc + net.IPv6len - IPv6headerSize = 40 ) diff --git a/src/main.go b/src/main.go index 50140e3..dc27472 100644 --- a/src/main.go +++ b/src/main.go @@ -5,6 +5,7 @@ import ( "log" "net" "os" + "runtime" ) /* TODO: Fix logging @@ -18,6 +19,10 @@ func main() { } deviceName := os.Args[1] + // increase number of go workers (for Go <1.5) + + runtime.GOMAXPROCS(runtime.NumCPU()) + // open TUN device tun, err := CreateTUN(deviceName) @@ -31,17 +36,21 @@ func main() { // start configuration lister - socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName) - l, err := net.Listen("unix", socketPath) - if err != nil { - log.Fatal("listen error:", err) - } - - for { - conn, err := l.Accept() + go func() { + socketPath := fmt.Sprintf("/var/run/wireguard/%s.sock", deviceName) + l, err := net.Listen("unix", socketPath) if err != nil { - log.Fatal("accept error:", err) + log.Fatal("listen error:", err) } - go ipcHandle(device, conn) - } + + for { + conn, err := l.Accept() + if err != nil { + log.Fatal("accept error:", err) + } + go ipcHandle(device, conn) + } + }() + + device.Wait() } diff --git a/src/peer.go b/src/peer.go index c8dc5c0..408c605 100644 --- a/src/peer.go +++ b/src/peer.go @@ -1,7 +1,9 @@ package main import ( + "encoding/base64" "errors" + "fmt" "net" "sync" "time" @@ -38,9 +40,9 @@ type Peer struct { /* Both keep-alive timers acts as one (see timers.go) * They are kept seperate to simplify the implementation. */ - keepalivePersistent *time.Timer // set for persistent keepalives - keepaliveAcknowledgement *time.Timer // set upon recieving messages - zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3 + keepalivePersistent *time.Timer // set for persistent keepalives + keepalivePassive *time.Timer // set upon recieving messages + zeroAllKeys *time.Timer // zero all key material after RejectAfterTime*3 } queue struct { nonce chan *QueueOutboundElement // nonce / pre-handshake queue @@ -63,8 +65,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.mac.Init(pk) peer.device = device + peer.timer.keepalivePassive = NewStoppedTimer() peer.timer.keepalivePersistent = NewStoppedTimer() - peer.timer.keepaliveAcknowledgement = NewStoppedTimer() peer.timer.zeroAllKeys = NewStoppedTimer() peer.flags.keepaliveWaiting = AtomicFalse @@ -115,6 +117,15 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { return peer } +func (peer *Peer) String() string { + return fmt.Sprintf( + "peer(%d %s %s)", + peer.id, + peer.endpoint.String(), + base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]), + ) +} + func (peer *Peer) Close() { close(peer.signal.stop) } diff --git a/src/receive.go b/src/receive.go index 99089a9..3e649b6 100644 --- a/src/receive.go +++ b/src/receive.go @@ -4,6 +4,8 @@ import ( "bytes" "encoding/binary" "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "net" "sync" "sync/atomic" @@ -362,7 +364,7 @@ func (device *Device) RoutineHandshake() { return } - logDebug.Println("Creating response...") + logDebug.Println("Creating response message for", peer.String()) outElem := device.NewOutboundElement() writer := bytes.NewBuffer(outElem.data[:0]) @@ -416,6 +418,8 @@ func (peer *Peer) RoutineSequentialReceiver() { var elem *QueueInboundElement device := peer.device + + logInfo := device.log.Info logDebug := device.log.Debug logDebug.Println("Routine, sequential receiver, started for peer", peer.id) @@ -450,7 +454,7 @@ func (peer *Peer) RoutineSequentialReceiver() { peer.KeepKeyFreshReceiving() - // check if confirming handshake + // check if using new key-pair kp := &peer.keyPairs kp.mutex.Lock() @@ -465,17 +469,18 @@ func (peer *Peer) RoutineSequentialReceiver() { // check for keep-alive if len(elem.packet) == 0 { + logDebug.Println("Received keep-alive from", peer.String()) return } // verify source and strip padding switch elem.packet[0] >> 4 { - case IPv4version: + case ipv4.Version: // strip padding - if len(elem.packet) < IPv4headerSize { + if len(elem.packet) < ipv4.HeaderLen { return } @@ -487,31 +492,33 @@ func (peer *Peer) RoutineSequentialReceiver() { dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] if device.routingTable.LookupIPv4(dst) != peer { + logInfo.Println("Packet with unallowed source IP from", peer.String()) return } - case IPv6version: + case ipv6.Version: // strip padding - if len(elem.packet) < IPv6headerSize { + if len(elem.packet) < ipv6.HeaderLen { return } field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] length := binary.BigEndian.Uint16(field) - length += IPv6headerSize + length += ipv6.HeaderLen elem.packet = elem.packet[:length] // verify IPv6 source dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] if device.routingTable.LookupIPv6(dst) != peer { + logInfo.Println("Packet with unallowed source IP from", peer.String()) return } default: - logDebug.Println("Receieved packet with unknown IP version") + logInfo.Println("Packet with invalid IP version from", peer.String()) return } @@ -522,6 +529,7 @@ func (peer *Peer) RoutineSequentialReceiver() { } func (device *Device) RoutineWriteToTUN(tun TUNDevice) { + logError := device.log.Error logDebug := device.log.Debug logDebug.Println("Routine, sequential tun writer, started") diff --git a/src/send.go b/src/send.go index 5ea9a8f..d8ddc82 100644 --- a/src/send.go +++ b/src/send.go @@ -3,6 +3,8 @@ package main import ( "encoding/binary" "golang.org/x/crypto/chacha20poly1305" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "net" "sync" "sync/atomic" @@ -21,28 +23,26 @@ import ( * The functions in this file occure (roughly) in the order packets are processed. */ -/* A work unit - * - * The sequential consumers will attempt to take the lock, - * workers release lock when they have completed work on the packet. +/* The sequential consumers will attempt to take the lock, + * workers release lock when they have completed work (encryption) on the packet. * * If the element is inserted into the "encryption queue", - * the content is preceeded by enough "junk" to contain the header + * the content is preceeded by enough "junk" to contain the transport header * (to allow the construction of transport messages in-place) */ type QueueOutboundElement struct { dropped int32 mutex sync.Mutex - data [MaxMessageSize]byte - packet []byte // slice of "data" (always!) - nonce uint64 // nonce for encryption - keyPair *KeyPair // key-pair for encryption - peer *Peer // related peer + data [MaxMessageSize]byte // slice holding the packet data + packet []byte // slice of "data" (always!) + nonce uint64 // nonce for encryption + keyPair *KeyPair // key-pair for encryption + peer *Peer // related peer } func (peer *Peer) FlushNonceQueue() { elems := len(peer.queue.nonce) - for i := 0; i < elems; i += 1 { + for i := 0; i < elems; i++ { select { case <-peer.queue.nonce: default: @@ -111,14 +111,18 @@ func addToEncryptionQueue( * Obs. Single instance per TUN device */ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { + if tun == nil { - // dummy return } elem := device.NewOutboundElement() - device.log.Debug.Println("Routine, TUN Reader: started") + logDebug := device.log.Debug + logError := device.log.Error + + logDebug.Println("Routine, TUN Reader: started") + for { // read packet @@ -129,12 +133,17 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { elem.packet = elem.data[MessageTransportHeaderSize:] size, err := tun.Read(elem.packet) if err != nil { - device.log.Error.Println("Failed to read packet from TUN device:", err) - continue + + // stop process + + logError.Println("Failed to read packet from TUN device:", err) + device.Close() + return } + elem.packet = elem.packet[:size] - if len(elem.packet) < IPv4headerSize { - device.log.Error.Println("Packet too short, length:", size) + if len(elem.packet) < ipv4.HeaderLen { + logError.Println("Packet too short, length:", size) continue } @@ -142,23 +151,24 @@ func (device *Device) RoutineReadFromTUN(tun TUNDevice) { var peer *Peer switch elem.packet[0] >> 4 { - case IPv4version: + case ipv4.Version: dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] peer = device.routingTable.LookupIPv4(dst) - case IPv6version: + case ipv6.Version: dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] peer = device.routingTable.LookupIPv6(dst) default: - device.log.Debug.Println("Receieved packet with unknown IP version") + logDebug.Println("Receieved packet with unknown IP version") } if peer == nil { continue } + if peer.endpoint == nil { - device.log.Debug.Println("No known endpoint for peer", peer.id) + logDebug.Println("No known endpoint for peer", peer.String()) continue } @@ -184,7 +194,7 @@ func (peer *Peer) RoutineNonce() { device := peer.device logDebug := device.log.Debug - logDebug.Println("Routine, nonce worker, started for peer", peer.id) + logDebug.Println("Routine, nonce worker, started for peer", peer.String()) func() { @@ -216,15 +226,15 @@ func (peer *Peer) RoutineNonce() { } } signalSend(peer.signal.handshakeBegin) - logDebug.Println("Waiting for key-pair, peer", peer.id) + logDebug.Println("Awaiting key-pair for", peer.String()) select { case <-peer.signal.newKeyPair: - logDebug.Println("Key-pair negotiated for peer", peer.id) + logDebug.Println("Key-pair negotiated for", peer.String()) goto NextPacket case <-peer.signal.flushNonceQueue: - logDebug.Println("Clearing queue for peer", peer.id) + logDebug.Println("Clearing queue for", peer.String()) peer.FlushNonceQueue() elem = nil goto NextPacket @@ -313,13 +323,14 @@ func (peer *Peer) RoutineSequentialSender() { device := peer.device logDebug := device.log.Debug - logDebug.Println("Routine, sequential sender, started for peer", peer.id) + logDebug.Println("Routine, sequential sender, started for", peer.String()) for { select { case <-peer.signal.stop: - logDebug.Println("Routine, sequential sender, stopped for peer", peer.id) + logDebug.Println("Routine, sequential sender, stopped for", peer.String()) return + case work := <-peer.queue.outbound: work.mutex.Lock() if work.IsDropped() { @@ -334,7 +345,7 @@ func (peer *Peer) RoutineSequentialSender() { defer peer.mutex.RUnlock() if peer.endpoint == nil { - logDebug.Println("No endpoint for peer:", peer.id) + logDebug.Println("No endpoint for", peer.String()) return } @@ -352,7 +363,7 @@ func (peer *Peer) RoutineSequentialSender() { } atomic.AddUint64(&peer.txBytes, uint64(len(work.packet))) - // reset keep-alive (passive keep-alives / acknowledgements) + // reset keep-alive peer.TimerResetKeepalive() }() diff --git a/src/timers.go b/src/timers.go index 6393955..2e5046e 100644 --- a/src/timers.go +++ b/src/timers.go @@ -50,7 +50,7 @@ func (peer *Peer) KeepKeyFreshReceiving() { * - First transport message under the "next" key */ func (peer *Peer) EventHandshakeComplete() { - peer.device.log.Debug.Println("Handshake completed") + peer.device.log.Info.Println("Negotiated new handshake for", peer.String()) peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3) signalSend(peer.signal.handshakeCompleted) } @@ -112,7 +112,7 @@ func (peer *Peer) TimerResetKeepalive() { // stop acknowledgement timer - timerStop(peer.timer.keepaliveAcknowledgement) + timerStop(peer.timer.keepalivePassive) atomic.StoreInt32(&peer.flags.keepaliveWaiting, AtomicFalse) } @@ -140,7 +140,7 @@ func (peer *Peer) RoutineTimerHandler() { device := peer.device logDebug := device.log.Debug - logDebug.Println("Routine, timer handler, started for peer", peer.id) + logDebug.Println("Routine, timer handler, started for peer", peer.String()) for { select { @@ -152,14 +152,14 @@ func (peer *Peer) RoutineTimerHandler() { case <-peer.timer.keepalivePersistent.C: - logDebug.Println("Sending persistent keep-alive to peer", peer.id) + logDebug.Println("Sending persistent keep-alive to", peer.String()) peer.SendKeepAlive() peer.TimerResetKeepalive() - case <-peer.timer.keepaliveAcknowledgement.C: + case <-peer.timer.keepalivePassive.C: - logDebug.Println("Sending passive persistent keep-alive to peer", peer.id) + logDebug.Println("Sending passive persistent keep-alive to", peer.String()) peer.SendKeepAlive() peer.TimerResetKeepalive() @@ -168,7 +168,7 @@ func (peer *Peer) RoutineTimerHandler() { case <-peer.timer.zeroAllKeys.C: - logDebug.Println("Clearing all key material for peer", peer.id) + logDebug.Println("Clearing all key material for", peer.String()) // zero out key pairs @@ -208,14 +208,12 @@ func (peer *Peer) RoutineHandshakeInitiator() { var elem *QueueOutboundElement + logInfo := device.log.Info logError := device.log.Error logDebug := device.log.Debug - logDebug.Println("Routine, handshake initator, started for peer", peer.id) + logDebug.Println("Routine, handshake initator, started for", peer.String()) - for run := true; run; { - var err error - var attempts uint - var deadline time.Time + for { // wait for signal @@ -227,15 +225,17 @@ func (peer *Peer) RoutineHandshakeInitiator() { // wait for handshake - run = func() bool { - for { + func() { + var err error + var deadline time.Time + for attempts := uint(1); ; attempts++ { // clear completed signal select { case <-peer.signal.handshakeCompleted: case <-peer.signal.stop: - return false + return default: } @@ -246,43 +246,39 @@ func (peer *Peer) RoutineHandshakeInitiator() { } elem, err = peer.BeginHandshakeInitiation() if err != nil { - logError.Println("Failed to create initiation message:", err) - break + logError.Println("Failed to create initiation message", err, "for", peer.String()) + return } // set timeout - attempts += 1 if attempts == 1 { deadline = time.Now().Add(MaxHandshakeAttemptTime) } timeout := time.NewTimer(RekeyTimeout) - logDebug.Println("Handshake initiation attempt", attempts, "queued for peer", peer.id) + logDebug.Println("Handshake initiation attempt", attempts, "queued for", peer.String()) // wait for handshake or timeout select { + case <-peer.signal.stop: - return true + return case <-peer.signal.handshakeCompleted: <-timeout.C - return true + return case <-timeout.C: - logDebug.Println("Timeout") - - // check if sufficient time for retry - if deadline.Before(time.Now().Add(RekeyTimeout)) { + logInfo.Println("Handshake negotiation timed out for", peer.String()) signalSend(peer.signal.flushNonceQueue) timerStop(peer.timer.keepalivePersistent) - timerStop(peer.timer.keepaliveAcknowledgement) - return true + timerStop(peer.timer.keepalivePassive) + return } } } - return true }() signalClear(peer.signal.handshakeBegin) diff --git a/src/trie.go b/src/trie.go index c2304b2..e81b5b6 100644 --- a/src/trie.go +++ b/src/trie.go @@ -23,7 +23,8 @@ type Trie struct { bits []byte peer *Peer - // Index of "branching" bit + // index of "branching" bit + bit_at_byte uint bit_at_shift uint } @@ -36,7 +37,7 @@ type Trie struct { func commonBits(ip1 net.IP, ip2 net.IP) uint { var i uint size := uint(len(ip1)) - for i = 0; i < size; i += 1 { + for i = 0; i < size; i++ { v := ip1[i] ^ ip2[i] if v != 0 { v >>= 1 @@ -84,7 +85,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie { return node } - // Walk recursivly + // walk recursivly node.child[0] = node.child[0].RemovePeer(p) node.child[1] = node.child[1].RemovePeer(p) @@ -93,7 +94,7 @@ func (node *Trie) RemovePeer(p *Peer) *Trie { return node } - // Remove peer & merge + // remove peer & merge node.peer = nil if node.child[0] == nil { @@ -108,7 +109,7 @@ func (node *Trie) choose(ip net.IP) byte { func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { - // At leaf + // at leaf if node == nil { return &Trie{ @@ -120,7 +121,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { } } - // Traverse deeper + // traverse deeper common := commonBits(node.bits, ip) if node.cidr <= cidr && common >= node.cidr { @@ -133,7 +134,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { return node } - // Split node + // split node newNode := &Trie{ bits: ip, @@ -145,7 +146,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { cidr = min(cidr, common) - // Check for shorter prefix + // check for shorter prefix if newNode.cidr == cidr { bit := newNode.choose(node.bits) @@ -153,7 +154,7 @@ func (node *Trie) Insert(ip net.IP, cidr uint, peer *Peer) *Trie { return newNode } - // Create new parent for node & newNode + // create new parent for node & newNode parent := &Trie{ bits: ip, |