summaryrefslogtreecommitdiffhomepage
path: root/conn
diff options
context:
space:
mode:
Diffstat (limited to 'conn')
-rw-r--r--conn/bind_std.go27
-rw-r--r--conn/sticky_default.go14
-rw-r--r--conn/sticky_linux.go105
-rw-r--r--conn/sticky_linux_test.go80
4 files changed, 128 insertions, 98 deletions
diff --git a/conn/bind_std.go b/conn/bind_std.go
index 69789b3..c701ef8 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -81,11 +81,10 @@ func NewStdNetBind() Bind {
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
- // src is the current sticky source address and interface index, if supported.
- src struct {
- netip.Addr
- ifidx int32
- }
+ // src is the current sticky source address and interface index, if
+ // supported. Typically this is a PKTINFO structure from/for control
+ // messages, see unix.PKTINFO for an example.
+ src []byte
}
var (
@@ -104,21 +103,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
}
func (e *StdNetEndpoint) ClearSrc() {
- e.src.ifidx = 0
- e.src.Addr = netip.Addr{}
+ if e.src != nil {
+ // Truncate src, no need to reallocate.
+ e.src = e.src[:0]
+ }
}
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
-func (e *StdNetEndpoint) SrcIP() netip.Addr {
- return e.src.Addr
-}
-
-func (e *StdNetEndpoint) SrcIfidx() int32 {
- return e.src.ifidx
-}
+// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
@@ -129,10 +124,6 @@ func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
-func (e *StdNetEndpoint) SrcToString() string {
- return e.src.Addr.String()
-}
-
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
diff --git a/conn/sticky_default.go b/conn/sticky_default.go
index 05f00ea..1fa8a0c 100644
--- a/conn/sticky_default.go
+++ b/conn/sticky_default.go
@@ -7,6 +7,20 @@
package conn
+import "net/netip"
+
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return ""
+}
+
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
// use alternatively named flags and need ports and require testing.
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go
index 274fa38..a30ccc7 100644
--- a/conn/sticky_linux.go
+++ b/conn/sticky_linux.go
@@ -14,6 +14,37 @@ import (
"golang.org/x/sys/unix"
)
+func (e *StdNetEndpoint) SrcIP() netip.Addr {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return netip.AddrFrom4(info.Spec_dst)
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ // TODO: set zone. in order to do so we need to check if the address is
+ // link local, and if it is perform a syscall to turn the ifindex into a
+ // zone string because netip uses string zones.
+ return netip.AddrFrom16(info.Addr)
+ }
+ return netip.Addr{}
+}
+
+func (e *StdNetEndpoint) SrcIfidx() int32 {
+ switch len(e.src) {
+ case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
+ info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return info.Ifindex
+ case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
+ info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
+ return int32(info.Ifindex)
+ }
+ return 0
+}
+
+func (e *StdNetEndpoint) SrcToString() string {
+ return e.SrcIP().String()
+}
+
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
@@ -35,81 +66,43 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
if hdr.Level == unix.IPPROTO_IP &&
hdr.Type == unix.IP_PKTINFO {
- info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
- ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
- ep.src.ifidx = info.Ifindex
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ }
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
return
}
if hdr.Level == unix.IPPROTO_IPV6 &&
hdr.Type == unix.IPV6_PKTINFO {
- info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
- ep.src.Addr = netip.AddrFrom16(info.Addr)
- ep.src.ifidx = int32(info.Ifindex)
+ if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
+ ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
+ hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
+ copy(ep.src, hdrBuf)
+ copy(ep.src[unix.CmsgLen(0):], data)
return
}
}
}
-// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
-// panics if buf is of insufficient size.
-func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
- size := int(unsafe.Sizeof(t))
- if len(buf) < size {
- panic("pktInfoFromBuf: buffer too small")
- }
- copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
- return t
-}
-
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
// and source ifindex found in ep. control's len will be set to 0 in the event
// that ep is a default value.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
- *control = (*control)[:cap(*control)]
- if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
- *control = (*control)[:0]
+ if cap(*control) < len(ep.src) {
return
}
-
- if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
- *control = (*control)[:0]
- return
- }
-
- if len(*control) < srcControlSize {
- *control = (*control)[:0]
- return
- }
-
- hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
- if ep.SrcIP().Is4() {
- hdr.Level = unix.IPPROTO_IP
- hdr.Type = unix.IP_PKTINFO
- hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
-
- info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
- info.Ifindex = ep.src.ifidx
- if ep.SrcIP().IsValid() {
- info.Spec_dst = ep.SrcIP().As4()
- }
- *control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
- } else {
- hdr.Level = unix.IPPROTO_IPV6
- hdr.Type = unix.IPV6_PKTINFO
- hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
-
- info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
- info.Ifindex = uint32(ep.src.ifidx)
- if ep.SrcIP().IsValid() {
- info.Addr = ep.SrcIP().As16()
- }
- *control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
- }
-
+ *control = (*control)[:0]
+ *control = append(*control, ep.src...)
}
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go
index 0219ac3..679213a 100644
--- a/conn/sticky_linux_test.go
+++ b/conn/sticky_linux_test.go
@@ -18,13 +18,47 @@ import (
"golang.org/x/sys/unix"
)
+func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
+ var buf []byte
+ if addr.Is4() {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet4Pktinfo{
+ Ifindex: ifidx,
+ Spec_dst: addr.As4(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
+ } else {
+ buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
+ hdr := unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ }
+ hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
+ copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
+
+ info := unix.Inet6Pktinfo{
+ Ifindex: uint32(ifidx),
+ Addr: addr.As16(),
+ }
+ copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
+ }
+
+ ep.src = buf
+}
+
func Test_setSrcControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
}
- ep.src.Addr = netip.MustParseAddr("127.0.0.1")
- ep.src.ifidx = 5
+ setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
control := make([]byte, srcControlSize)
@@ -53,8 +87,7 @@ func Test_setSrcControl(t *testing.T) {
ep := &StdNetEndpoint{
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
}
- ep.src.Addr = netip.MustParseAddr("::1")
- ep.src.ifidx = 5
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
control := make([]byte, srcControlSize)
@@ -80,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
})
t.Run("ClearOnNoSrc", func(t *testing.T) {
- control := make([]byte, srcControlSize)
+ control := make([]byte, unix.CmsgLen(0))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = 1
hdr.Type = 2
@@ -96,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
func Test_getSrcFromControl(t *testing.T) {
t.Run("IPv4", func(t *testing.T) {
- control := make([]byte, srcControlSize)
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
@@ -108,15 +141,15 @@ func Test_getSrcFromControl(t *testing.T) {
ep := &StdNetEndpoint{}
getSrcFromControl(control, ep)
- if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
- t.Errorf("unexpected address: %v", ep.src.Addr)
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
}
- if ep.src.ifidx != 5 {
- t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("IPv6", func(t *testing.T) {
- control := make([]byte, srcControlSize)
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IPV6
hdr.Type = unix.IPV6_PKTINFO
@@ -131,22 +164,21 @@ func Test_getSrcFromControl(t *testing.T) {
if ep.SrcIP() != netip.MustParseAddr("::1") {
t.Errorf("unexpected address: %v", ep.SrcIP())
}
- if ep.src.ifidx != 5 {
- t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("ClearOnEmpty", func(t *testing.T) {
- control := make([]byte, srcControlSize)
+ var control []byte
ep := &StdNetEndpoint{}
- ep.src.Addr = netip.MustParseAddr("::1")
- ep.src.ifidx = 5
+ setSrc(ep, netip.MustParseAddr("::1"), 5)
getSrcFromControl(control, ep)
if ep.SrcIP().IsValid() {
- t.Errorf("unexpected address: %v", ep.src.Addr)
+ t.Errorf("unexpected address: %v", ep.SrcIP())
}
- if ep.src.ifidx != 0 {
- t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ if ep.SrcIfidx() != 0 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
t.Run("Multiple", func(t *testing.T) {
@@ -154,7 +186,7 @@ func Test_getSrcFromControl(t *testing.T) {
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
zeroHdr.SetLen(unix.CmsgLen(0))
- control := make([]byte, srcControlSize)
+ control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
hdr.Level = unix.IPPROTO_IP
hdr.Type = unix.IP_PKTINFO
@@ -170,11 +202,11 @@ func Test_getSrcFromControl(t *testing.T) {
ep := &StdNetEndpoint{}
getSrcFromControl(combined, ep)
- if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
- t.Errorf("unexpected address: %v", ep.src.Addr)
+ if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
+ t.Errorf("unexpected address: %v", ep.SrcIP())
}
- if ep.src.ifidx != 5 {
- t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
+ if ep.SrcIfidx() != 5 {
+ t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
}
})
}