diff options
-rw-r--r-- | pkg/tcpip/header/checksum.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/header/checksum_test.go | 6 |
2 files changed, 12 insertions, 9 deletions
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go index ce57b581a..204285576 100644 --- a/pkg/tcpip/header/checksum.go +++ b/pkg/tcpip/header/checksum.go @@ -160,20 +160,23 @@ func unrolledCalculateChecksum(buf []byte, odd bool, initial uint32) (uint16, bo return ChecksumCombine(uint16(v), uint16(v>>16)), odd } -// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the -// given byte array. +// ChecksumOld calculates the checksum (as defined in RFC 1071) of the bytes in +// the given byte array. This function uses a non-optimized implementation. Its +// only retained for reference and to use as a benchmark/test. Most code should +// use the header.Checksum function. // // The initial checksum must have been computed on an even number of bytes. -func Checksum(buf []byte, initial uint16) uint16 { +func ChecksumOld(buf []byte, initial uint16) uint16 { s, _ := calculateChecksum(buf, false, uint32(initial)) return s } -// UnrolledChecksum calculates the checksum (as defined in RFC 1071) of the -// bytes in the given byte array. +// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the +// given byte array. This function uses an optimized unrolled version of the +// checksum algorithm. // // The initial checksum must have been computed on an even number of bytes. -func UnrolledChecksum(buf []byte, initial uint16) uint16 { +func Checksum(buf []byte, initial uint16) uint16 { s, _ := unrolledCalculateChecksum(buf, false, uint32(initial)) return s } diff --git a/pkg/tcpip/header/checksum_test.go b/pkg/tcpip/header/checksum_test.go index 2fbd16a65..309403482 100644 --- a/pkg/tcpip/header/checksum_test.go +++ b/pkg/tcpip/header/checksum_test.go @@ -128,8 +128,8 @@ func TestChecksum(t *testing.T) { } for i := range testCases { - testCases[i].csumOrig = header.Checksum(testCases[i].buf, testCases[i].initial) - testCases[i].csumNew = header.UnrolledChecksum(testCases[i].buf, testCases[i].initial) + testCases[i].csumOrig = header.ChecksumOld(testCases[i].buf, testCases[i].initial) + testCases[i].csumNew = header.Checksum(testCases[i].buf, testCases[i].initial) if got, want := testCases[i].csumNew, testCases[i].csumOrig; got != want { t.Fatalf("new checksum for (buf = %x, initial = %d) does not match old got: %d, want: %d", testCases[i].buf, testCases[i].initial, got, want) } @@ -143,8 +143,8 @@ func BenchmarkChecksum(b *testing.B) { fn func([]byte, uint16) uint16 name string }{ + {header.ChecksumOld, fmt.Sprintf("checksum_old")}, {header.Checksum, fmt.Sprintf("checksum")}, - {header.UnrolledChecksum, fmt.Sprintf("unrolled_checksum")}, } for _, csumImpl := range checkSumImpls { |