diff options
Diffstat (limited to 'conn/sticky_linux_test.go')
-rw-r--r-- | conn/sticky_linux_test.go | 80 |
1 files changed, 56 insertions, 24 deletions
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index 0219ac3..679213a 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -18,13 +18,47 @@ import ( "golang.org/x/sys/unix" ) +func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { + var buf []byte + if addr.Is4() { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + } else { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } + + ep.src = buf +} + func Test_setSrcControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), } - ep.src.Addr = netip.MustParseAddr("127.0.0.1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) control := make([]byte, srcControlSize) @@ -53,8 +87,7 @@ func Test_setSrcControl(t *testing.T) { ep := &StdNetEndpoint{ AddrPort: netip.MustParseAddrPort("[::1]:1234"), } - ep.src.Addr = netip.MustParseAddr("::1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("::1"), 5) control := make([]byte, srcControlSize) @@ -80,7 +113,7 @@ func Test_setSrcControl(t *testing.T) { }) t.Run("ClearOnNoSrc", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgLen(0)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = 1 hdr.Type = 2 @@ -96,7 +129,7 @@ func Test_setSrcControl(t *testing.T) { func Test_getSrcFromControl(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -108,15 +141,15 @@ func Test_getSrcFromControl(t *testing.T) { ep := &StdNetEndpoint{} getSrcFromControl(control, ep) - if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.src.Addr) + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("IPv6", func(t *testing.T) { - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IPV6 hdr.Type = unix.IPV6_PKTINFO @@ -131,22 +164,21 @@ func Test_getSrcFromControl(t *testing.T) { if ep.SrcIP() != netip.MustParseAddr("::1") { t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("ClearOnEmpty", func(t *testing.T) { - control := make([]byte, srcControlSize) + var control []byte ep := &StdNetEndpoint{} - ep.src.Addr = netip.MustParseAddr("::1") - ep.src.ifidx = 5 + setSrc(ep, netip.MustParseAddr("::1"), 5) getSrcFromControl(control, ep) if ep.SrcIP().IsValid() { - t.Errorf("unexpected address: %v", ep.src.Addr) + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 0 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 0 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) t.Run("Multiple", func(t *testing.T) { @@ -154,7 +186,7 @@ func Test_getSrcFromControl(t *testing.T) { zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) zeroHdr.SetLen(unix.CmsgLen(0)) - control := make([]byte, srcControlSize) + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) hdr.Level = unix.IPPROTO_IP hdr.Type = unix.IP_PKTINFO @@ -170,11 +202,11 @@ func Test_getSrcFromControl(t *testing.T) { ep := &StdNetEndpoint{} getSrcFromControl(combined, ep) - if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.src.Addr) + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) } - if ep.src.ifidx != 5 { - t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) } }) } |