From 5d37bd24e14e3fff6c1ce61e299480beb3d68c00 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 21 Oct 2023 18:41:27 +0200 Subject: conn: separate gso and sticky control Android wants GSO but not sticky. Signed-off-by: Jason A. Donenfeld --- conn/bind_std.go | 2 +- conn/control_default.go | 51 --------- conn/control_linux.go | 159 --------------------------- conn/control_linux_test.go | 266 --------------------------------------------- conn/gso_default.go | 21 ++++ conn/gso_linux.go | 65 +++++++++++ conn/sticky_default.go | 42 +++++++ conn/sticky_linux.go | 112 +++++++++++++++++++ conn/sticky_linux_test.go | 266 +++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 507 insertions(+), 477 deletions(-) delete mode 100644 conn/control_default.go delete mode 100644 conn/control_linux.go delete mode 100644 conn/control_linux_test.go create mode 100644 conn/gso_default.go create mode 100644 conn/gso_linux.go create mode 100644 conn/sticky_default.go create mode 100644 conn/sticky_linux.go create mode 100644 conn/sticky_linux_test.go (limited to 'conn') diff --git a/conn/bind_std.go b/conn/bind_std.go index 5a00f34..e1bcbd1 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -65,7 +65,7 @@ func NewStdNetBind() Bind { msgs := make([]ipv6.Message, IdealBatchSize) for i := range msgs { msgs[i].Buffers = make(net.Buffers, 1) - msgs[i].OOB = make([]byte, controlSize) + msgs[i].OOB = make([]byte, stickyControlSize+gsoControlSize) } return &msgs }, diff --git a/conn/control_default.go b/conn/control_default.go deleted file mode 100644 index 9459da5..0000000 --- a/conn/control_default.go +++ /dev/null @@ -1,51 +0,0 @@ -//go:build !linux || android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import "net/netip" - -func (e *StdNetEndpoint) SrcIP() netip.Addr { - return netip.Addr{} -} - -func (e *StdNetEndpoint) SrcIfidx() int32 { - return 0 -} - -func (e *StdNetEndpoint) SrcToString() string { - return "" -} - -// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets -// {get,set}srcControl feature set, but use alternatively named flags and need -// ports and require testing. - -// getSrcFromControl parses the control for PKTINFO and if found updates ep with -// the source information found. -func getSrcFromControl(control []byte, ep *StdNetEndpoint) { -} - -// setSrcControl parses the control for PKTINFO and if found updates ep with -// the source information found. -func setSrcControl(control *[]byte, ep *StdNetEndpoint) { -} - -// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. -func getGSOSize(control []byte) (int, error) { - return 0, nil -} - -// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. -func setGSOSize(control *[]byte, gsoSize uint16) { -} - -// controlSize returns the recommended buffer size for pooling sticky and UDP -// offloading control data. -const controlSize = 0 - -const StdNetSupportsStickySockets = false diff --git a/conn/control_linux.go b/conn/control_linux.go deleted file mode 100644 index 44a94e6..0000000 --- a/conn/control_linux.go +++ /dev/null @@ -1,159 +0,0 @@ -//go:build linux && !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "fmt" - "net/netip" - "unsafe" - - "golang.org/x/sys/unix" -) - -func (e *StdNetEndpoint) SrcIP() netip.Addr { - switch len(e.src) { - case unix.CmsgSpace(unix.SizeofInet4Pktinfo): - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) - return netip.AddrFrom4(info.Spec_dst) - case unix.CmsgSpace(unix.SizeofInet6Pktinfo): - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) - // TODO: set zone. in order to do so we need to check if the address is - // link local, and if it is perform a syscall to turn the ifindex into a - // zone string because netip uses string zones. - return netip.AddrFrom16(info.Addr) - } - return netip.Addr{} -} - -func (e *StdNetEndpoint) SrcIfidx() int32 { - switch len(e.src) { - case unix.CmsgSpace(unix.SizeofInet4Pktinfo): - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) - return info.Ifindex - case unix.CmsgSpace(unix.SizeofInet6Pktinfo): - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) - return int32(info.Ifindex) - } - return 0 -} - -func (e *StdNetEndpoint) SrcToString() string { - return e.SrcIP().String() -} - -// getSrcFromControl parses the control for PKTINFO and if found updates ep with -// the source information found. -func getSrcFromControl(control []byte, ep *StdNetEndpoint) { - ep.ClearSrc() - - var ( - hdr unix.Cmsghdr - data []byte - rem []byte = control - err error - ) - - for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) - if err != nil { - return - } - - if hdr.Level == unix.IPPROTO_IP && - hdr.Type == unix.IP_PKTINFO { - - if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { - ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) - } - ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] - - hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) - copy(ep.src, hdrBuf) - copy(ep.src[unix.CmsgLen(0):], data) - return - } - - if hdr.Level == unix.IPPROTO_IPV6 && - hdr.Type == unix.IPV6_PKTINFO { - - if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { - ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) - } - - ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] - - hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) - copy(ep.src, hdrBuf) - copy(ep.src[unix.CmsgLen(0):], data) - return - } - } -} - -// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address -// and source ifindex found in ep. control's len will be set to 0 in the event -// that ep is a default value. -func setSrcControl(control *[]byte, ep *StdNetEndpoint) { - if cap(*control) < len(ep.src) { - return - } - *control = (*control)[:0] - *control = append(*control, ep.src...) -} - -const ( - sizeOfGSOData = 2 -) - -// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. -func getGSOSize(control []byte) (int, error) { - var ( - hdr unix.Cmsghdr - data []byte - rem = control - err error - ) - - for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) - if err != nil { - return 0, fmt.Errorf("error parsing socket control message: %w", err) - } - if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { - var gso uint16 - copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) - return int(gso), nil - } - } - return 0, nil -} - -// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing -// data in control untouched. -func setGSOSize(control *[]byte, gsoSize uint16) { - existingLen := len(*control) - avail := cap(*control) - existingLen - space := unix.CmsgSpace(sizeOfGSOData) - if avail < space { - return - } - *control = (*control)[:cap(*control)] - gsoControl := (*control)[existingLen:] - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) - hdr.Level = unix.SOL_UDP - hdr.Type = unix.UDP_SEGMENT - hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) - copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) - *control = (*control)[:existingLen+space] -} - -// controlSize returns the recommended buffer size for pooling sticky and UDP -// offloading control data. -var controlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + unix.CmsgSpace(sizeOfGSOData) - -const StdNetSupportsStickySockets = true diff --git a/conn/control_linux_test.go b/conn/control_linux_test.go deleted file mode 100644 index 96f9da2..0000000 --- a/conn/control_linux_test.go +++ /dev/null @@ -1,266 +0,0 @@ -//go:build linux && !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "context" - "net" - "net/netip" - "runtime" - "testing" - "unsafe" - - "golang.org/x/sys/unix" -) - -func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { - var buf []byte - if addr.Is4() { - buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) - hdr := unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - } - hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) - copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) - - info := unix.Inet4Pktinfo{ - Ifindex: ifidx, - Spec_dst: addr.As4(), - } - copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) - } else { - buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) - hdr := unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - } - hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) - copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) - - info := unix.Inet6Pktinfo{ - Ifindex: uint32(ifidx), - Addr: addr.As16(), - } - copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) - } - - ep.src = buf -} - -func Test_setSrcControl(t *testing.T) { - t.Run("IPv4", func(t *testing.T) { - ep := &StdNetEndpoint{ - AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), - } - setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) - - control := make([]byte, controlSize) - - setSrcControl(&control, ep) - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - if hdr.Level != unix.IPPROTO_IP { - t.Errorf("unexpected level: %d", hdr.Level) - } - if hdr.Type != unix.IP_PKTINFO { - t.Errorf("unexpected type: %d", hdr.Type) - } - if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { - t.Errorf("unexpected length: %d", hdr.Len) - } - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { - t.Errorf("unexpected address: %v", info.Spec_dst) - } - if info.Ifindex != 5 { - t.Errorf("unexpected ifindex: %d", info.Ifindex) - } - }) - - t.Run("IPv6", func(t *testing.T) { - ep := &StdNetEndpoint{ - AddrPort: netip.MustParseAddrPort("[::1]:1234"), - } - setSrc(ep, netip.MustParseAddr("::1"), 5) - - control := make([]byte, controlSize) - - setSrcControl(&control, ep) - - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - if hdr.Level != unix.IPPROTO_IPV6 { - t.Errorf("unexpected level: %d", hdr.Level) - } - if hdr.Type != unix.IPV6_PKTINFO { - t.Errorf("unexpected type: %d", hdr.Type) - } - if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { - t.Errorf("unexpected length: %d", hdr.Len) - } - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - if info.Addr != ep.SrcIP().As16() { - t.Errorf("unexpected address: %v", info.Addr) - } - if info.Ifindex != 5 { - t.Errorf("unexpected ifindex: %d", info.Ifindex) - } - }) - - t.Run("ClearOnNoSrc", func(t *testing.T) { - control := make([]byte, controlSize) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - hdr.Level = 1 - hdr.Type = 2 - hdr.Len = 3 - - setSrcControl(&control, &StdNetEndpoint{}) - - if len(control) != 0 { - t.Errorf("unexpected control: %v", control) - } - }) -} - -func Test_getSrcFromControl(t *testing.T) { - t.Run("IPv4", func(t *testing.T) { - control := make([]byte, controlSize) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - info.Spec_dst = [4]byte{127, 0, 0, 1} - info.Ifindex = 5 - - ep := &StdNetEndpoint{} - getSrcFromControl(control, ep) - - if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.SrcIP()) - } - if ep.SrcIfidx() != 5 { - t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) - } - }) - t.Run("IPv6", func(t *testing.T) { - control := make([]byte, controlSize) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - hdr.Level = unix.IPPROTO_IPV6 - hdr.Type = unix.IPV6_PKTINFO - hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) - info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} - info.Ifindex = 5 - - ep := &StdNetEndpoint{} - getSrcFromControl(control, ep) - - if ep.SrcIP() != netip.MustParseAddr("::1") { - t.Errorf("unexpected address: %v", ep.SrcIP()) - } - if ep.SrcIfidx() != 5 { - t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) - } - }) - t.Run("ClearOnEmpty", func(t *testing.T) { - var control []byte - ep := &StdNetEndpoint{} - setSrc(ep, netip.MustParseAddr("::1"), 5) - - getSrcFromControl(control, ep) - if ep.SrcIP().IsValid() { - t.Errorf("unexpected address: %v", ep.SrcIP()) - } - if ep.SrcIfidx() != 0 { - t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) - } - }) - t.Run("Multiple", func(t *testing.T) { - zeroControl := make([]byte, unix.CmsgSpace(0)) - zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) - zeroHdr.SetLen(unix.CmsgLen(0)) - - control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) - hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) - hdr.Level = unix.IPPROTO_IP - hdr.Type = unix.IP_PKTINFO - hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) - info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) - info.Spec_dst = [4]byte{127, 0, 0, 1} - info.Ifindex = 5 - - combined := make([]byte, 0) - combined = append(combined, zeroControl...) - combined = append(combined, control...) - - ep := &StdNetEndpoint{} - getSrcFromControl(combined, ep) - - if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { - t.Errorf("unexpected address: %v", ep.SrcIP()) - } - if ep.SrcIfidx() != 5 { - t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) - } - }) -} - -func Test_listenConfig(t *testing.T) { - t.Run("IPv4", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") - if err != nil { - t.Fatal(err) - } - defer conn.Close() - sc, err := conn.(*net.UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - if runtime.GOOS == "linux" { - var i int - sc.Control(func(fd uintptr) { - i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) - }) - if err != nil { - t.Fatal(err) - } - if i != 1 { - t.Error("IP_PKTINFO not set!") - } - } else { - t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) - } - }) - t.Run("IPv6", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") - if err != nil { - t.Fatal(err) - } - sc, err := conn.(*net.UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - if runtime.GOOS == "linux" { - var i int - sc.Control(func(fd uintptr) { - i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) - }) - if err != nil { - t.Fatal(err) - } - if i != 1 { - t.Error("IPV6_PKTINFO not set!") - } - } else { - t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) - } - }) -} diff --git a/conn/gso_default.go b/conn/gso_default.go new file mode 100644 index 0000000..57780db --- /dev/null +++ b/conn/gso_default.go @@ -0,0 +1,21 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// gsoControlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const gsoControlSize = 0 diff --git a/conn/gso_linux.go b/conn/gso_linux.go new file mode 100644 index 0000000..b8599ce --- /dev/null +++ b/conn/gso_linux.go @@ -0,0 +1,65 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// gsoControlSize returns the recommended buffer size for pooling UDP +// offloading control data. +var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) diff --git a/conn/sticky_default.go b/conn/sticky_default.go new file mode 100644 index 0000000..0b21386 --- /dev/null +++ b/conn/sticky_default.go @@ -0,0 +1,42 @@ +//go:build !linux || android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +const stickyControlSize = 0 + +const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go new file mode 100644 index 0000000..8e206e9 --- /dev/null +++ b/conn/sticky_linux.go @@ -0,0 +1,112 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { + ep.ClearSrc() + + var ( + hdr unix.Cmsghdr + data []byte + rem []byte = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return + } + + if hdr.Level == unix.IPPROTO_IP && + hdr.Type == unix.IP_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + + if hdr.Level == unix.IPPROTO_IPV6 && + hdr.Type == unix.IPV6_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + } +} + +// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address +// and source ifindex found in ep. control's len will be set to 0 in the event +// that ep is a default value. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { + if cap(*control) < len(ep.src) { + return + } + *control = (*control)[:0] + *control = append(*control, ep.src...) +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + +const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go new file mode 100644 index 0000000..d2bd584 --- /dev/null +++ b/conn/sticky_linux_test.go @@ -0,0 +1,266 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "net" + "net/netip" + "runtime" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { + var buf []byte + if addr.Is4() { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + } else { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } + + ep.src = buf +} + +func Test_setSrcControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) + + control := make([]byte, stickyControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IP { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IP_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("IPv6", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("[::1]:1234"), + } + setSrc(ep, netip.MustParseAddr("::1"), 5) + + control := make([]byte, stickyControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IPV6 { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IPV6_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Addr != ep.SrcIP().As16() { + t.Errorf("unexpected address: %v", info.Addr) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("ClearOnNoSrc", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = 1 + hdr.Type = 2 + hdr.Len = 3 + + setSrcControl(&control, &StdNetEndpoint{}) + + if len(control) != 0 { + t.Errorf("unexpected control: %v", control) + } + }) +} + +func Test_getSrcFromControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("IPv6", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("::1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("ClearOnEmpty", func(t *testing.T) { + var control []byte + ep := &StdNetEndpoint{} + setSrc(ep, netip.MustParseAddr("::1"), 5) + + getSrcFromControl(control, ep) + if ep.SrcIP().IsValid() { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 0 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("Multiple", func(t *testing.T) { + zeroControl := make([]byte, unix.CmsgSpace(0)) + zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) + zeroHdr.SetLen(unix.CmsgLen(0)) + + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + combined := make([]byte, 0) + combined = append(combined, zeroControl...) + combined = append(combined, control...) + + ep := &StdNetEndpoint{} + getSrcFromControl(combined, ep) + + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) +} + +func Test_listenConfig(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IP_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) + t.Run("IPv6", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + if err != nil { + t.Fatal(err) + } + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IPV6_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) +} -- cgit v1.2.3