summaryrefslogtreecommitdiffhomepage
path: root/src/device.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/device.go')
-rw-r--r--src/device.go90
1 files changed, 52 insertions, 38 deletions
diff --git a/src/device.go b/src/device.go
index de96f0b..4aa90e3 100644
--- a/src/device.go
+++ b/src/device.go
@@ -1,13 +1,10 @@
package main
import (
- "errors"
- "fmt"
"net"
"runtime"
"sync"
"sync/atomic"
- "time"
)
type Device struct {
@@ -34,31 +31,45 @@ type Device struct {
queue struct {
encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement
- inbound chan *QueueInboundElement
handshake chan QueueHandshakeElement
}
signal struct {
- stop chan struct{}
+ stop chan struct{} // halts all go routines
+ newUDPConn chan struct{} // a net.conn was set
}
- underLoad int32 // used as an atomic bool
+ isUp int32 // atomic bool: interface is up
+ underLoad int32 // atomic bool: device is under load
ratelimiter Ratelimiter
peers map[NoisePublicKey]*Peer
mac MACStateDevice
}
+/* Warning:
+ * The caller must hold the device mutex (write lock)
+ */
+func removePeerUnsafe(device *Device, key NoisePublicKey) {
+ peer, ok := device.peers[key]
+ if !ok {
+ return
+ }
+ peer.mutex.Lock()
+ device.routingTable.RemovePeer(peer)
+ delete(device.peers, key)
+ peer.Close()
+}
+
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
device.mutex.Lock()
defer device.mutex.Unlock()
- // check if public key is matching any peer
+ // remove peers with matching public keys
publicKey := sk.publicKey()
- for _, peer := range device.peers {
+ for key, peer := range device.peers {
h := &peer.handshake
h.mutex.RLock()
if h.remoteStatic.Equals(publicKey) {
- h.mutex.RUnlock()
- return errors.New("Private key matches public key of peer")
+ removePeerUnsafe(device, key)
}
h.mutex.RUnlock()
}
@@ -71,17 +82,19 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do DH precomputations
- isZero := device.privateKey.IsZero()
+ rmKey := device.privateKey.IsZero()
- for _, peer := range device.peers {
+ for key, peer := range device.peers {
h := &peer.handshake
h.mutex.Lock()
- if isZero {
+ if rmKey {
h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
+ if isZero(h.precomputedStaticStatic[:]) {
+ removePeerUnsafe(device, key)
+ }
}
- fmt.Println(h.precomputedStaticStatic)
h.mutex.Unlock()
}
@@ -130,11 +143,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
- device.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals
device.signal.stop = make(chan struct{})
+ device.signal.newUDPConn = make(chan struct{}, 1)
// start workers
@@ -145,33 +158,42 @@ func NewDevice(tun TUNDevice, logLevel int) *Device {
}
go device.RoutineBusyMonitor()
- go device.RoutineMTUUpdater()
- go device.RoutineWriteToTUN()
go device.RoutineReadFromTUN()
+ go device.RoutineTUNEventReader()
go device.RoutineReceiveIncomming()
go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
return device
}
-func (device *Device) RoutineMTUUpdater() {
+func (device *Device) RoutineTUNEventReader() {
+ events := device.tun.Events()
logError := device.log.Error
- for ; ; time.Sleep(5 * time.Second) {
- // load updated MTU
-
- mtu, err := device.tun.MTU()
- if err != nil {
- logError.Println("Failed to load updated MTU of device:", err)
- continue
+ for event := range events {
+ if event&TUNEventMTUUpdate != 0 {
+ mtu, err := device.tun.MTU()
+ if err != nil {
+ logError.Println("Failed to load updated MTU of device:", err)
+ } else {
+ if mtu+MessageTransportSize > MaxMessageSize {
+ mtu = MaxMessageSize - MessageTransportSize
+ }
+ atomic.StoreInt32(&device.mtu, int32(mtu))
+ }
}
- // upper bound of mtu
+ if event&TUNEventUp != 0 {
+ println("handle 1")
+ atomic.StoreInt32(&device.isUp, AtomicTrue)
+ updateUDPConn(device)
+ println("handle 2", device.net.conn)
+ }
- if mtu+MessageTransportSize > MaxMessageSize {
- mtu = MaxMessageSize - MessageTransportSize
+ if event&TUNEventDown != 0 {
+ atomic.StoreInt32(&device.isUp, AtomicFalse)
+ closeUDPConn(device)
}
- atomic.StoreInt32(&device.mtu, int32(mtu))
}
}
@@ -184,15 +206,7 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
func (device *Device) RemovePeer(key NoisePublicKey) {
device.mutex.Lock()
defer device.mutex.Unlock()
-
- peer, ok := device.peers[key]
- if !ok {
- return
- }
- peer.mutex.Lock()
- device.routingTable.RemovePeer(peer)
- delete(device.peers, key)
- peer.Close()
+ removePeerUnsafe(device, key)
}
func (device *Device) RemoveAllPeers() {