diff options
Diffstat (limited to 'conn/sticky_linux.go')
-rw-r--r-- | conn/sticky_linux.go | 105 |
1 files changed, 49 insertions, 56 deletions
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go index 274fa38..a30ccc7 100644 --- a/conn/sticky_linux.go +++ b/conn/sticky_linux.go @@ -14,6 +14,37 @@ import ( "golang.org/x/sys/unix" ) +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + // getSrcFromControl parses the control for PKTINFO and if found updates ep with // the source information found. func getSrcFromControl(control []byte, ep *StdNetEndpoint) { @@ -35,81 +66,43 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { if hdr.Level == unix.IPPROTO_IP && hdr.Type == unix.IP_PKTINFO { - info := pktInfoFromBuf[unix.Inet4Pktinfo](data) - ep.src.Addr = netip.AddrFrom4(info.Spec_dst) - ep.src.ifidx = info.Ifindex + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } if hdr.Level == unix.IPPROTO_IPV6 && hdr.Type == unix.IPV6_PKTINFO { - info := pktInfoFromBuf[unix.Inet6Pktinfo](data) - ep.src.Addr = netip.AddrFrom16(info.Addr) - ep.src.ifidx = int32(info.Ifindex) + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) return } } } -// pktInfoFromBuf returns type T populated from the provided buf via copy(). It -// panics if buf is of insufficient size. -func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) { - size := int(unsafe.Sizeof(t)) - if len(buf) < size { - panic("pktInfoFromBuf: buffer too small") - } - copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf) - return t -} - // setSrcControl sets an IP{V6}_PKTINFO in control based on the source address // and source ifindex found in ep. control's len will be set to 0 in the event // that ep is a default value. func setSrcControl(control *[]byte, ep *StdNetEndpoint) { - *control = (*control)[:cap(*control)] - if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) { - *control = (*control)[:0] + if cap(*control) < len(ep.src) { return } - - if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() { - *control = (*control)[:0] - return - } - - if len(*control) < srcControlSize { - *control = (*control)[:0] - return - } - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0])) - if ep.SrcIP().Is4() { - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) - - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = ep.src.ifidx - if ep.SrcIP().IsValid() { - info.Spec_dst = ep.SrcIP().As4() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] - } else { - hdr.Level = unix.IPPROTO_IPV6 - hdr.Type = unix.IPV6_PKTINFO - hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) - - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr])) - info.Ifindex = uint32(ep.src.ifidx) - if ep.SrcIP().IsValid() { - info.Addr = ep.SrcIP().As16() - } - *control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] - } - + *control = (*control)[:0] + *control = append(*control, ep.src...) } var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) |