diff options
author | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-02-22 02:01:50 +0100 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2021-02-23 20:00:57 +0100 |
commit | a4f8e83d5d9f477554971e90e9ab85922f506ea9 (patch) | |
tree | 5249ac2dbdc8cbb6a7d2d40814b07d7d1f38ad4d /device | |
parent | c69481f1b3b4b37b9c16f997a5d8d91367d9bfee (diff) |
conn: make binds replacable
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'device')
-rw-r--r-- | device/device.go | 13 | ||||
-rw-r--r-- | device/device_test.go | 5 | ||||
-rw-r--r-- | device/peer.go | 9 | ||||
-rw-r--r-- | device/sticky_default.go | 2 | ||||
-rw-r--r-- | device/sticky_linux.go | 15 | ||||
-rw-r--r-- | device/uapi.go | 3 |
6 files changed, 19 insertions, 28 deletions
diff --git a/device/device.go b/device/device.go index 432549d..4b131a2 100644 --- a/device/device.go +++ b/device/device.go @@ -279,11 +279,12 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { return nil } -func NewDevice(tunDevice tun.Device, logger *Logger) *Device { +func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device := new(Device) device.state.state = uint32(deviceStateDown) device.closed = make(chan struct{}) device.log = logger + device.net.bind = bind device.tun.device = tunDevice mtu, err := device.tun.device.MTU() if err != nil { @@ -302,11 +303,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device { device.queue.encryption = newOutboundQueue() device.queue.decryption = newInboundQueue() - // prepare net - - device.net.port = 0 - device.net.bind = nil - // start workers cpus := runtime.NumCPU() @@ -414,7 +410,6 @@ func unsafeCloseBind(device *Device) error { } if netc.bind != nil { err = netc.bind.Close() - netc.bind = nil } netc.stopping.Wait() return err @@ -474,16 +469,14 @@ func (device *Device) BindUpdate() error { // bind to new port var err error netc := &device.net - netc.bind, netc.port, err = conn.CreateBind(netc.port) + netc.port, err = netc.bind.Open(netc.port) if err != nil { - netc.bind = nil netc.port = 0 return err } netc.netlinkCancel, err = device.startRouteListener(netc.bind) if err != nil { netc.bind.Close() - netc.bind = nil netc.port = 0 return err } diff --git a/device/device_test.go b/device/device_test.go index 7958fb9..01ee2ac 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -21,6 +21,7 @@ import ( "testing" "time" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/tun/tuntest" ) @@ -158,7 +159,7 @@ func genTestPair(tb testing.TB) (pair testPair) { if _, ok := tb.(*testing.B); ok && !testing.Verbose() { level = LogLevelError } - p.dev = NewDevice(p.tun.TUN(), NewLogger(level, fmt.Sprintf("dev%d: ", i))) + p.dev = NewDevice(p.tun.TUN(), conn.NewDefaultBind(), NewLogger(level, fmt.Sprintf("dev%d: ", i))) if err := p.dev.IpcSet(cfg[i]); err != nil { tb.Errorf("failed to configure device %d: %v", i, err) p.dev.Close() @@ -332,7 +333,7 @@ func randDevice(t *testing.T) *Device { } tun := newDummyTUN("dummy") logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) + device := NewDevice(tun, conn.NewDefaultBind(), logger) device.SetPrivateKey(sk) return device } diff --git a/device/peer.go b/device/peer.go index 499888d..332f7bd 100644 --- a/device/peer.go +++ b/device/peer.go @@ -126,13 +126,8 @@ func (peer *Peer) SendBuffer(buffer []byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() - if peer.device.net.bind == nil { - // Packets can leak through to SendBuffer while the device is closing. - // When that happens, drop them silently to avoid spurious errors. - if peer.device.isClosed() { - return nil - } - return errors.New("no bind") + if peer.device.isClosed() { + return nil } peer.RLock() diff --git a/device/sticky_default.go b/device/sticky_default.go index 56da4eb..1cc52f6 100644 --- a/device/sticky_default.go +++ b/device/sticky_default.go @@ -1,4 +1,4 @@ -// +build !linux android +// +build !linux package device diff --git a/device/sticky_linux.go b/device/sticky_linux.go index a984f24..6193ea3 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -1,5 +1,3 @@ -// +build !android - /* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. @@ -21,11 +19,16 @@ import ( "unsafe" "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + if _, ok := bind.(*conn.LinuxSocketBind); !ok { + return nil, nil + } + netlinkSock, err := createNetlinkRouteSocket() if err != nil { return nil, err @@ -109,11 +112,11 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl pePtr.peer.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { + if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx { pePtr.peer.Unlock() break } - pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() + pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc() pePtr.peer.Unlock() } attr = attr[attrhdr.Len:] @@ -133,7 +136,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl peer.RUnlock() continue } - nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) + nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint) if nativeEP == nil { peer.RUnlock() continue @@ -176,7 +179,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl Len: 8, Type: unix.RTA_MARK, }, - uint32(bind.LastMark()), + device.net.fwmark, } nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) reqPeerLock.Lock() diff --git a/device/uapi.go b/device/uapi.go index 406880f..659af0a 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -18,7 +18,6 @@ import ( "sync/atomic" "time" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" ) @@ -331,7 +330,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error case "endpoint": device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) - endpoint, err := conn.CreateEndpoint(value) + endpoint, err := device.net.bind.ParseEndpoint(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) } |