diff options
author | Jordan Whited <jordan@tailscale.com> | 2023-10-02 14:43:56 -0700 |
---|---|---|
committer | Jason A. Donenfeld <Jason@zx2c4.com> | 2023-10-10 15:07:36 +0200 |
commit | 895d6c23cd60bd0522c5b6598a69ad6c5f1ab3a7 (patch) | |
tree | bde5e763ce9b81a049815d1a268101ea5af61f83 /tun | |
parent | 4201e08f1dbb521e5555d96a3b6464a860466f5f (diff) |
tun: unwind summing loop in checksumNoFold()
$ benchstat old.txt new.txt
goos: linux
goarch: amd64
pkg: golang.zx2c4.com/wireguard/tun
cpu: 12th Gen Intel(R) Core(TM) i5-12400
│ old.txt │ new.txt │
│ sec/op │ sec/op vs base │
Checksum/64-12 10.670n ± 2% 4.769n ± 0% -55.30% (p=0.000 n=10)
Checksum/128-12 19.665n ± 2% 8.032n ± 0% -59.16% (p=0.000 n=10)
Checksum/256-12 37.68n ± 1% 16.06n ± 0% -57.37% (p=0.000 n=10)
Checksum/512-12 76.61n ± 3% 32.13n ± 0% -58.06% (p=0.000 n=10)
Checksum/1024-12 160.55n ± 4% 64.25n ± 0% -59.98% (p=0.000 n=10)
Checksum/1500-12 231.05n ± 7% 94.12n ± 0% -59.26% (p=0.000 n=10)
Checksum/2048-12 309.5n ± 3% 128.5n ± 0% -58.48% (p=0.000 n=10)
Checksum/4096-12 603.8n ± 4% 257.2n ± 0% -57.41% (p=0.000 n=10)
Checksum/8192-12 1185.0n ± 3% 515.5n ± 0% -56.50% (p=0.000 n=10)
Checksum/9000-12 1328.5n ± 5% 564.8n ± 0% -57.49% (p=0.000 n=10)
Checksum/9001-12 1340.5n ± 3% 564.8n ± 0% -57.87% (p=0.000 n=10)
geomean 185.3n 77.99n -57.92%
Reviewed-by: Adrian Dewhurst <adrian@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
Diffstat (limited to 'tun')
-rw-r--r-- | tun/checksum.go | 100 | ||||
-rw-r--r-- | tun/checksum_test.go | 35 |
2 files changed, 123 insertions, 12 deletions
diff --git a/tun/checksum.go b/tun/checksum.go index f4f8471..29a8fc8 100644 --- a/tun/checksum.go +++ b/tun/checksum.go @@ -3,23 +3,99 @@ package tun import "encoding/binary" // TODO: Explore SIMD and/or other assembly optimizations. +// TODO: Test native endian loads. See RFC 1071 section 2 part B. func checksumNoFold(b []byte, initial uint64) uint64 { ac := initial - i := 0 - n := len(b) - for n >= 4 { - ac += uint64(binary.BigEndian.Uint32(b[i : i+4])) - n -= 4 - i += 4 + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] + } + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] + } + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] + } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] } - for n >= 2 { - ac += uint64(binary.BigEndian.Uint16(b[i : i+2])) - n -= 2 - i += 2 + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] } - if n == 1 { - ac += uint64(b[i]) << 8 + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + return ac } diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 0000000..c1ccff5 --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,35 @@ +package tun + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkChecksum(b *testing.B) { + lengths := []int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + } + + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum(buf, 0) + } + }) + } +} |