summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-08-07 15:25:04 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-08-07 15:25:04 +0200
commitcba1d6585ab9b12ae3e0897db85675ba452c3f09 (patch)
tree13d0975bf53a107c2760c833fd07f36d860a338a
parent8c34c4cbb3780c433148966a004f5a51aace0f64 (diff)
Number of fixes in response to code review
This version cannot complete a handshake. The program will panic upon receiving any message on the UDP socket.
-rw-r--r--src/config.go102
-rw-r--r--src/constants.go20
-rw-r--r--src/daemon_linux.go2
-rw-r--r--src/device.go90
-rw-r--r--src/macs.go19
-rw-r--r--src/peer.go19
-rw-r--r--src/receive.go507
-rw-r--r--src/send.go3
-rw-r--r--src/timers.go74
-rw-r--r--src/tun.go11
-rw-r--r--src/tun_linux.go23
-rw-r--r--src/uapi_linux.go31
12 files changed, 504 insertions, 397 deletions
diff --git a/src/config.go b/src/config.go
index e2d7f20..d952a3a 100644
--- a/src/config.go
+++ b/src/config.go
@@ -84,13 +84,47 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return nil
}
+func updateUDPConn(device *Device) error {
+ var err error
+ netc := &device.net
+ netc.mutex.Lock()
+
+ // close existing connection
+
+ if netc.conn != nil {
+ netc.conn.Close()
+ netc.conn = nil
+ }
+
+ // open new existing connection
+
+ conn, err := net.ListenUDP("udp", netc.addr)
+ if err == nil {
+ netc.conn = conn
+ signalSend(device.signal.newUDPConn)
+ }
+
+ netc.mutex.Unlock()
+ return err
+}
+
+func closeUDPConn(device *Device) {
+ device.net.mutex.Lock()
+ device.net.conn = nil
+ device.net.mutex.Unlock()
+ println("send signal")
+ signalSend(device.signal.newUDPConn)
+}
+
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
scanner := bufio.NewScanner(socket)
+ logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
var peer *Peer
+ dummy := false
deviceConfig := true
for scanner.Scan() {
@@ -135,17 +169,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
netc := &device.net
netc.mutex.Lock()
if netc.addr.Port != int(port) {
- if netc.conn != nil {
- netc.conn.Close()
- }
netc.addr.Port = int(port)
- netc.conn, err = net.ListenUDP("udp", netc.addr)
}
netc.mutex.Unlock()
- if err != nil {
- logError.Println("Failed to create UDP listener:", err)
- return &IPCError{Code: ipcErrorIO}
- }
+ updateUDPConn(device)
+
// TODO: Clear source address of all peers
case "fwmark":
@@ -189,17 +217,30 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.mutex.RLock()
if device.publicKey.Equals(pubKey) {
+
+ // create dummy instance
+
+ peer = &Peer{}
+ dummy = true
device.mutex.RUnlock()
- logError.Println("Public key of peer matches private key of device")
- return &IPCError{Code: ipcErrorInvalid}
- }
+ logInfo.Println("Ignoring peer with public key of device")
+
+ } else {
+
+ // find peer referenced
- // find peer referenced
+ peer, _ = device.peers[pubKey]
+ device.mutex.RUnlock()
+ if peer == nil {
+ peer, err = device.NewPeer(pubKey)
+ if err != nil {
+ logError.Println("Failed to create new peer:", err)
+ return &IPCError{Code: ipcErrorInvalid}
+ }
+ }
+ signalSend(peer.signal.handshakeReset)
+ dummy = false
- peer, _ = device.peers[pubKey]
- device.mutex.RUnlock()
- if peer == nil {
- peer = device.NewPeer(pubKey)
}
case "remove":
@@ -207,16 +248,17 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
- device.RemovePeer(peer.handshake.remoteStatic)
- logDebug.Println("Removing", peer.String())
- peer = nil
+ if !dummy {
+ logDebug.Println("Removing", peer.String())
+ device.RemovePeer(peer.handshake.remoteStatic)
+ }
+ peer = &Peer{}
+ dummy = true
case "preshared_key":
- err := func() error {
- peer.mutex.Lock()
- defer peer.mutex.Unlock()
- return peer.handshake.presharedKey.FromHex(value)
- }()
+ peer.mutex.Lock()
+ err := peer.handshake.presharedKey.FromHex(value)
+ peer.mutex.Unlock()
if err != nil {
logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalid}
@@ -232,6 +274,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
peer.mutex.Lock()
peer.endpoint = addr
peer.mutex.Unlock()
+ signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval":
@@ -251,12 +294,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// send immediate keep-alive
if old == 0 && secs != 0 {
- up, err := device.tun.IsUp()
if err != nil {
logError.Println("Failed to get tun device status:", err)
return &IPCError{Code: ipcErrorIO}
}
- if up {
+ if atomic.LoadInt32(&device.isUp) == AtomicTrue && !dummy {
peer.SendKeepAlive()
}
}
@@ -266,7 +308,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
- device.routingTable.RemovePeer(peer)
+ if !dummy {
+ device.routingTable.RemovePeer(peer)
+ }
case "allowed_ip":
_, network, err := net.ParseCIDR(value)
@@ -275,7 +319,9 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
ones, _ := network.Mask.Size()
- device.routingTable.Insert(network.IP, uint(ones), peer)
+ if !dummy {
+ device.routingTable.Insert(network.IP, uint(ones), peer)
+ }
default:
logError.Println("Invalid UAPI key (peer configuration):", key)
diff --git a/src/constants.go b/src/constants.go
index f09ded6..37603e8 100644
--- a/src/constants.go
+++ b/src/constants.go
@@ -7,16 +7,15 @@ import (
/* Specification constants */
const (
- RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
- RejectAfterMessages = (1 << 64) - (1 << 4) - 1
- RekeyAfterTime = time.Second * 120
- RekeyAttemptTime = time.Second * 90
- RekeyTimeout = time.Second * 5
- RejectAfterTime = time.Second * 180
- KeepaliveTimeout = time.Second * 10
- CookieRefreshTime = time.Second * 120
- MaxHandshakeAttemptTime = time.Second * 90
- PaddingMultiple = 16
+ RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
+ RejectAfterMessages = (1 << 64) - (1 << 4) - 1
+ RekeyAfterTime = time.Second * 120
+ RekeyAttemptTime = time.Second * 90
+ RekeyTimeout = time.Second * 5
+ RejectAfterTime = time.Second * 180
+ KeepaliveTimeout = time.Second * 10
+ CookieRefreshTime = time.Second * 120
+ PaddingMultiple = 16
)
const (
@@ -33,4 +32,5 @@ const (
QueueHandshakeBusySize = QueueHandshakeSize / 8
MinMessageSize = MessageTransportSize // size of keep-alive
MaxMessageSize = ((1 << 16) - 1) + MessageTransportHeaderSize
+ MaxPeers = 1 << 16
)
diff --git a/src/daemon_linux.go b/src/daemon_linux.go
index 809c176..730f89e 100644
--- a/src/daemon_linux.go
+++ b/src/daemon_linux.go
@@ -7,6 +7,8 @@ import (
/* Daemonizes the process on linux
*
* This is done by spawning and releasing a copy with the --foreground flag
+ *
+ * TODO: Use env variable to spawn in background
*/
func Daemonize() error {
diff --git a/src/device.go b/src/device.go
index de96f0b..4aa90e3 100644
--- a/src/device.go
+++ b/src/device.go
@@ -1,13 +1,10 @@
package main
import (
- "errors"
- "fmt"
"net"
"runtime"
"sync"
"sync/atomic"
- "time"
)
type Device struct {
@@ -34,31 +31,45 @@ type Device struct {
queue struct {
encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement
- inbound chan *QueueInboundElement
handshake chan QueueHandshakeElement
}
signal struct {
- stop chan struct{}
+ stop chan struct{} // halts all go routines
+ newUDPConn chan struct{} // a net.conn was set
}
- underLoad int32 // used as an atomic bool
+ isUp int32 // atomic bool: interface is up
+ underLoad int32 // atomic bool: device is under load
ratelimiter Ratelimiter
peers map[NoisePublicKey]*Peer
mac MACStateDevice
}
+/* Warning:
+ * The caller must hold the device mutex (write lock)
+ */
+func removePeerUnsafe(device *Device, key NoisePublicKey) {
+ peer, ok := device.peers[key]
+ if !ok {
+ return
+ }
+ peer.mutex.Lock()
+ device.routingTable.RemovePeer(peer)
+ delete(device.peers, key)
+ peer.Close()
+}
+
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.mutex.Lock()
defer device.mutex.Unlock()
- // check if public key is matching any peer
+ // remove peers with matching public keys
publicKey := sk.publicKey()
- for _, peer := range device.peers {
+ for key, peer := range device.peers {
h := &peer.handshake
h.mutex.RLock()
if h.remoteStatic.Equals(publicKey) {
- h.mutex.RUnlock()
- return errors.New("Private key matches public key of peer")
+ removePeerUnsafe(device, key)
}
h.mutex.RUnlock()
}
@@ -71,17 +82,19 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do DH precomputations
- isZero := device.privateKey.IsZero()
+ rmKey := device.privateKey.IsZero()
- for _, peer := range device.peers {
+ for key, peer := range device.peers {
h := &peer.handshake
h.mutex.Lock()
- if isZero {
+ if rmKey {
h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+ if isZero(h.precomputedStaticStatic[:]) {
+ removePeerUnsafe(device, key)
+ }
}
- fmt.Println(h.precomputedStaticStatic)
h.mutex.Unlock()
}
@@ -130,11 +143,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
- device.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals
device.signal.stop = make(chan struct{})
+ device.signal.newUDPConn = make(chan struct{}, 1)
// start workers
@@ -145,33 +158,42 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
}
go device.RoutineBusyMonitor()
- go device.RoutineMTUUpdater()
- go device.RoutineWriteToTUN()
go device.RoutineReadFromTUN()
+ go device.RoutineTUNEventReader()
go device.RoutineReceiveIncomming()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
return device
}
-func (device *Device) RoutineMTUUpdater() {
+func (device *Device) RoutineTUNEventReader() {
+ events := device.tun.Events()
logError := device.log.Error
- for ; ; time.Sleep(5 * time.Second) {
- // load updated MTU
-
- mtu, err := device.tun.MTU()
- if err != nil {
- logError.Println("Failed to load updated MTU of device:", err)
- continue
+ for event := range events {
+ if event&TUNEventMTUUpdate != 0 {
+ mtu, err := device.tun.MTU()
+ if err != nil {
+ logError.Println("Failed to load updated MTU of device:", err)
+ } else {
+ if mtu+MessageTransportSize > MaxMessageSize {
+ mtu = MaxMessageSize - MessageTransportSize
+ }
+ atomic.StoreInt32(&device.mtu, int32(mtu))
+ }
}
- // upper bound of mtu
+ if event&TUNEventUp != 0 {
+ println("handle 1")
+ atomic.StoreInt32(&device.isUp, AtomicTrue)
+ updateUDPConn(device)
+ println("handle 2", device.net.conn)
+ }
- if mtu+MessageTransportSize > MaxMessageSize {
- mtu = MaxMessageSize - MessageTransportSize
+ if event&TUNEventDown != 0 {
+ atomic.StoreInt32(&device.isUp, AtomicFalse)
+ closeUDPConn(device)
}
- atomic.StoreInt32(&device.mtu, int32(mtu))
}
}
@@ -184,15 +206,7 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
func (device *Device) RemovePeer(key NoisePublicKey) {
device.mutex.Lock()
defer device.mutex.Unlock()
-
- peer, ok := device.peers[key]
- if !ok {
- return
- }
- peer.mutex.Lock()
- device.routingTable.RemovePeer(peer)
- delete(device.peers, key)
- peer.Close()
+ removePeerUnsafe(device, key)
}
func (device *Device) RemoveAllPeers() {
diff --git a/src/macs.go b/src/macs.go
index beb5f76..d55e18f 100644
--- a/src/macs.go
+++ b/src/macs.go
@@ -18,12 +18,13 @@ type MACStateDevice struct {
}
type MACStatePeer struct {
- mutex sync.RWMutex
- cookieSet time.Time
- cookie [blake2s.Size128]byte
- lastMAC1 [blake2s.Size128]byte // TODO: Check if set
- keyMAC1 [blake2s.Size]byte
- keyMAC2 [blake2s.Size]byte
+ mutex sync.RWMutex
+ cookieSet time.Time
+ cookie [blake2s.Size128]byte
+ lastMAC1Set bool
+ lastMAC1 [blake2s.Size128]byte
+ keyMAC1 [blake2s.Size]byte
+ keyMAC2 [blake2s.Size]byte
}
/* Methods for verifing MAC fields
@@ -184,6 +185,10 @@ func (device *Device) ConsumeMessageCookieReply(msg *MessageCookieReply) bool {
state.mutex.Lock()
defer state.mutex.Unlock()
+ if !state.lastMAC1Set {
+ return false
+ }
+
_, err := XChaCha20Poly1305Decrypt(
cookie[:0],
&msg.Nonce,
@@ -246,7 +251,7 @@ func (state *MACStatePeer) AddMacs(msg []byte) {
mac.Sum(mac1[:0])
}()
copy(state.lastMAC1[:], mac1)
- // TODO: Set lastMac flag
+ state.lastMAC1Set = true
// set mac2
diff --git a/src/peer.go b/src/peer.go
index 9136959..02aac3b 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -9,16 +9,14 @@ import (
"time"
)
-const ()
-
type Peer struct {
id uint
mutex sync.RWMutex
- endpoint *net.UDPAddr
persistentKeepaliveInterval uint64
keyPairs KeyPairs
handshake Handshake
device *Device
+ endpoint *net.UDPAddr
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
@@ -34,6 +32,7 @@ type Peer struct {
newKeyPair chan struct{} // (size 1) : a new key pair was generated
handshakeBegin chan struct{} // (size 1) : request that a new handshake be started ("queue handshake")
handshakeCompleted chan struct{} // (size 1) : handshake completed
+ handshakeReset chan struct{} // (size 1) : reset handshake negotiation state
flushNonceQueue chan struct{} // (size 1) : empty queued packets
messageSend chan struct{} // (size 1) : a message was send to the peer
messageReceived chan struct{} // (size 1) : an authenticated message was received
@@ -44,6 +43,7 @@ type Peer struct {
keepalivePassive *time.Timer // set upon recieving messages
newHandshake *time.Timer // begin a new handshake (after Keepalive + RekeyTimeout)
zeroAllKeys *time.Timer // zero all key material (after RejectAfterTime*3)
+ handshakeDeadline *time.Timer // Current handshake must be completed
pendingKeepalivePassive bool
pendingNewHandshake bool
@@ -59,7 +59,7 @@ type Peer struct {
mac MACStatePeer
}
-func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
+func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// create peer
peer := new(Peer)
@@ -80,11 +80,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.id = device.idCounter
device.idCounter += 1
+ // check if over limit
+
+ if len(device.peers) >= MaxPeers {
+ return nil, errors.New("Too many peers")
+ }
+
// map public key
_, ok := device.peers[pk]
if ok {
- panic(errors.New("bug: adding existing peer"))
+ return nil, errors.New("Adding existing peer")
}
device.peers[pk] = peer
device.mutex.Unlock()
@@ -108,6 +114,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
peer.signal.stop = make(chan struct{})
peer.signal.newKeyPair = make(chan struct{}, 1)
peer.signal.handshakeBegin = make(chan struct{}, 1)
+ peer.signal.handshakeReset = make(chan struct{}, 1)
peer.signal.handshakeCompleted = make(chan struct{}, 1)
peer.signal.flushNonceQueue = make(chan struct{}, 1)
@@ -117,7 +124,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) *Peer {
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
- return peer
+ return peer, nil
}
func (peer *Peer) String() string {
diff --git a/src/receive.go b/src/receive.go
index fb5c51f..5f46925 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -111,113 +111,84 @@ func (device *Device) RoutineBusyMonitor() {
func (device *Device) RoutineReceiveIncomming() {
- logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started")
- var buffer *[MaxMessageSize]byte
-
for {
- // check if stopped
+ // wait for new conn
+
+ var conn *net.UDPConn
select {
+ case <-device.signal.newUDPConn:
+ device.net.mutex.RLock()
+ conn = device.net.conn
+ device.net.mutex.RUnlock()
+
case <-device.signal.stop:
return
- default:
}
- // read next datagram
-
- if buffer == nil {
- buffer = device.GetMessageBuffer()
- }
-
- // TODO: Take writelock to sleep
- device.net.mutex.RLock()
- conn := device.net.conn
- device.net.mutex.RUnlock()
if conn == nil {
- time.Sleep(time.Second)
continue
}
- // TODO: Wait for new conn or message
- conn.SetReadDeadline(time.Now().Add(time.Second))
+ // receive datagrams until closed
- size, raddr, err := conn.ReadFromUDP(buffer[:])
- if err != nil || size < MinMessageSize {
- continue
- }
+ buffer := device.GetMessageBuffer()
- // handle packet
+ for {
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ // read next datagram
- func() {
- switch msgType {
-
- case MessageInitiationType, MessageResponseType:
-
- // TODO: Check size early
+ size, raddr, err := conn.ReadFromUDP(buffer[:]) // TODO: This is broken
- // add to handshake queue
+ if err != nil {
+ break
+ }
- device.addToHandshakeQueue(
- device.queue.handshake,
- QueueHandshakeElement{
- msgType: msgType,
- buffer: buffer,
- packet: packet,
- source: raddr,
- },
- )
- buffer = nil
+ if size < MinMessageSize {
+ continue
+ }
- case MessageCookieReplyType:
+ // check size of packet
- // TODO: Queue all the things
+ packet := buffer[:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- // verify and update peer cookie state
+ var okay bool
- if len(packet) != MessageCookieReplySize {
- return
- }
+ switch msgType {
- var reply MessageCookieReply
- reader := bytes.NewReader(packet)
- err := binary.Read(reader, binary.LittleEndian, &reply)
- if err != nil {
- logDebug.Println("Failed to decode cookie reply")
- return
- }
- device.ConsumeMessageCookieReply(&reply)
+ // check if transport
case MessageTransportType:
- // lookup key pair
+ // check size
- if len(packet) < MessageTransportSize {
- return
+ if len(packet) < MessageTransportType {
+ continue
}
+ // lookup key pair
+
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indices.Lookup(receiver)
keyPair := value.keyPair
if keyPair == nil {
- return
+ continue
}
// check key-pair expiry
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
- return
+ continue
}
- // add to peer queue
+ // create work element
peer := value.peer
elem := &QueueInboundElement{
@@ -233,11 +204,33 @@ func (device *Device) RoutineReceiveIncomming() {
device.addToInboundQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem)
buffer = nil
+ continue
- default:
- logInfo.Println("Got unknown message from:", raddr)
+ // otherwise it is a handshake related packet
+
+ case MessageInitiationType:
+ okay = len(packet) == MessageInitiationSize
+
+ case MessageResponseType:
+ okay = len(packet) == MessageResponseSize
+
+ case MessageCookieReplyType:
+ okay = len(packet) == MessageCookieReplySize
}
- }()
+
+ if okay {
+ device.addToHandshakeQueue(
+ device.queue.handshake,
+ QueueHandshakeElement{
+ msgType: msgType,
+ buffer: buffer,
+ packet: packet,
+ source: raddr,
+ },
+ )
+ buffer = device.GetMessageBuffer()
+ }
+ }
}
}
@@ -306,154 +299,165 @@ func (device *Device) RoutineHandshake() {
return
}
- func() {
+ // handle cookie fields and ratelimiting
- // verify mac1
+ switch elem.msgType {
- if !device.mac.CheckMAC1(elem.packet) {
- logDebug.Println("Received packet with invalid mac1")
+ case MessageCookieReplyType:
+
+ // verify and update peer cookie state
+
+ var reply MessageCookieReply
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &reply)
+ if err != nil {
+ logDebug.Println("Failed to decode cookie reply")
return
}
+ device.ConsumeMessageCookieReply(&reply)
+ continue
- // verify mac2
+ case MessageInitiationType, MessageResponseType:
- busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue
+ // check mac fields and ratelimit
- if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
- sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
- reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
- if err != nil {
- logError.Println("Failed to create cookie reply:", err)
- return
- }
- // TODO: Use temp
- writer := bytes.NewBuffer(elem.packet[:0])
- binary.Write(writer, binary.LittleEndian, reply)
- elem.packet = writer.Bytes()
- _, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
- if err != nil {
- logDebug.Println("Failed to send cookie reply:", err)
- }
+ if !device.mac.CheckMAC1(elem.packet) {
+ logDebug.Println("Received packet with invalid mac1")
return
}
- // ratelimit
-
- // TODO: Only ratelimit when busy
+ busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue
- if !device.ratelimiter.Allow(elem.source.IP) {
- return
+ if busy {
+ if !device.mac.CheckMAC2(elem.packet, elem.source) {
+ sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
+ reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
+ if err != nil {
+ logError.Println("Failed to create cookie reply:", err)
+ return
+ }
+ writer := bytes.NewBuffer(temp[:0])
+ binary.Write(writer, binary.LittleEndian, reply)
+ _, err = device.net.conn.WriteToUDP(
+ writer.Bytes(),
+ elem.source,
+ )
+ if err != nil {
+ logDebug.Println("Failed to send cookie reply:", err)
+ }
+ continue
+ }
+ if !device.ratelimiter.Allow(elem.source.IP) {
+ continue
+ }
}
- // handle messages
+ default:
+ logError.Println("Invalid packet ended up in the handshake queue")
+ continue
+ }
- switch elem.msgType {
- case MessageInitiationType:
+ // handle handshake initation/response content
- // unmarshal
+ switch elem.msgType {
+ case MessageInitiationType:
- if len(elem.packet) != MessageInitiationSize {
- return
- }
+ // unmarshal
- var msg MessageInitiation
- reader := bytes.NewReader(elem.packet)
- err := binary.Read(reader, binary.LittleEndian, &msg)
- if err != nil {
- logError.Println("Failed to decode initiation message")
- return
- }
+ var msg MessageInitiation
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &msg)
+ if err != nil {
+ logError.Println("Failed to decode initiation message")
+ continue
+ }
- // consume initiation
+ // consume initiation
- peer := device.ConsumeMessageInitiation(&msg)
- if peer == nil {
- logInfo.Println(
- "Recieved invalid initiation message from",
- elem.source.IP.String(),
- elem.source.Port,
- )
- return
- }
+ peer := device.ConsumeMessageInitiation(&msg)
+ if peer == nil {
+ logInfo.Println(
+ "Recieved invalid initiation message from",
+ elem.source.IP.String(),
+ elem.source.Port,
+ )
+ continue
+ }
- // update timers
+ // update timers
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.TimerAnyAuthenticatedPacketReceived()
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
- // update endpoint
- // TODO: Add a race condition \s
+ // update endpoint
+ // TODO: Discover destination address also, only update on change
- peer.mutex.Lock()
- peer.endpoint = elem.source
- peer.mutex.Unlock()
+ peer.mutex.Lock()
+ peer.endpoint = elem.source
+ peer.mutex.Unlock()
- // create response
+ // create response
- response, err := device.CreateMessageResponse(peer)
- if err != nil {
- logError.Println("Failed to create response message:", err)
- return
- }
+ response, err := device.CreateMessageResponse(peer)
+ if err != nil {
+ logError.Println("Failed to create response message:", err)
+ continue
+ }
- peer.TimerEphemeralKeyCreated()
- peer.NewKeyPair()
+ peer.TimerEphemeralKeyCreated()
+ peer.NewKeyPair()
- logDebug.Println("Creating response message for", peer.String())
+ logDebug.Println("Creating response message for", peer.String())
- writer := bytes.NewBuffer(temp[:0])
- binary.Write(writer, binary.LittleEndian, response)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
+ writer := bytes.NewBuffer(temp[:0])
+ binary.Write(writer, binary.LittleEndian, response)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
- // send response
+ // send response
- peer.SendBuffer(packet)
+ _, err = peer.SendBuffer(packet)
+ if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
+ }
- case MessageResponseType:
+ case MessageResponseType:
- // unmarshal
+ // unmarshal
- if len(elem.packet) != MessageResponseSize {
- return
- }
-
- var msg MessageResponse
- reader := bytes.NewReader(elem.packet)
- err := binary.Read(reader, binary.LittleEndian, &msg)
- if err != nil {
- logError.Println("Failed to decode response message")
- return
- }
+ var msg MessageResponse
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &msg)
+ if err != nil {
+ logError.Println("Failed to decode response message")
+ continue
+ }
- // consume response
+ // consume response
- peer := device.ConsumeMessageResponse(&msg)
- if peer == nil {
- logInfo.Println(
- "Recieved invalid response message from",
- elem.source.IP.String(),
- elem.source.Port,
- )
- return
- }
+ peer := device.ConsumeMessageResponse(&msg)
+ if peer == nil {
+ logInfo.Println(
+ "Recieved invalid response message from",
+ elem.source.IP.String(),
+ elem.source.Port,
+ )
+ continue
+ }
- // update timers
+ peer.TimerEphemeralKeyCreated()
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.TimerAnyAuthenticatedPacketReceived()
- peer.TimerHandshakeComplete()
+ // update timers
- // derive key-pair
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
+ peer.TimerHandshakeComplete()
- peer.NewKeyPair()
- peer.SendKeepAlive()
+ // derive key-pair
- default:
- logError.Println("Invalid message type in handshake queue")
- }
- }()
+ peer.NewKeyPair()
+ peer.SendKeepAlive()
+ }
}
}
@@ -463,6 +467,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
logInfo := device.log.Info
+ logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
@@ -478,116 +483,104 @@ func (peer *Peer) RoutineSequentialReceiver() {
// process packet
- func() {
- if elem.IsDropped() {
- return
- }
-
- // check for replay
-
- if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
- return
- }
+ if elem.IsDropped() {
+ continue
+ }
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.TimerAnyAuthenticatedPacketReceived()
- peer.KeepKeyFreshReceiving()
+ // check for replay
- // check if using new key-pair
+ if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+ continue
+ }
- kp := &peer.keyPairs
- kp.mutex.Lock()
- if kp.next == elem.keyPair {
- peer.TimerHandshakeComplete()
- kp.previous = kp.current
- kp.current = kp.next
- kp.next = nil
- }
- kp.mutex.Unlock()
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
+ peer.KeepKeyFreshReceiving()
- // check for keep-alive
+ // check if using new key-pair
- if len(elem.packet) == 0 {
- logDebug.Println("Received keep-alive from", peer.String())
- return
- }
- peer.TimerDataReceived()
+ kp := &peer.keyPairs
+ kp.mutex.Lock()
+ if kp.next == elem.keyPair {
+ peer.TimerHandshakeComplete()
+ kp.previous = kp.current
+ kp.current = kp.next
+ kp.next = nil
+ }
+ kp.mutex.Unlock()
- // verify source and strip padding
+ // check for keep-alive
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
+ if len(elem.packet) == 0 {
+ logDebug.Println("Received keep-alive from", peer.String())
+ continue
+ }
+ peer.TimerDataReceived()
- // strip padding
+ // verify source and strip padding
- if len(elem.packet) < ipv4.HeaderLen {
- return
- }
+ switch elem.packet[0] >> 4 {
+ case ipv4.Version:
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- // TODO: check length of packet & NOT TOO SMALL either
- elem.packet = elem.packet[:length]
+ // strip padding
- // verify IPv4 source
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.routingTable.LookupIPv4(src) != peer {
- logInfo.Println("Packet with unallowed source IP from", peer.String())
- return
- }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
- case ipv6.Version:
+ elem.packet = elem.packet[:length]
- // strip padding
+ // verify IPv4 source
- if len(elem.packet) < ipv6.HeaderLen {
- return
- }
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.routingTable.LookupIPv4(src) != peer {
+ logInfo.Println("Packet with unallowed source IP from", peer.String())
+ continue
+ }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- // TODO: check length of packet
- elem.packet = elem.packet[:length]
+ case ipv6.Version:
- // verify IPv6 source
+ // strip padding
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.routingTable.LookupIPv6(src) != peer {
- logInfo.Println("Packet with unallowed source IP from", peer.String())
- return
- }
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
- default:
- logInfo.Println("Packet with invalid IP version from", peer.String())
- return
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
}
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
- device.addToInboundQueue(device.queue.inbound, elem)
+ elem.packet = elem.packet[:length]
- // TODO: move TUN write into per peer routine
- }()
- }
-}
+ // verify IPv6 source
-func (device *Device) RoutineWriteToTUN() {
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.routingTable.LookupIPv6(src) != peer {
+ logInfo.Println("Packet with unallowed source IP from", peer.String())
+ continue
+ }
- logError := device.log.Error
- logDebug := device.log.Debug
- logDebug.Println("Routine, sequential tun writer, started")
+ default:
+ logInfo.Println("Packet with invalid IP version from", peer.String())
+ continue
+ }
- for {
- select {
- case <-device.signal.stop:
- return
- case elem := <-device.queue.inbound:
- _, err := device.tun.Write(elem.packet)
- device.PutMessageBuffer(elem.buffer)
- if err != nil {
- logError.Println("Failed to write packet to TUN device:", err)
- }
+ // write to tun
+
+ atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ _, err := device.tun.Write(elem.packet)
+ device.PutMessageBuffer(elem.buffer)
+ if err != nil {
+ logError.Println("Failed to write packet to TUN device:", err)
}
}
}
diff --git a/src/send.go b/src/send.go
index fc35732..cf1f018 100644
--- a/src/send.go
+++ b/src/send.go
@@ -168,8 +168,6 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
- println(size, err)
-
elem.packet = elem.packet[:size]
// lookup peer
@@ -210,6 +208,7 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue
+ signalSend(peer.signal.handshakeReset)
addToOutboundQueue(peer.queue.nonce, elem)
elem = nil
diff --git a/src/timers.go b/src/timers.go
index 1be85f0..ab2e7ad 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"golang.org/x/crypto/blake2s"
+ "math/rand"
"sync/atomic"
"time"
)
@@ -16,12 +17,11 @@ func (peer *Peer) KeepKeyFreshSending() {
if kp == nil {
return
}
- if !kp.isInitiator {
- return
- }
nonce := atomic.LoadUint64(&kp.sendNonce)
- send := nonce > RekeyAfterMessages || time.Now().Sub(kp.created) > RekeyAfterTime
- if send {
+ if nonce > RekeyAfterMessages {
+ signalSend(peer.signal.handshakeBegin)
+ }
+ if kp.isInitiator && time.Now().Sub(kp.created) > RekeyAfterTime {
signalSend(peer.signal.handshakeBegin)
}
}
@@ -30,6 +30,7 @@ func (peer *Peer) KeepKeyFreshSending() {
*
*/
func (peer *Peer) KeepKeyFreshReceiving() {
+ // TODO: Add a guard, clear on handshake complete (clear in TimerHandshakeComplete)
kp := peer.keyPairs.Current()
if kp == nil {
return
@@ -108,7 +109,6 @@ func (peer *Peer) TimerAnyAuthenticatedPacketTraversal() {
* - First transport message under the "next" key
*/
func (peer *Peer) TimerHandshakeComplete() {
- timerStop(peer.timer.zeroAllKeys)
atomic.StoreInt64(
&peer.stats.lastHandshakeNano,
time.Now().UnixNano(),
@@ -129,10 +129,7 @@ func (peer *Peer) TimerHandshakeComplete() {
* upon failure to complete a handshake
*/
func (peer *Peer) TimerEphemeralKeyCreated() {
- if !peer.timer.pendingZeroAllKeys {
- peer.timer.pendingZeroAllKeys = true
- peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
- }
+ peer.timer.zeroAllKeys.Reset(RejectAfterTime * 3)
}
func (peer *Peer) RoutineTimerHandler() {
@@ -154,19 +151,19 @@ func (peer *Peer) RoutineTimerHandler() {
interval := atomic.LoadUint64(&peer.persistentKeepaliveInterval)
if interval > 0 {
- logDebug.Println("Sending persistent keep-alive to", peer.String())
+ logDebug.Println("Sending keep-alive to", peer.String())
peer.SendKeepAlive()
}
case <-peer.timer.keepalivePassive.C:
- logDebug.Println("Sending passive keep-alive to", peer.String())
+ logDebug.Println("Sending keep-alive to", peer.String())
peer.SendKeepAlive()
if peer.timer.needAnotherKeepalive {
peer.timer.keepalivePassive.Reset(KeepaliveTimeout)
- peer.timer.needAnotherKeepalive = true
+ peer.timer.needAnotherKeepalive = false
}
// unresponsive session
@@ -189,8 +186,6 @@ func (peer *Peer) RoutineTimerHandler() {
kp := &peer.keyPairs
kp.mutex.Lock()
- peer.timer.pendingZeroAllKeys = false
-
// unmap indecies
indices.mutex.Lock()
@@ -251,40 +246,41 @@ func (peer *Peer) RoutineHandshakeInitiator() {
return
}
- // wait for handshake
+ // set deadline
+
+ BeginHandshakes:
- deadline := time.Now().Add(MaxHandshakeAttemptTime)
+ signalClear(peer.signal.handshakeReset)
+ deadline := time.NewTimer(RekeyAttemptTime)
+
+ AttemptHandshakes:
- Loop:
for attempts := uint(1); ; attempts++ {
- // clear completed signal
+ // check if deadline reached
select {
- case <-peer.signal.handshakeCompleted:
+ case <-deadline.C:
+ logInfo.Println("Handshake negotiation timed out for:", peer.String())
+ signalSend(peer.signal.flushNonceQueue)
+ timerStop(peer.timer.keepalivePersistent)
+ break
case <-peer.signal.stop:
return
default:
}
- // check if sufficient time for retry
-
- if deadline.Before(time.Now().Add(RekeyTimeout)) {
- logInfo.Println("Handshake negotiation timed out for", peer.String())
- signalSend(peer.signal.flushNonceQueue)
- timerStop(peer.timer.keepalivePersistent)
- timerStop(peer.timer.keepalivePassive)
- break Loop
- }
+ signalClear(peer.signal.handshakeCompleted)
// create initiation message
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
logError.Println("Failed to create handshake initiation message:", err)
- break Loop
+ break AttemptHandshakes
}
- peer.TimerEphemeralKeyCreated()
+
+ jitter := time.Millisecond * time.Duration(rand.Uint32()%334)
// marshal and send
@@ -299,14 +295,14 @@ func (peer *Peer) RoutineHandshakeInitiator() {
"Failed to send handshake initiation message to",
peer.String(), ":", err,
)
- continue
+ break
}
peer.TimerAnyAuthenticatedPacketTraversal()
- // set timeout
+ // set handshake timeout
- timeout := time.NewTimer(RekeyTimeout)
+ timeout := time.NewTimer(RekeyTimeout + jitter)
logDebug.Println(
"Handshake initiation attempt",
attempts, "sent to", peer.String(),
@@ -321,15 +317,19 @@ func (peer *Peer) RoutineHandshakeInitiator() {
case <-peer.signal.handshakeCompleted:
<-timeout.C
- break Loop
+ break AttemptHandshakes
+
+ case <-peer.signal.handshakeReset:
+ <-timeout.C
+ goto BeginHandshakes
case <-timeout.C:
+ // TODO: Clear source address for peer
continue
-
}
}
- // allow new signal to be set
+ // clear signal set in the meantime
signalClear(peer.signal.handshakeBegin)
}
diff --git a/src/tun.go b/src/tun.go
index d782bd5..1c4c281 100644
--- a/src/tun.go
+++ b/src/tun.go
@@ -6,10 +6,19 @@ package main
const DefaultMTU = 1420
+type TUNEvent int
+
+const (
+ TUNEventUp = 1 << iota
+ TUNEventDown
+ TUNEventMTUUpdate
+)
+
type TUNDevice interface {
Read([]byte) (int, error) // read a packet from the device (without any additional headers)
Write([]byte) (int, error) // writes a packet to the device (without any additional headers)
- IsUp() (bool, error) // is the interface up?
MTU() (int, error) // returns the MTU of the device
Name() string // returns the current name
+ Events() chan TUNEvent // returns a constant channel of events related to the device
+ Close() error // stops the device and closes the event channel
}
diff --git a/src/tun_linux.go b/src/tun_linux.go
index d0e2f47..34f746a 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -16,11 +16,12 @@ import (
const CloneDevicePath = "/dev/net/tun"
type NativeTun struct {
- fd *os.File
- name string
+ fd *os.File
+ name string
+ events chan TUNEvent
}
-func (tun *NativeTun) IsUp() (bool, error) {
+func (tun *NativeTun) isUp() (bool, error) {
inter, err := net.InterfaceByName(tun.name)
return inter.Flags&net.FlagUp != 0, err
}
@@ -111,6 +112,14 @@ func (tun *NativeTun) Read(d []byte) (int, error) {
return tun.fd.Read(d)
}
+func (tun *NativeTun) Events() chan TUNEvent {
+ return tun.events
+}
+
+func (tun *NativeTun) Close() error {
+ return nil
+}
+
func CreateTUN(name string) (TUNDevice, error) {
// open clone device
@@ -146,10 +155,14 @@ func CreateTUN(name string) (TUNDevice, error) {
newName := string(ifr[:])
newName = newName[:strings.Index(newName, "\000")]
device := &NativeTun{
- fd: fd,
- name: newName,
+ fd: fd,
+ name: newName,
+ events: make(chan TUNEvent, 5),
}
+ // TODO: Wait for device to be upped
+ device.events <- TUNEventUp
+
// set default MTU
err = device.setMTU(DefaultMTU)
diff --git a/src/uapi_linux.go b/src/uapi_linux.go
index d6d78e7..fd56b5a 100644
--- a/src/uapi_linux.go
+++ b/src/uapi_linux.go
@@ -7,7 +7,6 @@ import (
"net"
"os"
"path"
- "time"
)
const (
@@ -26,9 +25,10 @@ const (
*/
type UAPIListener struct {
- listener net.Listener // unix socket listener
- connNew chan net.Conn
- connErr chan error
+ listener net.Listener // unix socket listener
+ connNew chan net.Conn
+ connErr chan error
+ inotifyFd int
}
func (l *UAPIListener) Accept() (net.Conn, error) {
@@ -106,9 +106,28 @@ func NewUAPIListener(name string) (net.Listener, error) {
// watch for deletion of socket
+ uapi.inotifyFd, err = unix.InotifyInit()
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = unix.InotifyAddWatch(
+ uapi.inotifyFd,
+ socketPath,
+ unix.IN_ATTRIB|
+ unix.IN_DELETE|
+ unix.IN_DELETE_SELF,
+ )
+
+ if err != nil {
+ return nil, err
+ }
+
go func(l *UAPIListener) {
- for ; ; time.Sleep(time.Second) {
- if _, err := os.Stat(socketPath); os.IsNotExist(err) {
+ var buff [4096]byte
+ for {
+ unix.Read(uapi.inotifyFd, buff[:])
+ if _, err := os.Lstat(socketPath); os.IsNotExist(err) {
l.connErr <- err
return
}