diff options
Diffstat (limited to 'pkg/tcpip/header/checksum_test.go')
-rw-r--r-- | pkg/tcpip/header/checksum_test.go | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 86b466c1c..309403482 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -17,6 +17,8 @@ package header_test import ( + "fmt" + "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -107,3 +109,63 @@ func TestChecksumVVWithOffset(t *testing.T) { }) } } + +func TestChecksum(t *testing.T) { + var bufSizes = []int{0, 1, 2, 3, 4, 7, 8, 15, 16, 31, 32, 63, 64, 127, 128, 255, 256, 257, 1023, 1024} + type testCase struct { + buf []byte + initial uint16 + csumOrig uint16 + csumNew uint16 + } + testCases := make([]testCase, 100000) + // Ensure same buffer generation for test consistency. + rnd := rand.New(rand.NewSource(42)) + for i := range testCases { + testCases[i].buf = make([]byte, bufSizes[i%len(bufSizes)]) + testCases[i].initial = uint16(rnd.Intn(65536)) + rnd.Read(testCases[i].buf) + } + + for i := range testCases { + testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) + testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) + if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { + t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) + } + } +} + +func BenchmarkChecksum(b *testing.B) { + var bufSizes = []int{64, 128, 256, 512, 1024, 1500, 2048, 4096, 8192, 16384, 32767, 32768, 65535, 65536} + + checkSumImpls := []struct { + fn func([]byte, uint16) uint16 + name string + }{ + {header.ChecksumOld, fmt.Sprintf("checksum_old")}, + {header.Checksum, fmt.Sprintf("checksum")}, + } + + for _, csumImpl := range checkSumImpls { + // Ensure same buffer generation for test consistency. + rnd := rand.New(rand.NewSource(42)) + for _, bufSz := range bufSizes { + b.Run(fmt.Sprintf("%s_%d", csumImpl.name, bufSz), func(b *testing.B) { + tc := struct { + buf []byte + initial uint16 + csum uint16 + }{ + buf: make([]byte, bufSz), + initial: uint16(rnd.Intn(65536)), + } + rnd.Read(tc.buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + tc.csum = csumImpl.fn(tc.buf, tc.initial) + } + }) + } + } +} |