diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Makefile | 9 | ||||
-rw-r--r-- | src/config.go | 70 | ||||
-rw-r--r-- | src/constants.go | 21 | ||||
-rw-r--r-- | src/device.go | 41 | ||||
-rw-r--r-- | src/handshake.go | 172 | ||||
-rw-r--r-- | src/helper_test.go | 64 | ||||
-rw-r--r-- | src/ip.go | 7 | ||||
-rw-r--r-- | src/macs_test.go | 10 | ||||
-rw-r--r-- | src/main.go | 45 | ||||
-rw-r--r-- | src/noise_protocol.go | 127 | ||||
-rw-r--r-- | src/noise_test.go | 27 | ||||
-rw-r--r-- | src/noise_types.go | 17 | ||||
-rw-r--r-- | src/peer.go | 46 | ||||
-rw-r--r-- | src/routing.go | 12 | ||||
-rw-r--r-- | src/send.go | 215 | ||||
-rw-r--r-- | src/trie.go | 31 | ||||
-rw-r--r-- | src/tun.go | 2 | ||||
-rw-r--r-- | src/tun_linux.go | 8 |
18 files changed, 694 insertions, 230 deletions
diff --git a/src/Makefile b/src/Makefile new file mode 100644 index 0000000..4ef8199 --- /dev/null +++ b/src/Makefile @@ -0,0 +1,9 @@ +BINARY=wireguard-go + +build: + go build -o ${BINARY} + +clean: + if [ -f ${BINARY} ]; then rm ${BINARY}; fi + +.PHONY: clean diff --git a/src/config.go b/src/config.go index cb7e9ef..3b91d00 100644 --- a/src/config.go +++ b/src/config.go @@ -11,7 +11,7 @@ import ( "time" ) -/* todo : use real error code +/* TODO : use real error code * Many of which will be the same */ const ( @@ -37,8 +37,55 @@ func (s *IPCError) ErrorCode() int { return s.Code } -func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) { +func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error { + device.mutex.RLock() + defer device.mutex.RUnlock() + + // create lines + + lines := make([]string, 0, 100) + send := func(line string) { + lines = append(lines, line) + } + + if !device.privateKey.IsZero() { + send("private_key=" + device.privateKey.ToHex()) + } + + if device.address != nil { + send(fmt.Sprintf("listen_port=%d", device.address.Port)) + } + + for _, peer := range device.peers { + func() { + peer.mutex.RLock() + defer peer.mutex.RUnlock() + send("public_key=" + peer.handshake.remoteStatic.ToHex()) + send("preshared_key=" + peer.handshake.presharedKey.ToHex()) + if peer.endpoint != nil { + send("endpoint=" + peer.endpoint.String()) + } + send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes)) + send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes)) + send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) + for _, ip := range device.routingTable.AllowedIPs(peer) { + send("allowed_ip=" + ip.String()) + } + }() + } + + // send lines + + for _, line := range lines { + device.log.Debug.Println("config:", line) + _, err := socket.WriteString(line + "\n") + if err != nil { + return err + } + } + + return nil } func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { @@ -179,7 +226,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError { return nil } -func ipcListen(dev *Device, socket io.ReadWriter) error { +func ipcListen(device *Device, socket io.ReadWriter) error { buffered := func(s io.ReadWriter) *bufio.ReadWriter { reader := bufio.NewReader(s) @@ -187,6 +234,8 @@ func ipcListen(dev *Device, socket io.ReadWriter) error { return bufio.NewReadWriter(reader, writer) }(socket) + defer buffered.Flush() + for { op, err := buffered.ReadString('\n') if err != nil { @@ -197,17 +246,26 @@ func ipcListen(dev *Device, socket io.ReadWriter) error { switch op { case "set=1\n": - err := ipcSetOperation(dev, buffered) + err := ipcSetOperation(device, buffered) if err != nil { - fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode()) + fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode()) return err } else { - fmt.Fprintf(buffered, "errno=0\n") + fmt.Fprintf(buffered, "errno=0\n\n") } buffered.Flush() case "get=1\n": + err := ipcGetOperation(device, buffered) + if err != nil { + fmt.Fprintf(buffered, "errno=1\n\n") // fix + return err + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } + buffered.Flush() + case "\n": default: return errors.New("handle this please") } diff --git a/src/constants.go b/src/constants.go index dc95379..e8cdd63 100644 --- a/src/constants.go +++ b/src/constants.go @@ -5,12 +5,17 @@ import ( ) const ( - RekeyAfterMessage = (1 << 64) - (1 << 16) - 1 - RekeyAfterTime = time.Second * 120 - RekeyAttemptTime = time.Second * 90 - RekeyTimeout = time.Second * 5 - RejectAfterTime = time.Second * 180 - RejectAfterMessage = (1 << 64) - (1 << 4) - 1 - KeepaliveTimeout = time.Second * 10 - CookieRefreshTime = time.Second * 2 + RekeyAfterMessage = (1 << 64) - (1 << 16) - 1 + RekeyAfterTime = time.Second * 120 + RekeyAttemptTime = time.Second * 90 + RekeyTimeout = time.Second * 5 // TODO: Exponential backoff + RejectAfterTime = time.Second * 180 + RejectAfterMessage = (1 << 64) - (1 << 4) - 1 + KeepaliveTimeout = time.Second * 10 + CookieRefreshTime = time.Second * 2 + MaxHandshakeAttempTime = time.Second * 90 +) + +const ( + QueueOutboundSize = 1024 ) diff --git a/src/device.go b/src/device.go index b3484c5..a7a5c7b 100644 --- a/src/device.go +++ b/src/device.go @@ -2,23 +2,26 @@ package main import ( "net" + "runtime" "sync" ) type Device struct { - mtu int - fwMark uint32 - address *net.UDPAddr // UDP source address - conn *net.UDPConn // UDP "connection" - mutex sync.RWMutex - privateKey NoisePrivateKey - publicKey NoisePublicKey - routingTable RoutingTable - indices IndexTable - log *Logger - queueWorkOutbound chan *OutboundWorkQueueElement - peers map[NoisePublicKey]*Peer - mac MacStateDevice + mtu int + fwMark uint32 + address *net.UDPAddr // UDP source address + conn *net.UDPConn // UDP "connection" + mutex sync.RWMutex + privateKey NoisePrivateKey + publicKey NoisePublicKey + routingTable RoutingTable + indices IndexTable + log *Logger + queue struct { + encryption chan *QueueOutboundElement // parallel work queue + } + peers map[NoisePublicKey]*Peer + mac MacStateDevice } func (device *Device) SetPrivateKey(sk NoisePrivateKey) { @@ -41,7 +44,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) { } } -func (device *Device) Init() { +func NewDevice(tun TUNDevice) *Device { + device := new(Device) + device.mutex.Lock() defer device.mutex.Unlock() @@ -49,6 +54,14 @@ func (device *Device) Init() { device.peers = make(map[NoisePublicKey]*Peer) device.indices.Init() device.routingTable.Reset() + + // start workers + + for i := 0; i < runtime.NumCPU(); i += 1 { + go device.RoutineEncryption() + } + go device.RoutineReadFromTUN(tun) + return device } func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { diff --git a/src/handshake.go b/src/handshake.go new file mode 100644 index 0000000..238c339 --- /dev/null +++ b/src/handshake.go @@ -0,0 +1,172 @@ +package main + +import ( + "bytes" + "encoding/binary" + "net" + "sync/atomic" + "time" +) + +/* Sends a keep-alive if no packets queued for peer + * + * Used by initiator of handshake and with active keep-alive + */ +func (peer *Peer) SendKeepAlive() bool { + if len(peer.queue.nonce) == 0 { + select { + case peer.queue.nonce <- []byte{}: + return true + default: + return false + } + } + return true +} + +func (peer *Peer) RoutineHandshakeInitiator() { + var ongoing bool + var begun time.Time + var attempts uint + var timeout time.Timer + + device := peer.device + work := new(QueueOutboundElement) + buffer := make([]byte, 0, 1024) + + queueHandshakeInitiation := func() error { + work.mutex.Lock() + defer work.mutex.Unlock() + + // create initiation + + msg, err := device.CreateMessageInitiation(peer) + if err != nil { + return err + } + + // create "work" element + + writer := bytes.NewBuffer(buffer[:0]) + binary.Write(writer, binary.LittleEndian, &msg) + work.packet = writer.Bytes() + peer.mac.AddMacs(work.packet) + peer.InsertOutbound(work) + return nil + } + + for { + select { + case <-peer.signal.stopInitiator: + return + + case <-peer.signal.newHandshake: + if ongoing { + continue + } + + // create handshake + + err := queueHandshakeInitiation() + if err != nil { + device.log.Error.Println("Failed to create initiation message:", err) + } + + // log when we began + + begun = time.Now() + ongoing = true + attempts = 0 + timeout.Reset(RekeyTimeout) + + case <-peer.timer.sendKeepalive.C: + + // active keep-alives + + peer.SendKeepAlive() + + case <-peer.timer.handshakeTimeout.C: + + // check if we can stop trying + + if time.Now().Sub(begun) > MaxHandshakeAttempTime { + peer.signal.flushNonceQueue <- true + peer.timer.sendKeepalive.Stop() + ongoing = false + continue + } + + // otherwise, try again (exponental backoff) + + attempts += 1 + err := queueHandshakeInitiation() + if err != nil { + device.log.Error.Println("Failed to create initiation message:", err) + } + peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout) + } + } +} + +/* Handles packets related to handshake + * + * + */ +func (device *Device) HandshakeWorker(queue chan struct { + msg []byte + msgType uint32 + addr *net.UDPAddr +}) { + for { + elem := <-queue + + switch elem.msgType { + case MessageInitiationType: + if len(elem.msg) != MessageInitiationSize { + continue + } + + // check for cookie + + var msg MessageInitiation + + binary.Read(nil, binary.LittleEndian, &msg) + + case MessageResponseType: + if len(elem.msg) != MessageResponseSize { + continue + } + + // check for cookie + + case MessageCookieReplyType: + + case MessageTransportType: + } + + } +} + +func (device *Device) KeepKeyFresh(peer *Peer) { + + send := func() bool { + peer.keyPairs.mutex.RLock() + defer peer.keyPairs.mutex.RUnlock() + + kp := peer.keyPairs.current + if kp == nil { + return false + } + + nonce := atomic.LoadUint64(&kp.sendNonce) + if nonce > RekeyAfterMessage { + return true + } + + return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime + }() + + if send { + + } +} diff --git a/src/helper_test.go b/src/helper_test.go new file mode 100644 index 0000000..3a5c331 --- /dev/null +++ b/src/helper_test.go @@ -0,0 +1,64 @@ +package main + +import ( + "bytes" + "testing" +) + +/* Helpers for writing unit tests + */ + +type DummyTUN struct { + name string + mtu uint + packets chan []byte +} + +func (tun *DummyTUN) Name() string { + return tun.name +} + +func (tun *DummyTUN) MTU() uint { + return tun.mtu +} + +func (tun *DummyTUN) Write(d []byte) (int, error) { + tun.packets <- d + return len(d), nil +} + +func (tun *DummyTUN) Read(d []byte) (int, error) { + t := <-tun.packets + copy(d, t) + return len(t), nil +} + +func CreateDummyTUN(name string) (TUNDevice, error) { + var dummy DummyTUN + dummy.mtu = 1024 + dummy.packets = make(chan []byte, 100) + return &dummy, nil +} + +func assertNil(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, a []byte, b []byte) { + if bytes.Compare(a, b) != 0 { + t.Fatal(a, "!=", b) + } +} + +func randDevice(t *testing.T) *Device { + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + tun, _ := CreateDummyTUN("dummy") + device := NewDevice(tun) + device.SetPrivateKey(sk) + return device +} @@ -5,9 +5,10 @@ import ( ) const ( - IPv4version = 4 - IPv4offsetSrc = 12 - IPv4offsetDst = IPv4offsetSrc + net.IPv4len + IPv4version = 4 + IPv4offsetSrc = 12 + IPv4offsetDst = IPv4offsetSrc + net.IPv4len + IPv4headerSize = 20 ) const ( diff --git a/src/macs_test.go b/src/macs_test.go index a67ccfb..fcb64ea 100644 --- a/src/macs_test.go +++ b/src/macs_test.go @@ -8,8 +8,8 @@ import ( ) func TestMAC1(t *testing.T) { - dev1 := newDevice(t) - dev2 := newDevice(t) + dev1 := randDevice(t) + dev2 := randDevice(t) peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) @@ -34,12 +34,10 @@ func TestMACs(t *testing.T) { msg []byte, receiver uint32, ) bool { - var device1 Device - device1.Init() + device1 := randDevice(t) device1.SetPrivateKey(sk1) - var device2 Device - device2.Init() + device2 := randDevice(t) device2.SetPrivateKey(sk2) peer1 := device2.NewPeer(device1.privateKey.publicKey()) diff --git a/src/main.go b/src/main.go index b6f6deb..7c58972 100644 --- a/src/main.go +++ b/src/main.go @@ -1,36 +1,30 @@ package main import ( - "fmt" + "log" + "net" ) +/* + * + * TODO: Fix logging + */ + func main() { - fd, err := CreateTUN("test0") - fmt.Println(fd, err) + // Open TUN device - queue := make(chan []byte, 1000) + // TODO: Fix capabilities - // var device Device + tun, err := CreateTUN("test0") + log.Println(tun, err) + if err != nil { + return + } - // go OutgoingRoutingWorker(&device, queue) + device := NewDevice(tun) - for { - tmp := make([]byte, 1<<16) - n, err := fd.Read(tmp) - if err != nil { - break - } - queue <- tmp[:n] - } -} + // Start configuration lister -/* -import ( - "fmt" - "log" - "net" -) -func main() { l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock") if err != nil { log.Fatal("listen error:", err) @@ -41,12 +35,9 @@ func main() { if err != nil { log.Fatal("accept error:", err) } - - var dev Device go func(conn net.Conn) { - err := ipcListen(&dev, conn) - fmt.Println(err) + err := ipcListen(device, conn) + log.Println(err) }(fd) } } -*/ diff --git a/src/noise_protocol.go b/src/noise_protocol.go index e237dbe..46ceeda 100644 --- a/src/noise_protocol.go +++ b/src/noise_protocol.go @@ -77,7 +77,7 @@ type MessageCookieReply struct { type Handshake struct { state int - mutex sync.Mutex + mutex sync.RWMutex hash [blake2s.Size]byte // hash value chainKey [blake2s.Size]byte // chain key presharedKey NoiseSymmetricKey // psk @@ -205,49 +205,64 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { } hash = mixHash(hash, msg.Static[:]) - // find peer + // lookup peer peer := device.LookupPeer(peerPK) if peer == nil { return nil } handshake := &peer.handshake - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - // decrypt timestamp + // verify identity var timestamp TAI64N - func() { - var key [chacha20poly1305.KeySize]byte - chainKey, key = KDF2( - chainKey[:], - handshake.precomputedStaticStatic[:], - ) - aead, _ := chacha20poly1305.New(key[:]) - _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) - }() - if err != nil { - return nil - } - hash = mixHash(hash, msg.Timestamp[:]) + ok := func() bool { + + // read lock handshake + + handshake.mutex.RLock() + defer handshake.mutex.RUnlock() + + // decrypt timestamp + + func() { + var key [chacha20poly1305.KeySize]byte + chainKey, key = KDF2( + chainKey[:], + handshake.precomputedStaticStatic[:], + ) + aead, _ := chacha20poly1305.New(key[:]) + _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) + }() + if err != nil { + return false + } + hash = mixHash(hash, msg.Timestamp[:]) + + // TODO: check for flood attack + + // check for replay attack - // check for replay attack + return timestamp.After(handshake.lastTimestamp) + }() - if !timestamp.After(handshake.lastTimestamp) { + if !ok { return nil } - // TODO: check for flood attack - // update handshake state + handshake.mutex.Lock() + handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender handshake.remoteEphemeral = msg.Ephemeral handshake.lastTimestamp = timestamp handshake.state = HandshakeInitiationConsumed + + handshake.mutex.Unlock() + return peer } @@ -320,47 +335,67 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { return nil } - handshake.mutex.Lock() - defer handshake.mutex.Unlock() - if handshake.state != HandshakeInitiationCreated { - return nil - } + var ( + hash [blake2s.Size]byte + chainKey [blake2s.Size]byte + ) - // finish 3-way DH + ok := func() bool { - hash := mixHash(handshake.hash, msg.Ephemeral[:]) - chainKey := handshake.chainKey + // read lock handshake - func() { - ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) - chainKey = mixKey(chainKey, ss[:]) - ss = device.privateKey.sharedSecret(msg.Ephemeral) - chainKey = mixKey(chainKey, ss[:]) - }() + handshake.mutex.RLock() + defer handshake.mutex.RUnlock() - // add preshared key (psk) + if handshake.state != HandshakeInitiationCreated { + return false + } - var tau [blake2s.Size]byte - var key [chacha20poly1305.KeySize]byte - chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) - hash = mixHash(hash, tau[:]) + // finish 3-way DH - // authenticate + hash = mixHash(handshake.hash, msg.Ephemeral[:]) + chainKey = handshake.chainKey - aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) - if err != nil { + func() { + ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + chainKey = mixKey(chainKey, ss[:]) + ss = device.privateKey.sharedSecret(msg.Ephemeral) + chainKey = mixKey(chainKey, ss[:]) + }() + + // add preshared key (psk) + + var tau [blake2s.Size]byte + var key [chacha20poly1305.KeySize]byte + chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:]) + hash = mixHash(hash, tau[:]) + + // authenticate + + aead, _ := chacha20poly1305.New(key[:]) + _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + if err != nil { + return false + } + hash = mixHash(hash, msg.Empty[:]) + return true + }() + + if !ok { return nil } - hash = mixHash(hash, msg.Empty[:]) // update handshake state + handshake.mutex.Lock() + handshake.hash = hash handshake.chainKey = chainKey handshake.remoteIndex = msg.Sender handshake.state = HandshakeResponseConsumed + handshake.mutex.Unlock() + return lookup.peer } diff --git a/src/noise_test.go b/src/noise_test.go index dab603b..02f6bf3 100644 --- a/src/noise_test.go +++ b/src/noise_test.go @@ -6,29 +6,6 @@ import ( "testing" ) -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) - } -} - -func assertEqual(t *testing.T, a []byte, b []byte) { - if bytes.Compare(a, b) != 0 { - t.Fatal(a, "!=", b) - } -} - -func newDevice(t *testing.T) *Device { - var device Device - sk, err := newPrivateKey() - if err != nil { - t.Fatal(err) - } - device.Init() - device.SetPrivateKey(sk) - return &device -} - func TestCurveWrappers(t *testing.T) { sk1, err := newPrivateKey() assertNil(t, err) @@ -49,8 +26,8 @@ func TestCurveWrappers(t *testing.T) { func TestNoiseHandshake(t *testing.T) { - dev1 := newDevice(t) - dev2 := newDevice(t) + dev1 := randDevice(t) + dev2 := randDevice(t) peer1 := dev2.NewPeer(dev1.privateKey.publicKey()) peer2 := dev1.NewPeer(dev2.privateKey.publicKey()) diff --git a/src/noise_types.go b/src/noise_types.go index 5508f9a..5ebc130 100644 --- a/src/noise_types.go +++ b/src/noise_types.go @@ -3,18 +3,18 @@ package main import ( "encoding/hex" "errors" + "golang.org/x/crypto/chacha20poly1305" ) const ( - NoisePublicKeySize = 32 - NoisePrivateKeySize = 32 - NoiseSymmetricKeySize = 32 + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 ) type ( NoisePublicKey [NoisePublicKeySize]byte NoisePrivateKey [NoisePrivateKeySize]byte - NoiseSymmetricKey [NoiseSymmetricKeySize]byte + NoiseSymmetricKey [chacha20poly1305.KeySize]byte NoiseNonce uint64 // padded to 12-bytes ) @@ -30,6 +30,15 @@ func loadExactHex(dst []byte, src string) error { return nil } +func (key NoisePrivateKey) IsZero() bool { + for _, b := range key[:] { + if b != 0 { + return false + } + } + return true +} + func (key *NoisePrivateKey) FromHex(src string) error { return loadExactHex(key[:], src) } diff --git a/src/peer.go b/src/peer.go index e192b12..21cad9d 100644 --- a/src/peer.go +++ b/src/peer.go @@ -7,9 +7,7 @@ import ( "time" ) -const ( - OutboundQueueSize = 64 -) +const () type Peer struct { mutex sync.RWMutex @@ -18,10 +16,26 @@ type Peer struct { keyPairs KeyPairs handshake Handshake device *Device - queueInbound chan []byte - queueOutbound chan *OutboundWorkQueueElement - queueOutboundRouting chan []byte - mac MacStatePeer + tx_bytes uint64 + rx_bytes uint64 + time struct { + lastSend time.Time // last send message + } + signal struct { + newHandshake chan bool + flushNonceQueue chan bool // empty queued packets + stopSending chan bool // stop sending pipeline + stopInitiator chan bool // stop initiator timer + } + timer struct { + sendKeepalive time.Timer + handshakeTimeout time.Timer + } + queue struct { + nonce chan []byte // nonce / pre-handshake queue + outbound chan *QueueOutboundElement // sequential ordering of work + } + mac MacStatePeer } func (device *Device) NewPeer(pk NoisePublicKey) *Peer { @@ -33,7 +47,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { peer.device = device peer.keyPairs.Init() peer.mac.Init(pk) - peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize) + peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.nonce = make(chan []byte, QueueOutboundSize) // map public key @@ -54,5 +69,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer { handshake.mutex.Unlock() peer.mutex.Unlock() + // start workers + + peer.signal.stopSending = make(chan bool, 1) + peer.signal.stopInitiator = make(chan bool, 1) + peer.signal.newHandshake = make(chan bool, 1) + peer.signal.flushNonceQueue = make(chan bool, 1) + + go peer.RoutineNonce() + go peer.RoutineHandshakeInitiator() + return &peer } + +func (peer *Peer) Close() { + peer.signal.stopSending <- true + peer.signal.stopInitiator <- true +} diff --git a/src/routing.go b/src/routing.go index 4189c25..6a5e1f3 100644 --- a/src/routing.go +++ b/src/routing.go @@ -12,9 +12,20 @@ type RoutingTable struct { mutex sync.RWMutex } +func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet { + table.mutex.RLock() + defer table.mutex.RUnlock() + + allowed := make([]net.IPNet, 10) + table.IPv4.AllowedIPs(peer, allowed) + table.IPv6.AllowedIPs(peer, allowed) + return allowed +} + func (table *RoutingTable) Reset() { table.mutex.Lock() defer table.mutex.Unlock() + table.IPv4 = nil table.IPv6 = nil } @@ -22,6 +33,7 @@ func (table *RoutingTable) Reset() { func (table *RoutingTable) RemovePeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() + table.IPv4 = table.IPv4.RemovePeer(peer) table.IPv6 = table.IPv6.RemovePeer(peer) } diff --git a/src/send.go b/src/send.go index f58d311..4ff75db 100644 --- a/src/send.go +++ b/src/send.go @@ -5,107 +5,159 @@ import ( "golang.org/x/crypto/chacha20poly1305" "net" "sync" - "time" ) /* Handles outbound flow * * 1. TUN queue - * 2. Routing - * 3. Per peer queuing - * 4. (work queuing) + * 2. Routing (sequential) + * 3. Nonce assignment (sequential) + * 4. Encryption (parallel) + * 5. Transmission (sequential) * + * The order of packets (per peer) is maintained. + * The functions in this file occure (roughly) in the order packets are processed. */ -type OutboundWorkQueueElement struct { - wg sync.WaitGroup +/* A work unit + * + * The sequential consumers will attempt to take the lock, + * workers release lock when they have completed work on the packet. + */ +type QueueOutboundElement struct { + mutex sync.Mutex packet []byte nonce uint64 keyPair *KeyPair } -func (peer *Peer) HandshakeWorker(handshakeQueue []byte) { - +func (peer *Peer) FlushNonceQueue() { + elems := len(peer.queue.nonce) + for i := 0; i < elems; i += 1 { + select { + case <-peer.queue.nonce: + default: + return + } + } } -func (device *Device) SendPacket(packet []byte) { +func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) { + for { + select { + case peer.queue.outbound <- elem: + default: + select { + case <-peer.queue.outbound: + default: + } + } + } +} - // lookup peer +/* Reads packets from the TUN and inserts + * into nonce queue for peer + * + * Obs. Single instance per TUN device + */ +func (device *Device) RoutineReadFromTUN(tun TUNDevice) { + for { + // read packet - var peer *Peer - switch packet[0] >> 4 { - case IPv4version: - dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.routingTable.LookupIPv4(dst) + packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation + size, err := tun.Read(packet) + if err != nil { + device.log.Error.Println("Failed to read packet from TUN device:", err) + continue + } + packet = packet[:size] + if len(packet) < IPv4headerSize { + device.log.Error.Println("Packet too short, length:", len(packet)) + continue + } - case IPv6version: - dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.routingTable.LookupIPv6(dst) + device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove. - default: - device.log.Debug.Println("receieved packet with unknown IP version") - return - } + // lookup peer - if peer == nil { - return - } + var peer *Peer + switch packet[0] >> 4 { + case IPv4version: + dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.routingTable.LookupIPv4(dst) - // insert into peer queue + case IPv6version: + dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.routingTable.LookupIPv6(dst) - for { - select { - case peer.queueOutboundRouting <- packet: default: + device.log.Debug.Println("Receieved packet with unknown IP version") + return + } + + if peer == nil { + device.log.Debug.Println("No peer configured for IP") + return + } + + // insert into nonce/pre-handshake queue + + for { select { - case <-peer.queueOutboundRouting: + case peer.queue.nonce <- packet: default: + select { + case <-peer.queue.nonce: + default: + } + continue } - continue + break } - break } } -/* Go routine +/* Queues packets when there is no handshake. + * Then assigns nonces to packets sequentially + * and creates "work" structs for workers * + * TODO: Avoid dynamic allocation of work queue elements * - * 1. waits for handshake. - * 2. assigns key pair & nonce - * 3. inserts to working queue - * - * TODO: avoid dynamic allocation of work queue elements + * Obs. A single instance per peer */ -func (peer *Peer) RoutineOutboundNonceWorker() { +func (peer *Peer) RoutineNonce() { var packet []byte var keyPair *KeyPair - var flushTimer time.Timer for { // wait for packet if packet == nil { - packet = <-peer.queueOutboundRouting + select { + case packet = <-peer.queue.nonce: + case <-peer.signal.stopSending: + close(peer.queue.outbound) + return + } } // wait for key pair for keyPair == nil { - flushTimer.Reset(time.Second * 10) - // TODO: Handshake or NOP + peer.signal.newHandshake <- true select { case <-peer.keyPairs.newKeyPair: keyPair = peer.keyPairs.Current() continue - case <-flushTimer.C: - size := len(peer.queueOutboundRouting) - for i := 0; i < size; i += 1 { - <-peer.queueOutboundRouting - } + case <-peer.signal.flushNonceQueue: + peer.FlushNonceQueue() packet = nil + continue + case <-peer.signal.stopSending: + close(peer.queue.outbound) + return } - break } // process current packet @@ -114,14 +166,13 @@ func (peer *Peer) RoutineOutboundNonceWorker() { // create work element - work := new(OutboundWorkQueueElement) - work.wg.Add(1) + work := new(QueueOutboundElement) // TODO: profile, maybe use pool work.keyPair = keyPair work.packet = packet work.nonce = keyPair.sendNonce + work.mutex.Lock() packet = nil - peer.queueOutbound <- work keyPair.sendNonce += 1 // drop packets until there is space @@ -129,46 +180,36 @@ func (peer *Peer) RoutineOutboundNonceWorker() { func() { for { select { - case peer.device.queueWorkOutbound <- work: + case peer.device.queue.encryption <- work: return default: - drop := <-peer.device.queueWorkOutbound + drop := <-peer.device.queue.encryption drop.packet = nil - drop.wg.Done() + drop.mutex.Unlock() } } }() + peer.queue.outbound <- work } } } -/* Go routine - * - * sequentially reads packets from queue and sends to endpoint +/* Encrypts the elements in the queue + * and marks them for sequential consumption (by releasing the mutex) * + * Obs. One instance per core */ -func (peer *Peer) RoutineSequential() { - for work := range peer.queueOutbound { - work.wg.Wait() - if work.packet == nil { - continue - } - if peer.endpoint == nil { - continue - } - peer.device.conn.WriteToUDP(work.packet, peer.endpoint) - } -} - -func (device *Device) RoutineEncryptionWorker() { +func (device *Device) RoutineEncryption() { var nonce [chacha20poly1305.NonceSize]byte - for work := range device.queueWorkOutbound { + for work := range device.queue.encryption { + // pad packet padding := device.mtu - len(work.packet) if padding < 0 { + // drop work.packet = nil - work.wg.Done() + work.mutex.Unlock() } for n := 0; n < padding; n += 1 { work.packet = append(work.packet, 0) @@ -183,6 +224,30 @@ func (device *Device) RoutineEncryptionWorker() { work.packet, nil, ) - work.wg.Done() + work.mutex.Unlock() + } +} + +/* Sequentially reads packets from queue and sends to endpoint + * + * Obs. Single instance per peer. + * The routine terminates then the outbound queue is closed. + */ +func (peer *Peer) RoutineSequential() { + for work := range peer.queue.outbound { + work.mutex.Lock() + func() { + peer.mutex.RLock() + defer peer.mutex.RUnlock() + if work.packet == nil { + return + } + if peer.endpoint == nil { + return + } + peer.device.conn.WriteToUDP(work.packet, peer.endpoint) + peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval) + }() + work.mutex.Unlock() } } diff --git a/src/trie.go b/src/trie.go index 746c1b4..4049167 100644 --- a/src/trie.go +++ b/src/trie.go @@ -1,15 +1,20 @@ package main import ( + "errors" "net" ) /* Binary trie * + * The net.IPs used here are not formatted the + * same way as those created by the "net" functions. + * Here the IPs are slices of either 4 or 16 byte (not always 16) + * * Syncronization done seperatly * See: routing.go * - * Todo: Better commenting + * TODO: Better commenting */ type Trie struct { @@ -24,7 +29,7 @@ type Trie struct { } /* Finds length of matching prefix - * Maybe there is a faster way + * TODO: Make faster * * Assumption: len(ip1) == len(ip2) */ @@ -189,3 +194,25 @@ func (node *Trie) Count() uint { r := node.child[1].Count() return l + r } + +func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) { + if node.peer == p { + var mask net.IPNet + mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8) + if len(node.bits) == net.IPv4len { + mask.IP = net.IPv4( + node.bits[0], + node.bits[1], + node.bits[2], + node.bits[3], + ) + } else if len(node.bits) == net.IPv6len { + mask.IP = node.bits + } else { + panic(errors.New("bug: unexpected address length")) + } + results = append(results, mask) + } + node.child[0].AllowedIPs(p, results) + node.child[1].AllowedIPs(p, results) +} @@ -1,6 +1,6 @@ package main -type TUN interface { +type TUNDevice interface { Read([]byte) (int, error) Write([]byte) (int, error) Name() string diff --git a/src/tun_linux.go b/src/tun_linux.go index d545dfa..cbbcb70 100644 --- a/src/tun_linux.go +++ b/src/tun_linux.go @@ -9,9 +9,7 @@ import ( "unsafe" ) -/* Platform dependent functions for interacting with - * TUN devices on linux systems - * +/* Implementation of the TUN device interface for linux */ const CloneDevicePath = "/dev/net/tun" @@ -45,7 +43,7 @@ func (tun *NativeTun) Read(d []byte) (int, error) { return tun.fd.Read(d) } -func CreateTUN(name string) (TUN, error) { +func CreateTUN(name string) (TUNDevice, error) { // Open clone device fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0) if err != nil { @@ -53,7 +51,7 @@ func CreateTUN(name string) (TUN, error) { } // Prepare ifreq struct - var ifr [18]byte + var ifr [128]byte var flags uint16 = IFF_TUN | IFF_NO_PI nameBytes := []byte(name) if len(nameBytes) >= IFNAMSIZ { |