diff options
Diffstat (limited to 'device/device.go')
-rw-r--r-- | device/device.go | 146 |
1 files changed, 136 insertions, 10 deletions
diff --git a/device/device.go b/device/device.go index 8c08f1c..a9fedea 100644 --- a/device/device.go +++ b/device/device.go @@ -11,15 +11,14 @@ import ( "sync/atomic" "time" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" + "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" ) -const ( - DeviceRoutineNumberPerCPU = 3 - DeviceRoutineNumberAdditional = 2 -) - type Device struct { isUp AtomicBool // device is (going) up isClosed AtomicBool // device is closed? (acting as guard) @@ -39,9 +38,10 @@ type Device struct { starting sync.WaitGroup stopping sync.WaitGroup sync.RWMutex - bind Bind // bind interface - port uint16 // listening port - fwmark uint32 // mark value (0 = disabled) + bind conn.Bind // bind interface + netlinkCancel *rwcancel.RWCancel + port uint16 // listening port + fwmark uint32 // mark value (0 = disabled) } staticIdentity struct { @@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { cpus := runtime.NumCPU() device.state.starting.Wait() device.state.stopping.Wait() - device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) - device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional) for i := 0; i < cpus; i += 1 { + device.state.starting.Add(3) + device.state.stopping.Add(3) go device.RoutineEncryption() go device.RoutineDecryption() go device.RoutineHandshake() } + device.state.starting.Add(2) + device.state.stopping.Add(2) go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() @@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { } device.peers.RUnlock() } + +func unsafeCloseBind(device *Device) error { + var err error + netc := &device.net + if netc.netlinkCancel != nil { + netc.netlinkCancel.Cancel() + } + if netc.bind != nil { + err = netc.bind.Close() + netc.bind = nil + } + netc.stopping.Wait() + return err +} + +func (device *Device) BindSetMark(mark uint32) error { + + device.net.Lock() + defer device.net.Unlock() + + // check if modified + + if device.net.fwmark == mark { + return nil + } + + // update fwmark on existing bind + + device.net.fwmark = mark + if device.isUp.Get() && device.net.bind != nil { + if err := device.net.bind.SetMark(mark); err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + return nil +} + +func (device *Device) BindUpdate() error { + + device.net.Lock() + defer device.net.Unlock() + + // close existing sockets + + if err := unsafeCloseBind(device); err != nil { + return err + } + + // open new sockets + + if device.isUp.Get() { + + // bind to new port + + var err error + netc := &device.net + netc.bind, netc.port, err = conn.CreateBind(netc.port) + if err != nil { + netc.bind = nil + netc.port = 0 + return err + } + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.bind = nil + netc.port = 0 + return err + } + + // set fwmark + + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) + if err != nil { + return err + } + } + + // clear cached source addresses + + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Lock() + defer peer.Unlock() + if peer.endpoint != nil { + peer.endpoint.ClearSrc() + } + } + device.peers.RUnlock() + + // start receiving routines + + device.net.starting.Add(2) + device.net.stopping.Add(2) + go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) + go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.starting.Wait() + + device.log.Debug.Println("UDP bind has been updated") + } + + return nil +} + +func (device *Device) BindClose() error { + device.net.Lock() + err := unsafeCloseBind(device) + device.net.Unlock() + return err +} |