summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Makefile9
-rw-r--r--src/config.go70
-rw-r--r--src/constants.go21
-rw-r--r--src/device.go41
-rw-r--r--src/handshake.go172
-rw-r--r--src/helper_test.go64
-rw-r--r--src/ip.go7
-rw-r--r--src/macs_test.go10
-rw-r--r--src/main.go45
-rw-r--r--src/noise_protocol.go127
-rw-r--r--src/noise_test.go27
-rw-r--r--src/noise_types.go17
-rw-r--r--src/peer.go46
-rw-r--r--src/routing.go12
-rw-r--r--src/send.go215
-rw-r--r--src/trie.go31
-rw-r--r--src/tun.go2
-rw-r--r--src/tun_linux.go8
18 files changed, 694 insertions, 230 deletions
diff --git a/src/Makefile b/src/Makefile
new file mode 100644
index 0000000..4ef8199
--- /dev/null
+++ b/src/Makefile
@@ -0,0 +1,9 @@
+BINARY=wireguard-go
+
+build:
+ go build -o ${BINARY}
+
+clean:
+ if [ -f ${BINARY} ]; then rm ${BINARY}; fi
+
+.PHONY: clean
diff --git a/src/config.go b/src/config.go
index cb7e9ef..3b91d00 100644
--- a/src/config.go
+++ b/src/config.go
@@ -11,7 +11,7 @@ import (
"time"
)
-/* todo : use real error code
+/* TODO : use real error code
* Many of which will be the same
*/
const (
@@ -37,8 +37,55 @@ func (s *IPCError) ErrorCode() int {
return s.Code
}
-func ipcGetOperation(socket *bufio.ReadWriter, dev *Device) {
+func ipcGetOperation(device *Device, socket *bufio.ReadWriter) error {
+ device.mutex.RLock()
+ defer device.mutex.RUnlock()
+
+ // create lines
+
+ lines := make([]string, 0, 100)
+ send := func(line string) {
+ lines = append(lines, line)
+ }
+
+ if !device.privateKey.IsZero() {
+ send("private_key=" + device.privateKey.ToHex())
+ }
+
+ if device.address != nil {
+ send(fmt.Sprintf("listen_port=%d", device.address.Port))
+ }
+
+ for _, peer := range device.peers {
+ func() {
+ peer.mutex.RLock()
+ defer peer.mutex.RUnlock()
+ send("public_key=" + peer.handshake.remoteStatic.ToHex())
+ send("preshared_key=" + peer.handshake.presharedKey.ToHex())
+ if peer.endpoint != nil {
+ send("endpoint=" + peer.endpoint.String())
+ }
+ send(fmt.Sprintf("tx_bytes=%d", peer.tx_bytes))
+ send(fmt.Sprintf("rx_bytes=%d", peer.rx_bytes))
+ send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
+ for _, ip := range device.routingTable.AllowedIPs(peer) {
+ send("allowed_ip=" + ip.String())
+ }
+ }()
+ }
+
+ // send lines
+
+ for _, line := range lines {
+ device.log.Debug.Println("config:", line)
+ _, err := socket.WriteString(line + "\n")
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
}
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
@@ -179,7 +226,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return nil
}
-func ipcListen(dev *Device, socket io.ReadWriter) error {
+func ipcListen(device *Device, socket io.ReadWriter) error {
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
reader := bufio.NewReader(s)
@@ -187,6 +234,8 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
return bufio.NewReadWriter(reader, writer)
}(socket)
+ defer buffered.Flush()
+
for {
op, err := buffered.ReadString('\n')
if err != nil {
@@ -197,17 +246,26 @@ func ipcListen(dev *Device, socket io.ReadWriter) error {
switch op {
case "set=1\n":
- err := ipcSetOperation(dev, buffered)
+ err := ipcSetOperation(device, buffered)
if err != nil {
- fmt.Fprintf(buffered, "errno=%d\n", err.ErrorCode())
+ fmt.Fprintf(buffered, "errno=%d\n\n", err.ErrorCode())
return err
} else {
- fmt.Fprintf(buffered, "errno=0\n")
+ fmt.Fprintf(buffered, "errno=0\n\n")
}
buffered.Flush()
case "get=1\n":
+ err := ipcGetOperation(device, buffered)
+ if err != nil {
+ fmt.Fprintf(buffered, "errno=1\n\n") // fix
+ return err
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
+ }
+ buffered.Flush()
+ case "\n":
default:
return errors.New("handle this please")
}
diff --git a/src/constants.go b/src/constants.go
index dc95379..e8cdd63 100644
--- a/src/constants.go
+++ b/src/constants.go
@@ -5,12 +5,17 @@ import (
)
const (
- RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
- RekeyAfterTime = time.Second * 120
- RekeyAttemptTime = time.Second * 90
- RekeyTimeout = time.Second * 5
- RejectAfterTime = time.Second * 180
- RejectAfterMessage = (1 << 64) - (1 << 4) - 1
- KeepaliveTimeout = time.Second * 10
- CookieRefreshTime = time.Second * 2
+ RekeyAfterMessage = (1 << 64) - (1 << 16) - 1
+ RekeyAfterTime = time.Second * 120
+ RekeyAttemptTime = time.Second * 90
+ RekeyTimeout = time.Second * 5 // TODO: Exponential backoff
+ RejectAfterTime = time.Second * 180
+ RejectAfterMessage = (1 << 64) - (1 << 4) - 1
+ KeepaliveTimeout = time.Second * 10
+ CookieRefreshTime = time.Second * 2
+ MaxHandshakeAttempTime = time.Second * 90
+)
+
+const (
+ QueueOutboundSize = 1024
)
diff --git a/src/device.go b/src/device.go
index b3484c5..a7a5c7b 100644
--- a/src/device.go
+++ b/src/device.go
@@ -2,23 +2,26 @@ package main
import (
"net"
+ "runtime"
"sync"
)
type Device struct {
- mtu int
- fwMark uint32
- address *net.UDPAddr // UDP source address
- conn *net.UDPConn // UDP "connection"
- mutex sync.RWMutex
- privateKey NoisePrivateKey
- publicKey NoisePublicKey
- routingTable RoutingTable
- indices IndexTable
- log *Logger
- queueWorkOutbound chan *OutboundWorkQueueElement
- peers map[NoisePublicKey]*Peer
- mac MacStateDevice
+ mtu int
+ fwMark uint32
+ address *net.UDPAddr // UDP source address
+ conn *net.UDPConn // UDP "connection"
+ mutex sync.RWMutex
+ privateKey NoisePrivateKey
+ publicKey NoisePublicKey
+ routingTable RoutingTable
+ indices IndexTable
+ log *Logger
+ queue struct {
+ encryption chan *QueueOutboundElement // parallel work queue
+ }
+ peers map[NoisePublicKey]*Peer
+ mac MacStateDevice
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
@@ -41,7 +44,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) {
}
}
-func (device *Device) Init() {
+func NewDevice(tun TUNDevice) *Device {
+ device := new(Device)
+
device.mutex.Lock()
defer device.mutex.Unlock()
@@ -49,6 +54,14 @@ func (device *Device) Init() {
device.peers = make(map[NoisePublicKey]*Peer)
device.indices.Init()
device.routingTable.Reset()
+
+ // start workers
+
+ for i := 0; i < runtime.NumCPU(); i += 1 {
+ go device.RoutineEncryption()
+ }
+ go device.RoutineReadFromTUN(tun)
+ return device
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
diff --git a/src/handshake.go b/src/handshake.go
new file mode 100644
index 0000000..238c339
--- /dev/null
+++ b/src/handshake.go
@@ -0,0 +1,172 @@
+package main
+
+import (
+ "bytes"
+ "encoding/binary"
+ "net"
+ "sync/atomic"
+ "time"
+)
+
+/* Sends a keep-alive if no packets queued for peer
+ *
+ * Used by initiator of handshake and with active keep-alive
+ */
+func (peer *Peer) SendKeepAlive() bool {
+ if len(peer.queue.nonce) == 0 {
+ select {
+ case peer.queue.nonce <- []byte{}:
+ return true
+ default:
+ return false
+ }
+ }
+ return true
+}
+
+func (peer *Peer) RoutineHandshakeInitiator() {
+ var ongoing bool
+ var begun time.Time
+ var attempts uint
+ var timeout time.Timer
+
+ device := peer.device
+ work := new(QueueOutboundElement)
+ buffer := make([]byte, 0, 1024)
+
+ queueHandshakeInitiation := func() error {
+ work.mutex.Lock()
+ defer work.mutex.Unlock()
+
+ // create initiation
+
+ msg, err := device.CreateMessageInitiation(peer)
+ if err != nil {
+ return err
+ }
+
+ // create "work" element
+
+ writer := bytes.NewBuffer(buffer[:0])
+ binary.Write(writer, binary.LittleEndian, &msg)
+ work.packet = writer.Bytes()
+ peer.mac.AddMacs(work.packet)
+ peer.InsertOutbound(work)
+ return nil
+ }
+
+ for {
+ select {
+ case <-peer.signal.stopInitiator:
+ return
+
+ case <-peer.signal.newHandshake:
+ if ongoing {
+ continue
+ }
+
+ // create handshake
+
+ err := queueHandshakeInitiation()
+ if err != nil {
+ device.log.Error.Println("Failed to create initiation message:", err)
+ }
+
+ // log when we began
+
+ begun = time.Now()
+ ongoing = true
+ attempts = 0
+ timeout.Reset(RekeyTimeout)
+
+ case <-peer.timer.sendKeepalive.C:
+
+ // active keep-alives
+
+ peer.SendKeepAlive()
+
+ case <-peer.timer.handshakeTimeout.C:
+
+ // check if we can stop trying
+
+ if time.Now().Sub(begun) > MaxHandshakeAttempTime {
+ peer.signal.flushNonceQueue <- true
+ peer.timer.sendKeepalive.Stop()
+ ongoing = false
+ continue
+ }
+
+ // otherwise, try again (exponental backoff)
+
+ attempts += 1
+ err := queueHandshakeInitiation()
+ if err != nil {
+ device.log.Error.Println("Failed to create initiation message:", err)
+ }
+ peer.timer.handshakeTimeout.Reset((1 << attempts) * RekeyTimeout)
+ }
+ }
+}
+
+/* Handles packets related to handshake
+ *
+ *
+ */
+func (device *Device) HandshakeWorker(queue chan struct {
+ msg []byte
+ msgType uint32
+ addr *net.UDPAddr
+}) {
+ for {
+ elem := <-queue
+
+ switch elem.msgType {
+ case MessageInitiationType:
+ if len(elem.msg) != MessageInitiationSize {
+ continue
+ }
+
+ // check for cookie
+
+ var msg MessageInitiation
+
+ binary.Read(nil, binary.LittleEndian, &msg)
+
+ case MessageResponseType:
+ if len(elem.msg) != MessageResponseSize {
+ continue
+ }
+
+ // check for cookie
+
+ case MessageCookieReplyType:
+
+ case MessageTransportType:
+ }
+
+ }
+}
+
+func (device *Device) KeepKeyFresh(peer *Peer) {
+
+ send := func() bool {
+ peer.keyPairs.mutex.RLock()
+ defer peer.keyPairs.mutex.RUnlock()
+
+ kp := peer.keyPairs.current
+ if kp == nil {
+ return false
+ }
+
+ nonce := atomic.LoadUint64(&kp.sendNonce)
+ if nonce > RekeyAfterMessage {
+ return true
+ }
+
+ return kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime
+ }()
+
+ if send {
+
+ }
+}
diff --git a/src/helper_test.go b/src/helper_test.go
new file mode 100644
index 0000000..3a5c331
--- /dev/null
+++ b/src/helper_test.go
@@ -0,0 +1,64 @@
+package main
+
+import (
+ "bytes"
+ "testing"
+)
+
+/* Helpers for writing unit tests
+ */
+
+type DummyTUN struct {
+ name string
+ mtu uint
+ packets chan []byte
+}
+
+func (tun *DummyTUN) Name() string {
+ return tun.name
+}
+
+func (tun *DummyTUN) MTU() uint {
+ return tun.mtu
+}
+
+func (tun *DummyTUN) Write(d []byte) (int, error) {
+ tun.packets <- d
+ return len(d), nil
+}
+
+func (tun *DummyTUN) Read(d []byte) (int, error) {
+ t := <-tun.packets
+ copy(d, t)
+ return len(t), nil
+}
+
+func CreateDummyTUN(name string) (TUNDevice, error) {
+ var dummy DummyTUN
+ dummy.mtu = 1024
+ dummy.packets = make(chan []byte, 100)
+ return &dummy, nil
+}
+
+func assertNil(t *testing.T, err error) {
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func assertEqual(t *testing.T, a []byte, b []byte) {
+ if bytes.Compare(a, b) != 0 {
+ t.Fatal(a, "!=", b)
+ }
+}
+
+func randDevice(t *testing.T) *Device {
+ sk, err := newPrivateKey()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tun, _ := CreateDummyTUN("dummy")
+ device := NewDevice(tun)
+ device.SetPrivateKey(sk)
+ return device
+}
diff --git a/src/ip.go b/src/ip.go
index 3137891..a9685ad 100644
--- a/src/ip.go
+++ b/src/ip.go
@@ -5,9 +5,10 @@ import (
)
const (
- IPv4version = 4
- IPv4offsetSrc = 12
- IPv4offsetDst = IPv4offsetSrc + net.IPv4len
+ IPv4version = 4
+ IPv4offsetSrc = 12
+ IPv4offsetDst = IPv4offsetSrc + net.IPv4len
+ IPv4headerSize = 20
)
const (
diff --git a/src/macs_test.go b/src/macs_test.go
index a67ccfb..fcb64ea 100644
--- a/src/macs_test.go
+++ b/src/macs_test.go
@@ -8,8 +8,8 @@ import (
)
func TestMAC1(t *testing.T) {
- dev1 := newDevice(t)
- dev2 := newDevice(t)
+ dev1 := randDevice(t)
+ dev2 := randDevice(t)
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
@@ -34,12 +34,10 @@ func TestMACs(t *testing.T) {
msg []byte,
receiver uint32,
) bool {
- var device1 Device
- device1.Init()
+ device1 := randDevice(t)
device1.SetPrivateKey(sk1)
- var device2 Device
- device2.Init()
+ device2 := randDevice(t)
device2.SetPrivateKey(sk2)
peer1 := device2.NewPeer(device1.privateKey.publicKey())
diff --git a/src/main.go b/src/main.go
index b6f6deb..7c58972 100644
--- a/src/main.go
+++ b/src/main.go
@@ -1,36 +1,30 @@
package main
import (
- "fmt"
+ "log"
+ "net"
)
+/*
+ *
+ * TODO: Fix logging
+ */
+
func main() {
- fd, err := CreateTUN("test0")
- fmt.Println(fd, err)
+ // Open TUN device
- queue := make(chan []byte, 1000)
+ // TODO: Fix capabilities
- // var device Device
+ tun, err := CreateTUN("test0")
+ log.Println(tun, err)
+ if err != nil {
+ return
+ }
- // go OutgoingRoutingWorker(&device, queue)
+ device := NewDevice(tun)
- for {
- tmp := make([]byte, 1<<16)
- n, err := fd.Read(tmp)
- if err != nil {
- break
- }
- queue <- tmp[:n]
- }
-}
+ // Start configuration lister
-/*
-import (
- "fmt"
- "log"
- "net"
-)
-func main() {
l, err := net.Listen("unix", "/var/run/wireguard/wg0.sock")
if err != nil {
log.Fatal("listen error:", err)
@@ -41,12 +35,9 @@ func main() {
if err != nil {
log.Fatal("accept error:", err)
}
-
- var dev Device
go func(conn net.Conn) {
- err := ipcListen(&dev, conn)
- fmt.Println(err)
+ err := ipcListen(device, conn)
+ log.Println(err)
}(fd)
}
}
-*/
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index e237dbe..46ceeda 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -77,7 +77,7 @@ type MessageCookieReply struct {
type Handshake struct {
state int
- mutex sync.Mutex
+ mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
presharedKey NoiseSymmetricKey // psk
@@ -205,49 +205,64 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
}
hash = mixHash(hash, msg.Static[:])
- // find peer
+ // lookup peer
peer := device.LookupPeer(peerPK)
if peer == nil {
return nil
}
handshake := &peer.handshake
- handshake.mutex.Lock()
- defer handshake.mutex.Unlock()
- // decrypt timestamp
+ // verify identity
var timestamp TAI64N
- func() {
- var key [chacha20poly1305.KeySize]byte
- chainKey, key = KDF2(
- chainKey[:],
- handshake.precomputedStaticStatic[:],
- )
- aead, _ := chacha20poly1305.New(key[:])
- _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
- }()
- if err != nil {
- return nil
- }
- hash = mixHash(hash, msg.Timestamp[:])
+ ok := func() bool {
+
+ // read lock handshake
+
+ handshake.mutex.RLock()
+ defer handshake.mutex.RUnlock()
+
+ // decrypt timestamp
+
+ func() {
+ var key [chacha20poly1305.KeySize]byte
+ chainKey, key = KDF2(
+ chainKey[:],
+ handshake.precomputedStaticStatic[:],
+ )
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
+ }()
+ if err != nil {
+ return false
+ }
+ hash = mixHash(hash, msg.Timestamp[:])
+
+ // TODO: check for flood attack
+
+ // check for replay attack
- // check for replay attack
+ return timestamp.After(handshake.lastTimestamp)
+ }()
- if !timestamp.After(handshake.lastTimestamp) {
+ if !ok {
return nil
}
- // TODO: check for flood attack
-
// update handshake state
+ handshake.mutex.Lock()
+
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp
handshake.state = HandshakeInitiationConsumed
+
+ handshake.mutex.Unlock()
+
return peer
}
@@ -320,47 +335,67 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
return nil
}
- handshake.mutex.Lock()
- defer handshake.mutex.Unlock()
- if handshake.state != HandshakeInitiationCreated {
- return nil
- }
+ var (
+ hash [blake2s.Size]byte
+ chainKey [blake2s.Size]byte
+ )
- // finish 3-way DH
+ ok := func() bool {
- hash := mixHash(handshake.hash, msg.Ephemeral[:])
- chainKey := handshake.chainKey
+ // read lock handshake
- func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
- chainKey = mixKey(chainKey, ss[:])
- ss = device.privateKey.sharedSecret(msg.Ephemeral)
- chainKey = mixKey(chainKey, ss[:])
- }()
+ handshake.mutex.RLock()
+ defer handshake.mutex.RUnlock()
- // add preshared key (psk)
+ if handshake.state != HandshakeInitiationCreated {
+ return false
+ }
- var tau [blake2s.Size]byte
- var key [chacha20poly1305.KeySize]byte
- chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
- hash = mixHash(hash, tau[:])
+ // finish 3-way DH
- // authenticate
+ hash = mixHash(handshake.hash, msg.Ephemeral[:])
+ chainKey = handshake.chainKey
- aead, _ := chacha20poly1305.New(key[:])
- _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
- if err != nil {
+ func() {
+ ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ chainKey = mixKey(chainKey, ss[:])
+ ss = device.privateKey.sharedSecret(msg.Ephemeral)
+ chainKey = mixKey(chainKey, ss[:])
+ }()
+
+ // add preshared key (psk)
+
+ var tau [blake2s.Size]byte
+ var key [chacha20poly1305.KeySize]byte
+ chainKey, tau, key = KDF3(chainKey[:], handshake.presharedKey[:])
+ hash = mixHash(hash, tau[:])
+
+ // authenticate
+
+ aead, _ := chacha20poly1305.New(key[:])
+ _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
+ if err != nil {
+ return false
+ }
+ hash = mixHash(hash, msg.Empty[:])
+ return true
+ }()
+
+ if !ok {
return nil
}
- hash = mixHash(hash, msg.Empty[:])
// update handshake state
+ handshake.mutex.Lock()
+
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
+ handshake.mutex.Unlock()
+
return lookup.peer
}
diff --git a/src/noise_test.go b/src/noise_test.go
index dab603b..02f6bf3 100644
--- a/src/noise_test.go
+++ b/src/noise_test.go
@@ -6,29 +6,6 @@ import (
"testing"
)
-func assertNil(t *testing.T, err error) {
- if err != nil {
- t.Fatal(err)
- }
-}
-
-func assertEqual(t *testing.T, a []byte, b []byte) {
- if bytes.Compare(a, b) != 0 {
- t.Fatal(a, "!=", b)
- }
-}
-
-func newDevice(t *testing.T) *Device {
- var device Device
- sk, err := newPrivateKey()
- if err != nil {
- t.Fatal(err)
- }
- device.Init()
- device.SetPrivateKey(sk)
- return &device
-}
-
func TestCurveWrappers(t *testing.T) {
sk1, err := newPrivateKey()
assertNil(t, err)
@@ -49,8 +26,8 @@ func TestCurveWrappers(t *testing.T) {
func TestNoiseHandshake(t *testing.T) {
- dev1 := newDevice(t)
- dev2 := newDevice(t)
+ dev1 := randDevice(t)
+ dev2 := randDevice(t)
peer1 := dev2.NewPeer(dev1.privateKey.publicKey())
peer2 := dev1.NewPeer(dev2.privateKey.publicKey())
diff --git a/src/noise_types.go b/src/noise_types.go
index 5508f9a..5ebc130 100644
--- a/src/noise_types.go
+++ b/src/noise_types.go
@@ -3,18 +3,18 @@ package main
import (
"encoding/hex"
"errors"
+ "golang.org/x/crypto/chacha20poly1305"
)
const (
- NoisePublicKeySize = 32
- NoisePrivateKeySize = 32
- NoiseSymmetricKeySize = 32
+ NoisePublicKeySize = 32
+ NoisePrivateKeySize = 32
)
type (
NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte
- NoiseSymmetricKey [NoiseSymmetricKeySize]byte
+ NoiseSymmetricKey [chacha20poly1305.KeySize]byte
NoiseNonce uint64 // padded to 12-bytes
)
@@ -30,6 +30,15 @@ func loadExactHex(dst []byte, src string) error {
return nil
}
+func (key NoisePrivateKey) IsZero() bool {
+ for _, b := range key[:] {
+ if b != 0 {
+ return false
+ }
+ }
+ return true
+}
+
func (key *NoisePrivateKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
diff --git a/src/peer.go b/src/peer.go
index e192b12..21cad9d 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -7,9 +7,7 @@ import (
"time"
)
-const (
- OutboundQueueSize = 64
-)
+const ()
type Peer struct {
mutex sync.RWMutex
@@ -18,10 +16,26 @@ type Peer struct {
keyPairs KeyPairs
handshake Handshake
device *Device
- queueInbound chan []byte
- queueOutbound chan *OutboundWorkQueueElement
- queueOutboundRouting chan []byte
- mac MacStatePeer
+ tx_bytes uint64
+ rx_bytes uint64
+ time struct {
+ lastSend time.Time // last send message
+ }
+ signal struct {
+ newHandshake chan bool
+ flushNonceQueue chan bool // empty queued packets
+ stopSending chan bool // stop sending pipeline
+ stopInitiator chan bool // stop initiator timer
+ }
+ timer struct {
+ sendKeepalive time.Timer
+ handshakeTimeout time.Timer
+ }
+ queue struct {
+ nonce chan []byte // nonce / pre-handshake queue
+ outbound chan *QueueOutboundElement // sequential ordering of work
+ }
+ mac MacStatePeer
}
func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
@@ -33,7 +47,8 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.device = device
peer.keyPairs.Init()
peer.mac.Init(pk)
- peer.queueOutbound = make(chan *OutboundWorkQueueElement, OutboundQueueSize)
+ peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.nonce = make(chan []byte, QueueOutboundSize)
// map public key
@@ -54,5 +69,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
handshake.mutex.Unlock()
peer.mutex.Unlock()
+ // start workers
+
+ peer.signal.stopSending = make(chan bool, 1)
+ peer.signal.stopInitiator = make(chan bool, 1)
+ peer.signal.newHandshake = make(chan bool, 1)
+ peer.signal.flushNonceQueue = make(chan bool, 1)
+
+ go peer.RoutineNonce()
+ go peer.RoutineHandshakeInitiator()
+
return &peer
}
+
+func (peer *Peer) Close() {
+ peer.signal.stopSending <- true
+ peer.signal.stopInitiator <- true
+}
diff --git a/src/routing.go b/src/routing.go
index 4189c25..6a5e1f3 100644
--- a/src/routing.go
+++ b/src/routing.go
@@ -12,9 +12,20 @@ type RoutingTable struct {
mutex sync.RWMutex
}
+func (table *RoutingTable) AllowedIPs(peer *Peer) []net.IPNet {
+ table.mutex.RLock()
+ defer table.mutex.RUnlock()
+
+ allowed := make([]net.IPNet, 10)
+ table.IPv4.AllowedIPs(peer, allowed)
+ table.IPv6.AllowedIPs(peer, allowed)
+ return allowed
+}
+
func (table *RoutingTable) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()
+
table.IPv4 = nil
table.IPv6 = nil
}
@@ -22,6 +33,7 @@ func (table *RoutingTable) Reset() {
func (table *RoutingTable) RemovePeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
+
table.IPv4 = table.IPv4.RemovePeer(peer)
table.IPv6 = table.IPv6.RemovePeer(peer)
}
diff --git a/src/send.go b/src/send.go
index f58d311..4ff75db 100644
--- a/src/send.go
+++ b/src/send.go
@@ -5,107 +5,159 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"net"
"sync"
- "time"
)
/* Handles outbound flow
*
* 1. TUN queue
- * 2. Routing
- * 3. Per peer queuing
- * 4. (work queuing)
+ * 2. Routing (sequential)
+ * 3. Nonce assignment (sequential)
+ * 4. Encryption (parallel)
+ * 5. Transmission (sequential)
*
+ * The order of packets (per peer) is maintained.
+ * The functions in this file occure (roughly) in the order packets are processed.
*/
-type OutboundWorkQueueElement struct {
- wg sync.WaitGroup
+/* A work unit
+ *
+ * The sequential consumers will attempt to take the lock,
+ * workers release lock when they have completed work on the packet.
+ */
+type QueueOutboundElement struct {
+ mutex sync.Mutex
packet []byte
nonce uint64
keyPair *KeyPair
}
-func (peer *Peer) HandshakeWorker(handshakeQueue []byte) {
-
+func (peer *Peer) FlushNonceQueue() {
+ elems := len(peer.queue.nonce)
+ for i := 0; i < elems; i += 1 {
+ select {
+ case <-peer.queue.nonce:
+ default:
+ return
+ }
+ }
}
-func (device *Device) SendPacket(packet []byte) {
+func (peer *Peer) InsertOutbound(elem *QueueOutboundElement) {
+ for {
+ select {
+ case peer.queue.outbound <- elem:
+ default:
+ select {
+ case <-peer.queue.outbound:
+ default:
+ }
+ }
+ }
+}
- // lookup peer
+/* Reads packets from the TUN and inserts
+ * into nonce queue for peer
+ *
+ * Obs. Single instance per TUN device
+ */
+func (device *Device) RoutineReadFromTUN(tun TUNDevice) {
+ for {
+ // read packet
- var peer *Peer
- switch packet[0] >> 4 {
- case IPv4version:
- dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.routingTable.LookupIPv4(dst)
+ packet := make([]byte, 1<<16) // TODO: Fix & avoid dynamic allocation
+ size, err := tun.Read(packet)
+ if err != nil {
+ device.log.Error.Println("Failed to read packet from TUN device:", err)
+ continue
+ }
+ packet = packet[:size]
+ if len(packet) < IPv4headerSize {
+ device.log.Error.Println("Packet too short, length:", len(packet))
+ continue
+ }
- case IPv6version:
- dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.routingTable.LookupIPv6(dst)
+ device.log.Debug.Println("New packet on TUN:", packet) // TODO: Slow debugging, remove.
- default:
- device.log.Debug.Println("receieved packet with unknown IP version")
- return
- }
+ // lookup peer
- if peer == nil {
- return
- }
+ var peer *Peer
+ switch packet[0] >> 4 {
+ case IPv4version:
+ dst := packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
+ peer = device.routingTable.LookupIPv4(dst)
- // insert into peer queue
+ case IPv6version:
+ dst := packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
+ peer = device.routingTable.LookupIPv6(dst)
- for {
- select {
- case peer.queueOutboundRouting <- packet:
default:
+ device.log.Debug.Println("Receieved packet with unknown IP version")
+ return
+ }
+
+ if peer == nil {
+ device.log.Debug.Println("No peer configured for IP")
+ return
+ }
+
+ // insert into nonce/pre-handshake queue
+
+ for {
select {
- case <-peer.queueOutboundRouting:
+ case peer.queue.nonce <- packet:
default:
+ select {
+ case <-peer.queue.nonce:
+ default:
+ }
+ continue
}
- continue
+ break
}
- break
}
}
-/* Go routine
+/* Queues packets when there is no handshake.
+ * Then assigns nonces to packets sequentially
+ * and creates "work" structs for workers
*
+ * TODO: Avoid dynamic allocation of work queue elements
*
- * 1. waits for handshake.
- * 2. assigns key pair & nonce
- * 3. inserts to working queue
- *
- * TODO: avoid dynamic allocation of work queue elements
+ * Obs. A single instance per peer
*/
-func (peer *Peer) RoutineOutboundNonceWorker() {
+func (peer *Peer) RoutineNonce() {
var packet []byte
var keyPair *KeyPair
- var flushTimer time.Timer
for {
// wait for packet
if packet == nil {
- packet = <-peer.queueOutboundRouting
+ select {
+ case packet = <-peer.queue.nonce:
+ case <-peer.signal.stopSending:
+ close(peer.queue.outbound)
+ return
+ }
}
// wait for key pair
for keyPair == nil {
- flushTimer.Reset(time.Second * 10)
- // TODO: Handshake or NOP
+ peer.signal.newHandshake <- true
select {
case <-peer.keyPairs.newKeyPair:
keyPair = peer.keyPairs.Current()
continue
- case <-flushTimer.C:
- size := len(peer.queueOutboundRouting)
- for i := 0; i < size; i += 1 {
- <-peer.queueOutboundRouting
- }
+ case <-peer.signal.flushNonceQueue:
+ peer.FlushNonceQueue()
packet = nil
+ continue
+ case <-peer.signal.stopSending:
+ close(peer.queue.outbound)
+ return
}
- break
}
// process current packet
@@ -114,14 +166,13 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
// create work element
- work := new(OutboundWorkQueueElement)
- work.wg.Add(1)
+ work := new(QueueOutboundElement) // TODO: profile, maybe use pool
work.keyPair = keyPair
work.packet = packet
work.nonce = keyPair.sendNonce
+ work.mutex.Lock()
packet = nil
- peer.queueOutbound <- work
keyPair.sendNonce += 1
// drop packets until there is space
@@ -129,46 +180,36 @@ func (peer *Peer) RoutineOutboundNonceWorker() {
func() {
for {
select {
- case peer.device.queueWorkOutbound <- work:
+ case peer.device.queue.encryption <- work:
return
default:
- drop := <-peer.device.queueWorkOutbound
+ drop := <-peer.device.queue.encryption
drop.packet = nil
- drop.wg.Done()
+ drop.mutex.Unlock()
}
}
}()
+ peer.queue.outbound <- work
}
}
}
-/* Go routine
- *
- * sequentially reads packets from queue and sends to endpoint
+/* Encrypts the elements in the queue
+ * and marks them for sequential consumption (by releasing the mutex)
*
+ * Obs. One instance per core
*/
-func (peer *Peer) RoutineSequential() {
- for work := range peer.queueOutbound {
- work.wg.Wait()
- if work.packet == nil {
- continue
- }
- if peer.endpoint == nil {
- continue
- }
- peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
- }
-}
-
-func (device *Device) RoutineEncryptionWorker() {
+func (device *Device) RoutineEncryption() {
var nonce [chacha20poly1305.NonceSize]byte
- for work := range device.queueWorkOutbound {
+ for work := range device.queue.encryption {
+
// pad packet
padding := device.mtu - len(work.packet)
if padding < 0 {
+ // drop
work.packet = nil
- work.wg.Done()
+ work.mutex.Unlock()
}
for n := 0; n < padding; n += 1 {
work.packet = append(work.packet, 0)
@@ -183,6 +224,30 @@ func (device *Device) RoutineEncryptionWorker() {
work.packet,
nil,
)
- work.wg.Done()
+ work.mutex.Unlock()
+ }
+}
+
+/* Sequentially reads packets from queue and sends to endpoint
+ *
+ * Obs. Single instance per peer.
+ * The routine terminates then the outbound queue is closed.
+ */
+func (peer *Peer) RoutineSequential() {
+ for work := range peer.queue.outbound {
+ work.mutex.Lock()
+ func() {
+ peer.mutex.RLock()
+ defer peer.mutex.RUnlock()
+ if work.packet == nil {
+ return
+ }
+ if peer.endpoint == nil {
+ return
+ }
+ peer.device.conn.WriteToUDP(work.packet, peer.endpoint)
+ peer.timer.sendKeepalive.Reset(peer.persistentKeepaliveInterval)
+ }()
+ work.mutex.Unlock()
}
}
diff --git a/src/trie.go b/src/trie.go
index 746c1b4..4049167 100644
--- a/src/trie.go
+++ b/src/trie.go
@@ -1,15 +1,20 @@
package main
import (
+ "errors"
"net"
)
/* Binary trie
*
+ * The net.IPs used here are not formatted the
+ * same way as those created by the "net" functions.
+ * Here the IPs are slices of either 4 or 16 byte (not always 16)
+ *
* Syncronization done seperatly
* See: routing.go
*
- * Todo: Better commenting
+ * TODO: Better commenting
*/
type Trie struct {
@@ -24,7 +29,7 @@ type Trie struct {
}
/* Finds length of matching prefix
- * Maybe there is a faster way
+ * TODO: Make faster
*
* Assumption: len(ip1) == len(ip2)
*/
@@ -189,3 +194,25 @@ func (node *Trie) Count() uint {
r := node.child[1].Count()
return l + r
}
+
+func (node *Trie) AllowedIPs(p *Peer, results []net.IPNet) {
+ if node.peer == p {
+ var mask net.IPNet
+ mask.Mask = net.CIDRMask(int(node.cidr), len(node.bits)*8)
+ if len(node.bits) == net.IPv4len {
+ mask.IP = net.IPv4(
+ node.bits[0],
+ node.bits[1],
+ node.bits[2],
+ node.bits[3],
+ )
+ } else if len(node.bits) == net.IPv6len {
+ mask.IP = node.bits
+ } else {
+ panic(errors.New("bug: unexpected address length"))
+ }
+ results = append(results, mask)
+ }
+ node.child[0].AllowedIPs(p, results)
+ node.child[1].AllowedIPs(p, results)
+}
diff --git a/src/tun.go b/src/tun.go
index 1a8bb82..594754a 100644
--- a/src/tun.go
+++ b/src/tun.go
@@ -1,6 +1,6 @@
package main
-type TUN interface {
+type TUNDevice interface {
Read([]byte) (int, error)
Write([]byte) (int, error)
Name() string
diff --git a/src/tun_linux.go b/src/tun_linux.go
index d545dfa..cbbcb70 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -9,9 +9,7 @@ import (
"unsafe"
)
-/* Platform dependent functions for interacting with
- * TUN devices on linux systems
- *
+/* Implementation of the TUN device interface for linux
*/
const CloneDevicePath = "/dev/net/tun"
@@ -45,7 +43,7 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
return tun.fd.Read(d)
}
-func CreateTUN(name string) (TUN, error) {
+func CreateTUN(name string) (TUNDevice, error) {
// Open clone device
fd, err := os.OpenFile(CloneDevicePath, os.O_RDWR, 0)
if err != nil {
@@ -53,7 +51,7 @@ func CreateTUN(name string) (TUN, error) {
}
// Prepare ifreq struct
- var ifr [18]byte
+ var ifr [128]byte
var flags uint16 = IFF_TUN | IFF_NO_PI
nameBytes := []byte(name)
if len(nameBytes) >= IFNAMSIZ {