summaryrefslogtreecommitdiffhomepage
path: root/conn/sticky_linux_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'conn/sticky_linux_test.go')
-rw-r--r--conn/sticky_linux_test.go80
1 files changed, 56 insertions, 24 deletions
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
index 0219ac3..679213a 100644
--- a/conn/sticky_linux_test.go
+++ b/conn/sticky_linux_test.go
@@ -18,13 +18,47 @@ import (
"golang.org/x/sys/unix"
)
+func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
+ var buf []byte
+ if addr.Is4() {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet4Pktinfo{
+ Ifindex: ifidx,
+ Spec_dst: addr.As4(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
+ } else {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet6Pktinfo{
+ Ifindex: uint32(ifidx),
+ Addr: addr.As16(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = buf
+}
+
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
- ep.src.Addr = netip.MustParseAddr("127.0.0.1")
- ep.src.ifidx = 5
+ setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
control := make([]byte, srcControlSize)
@@ -53,8 +87,7 @@ func Test_setSrcControl(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
- ep.src.Addr = netip.MustParseAddr("::1")
- ep.src.ifidx = 5
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
control := make([]byte, srcControlSize)
@@ -80,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
- control := make([]byte, srcControlSize)
+ control := make([]byte, unix.CmsgLen(0))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
@@ -96,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
- control := make([]byte, srcControlSize)
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
@@ -108,15 +141,15 @@ func Test_getSrcFromControl(t *testing.T) {
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
- if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
- t.Errorf("unexpected address: %v", ep.src.Addr)
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
}
- if ep.src.ifidx != 5 {
- t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("IPv6", func(t *testing.T) {
- control := make([]byte, srcControlSize)
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
@@ -131,22 +164,21 @@ func Test_getSrcFromControl(t *testing.T) {
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
- if ep.src.ifidx != 5 {
- t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
- control := make([]byte, srcControlSize)
+ var control []byte
ep := &StdNetEndpoint{}
- ep.src.Addr = netip.MustParseAddr("::1")
- ep.src.ifidx = 5
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
- t.Errorf("unexpected address: %v", ep.src.Addr)
+ t.Errorf("unexpected address: %v", ep.SrcIP())
}
- if ep.src.ifidx != 0 {
- t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ if ep.SrcIfidx() != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("Multiple", func(t *testing.T) {
@@ -154,7 +186,7 @@ func Test_getSrcFromControl(t *testing.T) {
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
zeroHdr.SetLen(unix.CmsgLen(0))
- control := make([]byte, srcControlSize)
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
@@ -170,11 +202,11 @@ func Test_getSrcFromControl(t *testing.T) {
ep := &StdNetEndpoint{}
getSrcFromControl(combined, ep)
- if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
- t.Errorf("unexpected address: %v", ep.src.Addr)
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
}
- if ep.src.ifidx != 5 {
- t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
}