diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-03-03 04:04:41 +0100 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2019-03-03 05:00:40 +0100 |
commit | 69f0fe67b63d90e523a5a1241fb1b46c2e8dbe03 (patch) | |
tree | 1ef86da3242afde462dcadb7241bb09f499d5bd7 /device/device.go | |
parent | d435be35cac49af9367b2005d831d55e570c4b1b (diff) |
global: begin modularization
Diffstat (limited to 'device/device.go')
-rw-r--r-- | device/device.go | 396 |
1 files changed, 396 insertions, 0 deletions
diff --git a/device/device.go b/device/device.go new file mode 100644 index 0000000..d6c96d6 --- /dev/null +++ b/device/device.go @@ -0,0 +1,396 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/tun" + "runtime" + "sync" + "sync/atomic" + "time" +) + +const ( + DeviceRoutineNumberPerCPU = 3 + DeviceRoutineNumberAdditional = 2 +) + +type Device struct { + isUp AtomicBool // device is (going) up + isClosed AtomicBool // device is closed? (acting as guard) + log *Logger + + // synchronized resources (locks acquired in order) + + state struct { + starting sync.WaitGroup + stopping sync.WaitGroup + sync.Mutex + changing AtomicBool + current bool + } + + net struct { + starting sync.WaitGroup + stopping sync.WaitGroup + sync.RWMutex + bind Bind // bind interface + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) + } + + staticIdentity struct { + sync.RWMutex + privateKey NoisePrivateKey + publicKey NoisePublicKey + } + + peers struct { + sync.RWMutex + keyMap map[NoisePublicKey]*Peer + } + + // unprotected / "self-synchronising resources" + + allowedips AllowedIPs + indexTable IndexTable + cookieChecker CookieChecker + + rate struct { + underLoadUntil atomic.Value + limiter ratelimiter.Ratelimiter + } + + pool struct { + messageBufferPool *sync.Pool + messageBufferReuseChan chan *[MaxMessageSize]byte + inboundElementPool *sync.Pool + inboundElementReuseChan chan *QueueInboundElement + outboundElementPool *sync.Pool + outboundElementReuseChan chan *QueueOutboundElement + } + + queue struct { + encryption chan *QueueOutboundElement + decryption chan *QueueInboundElement + handshake chan QueueHandshakeElement + } + + signals struct { + stop chan struct{} + } + + tun struct { + device tun.TUNDevice + mtu int32 + } +} + +/* Converts the peer into a "zombie", which remains in the peer map, + * but processes no packets and does not exists in the routing table. + * + * Must hold device.peers.Mutex + */ +func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) { + + // stop routing and processing of packets + + device.allowedips.RemoveByPeer(peer) + peer.Stop() + + // remove from peer map + + delete(device.peers.keyMap, key) +} + +func deviceUpdateState(device *Device) { + + // check if state already being updated (guard) + + if device.state.changing.Swap(true) { + return + } + + // compare to current state of device + + device.state.Lock() + + newIsUp := device.isUp.Get() + + if newIsUp == device.state.current { + device.state.changing.Set(false) + device.state.Unlock() + return + } + + // change state of device + + switch newIsUp { + case true: + if err := device.BindUpdate(); err != nil { + device.isUp.Set(false) + break + } + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Start() + if peer.persistentKeepaliveInterval > 0 { + peer.SendKeepalive() + } + } + device.peers.RUnlock() + + case false: + device.BindClose() + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Stop() + } + device.peers.RUnlock() + } + + // update state variables + + device.state.current = newIsUp + device.state.changing.Set(false) + device.state.Unlock() + + // check for state change in the mean time + + deviceUpdateState(device) +} + +func (device *Device) Up() { + + // closed device cannot be brought up + + if device.isClosed.Get() { + return + } + + device.isUp.Set(true) + deviceUpdateState(device) +} + +func (device *Device) Down() { + device.isUp.Set(false) + deviceUpdateState(device) +} + +func (device *Device) IsUnderLoad() bool { + + // check if currently under load + + now := time.Now() + underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + if underLoad { + device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) + return true + } + + // check if recently under load + + until := device.rate.underLoadUntil.Load().(time.Time) + return until.After(now) +} + +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { + + // lock required resources + + device.staticIdentity.Lock() + defer device.staticIdentity.Unlock() + + device.peers.Lock() + defer device.peers.Unlock() + + for _, peer := range device.peers.keyMap { + peer.handshake.mutex.RLock() + defer peer.handshake.mutex.RUnlock() + } + + // remove peers with matching public keys + + publicKey := sk.publicKey() + for key, peer := range device.peers.keyMap { + if peer.handshake.remoteStatic.Equals(publicKey) { + unsafeRemovePeer(device, peer, key) + } + } + + // update key material + + device.staticIdentity.privateKey = sk + device.staticIdentity.publicKey = publicKey + device.cookieChecker.Init(publicKey) + + // do static-static DH pre-computations + + rmKey := device.staticIdentity.privateKey.IsZero() + + for key, peer := range device.peers.keyMap { + + handshake := &peer.handshake + + if rmKey { + handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{} + } else { + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) + } + + if isZero(handshake.precomputedStaticStatic[:]) { + unsafeRemovePeer(device, peer, key) + } + } + + return nil +} + +func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device { + device := new(Device) + + device.isUp.Set(false) + device.isClosed.Set(false) + + device.log = logger + + device.tun.device = tunDevice + mtu, err := device.tun.device.MTU() + if err != nil { + logger.Error.Println("Trouble determining MTU, assuming default:", err) + mtu = DefaultMTU + } + device.tun.mtu = int32(mtu) + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) + + device.rate.limiter.Init() + device.rate.underLoadUntil.Store(time.Time{}) + + device.indexTable.Init() + device.allowedips.Reset() + + device.PopulatePools() + + // create queues + + device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) + device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) + device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) + + // prepare signals + + device.signals.stop = make(chan struct{}) + + // prepare net + + device.net.port = 0 + device.net.bind = nil + + // start workers + + cpus := runtime.NumCPU() + device.state.starting.Wait() + device.state.stopping.Wait() + device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) + device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) + for i := 0; i < cpus; i += 1 { + go device.RoutineEncryption() + go device.RoutineDecryption() + go device.RoutineHandshake() + } + + go device.RoutineReadFromTUN() + go device.RoutineTUNEventReader() + + device.state.starting.Wait() + + return device +} + +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { + device.peers.RLock() + defer device.peers.RUnlock() + + return device.peers.keyMap[pk] +} + +func (device *Device) RemovePeer(key NoisePublicKey) { + device.peers.Lock() + defer device.peers.Unlock() + + // stop peer and remove from routing + + peer, ok := device.peers.keyMap[key] + if ok { + unsafeRemovePeer(device, peer, key) + } +} + +func (device *Device) RemoveAllPeers() { + device.peers.Lock() + defer device.peers.Unlock() + + for key, peer := range device.peers.keyMap { + unsafeRemovePeer(device, peer, key) + } + + device.peers.keyMap = make(map[NoisePublicKey]*Peer) +} + +func (device *Device) FlushPacketQueues() { + for { + select { + case elem, ok := <-device.queue.decryption: + if ok { + elem.Drop() + } + case elem, ok := <-device.queue.encryption: + if ok { + elem.Drop() + } + case <-device.queue.handshake: + default: + return + } + } + +} + +func (device *Device) Close() { + if device.isClosed.Swap(true) { + return + } + + device.state.starting.Wait() + + device.log.Info.Println("Device closing") + device.state.changing.Set(true) + device.state.Lock() + defer device.state.Unlock() + + device.tun.device.Close() + device.BindClose() + + device.isUp.Set(false) + + close(device.signals.stop) + + device.RemoveAllPeers() + + device.state.stopping.Wait() + device.FlushPacketQueues() + + device.rate.limiter.Close() + + device.state.changing.Set(false) + device.log.Info.Println("Interface closed") +} + +func (device *Device) Wait() chan struct{} { + return device.signals.stop +} |