summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2018-02-02 16:40:14 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2018-02-02 16:40:14 +0100
commit029410b118f079d77fa448cf56a97b949faee126 (patch)
tree5c9ecf509601b3abffe36094b3b228b87b7d8b92
parent1e42b1402261d15b87b1b5871f7bc51342b46e34 (diff)
Rework of entire locking system
Locking on the Device instance is now much more fined-grained, seperating out the fields into "resources" st. most common interactions only require a small number.
-rw-r--r--src/conn.go17
-rw-r--r--src/device.go309
-rw-r--r--src/noise_helpers.go7
-rw-r--r--src/noise_protocol.go23
-rw-r--r--src/peer.go63
-rw-r--r--src/receive.go14
-rw-r--r--src/send.go8
-rw-r--r--src/timers.go4
-rw-r--r--src/tun_linux.go4
-rw-r--r--src/uapi.go146
10 files changed, 371 insertions, 224 deletions
diff --git a/src/conn.go b/src/conn.go
index c2f5dee..fb30ec2 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -65,12 +65,12 @@ func unsafeCloseBind(device *Device) error {
}
func (device *Device) BindUpdate() error {
- device.mutex.Lock()
- defer device.mutex.Unlock()
- netc := &device.net
- netc.mutex.Lock()
- defer netc.mutex.Unlock()
+ device.net.mutex.Lock()
+ defer device.net.mutex.Unlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
// close existing sockets
@@ -85,6 +85,7 @@ func (device *Device) BindUpdate() error {
// bind to new port
var err error
+ netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil {
netc.bind = nil
@@ -100,12 +101,12 @@ func (device *Device) BindUpdate() error {
// clear cached source addresses
- for _, peer := range device.peers {
+ for _, peer := range device.peers.keyMap {
peer.mutex.Lock()
+ defer peer.mutex.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
- peer.mutex.Unlock()
}
// start receiving routines
@@ -120,10 +121,8 @@ func (device *Device) BindUpdate() error {
}
func (device *Device) BindClose() error {
- device.mutex.Lock()
device.net.mutex.Lock()
err := unsafeCloseBind(device)
device.net.mutex.Unlock()
- device.mutex.Unlock()
return err
}
diff --git a/src/device.go b/src/device.go
index f1c09c6..0317b60 100644
--- a/src/device.go
+++ b/src/device.go
@@ -9,106 +9,170 @@ import (
)
type Device struct {
- isUp AtomicBool // device is (going) up
- isClosed AtomicBool // device is closed? (acting as guard)
- log *Logger // collection of loggers for levels
- idCounter uint // for assigning debug ids to peers
- fwMark uint32
- tun struct {
- device TUNDevice
- mtu int32
- }
+ isUp AtomicBool // device is (going) up
+ isClosed AtomicBool // device is closed? (acting as guard)
+ log *Logger
+
+ // synchronized resources (locks acquired in order)
+
state struct {
mutex deadlock.Mutex
changing AtomicBool
current bool
}
- pool struct {
- messageBuffers sync.Pool
- }
+
net struct {
mutex deadlock.RWMutex
bind Bind // bind interface
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
}
- mutex deadlock.RWMutex
- privateKey NoisePrivateKey
- publicKey NoisePublicKey
- routingTable RoutingTable
- indices IndexTable
- queue struct {
+
+ noise struct {
+ mutex deadlock.RWMutex
+ privateKey NoisePrivateKey
+ publicKey NoisePublicKey
+ }
+
+ routing struct {
+ mutex deadlock.RWMutex
+ table RoutingTable
+ }
+
+ peers struct {
+ mutex deadlock.RWMutex
+ keyMap map[NoisePublicKey]*Peer
+ }
+
+ // unprotected / "self-synchronising resources"
+
+ indices IndexTable
+ mac CookieChecker
+
+ rate struct {
+ underLoadUntil atomic.Value
+ limiter Ratelimiter
+ }
+
+ pool struct {
+ messageBuffers sync.Pool
+ }
+
+ queue struct {
encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement
handshake chan QueueHandshakeElement
}
+
signal struct {
stop Signal
}
- underLoadUntil atomic.Value
- ratelimiter Ratelimiter
- peers map[NoisePublicKey]*Peer
- mac CookieChecker
+
+ tun struct {
+ device TUNDevice
+ mtu int32
+ }
}
-func deviceUpdateState(device *Device) {
+/* Converts the peer into a "zombie", which remains in the peer map,
+ * but processes no packets and does not exists in the routing table.
+ *
+ * Must hold:
+ * device.peers.mutex : exclusive lock
+ * device.routing : exclusive lock
+ */
+func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
- // check if state already being updated (guard)
+ // stop routing and processing of packets
- if device.state.changing.Swap(true) {
- return
+ device.routing.table.RemovePeer(peer)
+ peer.Stop()
+
+ // clean index table
+
+ kp := &peer.keyPairs
+ kp.mutex.Lock()
+
+ if kp.previous != nil {
+ device.indices.Delete(kp.previous.localIndex)
}
- // compare to current state of device
+ if kp.current != nil {
+ device.indices.Delete(kp.current.localIndex)
+ }
- device.state.mutex.Lock()
+ if kp.next != nil {
+ device.indices.Delete(kp.next.localIndex)
+ }
- newIsUp := device.isUp.Get()
+ kp.previous = nil
+ kp.current = nil
+ kp.next = nil
+ kp.mutex.Unlock()
- if newIsUp == device.state.current {
- device.state.mutex.Unlock()
- device.state.changing.Set(false)
+ // remove from peer map
+
+ delete(device.peers.keyMap, key)
+}
+
+func deviceUpdateState(device *Device) {
+
+ // check if state already being updated (guard)
+
+ if device.state.changing.Swap(true) {
return
}
- device.state.mutex.Unlock()
+ func() {
- // change state of device
+ // compare to current state of device
- switch newIsUp {
- case true:
+ device.state.mutex.Lock()
+ defer device.state.mutex.Unlock()
- // start listener
+ newIsUp := device.isUp.Get()
- if err := device.BindUpdate(); err != nil {
- device.isUp.Set(false)
- break
+ if newIsUp == device.state.current {
+ device.state.changing.Set(false)
+ return
}
- // start every peer
+ // change state of device
- for _, peer := range device.peers {
- peer.Start()
- }
+ switch newIsUp {
+ case true:
+ if err := device.BindUpdate(); err != nil {
+ device.isUp.Set(false)
+ break
+ }
- case false:
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
- // stop listening
+ for _, peer := range device.peers.keyMap {
+ peer.Start()
+ }
- device.BindClose()
+ case false:
+ device.BindClose()
- // stop every peer
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
- for _, peer := range device.peers {
- peer.Stop()
+ for _, peer := range device.peers.keyMap {
+ println("stopping peer")
+ peer.Stop()
+ }
}
- }
- // update state variables
- // and check for state change in the mean time
+ // update state variables
+
+ device.state.current = newIsUp
+ device.state.changing.Set(false)
+ }()
+
+ // check for state change in the mean time
- device.state.current = newIsUp
- device.state.changing.Set(false)
deviceUpdateState(device)
}
@@ -133,18 +197,6 @@ func (device *Device) Down() {
deviceUpdateState(device)
}
-/* Warning:
- * The caller must hold the device mutex (write lock)
- */
-func removePeerUnsafe(device *Device, key NoisePublicKey) {
- peer, ok := device.peers[key]
- if !ok {
- return
- }
- device.routingTable.RemovePeer(peer)
- delete(device.peers, key)
-}
-
func (device *Device) IsUnderLoad() bool {
// check if currently under load
@@ -152,54 +204,66 @@ func (device *Device) IsUnderLoad() bool {
now := time.Now()
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
if underLoad {
- device.underLoadUntil.Store(now.Add(time.Second))
+ device.rate.underLoadUntil.Store(now.Add(time.Second))
return true
}
// check if recently under load
- until := device.underLoadUntil.Load().(time.Time)
+ until := device.rate.underLoadUntil.Load().(time.Time)
return until.After(now)
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
- device.mutex.Lock()
- defer device.mutex.Unlock()
+
+ // lock required resources
+
+ device.noise.mutex.Lock()
+ defer device.noise.mutex.Unlock()
+
+ device.routing.mutex.Lock()
+ defer device.routing.mutex.Unlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ for _, peer := range device.peers.keyMap {
+ peer.handshake.mutex.RLock()
+ defer peer.handshake.mutex.RUnlock()
+ }
// remove peers with matching public keys
publicKey := sk.publicKey()
- for key, peer := range device.peers {
- h := &peer.handshake
- h.mutex.RLock()
- if h.remoteStatic.Equals(publicKey) {
- removePeerUnsafe(device, key)
+ for key, peer := range device.peers.keyMap {
+ if peer.handshake.remoteStatic.Equals(publicKey) {
+ unsafeRemovePeer(device, peer, key)
}
- h.mutex.RUnlock()
}
// update key material
- device.privateKey = sk
- device.publicKey = publicKey
+ device.noise.privateKey = sk
+ device.noise.publicKey = publicKey
device.mac.Init(publicKey)
- // do DH pre-computations
+ // do static-static DH pre-computations
+
+ rmKey := device.noise.privateKey.IsZero()
- rmKey := device.privateKey.IsZero()
+ for key, peer := range device.peers.keyMap {
+
+ hs := &peer.handshake
- for key, peer := range device.peers {
- h := &peer.handshake
- h.mutex.Lock()
if rmKey {
- h.precomputedStaticStatic = [NoisePublicKeySize]byte{}
+ hs.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
- h.precomputedStaticStatic = device.privateKey.sharedSecret(h.remoteStatic)
- if isZero(h.precomputedStaticStatic[:]) {
- removePeerUnsafe(device, key)
- }
+ hs.precomputedStaticStatic = device.noise.privateKey.sharedSecret(hs.remoteStatic)
+ }
+
+ if isZero(hs.precomputedStaticStatic[:]) {
+ unsafeRemovePeer(device, peer, key)
}
- h.mutex.Unlock()
}
return nil
@@ -215,21 +279,23 @@ func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device)
- device.mutex.Lock()
- defer device.mutex.Unlock()
device.isUp.Set(false)
device.isClosed.Set(false)
device.log = logger
- device.peers = make(map[NoisePublicKey]*Peer)
device.tun.device = tun
+ device.peers.keyMap = make(map[NoisePublicKey]*Peer)
- device.indices.Init()
- device.ratelimiter.Init()
+ // initialize anti-DoS / anti-scanning features
+
+ device.rate.limiter.Init()
+ device.rate.underLoadUntil.Store(time.Time{})
- device.routingTable.Reset()
- device.underLoadUntil.Store(time.Time{})
+ // initialize noise & crypt-key routine
+
+ device.indices.Init()
+ device.routing.table.Reset()
// setup buffer pool
@@ -264,36 +330,50 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
- go device.ratelimiter.RoutineGarbageCollector(device.signal.stop)
+ go device.rate.limiter.RoutineGarbageCollector(device.signal.stop)
return device
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
- device.mutex.RLock()
- defer device.mutex.RUnlock()
- return device.peers[pk]
+ device.peers.mutex.RLock()
+ defer device.peers.mutex.RUnlock()
+
+ return device.peers.keyMap[pk]
}
func (device *Device) RemovePeer(key NoisePublicKey) {
- device.mutex.Lock()
- defer device.mutex.Unlock()
- removePeerUnsafe(device, key)
+ device.noise.mutex.Lock()
+ defer device.noise.mutex.Unlock()
+
+ device.routing.mutex.Lock()
+ defer device.routing.mutex.Unlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ // stop peer and remove from routing
+
+ peer, ok := device.peers.keyMap[key]
+ if ok {
+ unsafeRemovePeer(device, peer, key)
+ }
}
func (device *Device) RemoveAllPeers() {
- device.mutex.Lock()
- defer device.mutex.Unlock()
- for key, peer := range device.peers {
- peer.Stop()
- peer, ok := device.peers[key]
- if !ok {
- return
- }
- device.routingTable.RemovePeer(peer)
- delete(device.peers, key)
+ device.routing.mutex.Lock()
+ defer device.routing.mutex.Unlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ for key, peer := range device.peers.keyMap {
+ println("rm", peer.String())
+ unsafeRemovePeer(device, peer, key)
}
+
+ device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
func (device *Device) Close() {
@@ -305,7 +385,6 @@ func (device *Device) Close() {
device.tun.device.Close()
device.BindClose()
device.isUp.Set(false)
- println("remove")
device.RemoveAllPeers()
device.log.Info.Println("Interface closed")
}
diff --git a/src/noise_helpers.go b/src/noise_helpers.go
index 24302c0..1e2de5f 100644
--- a/src/noise_helpers.go
+++ b/src/noise_helpers.go
@@ -3,6 +3,7 @@ package main
import (
"crypto/hmac"
"crypto/rand"
+ "crypto/subtle"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519"
"hash"
@@ -58,11 +59,11 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
}
func isZero(val []byte) bool {
- var acc byte
+ acc := 1
for _, b := range val {
- acc |= b
+ acc &= subtle.ConstantTimeByteEq(b, 0)
}
- return acc == 0
+ return acc == 1
}
func setZero(arr []byte) {
diff --git a/src/noise_protocol.go b/src/noise_protocol.go
index 2f9e1d5..d620a0d 100644
--- a/src/noise_protocol.go
+++ b/src/noise_protocol.go
@@ -137,6 +137,10 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
+
+ device.noise.mutex.Lock()
+ defer device.noise.mutex.Unlock()
+
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
@@ -187,7 +191,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
ss[:],
)
aead, _ := chacha20poly1305.New(key[:])
- aead.Seal(msg.Static[:0], ZeroNonce[:], device.publicKey[:], handshake.hash[:])
+ aead.Seal(msg.Static[:0], ZeroNonce[:], device.noise.publicKey[:], handshake.hash[:])
}()
handshake.mixHash(msg.Static[:])
@@ -212,16 +216,19 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
}
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
- if msg.Type != MessageInitiationType {
- return nil
- }
-
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
- mixHash(&hash, &InitialHash, device.publicKey[:])
+ if msg.Type != MessageInitiationType {
+ return nil
+ }
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
+ mixHash(&hash, &InitialHash, device.noise.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
@@ -231,7 +238,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
- ss := device.privateKey.sharedSecret(msg.Ephemeral)
+ ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
@@ -407,7 +414,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
}()
func() {
- ss := device.privateKey.sharedSecret(msg.Ephemeral)
+ ss := device.noise.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
diff --git a/src/peer.go b/src/peer.go
index 5ad4511..3b8f7cc 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -14,7 +14,6 @@ const (
)
type Peer struct {
- id uint
isRunning AtomicBool
mutex deadlock.RWMutex
persistentKeepaliveInterval uint64
@@ -22,17 +21,20 @@ type Peer struct {
handshake Handshake
device *Device
endpoint Endpoint
- stats struct {
+
+ stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
}
+
time struct {
mutex deadlock.RWMutex
lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake
nextKeepalive time.Time
}
+
signal struct {
newKeyPair Signal // size 1, new key pair was generated
handshakeCompleted Signal // size 1, handshake completed
@@ -41,7 +43,9 @@ type Peer struct {
messageSend Signal // size 1, message was send to peer
messageReceived Signal // size 1, authenticated message recv
}
+
timer struct {
+
// state related to WireGuard timers
keepalivePersistent Timer // set for persistent keepalives
@@ -54,17 +58,20 @@ type Peer struct {
sendLastMinuteHandshake bool
needAnotherKeepalive bool
}
+
queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work
}
+
routines struct {
mutex deadlock.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop
stop Signal // size 0, stop all goroutines in peer
}
+
mac CookieGenerator
}
@@ -74,8 +81,22 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return nil, errors.New("Device closed")
}
- device.mutex.Lock()
- defer device.mutex.Unlock()
+ // lock resources
+
+ device.state.mutex.Lock()
+ defer device.state.mutex.Unlock()
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ // check if over limit
+
+ if len(device.peers.keyMap) >= MaxPeers {
+ return nil, errors.New("Too many peers")
+ }
// create peer
@@ -94,32 +115,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer()
- // assign id for debugging
-
- peer.id = device.idCounter
- device.idCounter += 1
-
- // check if over limit
-
- if len(device.peers) >= MaxPeers {
- return nil, errors.New("Too many peers")
- }
-
// map public key
- _, ok := device.peers[pk]
+ _, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("Adding existing peer")
}
- device.peers[pk] = peer
+ device.peers.keyMap[pk] = peer
// precompute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
- handshake.precomputedStaticStatic =
- device.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic = device.noise.privateKey.sharedSecret(pk)
handshake.mutex.Unlock()
// reset endpoint
@@ -134,11 +143,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// start peer
- peer.device.state.mutex.Lock()
if peer.device.isUp.Get() {
peer.Start()
}
- peer.device.state.mutex.Unlock()
return peer, nil
}
@@ -166,14 +173,12 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
func (peer *Peer) String() string {
if peer.endpoint == nil {
return fmt.Sprintf(
- "peer(%d unknown %s)",
- peer.id,
+ "peer(unknown %s)",
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
return fmt.Sprintf(
- "peer(%d %s %s)",
- peer.id,
+ "peer(%s %s)",
peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
@@ -181,8 +186,12 @@ func (peer *Peer) String() string {
func (peer *Peer) Start() {
+ if peer.device.isClosed.Get() {
+ return
+ }
+
peer.routines.mutex.Lock()
- defer peer.routines.mutex.Lock()
+ defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println("Starting:", peer.String())
@@ -222,7 +231,7 @@ func (peer *Peer) Start() {
func (peer *Peer) Stop() {
peer.routines.mutex.Lock()
- defer peer.routines.mutex.Lock()
+ defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println("Stopping:", peer.String())
diff --git a/src/receive.go b/src/receive.go
index 5ad7c4b..1f44df2 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -372,7 +372,7 @@ func (device *Device) RoutineHandshake() {
// check ratelimiter
- if !device.ratelimiter.Allow(elem.endpoint.DstIP()) {
+ if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
continue
}
}
@@ -495,19 +495,23 @@ func (device *Device) RoutineHandshake() {
func (peer *Peer) RoutineSequentialReceiver() {
+ defer peer.routines.stopping.Done()
+
device := peer.device
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
- logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
+ logDebug.Println("Routine, sequential receiver, started for peer", peer.String())
+
+ peer.routines.starting.Done()
for {
select {
case <-peer.routines.stop.Wait():
- logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
+ logDebug.Println("Routine, sequential receiver, stopped for peer", peer.String())
return
case elem := <-peer.queue.inbound:
@@ -581,7 +585,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.routingTable.LookupIPv4(src) != peer {
+ if device.routing.table.LookupIPv4(src) != peer {
logInfo.Println(
"IPv4 packet with disallowed source address from",
peer.String(),
@@ -609,7 +613,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.routingTable.LookupIPv6(src) != peer {
+ if device.routing.table.LookupIPv6(src) != peer {
logInfo.Println(
"IPv6 packet with disallowed source address from",
peer.String(),
diff --git a/src/send.go b/src/send.go
index e0a546d..7488d3a 100644
--- a/src/send.go
+++ b/src/send.go
@@ -151,14 +151,14 @@ func (device *Device) RoutineReadFromTUN() {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
- peer = device.routingTable.LookupIPv4(dst)
+ peer = device.routing.table.LookupIPv4(dst)
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
- peer = device.routingTable.LookupIPv6(dst)
+ peer = device.routing.table.LookupIPv6(dst)
default:
logDebug.Println("Received packet with unknown IP version")
@@ -187,10 +187,14 @@ func (device *Device) RoutineReadFromTUN() {
func (peer *Peer) RoutineNonce() {
var keyPair *KeyPair
+ defer peer.routines.stopping.Done()
+
device := peer.device
logDebug := device.log.Debug
logDebug.Println("Routine, nonce worker, started for peer", peer.String())
+ peer.routines.starting.Done()
+
for {
NextPacket:
select {
diff --git a/src/timers.go b/src/timers.go
index f1ed9c5..2ef105e 100644
--- a/src/timers.go
+++ b/src/timers.go
@@ -303,7 +303,7 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
- "Failed to send handshake to peer:", peer.String())
+ "Failed to send handshake to peer:", peer.String(), "(", err, ")")
}
case <-peer.timer.handshakeDeadline.Wait():
@@ -326,7 +326,7 @@ func (peer *Peer) RoutineTimerHandler() {
err := peer.sendNewHandshake()
if err != nil {
logInfo.Println(
- "Failed to send handshake to peer:", peer.String())
+ "Failed to send handshake to peer:", peer.String(), "(", err, ")")
}
peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
diff --git a/src/tun_linux.go b/src/tun_linux.go
index daa2462..9756169 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -313,7 +313,7 @@ func CreateTUNFromFile(name string, fd *os.File) (TUNDevice, error) {
}
go device.RoutineNetlinkListener()
- go device.RoutineHackListener() // cross namespace
+ // go device.RoutineHackListener() // cross namespace
// set default MTU
@@ -369,7 +369,7 @@ func CreateTUN(name string) (TUNDevice, error) {
}
go device.RoutineNetlinkListener()
- go device.RoutineHackListener() // cross namespace
+ // go device.RoutineHackListener() // cross namespace
// set default MTU
diff --git a/src/uapi.go b/src/uapi.go
index 68ebe43..caaa498 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -25,32 +25,51 @@ func (s *IPCError) ErrorCode() int64 {
func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
- // create lines
+ device.log.Debug.Println("UAPI: Processing get operation")
- device.mutex.RLock()
- device.net.mutex.RLock()
+ // create lines
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
}
- if !device.privateKey.IsZero() {
- send("private_key=" + device.privateKey.ToHex())
- }
+ func() {
- if device.net.port != 0 {
- send(fmt.Sprintf("listen_port=%d", device.net.port))
- }
+ // lock required resources
- if device.net.fwmark != 0 {
- send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
- }
+ device.net.mutex.RLock()
+ defer device.net.mutex.RUnlock()
+
+ device.noise.mutex.RLock()
+ defer device.noise.mutex.RUnlock()
+
+ device.routing.mutex.RLock()
+ defer device.routing.mutex.RUnlock()
+
+ device.peers.mutex.Lock()
+ defer device.peers.mutex.Unlock()
+
+ // serialize device related values
+
+ if !device.noise.privateKey.IsZero() {
+ send("private_key=" + device.noise.privateKey.ToHex())
+ }
+
+ if device.net.port != 0 {
+ send(fmt.Sprintf("listen_port=%d", device.net.port))
+ }
+
+ if device.net.fwmark != 0 {
+ send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
+ }
- for _, peer := range device.peers {
- func() {
+ // serialize each peer state
+
+ for _, peer := range device.peers.keyMap {
peer.mutex.RLock()
defer peer.mutex.RUnlock()
+
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
if peer.endpoint != nil {
@@ -69,16 +88,14 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
atomic.LoadUint64(&peer.persistentKeepaliveInterval),
))
- for _, ip := range device.routingTable.AllowedIPs(peer) {
+ for _, ip := range device.routing.table.AllowedIPs(peer) {
send("allowed_ip=" + ip.String())
}
- }()
- }
- device.net.mutex.RUnlock()
- device.mutex.RUnlock()
+ }
+ }()
- // send lines
+ // send lines (does not require resource locks)
for _, line := range lines {
_, err := socket.WriteString(line + "\n")
@@ -94,7 +111,6 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
scanner := bufio.NewScanner(socket)
- logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
@@ -130,6 +146,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalid}
}
+ logDebug.Println("UAPI: Updating device private key")
device.SetPrivateKey(sk)
case "listen_port":
@@ -144,6 +161,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update port and rebind
+ logDebug.Println("UAPI: Updating listen port")
+
device.net.mutex.Lock()
device.net.port = uint16(port)
device.net.mutex.Unlock()
@@ -170,6 +189,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
+ logDebug.Println("UAPI: Updating fwmark")
+
device.net.mutex.Lock()
device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock()
@@ -181,6 +202,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "public_key":
// switch to peer configuration
+ logDebug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
case "replace_peers":
@@ -188,6 +210,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
+ logDebug.Println("UAPI: Removing all peers")
device.RemoveAllPeers()
default:
@@ -203,43 +226,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
switch key {
case "public_key":
- var pubKey NoisePublicKey
- err := pubKey.FromHex(value)
+ var publicKey NoisePublicKey
+ err := publicKey.FromHex(value)
if err != nil {
logError.Println("Failed to get peer by public_key:", err)
return &IPCError{Code: ipcErrorInvalid}
}
- // check if public key of peer equal to device
+ // ignore peer with public key of device
- device.mutex.RLock()
- if device.publicKey.Equals(pubKey) {
-
- // create dummy instance (not added to device)
+ device.noise.mutex.RLock()
+ equals := device.noise.publicKey.Equals(publicKey)
+ device.noise.mutex.RUnlock()
+ if equals {
peer = &Peer{}
dummy = true
- device.mutex.RUnlock()
- logInfo.Println("Ignoring peer with public key of device")
+ }
- } else {
+ // find peer referenced
- // find peer referenced
+ peer = device.LookupPeer(publicKey)
- peer, _ = device.peers[pubKey]
- device.mutex.RUnlock()
- if peer == nil {
- peer, err = device.NewPeer(pubKey)
- if err != nil {
- logError.Println("Failed to create new peer:", err)
- return &IPCError{Code: ipcErrorInvalid}
- }
+ if peer == nil {
+ peer, err = device.NewPeer(publicKey)
+ if err != nil {
+ logError.Println("Failed to create new peer:", err)
+ return &IPCError{Code: ipcErrorInvalid}
}
- peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
- dummy = false
-
+ logDebug.Println("UAPI: Created new peer:", peer.String())
}
+ peer.mutex.Lock()
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
+ peer.mutex.Unlock()
+
case "remove":
// remove currently selected peer from device
@@ -249,7 +270,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid}
}
if !dummy {
- logDebug.Println("Removing", peer.String())
+ logDebug.Println("UAPI: Removing peer:", peer.String())
device.RemovePeer(peer.handshake.remoteStatic)
}
peer = &Peer{}
@@ -259,9 +280,12 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update PSK
- peer.mutex.Lock()
+ logDebug.Println("UAPI: Updating pre-shared key for peer:", peer.String())
+
+ peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
- peer.mutex.Unlock()
+ peer.handshake.mutex.Unlock()
+
if err != nil {
logError.Println("Failed to set preshared_key:", err)
return &IPCError{Code: ipcErrorInvalid}
@@ -271,6 +295,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// set endpoint destination
+ logDebug.Println("UAPI: Updating endpoint for peer:", peer.String())
+
err := func() error {
peer.mutex.Lock()
defer peer.mutex.Unlock()
@@ -292,6 +318,8 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update keep-alive interval
+ logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer.String())
+
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to set persistent_keepalive_interval:", err)
@@ -316,25 +344,41 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "replace_allowed_ips":
+
+ logDebug.Println("UAPI: Removing all allowed IPs for peer:", peer.String())
+
if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid}
}
- if !dummy {
- device.routingTable.RemovePeer(peer)
+
+ if dummy {
+ continue
}
+ device.routing.mutex.Lock()
+ device.routing.table.RemovePeer(peer)
+ device.routing.mutex.Unlock()
+
case "allowed_ip":
+
+ logDebug.Println("UAPI: Adding allowed_ip to peer:", peer.String())
+
_, network, err := net.ParseCIDR(value)
if err != nil {
logError.Println("Failed to set allowed_ip:", err)
return &IPCError{Code: ipcErrorInvalid}
}
- ones, _ := network.Mask.Size()
- if !dummy {
- device.routingTable.Insert(network.IP, uint(ones), peer)
+
+ if dummy {
+ continue
}
+ ones, _ := network.Mask.Size()
+ device.routing.mutex.Lock()
+ device.routing.table.Insert(network.IP, uint(ones), peer)
+ device.routing.mutex.Unlock()
+
default:
logError.Println("Invalid UAPI key (peer configuration):", key)
return &IPCError{Code: ipcErrorInvalid}