summaryrefslogtreecommitdiffhomepage
path: root/conn/sticky_linux.go
blob: bf1783912a002bee553a5a176e1733b64105a6c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
/* SPDX-License-Identifier: MIT
 *
 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
 */

package conn

import (
	"net/netip"
	"unsafe"

	"golang.org/x/sys/unix"
)

// getSrcFromControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
	ep.ClearSrc()

	var (
		hdr  unix.Cmsghdr
		data []byte
		rem  []byte = control
		err  error
	)

	for len(rem) > unix.SizeofCmsghdr {
		hdr, data, rem, err = unix.ParseOneSocketControlMessage(control)
		if err != nil {
			return
		}

		if hdr.Level == unix.IPPROTO_IP &&
			hdr.Type == unix.IP_PKTINFO {

			info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
			ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
			ep.src.ifidx = info.Ifindex

			return
		}

		if hdr.Level == unix.IPPROTO_IPV6 &&
			hdr.Type == unix.IPV6_PKTINFO {

			info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
			ep.src.Addr = netip.AddrFrom16(info.Addr)
			ep.src.ifidx = int32(info.Ifindex)

			return
		}
	}
}

// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
// panics if buf is of insufficient size.
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
	size := int(unsafe.Sizeof(t))
	if len(buf) < size {
		panic("pktInfoFromBuf: buffer too small")
	}
	copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
	return t
}

// setSrcControl parses the control for PKTINFO and if found updates ep with
// the source information found.
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
	*control = (*control)[:cap(*control)]
	if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
		*control = (*control)[:0]
		return
	}

	if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
		*control = (*control)[:0]
		return
	}

	if len(*control) < srcControlSize {
		*control = (*control)[:0]
		return
	}

	hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
	if ep.SrcIP().Is4() {
		hdr.Level = unix.IPPROTO_IP
		hdr.Type = unix.IP_PKTINFO
		hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))

		info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
		info.Ifindex = ep.src.ifidx
		if ep.SrcIP().IsValid() {
			info.Spec_dst = ep.SrcIP().As4()
		}
	} else {
		hdr.Level = unix.IPPROTO_IPV6
		hdr.Type = unix.IPV6_PKTINFO
		hdr.Len = unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo

		info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
		info.Ifindex = uint32(ep.src.ifidx)
		if ep.SrcIP().IsValid() {
			info.Addr = ep.SrcIP().As16()
		}
	}

	*control = (*control)[:hdr.Len]
}

var srcControlSize = unix.CmsgLen(unix.SizeofInet6Pktinfo)