diff options
Diffstat (limited to 'device')
-rw-r--r-- | device/device.go | 17 | ||||
-rw-r--r-- | device/receive.go | 15 |
2 files changed, 12 insertions, 20 deletions
diff --git a/device/device.go b/device/device.go index 1e32db6..a635e68 100644 --- a/device/device.go +++ b/device/device.go @@ -11,9 +11,6 @@ import ( "sync/atomic" "time" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" @@ -468,8 +465,9 @@ func (device *Device) BindUpdate() error { // bind to new port var err error + var recvFns []conn.ReceiveFunc netc := &device.net - netc.port, err = netc.bind.Open(netc.port) + recvFns, netc.port, err = netc.bind.Open(netc.port) if err != nil { netc.port = 0 return err @@ -501,11 +499,12 @@ func (device *Device) BindUpdate() error { device.peers.RUnlock() // start receiving routines - device.net.stopping.Add(2) - device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption - device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) + device.net.stopping.Add(len(recvFns)) + device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + for _, fn := range recvFns { + go device.RoutineReceiveIncoming(fn) + } device.log.Verbosef("UDP bind has been updated") return nil diff --git a/device/receive.go b/device/receive.go index 5ddb66c..fa5c0a6 100644 --- a/device/receive.go +++ b/device/receive.go @@ -68,15 +68,15 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { +func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) { defer func() { - device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP) + device.log.Verbosef("Routine: receive incoming %p - stopped", recv) device.queue.decryption.wg.Done() device.queue.handshake.wg.Done() device.net.stopping.Done() }() - device.log.Verbosef("Routine: receive incoming IPv%d - started", IP) + device.log.Verbosef("Routine: receive incoming %p - started", recv) // receive datagrams until conn is closed @@ -90,14 +90,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { ) for { - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") - } + size, endpoint, err = recv(buffer[:]) if err != nil { device.PutMessageBuffer(buffer) |