summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/conn.go16
-rw-r--r--src/conn_default.go35
-rw-r--r--src/conn_linux.go2
-rw-r--r--src/cookie_test.go7
-rw-r--r--src/daemon_linux.go13
-rw-r--r--src/device.go8
-rw-r--r--src/helper_test.go8
-rw-r--r--src/misc.go8
-rw-r--r--src/noise_test.go8
-rw-r--r--src/receive.go27
-rwxr-xr-xsrc/tests/netns.sh76
-rw-r--r--src/tun_linux.go1
-rw-r--r--src/uapi.go13
13 files changed, 194 insertions, 28 deletions
diff --git a/src/conn.go b/src/conn.go
index a047bb6..3cf00ab 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -15,6 +15,22 @@ type UDPBind interface {
Close() error
}
+/* An Endpoint maintains the source/destination caching for a peer
+ *
+ * dst : the remote address of a peer
+ * src : the local address from which datagrams originate going to the peer
+ *
+ */
+type UDPEndpoint interface {
+ ClearSrc() // clears the source address
+ ClearDst() // clears the destination address
+ SrcToString() string // returns the local source address (ip:port)
+ DstToString() string // returns the destination address (ip:port)
+ DstToBytes() []byte // used for mac2 cookie calculations
+ DstIP() net.IP
+ SrcIP() net.IP
+}
+
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
diff --git a/src/conn_default.go b/src/conn_default.go
index 279643e..31cab5c 100644
--- a/src/conn_default.go
+++ b/src/conn_default.go
@@ -6,6 +6,41 @@ import (
"net"
)
+/* This code is meant to be a temporary solution
+ * on platforms for which the sticky socket / source caching behavior
+ * has not yet been implemented.
+ *
+ * See conn_linux.go for an implementation on the linux platform.
+ */
+
+type Endpoint *net.UDPAddr
+
+type NativeBind *net.UDPConn
+
+func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
+
+ // listen
+
+ addr := UDPAddr{
+ Port: int(port),
+ }
+ conn, err := net.ListenUDP("udp", &addr)
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // retrieve port
+
+ laddr := conn.LocalAddr()
+ uaddr, _ = net.ResolveUDPAddr(
+ laddr.Network(),
+ laddr.String(),
+ )
+ return uaddr.Port
+}
+
+func (_ Endpoint) ClearSrc() {}
+
func SetMark(conn *net.UDPConn, value uint32) error {
return nil
}
diff --git a/src/conn_linux.go b/src/conn_linux.go
index 383ff7e..fb576b1 100644
--- a/src/conn_linux.go
+++ b/src/conn_linux.go
@@ -168,7 +168,7 @@ func (end *Endpoint) DstIP() net.IP {
}
}
-func (end *Endpoint) SrcToBytes() []byte {
+func (end *Endpoint) DstToBytes() []byte {
ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:]
diff --git a/src/cookie_test.go b/src/cookie_test.go
index 193a76e..d745fe7 100644
--- a/src/cookie_test.go
+++ b/src/cookie_test.go
@@ -1,7 +1,6 @@
package main
import (
- "net"
"testing"
)
@@ -25,7 +24,7 @@ func TestCookieMAC1(t *testing.T) {
// check mac1
- src, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4000")
+ src := []byte{192, 168, 13, 37, 10, 10, 10}
checkMAC1 := func(msg []byte) {
generator.AddMacs(msg)
@@ -128,12 +127,12 @@ func TestCookieMAC1(t *testing.T) {
msg[5] ^= 0x20
- srcBad1, _ := net.ResolveUDPAddr("udp", "192.168.13.37:4001")
+ srcBad1 := []byte{192, 168, 13, 37, 40, 01}
if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed")
}
- srcBad2, _ := net.ResolveUDPAddr("udp", "192.168.13.38:4000")
+ srcBad2 := []byte{192, 168, 13, 38, 40, 01}
if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed")
}
diff --git a/src/daemon_linux.go b/src/daemon_linux.go
index 8210f8b..e1aaede 100644
--- a/src/daemon_linux.go
+++ b/src/daemon_linux.go
@@ -2,20 +2,25 @@ package main
import (
"os"
+ "os/exec"
)
/* Daemonizes the process on linux
*
* This is done by spawning and releasing a copy with the --foreground flag
- *
- * TODO: Use env variable to spawn in background
*/
-
func Daemonize(attr *os.ProcAttr) error {
+ // I would like to use os.Executable,
+ // however this means dropping support for Go <1.8
+ path, err := exec.LookPath(os.Args[0])
+ if err != nil {
+ return err
+ }
+
argv := []string{os.Args[0], "--foreground"}
argv = append(argv, os.Args[1:]...)
process, err := os.StartProcess(
- argv[0],
+ path,
argv,
attr,
)
diff --git a/src/device.go b/src/device.go
index 429ee46..0085cee 100644
--- a/src/device.go
+++ b/src/device.go
@@ -8,8 +8,9 @@ import (
)
type Device struct {
- log *Logger // collection of loggers for levels
- idCounter uint // for assigning debug ids to peers
+ closed AtomicBool // device is closed? (acting as guard)
+ log *Logger // collection of loggers for levels
+ idCounter uint // for assigning debug ids to peers
fwMark uint32
tun struct {
device TUNDevice
@@ -203,6 +204,9 @@ func (device *Device) RemoveAllPeers() {
}
func (device *Device) Close() {
+ if device.closed.Swap(true) {
+ return
+ }
device.log.Info.Println("Closing device")
device.RemoveAllPeers()
close(device.signal.stop)
diff --git a/src/helper_test.go b/src/helper_test.go
index fc171e8..8548121 100644
--- a/src/helper_test.go
+++ b/src/helper_test.go
@@ -2,6 +2,7 @@ package main
import (
"bytes"
+ "os"
"testing"
)
@@ -15,6 +16,10 @@ type DummyTUN struct {
events chan TUNEvent
}
+func (tun *DummyTUN) File() *os.File {
+ return nil
+}
+
func (tun *DummyTUN) Name() string {
return tun.name
}
@@ -67,7 +72,8 @@ func randDevice(t *testing.T) *Device {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
- device := NewDevice(tun, LogLevelError)
+ logger := NewLogger(LogLevelError, "")
+ device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}
diff --git a/src/misc.go b/src/misc.go
index bbe0d68..b43e97e 100644
--- a/src/misc.go
+++ b/src/misc.go
@@ -21,6 +21,14 @@ func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.flag) == AtomicTrue
}
+func (a *AtomicBool) Swap(val bool) bool {
+ flag := AtomicFalse
+ if val {
+ flag = AtomicTrue
+ }
+ return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
+}
+
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
diff --git a/src/noise_test.go b/src/noise_test.go
index 48408f9..0d7f0e9 100644
--- a/src/noise_test.go
+++ b/src/noise_test.go
@@ -117,8 +117,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error
var out []byte
var nonce [12]byte
- out = key1.send.aead.Seal(out, nonce[:], testMsg, nil)
- out, err = key2.receive.aead.Open(out[:0], nonce[:], out, nil)
+ out = key1.send.Seal(out, nonce[:], testMsg, nil)
+ out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
@@ -128,8 +128,8 @@ func TestNoiseHandshake(t *testing.T) {
var err error
var out []byte
var nonce [12]byte
- out = key2.send.aead.Seal(out, nonce[:], testMsg, nil)
- out, err = key1.receive.aead.Open(out[:0], nonce[:], out, nil)
+ out = key2.send.Seal(out, nonce[:], testMsg, nil)
+ out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
diff --git a/src/receive.go b/src/receive.go
index ff3b7bd..b8b06f7 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -311,7 +311,10 @@ func (device *Device) RoutineHandshake() {
return
}
- srcBytes := elem.endpoint.SrcToBytes()
+ // endpoints destination address is the source of the datagram
+
+ srcBytes := elem.endpoint.DstToBytes()
+
if device.IsUnderLoad() {
// verify MAC2 field
@@ -320,8 +323,12 @@ func (device *Device) RoutineHandshake() {
// construct cookie reply
- logDebug.Println("Sending cookie reply to:", elem.endpoint.SrcToString())
- sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
+ logDebug.Println(
+ "Sending cookie reply to:",
+ elem.endpoint.DstToString(),
+ )
+
+ sender := binary.LittleEndian.Uint32(elem.packet[4:8])
reply, err := device.mac.CreateReply(elem.packet, sender, srcBytes)
if err != nil {
logError.Println("Failed to create cookie reply:", err)
@@ -555,8 +562,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
- logInfo.Println(src)
- logInfo.Println("Packet with unallowed source IPv4 from", peer.String())
+ logInfo.Println(
+ "IPv4 packet with unallowed source address from",
+ peer.String(),
+ )
continue
}
@@ -581,8 +590,10 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
- logInfo.Println(src)
- logInfo.Println("Packet with unallowed source IPv6 from", peer.String())
+ logInfo.Println(
+ "IPv6 packet with unallowed source address from",
+ peer.String(),
+ )
continue
}
@@ -591,7 +602,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
- // write to tun
+ // write to tun device
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.packet)
diff --git a/src/tests/netns.sh b/src/tests/netns.sh
index b5c2f9c..22abea8 100755
--- a/src/tests/netns.sh
+++ b/src/tests/netns.sh
@@ -20,6 +20,14 @@
# wireguard peers in $ns1 and $ns2. Note that $ns0 is the endpoint for the wg1
# interfaces in $ns1 and $ns2. See https://www.wireguard.com/netns/ for further
# details on how this is accomplished.
+
+# This code is ported to the WireGuard-Go directly from the kernel project.
+#
+# Please ensure that you have installed the newest version of the WireGuard
+# tools from the WireGuard project and before running these tests as:
+#
+# ./netns.sh <path to wireguard-go>
+
set -e
exec 3>&1
@@ -27,7 +35,7 @@ export WG_HIDE_KEYS=never
netns0="wg-test-$$-0"
netns1="wg-test-$$-1"
netns2="wg-test-$$-2"
-program="../wireguard-go"
+program=$1
export LOG_LEVEL="info"
pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; }
@@ -349,4 +357,68 @@ ip1 link del veth1
ip1 link del wg1
ip2 link del wg2
-echo "done"
+# Test that Netlink/IPC is working properly by doing things that usually cause split responses
+
+n0 $program wg0
+sleep 5
+config=( "[Interface]" "PrivateKey=$(wg genkey)" "[Peer]" "PublicKey=$(wg genkey)" )
+for a in {1..255}; do
+ for b in {0..255}; do
+ config+=( "AllowedIPs=$a.$b.0.0/16,$a::$b/128" )
+ done
+done
+n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
+i=0
+for ip in $(n0 wg show wg0 allowed-ips); do
+ ((++i))
+done
+((i == 255*256*2+1))
+ip0 link del wg0
+
+n0 $program wg0
+config=( "[Interface]" "PrivateKey=$(wg genkey)" )
+for a in {1..40}; do
+ config+=( "[Peer]" "PublicKey=$(wg genkey)" )
+ for b in {1..52}; do
+ config+=( "AllowedIPs=$a.$b.0.0/16" )
+ done
+done
+n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
+i=0
+while read -r line; do
+ j=0
+ for ip in $line; do
+ ((++j))
+ done
+ ((j == 53))
+ ((++i))
+done < <(n0 wg show wg0 allowed-ips)
+((i == 40))
+ip0 link del wg0
+
+n0 $program wg0
+config=( )
+for i in {1..29}; do
+ config+=( "[Peer]" "PublicKey=$(wg genkey)" )
+done
+config+=( "[Peer]" "PublicKey=$(wg genkey)" "AllowedIPs=255.2.3.4/32,abcd::255/128" )
+n0 wg setconf wg0 <(printf '%s\n' "${config[@]}")
+n0 wg showconf wg0 > /dev/null
+ip0 link del wg0
+
+! n0 wg show doesnotexist || false
+
+declare -A objects
+while read -t 0.1 -r line 2>/dev/null || [[ $? -ne 142 ]]; do
+ [[ $line =~ .*(wg[0-9]+:\ [A-Z][a-z]+\ [0-9]+)\ .*(created|destroyed).* ]] || continue
+ objects["${BASH_REMATCH[1]}"]+="${BASH_REMATCH[2]}"
+done < /dev/kmsg
+alldeleted=1
+for object in "${!objects[@]}"; do
+ if [[ ${objects["$object"]} != *createddestroyed ]]; then
+ echo "Error: $object: merely ${objects["$object"]}" >&3
+ alldeleted=0
+ fi
+done
+[[ $alldeleted -eq 1 ]]
+pretty "" "Objects that were created were also destroyed."
diff --git a/src/tun_linux.go b/src/tun_linux.go
index 2a5b276..a728a48 100644
--- a/src/tun_linux.go
+++ b/src/tun_linux.go
@@ -57,7 +57,6 @@ type NativeTun struct {
}
func (tun *NativeTun) File() *os.File {
- println(tun.fd.Name())
return tun.fd
}
diff --git a/src/uapi.go b/src/uapi.go
index 5e40939..e1d0929 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -145,11 +145,22 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
}
case "fwmark":
- fwmark, err := strconv.ParseUint(value, 10, 32)
+
+ // parse fwmark field
+
+ fwmark, err := func() (uint32, error) {
+ if value == "" {
+ return 0, nil
+ }
+ mark, err := strconv.ParseUint(value, 10, 32)
+ return uint32(mark), err
+ }()
+
if err != nil {
logError.Println("Invalid fwmark", err)
return &IPCError{Code: ipcErrorInvalid}
}
+
device.net.mutex.Lock()
device.net.fwmark = uint32(fwmark)
device.net.mutex.Unlock()