summaryrefslogtreecommitdiffhomepage
path: root/device
diff options
context:
space:
mode:
Diffstat (limited to 'device')
-rw-r--r--device/device.go279
-rw-r--r--device/device_test.go2
-rw-r--r--device/devicestate_string.go26
-rw-r--r--device/peer.go8
-rw-r--r--device/receive.go2
-rw-r--r--device/send.go4
-rw-r--r--device/timers.go2
-rw-r--r--device/uapi.go4
8 files changed, 188 insertions, 139 deletions
diff --git a/device/device.go b/device/device.go
index 9ea7c24..c637e38 100644
--- a/device/device.go
+++ b/device/device.go
@@ -21,17 +21,26 @@ import (
)
type Device struct {
- isUp AtomicBool // device is (going) up
- isClosed AtomicBool // device is closed? (acting as guard)
- log *Logger
+ log *Logger
// synchronized resources (locks acquired in order)
state struct {
+ // state holds the device's state. It is accessed atomically.
+ // Use the device.deviceState method to read it.
+ // If state.mu is (r)locked, state is the current state of the device.
+ // Without state.mu (r)locked, state is either the current state
+ // of the device or the intended future state of the device.
+ // For example, while executing a call to Up, state will be deviceStateUp.
+ // There is no guarantee that that intended future state of the device
+ // will become the actual state; Up can fail.
+ // The device can also change state multiple times between time of check and time of use.
+ // Unsynchronized uses of state must therefore be advisory/best-effort only.
+ state uint32 // actually a deviceState, but typed uint32 for conveniene
+ // stopping blocks until all inputs to Device have been closed.
stopping sync.WaitGroup
- sync.Mutex
- changing AtomicBool
- current bool
+ // mu protects state changes.
+ mu sync.Mutex
}
net struct {
@@ -87,6 +96,43 @@ type Device struct {
closed chan struct{}
}
+// deviceState represents the state of a Device.
+// There are four states: new, down, up, closed.
+// However, state new should never be observable.
+// Transitions:
+//
+// new -> down -----+
+// ↑↓ ↓
+// up -> closed
+//
+type deviceState uint32
+
+//go:generate stringer -type deviceState -trimprefix=deviceState
+const (
+ deviceStateNew deviceState = iota
+ deviceStateDown
+ deviceStateUp
+ deviceStateClosed
+)
+
+// deviceState returns device.state.state as a deviceState
+// See those docs for how to interpret this value.
+func (device *Device) deviceState() deviceState {
+ return deviceState(atomic.LoadUint32(&device.state.state))
+}
+
+// isClosed reports whether the device is closed (or is closing).
+// See device.state.state comments for how to interpret this value.
+func (device *Device) isClosed() bool {
+ return device.deviceState() == deviceStateClosed
+}
+
+// isUp reports whether the device is up (or is attempting to come up).
+// See device.state.state comments for how to interpret this value.
+func (device *Device) isUp() bool {
+ return device.deviceState() == deviceStateUp
+}
+
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
// An outboundQueue is ref-counted using its wg field.
// An outboundQueue created with newOutboundQueue has one reference.
@@ -154,91 +200,82 @@ func newHandshakeQueue() *handshakeQueue {
* Must hold device.peers.Mutex
*/
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
-
// stop routing and processing of packets
-
device.allowedips.RemoveByPeer(peer)
peer.Stop()
// remove from peer map
-
delete(device.peers.keyMap, key)
device.peers.empty.Set(len(device.peers.keyMap) == 0)
}
-func deviceUpdateState(device *Device) {
-
- // check if state already being updated (guard)
-
- if device.state.changing.Swap(true) {
+// changeState attempts to change the device state to match want.
+func (device *Device) changeState(want deviceState) {
+ device.state.mu.Lock()
+ defer device.state.mu.Unlock()
+ old := device.deviceState()
+ if old == deviceStateClosed {
+ // once closed, always closed
+ device.log.Verbosef("Interface closed, ignored requested state %s", want)
return
}
-
- // compare to current state of device
-
- device.state.Lock()
-
- newIsUp := device.isUp.Get()
-
- if newIsUp == device.state.current {
- device.state.changing.Set(false)
- device.state.Unlock()
+ switch want {
+ case old:
+ device.log.Verbosef("Interface already in state %s", want)
return
- }
-
- // change state of device
-
- switch newIsUp {
- case true:
- if err := device.BindUpdate(); err != nil {
- device.log.Errorf("Unable to update bind: %v", err)
- device.isUp.Set(false)
+ case deviceStateUp:
+ atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
+ if ok := device.upLocked(); ok {
break
}
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Start()
- if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
- peer.SendKeepalive()
- }
- }
- device.peers.RUnlock()
-
- case false:
- device.BindClose()
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Stop()
- }
- device.peers.RUnlock()
+ fallthrough // up failed; bring the device all the way back down
+ case deviceStateDown:
+ atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
+ device.downLocked()
}
+ device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
+}
- // update state variables
-
- device.state.current = newIsUp
- device.state.changing.Set(false)
- device.state.Unlock()
-
- // check for state change in the mean time
+// upLocked attempts to bring the device up and reports whether it succeeded.
+// The caller must hold device.state.mu and is responsible for updating device.state.state.
+func (device *Device) upLocked() bool {
+ if err := device.BindUpdate(); err != nil {
+ device.log.Errorf("Unable to update bind: %v", err)
+ return false
+ }
- deviceUpdateState(device)
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Start()
+ if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
+ peer.SendKeepalive()
+ }
+ }
+ device.peers.RUnlock()
+ return true
}
-func (device *Device) Up() {
-
- // closed device cannot be brought up
+// downLocked attempts to bring the device down.
+// The caller must hold device.state.mu and is responsible for updating device.state.state.
+func (device *Device) downLocked() {
+ err := device.BindClose()
+ if err != nil {
+ device.log.Errorf("Bind close failed: %v", err)
+ }
- if device.isClosed.Get() {
- return
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Stop()
}
+ device.peers.RUnlock()
+}
- device.isUp.Set(true)
- deviceUpdateState(device)
+func (device *Device) Up() {
+ device.changeState(deviceStateUp)
}
func (device *Device) Down() {
- device.isUp.Set(false)
- deviceUpdateState(device)
+ device.changeState(deviceStateDown)
}
func (device *Device) IsUnderLoad() bool {
@@ -310,6 +347,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device := new(Device)
+ device.state.state = uint32(deviceStateDown)
device.closed = make(chan struct{})
device.log = logger
device.tun.device = tunDevice
@@ -382,19 +420,16 @@ func (device *Device) RemoveAllPeers() {
}
func (device *Device) Close() {
- if device.isClosed.Swap(true) {
+ device.state.mu.Lock()
+ defer device.state.mu.Unlock()
+ if device.isClosed() {
return
}
-
+ atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
device.log.Verbosef("Device closing")
- device.state.changing.Set(true)
- device.state.Lock()
- defer device.state.Unlock()
device.tun.device.Close()
- device.BindClose()
-
- device.isUp.Set(false)
+ device.downLocked()
// Remove peers before closing queues,
// because peers assume that queues are active.
@@ -410,8 +445,7 @@ func (device *Device) Close() {
device.rate.limiter.Close()
- device.state.changing.Set(false)
- device.log.Verbosef("Interface closed")
+ device.log.Verbosef("Device closed")
close(device.closed)
}
@@ -420,7 +454,7 @@ func (device *Device) Wait() chan struct{} {
}
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
- if device.isClosed.Get() {
+ if !device.isUp() {
return
}
@@ -457,27 +491,23 @@ func (device *Device) Bind() conn.Bind {
}
func (device *Device) BindSetMark(mark uint32) error {
-
device.net.Lock()
defer device.net.Unlock()
// check if modified
-
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
-
device.net.fwmark = mark
- if device.isUp.Get() && device.net.bind != nil {
+ if device.isUp() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
-
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
@@ -492,70 +522,63 @@ func (device *Device) BindSetMark(mark uint32) error {
}
func (device *Device) BindUpdate() error {
-
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
-
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
+ if !device.isUp() {
+ return nil
+ }
- if device.isUp.Get() {
-
- // bind to new port
+ // bind to new port
+ var err error
+ netc := &device.net
+ netc.bind, netc.port, err = conn.CreateBind(netc.port)
+ if err != nil {
+ netc.bind = nil
+ netc.port = 0
+ return err
+ }
+ netc.netlinkCancel, err = device.startRouteListener(netc.bind)
+ if err != nil {
+ netc.bind.Close()
+ netc.bind = nil
+ netc.port = 0
+ return err
+ }
- var err error
- netc := &device.net
- netc.bind, netc.port, err = conn.CreateBind(netc.port)
- if err != nil {
- netc.bind = nil
- netc.port = 0
- return err
- }
- netc.netlinkCancel, err = device.startRouteListener(netc.bind)
+ // set fwmark
+ if netc.fwmark != 0 {
+ err = netc.bind.SetMark(netc.fwmark)
if err != nil {
- netc.bind.Close()
- netc.bind = nil
- netc.port = 0
return err
}
+ }
- // set fwmark
-
- if netc.fwmark != 0 {
- err = netc.bind.SetMark(netc.fwmark)
- if err != nil {
- return err
- }
- }
-
- // clear cached source addresses
-
- device.peers.RLock()
- for _, peer := range device.peers.keyMap {
- peer.Lock()
- defer peer.Unlock()
- if peer.endpoint != nil {
- peer.endpoint.ClearSrc()
- }
+ // clear cached source addresses
+ device.peers.RLock()
+ for _, peer := range device.peers.keyMap {
+ peer.Lock()
+ defer peer.Unlock()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
}
- device.peers.RUnlock()
-
- // start receiving routines
-
- device.net.stopping.Add(2)
- device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
- device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
- go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
- go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
-
- device.log.Verbosef("UDP bind has been updated")
}
+ device.peers.RUnlock()
+
+ // start receiving routines
+ device.net.stopping.Add(2)
+ device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+ device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
+ go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
+ device.log.Verbosef("UDP bind has been updated")
return nil
}
diff --git a/device/device_test.go b/device/device_test.go
index 50e3dbc..56ecd17 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -172,7 +172,7 @@ NextAttempt:
// The device might still not be up, e.g. due to an error
// in RoutineTUNEventReader's call to dev.Up that got swallowed.
// Assume it's due to a transient error (port in use), and retry.
- if !p.dev.isUp.Get() {
+ if !p.dev.isUp() {
tb.Logf("device %d did not come up, trying again", i)
p.dev.Close()
continue NextAttempt
diff --git a/device/devicestate_string.go b/device/devicestate_string.go
new file mode 100644
index 0000000..e8f16b0
--- /dev/null
+++ b/device/devicestate_string.go
@@ -0,0 +1,26 @@
+// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
+
+package device
+
+import "strconv"
+
+func _() {
+ // An "invalid array index" compiler error signifies that the constant values have changed.
+ // Re-run the stringer command to generate them again.
+ var x [1]struct{}
+ _ = x[deviceStateNew-0]
+ _ = x[deviceStateDown-1]
+ _ = x[deviceStateUp-2]
+ _ = x[deviceStateClosed-3]
+}
+
+const _deviceState_name = "NewDownUpClosed"
+
+var _deviceState_index = [...]uint8{0, 3, 7, 9, 15}
+
+func (i deviceState) String() string {
+ if i >= deviceState(len(_deviceState_index)-1) {
+ return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
+ }
+ return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
+}
diff --git a/device/peer.go b/device/peer.go
index 0bf19fd..abe8a08 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -62,7 +62,7 @@ type Peer struct {
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
- if device.isClosed.Get() {
+ if device.isClosed() {
return nil, errors.New("device closed")
}
@@ -107,7 +107,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
device.peers.empty.Set(false)
// start peer
- if peer.device.isUp.Get() {
+ if peer.device.isUp() {
peer.Start()
}
@@ -121,7 +121,7 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
if peer.device.net.bind == nil {
// Packets can leak through to SendBuffer while the device is closing.
// When that happens, drop them silently to avoid spurious errors.
- if peer.device.isClosed.Get() {
+ if peer.device.isClosed() {
return nil
}
return errors.New("no bind")
@@ -152,7 +152,7 @@ func (peer *Peer) String() string {
func (peer *Peer) Start() {
// should never start a peer on a closed device
- if peer.device.isClosed.Get() {
+ if peer.device.isClosed() {
return
}
diff --git a/device/receive.go b/device/receive.go
index 21d9dbc..c6a28f7 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -474,7 +474,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
}
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
- if err != nil && !device.isClosed.Get() {
+ if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packet to TUN device: %v", err)
}
if len(peer.queue.inbound) == 0 {
diff --git a/device/send.go b/device/send.go
index b9bcb33..982fec0 100644
--- a/device/send.go
+++ b/device/send.go
@@ -225,7 +225,7 @@ func (device *Device) RoutineReadFromTUN() {
size, err := device.tun.device.Read(elem.buffer[:], offset)
if err != nil {
- if !device.isClosed.Get() {
+ if !device.isClosed() {
device.log.Errorf("Failed to read packet from TUN device: %v", err)
device.Close()
}
@@ -291,7 +291,7 @@ func (peer *Peer) StagePacket(elem *QueueOutboundElement) {
func (peer *Peer) SendStagedPackets() {
top:
- if len(peer.queue.staged) == 0 || !peer.device.isUp.Get() {
+ if len(peer.queue.staged) == 0 || !peer.device.isUp() {
return
}
diff --git a/device/timers.go b/device/timers.go
index 1ea91c7..f740cf0 100644
--- a/device/timers.go
+++ b/device/timers.go
@@ -73,7 +73,7 @@ func (timer *Timer) IsPending() bool {
}
func (peer *Peer) timersActive() bool {
- return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && !peer.device.peers.empty.Get()
+ return peer.isRunning.Get() && peer.device != nil && peer.device.isUp() && !peer.device.peers.empty.Get()
}
func expiredRetransmitHandshake(peer *Peer) {
diff --git a/device/uapi.go b/device/uapi.go
index 3af37e7..406880f 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -258,7 +258,7 @@ type ipcSetPeer struct {
}
func (peer *ipcSetPeer) handlePostConfig() {
- if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp.Get() {
+ if peer.Peer != nil && !peer.dummy && peer.Peer.device.isUp() {
peer.SendStagedPackets()
}
}
@@ -354,7 +354,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
if err != nil {
return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err)
}
- if device.isUp.Get() && !peer.dummy {
+ if device.isUp() && !peer.dummy {
peer.SendKeepalive()
}
}