summaryrefslogtreecommitdiffhomepage
path: root/conn/sticky_linux_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'conn/sticky_linux_test.go')
-rw-r--r--conn/sticky_linux_test.go266
1 files changed, 266 insertions, 0 deletions
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
new file mode 100644
index 0000000..d2bd584
--- /dev/null
+++ b/conn/sticky_linux_test.go
@@ -0,0 +1,266 @@
+//go:build linux && !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "context"
+ "net"
+ "net/netip"
+ "runtime"
+ "testing"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
+ var buf []byte
+ if addr.Is4() {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet4Pktinfo{
+ Ifindex: ifidx,
+ Spec_dst: addr.As4(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
+ } else {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet6Pktinfo{
+ Ifindex: uint32(ifidx),
+ Addr: addr.As16(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = buf
+}
+
+func Test_setSrcControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IP {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IP_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
+ t.Errorf("unexpected address: %v", info.Spec_dst)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("IPv6", func(t *testing.T) {
+ ep := &StdNetEndpoint{
+ AddrPort: netip.MustParseAddrPort("[::1]:1234"),
+ }
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ control := make([]byte, stickyControlSize)
+
+ setSrcControl(&control, ep)
+
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ if hdr.Level != unix.IPPROTO_IPV6 {
+ t.Errorf("unexpected level: %d", hdr.Level)
+ }
+ if hdr.Type != unix.IPV6_PKTINFO {
+ t.Errorf("unexpected type: %d", hdr.Type)
+ }
+ if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
+ t.Errorf("unexpected length: %d", hdr.Len)
+ }
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ if info.Addr != ep.SrcIP().As16() {
+ t.Errorf("unexpected address: %v", info.Addr)
+ }
+ if info.Ifindex != 5 {
+ t.Errorf("unexpected ifindex: %d", info.Ifindex)
+ }
+ })
+
+ t.Run("ClearOnNoSrc", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = 1
+ hdr.Type = 2
+ hdr.Len = 3
+
+ setSrcControl(&control, &StdNetEndpoint{})
+
+ if len(control) != 0 {
+ t.Errorf("unexpected control: %v", control)
+ }
+ })
+}
+
+func Test_getSrcFromControl(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ control := make([]byte, stickyControlSize)
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IPV6
+ hdr.Type = unix.IPV6_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
+ info.Ifindex = 5
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(control, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("::1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("ClearOnEmpty", func(t *testing.T) {
+ var control []byte
+ ep := &StdNetEndpoint{}
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
+
+ getSrcFromControl(control, ep)
+ if ep.SrcIP().IsValid() {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+ t.Run("Multiple", func(t *testing.T) {
+ zeroControl := make([]byte, unix.CmsgSpace(0))
+ zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
+ zeroHdr.SetLen(unix.CmsgLen(0))
+
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
+ hdr.Level = unix.IPPROTO_IP
+ hdr.Type = unix.IP_PKTINFO
+ hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
+ info.Spec_dst = [4]byte{127, 0, 0, 1}
+ info.Ifindex = 5
+
+ combined := make([]byte, 0)
+ combined = append(combined, zeroControl...)
+ combined = append(combined, control...)
+
+ ep := &StdNetEndpoint{}
+ getSrcFromControl(combined, ep)
+
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
+ }
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
+ }
+ })
+}
+
+func Test_listenConfig(t *testing.T) {
+ t.Run("IPv4", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IP_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+ t.Run("IPv6", func(t *testing.T) {
+ conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ sc, err := conn.(*net.UDPConn).SyscallConn()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if runtime.GOOS == "linux" {
+ var i int
+ sc.Control(func(fd uintptr) {
+ i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if i != 1 {
+ t.Error("IPV6_PKTINFO not set!")
+ }
+ } else {
+ t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
+ }
+ })
+}