summaryrefslogtreecommitdiffhomepage
path: root/tun
diff options
context:
space:
mode:
authorJordan Whited <jordan@tailscale.com>2023-10-02 14:43:56 -0700
committerJason A. Donenfeld <Jason@zx2c4.com>2023-10-10 15:07:36 +0200
commit895d6c23cd60bd0522c5b6598a69ad6c5f1ab3a7 (patch)
treebde5e763ce9b81a049815d1a268101ea5af61f83 /tun
parent4201e08f1dbb521e5555d96a3b6464a860466f5f (diff)
tun: unwind summing loop in checksumNoFold()
$ benchstat old.txt new.txt goos: linux goarch: amd64 pkg: golang.zx2c4.com/wireguard/tun cpu: 12th Gen Intel(R) Core(TM) i5-12400 │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ Checksum/64-12 10.670n ± 2% 4.769n ± 0% -55.30% (p=0.000 n=10) Checksum/128-12 19.665n ± 2% 8.032n ± 0% -59.16% (p=0.000 n=10) Checksum/256-12 37.68n ± 1% 16.06n ± 0% -57.37% (p=0.000 n=10) Checksum/512-12 76.61n ± 3% 32.13n ± 0% -58.06% (p=0.000 n=10) Checksum/1024-12 160.55n ± 4% 64.25n ± 0% -59.98% (p=0.000 n=10) Checksum/1500-12 231.05n ± 7% 94.12n ± 0% -59.26% (p=0.000 n=10) Checksum/2048-12 309.5n ± 3% 128.5n ± 0% -58.48% (p=0.000 n=10) Checksum/4096-12 603.8n ± 4% 257.2n ± 0% -57.41% (p=0.000 n=10) Checksum/8192-12 1185.0n ± 3% 515.5n ± 0% -56.50% (p=0.000 n=10) Checksum/9000-12 1328.5n ± 5% 564.8n ± 0% -57.49% (p=0.000 n=10) Checksum/9001-12 1340.5n ± 3% 564.8n ± 0% -57.87% (p=0.000 n=10) geomean 185.3n 77.99n -57.92% Reviewed-by: Adrian Dewhurst <adrian@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'tun')
-rw-r--r--tun/checksum.go100
-rw-r--r--tun/checksum_test.go35
2 files changed, 123 insertions, 12 deletions
diff --git a/tun/checksum.go b/tun/checksum.go
index f4f8471..29a8fc8 100644
--- a/tun/checksum.go
+++ b/tun/checksum.go
@@ -3,23 +3,99 @@ package tun
import "encoding/binary"
// TODO: Explore SIMD and/or other assembly optimizations.
+// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
ac := initial
- i := 0
- n := len(b)
- for n >= 4 {
- ac += uint64(binary.BigEndian.Uint32(b[i : i+4]))
- n -= 4
- i += 4
+
+ for len(b) >= 128 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ ac += uint64(binary.BigEndian.Uint32(b[64:68]))
+ ac += uint64(binary.BigEndian.Uint32(b[68:72]))
+ ac += uint64(binary.BigEndian.Uint32(b[72:76]))
+ ac += uint64(binary.BigEndian.Uint32(b[76:80]))
+ ac += uint64(binary.BigEndian.Uint32(b[80:84]))
+ ac += uint64(binary.BigEndian.Uint32(b[84:88]))
+ ac += uint64(binary.BigEndian.Uint32(b[88:92]))
+ ac += uint64(binary.BigEndian.Uint32(b[92:96]))
+ ac += uint64(binary.BigEndian.Uint32(b[96:100]))
+ ac += uint64(binary.BigEndian.Uint32(b[100:104]))
+ ac += uint64(binary.BigEndian.Uint32(b[104:108]))
+ ac += uint64(binary.BigEndian.Uint32(b[108:112]))
+ ac += uint64(binary.BigEndian.Uint32(b[112:116]))
+ ac += uint64(binary.BigEndian.Uint32(b[116:120]))
+ ac += uint64(binary.BigEndian.Uint32(b[120:124]))
+ ac += uint64(binary.BigEndian.Uint32(b[124:128]))
+ b = b[128:]
+ }
+ if len(b) >= 64 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ ac += uint64(binary.BigEndian.Uint32(b[32:36]))
+ ac += uint64(binary.BigEndian.Uint32(b[36:40]))
+ ac += uint64(binary.BigEndian.Uint32(b[40:44]))
+ ac += uint64(binary.BigEndian.Uint32(b[44:48]))
+ ac += uint64(binary.BigEndian.Uint32(b[48:52]))
+ ac += uint64(binary.BigEndian.Uint32(b[52:56]))
+ ac += uint64(binary.BigEndian.Uint32(b[56:60]))
+ ac += uint64(binary.BigEndian.Uint32(b[60:64]))
+ b = b[64:]
+ }
+ if len(b) >= 32 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ ac += uint64(binary.BigEndian.Uint32(b[16:20]))
+ ac += uint64(binary.BigEndian.Uint32(b[20:24]))
+ ac += uint64(binary.BigEndian.Uint32(b[24:28]))
+ ac += uint64(binary.BigEndian.Uint32(b[28:32]))
+ b = b[32:]
+ }
+ if len(b) >= 16 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ ac += uint64(binary.BigEndian.Uint32(b[8:12]))
+ ac += uint64(binary.BigEndian.Uint32(b[12:16]))
+ b = b[16:]
}
- for n >= 2 {
- ac += uint64(binary.BigEndian.Uint16(b[i : i+2]))
- n -= 2
- i += 2
+ if len(b) >= 8 {
+ ac += uint64(binary.BigEndian.Uint32(b[:4]))
+ ac += uint64(binary.BigEndian.Uint32(b[4:8]))
+ b = b[8:]
}
- if n == 1 {
- ac += uint64(b[i]) << 8
+ if len(b) >= 4 {
+ ac += uint64(binary.BigEndian.Uint32(b))
+ b = b[4:]
}
+ if len(b) >= 2 {
+ ac += uint64(binary.BigEndian.Uint16(b))
+ b = b[2:]
+ }
+ if len(b) == 1 {
+ ac += uint64(b[0]) << 8
+ }
+
return ac
}
diff --git a/tun/checksum_test.go b/tun/checksum_test.go
new file mode 100644
index 0000000..c1ccff5
--- /dev/null
+++ b/tun/checksum_test.go
@@ -0,0 +1,35 @@
+package tun
+
+import (
+ "fmt"
+ "math/rand"
+ "testing"
+)
+
+func BenchmarkChecksum(b *testing.B) {
+ lengths := []int{
+ 64,
+ 128,
+ 256,
+ 512,
+ 1024,
+ 1500,
+ 2048,
+ 4096,
+ 8192,
+ 9000,
+ 9001,
+ }
+
+ for _, length := range lengths {
+ b.Run(fmt.Sprintf("%d", length), func(b *testing.B) {
+ buf := make([]byte, length)
+ rng := rand.New(rand.NewSource(1))
+ rng.Read(buf)
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ checksum(buf, 0)
+ }
+ })
+ }
+}