diff options
author | Jordan Whited <jordan@tailscale.com> | 2023-03-14 20:28:07 -0700 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2023-03-16 17:45:41 +0100 |
commit | 07a1e55270bd34ee526ad328d597fa01a8e17619 (patch) | |
tree | 4f4ca5ed6665af1e2432f9b382d274bc968847a3 /conn | |
parent | fff53afca779078061128a5a10c31e67a7919d35 (diff) |
conn: fix getSrcFromControl() iteration
We only expect a single control message in the normal case, but this
would loop infinitely if there were more.
Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'conn')
-rw-r--r-- | conn/sticky_linux.go | 2 | ||||
-rw-r--r-- | conn/sticky_linux_test.go | 28 |
2 files changed, 29 insertions, 1 deletions
diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go index 342e739..278eb19 100644 --- a/conn/sticky_linux.go +++ b/conn/sticky_linux.go @@ -25,7 +25,7 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) { ) for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) if err != nil { return } diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index 672b67e..503c342 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -150,6 +150,34 @@ func Test_getSrcFromControl(t *testing.T) { t.Errorf("unexpected ifindex: %d", ep.src.ifidx) } }) + t.Run("Multiple", func(t *testing.T) { + zeroControl := make([]byte, unix.CmsgSpace(0)) + zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) + zeroHdr.SetLen(unix.CmsgLen(0)) + + control := make([]byte, srcControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + combined := make([]byte, 0) + combined = append(combined, zeroControl...) + combined = append(combined, control...) + + ep := &StdNetEndpoint{} + getSrcFromControl(combined, ep) + + if ep.src.Addr != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.src.Addr) + } + if ep.src.ifidx != 5 { + t.Errorf("unexpected ifindex: %d", ep.src.ifidx) + } + }) } func Test_listenConfig(t *testing.T) { |