From 6252de0db9331cbc20074e9d40165266b5816148 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Fri, 15 Jan 2021 14:32:34 -0800 Subject: device: split IpcSetOperation into parts The goal of this change is to make the structure of IpcSetOperation easier to follow. IpcSetOperation contains a small state machine: It starts by configuring the device, then shifts to configuring one peer at a time. Having the code all in one giant method obscured that structure. Split out the parts into helper functions and encapsulate the peer state. This makes the overall structure more apparent. Signed-off-by: Josh Bleecher Snyder --- device/uapi.go | 402 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 198 insertions(+), 204 deletions(-) diff --git a/device/uapi.go b/device/uapi.go index 7f50869..7d180bb 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -41,6 +41,8 @@ func ipcErrorf(code int64, msg string, args ...interface{}) *IPCError { return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } +// IpcGetOperation implements the WireGuard configuration protocol "get" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. func (device *Device) IpcGetOperation(w io.Writer) error { lines := make([]string, 0, 100) send := func(line string) { @@ -116,6 +118,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error { return nil } +// IpcSetOperation implements the WireGuard configuration protocol "set" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. func (device *Device) IpcSetOperation(r io.Reader) (err error) { defer func() { if err != nil { @@ -123,20 +127,14 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { } }() - logDebug := device.log.Debug - - var peer *Peer - dummy := false - createdNewPeer := false + peer := new(ipcSetPeer) deviceConfig := true scanner := bufio.NewScanner(r) for scanner.Scan() { - - // parse line - line := scanner.Text() if line == "" { + // Blank line means terminate operation. return nil } parts := strings.Split(line, "=") @@ -146,243 +144,239 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { key := parts[0] value := parts[1] - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - var sk NoisePrivateKey - err := sk.FromMaybeZeroHex(value) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) - } - logDebug.Println("UAPI: Updating private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.Lock() - device.net.port = uint16(port) - device.net.Unlock() - - if err := device.BindUpdate(); err != nil { - return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) - } - - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) - } - - logDebug.Println("UAPI: Updating fwmark") - - if err := device.BindSetMark(uint32(fwmark)); err != nil { - return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") + if key == "public_key" { + if deviceConfig { + device.log.Debug.Println("UAPI: Transition to peer configuration") deviceConfig = false - - case "replace_peers": - if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) } + // Load/create the peer we are now configuring. + err := device.handlePublicKeyLine(peer, value) + if err != nil { + return err + } + continue } - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - var publicKey NoisePublicKey - err := publicKey.FromHex(value) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) - } - - // ignore peer with public key of device - - device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equals(publicKey) - device.staticIdentity.RUnlock() - - if dummy { - peer = &Peer{} - } else { - peer = device.LookupPeer(publicKey) - } - - createdNewPeer = peer == nil - if createdNewPeer { - peer, err = device.NewPeer(publicKey) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) - } - logDebug.Println(peer, "- UAPI: Created") - } - - case "update_only": - - // allow disabling of creation + var err error + if deviceConfig { + err = device.handleDeviceLine(key, value) + } else { + err = device.handlePeerLine(peer, key, value) + } + if err != nil { + return err + } + } - if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) - } - if createdNewPeer && !dummy { - device.RemovePeer(peer.handshake.remoteStatic) - peer = &Peer{} - dummy = true - } + return scanner.Err() +} - case "remove": +func (device *Device) handleDeviceLine(key, value string) error { + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromMaybeZeroHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) + } + device.log.Debug.Println("UAPI: Updating private key") + device.SetPrivateKey(sk) - // remove currently selected peer from device + case "listen_port": + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) + } - if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) - } - if !dummy { - logDebug.Println(peer, "- UAPI: Removing") - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true + // update port and rebind + device.log.Debug.Println("UAPI: Updating listen port") - case "preshared_key": + device.net.Lock() + device.net.port = uint16(port) + device.net.Unlock() - // update PSK + if err := device.BindUpdate(); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) + } - logDebug.Println(peer, "- UAPI: Updating preshared key") + case "fwmark": + // parse fwmark field + fwmark, err := func() (uint32, error) { + if value == "" { + return 0, nil + } + mark, err := strconv.ParseUint(value, 10, 32) + return uint32(mark), err + }() - peer.handshake.mutex.Lock() - err := peer.handshake.presharedKey.FromHex(value) - peer.handshake.mutex.Unlock() + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) + } - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) - } + device.log.Debug.Println("UAPI: Updating fwmark") - case "endpoint": + if err := device.BindSetMark(uint32(fwmark)); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) + } - // set endpoint destination + case "replace_peers": + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + } + device.log.Debug.Println("UAPI: Removing all peers") + device.RemoveAllPeers() - logDebug.Println(peer, "- UAPI: Updating endpoint") + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) + } - err := func() error { - peer.Lock() - defer peer.Unlock() - endpoint, err := conn.CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - return nil - }() + return nil +} - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) - } +// An ipcSetPeer is the current state of an IPC set operation on a peer. +type ipcSetPeer struct { + *Peer // Peer is the current peer being operated on + dummy bool // dummy reports whether this peer is a temporary, placeholder peer + created bool // new reports whether this is a newly created peer +} - case "persistent_keepalive_interval": +func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { + // Load/create the peer we are configuring. + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) + } - // update persistent keepalive interval + // Ignore peer with the same public key as this device. + device.staticIdentity.RLock() + peer.dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.RUnlock() - logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") + if peer.dummy { + peer.Peer = &Peer{} + } else { + peer.Peer = device.LookupPeer(publicKey) + } - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) - } + peer.created = peer.Peer == nil + if peer.created { + peer.Peer, err = device.NewPeer(publicKey) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) + } + device.log.Debug.Println(peer, "- UAPI: Created") + } + return nil +} - old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) +func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { + switch key { + case "update_only": + // allow disabling of creation + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + } + if peer.created && !peer.dummy { + device.RemovePeer(peer.handshake.remoteStatic) + peer.Peer = &Peer{} + peer.dummy = true + } - // send immediate keepalive if we're turning it on and before it wasn't on + case "remove": + // remove currently selected peer from device + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) + } + if !peer.dummy { + device.log.Debug.Println(peer, "- UAPI: Removing") + device.RemovePeer(peer.handshake.remoteStatic) + } + peer.Peer = &Peer{} + peer.dummy = true - if old == 0 && secs != 0 { - if err != nil { - return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err) - } - if device.isUp.Get() && !dummy { - peer.SendKeepalive() - } - } + case "preshared_key": + device.log.Debug.Println(peer, "- UAPI: Updating preshared key") - case "replace_allowed_ips": + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() - logDebug.Println(peer, "- UAPI: Removing all allowedips") + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) + } - if value != "true" { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) - } + case "endpoint": + device.log.Debug.Println(peer, "- UAPI: Updating endpoint") - if dummy { - continue - } + err := func() error { + peer.Lock() + defer peer.Unlock() + endpoint, err := conn.CreateEndpoint(value) + if err != nil { + return err + } + peer.endpoint = endpoint + return nil + }() - device.allowedips.RemoveByPeer(peer) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) + } - case "allowed_ip": + case "persistent_keepalive_interval": + device.log.Debug.Println(peer, "- UAPI: Updating persistent keepalive interval") - logDebug.Println(peer, "- UAPI: Adding allowedip") + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + } - _, network, err := net.ParseCIDR(value) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) - } + old := atomic.SwapUint32(&peer.persistentKeepaliveInterval, uint32(secs)) - if dummy { - continue - } + // Send immediate keepalive if we're turning it on and before it wasn't on. + if old == 0 && secs != 0 { + if err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to get tun device status: %w", err) + } + if device.isUp.Get() && !peer.dummy { + peer.SendKeepalive() + } + } - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint(ones), peer) + case "replace_allowed_ips": + device.log.Debug.Println(peer, "- UAPI: Removing all allowedips") + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + } + if peer.dummy { + return nil + } + device.allowedips.RemoveByPeer(peer.Peer) - case "protocol_version": + case "allowed_ip": + device.log.Debug.Println(peer, "- UAPI: Adding allowedip") - if value != "1" { - return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) - } + _, network, err := net.ParseCIDR(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) + } + if peer.dummy { + return nil + } + ones, _ := network.Mask.Size() + device.allowedips.Insert(network.IP, uint(ones), peer.Peer) - default: - return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) - } + case "protocol_version": + if value != "1" { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) } + + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) } - return scanner.Err() + return nil } func (device *Device) IpcGet() (string, error) { -- cgit v1.2.3