summaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2018-02-01 11:20:36 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2018-02-01 11:20:36 +0100
commit1e42b1402261d15b87b1b5871f7bc51342b46e34 (patch)
treeaef0aef2eadf6ae7142486069405f45b84d59b58 /src
parenta57c790e36439729a6af7e53ee9068898f3ac992 (diff)
parentf73d2fb2d96bc3fbc8bc4cce452e3c19689de01e (diff)
Merge branch 'timer-teardown' of git.zx2c4.com:wireguard-go into timer-teardown
Diffstat (limited to 'src')
-rw-r--r--src/conn.go19
-rw-r--r--src/device.go111
-rw-r--r--src/peer.go68
-rw-r--r--src/receive.go21
-rw-r--r--src/send.go8
-rw-r--r--src/uapi.go12
6 files changed, 173 insertions, 66 deletions
diff --git a/src/conn.go b/src/conn.go
index 1d033ff..c2f5dee 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -64,9 +64,13 @@ func unsafeCloseBind(device *Device) error {
return err
}
-/* Must hold device and net lock
- */
-func unsafeUpdateBind(device *Device) error {
+func (device *Device) BindUpdate() error {
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
+ netc := &device.net
+ netc.mutex.Lock()
+ defer netc.mutex.Unlock()
// close existing sockets
@@ -74,18 +78,13 @@ func unsafeUpdateBind(device *Device) error {
return err
}
- // assumption: netc.update WaitGroup should be exactly 1
-
// open new sockets
if device.isUp.Get() {
- device.log.Debug.Println("UDP bind updating")
-
// bind to new port
var err error
- netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port)
if err != nil {
netc.bind = nil
@@ -109,7 +108,7 @@ func unsafeUpdateBind(device *Device) error {
peer.mutex.Unlock()
}
- // decrease waitgroup to 0
+ // start receiving routines
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
@@ -120,7 +119,7 @@ func unsafeUpdateBind(device *Device) error {
return nil
}
-func closeBind(device *Device) error {
+func (device *Device) BindClose() error {
device.mutex.Lock()
device.net.mutex.Lock()
err := unsafeCloseBind(device)
diff --git a/src/device.go b/src/device.go
index 5f8e91b..f1c09c6 100644
--- a/src/device.go
+++ b/src/device.go
@@ -9,7 +9,7 @@ import (
)
type Device struct {
- isUp AtomicBool // device is up (TUN interface up)?
+ isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
log *Logger // collection of loggers for levels
idCounter uint // for assigning debug ids to peers
@@ -18,6 +18,11 @@ type Device struct {
device TUNDevice
mtu int32
}
+ state struct {
+ mutex deadlock.Mutex
+ changing AtomicBool
+ current bool
+ }
pool struct {
messageBuffers sync.Pool
}
@@ -46,37 +51,86 @@ type Device struct {
mac CookieChecker
}
-func (device *Device) Up() {
- device.mutex.Lock()
- defer device.mutex.Unlock()
+func deviceUpdateState(device *Device) {
- device.net.mutex.Lock()
- defer device.net.mutex.Unlock()
+ // check if state already being updated (guard)
- if device.isUp.Swap(true) {
+ if device.state.changing.Swap(true) {
return
}
- unsafeUpdateBind(device)
+ // compare to current state of device
+
+ device.state.mutex.Lock()
+
+ newIsUp := device.isUp.Get()
+
+ if newIsUp == device.state.current {
+ device.state.mutex.Unlock()
+ device.state.changing.Set(false)
+ return
+ }
+
+ device.state.mutex.Unlock()
+
+ // change state of device
+
+ switch newIsUp {
+ case true:
+
+ // start listener
+
+ if err := device.BindUpdate(); err != nil {
+ device.isUp.Set(false)
+ break
+ }
+
+ // start every peer
+
+ for _, peer := range device.peers {
+ peer.Start()
+ }
+
+ case false:
+
+ // stop listening
+
+ device.BindClose()
- for _, peer := range device.peers {
- peer.Start()
+ // stop every peer
+
+ for _, peer := range device.peers {
+ peer.Stop()
+ }
}
+
+ // update state variables
+ // and check for state change in the mean time
+
+ device.state.current = newIsUp
+ device.state.changing.Set(false)
+ deviceUpdateState(device)
}
-func (device *Device) Down() {
- device.mutex.Lock()
- defer device.mutex.Unlock()
+func (device *Device) Up() {
+
+ // closed device cannot be brought up
- if !device.isUp.Swap(false) {
+ if device.isClosed.Get() {
return
}
- closeBind(device)
+ device.state.mutex.Lock()
+ device.isUp.Set(true)
+ device.state.mutex.Unlock()
+ deviceUpdateState(device)
+}
- for _, peer := range device.peers {
- peer.Stop()
- }
+func (device *Device) Down() {
+ device.state.mutex.Lock()
+ device.isUp.Set(false)
+ device.state.mutex.Unlock()
+ deviceUpdateState(device)
}
/* Warning:
@@ -87,7 +141,6 @@ func removePeerUnsafe(device *Device, key NoisePublicKey) {
if !ok {
return
}
- peer.Stop()
device.routingTable.RemovePeer(peer)
delete(device.peers, key)
}
@@ -231,20 +284,30 @@ func (device *Device) RemovePeer(key NoisePublicKey) {
func (device *Device) RemoveAllPeers() {
device.mutex.Lock()
defer device.mutex.Unlock()
- for key := range device.peers {
- removePeerUnsafe(device, key)
+
+ for key, peer := range device.peers {
+ peer.Stop()
+ peer, ok := device.peers[key]
+ if !ok {
+ return
+ }
+ device.routingTable.RemovePeer(peer)
+ delete(device.peers, key)
}
}
func (device *Device) Close() {
+ device.log.Info.Println("Device closing")
if device.isClosed.Swap(true) {
return
}
- device.log.Info.Println("Closing device")
- device.RemoveAllPeers()
device.signal.stop.Broadcast()
device.tun.device.Close()
- closeBind(device)
+ device.BindClose()
+ device.isUp.Set(false)
+ println("remove")
+ device.RemoveAllPeers()
+ device.log.Info.Println("Interface closed")
}
func (device *Device) Wait() chan struct{} {
diff --git a/src/peer.go b/src/peer.go
index 3d82989..5ad4511 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -4,6 +4,7 @@ import (
"encoding/base64"
"errors"
"fmt"
+ "github.com/sasha-s/go-deadlock"
"sync"
"time"
)
@@ -14,7 +15,8 @@ const (
type Peer struct {
id uint
- mutex sync.RWMutex
+ isRunning AtomicBool
+ mutex deadlock.RWMutex
persistentKeepaliveInterval uint64
keyPairs KeyPairs
handshake Handshake
@@ -26,7 +28,7 @@ type Peer struct {
lastHandshakeNano int64 // nano seconds since epoch
}
time struct {
- mutex sync.RWMutex
+ mutex deadlock.RWMutex
lastSend time.Time // last send message
lastHandshake time.Time // last completed handshake
nextKeepalive time.Time
@@ -58,7 +60,7 @@ type Peer struct {
inbound chan *QueueInboundElement // sequential ordering of work
}
routines struct {
- mutex sync.Mutex // held when stopping / starting routines
+ mutex deadlock.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop
stop Signal // size 0, stop all goroutines in peer
@@ -67,6 +69,14 @@ type Peer struct {
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
+
+ if device.isClosed.Get() {
+ return nil, errors.New("Device closed")
+ }
+
+ device.mutex.Lock()
+ defer device.mutex.Unlock()
+
// create peer
peer := new(Peer)
@@ -75,17 +85,17 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.mac.Init(pk)
peer.device = device
+ peer.isRunning.Set(false)
+ peer.timer.zeroAllKeys = NewTimer()
peer.timer.keepalivePersistent = NewTimer()
peer.timer.keepalivePassive = NewTimer()
- peer.timer.zeroAllKeys = NewTimer()
peer.timer.handshakeNew = NewTimer()
peer.timer.handshakeDeadline = NewTimer()
peer.timer.handshakeTimeout = NewTimer()
// assign id for debugging
- device.mutex.Lock()
peer.id = device.idCounter
device.idCounter += 1
@@ -102,7 +112,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
return nil, errors.New("Adding existing peer")
}
device.peers[pk] = peer
- device.mutex.Unlock()
// precompute DH
@@ -117,23 +126,20 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil
- // prepare queuing
-
- peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
- peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
-
// prepare signaling & routines
- peer.signal.newKeyPair = NewSignal()
- peer.signal.handshakeBegin = NewSignal()
- peer.signal.handshakeCompleted = NewSignal()
- peer.signal.flushNonceQueue = NewSignal()
-
peer.routines.mutex.Lock()
peer.routines.stop = NewSignal()
peer.routines.mutex.Unlock()
+ // start peer
+
+ peer.device.state.mutex.Lock()
+ if peer.device.isUp.Get() {
+ peer.Start()
+ }
+ peer.device.state.mutex.Unlock()
+
return peer, nil
}
@@ -148,6 +154,10 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return errors.New("No known endpoint for peer")
}
+ if peer.device.net.bind == nil {
+ return errors.New("No bind")
+ }
+
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
@@ -174,12 +184,26 @@ func (peer *Peer) Start() {
peer.routines.mutex.Lock()
defer peer.routines.mutex.Lock()
+ peer.device.log.Debug.Println("Starting:", peer.String())
+
// stop & wait for ungoing routines (if any)
+ peer.isRunning.Set(false)
peer.routines.stop.Broadcast()
peer.routines.starting.Wait()
peer.routines.stopping.Wait()
+ // prepare queues
+
+ peer.signal.newKeyPair = NewSignal()
+ peer.signal.handshakeBegin = NewSignal()
+ peer.signal.handshakeCompleted = NewSignal()
+ peer.signal.flushNonceQueue = NewSignal()
+
+ peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
+ peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
+
// reset signal and start (new) routines
peer.routines.stop = NewSignal()
@@ -192,6 +216,7 @@ func (peer *Peer) Start() {
go peer.RoutineSequentialReceiver()
peer.routines.starting.Wait()
+ peer.isRunning.Set(true)
}
func (peer *Peer) Stop() {
@@ -199,13 +224,22 @@ func (peer *Peer) Stop() {
peer.routines.mutex.Lock()
defer peer.routines.mutex.Lock()
+ peer.device.log.Debug.Println("Stopping:", peer.String())
+
// stop & wait for ungoing routines (if any)
peer.routines.stop.Broadcast()
peer.routines.starting.Wait()
peer.routines.stopping.Wait()
+ // close queues
+
+ close(peer.queue.nonce)
+ close(peer.queue.outbound)
+ close(peer.queue.inbound)
+
// reset signal (to handle repeated stopping)
peer.routines.stop = NewSignal()
+ peer.isRunning.Set(false)
}
diff --git a/src/receive.go b/src/receive.go
index 0b87a3c..5ad7c4b 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -123,7 +123,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
- return
+ panic("invalid IP version")
}
if err != nil {
@@ -184,9 +184,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
// add to decryption queues
- device.addToDecryptionQueue(device.queue.decryption, elem)
- device.addToInboundQueue(peer.queue.inbound, elem)
- buffer = device.GetMessageBuffer()
+ if peer.isRunning.Get() {
+ device.addToDecryptionQueue(device.queue.decryption, elem)
+ device.addToInboundQueue(peer.queue.inbound, elem)
+ buffer = device.GetMessageBuffer()
+ }
continue
@@ -308,13 +310,20 @@ func (device *Device) RoutineHandshake() {
return
}
- // lookup peer and consume response
+ // lookup peer from index
entry := device.indices.Lookup(reply.Receiver)
+
if entry.peer == nil {
continue
}
- entry.peer.mac.ConsumeReply(&reply)
+
+ // consume reply
+
+ if peer := entry.peer; peer.isRunning.Get() {
+ peer.mac.ConsumeReply(&reply)
+ }
+
continue
case MessageInitiationType, MessageResponseType:
diff --git a/src/send.go b/src/send.go
index fa13c91..e0a546d 100644
--- a/src/send.go
+++ b/src/send.go
@@ -170,9 +170,11 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue
- peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
- addToOutboundQueue(peer.queue.nonce, elem)
- elem = device.NewOutboundElement()
+ if peer.isRunning.Get() {
+ peer.timer.handshakeDeadline.Reset(RekeyAttemptTime)
+ addToOutboundQueue(peer.queue.nonce, elem)
+ elem = device.NewOutboundElement()
+ }
}
}
diff --git a/src/uapi.go b/src/uapi.go
index f66528c..68ebe43 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -144,16 +144,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update port and rebind
- device.mutex.Lock()
device.net.mutex.Lock()
-
device.net.port = uint16(port)
- err = unsafeUpdateBind(device)
-
device.net.mutex.Unlock()
- device.mutex.Unlock()
- if err != nil {
+ if err := device.BindUpdate(); err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{Code: ipcErrorPortInUse}
}
@@ -179,6 +174,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock()
+ if err := device.BindUpdate(); err != nil {
+ logError.Println("Failed to update fwmark:", err)
+ return &IPCError{Code: ipcErrorPortInUse}
+ }
+
case "public_key":
// switch to peer configuration
deviceConfig = false