diff options
Diffstat (limited to 'src/device.go')
-rw-r--r-- | src/device.go | 90 |
1 files changed, 52 insertions, 38 deletions
diff --git a/src/device.go b/src/device.go index de96f0b..4aa90e3 100644 --- a/src/device.go +++ b/src/device.go @@ -1,13 +1,10 @@ package main import ( - "errors" - "fmt" "net" "runtime" "sync" "sync/atomic" - "time" ) type Device struct { @@ -34,31 +31,45 @@ type Device struct { queue struct { encryption chan *QueueOutboundElement decryption chan *QueueInboundElement - inbound chan *QueueInboundElement handshake chan QueueHandshakeElement } signal struct { - stop chan struct{} + stop chan struct{} // halts all go routines + newUDPConn chan struct{} // a net.conn was set } - underLoad int32 // used as an atomic bool + isUp int32 // atomic bool: interface is up + underLoad int32 // atomic bool: device is under load ratelimiter Ratelimiter peers map[NoisePublicKey]*Peer mac MACStateDevice } +/* Warning: + * The caller must hold the device mutex (write lock) + */ +func removePeerUnsafe(device *Device, key NoisePublicKey) { + peer, ok := device.peers[key] + if !ok { + return + } + peer.mutex.Lock() + device.routingTable.RemovePeer(peer) + delete(device.peers, key) + peer.Close() +} + func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { device.mutex.Lock() defer device.mutex.Unlock() - // check if public key is matching any peer + // remove peers with matching public keys publicKey := sk.publicKey() - for _, peer := range device.peers { + for key, peer := range device.peers { h := &peer.handshake h.mutex.RLock() if h.remoteStatic.Equals(publicKey) { - h.mutex.RUnlock() - return errors.New("Private key matches public key of peer") + removePeerUnsafe(device, key) } h.mutex.RUnlock() } @@ -71,17 +82,19 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // do DH precomputations - isZero := device.privateKey.IsZero() + rmKey := device.privateKey.IsZero() - for _, peer := range device.peers { + for key, peer := range device.peers { h := &peer.handshake h.mutex.Lock() - if isZero { + if rmKey { h.precomputedStaticStatic = [NoisePublicKeySize]byte{} } else { h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic) + if isZero(h.precomputedStaticStatic[:]) { + removePeerUnsafe(device, key) + } } - fmt.Println(h.precomputedStaticStatic) h.mutex.Unlock() } @@ -130,11 +143,11 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - device.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) // prepare signals device.signal.stop = make(chan struct{}) + device.signal.newUDPConn = make(chan struct{}, 1) // start workers @@ -145,33 +158,42 @@ func NewDevice(tun TUNDevice, logLevel int) *Device { } go device.RoutineBusyMonitor() - go device.RoutineMTUUpdater() - go device.RoutineWriteToTUN() go device.RoutineReadFromTUN() + go device.RoutineTUNEventReader() go device.RoutineReceiveIncomming() go device.ratelimiter.RoutineGarbageCollector(device.signal.stop) return device } -func (device *Device) RoutineMTUUpdater() { +func (device *Device) RoutineTUNEventReader() { + events := device.tun.Events() logError := device.log.Error - for ; ; time.Sleep(5 * time.Second) { - // load updated MTU - - mtu, err := device.tun.MTU() - if err != nil { - logError.Println("Failed to load updated MTU of device:", err) - continue + for event := range events { + if event&TUNEventMTUUpdate != 0 { + mtu, err := device.tun.MTU() + if err != nil { + logError.Println("Failed to load updated MTU of device:", err) + } else { + if mtu+MessageTransportSize > MaxMessageSize { + mtu = MaxMessageSize - MessageTransportSize + } + atomic.StoreInt32(&device.mtu, int32(mtu)) + } } - // upper bound of mtu + if event&TUNEventUp != 0 { + println("handle 1") + atomic.StoreInt32(&device.isUp, AtomicTrue) + updateUDPConn(device) + println("handle 2", device.net.conn) + } - if mtu+MessageTransportSize > MaxMessageSize { - mtu = MaxMessageSize - MessageTransportSize + if event&TUNEventDown != 0 { + atomic.StoreInt32(&device.isUp, AtomicFalse) + closeUDPConn(device) } - atomic.StoreInt32(&device.mtu, int32(mtu)) } } @@ -184,15 +206,7 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { func (device *Device) RemovePeer(key NoisePublicKey) { device.mutex.Lock() defer device.mutex.Unlock() - - peer, ok := device.peers[key] - if !ok { - return - } - peer.mutex.Lock() - device.routingTable.RemovePeer(peer) - delete(device.peers, key) - peer.Close() + removePeerUnsafe(device, key) } func (device *Device) RemoveAllPeers() { |