summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--conn/bind_linux.go109
-rw-r--r--conn/bind_std.go69
-rw-r--r--conn/bind_windows.go22
-rw-r--r--conn/bindtest/bindtest.go31
-rw-r--r--conn/conn.go17
-rw-r--r--device/device.go17
-rw-r--r--device/receive.go15
7 files changed, 138 insertions, 142 deletions
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index 70ea609..9eec384 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
type LinuxSocketBind struct {
- sock4 int
- sock6 int
- lastMark uint32
- closing sync.RWMutex
+ // mu guards sock4 and sock6 and the associated fds.
+ // As long as someone holds mu (read or write), the associated fds are valid.
+ mu sync.RWMutex
+ sock4 int
+ sock6 int
}
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
@@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
return nil, errors.New("invalid IP address")
}
-func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) {
+func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
var err error
var newPort uint16
var tries int
if bind.sock4 != -1 || bind.sock6 != -1 {
- return 0, ErrBindAlreadyOpen
+ return nil, 0, ErrBindAlreadyOpen
}
originalPort := port
again:
port = originalPort
+ var sock4, sock6 int
// Attempt ipv6 bind, update port if successful.
- bind.sock6, newPort, err = create6(port)
+ sock6, newPort, err = create6(port)
if err != nil {
- if err != syscall.EAFNOSUPPORT {
- return 0, err
+ if !errors.Is(err, syscall.EAFNOSUPPORT) {
+ return nil, 0, err
}
} else {
port = newPort
}
// Attempt ipv4 bind, update port if successful.
- bind.sock4, newPort, err = create4(port)
+ sock4, newPort, err = create4(port)
if err != nil {
- if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 {
- unix.Close(bind.sock6)
+ if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
+ unix.Close(sock6)
tries++
goto again
}
- if err != syscall.EAFNOSUPPORT {
- unix.Close(bind.sock6)
- return 0, err
+ if !errors.Is(err, syscall.EAFNOSUPPORT) {
+ unix.Close(sock6)
+ return nil, 0, err
}
} else {
port = newPort
}
- if bind.sock4 == -1 && bind.sock6 == -1 {
- return 0, syscall.EAFNOSUPPORT
+ var fns []ReceiveFunc
+ if sock4 != -1 {
+ fns = append(fns, makeReceiveIPv4(sock4))
+ bind.sock4 = sock4
+ }
+ if sock6 != -1 {
+ fns = append(fns, makeReceiveIPv6(sock6))
+ bind.sock6 = sock6
+ }
+ if len(fns) == 0 {
+ return nil, 0, syscall.EAFNOSUPPORT
}
- return port, nil
+ return fns, port, nil
}
func (bind *LinuxSocketBind) SetMark(value uint32) error {
- bind.closing.RLock()
- defer bind.closing.RUnlock()
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
@@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
}
}
- bind.lastMark = value
return nil
}
func (bind *LinuxSocketBind) Close() error {
- var err1, err2 error
- bind.closing.RLock()
+ // Take a readlock to shut down the sockets...
+ bind.mu.RLock()
if bind.sock6 != -1 {
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
}
if bind.sock4 != -1 {
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
}
- bind.closing.RUnlock()
- bind.closing.Lock()
+ bind.mu.RUnlock()
+ // ...and a write lock to close the fd.
+ // This ensures that no one else is using the fd.
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+ var err1, err2 error
if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6)
bind.sock6 = -1
@@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error {
err2 = unix.Close(bind.sock4)
bind.sock4 = -1
}
- bind.closing.Unlock()
if err1 != nil {
return err1
@@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error {
return err2
}
-func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
- bind.closing.RLock()
- defer bind.closing.RUnlock()
-
- var end LinuxSocketEndpoint
- if bind.sock6 == -1 {
- return 0, nil, net.ErrClosed
+func makeReceiveIPv6(sock int) ReceiveFunc {
+ return func(buff []byte) (int, Endpoint, error) {
+ var end LinuxSocketEndpoint
+ n, err := receive6(sock, buff, &end)
+ return n, &end, err
}
- n, err := receive6(
- bind.sock6,
- buff,
- &end,
- )
- return n, &end, err
}
-func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
- bind.closing.RLock()
- defer bind.closing.RUnlock()
-
- var end LinuxSocketEndpoint
- if bind.sock4 == -1 {
- return 0, nil, net.ErrClosed
+func makeReceiveIPv4(sock int) ReceiveFunc {
+ return func(buff []byte) (int, Endpoint, error) {
+ var end LinuxSocketEndpoint
+ n, err := receive4(sock, buff, &end)
+ return n, &end, err
}
- n, err := receive4(
- bind.sock4,
- buff,
- &end,
- )
- return n, &end, err
}
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
- bind.closing.RLock()
- defer bind.closing.RUnlock()
-
nend, ok := end.(*LinuxSocketEndpoint)
if !ok {
return ErrWrongEndpointType
}
+ bind.mu.RLock()
+ defer bind.mu.RUnlock()
if !nend.isV6 {
if bind.sock4 == -1 {
return net.ErrClosed
diff --git a/conn/bind_std.go b/conn/bind_std.go
index f8b8a1b..5261779 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -8,6 +8,7 @@ package conn
import (
"errors"
"net"
+ "sync"
"syscall"
)
@@ -16,6 +17,7 @@ import (
// It uses the Go's net package to implement networking.
// See LinuxSocketBind for a proper implementation on the Linux platform.
type StdNetBind struct {
+ mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
@@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
return conn, uaddr.Port, nil
}
-func (bind *StdNetBind) Open(uport uint16) (uint16, error) {
+func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
var err error
var tries int
if bind.ipv4 != nil || bind.ipv6 != nil {
- return 0, ErrBindAlreadyOpen
+ return nil, 0, ErrBindAlreadyOpen
}
// Attempt to open ipv4 and ipv6 listeners on the same port.
@@ -97,7 +102,7 @@ again:
ipv4, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
- return 0, err
+ return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
@@ -109,17 +114,27 @@ again:
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
ipv4.Close()
- return 0, err
+ return nil, 0, err
}
- if ipv4 == nil && ipv6 == nil {
- return 0, syscall.EAFNOSUPPORT
+ var fns []ReceiveFunc
+ if ipv4 != nil {
+ fns = append(fns, makeReceiveFunc(ipv4, true))
+ bind.ipv4 = ipv4
}
- bind.ipv4 = ipv4
- bind.ipv6 = ipv6
- return uint16(port), nil
+ if ipv6 != nil {
+ fns = append(fns, makeReceiveFunc(ipv6, false))
+ bind.ipv6 = ipv6
+ }
+ if len(fns) == 0 {
+ return nil, 0, syscall.EAFNOSUPPORT
+ }
+ return fns, uint16(port), nil
}
func (bind *StdNetBind) Close() error {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
+
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
@@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error {
return err2
}
-func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
- if bind.ipv4 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
+func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc {
+ return func(buff []byte) (int, Endpoint, error) {
+ n, endpoint, err := conn.ReadFromUDP(buff)
+ if isIPv4 && endpoint != nil {
+ endpoint.IP = endpoint.IP.To4()
+ }
+ return n, (*StdNetEndpoint)(endpoint), err
}
- n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
- if endpoint != nil {
- endpoint.IP = endpoint.IP.To4()
- }
- return n, (*StdNetEndpoint)(endpoint), err
-}
-
-func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
- if bind.ipv6 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
- }
- n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
- return n, (*StdNetEndpoint)(endpoint), err
}
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
@@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if !ok {
return ErrWrongEndpointType
}
- var conn *net.UDPConn
- var blackhole bool
- if nend.IP.To4() != nil {
- blackhole = bind.blackhole4
- conn = bind.ipv4
- } else {
+
+ bind.mu.Lock()
+ blackhole := bind.blackhole4
+ conn := bind.ipv4
+ if nend.IP.To4() == nil {
blackhole = bind.blackhole6
conn = bind.ipv6
}
+ bind.mu.Unlock()
+
if blackhole {
return nil
}
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index 1e2712e..6cabee1 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -266,7 +266,7 @@ func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sock
return sa, nil
}
-func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
+func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
bind.mu.Lock()
defer bind.mu.Unlock()
defer func() {
@@ -275,30 +275,30 @@ func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
}
}()
if atomic.LoadUint32(&bind.isOpen) != 0 {
- return 0, ErrBindAlreadyOpen
+ return nil, 0, ErrBindAlreadyOpen
}
var sa windows.Sockaddr
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
if err != nil {
- return 0, err
+ return nil, 0, err
}
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
if err != nil {
- return 0, err
+ return nil, 0, err
}
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
for i := 0; i < packetsPerRing; i++ {
err = bind.v4.InsertReceiveRequest()
if err != nil {
- return 0, err
+ return nil, 0, err
}
err = bind.v6.InsertReceiveRequest()
if err != nil {
- return 0, err
+ return nil, 0, err
}
}
atomic.StoreUint32(&bind.isOpen, 1)
- return
+ return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
}
func (bind *WinRingBind) Close() error {
@@ -395,13 +395,13 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
return n, &ep, nil
}
-func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
return bind.v4.Receive(buf, &bind.isOpen)
}
-func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) {
+func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
return bind.v6.Receive(buf, &bind.isOpen)
@@ -482,6 +482,8 @@ func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
}
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
sysconn, err := bind.ipv4.SyscallConn()
if err != nil {
return err
@@ -500,6 +502,8 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
}
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ bind.mu.Lock()
+ defer bind.mu.Unlock()
sysconn, err := bind.ipv6.SyscallConn()
if err != nil {
return err
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index ad8fa05..7d43fb3 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -65,12 +65,14 @@ func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
func (c ChannelEndpoint) SrcIP() net.IP { return nil }
-func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) {
+func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool)
+ fns = append(fns, c.makeReceiveFunc(*c.rx4))
+ fns = append(fns, c.makeReceiveFunc(*c.rx6))
if rand.Uint32()&1 == 0 {
- return uint16(c.source4), nil
+ return fns, uint16(c.source4), nil
} else {
- return uint16(c.source6), nil
+ return fns, uint16(c.source6), nil
}
}
@@ -87,21 +89,14 @@ func (c *ChannelBind) Close() error {
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
-func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) {
- select {
- case <-c.closeSignal:
- return 0, nil, net.ErrClosed
- case rx := <-*c.rx6:
- return copy(b, rx), c.target6, nil
- }
-}
-
-func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
- select {
- case <-c.closeSignal:
- return 0, nil, net.ErrClosed
- case rx := <-*c.rx4:
- return copy(b, rx), c.target4, nil
+func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
+ return func(b []byte) (n int, ep conn.Endpoint, err error) {
+ select {
+ case <-c.closeSignal:
+ return 0, nil, net.ErrClosed
+ case rx := <-ch:
+ return copy(b, rx), c.target6, nil
+ }
}
}
diff --git a/conn/conn.go b/conn/conn.go
index 6fd232f..3c7fcd0 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -12,6 +12,11 @@ import (
"strings"
)
+// A ReceiveFunc receives a single inbound packet from the network.
+// It writes the data into b. n is the length of the packet.
+// ep is the remote endpoint.
+type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
+
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
@@ -19,23 +24,17 @@ import (
type Bind interface {
// Open puts the Bind into a listening state on a given port and reports the actual
// port that it bound to. Passing zero results in a random selection.
- Open(port uint16) (actualPort uint16, err error)
+ // fns is the set of functions that will be called to receive packets.
+ Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
// Close closes the Bind listener.
+ // All fns returned by Open must return net.ErrClosed after a call to Close.
Close() error
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
- // ReceiveIPv6 reads an IPv6 UDP packet into b. It reports the number of bytes read,
- // n, the packet source address ep, and any error.
- ReceiveIPv6(b []byte) (n int, ep Endpoint, err error)
-
- // ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read,
- // n, the packet source address ep, and any error.
- ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
-
// Send writes a packet b to address ep.
Send(b []byte, ep Endpoint) error
diff --git a/device/device.go b/device/device.go
index 1e32db6..a635e68 100644
--- a/device/device.go
+++ b/device/device.go
@@ -11,9 +11,6 @@ import (
"sync/atomic"
"time"
- "golang.org/x/net/ipv4"
- "golang.org/x/net/ipv6"
-
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
@@ -468,8 +465,9 @@ func (device *Device) BindUpdate() error {
// bind to new port
var err error
+ var recvFns []conn.ReceiveFunc
netc := &device.net
- netc.port, err = netc.bind.Open(netc.port)
+ recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil {
netc.port = 0
return err
@@ -501,11 +499,12 @@ func (device *Device) BindUpdate() error {
device.peers.RUnlock()
// start receiving routines
- device.net.stopping.Add(2)
- device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
- device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
- go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
- go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
+ device.net.stopping.Add(len(recvFns))
+ device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
+ device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
+ for _, fn := range recvFns {
+ go device.RoutineReceiveIncoming(fn)
+ }
device.log.Verbosef("UDP bind has been updated")
return nil
diff --git a/device/receive.go b/device/receive.go
index 5ddb66c..fa5c0a6 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -68,15 +68,15 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
-func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
+func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
defer func() {
- device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
+ device.log.Verbosef("Routine: receive incoming %p - stopped", recv)
device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done()
device.net.stopping.Done()
}()
- device.log.Verbosef("Routine: receive incoming IPv%d - started", IP)
+ device.log.Verbosef("Routine: receive incoming %p - started", recv)
// receive datagrams until conn is closed
@@ -90,14 +90,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
)
for {
- switch IP {
- case ipv4.Version:
- size, endpoint, err = bind.ReceiveIPv4(buffer[:])
- case ipv6.Version:
- size, endpoint, err = bind.ReceiveIPv6(buffer[:])
- default:
- panic("invalid IP version")
- }
+ size, endpoint, err = recv(buffer[:])
if err != nil {
device.PutMessageBuffer(buffer)