diff options
Diffstat (limited to 'device/peer.go')
-rw-r--r-- | device/peer.go | 270 |
1 files changed, 270 insertions, 0 deletions
diff --git a/device/peer.go b/device/peer.go new file mode 100644 index 0000000..af3ef9d --- /dev/null +++ b/device/peer.go @@ -0,0 +1,270 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "encoding/base64" + "errors" + "fmt" + "sync" + "time" +) + +const ( + PeerRoutineNumber = 3 +) + +type Peer struct { + isRunning AtomicBool + sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer + keypairs Keypairs + handshake Handshake + device *Device + endpoint Endpoint + persistentKeepaliveInterval uint16 + + // This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly + stats struct { + txBytes uint64 // bytes send to peer (endpoint) + rxBytes uint64 // bytes received from peer + lastHandshakeNano int64 // nano seconds since epoch + } + + timers struct { + retransmitHandshake *Timer + sendKeepalive *Timer + newHandshake *Timer + zeroKeyMaterial *Timer + persistentKeepalive *Timer + handshakeAttempts uint32 + needAnotherKeepalive AtomicBool + sentLastMinuteHandshake AtomicBool + } + + signals struct { + newKeypairArrived chan struct{} + flushNonceQueue chan struct{} + } + + queue struct { + nonce chan *QueueOutboundElement // nonce / pre-handshake queue + outbound chan *QueueOutboundElement // sequential ordering of work + inbound chan *QueueInboundElement // sequential ordering of work + packetInNonceQueueIsAwaitingKey AtomicBool + } + + routines struct { + sync.Mutex // held when stopping / starting routines + starting sync.WaitGroup // routines pending start + stopping sync.WaitGroup // routines pending stop + stop chan struct{} // size 0, stop all go routines in peer + } + + cookieGenerator CookieGenerator +} + +func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { + + if device.isClosed.Get() { + return nil, errors.New("device closed") + } + + // lock resources + + device.staticIdentity.RLock() + defer device.staticIdentity.RUnlock() + + device.peers.Lock() + defer device.peers.Unlock() + + // check if over limit + + if len(device.peers.keyMap) >= MaxPeers { + return nil, errors.New("too many peers") + } + + // create peer + + peer := new(Peer) + peer.Lock() + defer peer.Unlock() + + peer.cookieGenerator.Init(pk) + peer.device = device + peer.isRunning.Set(false) + + // map public key + + _, ok := device.peers.keyMap[pk] + if ok { + return nil, errors.New("adding existing peer") + } + device.peers.keyMap[pk] = peer + + // pre-compute DH + + handshake := &peer.handshake + handshake.mutex.Lock() + handshake.remoteStatic = pk + handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) + handshake.mutex.Unlock() + + // reset endpoint + + peer.endpoint = nil + + // start peer + + if peer.device.isUp.Get() { + peer.Start() + } + + return peer, nil +} + +func (peer *Peer) SendBuffer(buffer []byte) error { + peer.device.net.RLock() + defer peer.device.net.RUnlock() + + if peer.device.net.bind == nil { + return errors.New("no bind") + } + + peer.RLock() + defer peer.RUnlock() + + if peer.endpoint == nil { + return errors.New("no known endpoint for peer") + } + + return peer.device.net.bind.Send(buffer, peer.endpoint) +} + +func (peer *Peer) String() string { + base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + abbreviatedKey := "invalid" + if len(base64Key) == 44 { + abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43] + } + return fmt.Sprintf("peer(%s)", abbreviatedKey) +} + +func (peer *Peer) Start() { + + // should never start a peer on a closed device + + if peer.device.isClosed.Get() { + return + } + + // prevent simultaneous start/stop operations + + peer.routines.Lock() + defer peer.routines.Unlock() + + if peer.isRunning.Get() { + return + } + + device := peer.device + device.log.Debug.Println(peer, "- Starting...") + + // reset routine state + + peer.routines.starting.Wait() + peer.routines.stopping.Wait() + peer.routines.stop = make(chan struct{}) + peer.routines.starting.Add(PeerRoutineNumber) + peer.routines.stopping.Add(PeerRoutineNumber) + + // prepare queues + + peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) + peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + + peer.timersInit() + peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + peer.signals.newKeypairArrived = make(chan struct{}, 1) + peer.signals.flushNonceQueue = make(chan struct{}, 1) + + // wait for routines to start + + go peer.RoutineNonce() + go peer.RoutineSequentialSender() + go peer.RoutineSequentialReceiver() + + peer.routines.starting.Wait() + peer.isRunning.Set(true) +} + +func (peer *Peer) ZeroAndFlushAll() { + device := peer.device + + // clear key pairs + + keypairs := &peer.keypairs + keypairs.Lock() + device.DeleteKeypair(keypairs.previous) + device.DeleteKeypair(keypairs.current) + device.DeleteKeypair(keypairs.next) + keypairs.previous = nil + keypairs.current = nil + keypairs.next = nil + keypairs.Unlock() + + // clear handshake state + + handshake := &peer.handshake + handshake.mutex.Lock() + device.indexTable.Delete(handshake.localIndex) + handshake.Clear() + handshake.mutex.Unlock() + + peer.FlushNonceQueue() +} + +func (peer *Peer) Stop() { + + // prevent simultaneous start/stop operations + + if !peer.isRunning.Swap(false) { + return + } + + peer.routines.starting.Wait() + + peer.routines.Lock() + defer peer.routines.Unlock() + + peer.device.log.Debug.Println(peer, "- Stopping...") + + peer.timersStop() + + // stop & wait for ongoing peer routines + + close(peer.routines.stop) + peer.routines.stopping.Wait() + + // close queues + + close(peer.queue.nonce) + close(peer.queue.outbound) + close(peer.queue.inbound) + + peer.ZeroAndFlushAll() +} + +var roamingDisabled bool + +func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { + if roamingDisabled { + return + } + peer.Lock() + peer.endpoint = endpoint + peer.Unlock() +} |