diff options
Diffstat (limited to 'device.go')
-rw-r--r-- | device.go | 364 |
1 files changed, 364 insertions, 0 deletions
diff --git a/device.go b/device.go new file mode 100644 index 0000000..22e0990 --- /dev/null +++ b/device.go @@ -0,0 +1,364 @@ +package main + +import ( + "runtime" + "sync" + "sync/atomic" + "time" +) + +type Device struct { + isUp AtomicBool // device is (going) up + isClosed AtomicBool // device is closed? (acting as guard) + log *Logger + + // synchronized resources (locks acquired in order) + + state struct { + mutex sync.Mutex + changing AtomicBool + current bool + } + + net struct { + mutex sync.RWMutex + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + } + + noise struct { + mutex sync.RWMutex + privateKey NoisePrivateKey + publicKey NoisePublicKey + } + + routing struct { + mutex sync.RWMutex + table RoutingTable + } + + peers struct { + mutex sync.RWMutex + keyMap map[NoisePublicKey]*Peer + } + + // unprotected / "self-synchronising resources" + + indices IndexTable + mac CookieChecker + + rate struct { + underLoadUntil atomic.Value + limiter Ratelimiter + } + + pool struct { + messageBuffers sync.Pool + } + + queue struct { + encryption chan *QueueOutboundElement + decryption chan *QueueInboundElement + handshake chan QueueHandshakeElement + } + + signal struct { + stop Signal + } + + tun struct { + device TUNDevice + mtu int32 + } +} + +/* Converts the peer into a "zombie", which remains in the peer map, + * but processes no packets and does not exists in the routing table. + * + * Must hold: + * device.peers.mutex : exclusive lock + * device.routing : exclusive lock + */ +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { + + // stop routing and processing of packets + + device.routing.table.RemovePeer(peer) + peer.Stop() + + // remove from peer map + + delete(device.peers.keyMap, key) +} + +func deviceUpdateState(device *Device) { + + // check if state already being updated (guard) + + if device.state.changing.Swap(true) { + return + } + + // compare to current state of device + + device.state.mutex.Lock() + + newIsUp := device.isUp.Get() + + if newIsUp == device.state.current { + device.state.changing.Set(false) + device.state.mutex.Unlock() + return + } + + // change state of device + + switch newIsUp { + case true: + if err := device.BindUpdate(); err != nil { + device.isUp.Set(false) + break + } + device.peers.mutex.Lock() + for _, peer := range device.peers.keyMap { + peer.Start() + } + device.peers.mutex.Unlock() + + case false: + device.BindClose() + device.peers.mutex.Lock() + for _, peer := range device.peers.keyMap { + peer.Stop() + } + device.peers.mutex.Unlock() + } + + // update state variables + + device.state.current = newIsUp + device.state.changing.Set(false) + device.state.mutex.Unlock() + + // check for state change in the mean time + + deviceUpdateState(device) +} + +func (device *Device) Up() { + + // closed device cannot be brought up + + if device.isClosed.Get() { + return + } + + device.state.mutex.Lock() + device.isUp.Set(true) + device.state.mutex.Unlock() + deviceUpdateState(device) +} + +func (device *Device) Down() { + device.state.mutex.Lock() + device.isUp.Set(false) + device.state.mutex.Unlock() + deviceUpdateState(device) +} + +func (device *Device) IsUnderLoad() bool { + + // check if currently under load + + now := time.Now() + underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + if underLoad { + device.rate.underLoadUntil.Store(now.Add(time.Second)) + return true + } + + // check if recently under load + + until := device.rate.underLoadUntil.Load().(time.Time) + return until.After(now) +} + +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { + + // lock required resources + + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for _, peer := range device.peers.keyMap { + peer.handshake.mutex.RLock() + defer peer.handshake.mutex.RUnlock() + } + + // remove peers with matching public keys + + publicKey := sk.publicKey() + for key, peer := range device.peers.keyMap { + if peer.handshake.remoteStatic.Equals(publicKey) { + unsafeRemovePeer(device, peer, key) + } + } + + // update key material + + device.noise.privateKey = sk + device.noise.publicKey = publicKey + device.mac.Init(publicKey) + + // do static-static DH pre-computations + + rmKey := device.noise.privateKey.IsZero() + + for key, peer := range device.peers.keyMap { + + hs := &peer.handshake + + if rmKey { + hs.precomputedStaticStatic = [NoisePublicKeySize]byte{} + } else { + hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic) + } + + if isZero(hs.precomputedStaticStatic[:]) { + unsafeRemovePeer(device, peer, key) + } + } + + return nil +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + device.pool.messageBuffers.Put(msg) +} + +func NewDevice(tun TUNDevice, logger *Logger) *Device { + device := new(Device) + + device.isUp.Set(false) + device.isClosed.Set(false) + + device.log = logger + device.tun.device = tun + device.peers.keyMap = make(map[NoisePublicKey]*Peer) + + // initialize anti-DoS / anti-scanning features + + device.rate.limiter.Init() + device.rate.underLoadUntil.Store(time.Time{}) + + // initialize noise & crypt-key routine + + device.indices.Init() + device.routing.table.Reset() + + // setup buffer pool + + device.pool.messageBuffers = sync.Pool{ + New: func() interface{} { + return new([MaxMessageSize]byte) + }, + } + + // create queues + + device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) + device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + + // prepare signals + + device.signal.stop = NewSignal() + + // prepare net + + device.net.port = 0 + device.net.bind = nil + + // start workers + + for i := 0; i < runtime.NumCPU(); i += 1 { + go device.RoutineEncryption() + go device.RoutineDecryption() + go device.RoutineHandshake() + } + + go device.RoutineReadFromTUN() + go device.RoutineTUNEventReader() + go device.rate.limiter.RoutineGarbageCollector(device.signal.stop) + + return device +} + +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { + device.peers.mutex.RLock() + defer device.peers.mutex.RUnlock() + + return device.peers.keyMap[pk] +} + +func (device *Device) RemovePeer(key NoisePublicKey) { + device.noise.mutex.Lock() + defer device.noise.mutex.Unlock() + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + // stop peer and remove from routing + + peer, ok := device.peers.keyMap[key] + if ok { + unsafeRemovePeer(device, peer, key) + } +} + +func (device *Device) RemoveAllPeers() { + + device.routing.mutex.Lock() + defer device.routing.mutex.Unlock() + + device.peers.mutex.Lock() + defer device.peers.mutex.Unlock() + + for key, peer := range device.peers.keyMap { + println("rm", peer.String()) + unsafeRemovePeer(device, peer, key) + } + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) +} + +func (device *Device) Close() { + device.log.Info.Println("Device closing") + if device.isClosed.Swap(true) { + return + } + device.signal.stop.Broadcast() + device.tun.device.Close() + device.BindClose() + device.isUp.Set(false) + device.RemoveAllPeers() + device.log.Info.Println("Interface closed") +} + +func (device *Device) Wait() chan struct{} { + return device.signal.stop.Wait() +} |