summaryrefslogtreecommitdiffhomepage
path: root/device
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-02-22 02:01:50 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-02-23 20:00:57 +0100
commita4f8e83d5d9f477554971e90e9ab85922f506ea9 (patch)
tree5249ac2dbdc8cbb6a7d2d40814b07d7d1f38ad4d /device
parentc69481f1b3b4b37b9c16f997a5d8d91367d9bfee (diff)
conn: make binds replacable
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device')
-rw-r--r--device/device.go13
-rw-r--r--device/device_test.go5
-rw-r--r--device/peer.go9
-rw-r--r--device/sticky_default.go2
-rw-r--r--device/sticky_linux.go15
-rw-r--r--device/uapi.go3
6 files changed, 19 insertions, 28 deletions
diff --git a/device/device.go b/device/device.go
index 432549d..4b131a2 100644
--- a/device/device.go
+++ b/device/device.go
@@ -279,11 +279,12 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil
}
-func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
+func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
device := new(Device)
device.state.state = uint32(deviceStateDown)
device.closed = make(chan struct{})
device.log = logger
+ device.net.bind = bind
device.tun.device = tunDevice
mtu, err := device.tun.device.MTU()
if err != nil {
@@ -302,11 +303,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device.queue.encryption = newOutboundQueue()
device.queue.decryption = newInboundQueue()
- // prepare net
-
- device.net.port = 0
- device.net.bind = nil
-
// start workers
cpus := runtime.NumCPU()
@@ -414,7 +410,6 @@ func unsafeCloseBind(device *Device) error {
}
if netc.bind != nil {
err = netc.bind.Close()
- netc.bind = nil
}
netc.stopping.Wait()
return err
@@ -474,16 +469,14 @@ func (device *Device) BindUpdate() error {
// bind to new port
var err error
netc := &device.net
- netc.bind, netc.port, err = conn.CreateBind(netc.port)
+ netc.port, err = netc.bind.Open(netc.port)
if err != nil {
- netc.bind = nil
netc.port = 0
return err
}
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
- netc.bind = nil
netc.port = 0
return err
}
diff --git a/device/device_test.go b/device/device_test.go
index 7958fb9..01ee2ac 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"time"
+ "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/tun/tuntest"
)
@@ -158,7 +159,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError
}
- p.dev = NewDevice(p.tun.TUN(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
+ p.dev = NewDevice(p.tun.TUN(), conn.NewDefaultBind(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
if err := p.dev.IpcSet(cfg[i]); err != nil {
tb.Errorf("failed to configure device %d: %v", i, err)
p.dev.Close()
@@ -332,7 +333,7 @@ func randDevice(t *testing.T) *Device {
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
- device := NewDevice(tun, logger)
+ device := NewDevice(tun, conn.NewDefaultBind(), logger)
device.SetPrivateKey(sk)
return device
}
diff --git a/device/peer.go b/device/peer.go
index 499888d..332f7bd 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -126,13 +126,8 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
- if peer.device.net.bind == nil {
- // Packets can leak through to SendBuffer while the device is closing.
- // When that happens, drop them silently to avoid spurious errors.
- if peer.device.isClosed() {
- return nil
- }
- return errors.New("no bind")
+ if peer.device.isClosed() {
+ return nil
}
peer.RLock()
diff --git a/device/sticky_default.go b/device/sticky_default.go
index 56da4eb..1cc52f6 100644
--- a/device/sticky_default.go
+++ b/device/sticky_default.go
@@ -1,4 +1,4 @@
-// +build !linux android
+// +build !linux
package device
diff --git a/device/sticky_linux.go b/device/sticky_linux.go
index a984f24..6193ea3 100644
--- a/device/sticky_linux.go
+++ b/device/sticky_linux.go
@@ -1,5 +1,3 @@
-// +build !android
-
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
@@ -21,11 +19,16 @@ import (
"unsafe"
"golang.org/x/sys/unix"
+
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
+ if _, ok := bind.(*conn.LinuxSocketBind); !ok {
+ return nil, nil
+ }
+
netlinkSock, err := createNetlinkRouteSocket()
if err != nil {
return nil, err
@@ -109,11 +112,11 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
pePtr.peer.Unlock()
break
}
- if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
+ if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
pePtr.peer.Unlock()
break
}
- pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
+ pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
@@ -133,7 +136,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
peer.RUnlock()
continue
}
- nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
+ nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
if nativeEP == nil {
peer.RUnlock()
continue
@@ -176,7 +179,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
Len: 8,
Type: unix.RTA_MARK,
},
- uint32(bind.LastMark()),
+ device.net.fwmark,
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
diff --git a/device/uapi.go b/device/uapi.go
index 406880f..659af0a 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -18,7 +18,6 @@ import (
"sync/atomic"
"time"
- "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ipc"
)
@@ -331,7 +330,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "endpoint":
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
- endpoint, err := conn.CreateEndpoint(value)
+ endpoint, err := device.net.bind.ParseEndpoint(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
}